Source code for glass_box_umap.plotting.mpl

import matplotlib.pyplot as plt
import numpy as np
from adjustText import adjust_text
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from matplotlib.pyplot import get_cmap
from numpy.typing import NDArray


def make_cmap(n_colors: int = 40, seed: int = 42) -> ListedColormap:
    """Create a shuffled qualitative colormap by tiling tab20, tab20b, and tab20c.

    There are 60 unique colors across the three colormaps. If n_colors exceeds 60,
    colors are tiled (repeated) to reach the requested count, then shuffled.
    """
    base_colors = np.vstack(
        (
            get_cmap("tab20").colors,  # type: ignore
            get_cmap("tab20b").colors,  # type: ignore
            get_cmap("tab20c").colors,  # type: ignore
        )
    )
    tiles = (n_colors + len(base_colors) - 1) // len(base_colors)
    colors = np.tile(base_colors, (tiles, 1))[:n_colors]
    rng = np.random.default_rng(seed)
    colors = colors[rng.permutation(n_colors)]

    return ListedColormap(colors)


[docs] def plot_embedding_static( Z: NDArray[np.floating], group_ids: NDArray[np.integer] | None = None, group_names: list[str] | None = None, cmap: ListedColormap | None = None, marker_size: float = 2.0, ) -> Figure: """Static (matplotlib) scatter plot of a 2D embedding, optionally colored by group. Args: Z: Embedding coordinates with shape (n_samples, 2). group_ids: Integer group ID per point with shape (n_samples,). If None, points are uncolored. group_names: Human-readable name for each group, indexed by group ID. If None and group_ids are provided, defaults to str(gid) for each group. cmap: Colormap for the scatter plot. If None and group_ids are provided, a colormap is generated with one color per unique group. marker_size: Size of scatter plot markers. """ fig, ax = plt.subplots(figsize=(10, 8)) if group_ids is None: ax.scatter( Z[:, 0], Z[:, 1], s=marker_size, alpha=0.5, rasterized=True, ) else: unique_ids = np.unique(group_ids) if cmap is None: cmap = make_cmap(len(unique_ids)) if group_names is None: group_names = [str(gid) for gid in range(unique_ids.max() + 1)] texts = [] for gid in unique_ids: mask = group_ids == gid color = cmap(gid % len(cmap.colors)) # type: ignore ax.scatter( Z[mask, 0], Z[mask, 1], s=marker_size, alpha=0.5, color=color, rasterized=True, ) cx, cy = np.median(Z[mask, 0]), np.median(Z[mask, 1]) texts.append( ax.text( cx, cy, group_names[gid], fontsize=7, fontweight="bold", ha="center", va="center", bbox=dict(boxstyle="round,pad=0.2", fc="white", ec=color, alpha=0.7), ) ) adjust_text(texts, ax=ax) ax.set_xlabel("UMAP 1") ax.set_ylabel("UMAP 2") return fig