Skip to content

tbqh/torchtitan_te_plugin

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 

Repository files navigation

TorchTitan + TransformerEngine NVFP4 Plugin

This package demonstrates integrating TransformerEngine's NVFP4 quantization into TorchTitan using the registry override system.

What it does

The --nvfp4 flag triggers the override system which:

  1. Imports torchtitan_te.overrides (registers te_nvfp4_linear override)
  2. Calls apply_overrides(model_config, ...) before model construction
  3. Replaces all Linear.Config with TENVFP4Linear.Config
  4. When the model is built, TENVFP4Linear uses te.Linear + te.fp8_autocast

This approach is torch.compile compatible.

Installation

pip install -e .

Requires:

  • TorchTitan (PR #3396 branch)
  • TransformerEngine (with NVFP4 support)
  • Blackwell GPU (SM100+)

Usage

# bf16 baseline
python -m torchtitan_te.train --data fineweb

# NVFP4 quantization
python -m torchtitan_te.train --data fineweb --nvfp4

# NVFP4 + torch.compile
python -m torchtitan_te.train --data fineweb --nvfp4 --compile

# Overfit test (verify training works)
python -m torchtitan_te.train --overfit --nvfp4 --steps 10

How the override works

# In overrides.py
@register(
    "te_nvfp4_linear",
    target=Linear.Config,
    description="Replace Linear with TE NVFP4 Linear",
)
def te_nvfp4_linear_override(cfg: Linear.Config) -> TENVFP4Linear.Config | None:
    if cfg.in_features % 16 != 0 or cfg.out_features % 16 != 0:
        return None  # Skip - NVFP4 requires dims divisible by 16
    return TENVFP4Linear.Config(
        in_features=cfg.in_features,
        out_features=cfg.out_features,
        bias=cfg.bias,
        param_init=cfg.param_init,
    )
# In train.py
if args.nvfp4:
    from torchtitan.registry import apply_overrides, OverrideConfig
    from . import overrides  # registers te_nvfp4_linear
    override_config = OverrideConfig(modules=["torchtitan_te.overrides"])
    apply_overrides(model_config, override_config)

model = Llama3Model(model_config)  # Now uses TENVFP4Linear instead of Linear

Files

  • nvfp4.py - TENVFP4Linear module (wraps te.Linear + fp8_autocast)
  • overrides.py - @register decorator for the registry override system
  • train.py - Training script demonstrating the integration
  • common.py - Shared utilities (model configs, data loading, etc.)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages