{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "2be08915", "metadata": { "execution": { "iopub.execute_input": "2026-05-11T22:09:07.821967Z", "iopub.status.busy": "2026-05-11T22:09:07.821837Z", "iopub.status.idle": "2026-05-11T22:09:07.871349Z", "shell.execute_reply": "2026-05-11T22:09:07.870906Z" }, "tags": [ "remove-cell", "remove-output" ] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "61ec76cd", "metadata": {}, "source": [ "# Saving & Loading\n", "\n", "A fitted `GlassBoxUMAP` is a PyTorch model, plus the bits of state that connect raw input features to that model (PCA basis, feature means, training hyperparameters). The `save` and `load` methods round-trip all of this to a single file on disk." ] }, { "cell_type": "markdown", "id": "a30cf380", "metadata": {}, "source": [ "## Fit a model\n", "\n", "We'll use scikit-learn's digits dataset, which is 1,797 handwritten digit images flattened to 64 features (8x8 grayscale pixels)." ] }, { "cell_type": "code", "execution_count": 2, "id": "ee85341e", "metadata": { "execution": { "iopub.execute_input": "2026-05-11T22:09:07.872840Z", "iopub.status.busy": "2026-05-11T22:09:07.872764Z", "iopub.status.idle": "2026-05-11T22:09:57.047504Z", "shell.execute_reply": "2026-05-11T22:09:57.046914Z" }, "tags": [ "remove-output" ] }, "outputs": [ { "data": { "text/plain": [ "GlassBoxUMAP(n_neighbors=15, min_dist=0.1, metric='euclidean', n_components=2, negative_sample_rate=5, repulsion_strength=1.0, pca_components=None, encoder_name='default', encoder_kwargs={}, lr=0.001, epochs=200, batch_size=10000, num_batches=None, num_workers=0, checkpoint_dir=None, restore_best_weights=True, random_state=None, quiet=True, extra_callbacks=[], _model=UMAPLightningModule(\n", " (encoder): DeepPReLUNet(\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " (model): Sequential(\n", " (0): Linear(in_features=64, out_features=128, bias=False)\n", " (1): VmapPReLU(num_parameters=1)\n", " (2): LayerNormDetached()\n", " (3): Dropout(p=0.0, inplace=False)\n", " (4): Linear(in_features=128, out_features=128, bias=False)\n", " (5): VmapPReLU(num_parameters=1)\n", " (6): LayerNormDetached()\n", " (7): Dropout(p=0.0, inplace=False)\n", " (8): Linear(in_features=128, out_features=128, bias=False)\n", " (9): VmapPReLU(num_parameters=1)\n", " (10): Dropout(p=0.0, inplace=False)\n", " (11): Linear(in_features=128, out_features=2, bias=False)\n", " )\n", " )\n", "), _pca=None, _mean=array([ 0.00000000e+00, -5.47818786e-07, -1.53904040e-07, -8.63056684e-08,\n", " 3.33016494e-08, 1.05046141e-07, -8.64715162e-08, 4.02671319e-08,\n", " -3.43394220e-07, 2.50386222e-07, -9.64553735e-08, -9.03854485e-08,\n", " 7.09815993e-08, 1.78051053e-07, -4.00681195e-08, 1.15826055e-07,\n", " 3.03098091e-07, 1.68597893e-07, -7.03182224e-08, -7.03845586e-08,\n", " 1.29209738e-07, 5.83773918e-09, 1.48596996e-07, -6.45153136e-07,\n", " -7.97879665e-08, -5.36673994e-08, -3.70497446e-08, 1.49393060e-07,\n", " 1.26705482e-08, -1.28612697e-07, 2.61238824e-07, 2.33920034e-07,\n", " 0.00000000e+00, 5.10802201e-09, -5.98036607e-08, -8.41165146e-08,\n", " 5.27386668e-09, 2.08964526e-08, 6.98074189e-07, 0.00000000e+00,\n", " -3.54775324e-07, -1.37684410e-07, 2.35964080e-07, 4.52507720e-08,\n", " 7.85441259e-08, 2.65351794e-08, 5.48614807e-08, -5.20180720e-07,\n", " 2.73561113e-07, -4.41545382e-07, 3.04491188e-08, 1.63315725e-08,\n", " -1.65728778e-07, 8.82294700e-08, 2.62565578e-07, -2.45898178e-07,\n", " 1.65836568e-07, 2.76761909e-07, -2.48468780e-07, 3.68613030e-07,\n", " 7.01523772e-09, 2.16394383e-07, 3.51508191e-07, 5.54668134e-07],\n", " dtype=float32), _device=device(type='mps'))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from glass_box_umap import GlassBoxUMAP\n", "\n", "def load_data():\n", " from sklearn.datasets import load_digits\n", " from sklearn.preprocessing import StandardScaler\n", " digits, _ = load_digits(return_X_y=True)\n", " return StandardScaler().fit_transform(digits)\n", "\n", "X = load_data()\n", "embedder = GlassBoxUMAP(quiet=True)\n", "embedder.fit(X)" ] }, { "cell_type": "markdown", "id": "d4d855cb", "metadata": {}, "source": [ "## Save the model\n", "\n", "`save` takes a path and writes a PyTorch checkpoint.\n", "\n", ":::{admonition} save API\n", ":class: api, dropdown\n", "\n", "From the {meth}`API docs `:\n", "\n", "```{eval-rst}\n", ".. automethod:: glass_box_umap.GlassBoxUMAP.save\n", " :noindex:\n", "```\n", ":::" ] }, { "cell_type": "code", "execution_count": 3, "id": "3e2a4f8c", "metadata": { "execution": { "iopub.execute_input": "2026-05-11T22:09:57.048683Z", "iopub.status.busy": "2026-05-11T22:09:57.048598Z", "iopub.status.idle": "2026-05-11T22:09:57.080820Z", "shell.execute_reply": "2026-05-11T22:09:57.080436Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "saved model (167.2 KiB) to embedder.pt\n" ] } ], "source": [ "from pathlib import Path\n", "\n", "model_path = Path.cwd() / \"embedder.pt\"\n", "embedder.save(model_path)\n", "\n", "print(f\"saved model ({model_path.stat().st_size / 1024:.1f} KiB) to {model_path.name}\")" ] }, { "cell_type": "markdown", "id": "2861a83b", "metadata": {}, "source": [ "## Load the model\n", "\n", "`GlassBoxUMAP.load` is a classmethod that reconstructs the embedder from a checkpoint.\n", "\n", "\n", ":::{admonition} load API\n", ":class: api, dropdown\n", "\n", "From the {meth}`API docs `:\n", "\n", "```{eval-rst}\n", ".. automethod:: glass_box_umap.GlassBoxUMAP.load\n", " :noindex:\n", "```\n", ":::" ] }, { "cell_type": "code", "execution_count": 4, "id": "0606de75", "metadata": { "execution": { "iopub.execute_input": "2026-05-11T22:09:57.081950Z", "iopub.status.busy": "2026-05-11T22:09:57.081872Z", "iopub.status.idle": "2026-05-11T22:09:57.104393Z", "shell.execute_reply": "2026-05-11T22:09:57.103996Z" } }, "outputs": [], "source": [ "loaded = GlassBoxUMAP.load(model_path)" ] }, { "cell_type": "markdown", "id": "e540f416", "metadata": {}, "source": [ "The reloaded embedder is functionally identical to the original. We can confirm that with `==`:" ] }, { "cell_type": "code", "execution_count": 5, "id": "94579964", "metadata": { "execution": { "iopub.execute_input": "2026-05-11T22:09:57.105597Z", "iopub.status.busy": "2026-05-11T22:09:57.105532Z", "iopub.status.idle": "2026-05-11T22:09:57.126410Z", "shell.execute_reply": "2026-05-11T22:09:57.126031Z" } }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "loaded == embedder" ] }, { "cell_type": "markdown", "id": "c6bca615", "metadata": {}, "source": [ "```{note}\n", "`==` performs a *semantic* comparison: it checks that both embedders describe the same trained model — matching architecture, hyperparameters, learned weights, PCA fit, and centering vector — and ignores incidental runtime state like device placement, logging verbosity, and DataLoader workers. So two consecutive `load`s of the same file always equate, even if one was moved to CPU and the other left on GPU.\n", "```" ] }, { "cell_type": "markdown", "id": "27392f86", "metadata": {}, "source": [ "Concretely, that means `transform` and `compute_contributions` produce bitwise-identical outputs from either embedder:" ] }, { "cell_type": "code", "execution_count": 6, "id": "ff18ac88", "metadata": { "execution": { "iopub.execute_input": "2026-05-11T22:09:57.127431Z", "iopub.status.busy": "2026-05-11T22:09:57.127371Z", "iopub.status.idle": "2026-05-11T22:09:57.205766Z", "shell.execute_reply": "2026-05-11T22:09:57.205288Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "embeddings and contributions match exactly\n" ] } ], "source": [ "import numpy as np\n", "\n", "Z_original = embedder.transform(X)\n", "Z_loaded = loaded.transform(X)\n", "assert np.array_equal(Z_original, Z_loaded)\n", "\n", "C_original = embedder.compute_contributions(X)\n", "C_loaded = loaded.compute_contributions(X)\n", "assert np.array_equal(C_original, C_loaded)\n", "\n", "print(\"embeddings and contributions match exactly\")" ] }, { "cell_type": "markdown", "id": "78a6f034", "metadata": {}, "source": [ "## Caveats\n", "\n", "### Custom encoders\n", "\n", "If you fit a model with a custom encoder, the loading process must register the same encoder before `load` runs. See [Saving/Loading](custom_encoders.ipynb#saving-loading) in the Custom Encoders guide for details." ] } ], "metadata": { "kernelspec": { "display_name": "glass-box-umap (3.13.1)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 5 }