This package demonstrates integrating TransformerEngine's NVFP4 quantization into TorchTitan using the registry override system.
The --nvfp4 flag triggers the override system which:
- Imports
torchtitan_te.overrides(registerste_nvfp4_linearoverride) - Calls
apply_overrides(model_config, ...)before model construction - Replaces all
Linear.ConfigwithTENVFP4Linear.Config - When the model is built,
TENVFP4Linearuseste.Linear+te.fp8_autocast
This approach is torch.compile compatible.
pip install -e .Requires:
- TorchTitan (PR #3396 branch)
- TransformerEngine (with NVFP4 support)
- Blackwell GPU (SM100+)
# bf16 baseline
python -m torchtitan_te.train --data fineweb
# NVFP4 quantization
python -m torchtitan_te.train --data fineweb --nvfp4
# NVFP4 + torch.compile
python -m torchtitan_te.train --data fineweb --nvfp4 --compile
# Overfit test (verify training works)
python -m torchtitan_te.train --overfit --nvfp4 --steps 10# In overrides.py
@register(
"te_nvfp4_linear",
target=Linear.Config,
description="Replace Linear with TE NVFP4 Linear",
)
def te_nvfp4_linear_override(cfg: Linear.Config) -> TENVFP4Linear.Config | None:
if cfg.in_features % 16 != 0 or cfg.out_features % 16 != 0:
return None # Skip - NVFP4 requires dims divisible by 16
return TENVFP4Linear.Config(
in_features=cfg.in_features,
out_features=cfg.out_features,
bias=cfg.bias,
param_init=cfg.param_init,
)# In train.py
if args.nvfp4:
from torchtitan.registry import apply_overrides, OverrideConfig
from . import overrides # registers te_nvfp4_linear
override_config = OverrideConfig(modules=["torchtitan_te.overrides"])
apply_overrides(model_config, override_config)
model = Llama3Model(model_config) # Now uses TENVFP4Linear instead of Linearnvfp4.py-TENVFP4Linearmodule (wrapste.Linear+fp8_autocast)overrides.py-@registerdecorator for the registry override systemtrain.py- Training script demonstrating the integrationcommon.py- Shared utilities (model configs, data loading, etc.)