Following our integration of Asymmetric Flow Matching, our 400M parameter NanoDiT was training efficiently in terms of step-count convergence, but it was hitting an iter_per_sec of 0.025 on our single RTX 4090. A full 5,000-step ablation cycle required 56 hours of active compute.

The Bottleneck

Before scaling the dataset from our 7k curated subset to the full 70k+ pipeline, we needed a faster iteration loop. Our goal was to halve the active compute time either via lower precision (FP8) or parameter reduction (shared spatial modulation).

Here is what worked—and what catastrophically failed.


1. The FP8 Memory Illusion & Breakthrough

Native FP8 via torchao theoretically promises doubled tensor core throughput and halved memory bandwidth compared to BF16. However, our initial naĂŻve implementation failed to execute a single forward/backward pass within our 24GB VRAM budget, instantly triggering Out Of Memory (OOM) errors at batch sizes where BF16 comfortably fit.

The Trap: Dynamic Scale State Overhead

Why did an 8-bit format consume more memory than a 16-bit format? The answer lies in how torchao handles dynamic tensor casting. For dynamic casting, the framework allocates dynamic scale states for every linear layer and continuously tracks rolling history maxima during the forward pass. This metadata overhead, combined with wrapper casting operations, drastically inflated the footprint and destroyed the bandwidth savings.

The Fix: Dynamic Tensor Masking & Scoped Autocast

To fit the model back into 24GB VRAM while preserving the FP8 throughput, we had to tame the scale-state overhead. We implemented strict dynamic tensor masking and optimized autocast scoping so that FP8 conversion was tightly localized to the heaviest matrix multiplications (the core attention and FFN projections), bypassing the metadata overhead for the rest of the network.

The Results: >2x Speedup

By resolving the overhead, active compute time plummeted from 56 hours to 26 hours (iter_per_sec increased from 0.025 to 0.053).

Crucially, the perceptual quality held up beautifully against BF16. Below are the final LPIPS metrics at step 5000 comparing Arm G (BF16 AsymFlow) against Arm I (FP8 Native):

Metric Arm G (BF16) Arm I (FP8) Delta
Reconstruction LPIPS 0.900 0.906 +0.006 (Negligible)
Text-only LPIPS 0.920 0.909 -0.011 (Better)
Text Manip Delta 0.485 0.504 +0.019 (Better)

Note: Lower LPIPS is better for perceptual similarity. Higher Text Manip Delta indicates stronger text-controllability.

Visual Comparisons

Reconstruction Fidelity (Conditioned on Identity + Text): (Arm G on left, Arm I on right) Reconstruction Comparison

Text-Only Controllability (Conditioned purely on Text prompt): (Arm G on left, Arm I on right) Text-only Comparison

The FP8 model achieves equivalent perceptual quality in less than half the time.


2. The Shared adaLN + LoRA Collapse

While working on compute optimization, we also explored parameter reduction.

The Hypothesis: Our DiT utilizes an adaLN (Adaptive Layer Normalization) projection in every transformer block to inject the timestep and spatial conditioning signals. What if we shared a single central adaLN projection across all layers, and relied on a lightweight per-block rank-8 LoRA to handle block-specific spatial localization? This would save 59M parameters (dropping the model from 237.7M to 178.8M non-embedding parameters).

The Reality: The model (Arm H) completely failed to converge. Reconstruction LPIPS stalled at 0.9823, and the output degraded into structural noise and visual dithering by step 5000.

The Lesson: Full-rank, independent per-block modulation is structurally load-bearing in a Diffusion Transformer. You cannot compress the spatial/semantic conditioning pathway without destroying the model’s ability to localize features across different depths of the network. The 59M parameter savings were simply not worth the catastrophic quality collapse.


Conclusion

The architecture is now stabilized, perceptually verified, and fast. The ablation phase is formally closed. With our pipeline executing 5k steps in just 26 hours, we are ready for the “big run”—scaling the dataset, extending the training horizon, and deploying bucket-aware batching.