Source code for glass_box_umap.jacobian

import copy
from dataclasses import dataclass
from typing import Literal

import numpy as np
import torch
import torch.nn as nn
from numpy.typing import NDArray
from torch.func import functional_call, jacrev, vmap  # pyright: ignore[reportPrivateImportUsage]


[docs] def compute_jacobian( model: nn.Module, x: torch.Tensor, batch_size: int = 1024, ) -> torch.Tensor: """Compute the Jacobian of a model using ``vmap`` + ``jacrev`` with ``functional_call``. Args: model: Encoder network (will be deep-copied and set to eval mode). x: Input tensor of shape ``(n, in_dim)``. batch_size: Number of samples per Jacobian batch. Returns: Jacobian tensor of shape ``(n, out_dim, in_dim)``. """ model = copy.deepcopy(model).eval() params = dict(model.named_parameters()) buffers = dict(model.named_buffers()) def func_single(x_single: torch.Tensor) -> torch.Tensor: return functional_call(model, {**params, **buffers}, (x_single.unsqueeze(0),)).squeeze(0) jac_fn = vmap(jacrev(func_single)) results = [] for start in range(0, x.shape[0], batch_size): x_batch = x[start : start + batch_size] with torch.no_grad(): J_batch = jac_fn(x_batch) results.append(J_batch) return torch.cat(results, dim=0)
[docs] def project_jacobian(jacobian: torch.Tensor, proj_tensor: torch.Tensor) -> torch.Tensor: """Map a Jacobian's input axis through a linear projection. Used to express a Jacobian computed in a reduced input space (e.g. PCA components) in terms of the original features, by right-multiplying with the projection matrix that maps reduced-space inputs back to original features. Args: jacobian: Jacobian tensor of shape ``(n, out_dim, in_dim_reduced)``. proj_tensor: Projection matrix of shape ``(in_dim_reduced, in_dim_original)``, e.g. ``pca.components_``. Returns: Jacobian of shape ``(n, out_dim, in_dim_original)``. """ return torch.einsum("bij,jk->bik", jacobian, proj_tensor)
[docs] def reduce_contributions( contributions: NDArray[np.floating], method: Literal["l2"] = "l2", ) -> NDArray[np.floating]: """Reduce per-feature contributions across embedding dimensions. Args: contributions: Feature contributions with shape (n_samples, n_components, n_features). method: Reduction method. ``"l2"`` takes the L2 norm across components. Returns: Reduced contributions with shape (n_samples, n_features). """ match method: case "l2": return np.linalg.norm(contributions, axis=1)
[docs] @dataclass class JacobianVerification: """Result of verifying that ``f(x) ≈ J(x) @ x``. Attributes: z_range: (min, max) of the embedding output. reconstruction_range: (min, max) of the Jacobian reconstruction. max_error: Maximum absolute error between embedding and reconstruction. mean_error: Mean absolute error between embedding and reconstruction. relative_error: Max error relative to the embedding's magnitude. """ z_range: tuple[float, float] reconstruction_range: tuple[float, float] max_error: float mean_error: float relative_error: float
[docs] def verify_jacobian( Z: NDArray[np.floating], J: NDArray[np.floating], X: NDArray[np.floating], ) -> JacobianVerification: """Verify that ``f(x) ≈ J(x) @ x``. Args: Z: Embedding output, shape ``(n, out_dim)``. J: Jacobian, shape ``(n, out_dim, in_dim)``. X: Input data, shape ``(n, in_dim)``. Returns: A ``JacobianVerification`` with error diagnostics. """ Z_reconstructed = np.einsum("noi,ni->no", J, X) return JacobianVerification( z_range=(float(Z.min()), float(Z.max())), reconstruction_range=(float(Z_reconstructed.min()), float(Z_reconstructed.max())), max_error=float(np.abs(Z - Z_reconstructed).max()), mean_error=float(np.abs(Z - Z_reconstructed).mean()), relative_error=float(np.abs(Z - Z_reconstructed).max() / (np.abs(Z).max() + 1e-8)), )