glass_box_umap.plotting¶
Overview¶
Pytorch Lightning callback that serves a live-updating Bokeh scatter. |
|
Interactive 2D embedding scatter linked to a feature-contribution bar chart. |
|
Static (matplotlib) scatter plot of a 2D embedding, optionally colored by group. |
Classes¶
- class LiveEmbeddingCallback(transform_fn: Callable[[Tensor], ndarray[tuple[Any, ...], dtype[floating]]], X: Tensor, labels: list[str] | None = None, port: int = 0, output_backend: Literal['canvas', 'webgl'] = 'webgl', hover_images: ndarray[tuple[Any, ...], dtype[uint8]] | None = None, block_after_fit: bool = True)[source]¶
Pytorch Lightning callback that serves a live-updating Bokeh scatter.
Spins up a Bokeh server on a background thread, opens a browser tab, and streams a fresh embedding (via
transform_fn) to the page after each training epoch starts. Each session keeps a per-frame history that the user can scrub through with a slider, play back with a button, or export to a self-contained HTML file. Training keeps running on the main thread; updates cross to the Bokeh event loop viaDocument.add_next_tick_callback.Base Classes:
pytorch_lightning.Callback- Parameters:
transform_fn¶ -- Callable that maps the high-dimensional
Xto a(n_samples, 2)array. Typically the embedder’stransformmethod.X¶ -- High-dimensional input fed to
transform_fnafter each epoch.labels¶ -- Optional per-sample categorical labels for coloring.
port¶ -- Port the Bokeh server listens on.
0(default) lets the OS pick a free port, which avoidsEADDRINUSEcollisions when the callback is re-instantiated within the same process (e.g. a Jupyter kernel that already hosts a previous run’s server).output_backend¶ -- Bokeh rendering backend for the scatter. Defaults to
"webgl"; switch to"canvas"if the GPU/driver/browser combination renders the plot incorrectly.hover_images¶ -- Optional uint8 image array of shape
(n_samples, H, W)or(n_samples, H, W, 3 | 4). When set, each tooltip shows the sample’s image above the index/label text.block_after_fit¶ -- When
True(default), block at the end of training so the Bokeh server keeps serving until the user presses Ctrl-C. Set toFalsefrom interactive contexts (e.g. Jupyter) where the host process already keeps the server alive.
Methods:
- on_fit_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]¶
Called when fit begins.
- on_train_epoch_start(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]¶
Called when the train epoch begins.
- on_train_epoch_end(trainer: pytorch_lightning.Trainer, pl_module: pytorch_lightning.LightningModule) None[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
Functions¶
- plot_embedding(Z: NDArray[floating], contributions: NDArray[floating], *, group_names: Sequence[Any] | NDArray | None = None, feature_names: list[str] | None = None, feature_values: NDArray[floating] | None = None, top_k_global: int = 200, hover_images: NDArray[uint8] | None = None, hover_tooltips: str | None = None, hover_data: Mapping[str, Sequence[Any]] | None = None, output_backend: glass_box_umap.plotting.bokeh._scatter.OutputBackend = 'webgl') bokeh.models.layouts.LayoutDOM[source]¶
Interactive 2D embedding scatter linked to a feature-contribution bar chart.
A single radio toggle above the scatter chooses how to color the points:
Group(only available whengroup_namesis provided): categorical coloring by user-supplied labels.Feature: a Viridis gradient over the L2-reduced contribution of one feature, picked via an autocomplete input that appears below the toggle (substring match, case-insensitive).Top feature: each sample is colored by the kept feature with its largest L2-reduced contribution. A slider lets the user choose the top-N most-frequent top features to colorize; samples whose top feature isn’t in that set are drawn in gray underneath the colored points.
Lasso- or box-selecting points in the scatter updates the linked bar chart on the right (which has its own
L2 | normed L2 | Dim 1 | Dim 2view toggle); with no selection the bars summarize all samples.- Parameters:
-
Embedding coordinates of shape
(n_samples, 2). contributions : NDArray[floating]
Per-feature contributions of shape
(n_samples, 2, n_features). Typically the output ofcompute_contributions()withreduction=None.group_names : Sequence[Any] | NDArray | None
Group label per sample. Any sequence of length
n_samples; elements are stringified before use. When provided, theGroupcolor mode is added to the radio and used as the default; whenNone(default), the radio shows onlyFeature/Top featureand starts inFeaturemode.feature_names : list[str] | None
Human-readable name per feature; length must equal
contributions.shape[2]. Defaults to"Feature {i}"(0-indexed).feature_values : NDArray[floating] | None
Per-sample feature values of shape
(n_samples, n_features). When provided, the default tooltip forFeaturemode addsvalue: <X>(the picker-selected feature’s value), and the default tooltip forTop featuremode addsvalue: <X>(the top feature’s value). Whatever scaling the caller passes is what the tooltip displays — pass raw values for human-readable tooltips, or the same standardized array fed to the embedder for consistency with contributions space. Ignored whenhover_tooltipsis set.top_k_global : int
How many features to ship to the browser, ranked by global L2 importance. Caps everything: the bar chart, the feature-picker autocomplete, and the candidate set for top-feature ranking.
hover_images : NDArray[uint8] | None
Per-sample uint8 image array of shape
(n_samples, H, W)or(n_samples, H, W, 3 | 4). When set, each tooltip shows the sample’s image above the default index/group text. Mutually exclusive withhover_tooltipsandhover_data.-
Bokeh tooltip HTML template that fully replaces the default. May reference
@index,@group(whengroup_namesis provided), and any keys fromhover_data. hover_data : Mapping[str, Sequence[Any]] | None
Extra columns merged into the scatter
ColumnDataSourcefor reference fromhover_tooltips. Each value must have lengthn_samples. Keys must not collide with the reserved columnsx,y,index,group,color_value,top_feature_group,top_feature_name,top_data_value,picker_data_value,sample_rank.output_backend : glass_box_umap.plotting.bokeh._scatter.OutputBackend
Bokeh rendering backend for the scatter. Defaults to
"webgl", which offloads rendering to the GPU and stays smooth at high sample counts. Switch to"canvas"if the GPU/driver/browser combination renders the plot incorrectly (e.g. blank canvas, wrong-sized points, or color banding) — canvas is slower but uses CPU rasterization and works on any setup that supports Bokeh at all.
-
- Returns:
A Bokeh layout — color-by controls + scatter on the left, linked bar chart with view toggle on the right. Pass it to
bokeh.io.show()orbokeh.io.save().- Return type:
bokeh.models.layouts.LayoutDOM
- plot_embedding_static(Z: NDArray[floating], group_ids: NDArray[integer] | None = None, group_names: list[str] | None = None, cmap: matplotlib.colors.ListedColormap | None = None, marker_size: float = 2.0) matplotlib.figure.Figure[source]¶
Static (matplotlib) scatter plot of a 2D embedding, optionally colored by group.
- Parameters:
-
Embedding coordinates with shape (n_samples, 2).
group_ids : NDArray[integer] | None
Integer group ID per point with shape (n_samples,). If None, points are uncolored.
group_names : list[str] | None
Human-readable name for each group, indexed by group ID. If None and group_ids are provided, defaults to str(gid) for each group.
cmap : matplotlib.colors.ListedColormap | None
Colormap for the scatter plot. If None and group_ids are provided, a colormap is generated with one color per unique group.
marker_size : float
Size of scatter plot markers.
-
- Return type:
matplotlib.figure.Figure