Skip to content

feat: Feature Injection Phase 2 across TRT pipeline + FP8 ControlNet quantization#21

Open
forkni wants to merge 12 commits into
SDTD_031_devfrom
feat/feature-injection-fp8-cn
Open

feat: Feature Injection Phase 2 across TRT pipeline + FP8 ControlNet quantization#21
forkni wants to merge 12 commits into
SDTD_031_devfrom
feat/feature-injection-fp8-cn

Conversation

@forkni

@forkni forkni commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Changes

  • 33e97b3 fix: export conditioning_scale as rank-1 to unblock fp8 CN quantization
  • 311f2d7 fix: CN fp8 engine inherits global Trtprofile=performance flag
  • 7e8700e feat: fp8 TRT engine for ControlNet (calibration, engine path, config gate)
  • a72b27e feat: align fi_strength default to thesis alpha 0.75 (squash cdb5135..a12798e)
  • 6ac025b style: lint reformat of FI model/utils code
  • 38b72df fix: correct SDXL log label for wrapped exports; log FI default when config omits fi keys
  • 0a135d7 fix: collect FI processors after install and keep FI scalars as traced tensors
  • 38ebe01 fix: thread FI cache args and outputs through ControlNet export routing in UnifiedExportWrapper
  • b57b5d7 fix: accept fp8_use_feature_injection in EngineBuilder.build and exclude FI subgraph from FP8 Q/DQ
  • 6f5905f feat: default use_feature_injection=True; document FI keys in example config
  • 04170e0 feat: wire fi_strength/fi_threshold/max_cache_maxframes through config + wrapper; fix live-update chain
  • 25414c0 feat: implement Phase 2 Feature Injection (FI) across TRT pipeline

Branch

feat/feature-injection-fp8-cn -> SDTD_031_dev

forkni added 12 commits June 12, 2026 22:42
StreamV2V thesis §3.4.2 Feature Fusion: nearest-neighbour O-cache blending.
All 10 implementation steps on forkni/v2v-feature-injection.

Steps 1-3 (utils/models/attention_processors):
- get_fi_eligible_mask: mid+up01 eligible, down always False (Appendix B.2)
- create_fi_cache: O-cache bucket allocator (no K/V dim, shape maxframes×B×S×H)
- UNet model: FI bindings, local-sequential names (fio_cache_in_i), fi_strength/fi_threshold [1] scalars
- CachedSTAttnProcessor2_0: get_nn_feats + FI blend (Eq 3.2, α=0.8 default)
  attribute-based I/O (proc._fi_cache / _fi_cache_out), return 2-tuple unchanged

Steps 4-6 (export wrapper, engine, pipeline):
- unet_unified_export: _collect_fi_processors walk, _set_fi_cache, FI args split,
  fi_cache_outs appended to return for ONNX graph tracing
- unet_engine: fio_cache/fi_strength/fi_threshold bindings; returns 3-tuple
  (noise_pred, kvo_cache_out, fio_cache_out) — lazy name lists same as KVO
- pipeline: fio_cache + persistent [1] fp32 tensors; defensive 3-tuple unpack;
  update_kvo_cache extended with fio circular-buffer write (squeeze dim 0)

Steps 7-10 (engine id, wrapper, config, OSC):
- engine_manager: --fi-{bool} suffix prevents stale engine load on FI toggle
- wrapper: use_feature_injection in both signatures; processor walk-order install
  with fi_eligible per mask; create_fi_cache after create_kvo_cache; fi_layer_count
  to UnifiedExportWrapper; fp8_use_feature_injection build opt
- config: use_feature_injection param_map entry
- stream_parameter_updater: fi_strength/fi_threshold in-place tensor update
  (CUDA-graph-safe, no realloc)
…g + wrapper; fix live-update chain

- config.py: add fi_strength (0.8), fi_threshold (0.98), max_cache_maxframes (4) keys
- wrapper.__init__ + _load_model: new fi_strength/fi_threshold params threaded through
- _load_model: replace hardcoded 0.8/0.98 tensors with float(fi_strength)/float(fi_threshold)
- wrapper.update_stream_params: add fi_strength/fi_threshold kwargs, forward to _param_updater
  (fixes live-update chain broken at wrapper level; updater apply side already existed at
  stream_parameter_updater.py:395-401)
… config

Ready-to-test defaults: FI enabled when use_cached_attn=true, strength=0.8, threshold=0.98.
Config default change means TD-generated yamls (no Fienable par yet) will activate FI
automatically whenever cached attention is on.
…ude FI subgraph from FP8 Q/DQ

Phase 2 (313acbe) wired fp8_use_feature_injection at the wrapper call site
(wrapper.py:2094) but never added a consumer, causing TypeError on the first
fp8+FI UNet build. Three-point fix mirroring the fp8_use_cached_attn pattern:

- __init__.py compile_unet: pop fp8_use_feature_injection from build_options
  before **build_options reaches builder.build(); forward to builder.build()
- builder.py EngineBuilder.build: add fp8_use_feature_injection param;
  thread to quantize_onnx_fp8 as use_feature_injection=
- fp8_quantize.py: add use_feature_injection param to quantize_onnx_fp8;
  add _FEATURE_EXCLUDE_PATTERNS["feature_injection"] covering fio_cache_*,
  fi_strength, fi_threshold; extend nodes_to_exclude when flag is set

Also closes the plan Med-risk "FP8 quantization of FI MatMul" item — the FI
cosine-MatMul/Gather subgraph stays FP16, symmetric to the kvo_cache exclusion.
Calibration capture needs no change: _read_onnx_input_specs synthesizes
fio_cache/fi_strength inputs generically, and the gate (use_cached_attn or ...)
is already open since FI requires cached_attn.
… gate)

- capture_calibration_data_controlnet(): synthetic SDXL passes on CN model,
  captures 6 real activations + synthesizes conditioning_scale (1D, const 1.0)
- quantize_onnx_fp8(): fix IndexError on scalar ONNX inputs (empty dims list)
- builder.build(): artifact_prefix param (controlnet.fp8.onnx vs unet.fp8.onnx);
  is_controlnet branch routes to CN calibration capture
- compile_controlnet(): extract fp8/calibration_steps opts, pass pipe_ref=controlnet
- engine_manager get_engine_path(): --fp8 suffix on CONTROLNET path prevents fp16 cache collision
- get_or_load_controlnet_engine() + _get_default_controlnet_build_options(): fp8 param plumbed through
- ControlNetConfig.fp8 field; config.py and wrapper.py pass it to engine manager
Previously the CN fp8 gate only read the per-CN config key. Now
fp8 = (global wrapper fp8) OR (per-CN override), so setting
Trtprofile=performance in the TD UI automatically builds all CN
engines as fp8 without per-controlnet yaml overrides.
modelopt's CalibrationDataProvider sorts input names alphabetically, making
conditioning_scale index-0. With an empty-shape (rank-0) ONNX declaration,
input_shapes[name][0] raises IndexError before calibration can start.

Fix: declare conditioning_scale as (1,) in get_shape_dict / get_sample_input /
get_input_profile. Broadcasting in the traced ONNX Mul is unaffected. The calib
npz already has shape (N,), which now splits into N × (1,) slices as expected.

ControlNetModelEngine probes the loaded engine's binding rank once in __init__
(_cs_rank1) so legacy fp16 engines (rank-0 binding) continue working without
requiring a cache flush.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant