import threading
import warnings
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
import numpy as np
import pytorch_lightning as pl
import torch
from bokeh.application import Application
from bokeh.application.handlers import FunctionHandler
from bokeh.document import Document
from bokeh.embed import file_html
from bokeh.layouts import column, row
from bokeh.models import (
Button,
CategoricalColorMapper,
ColumnDataSource,
CustomJS,
HoverTool,
LinearScale,
LogScale,
Slider,
Span,
Toggle,
)
from bokeh.models.layouts import Column
from bokeh.plotting import figure
from bokeh.resources import INLINE
from bokeh.server.server import Server
from bokeh.util.warnings import BokehUserWarning
from numpy.typing import NDArray
from tornado.ioloop import IOLoop
from ._colors import pick_palette
from ._hover import _HOVER_IMAGE_KEY, _to_hover_uri
from ._scatter import OutputBackend
warnings.filterwarnings("ignore", category=BokehUserWarning, message="reference already known")
LOSS_PLOT_HEIGHT = 200
PLAY_BUTTON_WIDTH = 80
LOG_BUTTON_WIDTH = 80
SAVE_BUTTON_WIDTH = 120
[docs]
class LiveEmbeddingCallback(pl.Callback):
"""Pytorch Lightning callback that serves a live-updating Bokeh scatter.
Spins up a Bokeh server on a background thread, opens a browser tab, and
streams a fresh embedding (via ``transform_fn``) to the page after each
training epoch starts. Each session keeps a per-frame history that the
user can scrub through with a slider, play back with a button, or export
to a self-contained HTML file. Training keeps running on the main thread;
updates cross to the Bokeh event loop via
``Document.add_next_tick_callback``.
Args:
transform_fn: Callable that maps the high-dimensional ``X`` to a
``(n_samples, 2)`` array. Typically the embedder's ``transform``
method.
X: High-dimensional input fed to ``transform_fn`` after each epoch.
labels: Optional per-sample categorical labels for coloring.
port: Port the Bokeh server listens on. ``0`` (default) lets the OS
pick a free port, which avoids ``EADDRINUSE`` collisions when the
callback is re-instantiated within the same process (e.g. a Jupyter
kernel that already hosts a previous run's server).
output_backend: Bokeh rendering backend for the scatter. Defaults to
``"webgl"``; switch to ``"canvas"`` if the GPU/driver/browser
combination renders the plot incorrectly.
hover_images: Optional uint8 image array of shape
``(n_samples, H, W)`` or ``(n_samples, H, W, 3 | 4)``. When set,
each tooltip shows the sample's image above the index/label text.
block_after_fit: When ``True`` (default), block at the end of training
so the Bokeh server keeps serving until the user presses Ctrl-C.
Set to ``False`` from interactive contexts (e.g. Jupyter) where
the host process already keeps the server alive.
"""
def __init__(
self,
transform_fn: Callable[[torch.Tensor], NDArray[np.floating]],
X: torch.Tensor,
labels: list[str] | None = None,
port: int = 0,
output_backend: OutputBackend = "webgl",
hover_images: NDArray[np.uint8] | None = None,
block_after_fit: bool = True,
) -> None:
self.transform_fn = transform_fn
self.X = X
self.labels = labels
self.factors = sorted(set(labels)) if labels is not None else None
self.port = port
self.output_backend = output_backend
self.hover_images = hover_images
self.block_after_fit = block_after_fit
self._hover_uris: list[str] | None = (
[_to_hover_uri(img) for img in hover_images] if hover_images is not None else None
)
self.frames: list[NDArray[np.floating]] = []
self.losses: list[float] = []
self.latest_epoch = -1
self.sessions: list[tuple[Document, ColumnDataSource, ColumnDataSource]] = []
self._ready = threading.Event()
self._server_error: BaseException | None = None
def _tooltip(self) -> str:
parts: list[str] = []
if self._hover_uris is not None:
parts.append(
f"<img src='@{_HOVER_IMAGE_KEY}' style='display:block; margin-bottom:4px'/>"
)
body = "index: @index"
if self.labels is not None:
body += " · label: @label"
parts.append(body)
return f"<div>{''.join(parts)}</div>"
def _build_layout(
self,
frames: list[NDArray[np.floating]],
losses: list[float],
is_live: bool,
) -> tuple[Column, ColumnDataSource, ColumnDataSource]:
n_pts = self.X.shape[0]
if frames:
initial = frames[-1]
initial_epoch = len(frames) - 1
else:
initial = np.zeros((n_pts, 2), dtype=np.float32)
initial_epoch = -1
source_data = dict(
x=initial[:, 0].tolist(),
y=initial[:, 1].tolist(),
index=list(range(n_pts)),
)
if self.labels is not None:
source_data["label"] = self.labels
if self._hover_uris is not None:
source_data[_HOVER_IMAGE_KEY] = list(self._hover_uris)
display_source = ColumnDataSource(source_data)
frames_source = ColumnDataSource(
dict(
x=[f[:, 0].tolist() for f in frames],
y=[f[:, 1].tolist() for f in frames],
)
)
title_text = "Initializing" if initial_epoch < 0 else f"Epoch {initial_epoch}"
p = figure(
title=title_text,
sizing_mode="stretch_both",
output_backend=self.output_backend, # type: ignore[arg-type]
)
p.add_tools(HoverTool(tooltips=self._tooltip()))
if self.labels is not None and self.factors is not None:
color_mapping = CategoricalColorMapper(
factors=self.factors, palette=pick_palette(len(self.factors))
)
p.scatter(
"x",
"y",
source=display_source,
color={"field": "label", "transform": color_mapping},
size=4,
)
else:
p.scatter("x", "y", source=display_source, color="navy", size=4)
loss_source = ColumnDataSource(dict(epoch=list(range(len(losses))), loss=list(losses)))
loss_plot = figure(
sizing_mode="stretch_width",
height=LOSS_PLOT_HEIGHT,
x_axis_label="epoch",
y_axis_label="loss",
toolbar_location=None,
)
loss_plot.line(x="epoch", y="loss", source=loss_source, line_width=2)
loss_plot.add_tools(
HoverTool(
tooltips=[("epoch", "@epoch"), ("loss", "@loss{0.0000}")],
mode="vline",
)
)
epoch_span = Span(
location=initial_epoch if initial_epoch >= 0 else 0,
dimension="height",
line_color="red",
line_dash="dashed",
line_width=1,
)
loss_plot.add_layout(epoch_span)
initial_end = max(1, len(frames) - 1)
slider = Slider(
start=0,
end=initial_end,
value=initial_end,
step=1,
title="epoch",
sizing_mode="stretch_width",
)
play_button = Button(label="Play", width=PLAY_BUTTON_WIDTH)
log_toggle = Toggle(label="Log Y", width=LOG_BUTTON_WIDTH)
linear_scale = LinearScale()
log_scale = LogScale()
log_toggle.js_on_change(
"active",
CustomJS(
args=dict(
loss_plot=loss_plot,
linear_scale=linear_scale,
log_scale=log_scale,
),
code="""
loss_plot.y_scale = cb_obj.active ? log_scale : linear_scale;
""",
),
)
scrub_callback = CustomJS(
args=dict(
display_source=display_source,
frames_source=frames_source,
title=p.title,
epoch_span=epoch_span,
),
code="""
const i = cb_obj.value;
const x = frames_source.data.x[i];
const y = frames_source.data.y[i];
if (x === undefined || y === undefined) return;
display_source.data['x'] = x;
display_source.data['y'] = y;
display_source.change.emit();
title.text = `Epoch ${i}`;
epoch_span.location = i;
""",
)
slider.js_on_change("value", scrub_callback)
play_callback = CustomJS(
args=dict(slider=slider),
code="""
if (window._play_interval) {
clearInterval(window._play_interval);
window._play_interval = null;
cb_obj.label = 'Play';
return;
}
if (slider.value >= slider.end) {
slider.value = slider.start;
}
cb_obj.label = 'Pause';
window._play_interval = setInterval(() => {
if (slider.value >= slider.end) {
clearInterval(window._play_interval);
window._play_interval = null;
cb_obj.label = 'Play';
return;
}
slider.value = slider.value + 1;
}, 50);
""",
)
play_button.js_on_click(play_callback)
controls: list = [log_toggle, play_button, slider]
if is_live:
new_frame_callback = CustomJS(
args=dict(
slider=slider,
frames_source=frames_source,
display_source=display_source,
title=p.title,
epoch_span=epoch_span,
),
code="""
const n = frames_source.data.x.length;
if (n === 0) return;
const new_end = Math.max(1, n - 1);
const was_at_end = slider.value === slider.end;
slider.end = new_end;
if (was_at_end) {
slider.value = new_end;
const idx = Math.min(new_end, n - 1);
display_source.data['x'] = frames_source.data.x[idx];
display_source.data['y'] = frames_source.data.y[idx];
display_source.change.emit();
title.text = `Epoch ${idx}`;
epoch_span.location = idx;
}
""",
)
frames_source.js_on_change("data", new_frame_callback)
save_button = Button(label="Save HTML", width=SAVE_BUTTON_WIDTH)
save_button.on_click(self._save_html)
controls.append(save_button)
layout = column(
p,
loss_plot,
row(*controls, sizing_mode="stretch_width"),
sizing_mode="stretch_both",
styles={
"max-width": "900px",
"min-height": "600px",
"max-height": "1100px",
},
)
return layout, frames_source, loss_source
def _save_html(self) -> None:
frames_snapshot = list(self.frames)
losses_snapshot = list(self.losses)
layout, _, _ = self._build_layout(frames_snapshot, losses_snapshot, is_live=False)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = Path(f"glassbox_live_{timestamp}.html").resolve()
html = file_html(layout, INLINE, title="Glass Box UMAP")
filename.write_text(html)
print(f"Saved {filename}")
def _doc_handler(self, doc: Document) -> None:
layout, frames_source, loss_source = self._build_layout(
self.frames, self.losses, is_live=True
)
doc.add_root(layout)
entry = (doc, frames_source, loss_source)
self.sessions.append(entry)
def cleanup(_ctx: object) -> None:
if entry in self.sessions:
self.sessions.remove(entry)
doc.on_session_destroyed(cleanup)
def _start_server(self) -> None:
try:
io_loop = IOLoop()
server = Server(
{"/": Application(FunctionHandler(self._doc_handler))},
io_loop=io_loop,
port=self.port,
session_token_expiration=86400,
)
self.port = server.port
server.start()
io_loop.add_callback(server.show, "/")
io_loop.add_callback(self._ready.set)
io_loop.start()
except BaseException as exc:
self._server_error = exc
self._ready.set()
[docs]
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
threading.Thread(target=self._start_server, daemon=True).start()
self._ready.wait(timeout=5.0)
if self._server_error is not None:
raise self._server_error
print(f"Live embedding serving at http://localhost:{self.port}/")
[docs]
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
emb = self.transform_fn(self.X)
self.frames.append(emb)
self.latest_epoch = trainer.current_epoch
new_x = emb[:, 0].tolist()
new_y = emb[:, 1].tolist()
for doc, frames_source, _loss_source in list(self.sessions):
def update(frames_source=frames_source, new_x=new_x, new_y=new_y) -> None:
frames_source.data = {
"x": [*frames_source.data["x"], new_x],
"y": [*frames_source.data["y"], new_y],
}
doc.add_next_tick_callback(update)
[docs]
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
loss = trainer.callback_metrics.get("loss_epoch")
if loss is None:
return
new_loss = float(loss)
new_epoch = trainer.current_epoch
self.losses.append(new_loss)
for doc, _frames_source, loss_source in list(self.sessions):
def update(loss_source=loss_source, new_loss=new_loss, new_epoch=new_epoch) -> None:
loss_source.data = {
"epoch": [*loss_source.data["epoch"], new_epoch],
"loss": [*loss_source.data["loss"], new_loss],
}
doc.add_next_tick_callback(update)
[docs]
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self.block_after_fit:
print(f"Training done. Server still serving at http://localhost:{self.port}/.")
return
print(
f"Training done. Server still serving at http://localhost:{self.port}/. "
"Press Ctrl-C to exit."
)
threading.Event().wait()