glass_box_umap.plotting

Overview

Classes

LiveEmbeddingCallback

Pytorch Lightning callback that serves a live-updating Bokeh scatter.

Function

plot_embedding(Z, contributions, *None, group_names, feature_names, feature_values, top_k_global, hover_images, hover_tooltips, hover_data, output_backend)

Interactive 2D embedding scatter linked to a feature-contribution bar chart.

plot_embedding_static(Z, group_ids, group_names, cmap, marker_size)

Static (matplotlib) scatter plot of a 2D embedding, optionally colored by group.

Classes

class LiveEmbeddingCallback(transform_fn: Callable[[Tensor], ndarray[tuple[Any, ...], dtype[floating]]], X: Tensor, labels: list[str] | None = None, port: int = 0, output_backend: Literal['canvas', 'webgl'] = 'webgl', hover_images: ndarray[tuple[Any, ...], dtype[uint8]] | None = None, block_after_fit: bool = True)[source]

Pytorch Lightning callback that serves a live-updating Bokeh scatter.

Spins up a Bokeh server on a background thread, opens a browser tab, and streams a fresh embedding (via transform_fn) to the page after each training epoch starts. Each session keeps a per-frame history that the user can scrub through with a slider, play back with a button, or export to a self-contained HTML file. Training keeps running on the main thread; updates cross to the Bokeh event loop via Document.add_next_tick_callback.

Base Classes:

pytorch_lightning.Callback

Parameters:
  • transform_fn -- Callable that maps the high-dimensional X to a (n_samples, 2) array. Typically the embedder’s transform method.

  • X -- High-dimensional input fed to transform_fn after each epoch.

  • labels -- Optional per-sample categorical labels for coloring.

  • port -- Port the Bokeh server listens on. 0 (default) lets the OS pick a free port, which avoids EADDRINUSE collisions when the callback is re-instantiated within the same process (e.g. a Jupyter kernel that already hosts a previous run’s server).

  • output_backend -- Bokeh rendering backend for the scatter. Defaults to "webgl"; switch to "canvas" if the GPU/driver/browser combination renders the plot incorrectly.

  • hover_images -- Optional uint8 image array of shape (n_samples, H, W) or (n_samples, H, W, 3 | 4). When set, each tooltip shows the sample’s image above the index/label text.

  • block_after_fit -- When True (default), block at the end of training so the Bokeh server keeps serving until the user presses Ctrl-C. Set to False from interactive contexts (e.g. Jupyter) where the host process already keeps the server alive.

Methods:

on_fit_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]

Called when fit begins.

on_train_epoch_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]

Called when the train epoch begins.

on_train_epoch_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_train_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]

Called when the train ends.

Functions

plot_embedding(Z: NDArray[floating], contributions: NDArray[floating], *, group_names: Sequence[Any] | NDArray | None = None, feature_names: list[str] | None = None, feature_values: NDArray[floating] | None = None, top_k_global: int = 200, hover_images: NDArray[uint8] | None = None, hover_tooltips: str | None = None, hover_data: Mapping[str, Sequence[Any]] | None = None, output_backend: glass_box_umap.plotting.bokeh._scatter.OutputBackend = 'webgl') bokeh.models.layouts.LayoutDOM[source]

Interactive 2D embedding scatter linked to a feature-contribution bar chart.

A single radio toggle above the scatter chooses how to color the points:

  • Group (only available when group_names is provided): categorical coloring by user-supplied labels.

  • Feature: a Viridis gradient over the L2-reduced contribution of one feature, picked via an autocomplete input that appears below the toggle (substring match, case-insensitive).

  • Top feature: each sample is colored by the kept feature with its largest L2-reduced contribution. A slider lets the user choose the top-N most-frequent top features to colorize; samples whose top feature isn’t in that set are drawn in gray underneath the colored points.

Lasso- or box-selecting points in the scatter updates the linked bar chart on the right (which has its own L2 | normed L2 | Dim 1 | Dim 2 view toggle); with no selection the bars summarize all samples.

Parameters:
  • Z : NDArray[floating]

    Embedding coordinates of shape (n_samples, 2).

  • contributions : NDArray[floating]

    Per-feature contributions of shape (n_samples, 2, n_features). Typically the output of compute_contributions() with reduction=None.

  • group_names : Sequence[Any] | NDArray | None

    Group label per sample. Any sequence of length n_samples; elements are stringified before use. When provided, the Group color mode is added to the radio and used as the default; when None (default), the radio shows only Feature / Top feature and starts in Feature mode.

  • feature_names : list[str] | None

    Human-readable name per feature; length must equal contributions.shape[2]. Defaults to "Feature {i}" (0-indexed).

  • feature_values : NDArray[floating] | None

    Per-sample feature values of shape (n_samples, n_features). When provided, the default tooltip for Feature mode adds value: <X> (the picker-selected feature’s value), and the default tooltip for Top feature mode adds value: <X> (the top feature’s value). Whatever scaling the caller passes is what the tooltip displays — pass raw values for human-readable tooltips, or the same standardized array fed to the embedder for consistency with contributions space. Ignored when hover_tooltips is set.

  • top_k_global : int

    How many features to ship to the browser, ranked by global L2 importance. Caps everything: the bar chart, the feature-picker autocomplete, and the candidate set for top-feature ranking.

  • hover_images : NDArray[uint8] | None

    Per-sample uint8 image array of shape (n_samples, H, W) or (n_samples, H, W, 3 | 4). When set, each tooltip shows the sample’s image above the default index/group text. Mutually exclusive with hover_tooltips and hover_data.

  • hover_tooltips : str | None

    Bokeh tooltip HTML template that fully replaces the default. May reference @index, @group (when group_names is provided), and any keys from hover_data.

  • hover_data : Mapping[str, Sequence[Any]] | None

    Extra columns merged into the scatter ColumnDataSource for reference from hover_tooltips. Each value must have length n_samples. Keys must not collide with the reserved columns x, y, index, group, color_value, top_feature_group, top_feature_name, top_data_value, picker_data_value, sample_rank.

  • output_backend : glass_box_umap.plotting.bokeh._scatter.OutputBackend

    Bokeh rendering backend for the scatter. Defaults to "webgl", which offloads rendering to the GPU and stays smooth at high sample counts. Switch to "canvas" if the GPU/driver/browser combination renders the plot incorrectly (e.g. blank canvas, wrong-sized points, or color banding) — canvas is slower but uses CPU rasterization and works on any setup that supports Bokeh at all.

Returns:

A Bokeh layout — color-by controls + scatter on the left, linked bar chart with view toggle on the right. Pass it to bokeh.io.show() or bokeh.io.save().

Return type:

bokeh.models.layouts.LayoutDOM

plot_embedding_static(Z: NDArray[floating], group_ids: NDArray[integer] | None = None, group_names: list[str] | None = None, cmap: matplotlib.colors.ListedColormap | None = None, marker_size: float = 2.0) matplotlib.figure.Figure[source]

Static (matplotlib) scatter plot of a 2D embedding, optionally colored by group.

Parameters:
  • Z : NDArray[floating]

    Embedding coordinates with shape (n_samples, 2).

  • group_ids : NDArray[integer] | None

    Integer group ID per point with shape (n_samples,). If None, points are uncolored.

  • group_names : list[str] | None

    Human-readable name for each group, indexed by group ID. If None and group_ids are provided, defaults to str(gid) for each group.

  • cmap : matplotlib.colors.ListedColormap | None

    Colormap for the scatter plot. If None and group_ids are provided, a colormap is generated with one color per unique group.

  • marker_size : float

    Size of scatter plot markers.

Return type:

matplotlib.figure.Figure