← Writing

SpatialDINO Lessons: 3D SSL on Lattice Light-Sheet Microscopy

2025-09-01

A detailed log from building SpatialDINO — a 3D self-supervised vision transformer for label-free segmentation and tracking of subcellular dynamics in lattice light-sheet microscopy.

SpatialDINO is a 3D self-supervised foundation model for lattice light-sheet microscopy (LLSM). It learns volumetric representations of live-cell dynamics without a single human label, then drives downstream segmentation and multi-frame tracking of endosomes, viral particles, and other sub-resolution organelles. It beat a prior approach co-led by Nobel laureate Eric Betzig (Chemistry, 2014) on downstream subcellular structure prediction. The work spanned November 2024 through June 2025 with Alex Lavaee, Tom Kirchhausen, and Jose Inacio. Code and weights: github.com/kirchhausenlab/cell_interactome.

This is the long-form engineering log — what we built, what broke, what we learned, grounded in the actual repository.

Why label-free, why 3D-native

LLSM produces 4D (T+ZYX) movies of living cells at ~3 nm — finer than the ~5 nm half-wavelength of the PSF, so particles appear as non-Gaussian blobs with no clean boundary. A typical movie has ~100 timepoints × ~80–100 z-frames with 40–80 endosomes per frame. At 10 seconds per voxel-level annotation, one movie takes ~1,500 hours to label. We have 300+ movies. That math is why classical SAM-style adaptations (MicroSAM, μSAM) don't work for us — they're 2D, scale- sensitive, and need supervision we cannot afford to produce.

The other forcing function: existing 2.5D adaptations of DINOv2 (apply per-slice, stack outputs) introduce inter-plane discontinuities and lose volumetric context. Endosomes touch and overlap in 3D; their identity is not recoverable from any single slice. So we built SpatialDINO to be 3D-native end to end — patch embedding, augmentation, attention, masking, segmentation, tracking.

Data pipeline

Filtering. Terabytes of raw 4D acquisitions get reduced via max-z-projections of the first frame per movie/channel. A human eyeballs the projection, accepts or rejects, and valid paths get logged. This semi-automated step is the entire quality-control budget — it's the only way to cover 300+ movies without burning months on inspection.

KMeans cropping. The single biggest data-side win. Random crops on sparse 3D volumes mostly hit background. Instead, we percentile-threshold (top 99.75–99.9%) the volume to a pseudo-binary mask, sample points from the mask, fit KMeans with k=4 clusters, and use the centroids as crop centers. Aspect ratios are jittered in [0.3, 0.6] of volume dimensions. Crops contain 3–5× more biological signal than random and the model reserves capacity for foreground rather than noise. This was implemented as a MONAI-compatible KMeansSpatialCropSamplesd transform (data/transforms.py:984-1202).

Format & scale. 180k volumes / 2.4 TB packed into 2,596 WebDataset tar shards (~950 MB each, ~60 experiments per shard). Distributed across 3 nodes × 256 CPU cores; ~12–14 hours end-to-end. The shard layout matters — see file-system notes below.

Custom MONAI transforms (open-sourced). Histogram normalization (re-implementation of ImageJ's algorithm: 65,536 bins on 16-bit data, threshold 1/5000 × total voxels for robust min/max — analogous to CLAHE but on the global histogram), chunked trilinear interpolation (separable kernel processed in windows so a 200×1024×1024 volume doesn't need 150 GB of intermediate tensors), and the KMeans cropping above.

Architecture: SpatialDINO

A 3D-native DINOv2 with several deliberate divergences. Final config: ViT-Small, 21M parameters, patch size 8, 2 global crops at 112³, 8 local crops at 48³.

3D Conv patch embed. A Conv3d directly turns volumes into token sequences. This is more than convenience — the conv carries the spatial inductive bias that 2D DINOv2 leans on positional encodings for.

Patch size 8, not 14. DINOv2's patch 14 was too coarse: features over-blurred small structures (viruses, sub-resolution puncta) and segmentation collapsed. Patch 8 lives at the Pareto frontier — patch 4 helps but is computationally prohibitive.

No positional encodings (NoPE). Maybe the cleanest result. With learned 2D-style positional encodings, the model collapsed reproducibly around 10k iterations: high-norm patch tokens, grid-like artifacts in feature maps, dead representations. We tried sincos, learned, and relative — all collapsed. Removing positional encodings entirely fixed it. The encoder has the option plumbed through (models/layers/encoder.py:164-192):

if self.pos_embed_type == "sincos":   ...
elif self.pos_embed_type == "learned": ...
else:
    self.pos_embed = None  # NoPE

The Conv3d already encodes spatial structure into tokens, so attention recovers position implicitly — analogous to NoPE in NLP (Haviv 2022, Kazemnejad 2023). To our knowledge this is the first formal NoPE result for 3D ViTs.

Test-time register tokens. Borrowed from Jiang et al. 2025 — append a single zero-initialized register token at inference. Marginal but consistent quality bump; 2/4/8 tokens added cost without benefit. Train-time registers reduced but did not prevent collapse, which is what pushed us to NoPE.

Unified head, fewer prototypes. Merged DINO and iBOT heads, dropped output dimension from 128k → 32k prototypes. Saved ~12 GB of GPU memory at no representation-quality cost. The DINO loss enforces global cls-level consistency; iBOT enforces masked-patch consistency; KoLeo regularizes the cls embedding spread. Killing either DINO or iBOT degraded features — they are genuinely complementary.

Block masking, not random voxel masking. Random voxel masking is trivial in 3D (neighbors interpolate). 3D block masking with foreground-biased sampling (15–50% mask ratio) makes the pretext task hard enough to learn from.

Augmentation stack

The ground rule: DINOv2's photometric jitter assumes RGB natural images. None of it works on grayscale fluorescence. We replaced color jitter with random contrast (γ ∈ [0.25, 2.0]), random 90° rotations restricted to the Z-Y plane (rotations through Z violate physical anisotropy), random flips on all axes, and Gaussian noise on global crops only. Cropping is KMeans, not random. That's the entire stack — minimal, microscopy-honest.

Pre-training infrastructure

24× A100-40GB across 3 DGX nodes, 250k iterations, ~4 days. BF16 throughout, FP32 retained for the DINO/iBOT heads. PyTorch DDP (we evaluated FSDP — the extra complexity didn't justify the marginal win at 21M params). Effective batch size 336.

File system was the hidden boss fight. NFS over TCP/IP gave us 500 ms+ batch-load latencies — every GPU contending through one server, request queues saturated, cache thrashed. Switching to RDMA over InfiniBand cut sequential read latency 2×. The bigger lever was WebDataset: streaming sequential reads from POSIX tar shards instead of millions of small file opens. Sustained throughput 1.8 GB/s across the cluster — ~60% above per-file access. Each worker claims exclusive shards, no contention, complete coverage.

Caught a Rendezvous backend bug in PyTorch that surfaced as InfiniBand RDMA stalls — filed pytorch#144779. Cost a week.

Collapse history (the iteration log)

Roughly chronological — these are the failure modes that taught us the design:

  • ADOPT instead of ADAM, then back. ADAM was fine.
  • Patch size 8 with global crops [48, 224, 224] and local [16, 48, 48] — asymmetric Z hurt feature quality. Cubic crops won.
  • Higher LR → instant collapse.
  • Shared vs separate iBOT head — kept the unified version after merging.
  • KoLeo on/off/back-on. Off → collapse. KoLeo proto (regularizing the prototype matrix instead of latents) was tried; reverted.
  • torch.clamp on losses to prevent NaN cascades from teacher temperature warmup.
  • Removed DINO head normalization (per Govindarajan 2024 partial-prototype collapse paper) — short-lived experiment.
  • Added a DeSD-style separate prediction head when the cls token started dominating; eventually folded back.
  • The recurring 10k-iter collapse was the signal that pushed us through registers → NoPE.

SINDER for 3D — singular defect repair

Even after NoPE, deep layers of the pretrained ViT exhibited the singular-defect artifacts described by Wang et al. 2024 for 2D DINOv2: input-agnostic high-norm tokens that don't depend on the image. We extended SINDER to 3D (models/sinder/).

Defect direction per block. Linearize the attention V-branch as a composed matrix A = A4·A3·A2·A1, and the MLP non-linearity via 100k-sample least-squares regression (GELU and SwiGLU both supported). Compose blocks cumulatively — Gᵢ = Eᵢ Gᵢ₋₁ — and take the leading left singular vector of each Gᵢ as that layer's defect direction.

for a in anomaly_as:
    aaa = a @ aaa
    u, _, _ = LA.svd(aaa)
    accumulative_anomalies.append(u[:, 0])

SVD reparameterization. Every linear becomes U @ diag(S + ε) @ Vᵀ. U, S, V freeze; only ε (zero-init) trains. For QKV, only the value branch's ε learns ("no query, no key" — defects live in V representations, not attention patterns). Trainable parameters drop to under 0.1% of the full model.

3D neighbor loss. For each anomalous token (logit below μ − τσ of layer-wise stats, boundary tokens excluded), compute its 3×3×3 Gaussian-weighted local average and minimize normalized L2 between the token and that smoothed target. The Gaussian is volumetric — 2D SINDER's square neighborhood doesn't generalize because biological structures have isotropic spatial coherence in 3D.

Layer-localized gradients. When defects appear at layer ℓ, zero gradients on layers below ℓ − L_limit + 1. Without this, corrections propagate back to early layers and degrade the spatial features the conv patch embed worked hard to learn. L_limit = 10 worked across our depth.

Inference at full-volume scale

A single LLSM volume after upsampling reaches a few million tokens at patch 8. Vanilla attention is quadratic — infeasible.

StreamingEncoder. Tile each volume into chunks of size streaming_patch_chunk_size, embed via Conv3d, and store tokens in a TokenStore (GPU/CPU/disk-backed). Stream Q/KV in blocks (q_block_tokens, kv_block_tokens) through the transformer with online-softmax attention (Triton kernels, log-sum-exp for numerical stability). Tokens write straight back into the unified feature grid via index views — no overlap-blend stitching, no edge artifacts.

Sliding-window fallback. For workloads that don't need the streaming path, a MONAI-style sliding window with 448³ ROIs, configurable overlap. Zero overlap gave fastest inference with negligible feature degradation; 10% overlap is the conservative default if downstream segmentation is sensitive.

Multi-node sharding. Plain DistributedSampler(shuffle=False) over the file list. Each rank owns whole volumes — within-volume sharding is unnecessary because StreamingEncoder already handles arbitrary sizes on a single GPU.

GPU transforms wrap CPU MONAI transforms. Preprocessing was the GPU-idle bottleneck. The chunked trilinear interpolation transform alone matters: isotropic resampling of a 200×1024×1024 volume on GPU would be ~150 GB; chunked in windows it scales as O(k³) per chunk and stays under a few GB.

Outputs per movie. Low-resolution features per timepoint, raw deskewed TIFFs, high-res features and PCA-reduced features, KDE probability maps, Voronoi-Otsu instance masks, and tracks.csv.

Downstream — segmentation and tracking, both unsupervised

Semantic segmentation. Concatenate patch features (D=384) with per-head attention summed across heads (D=6) into a 390-d descriptor. Sum across the feature dim to get a scalar activation volume — bimodal by construction (foreground high, background near zero), so 3D Otsu is the right thresholder. Multi-scale 3D Laplacian-of-Gaussian (σ ∈ ) max-pooled across scales pulls out sub-resolution blobs.

Instance segmentation. Voronoi-Otsu in a joint spatial-feature distance space:

d = α · d_spatial + (1−α) · d_feature

with α adapting to local density (≈0.8 sparse, ≈0.3 dense). Seeds come from LoG zero-crossings on feature activations, not raw intensity, so they survive photobleaching. Graph-cut refinement on the normalized-cut criterion splits touching instances. Critical heuristic: Voronoi-Otsu degrades when the SNR ratio falls below ~1.2 — for a 120-background / 130-particle case it fails outright. We pre-denoise with a context-aware filter to lift the foreground before instance separation.

Tracking. TrackPy + feature-similarity. Centroids are intensity-weighted moments (not geometric centers — intensity correlates with biological density). Project the 390-d feature to 10 PCA components on GPU (99.5% variance). Cost matrix combines spatial Euclidean, cosine feature similarity, and TrackPy's NearestVelocityPredict motion term:

C(i,j) = wₛ d_spatial + w_f d_feature + w_m d_motion

Hungarian assignment with virtual source/sink nodes for track init/term. Memory of 3–5 frames + Kalman gap-filling handles photobleaching dropouts and brief out-of-plane motion. Volume-jump detection flags candidate fission/fusion events for review. Centroid-only tracking is the TrackMate-style baseline this is meant to beat — Hungarian linking on point detections fails on sub-PSF particles, intensity-variable signal, and dense clustering, all simultaneously.

Headline innovations, in one place

  1. NoPE for 3D ViTs. Removing positional encodings cured a recurring 10k-iter collapse. The Conv3d patch embed carries the spatial bias.
  2. KMeans 3D crops. Content-aware cropping over percentile-thresholded volumes — 3–5× more signal per crop, smaller dataset, better features.
  3. Native 3D iBOT block masking with foreground-biased sampling.
  4. StreamingEncoder with token-store + online softmax — arbitrary-size volumes on a single GPU.
  5. 3D SINDER with 3×3×3 Gaussian neighbor loss and ε-only SVD repair touching under 0.1% of parameters.
  6. Lattice deskew baked into augmentation[2.404, 1.0, 1.0] Z-anisotropy correction inside the transform stack, not as preprocessing.
  7. WebDataset over RDMA-NFS — sustained 1.8 GB/s for distributed 3D training where naive NFS topped out at ~500 ms-per-batch latency.

What I'd do differently

  • Bigger model earlier. ViT-Small at 21M was the right call for this dataset size, but the diminishing-returns experiments suggested ViT-Base would have helped on the hardest endosome clusters.
  • Diffusion-based denoising as an auxiliary pretext. The MAE pixel-recon objective recovered noise rather than structure — we dropped it. A diffusion head trained to denoise would give the encoder a stronger prior given LLSM's brutal SNR floor.
  • Closer co-design with the tracking output earlier. We optimized features for segmentation first; tracking imposed its own constraints (PCA separability, frame-to-frame stability) that we discovered late.