diff --git a/optimized/tensorRT/.gitignore b/optimized/tensorRT/.gitignore index 599b1ff..923a777 100644 --- a/optimized/tensorRT/.gitignore +++ b/optimized/tensorRT/.gitignore @@ -12,6 +12,11 @@ onnx/ *.onnx *.onnx.data +# FP8 calibration data — captured by build/make_calib.py from the model +# checkpoint, a reproducible producer artifact (~hundreds of MB). Regenerated +# on demand, never committed. +*.calib.npz + # T5 tokenizer — generally ignored (downloaded under models//t5gemma/ # alongside the engine for legacy fallback), but the canonical copy ships # bundled at scripts/tokenizer.json (arch-agnostic, 34 MB) so the build path @@ -36,3 +41,7 @@ __pycache__/ # But keep __pycache__/ ignored even under build/ (the un-ignore above would # accidentally re-include build/__pycache__/ otherwise). build/**/__pycache__/ +# Same for make_calib.py's captured calibration npz (the *.calib.npz rule +# above is otherwise shadowed by the build/ un-ignore — make_calib writes it +# next to the build scripts by default). +build/**/*.calib.npz diff --git a/optimized/tensorRT/build/README.md b/optimized/tensorRT/build/README.md index 665ca18..4488718 100644 --- a/optimized/tensorRT/build/README.md +++ b/optimized/tensorRT/build/README.md @@ -69,6 +69,10 @@ python build_from_onnx.py same-s-decoder-fp32 # canonical ONNX is already FP32 python build_from_onnx.py sa3-m-fp32 # reads HF dit.onnx (already FP32) python build_from_onnx.py all-fp32 # every FP32 target python build_from_onnx.py all-both # canonical + FP32 + +# FP8 variant — opt-in, DiT-only. ~1.8x faster steps than FP16-mixed. +# Pair with `sa3_trt --precision fp8` at inference (fp8 DiT + fp16mixed decoder). +python build_from_onnx.py sa3-m-fp8 # reads HF dit_fp8.onnx (ModelOpt PTQ) ``` ### Consumer deps @@ -148,6 +152,33 @@ This wraps every RMSNorm chain, attention `Softmax`, and the RoPE region in `Cas Naive `BuilderFlag.FP16` (without the surgery) catastrophically overflows in RMSNorm variance + attention softmax — the islands are mandatory. BF16 was tried earlier and compounds quantisation error over 8 sampling steps (cos-sim drifts from 0.99 single-step to 0.81 final-latent vs PT FP32) — audibly degraded. +### FP8 DiT (opt-in, ~1.8x) + +`build_dit_fp8.py` extends the FP16-mixed recipe with a ModelOpt FP8 GEMM trunk: it takes the `dit_fp16mixed.onnx` plus a calibration `.npz` (real DiT inputs across the pingpong schedule) and produces `dit_fp8.onnx`: fp8 weight/activation Q/DQ on the MatMuls, the same FP32 islands re-applied (plus the conditioning front-end, which must stay FP32 or the t>=0.984 timestep features flush), and per-channel weight scales. Validated on sa3-m vs the FP16-mixed engine over the 47 reprompt Music prompts x 8 sigmas (L=646): worst single-step latent cosine 0.9982, 8-step compounded euler final-latent cosine mean 0.953 / median 0.957 / worst 0.873 over the 47 prompts (the compounded rollout is chaotic at the early sigmas, so judge by the distribution and by ear; decoded audio under the production pingpong sampler tracks the FP16-mixed generation at ~0.90 RMS-curve correlation), ~10.6-11.0 ms/step (vs ~18.7-19.4), ~1.8x. TRT tactic selection is nondeterministic per build; if a fresh engine benches noticeably slower, rebuild it. Under the stochastic pingpong sampler it yields a different but comparable sample. + +First capture the calibration data from the model checkpoint with `make_calib.py` (drives the model's own pingpong `generate()` to record real DiT inputs across the sampling schedule; prompts come from the repo's own `interface/reprompt.py` Music examples, the deployment-matched reprompt format): + +```bash +python make_calib.py \ + --model-config /SA3-M-hf/model_config.json \ + --checkpoint /SA3-M-hf/model.safetensors \ + --out sa3-m.calib.npz +``` + +Then build the engine: + +```bash +python build_dit_fp8.py \ + --input $HF/onnx/sa3-m/dit_fp16mixed.onnx \ + --calib sa3-m.calib.npz \ + --onnx $HF/onnx/sa3-m/dit_fp8.onnx \ + --engine ../models/$ARCH/sa3-m/dit_fp8.trt +``` + +`make_calib.py` needs only the repo + checkpoint (`torch`, `numpy`, `stable_audio_3`). `build_dit_fp8.py` additionally requires `nvidia-modelopt` + `onnxruntime-gpu` (the calibration-repair pass); consumers compile the published `dit_fp8.onnx` with plain `build_from_onnx.py sa3-m-fp8` (STRONGLY_TYPED, no ModelOpt, no calibration). + +> **Not yet on HF.** `dit_fp8.onnx` + `dit_fp8.onnx.data` are not in the model repo yet, so `build_from_onnx.py sa3-m-fp8` and `sa3_trt --precision fp8` 404 until a producer run uploads them (under exactly those filenames). The consumer recipe and `--precision fp8` plumbing land here so the wiring is reviewed; the artifact upload is the follow-up step. + Each script also writes the ONNX to `/onnx//.onnx`. After all 8 are done: ```bash @@ -166,6 +197,8 @@ git push | `build_from_onnx.py` | One target → download ONNX from HF + compile to TRT. **For the SA3 DiTs, pulls `dit_fp16mixed.onnx` (the pre-processed island-wrapped graph)** so the consumer just needs to invoke `STRONGLY_TYPED` compilation — no `onnx-graphsurgeon` required | consumer | | `build_dit_profile.py` | Build a DiT with custom `(min, opt, max)` profile shapes (experimental — short-form / fixed-shape variants). Operates on either ONNX flavor. | consumer | | `build_dit_fp16mixed.py` | **Producer-side** ONNX surgery: takes the canonical FP32 `dit.onnx`, finds RMSNorm chains + attention `Softmax` + RoPE region, wraps each in `Cast(FP32) ↔ Cast(FP16)` islands, converts non-island weights to FP16, and writes both the modified `dit_fp16mixed.onnx` AND the TRT engine. Only re-run when the model retrains or the island recipe changes. Requires `onnx` + `onnx-graphsurgeon`. | producer | +| `make_calib.py` | **Producer-side** FP8 calibration capture: drives the model's own pingpong `generate()` and records the six DiT engine inputs across the schedule into a `*.calib.npz` for `build_dit_fp8.py`. Needs only the checkpoint (`torch` + `stable_audio_3`). | producer | +| `build_dit_fp8.py` | **Producer-side** FP8 trunk on top of `dit_fp16mixed.onnx`: ModelOpt FP8 PTQ (MatMul/Gemm, max calibration from a `.npz`), restores ModelOpt-corrupted initializers + recalibrates activation scales, re-applies the FP32 islands (incl. the conditioning front-end), and per-channel weight scales. Writes `dit_fp8.onnx` + the TRT engine. ~1.8x faster steps than FP16-mixed. Requires `nvidia-modelopt` + `onnxruntime-gpu`. | producer | | `build_t5gemma.py` | Trace + export T5Gemma encoder ONNX + build TRT | producer | | `build_same_s_decoder.py` | Trace + export SAME-S decoder ONNX + build TRT | producer | | `build_same_s_encoder.py` | Trace + export SAME-S encoder ONNX + build TRT | producer | diff --git a/optimized/tensorRT/build/build.py b/optimized/tensorRT/build/build.py index d1aa962..92a61f5 100755 --- a/optimized/tensorRT/build/build.py +++ b/optimized/tensorRT/build/build.py @@ -93,6 +93,13 @@ def _from_onnx(name): {"label": "[opt-in] DiT sm-sfx FP32", "command": _from_onnx("sa3-sm-sfx-fp32"), "outputs": ["sa3-sm-sfx/dit_fp32.trt"]}, + # FP8 variant, opt-in, DiT-only. ~1.8x faster steps than FP16-mixed + # (ModelOpt PTQ; producer build_dit_fp8.py). A different but comparable + # sample under the stochastic pingpong sampler. + {"label": "[opt-in] DiT medium FP8 (~1.8x)", + "command": _from_onnx("sa3-m-fp8"), + "outputs": ["sa3-m/dit_fp8.trt"], + "opt_in": True}, # built only by number, never via "Build all missing" ] @@ -134,12 +141,19 @@ def render_menu(arch: str, arch_dir: Path) -> list[bool]: size_s = fmt_size(sz) if sz >= 0 else f"{DIM}(missing){RESET}" print(f" {tick} {DIM}{rel}{RESET} {size_s}") - n_missing = built_flags.count(False) + # "Build all missing" covers default targets only; opt-in variants (FP8) + # are built by number, so they are counted and reported separately. + n_missing = sum(1 for t, ok in zip(TARGETS, built_flags) + if not ok and not t.get("opt_in")) + n_optin_missing = sum(1 for t, ok in zip(TARGETS, built_flags) + if not ok and t.get("opt_in")) print() if n_missing == 0: - print(f" {BOLD}{GREEN}[A]{RESET} Build all missing {DIM}(nothing missing — all engines built){RESET}") + print(f" {BOLD}{GREEN}[A]{RESET} Build all missing {DIM}(nothing missing — all default engines built){RESET}") else: print(f" {BOLD}{YELLOW}[A]{RESET} Build all missing {DIM}({n_missing} target(s)){RESET}") + if n_optin_missing: + print(f" {DIM}(+{n_optin_missing} opt-in target(s) not in [A] — build by number){RESET}") print(f" {BOLD}{DIM}[Q]{RESET} Quit") return built_flags @@ -180,7 +194,9 @@ def main() -> int: return 0 if choice in ("a", "all"): - missing = [t for t, ok in zip(TARGETS, built_flags) if not ok] + # Opt-in variants (FP8) are excluded; build them by number. + missing = [t for t, ok in zip(TARGETS, built_flags) + if not ok and not t.get("opt_in")] if not missing: print(f" {DIM}Nothing to build.{RESET}") continue diff --git a/optimized/tensorRT/build/build_dit_fp8.py b/optimized/tensorRT/build/build_dit_fp8.py new file mode 100644 index 0000000..3707302 --- /dev/null +++ b/optimized/tensorRT/build/build_dit_fp8.py @@ -0,0 +1,861 @@ +#!/usr/bin/env python3 +"""Build a SA3 DiT TRT engine with an FP8 GEMM trunk: ModelOpt PTQ on top of +the FP16-mixed recipe (FP32 islands around RMSNorm/Softmax/RoPE) for ~1.8x +faster steps than ``dit_fp16mixed`` on Ada/Blackwell. + +This is a PRODUCER recipe (like ``build_dit_fp16mixed.py``): it transforms the +canonical ``dit_fp16mixed.onnx`` into a quantized ``dit_fp8.onnx`` and compiles +it. Consumers who just want the engine compile the published ``dit_fp8.onnx`` +with ``build_from_onnx.py`` (plain STRONGLY_TYPED, no graphsurgeon / ModelOpt). + +Background — why this is more than ``mtq.quantize(...)``: +- The DiT's per-step velocity error compounds over the 8 pingpong steps. BF16 + fails (final-latent cos ~0.81); naive FP8 PTQ lands ~0.91. The decisive + fixes, in order, take it to 0.9982 worst single-step latent cosine (n=376) + and a compounded euler final-latent cosine of mean 0.953 / worst 0.873 over + the 47 reprompt prompts vs the FP16-mixed engine; under the stochastic + pingpong sampler the result is a different but comparable sample, judged by + ear: + 1. ModelOpt rejects the canonical ONNX because upstream's island surgery + leaves it un-toposorted, and its opset bump leaves ReduceMean's pre-18 + attribute-form axes. We Kahn-sort + version_convert to opset 19 first. + 2. ModelOpt corrupts single elements of a few initializers during + preprocessing (e.g. ``dit.to_timestep_embed.2.bias`` absmax 0.12 -> + 6060). One exploded element inflates all adaLN conditioning -> fp16 + overflow -> a NaN engine and invalid calibration scales. We restore + every corrupted initializer from the source graph and recalibrate all + activation scales on a Q/DQ-bypassed copy (ORT, real conditioning). + 3. ModelOpt flattens the FP32 islands the fp16mixed recipe established + (RMSNorm variance overflows fp16 -> NaN). We re-apply them with the + FP16-mixed recipe's ``find_fp32_islands``, plus the conditioning + front-end that computes in fp32 upstream (timestep expo features / cond + embeds): in fp16 those flush and produce the entire t>=0.984 base + divergence. Initializers we upcast take their true fp32 values from the + source graph, not ModelOpt's fp16-rounded copies. + 4. Per-channel weight scales (GEMM N axis): weights are stored fp16 and + quantized at build time, so one outlier row otherwise crushes the + whole tensor's resolution. Per-channel constant-folds at build time and + costs nothing at runtime. Activations stay per-tensor (TRT requires it + for fp8 activation quant), calibrated with ``max`` (SA3 activation + outliers are signal; percentile clipping noticeably regresses parity). + +Calibration data is an INPUT (``--calib sa3-m.calib.npz``): a capture of real +(x_t, t, t5_hidden, t5_mask, seconds_total, local_add_cond) DiT inputs across +the pingpong sampling schedule. Produce it from the model checkpoint with the +companion ``make_calib.py`` (``python make_calib.py --model-config ... --checkpoint +... --out sa3-m.calib.npz``); the npz keys are the six ONNX input names, each a +leading-axis batch of samples. + +Inputs/outputs stay FP32 so the runtime can swap engines transparently. + +Validated (sa3-m, vs the FP16-mixed engine, over the 47 reprompt Music prompts +x 8 sigmas at L=646; the recipe is shape-independent: activation scales are +per-tensor and the default profile below is dynamic, so the numbers carry over): + - Worst single-step latent cosine (x + dt*v, n=376): 0.9982 (mean 0.9997) + - 8-step compounded euler final-latent cosine, distribution over the 47 + prompts: mean 0.953, median 0.957, p5 0.915, worst 0.873. The compounded + rollout is chaotic at the early sigmas (an eps=1e-3 input perturbation + alone compounds to ~0.967) and the FP16-mixed engine itself scores only + ~0.998 vs PT eager, so this is a guide, not a gate; the shipped gate is + decoded audio under the production pingpong sampler, judged by ear + (RMS-curve correlation ~0.90 vs the FP16-mixed engine's generation). + - Step time (B=1, L=646): ~10.6-11.0 ms (FP16-mixed: ~18.7-19.4 ms) -> ~1.8x. + TRT tactic selection is nondeterministic per build; if a fresh engine + benches noticeably slower, rebuild it. + - A true batched forward amortizes (~1.4x at B=4) once compute drops, unlike + FP16-mixed (<=1.09x): fp8 frees the SM throughput the FP16 engine saturated. + +Usage: + python build_dit_fp8.py + --input onnx/sa3-m/dit_fp16mixed.onnx + --calib sa3-m.calib.npz + --onnx /tmp/dit_fp8.onnx # intermediate (publishable) + --engine ../models//sa3-m/dit_fp8.trt + [--islands-mode {minimal,rope,hybrid}] # default: hybrid + [--calib-samples 16] [--workspace-gb 16] + [--work-dir DIR] [--keep-intermediates] # ~15 GB scratch + [--skip-convert] [--skip-build] + +Requires (producer): tensorrt, torch, onnx, numpy, nvidia-modelopt, +onnxruntime-gpu (the repair pass calibrates activation scales on CUDA EP). +""" +import argparse +import os +import sys +import time +from pathlib import Path + +import numpy as np + +BUILD_DIR = Path(__file__).resolve().parent +sys.path.insert(0, str(BUILD_DIR)) +# Reuse the FP16-mixed recipe's structural FP32-island finder verbatim (the fp8 +# island re-apply is the same RMSNorm/Softmax/RoPE detection it uses). +from build_dit_fp16mixed import find_fp32_islands # noqa: E402 + +T5_TOKENS = 256 +T5_HIDDEN_DIM = 768 +IO_CHANNELS = 256 + +SCALE_FLOOR = 6.2e-5 # fp16 min-normal floor for fp8-stored scales (TRT > 0) +FP8_MAX = 448.0 # e4m3 max + +# Default sa3-m latent profile (matches the published FP16-mixed sa3-m engine +# so the fp8 engine is a drop-in). min=1 keeps short windows on-engine. +_DEFAULT_PROFILE_LATENTS = (1, 1292, 4096) # (min, opt, max) + + +def _dit_profile(min_l: int, opt_l: int, max_l: int) -> dict: + return { + "x": [(1, IO_CHANNELS, min_l), (1, IO_CHANNELS, opt_l), (1, IO_CHANNELS, max_l)], + "t": [(1,), (1,), (1,)], + "t5_hidden": [(1, T5_TOKENS, T5_HIDDEN_DIM)] * 3, + "t5_mask": [(1, T5_TOKENS)] * 3, + "seconds_total": [(1,), (1,), (1,)], + "local_add_cond": [(1, 257, min_l), (1, 257, opt_l), (1, 257, max_l)], + } + +# Initializers ModelOpt is known to corrupt (single exploded elements). Every +# save downstream of the repair re-verifies these against the source graph, +# because an externalized modified tensor can silently revert on save. +KNOWN_RESTORES = ( + "dit.to_timestep_embed.2.bias", + "dit.transformer.layers.22.to_local_embed.0.weight", +) + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _kahn_sort(g) -> None: + """Topologically sort ``g.node`` in place (Kahn). Values, initializers and + external-data references are untouched.""" + available = {i.name for i in g.initializer} + available.update(inp.name for inp in g.input) + available.add("") + pending, ordered = list(g.node), [] + while pending: + progressed, remaining = False, [] + for n in pending: + if all(i in available for i in n.input): + ordered.append(n) + available.update(n.output) + progressed = True + else: + remaining.append(n) + if not progressed: + missing = {i for n in remaining for i in n.input} - available + raise RuntimeError(f"toposort stuck: {list(missing)[:5]}") + pending = remaining + del g.node[:] + g.node.extend(ordered) + + +def _detach_external(model) -> None: + """``onnx.load`` fills ``raw_data`` but keeps stale external-data metadata; + a later ``save_as_external_data`` then skips re-pointing those tensors and + silently references old offsets. Detach everything so saves re-serialize.""" + import onnx + + for i in model.graph.initializer: + if i.data_location == onnx.TensorProto.EXTERNAL: + del i.external_data[:] + i.data_location = onnx.TensorProto.DEFAULT + + +def _save_external(model, path: Path) -> None: + """Save with external data, keeping repaired/scale tensors inline + (size_threshold) and never appending into a pre-existing sidecar. + + The sidecar is ``.onnx.data`` — the repo-wide convention + (build_from_onnx.py / build_dit_fp16mixed.py) so the published + ``dit_fp8.onnx`` + ``dit_fp8.onnx.data`` pair round-trips through HF and + the consumer recipe resolves the external reference.""" + import onnx + + sidecar = Path(str(path) + ".data") + for p in (path, sidecar): + if p.exists(): + os.remove(p) + onnx.save(model, str(path), save_as_external_data=True, + location=sidecar.name, size_threshold=262144) + + +def _reverted_names(path: Path, src_inits: dict, names) -> list: + """Reload ``path`` and return the subset of ``names`` whose values drifted + >1% from ``src_inits`` (i.e. reverted to a corrupted copy on save).""" + import onnx + import onnx.numpy_helper as nh + + chk = onnx.load(str(path)) + ci = {i.name: i for i in chk.graph.initializer} + out = [] + for name in names: + if name not in ci or name not in src_inits: + continue + va = nh.to_array(ci[name]).astype(np.float32) + vb = nh.to_array(src_inits[name]).astype(np.float32) + if np.abs(va - vb).max() > np.abs(vb).max() * 1e-2 + 1e-12: + out.append(name) + return out + + +def _patch_reverts(path: Path, src_inits: dict, names, *, verify=True) -> list: + """Inline-patch (proto only) any initializer in ``names`` that reverted to a + corrupted value when ``path`` was saved, restoring it from ``src_inits``. + Handles the external-data revert bug where a >256 KB tensor silently keeps + stale bytes even after detaching. Returns the names patched.""" + import onnx + import onnx.numpy_helper as nh + + reverted = _reverted_names(path, src_inits, names) + if not reverted: + return [] + mp = onnx.load(str(path), load_external_data=False) + mi = {i.name: i for i in mp.graph.initializer} + for name in reverted: + t, s = mi[name], src_inits[name] + del t.external_data[:] + t.data_location = onnx.TensorProto.DEFAULT + t.raw_data = s.raw_data if s.raw_data else nh.from_array( + nh.to_array(s), name).raw_data + onnx.save(mp, str(path)) # proto only + if verify: + still = _reverted_names(path, src_inits, reverted) + assert not still, f"inline patch did not persist for {still}" + return reverted + + +def _guard_known_restores(path: Path, source_onnx: Path) -> None: + """Reload ``path``, compare KNOWN_RESTORES against ``source_onnx``, and + inline-patch (proto only) any tensor that reverted on save.""" + import onnx + + src = onnx.load(str(source_onnx)) + src_inits = {i.name: i for i in src.graph.initializer} + patched = _patch_reverts(path, src_inits, KNOWN_RESTORES) + if patched: + print(f" [guard] {len(patched)} tensors reverted on save; " + f"patched inline: {patched}") + + +# --------------------------------------------------------------------------- +# step 1: toposort + opset 19 (also a latent fix for the fp16mixed ONNX) +# --------------------------------------------------------------------------- + + +def toposort_opset19(input_onnx: Path, sorted_onnx: Path) -> None: + """Kahn-sort the FP16-mixed graph and convert to a native opset 19. + + Upstream's island surgery leaves the node list un-toposorted (TRT's parser + tolerates it; onnx.checker and ModelOpt do not). ModelOpt's own opset bump + only rewrites the import and leaves ReduceMean's pre-18 ``axes`` attribute, + which ORT then rejects. Do both correctly here so ModelOpt sees a clean + opset-19 model. Copies the external-data sidecar next to the sorted proto. + """ + import shutil + + import onnx + + m = onnx.load(str(input_onnx), load_external_data=False) + _kahn_sort(m.graph) + cur = max((imp.version for imp in m.opset_import + if imp.domain in ("", "ai.onnx")), default=0) + if cur < 19: + from onnx import version_converter + + print(f" converting opset {cur} -> 19 ...") + m = version_converter.convert_version(m, 19) + onnx.save(m, str(sorted_onnx)) + src_data = Path(str(input_onnx) + ".data") + dst_data = Path(str(sorted_onnx).rsplit(".onnx", 1)[0] + ".onnx.data") + # sorted proto references the sidecar by basename; mirror it alongside. + want = None + for i in m.graph.initializer: + if i.data_location == onnx.TensorProto.EXTERNAL and i.external_data: + want = i.external_data[0].value + break + if want is not None: + dst_data = sorted_onnx.parent / want + if src_data.exists() and not dst_data.exists(): + print(f" copying external-data sidecar -> {dst_data.name}") + shutil.copyfile(src_data, dst_data) + onnx.checker.check_model(str(sorted_onnx)) + print(f" toposorted opset-19 proto OK: {sorted_onnx}") + + +# --------------------------------------------------------------------------- +# step 2: fp8 PTQ +# --------------------------------------------------------------------------- + + +def quantize_fp8(sorted_onnx: Path, calib_npz: Path, out_onnx: Path) -> None: + """ModelOpt FP8 PTQ of the weighted GEMMs only. ``disable_mha_qdq`` keeps + attention BMMs and the softmax path on the FP16/FP32 recipe; the trunk + around each Q/DQ stays fp16.""" + from modelopt.onnx.quantization import quantize + + data = dict(np.load(calib_npz)) + n = data[next(iter(data))].shape[0] + print(f" {n} calibration samples, fp8 PTQ on {sorted_onnx.name}") + t0 = time.time() + quantize( + str(sorted_onnx), + quantize_mode="fp8", + calibration_data=data, + calibration_method="max", # amax scales; histogram calibrators fail on + # the zero-range all-zeros local_add_cond + calibration_eps=["cuda:0", "cpu"], + op_types_to_quantize=["MatMul", "Gemm"], + disable_mha_qdq=True, + high_precision_dtype="fp16", + use_external_data_format=True, + output_path=str(out_onnx), + ) + print(f" wrote {out_onnx} in {time.time() - t0:.0f}s") + + +# --------------------------------------------------------------------------- +# step 3: repair corrupted initializers + recalibrate activation scales +# --------------------------------------------------------------------------- + + +def repair(fp8_onnx: Path, sorted_onnx: Path, calib_npz: Path, + n_samples: int, out_onnx: Path) -> None: + import onnx + import onnx.numpy_helper as nh + import onnxruntime as ort + + def _load_detached(path): + mm = onnx.load(str(path)) + _detach_external(mm) + return mm + + m = _load_detached(fp8_onnx) + g = m.graph + inits = {i.name: i for i in g.initializer} + + # --- 1. restore every corrupted initializer vs the source graph -------- + src = _load_detached(sorted_onnx) + src_inits = {i.name: i for i in src.graph.initializer} + restored_names = [] + for name in sorted(set(inits) & set(src_inits)): + init = inits[name] + if init.data_type == onnx.TensorProto.FLOAT8E4M3FN: + continue # quantized weights legitimately differ + va = nh.to_array(init).astype(np.float32) + vb = nh.to_array(src_inits[name]).astype(np.float32) + if va.shape != vb.shape: + continue + ref = np.abs(vb).max() + 1e-12 + if np.abs(va - vb).max() / ref > 1e-2: + n_bad = int((np.abs(va - vb) > ref * 1e-2).sum()) + orig_dtype = nh.to_array(init).dtype + init.CopyFrom(nh.from_array( + nh.to_array(src_inits[name]).astype(orig_dtype), name)) + print(f" restored {name}: {n_bad} corrupted elems, " + f"absmax {np.abs(va).max():.4g} -> {np.abs(vb).max():.4g}") + restored_names.append(name) + print(f" {len(restored_names)} corrupted initializers restored") + + # --- 2. build a Q/DQ-bypassed probe model ------------------------------ + LAYOUT = ("Transpose", "Reshape", "Squeeze", "Unsqueeze") + bypass = onnx.ModelProto() + bypass.CopyFrom(m) + bg = bypass.graph + bcons: dict[str, list] = {} + for n in bg.node: + for i in n.input: + bcons.setdefault(i, []).append(n) + pairs, drop = [], set() + for q in list(bg.node): + if q.op_type != "QuantizeLinear": + continue + cur, chain, dq = q.output[0], [], None + for _ in range(4): + cs = bcons.get(cur, []) + if len(cs) != 1: + break + if cs[0].op_type == "DequantizeLinear": + dq = cs[0] + break + if cs[0].op_type in LAYOUT: + chain.append(cs[0]) + cur = cs[0].output[0] + continue + break + if dq is None: + continue + x_in = q.input[0] + if chain: + chain[0].input[0] = x_in + tail = chain[-1].output[0] + else: + tail = x_in + for c in bg.node: + for k, ci in enumerate(c.input): + if ci == dq.output[0]: + c.input[k] = tail + drop.add(q.name) + drop.add(dq.name) + pairs.append((q.name, dq.name, x_in)) + keep = [n for n in bg.node if n.name not in drop] + del bg.node[:] + bg.node.extend(keep) + probe_tensors = sorted({p[2] for p in pairs + if p[2] not in {i.name for i in bg.initializer}}) + del bg.output[:] + bg.output.extend( + onnx.helper.make_empty_tensor_value_info(t) for t in probe_tensors) + print(f" bypassed {len(pairs)} q/dq pairs, probing {len(probe_tensors)} " + f"activation tensors") + bp_path = out_onnx.parent / "_fp8_repair_bypass.onnx" + _save_external(bypass, bp_path) + + # --- 3. recalibrate activation amax on real conditioning ---------------- + d = np.load(calib_npz) + n_total = d[next(iter(d.files))].shape[0] + idx = np.linspace(0, n_total - 1, n_samples).round().astype(int) + so = ort.SessionOptions() + so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + sess = ort.InferenceSession( + str(bp_path), so, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + active_ep = sess.get_providers()[0] + if "CUDA" not in active_ep: + print(f" WARNING: CUDA EP unavailable; calibrating on {active_ep}. " + f"amax is EP-robust so scales are still valid, but install " + f"onnxruntime-gpu for a faster repair pass.") + print(f" calibrating on {len(idx)} samples ({active_ep})") + amax = {t: 0.0 for t in probe_tensors} + nonfinite: set = set() + feed_keys = ("x", "t", "t5_hidden", "t5_mask", "seconds_total", + "local_add_cond") + for j, i in enumerate(idx): + feed = {k: d[k][i:i + 1] for k in feed_keys} + outs = sess.run(None, feed) + for t, o in zip(probe_tensors, outs): + o = np.asarray(o).astype(np.float32) + fin = o[np.isfinite(o)] + if fin.size != o.size: + # additive-mask path (-inf fill): meaningless to quantize; + # mark its q/dq pair for removal + if t not in nonfinite: + print(f" non-finite tensor {t}: q/dq pair removed") + nonfinite.add(t) + amax[t] = max(amax[t], + float(np.abs(fin).max()) if fin.size else 0.0) + print(f" sample {j + 1}/{len(idx)} done", flush=True) + + # --- 4. write recalibrated scales / neutralize mask-path pairs ---------- + node_by_name = {n.name: n for n in g.node} + written = removed = 0 + for q_name, dq_name, x_in in pairs: + if x_in in nonfinite: + for nn in (q_name, dq_name): + nd = node_by_name[nn] + nd.op_type = "Identity" + del nd.input[1:] + del nd.attribute[:] + removed += 1 + continue + scale = max(amax[x_in] / FP8_MAX, SCALE_FLOOR) + for nn in (q_name, dq_name): + nd = node_by_name[nn] + sinit = inits[nd.input[1]] + orig_dtype = sinit.data_type + arr = nh.to_array(sinit) + sinit.CopyFrom(nh.from_array(np.array(scale, dtype=arr.dtype), + nd.input[1])) + assert sinit.data_type == orig_dtype + written += 1 + del g.value_info[:] # stale fp8 dtype claims on neutralized paths + print(f" wrote {written} recalibrated scales, neutralized {removed} " + f"mask-path q/dq pairs") + + _save_external(m, out_onnx) + # reload-verify EVERY restored tensor (large ones hit the external-data + # revert bug even after detaching); inline-patch any that reverted. This + # set is a superset of KNOWN_RESTORES, so no separate guard call is needed. + patched = _patch_reverts(out_onnx, src_inits, restored_names) + if patched: + print(f" inline-patched {len(patched)} reverted restores") + print(f" wrote+verified {out_onnx} ({len(restored_names)} restores)") + + +# --------------------------------------------------------------------------- +# step 4: re-apply FP32 islands (ModelOpt flattened them) +# --------------------------------------------------------------------------- + + +def reapply_fp32_islands(repaired_onnx: Path, sorted_onnx: Path, + out_onnx: Path, mode: str = "hybrid") -> None: + """Re-protect the FP32 islands the fp16mixed recipe established, which + ModelOpt strips. ``hybrid`` = structural islands (RMSNorm/Softmax) + RoPE + region + the conditioning front-end that computes fp32 upstream. Upcast + island initializers to their true fp32 values from the source graph.""" + import re + + import onnx + import onnx.numpy_helper as nh + + m = onnx.load(str(repaired_onnx)) + _detach_external(m) + g = m.graph + by_name = {n.name: n for n in g.node} + blocked = set(find_fp32_islands(m, mode="minimal")) + + FP16, FP32 = onnx.TensorProto.FLOAT16, onnx.TensorProto.FLOAT + inferred = onnx.shape_inference.infer_shapes( + onnx.load(str(repaired_onnx), load_external_data=False)) + dtype = {vi.name: vi.type.tensor_type.elem_type + for vi in (*inferred.graph.value_info, *inferred.graph.output, + *inferred.graph.input)} + for i in g.initializer: + dtype[i.name] = i.data_type + + upstream_fp32_inits: dict = {} + if mode in ("rope", "hybrid"): + srt_proto = onnx.load(str(sorted_onnx), load_external_data=False) + rope_named = (set(find_fp32_islands(srt_proto, mode="rope")) + - set(find_fp32_islands(srt_proto, mode="minimal"))) + kept = 0 + for name in rope_named: + n = by_name.get(name) + if n is not None and any(dtype.get(o) == FP16 for o in n.output): + blocked.add(name) + kept += 1 + print(f" rope extras: {kept} fp16-carrying nodes kept") + srt_full = onnx.load(str(sorted_onnx)) + for i in srt_full.graph.initializer: + if i.data_type == FP32: + upstream_fp32_inits[i.name] = nh.to_array(i) + del srt_full + print(f" loaded {len(upstream_fp32_inits)} upstream fp32 " + f"initializers for value restore") + + if mode == "hybrid": + # the conditioning front-end (timestep expo features / cond embeds / + # memory-token plumbing — everything before the transformer layers) + # computes fp32 upstream; fp16 here flushes the expo features and is + # the entire t>=0.984 base divergence. + inferred_s = onnx.shape_inference.infer_shapes( + onnx.load(str(sorted_onnx), load_external_data=False)) + sdtype = {vi.name: vi.type.tensor_type.elem_type + for vi in (*inferred_s.graph.value_info, + *inferred_s.graph.output, *inferred_s.graph.input)} + srt_p = onnx.load(str(sorted_onnx), load_external_data=False) + for i in srt_p.graph.initializer: + sdtype[i.name] = i.data_type + layer_pat = re.compile(r"^/transformer/layers\.") + ours_by_out = {o: n for n in g.node for o in n.output} + fe = 0 + for n in srt_p.graph.node: + if n.op_type == "Cast": + continue + for o in n.output: + if sdtype.get(o) == FP32 and not layer_pat.match(o): + ours = ours_by_out.get(o) + if ours is not None and ours.name: + blocked.add(ours.name) + fe += 1 + print(f" hybrid front-end fp32 ops added: {fe}") + + blocked_nodes = [by_name[b] for b in blocked if b in by_name] + print(f" {len(blocked_nodes)} island nodes " + f"({sum(1 for n in blocked_nodes if n.op_type == 'Softmax')} Softmax)") + + inits = {i.name: i for i in g.initializer} + prod = {o: n for n in g.node for o in n.output} + blocked_set = {n.name for n in blocked_nodes} + + def _const_dtype(n): + for a in n.attribute: + if a.name == "value": + return a.t.data_type + return None + + new_nodes, upcast_inits, casts_in, casts_out = [], 0, 0, 0 + for n in blocked_nodes: + for k, inp in enumerate(n.input): + if n.op_type.startswith("Reduce") and k >= 1: + continue # int64 axes initializer — never wrap in a Cast + if inp in inits and inits[inp].data_type == FP16: + if inp in upstream_fp32_inits: + arr = np.asarray(upstream_fp32_inits[inp], dtype=np.float32) + else: + arr = nh.to_array(inits[inp]).astype(np.float32) + inits[inp].CopyFrom(nh.from_array(arr, inp)) + upcast_inits += 1 + elif inp in prod and prod[inp].op_type == "Constant" \ + and _const_dtype(prod[inp]) == FP16: + for a in prod[inp].attribute: + if a.name == "value": + a.t.CopyFrom(nh.from_array( + nh.to_array(a.t).astype(np.float32), a.t.name)) + upcast_inits += 1 + elif inp in prod and prod[inp].name not in blocked_set \ + and dtype.get(inp, FP16) == FP16: + cast_out = f"{inp}__refp32_{casts_in}" + new_nodes.append(onnx.helper.make_node( + "Cast", [inp], [cast_out], + name=f"refp32_in_{casts_in}", to=FP32)) + n.input[k] = cast_out + casts_in += 1 + for o in list(n.output): + if dtype.get(o, FP16) != FP16: + continue + consumers = [c for c in g.node + if o in c.input and c.name not in blocked_set + and not c.name.startswith("refp32_")] + if not consumers: + continue + cast_out = f"{o}__refp16_{casts_out}" + new_nodes.append(onnx.helper.make_node( + "Cast", [o], [cast_out], + name=f"refp16_out_{casts_out}", to=FP16)) + for c in consumers: + for k, ci in enumerate(c.input): + if ci == o: + c.input[k] = cast_out + casts_out += 1 + + g.node.extend(new_nodes) + _kahn_sort(g) + del g.value_info[:] + print(f" upcast {upcast_inits} consts, {casts_in} in-casts, " + f"{casts_out} out-casts") + _save_external(m, out_onnx) + onnx.checker.check_model(str(out_onnx)) + _guard_known_restores(out_onnx, sorted_onnx) + print(f" saved {out_onnx}") + + +# --------------------------------------------------------------------------- +# step 5: per-channel weight scales +# --------------------------------------------------------------------------- + + +def perchannel_weights(islands_onnx: Path, out_onnx: Path, + sorted_onnx: Path) -> None: + """Upgrade weight-side Q/DQ pairs (initializer -> Transpose -> Q -> DQ -> + MatMul) to per-channel scales along the GEMM N (output-feature) axis. + Constant-folds at build time, free at runtime. Activation pairs stay + per-tensor (TRT requires it for fp8 activation quant). + + This is the FINAL save of the published artifact, so it re-externalizes + every initializer (incl. the >256 KB ``layers.22`` weight that hits the + external-data revert bug); re-guard the known restores afterwards.""" + import onnx + import onnx.numpy_helper as nh + + m = onnx.load(str(islands_onnx)) + _detach_external(m) + g = m.graph + inits = {i.name: i for i in g.initializer} + prod = {o: n for n in g.node for o in n.output} + cons: dict[str, list] = {} + for n in g.node: + for i in n.input: + cons.setdefault(i, []).append(n) + + upgraded = skipped = 0 + upgraded_scales: set = set() + for q in g.node: + if q.op_type != "QuantizeLinear": + continue + p = prod.get(q.input[0]) + if p is None or p.op_type != "Transpose" or p.input[0] not in inits: + continue + w_init = inits[p.input[0]] + if w_init.data_type == onnx.TensorProto.FLOAT8E4M3FN: + continue + w = nh.to_array(w_init).astype(np.float32) + if w.ndim != 2: + skipped += 1 + continue + perm = [a.ints for a in p.attribute if a.name == "perm"] + if perm and list(perm[0]) != [1, 0]: + skipped += 1 + continue + cs = cons.get(q.output[0], []) + if len(cs) != 1 or cs[0].op_type != "DequantizeLinear": + skipped += 1 + continue + dq = cs[0] + # stored weight is [out, in]; per-row amax = per-output-channel scale + # (axis 1 of the [in, out] transposed weight the GEMM consumes). + amax = np.abs(w).max(axis=1) + scales = np.maximum(amax / FP8_MAX, SCALE_FLOOR) + for node in (q, dq): + sinit = inits[node.input[1]] + orig_dtype = nh.to_array(sinit).dtype + sinit.CopyFrom(nh.from_array(scales.astype(orig_dtype), node.input[1])) + if len(node.input) > 2 and node.input[2] in inits: + zp = inits[node.input[2]] + znew = np.zeros(scales.shape, dtype=nh.to_array(zp).dtype) + if zp.data_type == onnx.TensorProto.FLOAT8E4M3FN: + t = onnx.helper.make_tensor( + node.input[2], zp.data_type, list(znew.shape), + bytes(len(znew.ravel())), raw=True) + else: + t = nh.from_array(znew, node.input[2]) + t.data_type = zp.data_type + zp.CopyFrom(t) + for a in list(node.attribute): + if a.name == "axis": + node.attribute.remove(a) + node.attribute.append(onnx.helper.make_attribute("axis", 1)) + upgraded_scales.add(node.input[1]) + upgraded += 1 + print(f" upgraded {upgraded} weight pairs to per-channel, skipped {skipped}") + del g.value_info[:] + _save_external(m, out_onnx) + onnx.checker.check_model(str(out_onnx)) + # the final save re-guards the corrupted-then-restored initializers, then + # confirms the per-channel scales survived the round-trip (they grew from + # scalars to vectors but stay inline, well under size_threshold). + _guard_known_restores(out_onnx, sorted_onnx) + chk = onnx.load(str(out_onnx), load_external_data=False) + nvec = sum(1 for i in chk.graph.initializer + if i.name in upgraded_scales and list(i.dims) not in ([], [1])) + assert nvec >= upgraded, \ + f"per-channel scales did not persist ({nvec} vectors < {upgraded})" + print(f" saved {out_onnx} ({nvec} per-channel scale vectors verified)") + + +# --------------------------------------------------------------------------- +# step 6: build the TRT engine +# --------------------------------------------------------------------------- + + +def build_trt_engine(onnx_path: Path, engine_path: Path, + profile_latents: tuple = _DEFAULT_PROFILE_LATENTS, + workspace_gb: int = 16) -> None: + import tensorrt as trt + + print(f"\n building TRT engine -> {engine_path}") + logger = trt.Logger(trt.Logger.WARNING) + builder = trt.Builder(logger) + network = builder.create_network( + 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)) + parser = trt.OnnxParser(network, logger) + if not parser.parse_from_file(str(onnx_path)): + for i in range(parser.num_errors): + print(f" parse error: {parser.get_error(i)}") + sys.exit(2) + cfg = builder.create_builder_config() + cfg.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_gb << 30) + profile = builder.create_optimization_profile() + for name, (lo, opt, hi) in _dit_profile(*profile_latents).items(): + profile.set_shape(name, lo, opt, hi) + cfg.add_optimization_profile(profile) + print(f" latent profile (min/opt/max): {profile_latents}") + print(f" building (workspace {workspace_gb} GB, STRONGLY_TYPED, fp8 trunk)...") + t0 = time.time() + serialized = builder.build_serialized_network(network, cfg) + if serialized is None: + print(" BUILD FAILED") + sys.exit(3) + print(f" built in {time.time() - t0:.0f}s ({serialized.nbytes / 1e6:.0f} MB)") + Path(engine_path).parent.mkdir(parents=True, exist_ok=True) + with open(engine_path, "wb") as f: + f.write(serialized) + print(f" wrote {engine_path}") + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +def convert_to_fp8(input_onnx: Path, calib_npz: Path, out_onnx: Path, + *, mode: str, calib_samples: int, work_dir: Path) -> None: + # Intermediates (five ~3 GB ONNX + sidecars) live in a dedicated work + # dir, NOT next to the published artifact — the documented usage writes + # --onnx straight into the HF repo, which must hold only dit_fp8.onnx(.data). + work = work_dir + work.mkdir(parents=True, exist_ok=True) + sorted_onnx = work / "_fp8_sorted.onnx" + quant_onnx = work / "_fp8_quant.onnx" + repaired_onnx = work / "_fp8_repaired.onnx" + islands_onnx = work / "_fp8_islands.onnx" + + print("=== [1/5] toposort + opset 19 ===") + toposort_opset19(input_onnx, sorted_onnx) + print("=== [2/5] fp8 PTQ ===") + quantize_fp8(sorted_onnx, calib_npz, quant_onnx) + print("=== [3/5] repair + recalibrate ===") + repair(quant_onnx, sorted_onnx, calib_npz, calib_samples, repaired_onnx) + print(f"=== [4/5] re-apply FP32 islands (mode={mode}) ===") + reapply_fp32_islands(repaired_onnx, sorted_onnx, islands_onnx, mode=mode) + print("=== [5/5] per-channel weight scales ===") + perchannel_weights(islands_onnx, out_onnx, sorted_onnx) + print(f"\nfp8 ONNX ready: {out_onnx}") + + +def main(): + ap = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--input", required=True, + help="Canonical FP16-mixed ONNX (onnx/sa3-m/dit_fp16mixed.onnx)") + ap.add_argument("--calib", required=True, + help="Calibration .npz (six DiT input tensors, batched)") + ap.add_argument("--onnx", default="/tmp/dit_fp8.onnx", + help="Output fp8 ONNX (intermediate; publishable to HF)") + ap.add_argument("--engine", default=None, + help="Output TRT engine path (default: alongside --onnx)") + ap.add_argument("--islands-mode", choices=("minimal", "rope", "hybrid"), + default="hybrid", + help="FP32 island coverage (hybrid = +RoPE +front-end; " + "the validated recipe)") + ap.add_argument("--calib-samples", type=int, default=16, + help="Samples used to recalibrate activation scales") + ap.add_argument("--min-latents", type=int, default=_DEFAULT_PROFILE_LATENTS[0]) + ap.add_argument("--opt-latents", type=int, default=_DEFAULT_PROFILE_LATENTS[1]) + ap.add_argument("--max-latents", type=int, default=_DEFAULT_PROFILE_LATENTS[2], + help="TRT latent profile (min/opt/max); default = sa3-m's " + "published profile") + ap.add_argument("--workspace-gb", type=int, default=16) + ap.add_argument("--work-dir", default=None, + help="Scratch dir for the ~15 GB of _fp8_* intermediates " + "(default: /_fp8_work, auto-removed on " + "success; a user-supplied dir is never auto-removed)") + ap.add_argument("--keep-intermediates", action="store_true", + help="Leave the default work dir in place after conversion") + ap.add_argument("--skip-convert", action="store_true", + help="Reuse an existing --onnx (just build)") + ap.add_argument("--skip-build", action="store_true", + help="Only produce the fp8 ONNX") + args = ap.parse_args() + + out_onnx = Path(args.onnx) + out_onnx.parent.mkdir(parents=True, exist_ok=True) + work_is_default = args.work_dir is None + work_dir = out_onnx.parent / "_fp8_work" if work_is_default else Path(args.work_dir) + if not args.skip_convert: + convert_to_fp8(Path(args.input), Path(args.calib), out_onnx, + mode=args.islands_mode, calib_samples=args.calib_samples, + work_dir=work_dir) + if not args.skip_build: + engine = Path(args.engine) if args.engine else out_onnx.with_suffix(".trt") + print("\n=== build TRT engine ===") + build_trt_engine( + out_onnx, engine, + profile_latents=(args.min_latents, args.opt_latents, args.max_latents), + workspace_gb=args.workspace_gb) + # Only auto-remove the dedicated dir we created; never rmtree a + # user-supplied --work-dir (it may be shared or hold other files). + if (work_is_default and not args.skip_convert + and not args.keep_intermediates and work_dir.exists()): + import shutil + + shutil.rmtree(work_dir, ignore_errors=True) + print(f"cleaned intermediates: {work_dir}") + + +if __name__ == "__main__": + main() diff --git a/optimized/tensorRT/build/build_from_onnx.py b/optimized/tensorRT/build/build_from_onnx.py index f75f70d..6460023 100755 --- a/optimized/tensorRT/build/build_from_onnx.py +++ b/optimized/tensorRT/build/build_from_onnx.py @@ -173,6 +173,25 @@ "profile": _DIT_PROFILE, "plugin": False, }, + # ── FP8 variant (DiT-only) ─────────────────────────────────────────── + # Pre-processed FP8-trunk ONNX (ModelOpt PTQ + FP32 islands + per-channel + # weights; producer: build_dit_fp8.py). STRONGLY_TYPED compile: the Q/DQ + # nodes and per-tensor dtypes carry the precision. ~1.8x faster steps than + # FP16-mixed; a different but comparable sample under the pingpong sampler. + "sa3-m-fp8": { + # External-data sidecar travels alongside (smaller than fp16mixed's + # 2.9 GB — the quantised GEMM weights are 1-byte e4m3). + "onnx_hf": ["sa3-m/dit_fp8.onnx", "sa3-m/dit_fp8.onnx.data"], + "trt_local": "sa3-m/dit_fp8.trt", + "flags": set(), + "network": "STRONGLY_TYPED", + "workspace_gb": 16, + "profile": _DIT_PROFILE, + "plugin": False, + # Opt-in: built only by explicit name, never via 'all' / 'all-both'. + # The published dit_fp8.onnx is the gate — see README ("Not yet on HF"). + "opt_in": True, + }, # SAME-L FP32 decoder: the canonical ONNX is FP16-mixed; we upcast every # FP16 initializer/Constant/Cast to FP32 in-process before building. The # Triton SWA plugin already runs FP32 internally, so its contract is @@ -336,7 +355,9 @@ def build_one(name: str) -> str: def main(): - canonical = [k for k in TARGETS if not k.endswith("-fp32")] + opt_in = [k for k in TARGETS if TARGETS[k].get("opt_in")] + canonical = [k for k in TARGETS + if not k.endswith("-fp32") and not TARGETS[k].get("opt_in")] fp32 = [k for k in TARGETS if k.endswith("-fp32")] if len(sys.argv) < 2: @@ -347,6 +368,10 @@ def main(): print("\nFP32 variants (~2x slower, ~2x engine size, opt-in):") for k in fp32: print(f" {k}") + if opt_in: + print("\nOpt-in variants (built only by explicit name, not in any group):") + for k in opt_in: + print(f" {k}") print("\nGroups:") print(" all — every canonical target (default for shipping)") print(" all-fp32 — every FP32 target") diff --git a/optimized/tensorRT/build/make_calib.py b/optimized/tensorRT/build/make_calib.py new file mode 100644 index 0000000..a5a84b2 --- /dev/null +++ b/optimized/tensorRT/build/make_calib.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +"""Capture FP8 calibration data for the SA3 DiT (producer-side). + +``build_dit_fp8.py`` needs a calibration ``.npz``: real DiT inputs sampled +across the pingpong sampling schedule. This script produces one from nothing +but the model checkpoint, by driving the model's own ``generate()`` and +recording the six DiT engine inputs at every sampling step. + +FP8 is the only quantization path in this repo that needs calibration data +(``fp16mixed`` / ``fp32`` need none, MLX is weight-only), so this also +establishes the convention: a ``.calib.npz`` whose keys are the six ONNX +input names, each a leading-axis batch of samples. The npz is a reproducible +producer artifact: gitignored, never committed, regenerated on demand. + +Native by construction: + * Loads with ``loading_utils.load_diffusion_cond`` (the repo's own loader: + ``create_diffusion_cond_from_config`` + ``copy_state_dict``), wrapped in a + ``StableAudioModel``, so it needs only the checkpoint dir + (``model_config.json`` + ``model.safetensors`` + the bundled + ``t5gemma-b-b-ul2/`` encoder, loaded locally; falls back to the config's HF + model_name if absent). No ``StableAudioModel.from_pretrained`` (its + hard-coded checkpoint registry), no ``stable-audio-tools``. + * Prompts come from the repo's own ``interface/reprompt.py``: the Music + few-shot examples (``SYSTEM_PROMPTS["Music"]``), pulled at runtime via the + module's ``_extract_examples`` so there is one source of truth, not a + copy. These are the exact post-reprompt format the model is driven with at + deployment (genre + instruments + rhythm + mood + BPM + Length). + * Captures from the real ``StableAudioModel.generate()`` with + ``sampler_type="pingpong"``, ``steps=8``, ``cfg_scale=1.0`` (the settings + the model is sampled with): the conditioner, the dist-shifted sigma + schedule, and the pingpong renoise are all the real inference path, not a + re-implementation. + * Hooks the ``DiTWrapper`` forward, where conditioning arrives as the + 257-token ``cross_attn_cond`` (``[t5(256) | seconds(1)]``) and the + 257-channel ``local_add_cond`` (``[inpaint_masked_input(256) | + inpaint_mask(1)]``, all zeros for plain text-to-music). The trailing + seconds token is dropped to ``t5_hidden`` (256); ``seconds_total`` is the + raw duration scalar (the engine recomputes the seconds embedding itself). + +The six keys and the verified mapping: + x (N, 256, L) pre-conv latent at each step + t (N,) sigma at each step + t5_hidden (N, 256, 768) raw T5Gemma hidden states (cross-attn token 257 dropped) + t5_mask (N, 256) T5 attention mask + seconds_total (N,) raw duration scalar + local_add_cond (N, 257, L) inpaint local conditioning (zeros for t2m) + +Why these prompts, 1 seed, one duration: + Calibration is amax (``calibration_method="max"``), so the question is "have + we seen the tail that sets each tensor's scale", not "have we covered the + distribution". The quantized prompt-driven GEMMs (``to_cond_embed`` and the + cross-attn projections; attention BMMs are excluded by ``disable_mha_qdq``) + have activation maxima dominated by T5Gemma's systematic outlier channels, + which fire for essentially every prompt, so amax saturates after a handful of + prompts. The reprompt example set (47 prompts, a few hundred samples + over the 8 sigmas) sits comfortably inside the standard PTQ calibration band + (128-512 samples), past the knee. Broader or out-of-distribution prompts can + make calibration *worse*: under max calibration a single OOD outlier inflates + the scale and coarsens fp8 resolution for the in-distribution inputs the + engine actually serves, so the deployment-matched in-repo set is the right + default. A second seed re-rolls noise the outlier channels already maxed, so + seeds are the least useful axis. Duration is irrelevant to fp8 scales (the + seconds-embedding front-end is inside the re-applied FP32 islands, never + quantized), and a single duration keeps the npz rectangular for the build + contract. + +Usage: + python make_calib.py \ + --model-config /model_config.json \ + --checkpoint /model.safetensors \ + --out sa3-m.calib.npz + [--duration 54.0] [--steps 8] [--seed 1528] [--device cuda] + +Requires (producer): torch, numpy, and this repo's ``stable_audio_3`` package +with the model checkpoint. +""" +import argparse +import json +import sys +from pathlib import Path + +import numpy as np + +T5_TOKENS = 256 +T5_HIDDEN_DIM = 768 +IO_CHANNELS = 256 +LOCAL_ADD_CHANNELS = 257 # [inpaint_masked_input(256) | inpaint_mask(1)] + +# Default duration: with the default 6 s duration_padding_sec this adapts to a +# latent length L=646, a representative shape inside the published sa3-m profile's +# dynamic range (min 1, opt 1292, max 4096). Calibration scales are per-tensor, +# so the capture length need not equal the profile's opt point. +DEFAULT_DURATION_S = 54.0 +DEFAULT_STEPS = 8 +DEFAULT_SEED = 1528 + + +def _load_prompts() -> list: + """The Music few-shot examples from the repo's own ``reprompt.py`` (the + exact post-reprompt format the model is driven with at deployment), via the + module's own ``_extract_examples`` parser. Imported lazily so ``--help`` + doesn't pull in transformers (which the model load needs anyway).""" + from stable_audio_3.interface.reprompt import ( + SYSTEM_PROMPTS, _extract_examples) + prompts = _extract_examples(SYSTEM_PROMPTS["Music"]) + if not prompts: + raise RuntimeError( + "no Music examples found in reprompt.SYSTEM_PROMPTS['Music']") + return prompts + + +# --------------------------------------------------------------------------- +# capture +# --------------------------------------------------------------------------- + + +def _load_model(model_config_path: Path, checkpoint_path: Path, device: str): + """Load via the repo's own ``load_diffusion_cond`` (config dict + safetensors) + and wrap in ``StableAudioModel`` so ``generate()`` is the real entry point. + fp32 (model_half=False) for a clean, deterministic capture; amax from fp32 is + a safe upper bound for the fp16/fp8 engine.""" + from stable_audio_3.loading_utils import load_diffusion_cond + from stable_audio_3.model import StableAudioModel + + with open(model_config_path) as f: + model_config = json.load(f) + + # Point the T5Gemma conditioner at the encoder bundled in the checkpoint dir + # (snapshot_download lays it out as //), so capture needs + # only the checkpoint and runs offline. Drop repo_id so a stale network repo + # can never win. Falls back to the config's HF model_name if not bundled. + dest = model_config_path.parent + for c in model_config.get("model", {}).get("conditioning", {}).get("configs", []): + if c.get("type") == "t5gemma": + cc = c.setdefault("config", {}) + subfolder = cc.get("subfolder", "t5gemma-b-b-ul2") + if (dest / subfolder).is_dir(): + cc["model_path"] = str(dest) + cc.pop("repo_id", None) + + print(f" loading model: {checkpoint_path}", flush=True) + model = load_diffusion_cond( + model_config, str(checkpoint_path), device=device, model_half=False) + return StableAudioModel(model, model_config, device, model_half=False) + + +def capture(sa3, prompts, duration, steps, base_seed): + """Hook the DiTWrapper forward and run the pingpong generate() once per + prompt. Each generate() emits exactly ``steps`` DiT calls, one per sigma, so + sigma coverage is balanced without per-sigma capping.""" + rows = {"x": [], "t": [], "t5_hidden": [], "t5_mask": [], "local_add_cond": []} + dit = sa3.dit # DiTWrapper + orig_forward = dit.forward + + def hook(x, t, **kw): + ca = kw["cross_attn_cond"] + mask = kw["cross_attn_mask"] + if ca.shape[1] == T5_TOKENS + 1: # drop the trailing seconds token + ca = ca[:, :T5_TOKENS] + mask = mask[:, :T5_TOKENS] + lac = kw.get("local_add_cond", None) + for i in range(x.shape[0]): # batch is 1 at cfg_scale=1.0 + rows["x"].append(x[i].float().cpu().numpy()) + rows["t"].append(np.array([float(t[i])], dtype=np.float32)) + rows["t5_hidden"].append(ca[i].float().cpu().numpy()) + rows["t5_mask"].append(mask[i].float().cpu().numpy()) + if lac is not None: + rows["local_add_cond"].append(lac[i].float().cpu().numpy()) + else: + rows["local_add_cond"].append( + np.zeros((LOCAL_ADD_CHANNELS, x.shape[-1]), dtype=np.float32)) + return orig_forward(x, t, **kw) + + try: + dit.forward = hook + for p_idx, prompt in enumerate(prompts): + sa3.generate( + prompt=prompt, duration=duration, steps=steps, + cfg_scale=1.0, sampler_type="pingpong", + seed=base_seed + p_idx, + return_latents=True, # skip decode; we only need DiT inputs + ) + print(f" [{p_idx + 1}/{len(prompts)}] {len(rows['x'])} samples", + flush=True) + finally: + dit.forward = orig_forward + + return rows + + +def _save(rows, duration, out_path: Path): + n = len(rows["x"]) + if n == 0: + raise RuntimeError("captured 0 samples") + lengths = {a.shape[-1] for a in rows["x"]} + if len(lengths) != 1: + raise RuntimeError( + f"captured ragged latent lengths {sorted(lengths)} - the npz must be " + f"rectangular; use a single --duration") + L = lengths.pop() + + out = { + "x": np.stack(rows["x"]).astype(np.float32), + "t": np.concatenate(rows["t"]).astype(np.float32), + "t5_hidden": np.stack(rows["t5_hidden"]).astype(np.float32), + "t5_mask": np.stack(rows["t5_mask"]).astype(np.float32), + "seconds_total": np.full(n, float(duration), dtype=np.float32), + "local_add_cond": np.stack(rows["local_add_cond"]).astype(np.float32), + } + out_path.parent.mkdir(parents=True, exist_ok=True) + np.savez(out_path, **out) + + sched = sorted({round(float(t), 6) for t in out["t"]}, reverse=True) + print(f"\n wrote {out_path}: {n} samples, L={L}") + print(f" sigma schedule ({len(sched)}): {[round(s, 4) for s in sched]}") + lo, hi = float(out['t5_hidden'].min()), float(out['t5_hidden'].max()) + print(f" ranges: t5_hidden [{lo:.2f}, {hi:.2f}] " + f"x [{float(out['x'].min()):.2f}, {float(out['x'].max()):.2f}] " + f"local_add_cond absmax {float(np.abs(out['local_add_cond']).max()):.3g}") + return out_path + + +# --------------------------------------------------------------------------- +# main +# --------------------------------------------------------------------------- + + +def main() -> int: + ap = argparse.ArgumentParser( + description=__doc__.split("\n\n")[0], + formatter_class=argparse.RawDescriptionHelpFormatter) + ap.add_argument("--model-config", required=True, + help="Path to model_config.json") + ap.add_argument("--checkpoint", required=True, + help="Path to model.safetensors") + ap.add_argument("--out", default="sa3-m.calib.npz", + help="Output calibration npz") + ap.add_argument("--duration", type=float, default=DEFAULT_DURATION_S, + help="Generation duration in seconds (one value: keeps the " + "npz rectangular; default adapts to L=646)") + ap.add_argument("--steps", type=int, default=DEFAULT_STEPS, + help="Pingpong steps (= sigmas captured per generate)") + ap.add_argument("--seed", type=int, default=DEFAULT_SEED, + help="Base seed (a per-prompt offset is derived)") + ap.add_argument("--device", default="cuda") + args = ap.parse_args() + + prompts = _load_prompts() + n_expected = len(prompts) * args.steps + print(f"capturing {len(prompts)} prompts x {args.steps} sigmas = " + f"{n_expected} samples @ {args.duration}s") + + sa3 = _load_model(Path(args.model_config), Path(args.checkpoint), args.device) + rows = capture(sa3, prompts, args.duration, args.steps, args.seed) + _save(rows, args.duration, Path(args.out)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py b/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py index 7124edb..40a6e1e 100644 --- a/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py +++ b/optimized/tensorRT/scripts/diff_attn_nocast_plugin.py @@ -5,6 +5,8 @@ """ import torch import tensorrt.plugin as trtp +import numpy as np +import numpy.typing as npt from typing import Tuple _stream_cache = {} @@ -13,7 +15,8 @@ @trtp.register("samel::diff_attn_swa") def diff_attn_swa_desc(q_bat: trtp.TensorDesc, k_bat: trtp.TensorDesc, - v_bat: trtp.TensorDesc, num_heads: int) -> trtp.TensorDesc: + v_bat: trtp.TensorDesc, + num_heads: npt.NDArray[np.int64]) -> trtp.TensorDesc: out = q_bat.like() out.shape_expr[-2] = q_bat.shape_expr[-2] // 2 return out @@ -21,7 +24,8 @@ def diff_attn_swa_desc(q_bat: trtp.TensorDesc, k_bat: trtp.TensorDesc, @trtp.impl("samel::diff_attn_swa") def diff_attn_swa_impl(q_bat: trtp.Tensor, k_bat: trtp.Tensor, v_bat: trtp.Tensor, - num_heads: int, outputs: Tuple[trtp.Tensor], stream: int): + num_heads: npt.NDArray[np.int64], + outputs: Tuple[trtp.Tensor], stream: int): global _triton_fn if stream not in _stream_cache: _stream_cache[stream] = torch.cuda.ExternalStream(stream) @@ -37,5 +41,5 @@ def diff_attn_swa_impl(q_bat: trtp.Tensor, k_bat: trtp.Tensor, v_bat: trtp.Tenso # NO dtype cast — Triton auto-compiles for the input dtype o = _triton_fn(q, k, v, window=17) - H = num_heads + H = int(np.asarray(num_heads).reshape(-1)[0]) out_t.copy_(o[:, :, :H, :] - o[:, :, H:, :]) diff --git a/optimized/tensorRT/scripts/sa3_trt.py b/optimized/tensorRT/scripts/sa3_trt.py index 199bc06..5388676 100644 --- a/optimized/tensorRT/scripts/sa3_trt.py +++ b/optimized/tensorRT/scripts/sa3_trt.py @@ -404,9 +404,11 @@ def __init__(self, dit: str, decoder: str, *, Args: dit: one of DIT_CHOICES — "sm-music" / "sm-sfx" / "medium" decoder: one of DECODER_PATHS — "same-s" / "same-l" - precision: "fp16mixed" (default, fastest) or "fp32" (bit-equiv - PyTorch eager, ~2× slower). Engines auto-download - from HF if the requested precision file is missing. + precision: "fp16mixed" (default, fastest), "fp32" (bit-equiv + PyTorch eager, ~2× slower), or "fp8" (medium DiT + only, ModelOpt PTQ, ~1.8× faster steps; pairs with + the fp16mixed decoder). Engines auto-download from + HF if the requested precision file is missing. default_T_lat: latent length to build the initial graph at default_steps: pingpong steps for the initial graph default_seconds: duration condition for the initial graph (used for @@ -634,8 +636,9 @@ def main(): ap.add_argument("--dit", choices=list(DIT_CHOICES.keys()), default=None) ap.add_argument("--decoder", choices=list(DECODER_PATHS.keys()), default=None) ap.add_argument("--precision", choices=list(canon.PRECISIONS), default="fp16mixed", - help="Engine precision: 'fp16mixed' (default, fast) or 'fp32' " - "(bit-equiv PyTorch eager, slower). Auto-downloads from HF.") + help="Engine precision: 'fp16mixed' (default, fast), 'fp32' " + "(bit-equiv PyTorch eager, slower), or 'fp8' (DiT-only " + "ModelOpt PTQ, ~1.8x faster steps). Auto-downloads from HF.") ap.add_argument("--models-dir", default=str(canon.MODELS_DIR)) ap.add_argument("--seconds", type=float, default=30.0) ap.add_argument("--steps", type=int, default=8) diff --git a/optimized/tensorRT/scripts/sa3_trt_core.py b/optimized/tensorRT/scripts/sa3_trt_core.py index db92cd2..a2092ad 100644 --- a/optimized/tensorRT/scripts/sa3_trt_core.py +++ b/optimized/tensorRT/scripts/sa3_trt_core.py @@ -146,27 +146,43 @@ def _detect_gpu_arch() -> str: # # The canonical engines are FP16-mixed (FP16 trunk + FP32 islands around # RMSNorm / Softmax / RoPE). Pure-FP32 variants are also published — same -# numerical behavior as PyTorch eager FP32. Use `--precision fp32` on the -# CLI to pick them; default is `fp16mixed`. +# numerical behavior as PyTorch eager FP32. `fp8` is a DiT-only ModelOpt-PTQ +# variant (~1.8x faster steps than fp16mixed; built by build/build_dit_fp8.py); +# under the stochastic pingpong sampler it gives a different but comparable +# sample. There is no fp8 decoder, so `--precision fp8` pairs the fp8 DiT with +# the fp16mixed decoder. Default is `fp16mixed`. # # The lookup tables below resolve the engine filename per (dit/decoder, # precision). Encoders are FP16-mixed only. DIT_ENGINE_FILENAME = { "fp16mixed": "dit_fp16mixed.trt", "fp32": "dit_fp32.trt", + "fp8": "dit_fp8.trt", } _DIT_SUBDIR = {"sm-music": "sa3-sm-music", "sm-sfx": "sa3-sm-sfx", "medium": "sa3-m"} DECODER_ENGINE_FILENAME = { "same-l": { "fp16mixed": "dec_dynamic_triton_swa.trt", "fp32": "dec_dynamic_fp32.trt", + "fp8": "dec_dynamic_triton_swa.trt", # no fp8 decoder: use fp16mixed }, "same-s": { "fp16mixed": "dec_dynamic_bf16.trt", "fp32": "dec_dynamic_fp32.trt", + "fp8": "dec_dynamic_bf16.trt", # no fp8 decoder: use fp16mixed }, } -PRECISIONS = ("fp16mixed", "fp32") +PRECISIONS = ("fp16mixed", "fp32", "fp8") +# fp8 is only built/published for the medium DiT (sa3-m); the small models +# have no dit_fp8.trt, so guard before a request 404s on a nonexistent file. +_FP8_DITS = ("medium",) + + +def _check_precision(dit_name: str, precision: str) -> None: + if precision == "fp8" and dit_name not in _FP8_DITS: + raise ValueError( + f"precision 'fp8' is only available for dit in {list(_FP8_DITS)} " + f"(sa3-m); got dit={dit_name!r}. Use 'fp16mixed' or 'fp32'.") def get_dit_engine_path(dit_name: str, precision: str = "fp16mixed") -> Path: @@ -174,6 +190,7 @@ def get_dit_engine_path(dit_name: str, precision: str = "fp16mixed") -> Path: raise ValueError(f"unknown dit={dit_name!r}; valid: {list(_DIT_SUBDIR)}") if precision not in DIT_ENGINE_FILENAME: raise ValueError(f"unknown precision={precision!r}; valid: {PRECISIONS}") + _check_precision(dit_name, precision) return ARCH_DIR / _DIT_SUBDIR[dit_name] / DIT_ENGINE_FILENAME[precision] @@ -189,6 +206,7 @@ def get_engine_files(dit_name: str, decoder_name: str, precision: str = "fp16mix with_encoder: bool = False) -> list[str]: """Relative paths (under ARCH_DIR) needed for the chosen pipeline. Pass this list to _ensure_files() to auto-download anything missing from HF.""" + _check_precision(dit_name, precision) files = list(SHARED_FILES) files.append(f"{_DIT_SUBDIR[dit_name]}/{DIT_ENGINE_FILENAME[precision]}") files.append(f"{decoder_name}/{DECODER_ENGINE_FILENAME[decoder_name][precision]}")