diff --git a/contrib/models/Qwen-Image-Edit/.gitignore b/contrib/models/Qwen-Image-Edit/.gitignore new file mode 100644 index 00000000..1135e7d7 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/.gitignore @@ -0,0 +1,5 @@ +output/ +output_edited.png +log-neuron-cc.txt +scratch/ +global_metric_store.json diff --git a/contrib/models/Qwen-Image-Edit/OPTIMIZATION.md b/contrib/models/Qwen-Image-Edit/OPTIMIZATION.md new file mode 100644 index 00000000..5839884c --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/OPTIMIZATION.md @@ -0,0 +1,103 @@ +# Qwen-Image-Edit on Trainium2 — optimized (TP4×CP4 + WLO) + +NeuronX adaptation of `Qwen/Qwen-Image-Edit-2509` for AWS Trainium2 inference, with the +latency optimizations applied. **Production config: TP=4 × CP=4 (world=16) + WLO**, which +runs the 896×1184 / 8-step / CFG=1 virtual try-on at **~4.5 s end-to-end** — faster than the +H100 vLLM-Omni reference (4.99 s), lossless. + +> **Scope / applicability.** This optimization round targets a **few-step +> distillation–finetuned checkpoint run with classifier-free guidance disabled +> (CFG=1)** — i.e. ~8 denoising steps and a *single* (positive-only) transformer forward +> per step, no negative-prompt pass. Two consequences shape every number below: +> - With CFG=1 there is no negative branch to batch, so the **CFG-parallel (DP=2) path +> buys nothing** — the right lever is **Context Parallel (CP)**, which shards the single +> forward across more cores. (For CFG>1, V3 CFG's DP=2 batching of neg+pos is the better +> layout; CP scaling here is specifically for the CFG=1 few-step regime.) +> - At only ~8 steps the fixed per-run cost (text encoder + VAE encode/decode, ~1.1 s) is a +> large fraction of E2E, so VAE/text-encoder latency matters far more than it would for a +> 50-step run — which is why the VAE batched-tile win below is material here. +> +> The latency targets (matching/beating H100's 4.99 s) and the CP=4 sweet-spot choice are +> all stated **for this few-step / CFG=1 workload**; a many-step or CFG>1 run has a +> different optimum. + +## What's here + +``` +src/ full NeuronX model/compile/run source (the optimizations live here) +release_v3cp4_wlo/ the production config: compile.sh / run_tryon.sh / test / README + sample outputs +requirements.txt +``` + +Input try-on images (cloth/, input_img/) are NOT included — supply your own and pass via the +run script flags (see release_v3cp4_wlo/run_tryon.sh). + +## Quick start (production config) + +```bash +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +cd /contrib/models/Qwen-Image-Edit + +# 1) compile transformer + vision + LM at world=16 (+WLO); VAE symlinked. ~35 min. +bash release_v3cp4_wlo/compile.sh + +# 2) run try-on (~4.5 s). Edit the cloth/model paths in the script or pass via QIE_CLOTH / QIE_MODEL_IMG. +bash release_v3cp4_wlo/run_tryon.sh out.png +``` + +See `release_v3cp4_wlo/README.md` for the full recipe, the CP scaling rationale (why CP=4), +and the quality test. + +## The optimizations (vs the original ~7.9 s baseline) + +| optimization | effect | how | +|---|---|---| +| **VAE batched-tile** | VAE encode 559 → 298 ms (−47%), decode 422 → 192 ms (−55%), **−6.1% E2E**; numerically identical | compile the tiled VAE encoder/decoder at `batch=N` (N = tiles/image — 6 for 1024², 12 for two-image edit) and run all tiles in **one** NEFF launch instead of N sequential launches. `_tiled_encode`/`_tiled_decode` enumerate tiles → run in chunks of the compiled batch → scatter+crop+blend back into the grid. Collapses ~37 ms/tile launch overhead; matters at few-step because VAE is a fixed cost. Set via `VAE_BATCH_SIZE` in `compile.sh`. **Both `_tiled_encode` and `_tiled_decode` must be updated** (half-applied regresses the decoder). | +| **WLO** (weight layout opt) | −3.2% step, **bit-exact** | pass `priority_model_key="inference"` to `ModelBuilder.compile()` — was never enabled. `QIE_WLO=1` (default). | +| **CP scaling** (CP=2 → CP=4) | **7.51 → 4.50 s** (−40%), step 793 → 411 ms, lossless | QIE transformer is compute-bound but V3 only used 8 of the chip's 32 logical cores; doubling cores (world=8 → 16) halves the step. Only a lever because CFG=1 runs a single forward (no DP=2 batching to exploit). Compile with `--tp_degree 4 --world_size 16`; run with `QIE_WORLD_SIZE=16 NEURON_RT_NUM_CORES=16`. | + +Output is visually equivalent across CP degrees (CP=4 vs CP=2 mean |Δ| 0.78%); not bit-exact +because the CP partition changes bf16 accumulation order. VAE batched-tile and WLO are +numerically clean (identical / bit-exact); CP scaling is lossless in the visual sense. + +### Per-optimization latency reduction + +Each delta is measured against the *then-current* baseline (they are applied sequentially), so +this is a cumulative path, **not** a simple sum. Workload: 896×1184 / 8-step / **CFG=1** +two-image virtual try-on on `trn2.48xlarge`. + +| step | optimization | E2E | what moved | +|---|---|---|---| +| 0 | baseline (V3 CP=2, fp32 reduce, per-tile VAE) | ~7.89 s | — | +| 1 | + bf16 TP all-reduce | ~7.5 s | TP all-reduce bytes halved (~−9% step) | +| 2 | + VAE batched-tile | 7.41 s | VAE enc 559→298 ms, dec 422→192 ms (−6.1% E2E) | +| 3 | + WLO | ~7.3 s | weight layout for inference (−3.2% step, bit-exact) | +| 4 | + CP scaling (CP=2 → CP=4) | **4.50 s** | transformer step 793 → 411 ms (2× cores) | + +> bf16 TP all-reduce and VAE batched-tile predate this round (part of the base contribution); +> the **new** work here is **WLO** and **CP scaling**, which take the verified-correct config +> from 7.51 s to **4.50 s**. After CP scaling the transformer is ~2.3 s of the 4.5 s, so the +> fixed text-encoder + VAE (~1.1 s) is now the dominant remaining cost — which is exactly why +> the VAE batched-tile win is load-bearing in this few-step regime. + +### Why TP=4 × CP=4 (not CP=8 or TP=8) +- **CP=8** (32 cores) reaches 4.10 s but the gain over CP=4 is only ~0.4 s — transformer + marginal return drops to 0.70× (from 0.52×) and the larger world=32 adds ~200 ms of DP + overhead to the TP=4 vision/LM. CP=4 is the sweet spot. +- **TP=8** triggers a flash-kernel seqlen-shard fallback and the 8-rank TP group has no + replica-group mapping in NeuronX for world<64. TP=4 (6 heads/rank) maps cleanly to the torus. +- CP must be a power of 2 (world ∈ {8,16,32}); the runtime rejects other core counts (e.g. + world=12 → "Unsupported topology"). + +## Key env vars (compile / run) + +- `QIE_WLO=1` (default) — weight layout optimization (bit-exact speedup). +- `QIE_WORLD_SIZE=16`, `NEURON_RT_NUM_CORES=16` — required at run time for CP=4. +- `QIE_ALLREDUCE_BF16=1` (default), `QIE_OPT_LEVEL=2`, `QIE_CC_TILING=4` — existing tuned defaults. +- Use the SHORT prompt `让图2的模特换上图1的下装` for try-on (the long prompt mis-edits on this ckpt). + +## Environment + +- Instance `trn2.48xlarge` (64 NeuronCores = 32 logical at LNC2) +- Venv `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (PyTorch 2.9, neuronx-cc 2.22, neuronx-distributed 0.16) +- `PYTHONPATH=src:$PYTHONPATH` for both compile and run diff --git a/contrib/models/Qwen-Image-Edit/README.md b/contrib/models/Qwen-Image-Edit/README.md new file mode 100644 index 00000000..2d9d8460 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/README.md @@ -0,0 +1,219 @@ +# Contrib Model: Qwen-Image-Edit + +NeuronX adaptation of [alibaba-pai/Qwen-Image-Edit-2509](https://huggingface.co/alibaba-pai/Qwen-Image-Edit-2509) for AWS Trainium2 inference. + +> **Latency-optimized config (recommended for few-step / CFG=1):** TP=4 × CP=4 (world=16) + +> WLO runs the 896×1184 / 8-step / **CFG=1** virtual try-on (few-step distillation–finetuned +> checkpoint) at **~4.5 s end-to-end on trn2.48xlarge**, beating the H100 vLLM-Omni reference +> (4.99 s), lossless. This regime is what the optimizations target — with CFG=1 a single +> forward per step makes Context Parallel the right lever; for CFG>1 use the V3 CFG (DP=2) path +> instead. See [`OPTIMIZATION.md`](OPTIMIZATION.md) for the full writeup (VAE batched-tile, +> bf16 TP all-reduce, WLO, CP scaling — each with measured latency deltas) and +> [`release_v3cp4_wlo/`](release_v3cp4_wlo/) for the ready-to-run production config. + +## Model Information + +- **HuggingFace ID:** `alibaba-pai/Qwen-Image-Edit-2509` +- **Model Type:** Diffusion model for image editing +- **Architecture:** Multi-component (Qwen2.5-VL Vision Encoder + Language Model + QwenImageTransformer2DModel + 3D VAE) +- **License:** Check HuggingFace model card + +## Architecture Details + +| Component | Model | Parameters | Neuron Parallelism | +|-----------|-------|------------|-------------------| +| Vision Encoder | Qwen2.5-VL ViT (32 blocks) | ~1.4B | TP=4, float32 (or CPU) | +| Language Model | Qwen2.5-VL LM (28 layers) | ~7B | TP=4, world_size=8 (or CPU) | +| Transformer | QwenImageTransformer2DModel | ~20.4B | TP=4-8, various parallelism modes | +| VAE | 3D AutoencoderKL (causal) | ~300M | Single device, tiled processing | + +Key parameters: +- **Transformer**: 48 attention heads, head_dim=128, inner_dim=6144 +- **Text Hidden Size**: 3584 (Qwen2.5-VL) +- **Dual-stream blocks**: 20 (separate text/image norms+FFN, joint attention) +- **Single-stream blocks**: 40 (concatenated text+image, parallel MLP+attention) + +## Performance + +6 compilation APIs with different parallelism strategies: + +| Version | Parallelism | Attention | Per Step | Total (50 steps) | Notes | +|---------|------------|-----------|----------|-----------------|-------| +| **V3 CFG** | TP=4, DP=2 | NKI Flash | **~0.75s** | **~53s** | Fastest, recommended | +| V3 CP | TP=4, CP=2 | NKI Flash | ~0.77s | ~55s | Context Parallel | +| V1 Flash | TP=8 | NKI Flash | ~1.2s | ~76s | NKI kernel | +| V2 Flash | TP=8 | NKI Flash | ~1.2s | ~76s | ModelBuilder + NKI | +| V2 | TP=8 | Standard SDPA | ~1.2s | ~76s | ModelBuilder | +| V1 | TP=8 | Standard SDPA | ~2.4s | ~136s | Baseline | + +Test: 1024x1024 output, guidance_scale=4.0, trn2.48xlarge, single-image +editing (`patch_multiplier=2`). Total time includes VAE encoding/decoding +and text encoding overhead. + +> Note: a two-image merge (`patch_multiplier=3`) processes a longer joint +> sequence (`S = 1024 + 12288 = 13312`) and lands at roughly 1.6× the +> per-step time of single-image editing on the same configuration. + +## Prerequisites + +- **Instance**: trn2.48xlarge (64 NeuronCores, 1.5TB device memory) +- **Virtual env**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` + - PyTorch 2.9, neuronx-cc 2.22, neuronx-distributed 0.16 +- **NVMe**: Mount RAID at `/opt/dlami/nvme/` (run `src/setup_nvme.sh`) + +## Usage + +### 1. Setup + +```bash +# Mount NVMe RAID +sudo bash src/setup_nvme.sh + +# Activate virtual environment +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Install dependencies +pip install -r requirements.txt +``` + +### 2. Download Model + +```bash +python src/cache_hf_model.py +``` + +### 3. Compile All Components + +```bash +# Compile V3 CFG (recommended, fastest) +bash src/compile.sh v3_cfg + +# Compile V3 CP (Context Parallel) +bash src/compile.sh v3_cp + +# Compile all versions +bash src/compile.sh + +# Custom dimensions: +# bash src/compile.sh +``` + +Compilation takes ~60-120 minutes total depending on version. + +### 4. Run Inference + +`compile.sh` defaults to `patch_multiplier=3` (two-image merge), so the +example below uses two input images. For single-image editing, recompile +with `patch_multiplier=2` first. + +```bash +# Two-image merge (matches compile.sh default of patch_multiplier=3) +NEURON_RT_NUM_CORES=8 PYTHONPATH=src:$PYTHONPATH python src/run_qwen_image_edit.py \ + --compiled_models_dir /opt/dlami/nvme/compiled_models_qwen_image_edit \ + --images assets/image1.png assets/image2.png \ + --prompt "merge subjects from image1 and image2 into a single scene" \ + --patch_multiplier 3 \ + --use_v3_cfg \ + --output output.png + +# Single-image editing (requires recompilation with patch_multiplier=2) +# bash src/compile.sh v3_cfg 1024 1024 448 8 1024 2 1 +# NEURON_RT_NUM_CORES=8 PYTHONPATH=src:$PYTHONPATH python src/run_qwen_image_edit.py \ +# --compiled_models_dir /opt/dlami/nvme/compiled_models_qwen_image_edit \ +# --images assets/image1.png \ +# --prompt "change the sky to sunset" \ +# --patch_multiplier 2 \ +# --use_v3_cfg \ +# --output output.png +``` + +### Runtime toggles + +The transformer dispatch in `compile_transformer_v3_cfg.py` reads four +optional environment variables: + +| Variable | Default | Effect | +|----------|---------|--------| +| `QIE_HOISTED_Q_ATTENTION` | `1` | Use the Phase 16 hoisted-Q `attention_cte` fork. Set to `0` to fall back to upstream `nkilib.core.attention.attention_cte`. | +| `QIE_ALLREDUCE_BF16` | `1` | Phase 17: use `bfloat16` instead of `float32` as the reduce dtype for every TP `RowParallelLinear` all-reduce. Halves the bytes on the wire and saves ~137 ms / step (~9% E2E) with no visible image quality regression. Set to `0` to keep the upstream `fp32` reduce. | +| `QIE_USE_NKILIB_ATTENTION` | `1` | When the hoisted-Q fork is disabled, choose between `nkilib` `attention_cte` (`1`) and the legacy `attention_isa_kernel` (`0`). | +| `QIE_SOFTMAX_DTYPE` | `float32` | Softmax accumulation dtype inside `attention_cte`. `bfloat16` is supported but measured no speedup on Trn2 because `mm_out_dtype` must stay `float32` on Gen3. | + +## Compatibility Matrix + +| Instance/Version | 2.22+ (PyTorch 2.9) | 2.21 and earlier | +|------------------|---------------------|------------------| +| Trn2 (trn2.48xlarge) | Tested | Not tested | +| Trn1 | Not tested | Not tested | +| Inf2 | Not supported | Not supported | + +## Testing + +```bash +# Run component tests +PYTHONPATH=src:$PYTHONPATH pytest test/integration/ --capture=tee-sys -v + +# Run all tests manually +PYTHONPATH=src:$PYTHONPATH python test/integration/run_all_tests.py +``` + +## Key Implementation Notes + +1. **Modulation Layer Sharding**: Uses `ColumnParallelLinear(gather_output=True)` to reduce memory from ~17GB to ~5.2GB per shard. +2. **RoPE Without Complex Numbers**: Neuron doesn't support C64; uses (cos, sin) tuples instead. +3. **M-RoPE Position IDs**: 3D position indices (temporal, height, width) for multimodal tokens. +4. **VAE Interpolation**: Replaces `nearest-exact` with `nearest` for Neuron compatibility. +5. **CFG Parallel**: Batches negative + positive prompts into single forward pass for ~6% speedup over CP. +6. **NKI Flash Attention**: Custom NKI kernel for Trainium2, requires `XLA_DISABLE_FUNCTIONALIZATION=1`. +7. **Hoisted-Q `attention_cte` (Phase 16)**: forked kernel (`attention_cte_qie_hoisted_q.py`) that hoists the Q-tile load out of the K/V section loop. For QIE shapes (`num_sections = 2` and Q identical across sections), this removes the redundant Q reload that the upstream `attention_cte` performs in `section_idx=1`. Bit-exact to the baseline (same `hardware_flops`, same output PNG MD5); −14.7 ms / step, +0.52 pp MFU on the V3 CFG configuration. Toggle via `QIE_HOISTED_Q_ATTENTION` (default `1`). +8. **bf16 TP all-reduce (Phase 17)**: drops the all-reduce dtype on every `RowParallelLinear` (attention output, attention text-output, MLP output) from `float32` to `bfloat16`. The V3 CFG configuration emits 956 TP all-reduces per step totalling ~18 GB; halving the bytes on the wire saves **~137 ms / step (~9% E2E)** — by far the largest single win in this contrib's optimization history. Output images are visually equivalent to the `fp32`-reduce baseline on the 1024×1024 two-image merge workload; not bit-exact (small numerical drift accumulates over 60 blocks × 40 steps, but does not change scene content / faces / composition). Toggle via `QIE_ALLREDUCE_BF16` (default `1`). + +## File Structure + +``` +Qwen-Image-Edit/ + README.md + requirements.txt + assets/ + image1.png, image2.png # Test input images + src/ + run_qwen_image_edit.py # Main inference script + neuron_commons.py # NeuronTextEncoderWrapper, SDPA implementations + neuron_parallel_utils.py # TP sharding utilities + neuron_rope.py # Neuron-compatible RoPE + autoencoder_kl_qwenimage_neuron.py # Neuron-compatible 3D VAE + compile_transformer.py # V1 transformer (TP=8) + compile_transformer_v1_flash.py # V1 Flash (NKI) + compile_transformer_v2.py # V2 (ModelBuilder) + compile_transformer_v2_flash.py # V2 Flash (ModelBuilder + NKI) + compile_transformer_v3_cp.py # V3 Context Parallel (TP=4, CP=2) + compile_transformer_v3_cfg.py # V3 CFG Parallel (TP=4, DP=2) + attention_cte_qie_hoisted_q.py # Phase 16: hoisted-Q attention_cte fork + compile_language_model_v3.py # Language Model V3 (TP=4) + compile_vision_encoder_v3.py # Vision Encoder V3 (TP=4) + compile_text_encoder.py # Vision encoder single-device + compile_vae.py # 3D VAE encoder/decoder + cache_hf_model.py # Download model + compile.sh # Master compilation script + setup_nvme.sh # NVMe RAID setup + test/ + integration/ + run_all_tests.py # Master test runner + test_vae.py # VAE tests + test_transformer.py # Transformer tests + test_text_encoder.py # Text encoder tests + test_component_comparison.py # Neuron vs CPU comparison + test_language_model_simple.py # Language model tests + test_multimodal.py # Multi-image tests + unit/ +``` + +## Example Checkpoints + +* [alibaba-pai/Qwen-Image-Edit-2509](https://huggingface.co/alibaba-pai/Qwen-Image-Edit-2509) + +## Maintainer + +Henan Wan (whn09) + +**Last Updated:** 2026-05-14 diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/README.md b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/README.md new file mode 100644 index 00000000..52754497 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/README.md @@ -0,0 +1,90 @@ +# V3 CP=4 + WLO — FINAL / RECOMMENDED QIE try-on config on Trainium2 (beats H100) + +**TP=4 × CP=4 (world=16) + WLO — the chosen production configuration.** Doubles the cores +used vs the CP=2 baseline (16 of the chip's 32 logical cores). Because the QIE transformer is +**compute-bound**, doubling cores **halves the transformer step** — and comms (KV all-gather +over 4 vs 2) does NOT eat the gain. + +> Why CP=4 and not CP=8: CP=8 (world=32, all 32 cores) runs at 4.10s but the gain over CP=4's +> 4.50s is small (only ~0.4s) — the transformer step's marginal return drops (16→32 cores = +> 0.70×, not 0.5×) and the larger world=32 adds ~200ms of DP overhead to the TP=4 vision/LM +> (which don't benefit from more cores). CP=4 (world=16) is the sweet spot: near-linear +> transformer speedup without the diminishing returns / vision-LM penalty of world=32. + +## Measured results (2026-06-01, verified outputs) + +| config | cores | warm step | E2E | output | +|---|---|---|---|---| +| TP4×CP2 + WLO (prev best) | 8 | 793 ms | 7.51 s | correct | +| **TP4×CP4 + WLO** | **16** | **411 ms** | **4.50 s** | correct | +| Δ | 2× | **−48%** | **−3.01 s (−40%)** | visually equiv | + +**4.50 s beats H100 vLLM-Omni's 4.99 s** — first time QIE on Trn2 matches/beats H100, lossless +(more cores + bit-exact WLO). Output is the correct dark/green tiered maxi skirt (matching the +black-skirt cloth input), visually equivalent to both CP=2 and the early known-good 05-23 image. + +Stage breakdown (CP=4): text-enc ~450 ms, VAE-enc ~310 ms, transformer 8× ~411 ms (step1 +~750 ms warmup), VAE-dec ~200 ms. + +## Quality + +CP=4 output vs CP=2 baseline: **mean |Δ| = 0.78%** (px>5: 4.7%); vs the early known-good +05-23 image: 0.49%. NOT bit-exact (CP degree changes the sequence partition → different bf16 +accumulation order) but visually equivalent — same garment/pose/face. `test_cp4_quality.py` +asserts mean |Δ| < 2/255. See `cp4_vs_cp2.png`. + +## Prompt matters (lesson learned the hard way) + +Use the **short** prompt `让图2的模特换上图1的下装`. During packaging I briefly used a long +prompt ("把右图模特腰部以下的下装换成…全部保持…") and got **khaki shorts** (the input cloth is a +BLACK maxi skirt) — and almost shipped it. Always diff the output against the input cloth, not +just "did it run". The short prompt is what was validated on this step4000 checkpoint. + +(Note: an earlier blurry result I first blamed on zero-padding of CP-alignment patches was +actually the long-prompt issue. Zero-padding is fine — verified correct at both CP=2 and CP=4. +The padding code is the original zero-pad.) + +## Why TP=4×CP=4 and not TP=8×CP=2 (same 16 cores)? + +Tested both. TP=8 → 3 heads/rank → flash kernel falls back to seqlen-sharding (needs seqlen_q +div by pow2 ≥512; fixed with padding, compiles). But **to_neuron crashes**: +`failed to init a collective algorithm for provided replica group` / `Failed to find device +to device paths`. + +**Corrected root cause (NOT a hardware/topology limit):** trn2.48xlarge is a 2D torus (16 +cards / 64 cores) and CAN physically do TP=8×CP=2. The real issue is `neuronx_distributed`'s +replica-group mapping (`parallel_state.get_logic_chosen`) does not cover LNC2 + TP=8 + +world<64: it lays the TP-group as (0..7)(8..15), but on the torus device 0 and 8 aren't +directly connected. The kernel docstring explicitly lists this as *Not Supported* (VNC2 8×8: +both LOGIC1 and LOGIC2 fail; needs a (0..7)(16..23)-style mapping NxD hasn't implemented), and +a `world_size<64 → force LOGIC1` fallback locks world=16 into the non-working LOGIC1. It could +be worked around with `ModelBuilder(init_custom_process_group_fn=…)` injecting a torus-friendly +group, but that's nontrivial — AND even if unblocked it likely won't beat TP=4×CP=4 (same 16 +cores/compute, but TP=8's all-reduce spans 8 ranks — pure-TP8 measured 9.74 s for that reason). + +TP=4 keeps 6 heads/rank (no kernel fallback) and a 4-rank group that maps cleanly to the torus +under the stock LOGIC1. **TP=4×CP=4 is the practical optimum at world=16** — and this is finally +WHY V3 uses TP=4 (not just the LM's GQA — TP=4's collective group also fits NxD/torus mapping). + +## Files +- `compile.sh` — compile transformer + vision + LM all at world=16 (+WLO). VAE symlinked. +- `run_tryon.sh [out.png]` — run try-on (short prompt, `QIE_WORLD_SIZE=16 NEURON_RT_NUM_CORES=16`). +- `test_cp4_quality.py ` — assert CP=4 ≈ CP=2 (mean |Δ| < 2/255). +- `tryon_cp4.png` / `tryon_cp2_baseline.png` / `cp4_vs_cp2.png` — verified reference outputs + diff. + +## Reproduce +```bash +bash release_v3cp4_wlo/compile.sh # ~35 min (full shard, world=16) +bash release_v3cp4_wlo/run_tryon.sh release_v3cp4_wlo/tryon_cp4.png # ~4.5 s +python release_v3cp4_wlo/test_cp4_quality.py \ + release_v3cp4_wlo/tryon_cp4.png release_v3cp4_wlo/tryon_cp2_baseline.png +``` + +## Environment +- `trn2.48xlarge` (64 NeuronCores = 32 logical at LNC2), venv + `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` +- Runtime: `NEURON_RT_NUM_CORES=16`, `QIE_WORLD_SIZE=16`, `PYTHONPATH=src:$PYTHONPATH` +- Compile defaults: `QIE_ALLREDUCE_BF16=1`, `QIE_OPT_LEVEL=2`, `QIE_CC_TILING=4`, `QIE_WLO=1` + +## Next: CP=8 (world=32, all 32 cores) +Projected ~205 ms/step → E2E ~3 s. TP=4 so no topology issue. Not yet validated end-to-end. diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/compile.sh b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/compile.sh new file mode 100755 index 00000000..df3fb0b2 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/compile.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# V3 CP=4 + WLO compile (TP=4 × CP=4 = world 16, uses 16 of the chip's 32 logical cores). +# +# This is the fast configuration: doubling cores vs the world=8 baseline (CP=2) halves +# the transformer step (793ms → 411ms) because QIE is compute-bound. Plus WLO (bit-exact). +# E2E 4.72s — beats H100 vLLM-Omni's 4.99s. Output is correct (see tryon_cp4.png). +# +# Three components must ALL be compiled at world=16: +# - transformer: --tp_degree 4 --world_size 16 (cp_degree = 16/4 = 4) +# - vision: VISION_WORLD_SIZE=16 +# - language model: LM_WORLD_SIZE=16 (keep --max_sequence_length 1024 to match runtime) +# VAE is single-device (world-agnostic) — symlinked from the CP=2 build. +set -euo pipefail +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +cd "$(dirname "$0")/.." # repo root + +export PYTHONPATH=src:${PYTHONPATH:-} +export QIE_ALLREDUCE_BF16=1 QIE_OPT_LEVEL=2 QIE_CC_TILING=4 QIE_WLO=1 + +MODEL_PATH="${QIE_MODEL_PATH:-/home/ubuntu/checkpoints/Qwen-Image-Edit-2509-step4000}" +DST="${QIE_OUT_DIR:-/opt/dlami/nvme/compiled_models_qwen_image_edit_step4000_896x1184_cp4}" +VAE_SRC="${QIE_VAE_SRC:-/opt/dlami/nvme/compiled_models_qwen_image_edit_step4000_896x1184_vaebatch6}" + +echo "=== transformer TP=4 CP=4 world=16 (+WLO) ===" +python src/compile_transformer_v3_cp.py \ + --model_path "$MODEL_PATH" \ + --height 1184 --width 896 --max_sequence_length 512 \ + --patch_multiplier 3 --tp_degree 4 --world_size 16 --batch_size 1 \ + --compiled_models_dir "$DST" --compiler_workdir /opt/dlami/nvme/cw_cp4 2>&1 | tail -4 + +echo "=== vision encoder world=16 ===" +VISION_WORLD_SIZE=16 python src/compile_vision_encoder_v3.py \ + --model_path "$MODEL_PATH" \ + --compiled_models_dir "$DST" --compiler_workdir /opt/dlami/nvme/cw_vis16 2>&1 | tail -3 + +echo "=== language model world=16 (max_seq 1024) ===" +LM_WORLD_SIZE=16 python src/compile_language_model_v3.py \ + --model_path "$MODEL_PATH" --max_sequence_length 1024 \ + --compiled_models_dir "$DST" --compiler_workdir /opt/dlami/nvme/cw_lm16 2>&1 | tail -3 + +echo "=== symlink VAE (single device) ===" +for c in vae_encoder vae_decoder quant_conv post_quant_conv vae_config.json; do + [ -e "$VAE_SRC/$c" ] && ln -sfn "$VAE_SRC/$c" "$DST/$c" +done +echo "DONE — compiled to $DST" diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/cp4_vs_cp2.png b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/cp4_vs_cp2.png new file mode 100644 index 00000000..5a3318df Binary files /dev/null and b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/cp4_vs_cp2.png differ diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/run_tryon.sh b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/run_tryon.sh new file mode 100755 index 00000000..3d401a9b --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/run_tryon.sh @@ -0,0 +1,29 @@ +#!/bin/bash +# V3 CP=4 + WLO try-on inference (TP=4 × CP=4 = world 16, 16 cores). ~4.7s E2E. +# Cloth (image 1) onto model (image 2); lower-body garment replaced. +# +# KEY runtime env vs CP=2: QIE_WORLD_SIZE=16 and NEURON_RT_NUM_CORES=16 (uses 16 cores). +set -euo pipefail +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +cd "$(dirname "$0")/.." # repo root + +export NEURON_RT_NUM_CORES=16 +export QIE_WORLD_SIZE=16 +export PYTHONPATH=src:${PYTHONPATH:-} + +OUT_DIR="${QIE_OUT_DIR:-/opt/dlami/nvme/compiled_models_qwen_image_edit_step4000_896x1184_cp4}" +OUT_IMG="${1:-release_v3cp4_wlo/tryon_cp4.png}" +CLOTH="${QIE_CLOTH:-cloth/1686634914e5521d5145f5c95c1b4ee70560881686.jpg}" +MODEL="${QIE_MODEL_IMG:-input_img/1764042352dfda7d588f8da62b2d3aea69d3889bcb.webp}" + +python src/run_qwen_image_edit.py \ + --images "$CLOTH" "$MODEL" \ + --prompt "让图2的模特换上图1的下装" \ + --negative_prompt "" \ + --output "$OUT_IMG" \ + --height 1184 --width 896 --image_h 448 --image_w 336 \ + --patch_multiplier 3 --max_sequence_length 1024 \ + --num_inference_steps 8 --true_cfg_scale 1.0 --seed 42 \ + --compiled_models_dir "$OUT_DIR" \ + --use_v3_cp \ + 2>&1 diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/test_cp4_quality.py b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/test_cp4_quality.py new file mode 100644 index 00000000..3fc1caf5 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/test_cp4_quality.py @@ -0,0 +1,38 @@ +""" +Validate CP=4 output is visually equivalent to the CP=2 baseline (same seed/inputs). + +CP=4 is NOT bit-exact vs CP=2 — changing the CP degree changes the sequence partition +and therefore the bf16 accumulation order (and the pad count: 368 vs 112). The result is +ULP-level drift that is visually equivalent, not pixel-identical. This test asserts the +structural agreement is within a visual-equivalence threshold (mean |Δ| < 8/255, the +garment/pose/face preserved), NOT max|Δ|=0. + +Usage: python test_cp4_quality.py +""" +import sys +import numpy as np +from PIL import Image + +THRESH_MEAN_PCT = 2.0 # mean |Δ| as % of 255; visual-equivalence bound + +def main(): + if len(sys.argv) != 3: + print("usage: python test_cp4_quality.py ") + return 2 + a = np.asarray(Image.open(sys.argv[1]).convert("RGB")).astype(np.int32) + b = np.asarray(Image.open(sys.argv[2]).convert("RGB")).astype(np.int32) + if a.shape != b.shape: + print(f"FAIL: shape mismatch {a.shape} vs {b.shape}") + return 1 + d = np.abs(a - b) + mean_pct = d.mean() / 255 * 100 + p5 = 100 * np.mean(d.max(-1) > 5) + print(f"mean|Δ|={mean_pct:.2f}% px>5={p5:.1f}% max|Δ|={int(d.max())}") + if mean_pct < THRESH_MEAN_PCT: + print(f"PASS — CP=4 visually equivalent to CP=2 (mean < {THRESH_MEAN_PCT}%)") + return 0 + print(f"FAIL — CP=4 differs too much from CP=2 (mean >= {THRESH_MEAN_PCT}%)") + return 1 + +if __name__ == "__main__": + sys.exit(main()) diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/tryon_cp2_baseline.png b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/tryon_cp2_baseline.png new file mode 100644 index 00000000..44fab437 Binary files /dev/null and b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/tryon_cp2_baseline.png differ diff --git a/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/tryon_cp4.png b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/tryon_cp4.png new file mode 100644 index 00000000..9963858e Binary files /dev/null and b/contrib/models/Qwen-Image-Edit/release_v3cp4_wlo/tryon_cp4.png differ diff --git a/contrib/models/Qwen-Image-Edit/requirements.txt b/contrib/models/Qwen-Image-Edit/requirements.txt new file mode 100644 index 00000000..aa83bd8f --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/requirements.txt @@ -0,0 +1,6 @@ +diffusers @ git+https://github.com/huggingface/diffusers +transformers>=4.45.0 +accelerate +qwen-vl-utils +torchvision +pillow diff --git a/contrib/models/Qwen-Image-Edit/src/__init__.py b/contrib/models/Qwen-Image-Edit/src/__init__.py new file mode 100644 index 00000000..8761f6cf --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/__init__.py @@ -0,0 +1 @@ +# Neuron implementation for Qwen-Image-Edit-2509 diff --git a/contrib/models/Qwen-Image-Edit/src/attention_cte_qie_hoisted_q.py b/contrib/models/Qwen-Image-Edit/src/attention_cte_qie_hoisted_q.py new file mode 100644 index 00000000..6861482d --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/attention_cte_qie_hoisted_q.py @@ -0,0 +1,2955 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Attention Kernel For Context Encoding (Prefill) + +ALGORITHM OVERVIEW: + +Base Attention Algorithm: +The kernel computes: Output = softmax(scale * Q @ K^T) @ V + +Breaking this down into steps: + 1. Compute attention scores: S = scale * Q @ K^T (matmul, called MM1) + 2. Apply masking to S (causal, sliding window, etc.) + 3. Compute row-wise softmax: P = softmax(S) = exp(S - max(S)) / sum(exp(S - max(S))) + 4. Compute final output: O = P @ V (matmul, called MM2) + +The transpose flags (tp_q, tp_k, tp_out) control input/output tensor layouts without +changing the mathematical computation. + +For a simple PyTorch reference implementation, see attention_cte_torch.py + +FEATURES: + +1. Causal Masking (causal_mask=True): + - Masks upper triangle of attention scores: S[i,j] = -inf when i < j + - Enables compute skipping: skip MM1/MM2 for upper triangle tiles + +2. Sliding Window Attention (SWA, when sliding_window > 0): + - Local attention: each query only attends to nearby keys within a window + - Masks attention scores: S[i,j] = -inf when |i - j| > sliding_window + - Currently only works with causal: masks both upper triangle AND positions outside window + - When used with CP: loads only required KV slice to save memory + +3. Context Parallelism (CP, global_cp_deg > 1, cp_offset is not None): + - Distributes long sequence computation across multiple devices/ranks + - Each rank (kernel call) processes a slice of Q sequence with full K/V + - cp_offset indicates which Q slice this rank handles (runtime value) + - Optionally supports strided Q slicing for better load balancing across CP ranks + - Requires dynamic masking since offset unknown at compile time + - Currently only supports causal attention + +4. Prefix Caching (k_prior/v_prior provided): + - K/V split into two parts: prior (cached) and active (current) + - prior_used_len specifies how much of prior to use (dynamic mask) + - Causal mask not required for prior portion (although we still apply SWA if applicable) + +5. Sink Tokens (sink provided): + - Add additional sink token to softmax denominator + +6. Sequence Packing (bound_min/bound_max provided): + - Multiple independent sequences are packed into a single tensor + - Each query position has a [bound_min, bound_max) range defining which KV positions it attends to + - Positions outside the range are masked with -inf, preventing cross-sequence attention + - Compatible with causal masking (both masks are applied simultaneously) + - Not compatible with prefix caching or context parallelism + +7. Grouped Query Attention (GQA, batch_size_kv < batch_size): + - Kernel handles GQA natively without explicit K/V replication + +7. Support for training: + - Kernel can optionally return maximum attention score and softmax denominator (per row) for backpropagation. + +IMPLEMENTATION DETAILS AND LOOP STRUCTURE: + +Level 1: LNC2 Sharding (on Trn2+) + - Shards computation across 2 NeuronCores (LNC=2) + - Primary sharding: Divides batch dimension evenly + - Secondary sharding (for odd batch): Last batch item sharded on seqlen_q + * Uses unequal split (65%/35%) for causal attention to balance load + * Falls back to single core for short sequences (< 1024 tokens) + +Level 2: Batch Loop + - Iterates over batch items assigned to this core + - Each batch item processes independently + - For GQA: maps Q batch_id to correct KV batch_id + +Level 3: Section Loop (Flash Attention for long sequences) + - For K/V length > 10K tokens: divide into 8K-token sections + - Process one section at a time to fit in SBUF memory + - Maintains running statistics (max, sum) across sections + - Final output computed using flash attention rescaling + - For short sequences: single section contains all K/V + - Check https://arxiv.org/abs/2205.14135 and https://arxiv.org/abs/2307.08691 + for more details about flash attention. + +Level 4: Group Loop (Q sequence processing) + - Q sequence divided into groups of 128 tokens (_Q_GRP_SZ) + - Each group processes independently within a section + - Software pipelining: overlaps operations across groups (i, i+1, i+2) + * Group i: PV computation, writeback + * Group i+1: Exp computation + * Group i+2: Q load, QK computation + - Uses modular allocation for efficient buffer reuse + +INTENDED USAGE: + +The kernel supports sequence lengths up to 36864 and is optimized for q sequence +length larger than ~256 (i.e., prefill/context encoding workloads). The head dimension (d) +can be up to 128. Batch size up to 16 has been tested. + +Input dtypes can be bfloat16, float16 or float32. The kernel uses float32 for softmax and +bfloat16 for other operations. + +""" + +from dataclasses import dataclass +from typing import Any, Optional + +import nki +import nki.isa as nisa +import nki.language as nl +from nki.isa import reduce_cmd + +from nkilib.core.utils.allocator import align_to +from nkilib.core.utils.kernel_assert import assert_shape, kernel_assert +from nkilib.core.utils.kernel_helpers import PSUM_BANK_SIZE, div_ceil, get_verified_program_sharding_info +from nkilib.core.utils.logging import get_logger +from nkilib.core.utils.modular_allocator import ModularAllocator +from nkilib.core.utils.stream_shuffle_broadcast import stream_shuffle_broadcast + +logger = get_logger("attention_cte") + +_FLOAT32_MIN = -3.4028235e38 # used for initialization and masking + +""" +Kernel constraints (based on tested range, values outside range might work in practice) +""" +_MAX_BS = 512 # max tested batch size +_MAX_SEQLEN = 131072 # max allowed seqlen +_MAX_BS_TIMES_SEQLEN_QK = 32.0 * 36864 * 36864 # max tested bs*seqlen_q*seqlen_k +_MAX_HEAD_DIM = 128 # max supported head dim (d) +_MIN_GLOBAL_CP_DEGREE = 1 # minimum context parallel degree +_MAX_GLOBAL_CP_DEGREE = 32 # minimum context parallel degree + + +""" +Sharding, tile size and threshold related constants +""" +_MIN_SEQLEN_FOR_LNC2_SHARDING = 1024 # if odd batch, then shard on sequence if len above this +_SEQLEN_SHARDING_SPLIT_FACTOR_DEFAULT = 0.5 # When sharding LNC2 on seqlen, pass 50% of seqlen to each shard +_SEQLEN_SHARDING_SPLIT_FACTOR_CAUSAL = ( + 0.65 # When sharding LNC2 on seqlen, split 65%-35% in causal case to balance compute +) +_Q_GRP_SZ = 128 +_V_TILE_SZ = 128 # V tile size for loading and MM2 operations +_K_TILE_SZ = 512 # K tile size for loading and MM1+masking operations +_EXP_TILE_SZ = 512 # Tile size for exp instructions (must equal _K_TILE_SZ) +_LARGE_TILE_SZ = 2048 # Larger tile size for allocations/pipelining (4 x 512 tiles) +_FLASH_ATTENTION_THRESHOLD = 10 * 1024 # Use flash attention above this K/V length +_FLASH_ATTENTION_SECTION_LENGTH = 8 * 1024 # Section size when using flash attention +_SWA_ALLOCATION_STRATEGY_THRESHOLD = ( + 128 # for SWA, threshold above which allocate more q tiles and use range_select masking +) + + +@nki.jit +def attention_cte_hoisted_q( + q: nl.ndarray, + k: nl.ndarray, + v: nl.ndarray, + scale: float = 1.0, + causal_mask: bool = True, + k_prior: Optional[nl.ndarray] = None, + v_prior: Optional[nl.ndarray] = None, + prior_used_len: Optional[nl.ndarray] = None, + sink: Optional[nl.ndarray] = None, + sliding_window: Optional[int] = None, + tp_q: bool = True, + tp_k: bool = False, + tp_out: bool = False, + cache_softmax: bool = False, + softmax_dtype=nl.float32, + mm_out_dtype=nl.float32, + cp_offset: Optional[nl.ndarray] = None, + global_cp_deg: int = None, + cp_strided_q_slicing: bool = False, + bound_min: Optional[nl.ndarray] = None, + bound_max: Optional[nl.ndarray] = None, +): + """Entrypoint NKI kernel that supports multiple attention variants. + + The kernel can be invoked with 1D SPMD grid for LNC2 or without grid. + + Dimensions: + batch_size: Number of query sequences + batch_size_kv: Number of key/value sequences (for GQA) + seqlen_q: Query sequence length + seqlen_kv: Key/value sequence length + seqlen_prior: Prior key/value sequence length (prefix caching) + d: Head dimension size + + Args: + q (nt.tensor): Query tensor with layout dependent on tp_q parameter + k (nt.tensor): Key tensor with layout dependent on tp_k parameter + v (nt.tensor): Value tensor with shape (batch_size_kv, seqlen, d) + scale (float, optional): Scaling factor for attention scores. It must be set to 1.0 (default value) when + using sliding window, context parallel, or prefix caching. In these cases, q + can be scaled before calling the kernel. + causal_mask (bool, optional): whether to use causal mask (default True) + k_prior (nt.tensor, optional): (Prefix caching) Prior key tensor with layout dependent on tp_k parameter + v_prior (nt.tensor, optional): (Prefix caching) Prior value tensor with shape (batch_size_kv, seqlen_prior, d) + prior_used_len (nt.tensor, optional): (Prefix caching) Actual used length in prior with shape (1,) + sink (nt.tensor, optional): Sink token tensor with shape (batch_size, 1) + sliding_window (int, optional): Sliding window size for attention, None or 0 denotes no sliding window mask + tp_q (bool): Query tensor transpose flag (default True) + tp_k (bool): Key tensor transpose flag (default False) + tp_out (bool): Output tensor transpose flag (default False) + cache_softmax (bool): Whether to cache softmax intermediate values (default False) + softmax_dtype (nl.dtype): Data type for softmax computations (current implementation tested with float32) + cp_offset (nt.tensor, optional): Context parallel offset tensor with shape (1, 1) + global_cp_deg (int, optional): Global context parallel degree + cp_strided_q_slicing (bool, optional): Whether Q is strided for load balancing (default False) + bound_min (nt.tensor, optional): (Sequence packing) Per-query lower bound (inclusive) of the KV range + to attend to, with shape (seqlen_q, 1). Query position i attends only to KV positions j + where bound_min[i] <= j < bound_max[i]. Must be provided together with bound_max. + Not compatible with prefix caching or context parallelism. + bound_max (nt.tensor, optional): (Sequence packing) Per-query upper bound (exclusive) of the KV range + to attend to, with shape (seqlen_q, 1). Must be provided together with bound_min. + + Returns: + Output tensor with attention results. Shape depends on tp_out parameter. + If cache_softmax is True, returns tuple of (output, out_neg_max, out_sum_recip). + + IO Shapes: + - q: + (batch_size, seqlen_q, d) when tp_q is True + (batch_size, d, seqlen_q) when tp_q is False + - k: + (batch_size_kv, seqlen_kv, d) when tp_k is True + (batch_size_kv, d, seqlen_kv) when tp_k is False + - v: (batch_size_kv, seqlen_kv, d) + - returns output with shape: + (batch_size, d, seqlen_q) if tp_out is True + (batch_size, seqlen_q, d) if tp_out is False + + - The math performed is softmax(q @ k) @ v (details described in top-level documentation) + + Prefix Caching: + If k_prior, v_prior and prior_used_len are specified for prefix caching, the + computation is equivalent to prepending the prior k/v (up to prior_used_len) + behind the active k/v and shifting the mask accordingly. The shapes of k_prior + and v_prior must match the shapes of k and v respectively on the non-seqlen + dimensions, while the shape of prior_used_len is (1,). By setting prior_used_len, + the actual prefix length can be chosen dynamically at runtime (up to seqlen_prior). + + MHA and Native GQA Support: + For MHA attention, the heads can be included as part of the batch dimension. + For GQA, when batch_size_kv < batch_size, we expect that batch_size % batch_size_kv == 0, + and the computation is equivalent to first applying torch_interleave on K and V. + Note that GQA typically applies to nheads dimension but the kernel combines + batch and nheads dimensions. + + Softmax Caching (useful during training): + When cache_softmax is True and out_neg_max/out_sum_recip are provided, returns + the negative max and reciprocal sum in the softmax with shapes: + padded_seq_grps = ceil(seqlen_q / 128) (_Q_GRP_SZ = 128) + neg_max: (batch_size, 128, padded_seq_grps) + recip: (batch_size, 128, padded_seq_grps) + + Context Parallel Support: + Enabled when global_cp_deg is set. Since the Q seqlen offset is usually not + known at compile time (based on rank ID), it is expected to be a (1, 1) HBM + input to the kernel. global_cp_deg (a compile time constant int) denotes total + number of ranks / CP degree. + + When cp_strided_q_slicing is False, this indicates FAL has sharded Q into contiguous chunks. + In this case, cp_offset should be passed as rank_id * partial_q_seqlen + + When cp_strided_q_slicing is True, this indicates Q has been sharded in row-strided manner + where stride is global_cp_deg. In this case, cp_offset should be passed as rank_id. Note that + K & V are still assumed to be contiguous. + + As an example for seqlen_q=4, seqlen_kv=12, global_cp_deg=3, rank_id = 1: + cp_strided_q_slicing False => q seqlen slice is [4, 5, 6, 7] and cp_offset is 4. + cp_strided_q_slicing True => q seqlen slice is [1, 4, 7, 10] and cp_offset is 1. + In both cases, KV token order is simply [0, 1, 2, ..., 10, 11] + + Pseudocode: + ``` + # High-level algorithm (see module docstring for detailed implementation) + for each batch in batch_size: + for each section in K/V (flash attention sectioning): + for each Q group (128 tokens): + # MM1: Compute attention scores + scores = Q @ K^T * scale + + # Apply masking (causal, sliding window, CP, prefix caching) + scores = apply_masks(scores) + + # Softmax with running statistics for flash attention + max_score = max(scores) + exp_scores = exp(scores - max_score) + sum_exp = sum(exp_scores) + + # Update running statistics across sections + update_flash_attention_stats(max_score, sum_exp) + + # MM2: Compute output + output += exp_scores @ V + + # Normalize output using flash attention correction + if last_section: + output = output / sum_exp + ``` + + """ + if sliding_window is None: + sliding_window = 0 + + if k_prior is not None: + is_prefix_caching = True + kernel_assert(v_prior is not None, "k_prior is not None but v_prior is None for prefix caching") + kernel_assert( + prior_used_len is not None, + "k_prior is not None but prior_used_len is None for prefix caching", + ) + else: + is_prefix_caching = False + kernel_assert(v_prior is None, "k_prior is None but v_prior is not None.") + kernel_assert(prior_used_len is None, "k_prior is None but prior_used_len is not None.") + + kernel_assert( + (bound_min is None) == (bound_max is None), "bound_min and bound_max must both be set or both be None" + ) + # Sequence packing is active when per-query KV bounds are provided + is_sequence_packed = bound_min is not None + + seqlen_q, seqlen_k_active, seqlen_k_prior, d, out_shape, softmax_shape = _check_input_and_return_shape( + q, + k, + v, + is_prefix_caching, + k_prior, + v_prior, + prior_used_len, + tp_q, + tp_k, + tp_out, + cache_softmax, + ) + if is_sequence_packed: + kernel_assert( + bound_min.shape == (seqlen_q, 1), + f"bound_min shape must be (seqlen_q, 1)=({seqlen_q}, 1), got {bound_min.shape}", + ) + kernel_assert( + bound_max.shape == (seqlen_q, 1), + f"bound_max shape must be (seqlen_q, 1)=({seqlen_q}, 1), got {bound_max.shape}", + ) + kernel_assert(not is_prefix_caching, "is_sequence_packed is not supported with prefix caching") + kernel_assert( + global_cp_deg is None or global_cp_deg <= 1, "is_sequence_packed is not supported with context parallelism" + ) + + result = nl.ndarray(shape=out_shape, dtype=q.dtype, buffer=nl.shared_hbm) + + out_sum_recip, out_neg_max = None, None + if cache_softmax: + out_neg_max = nl.ndarray(shape=softmax_shape, dtype=softmax_dtype, buffer=nl.shared_hbm) + out_sum_recip = nl.ndarray(shape=softmax_shape, dtype=softmax_dtype, buffer=nl.shared_hbm) + + bs = q.shape[0] + bs_kv = k.shape[0] + + # Batch size checks + kernel_assert(bs > 0, f"Batch size must be positive, got {bs}") + kernel_assert( + bs <= _MAX_BS, + f"attention_cte kernel is not tested for batch size above {_MAX_BS}, got {bs}.", + ) + kernel_assert(bs_kv > 0, f"Batch size must be positive, got {bs_kv}") + kernel_assert( + bs % bs_kv == 0, + f"Q batch size must be a multiple of KV batch size, got {bs=}, {bs_kv=}", + ) + + # Sequence length checks + seqlen_k_total = seqlen_k_active + seqlen_k_prior if seqlen_k_prior else seqlen_k_active + kernel_assert( + seqlen_q <= _MAX_SEQLEN, + f"attention_cte kernel is not tested for seqlen above {_MAX_SEQLEN}, got {seqlen_q=}.", + ) + kernel_assert( + seqlen_k_total <= _MAX_SEQLEN, + f"attention_cte kernel is not tested for seqlen above {_MAX_SEQLEN}, got {seqlen_k_total=}.", + ) + bs_seqlen_qk_product = float(bs * seqlen_q) * seqlen_k_total # use float to avoid overflow + if bs_seqlen_qk_product <= _MAX_BS_TIMES_SEQLEN_QK: + logger.warn( + f"attention_cte kernel is not tested for batch size x seqlen_q x seqlen_k above {_MAX_BS_TIMES_SEQLEN_QK}, got {bs_seqlen_qk_product=}.", + ) + kernel_assert( + sliding_window <= _MAX_SEQLEN, + f"attention_cte kernel is not tested for sliding window above {_MAX_SEQLEN}, got {sliding_window=}.", + ) + kernel_assert(sliding_window >= 0, f"sliding_window must be >= 0, got {sliding_window=}.") + + # head dim + kernel_assert(d > 0, f"d must be > 0, got {d=}.") + kernel_assert( + d <= _MAX_HEAD_DIM, + f"we do not support head_dim > {_MAX_HEAD_DIM}, got head dim {d}", + ) + + # mm_out_dtype + kernel_assert( + (str(mm_out_dtype) == str(nl.float32)) + or (str(mm_out_dtype) == str(nl.bfloat16) and (nisa.get_nc_version() >= nisa.nc_version.gen4)), + f"mm_out_dtype (psum) should be in [float32, bfloat16], and 2-byte dtype is only allows in gen4+ (Trn3+)," + f"but got dtype {mm_out_dtype} in hw version {nisa.get_nc_version()}.", + ) + + # Context parallel + if global_cp_deg: + kernel_assert( + _MIN_GLOBAL_CP_DEGREE <= global_cp_deg <= _MAX_GLOBAL_CP_DEGREE, + f"attention_cte kernel is not tested for global_cp_deg outside [{_MIN_GLOBAL_CP_DEGREE}, {_MAX_GLOBAL_CP_DEGREE}], " + f"got {global_cp_deg=}.", + ) + + # Create AttnConfig with high-level configuration + ac = AttnConfig( + seqlen_q=seqlen_q, + seqlen_k_active=seqlen_k_active, + seqlen_k_prior=seqlen_k_prior, + d=d, + tp_q=tp_q, + tp_k=tp_k, + tp_out=tp_out, + is_prefix_caching=is_prefix_caching, + causal_mask=causal_mask, + use_swa=sliding_window > 0, + sliding_window=sliding_window, + use_cp=global_cp_deg is not None, + global_cp_deg=global_cp_deg, + cp_strided_q_slicing=cp_strided_q_slicing, + scale=scale, + cache_softmax=cache_softmax, + dtype=q.dtype, + softmax_dtype=softmax_dtype, + mm_out_dtype=mm_out_dtype, + is_sequence_packed=is_sequence_packed, + ) + + grid_ndim, num_shard, shard_id = get_verified_program_sharding_info("attention_cte", max_sharding=2) + # Shard on batch size while it is divisible (if not sharded, num_shard = 1, shard_id = 0) + num_bs_per_shard = bs // num_shard + bs_offset = shard_id * num_bs_per_shard + + for batch_idx in range(num_bs_per_shard): + kv_batch_id = _q_to_kv_batch_id(batch_idx + bs_offset, bs, bs_kv) + _attention_cte_impl( + q, + k, + v, + k_prior, + v_prior, + prior_used_len, + result, + batch_idx + bs_offset, + kv_batch_id, + ac, + sink=sink, + out_neg_max=out_neg_max, + out_sum_recip=out_sum_recip, + cp_offset=cp_offset, + bound_min=bound_min, + bound_max=bound_max, + ) + + has_remainder = (bs % num_shard) != 0 + last_batch = bs - 1 + + # shard on seqlen_q for the remainder bs + if has_remainder: + last_batch_id = _q_to_kv_batch_id(last_batch, bs, bs_kv) + if seqlen_q >= _MIN_SEQLEN_FOR_LNC2_SHARDING: + # shard unequally on seqlen_q when causal mask and not sliding window/CP + # For CP we shard unequally when we have strided Q slicing. + use_causal_divide_factor = ( + causal_mask and (cp_offset is None or cp_strided_q_slicing) and (sliding_window == 0) + ) + divide_factor = ( + _SEQLEN_SHARDING_SPLIT_FACTOR_CAUSAL + if use_causal_divide_factor + else _SEQLEN_SHARDING_SPLIT_FACTOR_DEFAULT + ) + if is_prefix_caching and use_causal_divide_factor: + s_active, s_prior = v.shape[1], v_prior.shape[1] + divide_factor = ( + _SEQLEN_SHARDING_SPLIT_FACTOR_CAUSAL * s_active + _SEQLEN_SHARDING_SPLIT_FACTOR_DEFAULT * s_prior + ) / (s_active + s_prior) + + total_grps = div_ceil(seqlen_q, _Q_GRP_SZ) + batch_0_grp = int(total_grps * divide_factor) + batch_1_grp = total_grps - batch_0_grp + + batch_length = shard_id * batch_1_grp + (1 - shard_id) * batch_0_grp # shard_id is 0 or 1 + _attention_cte_impl( + q, + k, + v, + k_prior, + v_prior, + prior_used_len, + result, + last_batch, + last_batch_id, + ac, + sink=sink, + out_neg_max=out_neg_max, + out_sum_recip=out_sum_recip, + shard_seqlen_q_start=shard_id * batch_0_grp, + shard_seqlen_q_length=batch_length, + cp_offset=cp_offset, + bound_min=bound_min, + bound_max=bound_max, + ) + else: + # Have core 0 do all the work + if shard_id == 0: + _attention_cte_impl( + q, + k, + v, + k_prior, + v_prior, + prior_used_len, + result, + last_batch, + last_batch_id, + ac, + sink=sink, + out_neg_max=out_neg_max, + out_sum_recip=out_sum_recip, + cp_offset=cp_offset, + bound_min=bound_min, + bound_max=bound_max, + ) + + if cache_softmax: + return result, out_neg_max, out_sum_recip + else: + return result + + +@dataclass +class AttnConfig(nl.NKIObject): + """High-level attention configuration set at kernel entry point. + + Contains user-facing parameters and computed configuration flags. + """ + + # Sequence dimensions + seqlen_q: int = None + seqlen_k_active: int = None + seqlen_k_prior: int = None + d: int = None + + # Transpose flags + tp_q: bool = None + tp_k: bool = None + tp_out: bool = None + + # Masking configuration + is_prefix_caching: bool = None + causal_mask: bool = None + + # Sliding window attention + use_swa: bool = None + sliding_window: int = None + + # Context parallelism + use_cp: bool = None + global_cp_deg: int = None + cp_strided_q_slicing: bool = None + + # Other + scale: float = None + cache_softmax: bool = None + dtype: Any = None + softmax_dtype: Any = None + mm_out_dtype: Any = None + + # sequence packing + is_sequence_packed: bool = None + + +def _attention_cte_impl( + q, + k_active, + v_active, + k_prior, + v_prior, + prior_used_len, + o, + batch_id, + batch_id_kv, + ac: AttnConfig, + sink=None, + out_neg_max=None, + out_sum_recip=None, + shard_seqlen_q_start=-1, + shard_seqlen_q_length=-1, + cp_offset: Any = None, + bound_min: Any = None, + bound_max: Any = None, +): + """ + Internal implementation function for attention computation. + + This function processes a single batch and handles the core attention computation + with flash attention optimization, sectioning, and various masking strategies. + + Args: + q, k_active, v_active: Input tensors + k_prior, v_prior, prior_used_len: Prefix caching tensors + o: Output tensor + batch_id, batch_id_kv: Batch indices + ac: High-level attention configuration + sink: Optional sink token tensor + out_neg_max, out_sum_recip: Optional softmax cache outputs + shard_seqlen_q_start, shard_seqlen_q_length: Seqlen sharding parameters + cp_offset: Context parallel offset tensor + + High-level logic: + For large enough K/V length, we divide the K/V into sections of 8k. + + For each section: + a. Load K and V to SBUF + b. Loop over Q (groups) - each group has seqlen 128 (_Q_GRP_SZ) + c. Within each group: + i. Load Q + ii. Compute QK^T (MM1) and max + iii. Compute exponential and transpose + iv. Compute PV (MM2) + v. Write to output + + Handling multiple sections: + We keep running max, sum, etc. buffers that we keep updating as we go through + sections and use these to update the output using flash attention. + + Pipelining: + To maximize utilization of the hw engines, we use software pipelining to + inform scheduler decisions. This means we manually interleave iterations of + the group loop to make sure certain operations of group i+2 should start even + as group i is being processed. In addition we use modulo allocation for tensors + to enable pipelining across Q groups and K/V tiles within a section. + """ + is_seqlen_sharded = shard_seqlen_q_start >= 0 + + # Compute all tile parameters including section length and number of sections + atp = _compute_tile_parameters(ac, is_seqlen_sharded) + + # Update shard length if not sharded + if not is_seqlen_sharded: + shard_seqlen_q_start = 0 + shard_seqlen_q_length = atp.num_grps + + # Initialize allocator and buffer container + allocator = ModularAllocator(initial_address=0) + bufs = AttnInternalBuffers() + + # Allocate shared utilities (zero bias, sink) + bufs.zero_bias_tensor = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, 1), dtype=nl.float32) + nisa.memset(bufs.zero_bias_tensor, 0.0) + + if sink is not None: + bufs.sink_sb = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, 1), dtype=nl.float32) + # Load sink from HBM to SBUF + nisa.dma_copy(dst=bufs.sink_sb[0, 0], src=sink[batch_id, 0]) + stream_shuffle_broadcast(src=bufs.sink_sb, dst=bufs.sink_sb) + + # Setup range select bounds for dynamic masking (used in CP/SWA/Prefix caching) + _setup_range_select_bounds(ac, atp, bufs, allocator, cp_offset, prior_used_len, bound_min, bound_max) + + # Allocate running statistics (persistent across sections) + bufs.mm1_running_max = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, atp.num_grps), dtype=nl.float32) + bufs.exp_running_sum = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, atp.num_grps), dtype=nl.float32) + bufs.exp_sum_reciprocal = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, atp.num_grps), dtype=nl.float32) + + # === HOISTED Q: allocate Q once with all-resident tiles (no rotation) === + # Original layout: Q allocated inside section loop with num_free_tiles=[2], + # forcing reload at the start of each section. With QIE's num_sections=2, + # the second section reloaded all 13 Q-tiles redundantly. + # + # Hoisted layout: Q gets num_free_tiles = block_dim (all 13 alive), allocated + # before the section loop. Q is preloaded once. _allocate_attention_buffers + # is called with skip_q_alloc=True from the section loop so it doesn't + # clobber the Q address. + bufs.q_sb = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p * atp.num_q_grps_per_load), + dtype=nl.bfloat16, + block_dim=[div_ceil(atp.num_grps, atp.num_q_grps_per_load)], + num_free_tiles=[div_ceil(atp.num_grps, atp.num_q_grps_per_load)], # all alive + align_to=32, + ) + + # Mark allocator checkpoint for section-level reset (now AFTER Q allocation) + sbuf_addr_outer = allocator.get_current_address() + + # Preload all Q-load groups once (covers all sections). + # Use a temp SectionParams (section_idx 0 / offset 0) — the q load impl + # only reads ac, atp, q, batch_id; SectionParams isn't actually consulted + # for Q load (only used by the function for `is_causal` predicate). + _q_preload_sp = SectionParams( + section_idx=0, + section_offset=0, + section_offset_active=0, + next_section_offset_active=0, + section_contains_prefix=False, + next_section_contains_prefix=False, + ) + for _q_grp in range(0, atp.num_grps, atp.num_q_grps_per_load): + _load_q_impl(_q_grp, ac, atp, _q_preload_sp, bufs, q, batch_id, sbuf_addr_outer) + + # Loop over k/v sections + for section_idx in range(atp.num_sections): + # Compute some quantities for this section such as offset and whether section contains active/prior KV + section_offset = atp.section_len * section_idx + section_offset_active, section_contains_prefix = _compute_section_offset_active( + section_offset, ac.is_prefix_caching, atp.seqlen_k_prior_padded + ) + next_section_offset = atp.section_len * (section_idx + 1) + next_section_offset_active, next_section_contains_prefix = _compute_section_offset_active( + next_section_offset, ac.is_prefix_caching, atp.seqlen_k_prior_padded + ) + sp = SectionParams( + section_idx=section_idx, + section_offset=section_offset, + section_offset_active=section_offset_active, + next_section_offset_active=next_section_offset_active, + section_contains_prefix=section_contains_prefix, + next_section_contains_prefix=next_section_contains_prefix, + ) + + # Allocate the internal buffers (skip Q: hoisted above the section loop) + allocator.set_current_address(sbuf_addr_outer) + _allocate_attention_buffers(allocator, ac, atp, bufs, sink, skip_q_alloc=True) + sbuf_addr = allocator.get_current_address() + + # Load K and V for the section + sbuf_addr = _load_k_tile( + k_active, + k_prior, + bufs.k_sb, + batch_id_kv, + sp, + nl.bfloat16, + ac.tp_k, + atp.num_k_tiles_per_section, + sbuf_addr, + load_offset_active=bufs.k_offset_sb_u32, + ) + sbuf_addr = _load_v_tile( + v_active, + v_prior, + bufs.v_sb, + batch_id_kv, + sp, + nl.bfloat16, + atp.num_v_tiles_per_section, + sbuf_addr, + load_offset_active=bufs.k_offset_sb_u32, + ) + + # Start actual compute. Q is already preloaded above the section loop — + # all _load_q_impl calls inside the original code are skipped here. + if shard_seqlen_q_length <= 1: + # no pipelining when there's only 1 group + _qk_and_max_impl(shard_seqlen_q_start, ac, atp, sp, bufs) + _update_max_impl(shard_seqlen_q_start, ac, atp, sp, bufs, sink) + _exp_impl(shard_seqlen_q_start, ac, atp, sp, bufs, sink) + _pv_impl(shard_seqlen_q_start, ac, atp, sp, bufs) + _write_back_impl(shard_seqlen_q_start, ac, atp, sp, bufs, o, batch_id) + + else: + # Do software pipelining. We have a group loop and some initial/final calls + # outside the loop. + _qk_and_max_impl(shard_seqlen_q_start, ac, atp, sp, bufs) + _update_max_impl(shard_seqlen_q_start, ac, atp, sp, bufs, sink) + _exp_impl(shard_seqlen_q_start, ac, atp, sp, bufs, sink) + + _qk_and_max_impl(shard_seqlen_q_start + 1, ac, atp, sp, bufs) + _update_max_impl(shard_seqlen_q_start + 1, ac, atp, sp, bufs, sink) + shard_seqlen_q_end = shard_seqlen_q_start + shard_seqlen_q_length + + for grp_i in range(shard_seqlen_q_start, shard_seqlen_q_end - 2): # for each block of seq_q + if ac.use_swa and atp.is_causal: + nisa.memset(bufs.mm2_sb[grp_i][...], value=0.0) # when use_swa, mm_i == 0 is not the initial tile + + # We try to perform software pipelining where the following operations are overlapped: + # grp_i : PV, write_back + # grp_i+1 : EXP + # grp_i+2 : QK+Max (Q already loaded above) + _exp_impl(grp_i + 1, ac, atp, sp, bufs, sink) + _fused_qkmax_and_pv_impl(grp_i, ac, atp, sp, bufs) + _write_back_impl(grp_i, ac, atp, sp, bufs, o, batch_id) + _update_max_impl(grp_i + 2, ac, atp, sp, bufs, sink) + + _pv_impl(shard_seqlen_q_end - 2, ac, atp, sp, bufs) + + _write_back_impl(shard_seqlen_q_end - 2, ac, atp, sp, bufs, o, batch_id) + _exp_impl(shard_seqlen_q_end - 1, ac, atp, sp, bufs, sink) + _pv_impl(shard_seqlen_q_end - 1, ac, atp, sp, bufs) + _write_back_impl(shard_seqlen_q_end - 1, ac, atp, sp, bufs, o, batch_id) + + # If used with training, we need to also return the softmax intermediates + # num_grps is total number of groups, shard_seqlen_q_length is current shard + # write from [0:128, shard_seqlen_q_start:shard_seqlen_q_start+shard_seqlen_q_length] + # to [batch_id, 0:128, shard_seqlen_q_start:shard_seqlen_q_start+shard_seqlen_q_length] + dst_ap = [[atp.num_grps, atp.sb_p], [1, shard_seqlen_q_length]] + dst_offset = batch_id * atp.sb_p * atp.num_grps + shard_seqlen_q_start + src_ap = [[atp.num_grps, atp.sb_p], [1, shard_seqlen_q_length]] + src_offset = shard_seqlen_q_start + if out_neg_max is not None: + nisa.dma_copy( + out_neg_max.ap(pattern=dst_ap, offset=dst_offset), + src=bufs.mm1_running_max.ap(pattern=src_ap, offset=src_offset), + ) + if out_sum_recip is not None: + nisa.dma_copy( + out_sum_recip.ap(pattern=dst_ap, offset=dst_offset), + src=bufs.exp_sum_reciprocal.ap(pattern=src_ap, offset=src_offset), + ) + + +@dataclass +class AttnInternalBuffers(nl.NKIObject): + """Container for all SBUF and PSUM tensor buffers used in attention computation.""" + + # SBUF tensors to load q/k/v into + q_sb = None + k_sb = None + v_sb = None + + # SBUF/PSUM tensors for computation + + # QK and max + mm1_psum = None # output of MM1 on PSUM + mm1_copy_sb = None # Copy mm1_psum to SBUF for using affine select (Pool engine input needs to be SBUF) + mm1_affine_select_output = None # Output of affine select, goes via TSCR to produce mm1_masked + mm1_masked = None # Masked and scaled (if scale != 1.0) output after MM1 (in SBUF) - produced via affine_select+TSCR or range select + mm1_partial_max = None # tile-wise max after MM1 + mm1_section_max = None # max for section after MM1 + mm1_running_max = None # (persistent across sections) Running max across sections for output after MM1 + prev_mm1_running_max = None # Previous running max across sections for output after MM1 (used to hold value temporarily before section update) + flash_attn_correction_factor = None # Correction factor for flash attn (exp(prev_max-curr_max)) + + # Exp + exp_sb = None # Output of exp + exp_partial_sum = None # Exp-sum per tile + exp_section_sum = None # Exp-sum for section + exp_tp_sb = None # Transposed output of Exp (input to MM2) + exp_running_sum = None # (persistent across sections) Running sum across sections after exp + prev_exp_running_sum = ( + None # Previous running max across sections after exp (used to hold value temporarily before section update) + ) + exp_sum_reciprocal = None # (persistent across sections) Reciprocal of exp-sum, calculated in the last section + + # PV + mm2_psum = None # output of MM2 on PSUM + mm2_sb = None # accumulate output of MM2 (mm2_psum) into SBUF + mm2_prev_output = None # output from previous section, loaded from HBM to SBUF + mm2_accum_flash_attn = None # Accumulated and scaled by flash_attn_correction_factor output of MM2 across sections + mm2_final = None # Output in final section, scaled by exp_sum_reciprocal + + # Optional buffers (for tp_out=True) + tp_flash_attn_correction_factor_psum = None # nc_transpose of flash_attn_correction_factor on PSUM + tp_flash_attn_correction_factor_sb = None # transpose of flash_attn_correction_factor copied to SBUF + tp_exp_sum_reciprocal_psum = None # nc_transpose of exp_sum_reciprocal on PSUM + tp_exp_sum_reciprocal_sb = None # transpose of exp_sum_reciprocal copied to SBUF + mm2_prev_output_scaled = None # Scaled version of prev_output before accumulating + + # Shared/utility tensors + zero_bias_tensor = None # zeros, used for initialization/fallback in multiple places + sink_sb = None # sink loaded to SBUF + range_sel_lbs = None # lower bound for range_select for CP/SWA/Sequence Packing/Prefix caching + range_sel_ubs = None # upper bound for range_select for CP/SWA/Sequence Packing/Prefix caching + range_sel_lbs_prior = None # lower bound for range_select for Prefix caching + range_sel_ubs_prior = None # upper bound for range_select for Prefix caching + k_offset_sb_u32 = None # used for dynamic load for only required k/v for CP+SWA + + +@dataclass +class SectionParams(nl.NKIObject): + section_idx = None # Index of section + section_offset = None # Offset of section + section_offset_active = None # Offset of active K (adjusted by subtracting prior) + next_section_offset_active = None # Offset of active K for next section + section_contains_prefix = None # Whether current section contains prefix + next_section_contains_prefix = None # Whether next section contains prefix + + +@dataclass +class AttnTileParams(nl.NKIObject): + """Tile and buffer sizing parameters computed during implementation. + + Contains derived parameters specific to buffer allocation and tiling. + """ + + seqlen_k_active_updated: int = None # use updated value based on CP/SWA + seqlen_k_prior_padded: int = None # k_prior len padded to multiple of _K_TILE_SZ (512) + is_causal: bool = None # generally same as causal_mask, but can be modified for CP. + # Used to determine whether compute is eliminated. + + # Partition/tile sizes + sb_p: int = None # SBUF partition size (128) + + # Group parameters + num_grps: int = None + num_q_grps_per_load: int = None # load multiple q groups for better DMA efficiency + can_pack_q_load: bool = None # whether Q loads can be packed into num_q_grps_per_load + + # Tile counts per section + num_large_tiles_per_section: int = None + num_k_tiles_per_section: int = None + num_v_tiles_per_section: int = None + + # Exp parameters + exp_inst_elems: int = None # exp tile size + num_exp_insts_per_large_tile: int = None + + # Transpose and MM2 parameters + # After transpose the scores are laid out as (128,4,128) which effectively stores 4 KxQ tiles + # of 128x128 (recall K tile size is 512) + num_tps_in_mm2_grp: int = None # Number of transpose/MM2 per MM2 group (4) + mm2_grp_sz: int = None # Total free dim for MM2 group (4*128 = 512) + + # Use optimized allocation for SWA where more Q groups are allocated since each group + # only handles relatively small few K tiles + use_swa_optimized_allocation: bool = None + + # Dynamic masking - whether to use range select for masking instead of affine select. + # Required when we need runtime-determined masking (e.g., CP/prefix caching), + # or for performance reasons (e.g., SWA to avoid multiple copies with affine select) + dynamic_sel_mask: bool = None + + # Section + section_len: int = None # Length of section, typically 8k if multiple sections, else same as k seqlen + num_sections: int = None # Number of sections + + +def _compute_tile_parameters( + ac: AttnConfig, + is_seqlen_sharded: bool, +) -> AttnTileParams: + """ + Compute all tile and partition parameters for attention computation. + + Args: + ac: High-level attention configuration + is_seqlen_sharded: Whether Q sequence is sharded + + Returns: + AttnTileParams: Complete tile parameter configuration + """ + atp = AttnTileParams() + + # Validate scale parameter for special modes + if ac.use_swa or ac.is_prefix_caching or ac.use_cp: + # Only scale = 1.0 supported in these cases due to use of range select instead of TSCR + kernel_assert( + ac.scale == 1.0, + f"SWA/Prefix Caching/CP only support scale=1.0, but got {ac.scale=}", + ) + + # When we use CP, tiles are dynamically masked (mask unknown at compile time), so we turn off causal + # to disable compute-skipping. For strided Q slicing, we do not turn off causal masking since + # compute can be eliminated from the region that is masked in all ranks. + atp.is_causal = ac.causal_mask + kernel_assert(ac.causal_mask or not ac.use_cp, "CP currently only supports causal attn") + kernel_assert(ac.causal_mask or not ac.use_swa, "SWA currently only supports causal attn") + atp.dynamic_sel_mask = False + if ac.use_cp: + if not ac.cp_strided_q_slicing: + atp.is_causal = False + atp.dynamic_sel_mask = True + if ac.is_sequence_packed: + atp.dynamic_sel_mask = True + atp.seqlen_k_active_updated = ac.seqlen_k_active + atp.use_swa_optimized_allocation = ( + False # whether to allocate more q groups and fewer k tiles for exp and transpose + ) + + # Handle sliding window attention, in which case only at most (seqlen_q + sliding_window - 1) KV slice is loaded (when CP) + if ac.use_swa: + # When using SWA+CP (dynamic sbuf CP offsets), we (1) do dynamic masking with range_selects and (2) load reduced KV + # When not using CP, we apply both upper (causal) and lower (sliding window) triangular compute skipping; + # Note that the reduced KV load for CP+SWA only applies to (active) K not to K_prior. + # For K_prior, caller can choose to pass only the required KV since it not always possible + # to determine the required seqlen a priori on due to dynamic prior_used_len. + # When using strided Q slicing, we need to load entire KV due to masking pattern. + if ac.use_cp and not ac.cp_strided_q_slicing: + atp.seqlen_k_active_updated = min(ac.seqlen_k_active, ac.seqlen_q + ac.sliding_window - 1) + atp.seqlen_k_active_updated = min(ac.seqlen_k_active, align_to(atp.seqlen_k_active_updated, 512)) + else: + if ac.sliding_window <= _SWA_ALLOCATION_STRATEGY_THRESHOLD: + # use range select to save on excess copy instructions on DVE + atp.dynamic_sel_mask = True + # We only use 1 or 2 tile per group so want to overlap the groups more. + if not ac.is_prefix_caching: + # When prefix caching is enabled, we use static_range which means + # that cannot reduce number of 2048 tiles too much without causing + # data race. So we only use this optimization without prefix caching + atp.use_swa_optimized_allocation = True + + # Partition size + atp.sb_p = nl.tile_size.pmax + # assert that _Q_GRP_SZ = _V_TILE_SZ = atp.sb_p (= 128) since that is an implict assumption in the code + # and updating it requires careful updates. + kernel_assert( + _Q_GRP_SZ == atp.sb_p, + f"Internal error: expect Q group size to match SBUF partition dimension, got {_Q_GRP_SZ=}, {atp.sb_p=}", + ) + kernel_assert( + _V_TILE_SZ == atp.sb_p, + f"Internal error: expect V tile size to match SBUF partition dimension, got {_V_TILE_SZ=}, {atp.sb_p=}", + ) + + # Group configuration + atp.num_grps = div_ceil(ac.seqlen_q, atp.sb_p) + atp.can_pack_q_load = not is_seqlen_sharded + + num_q_grps_per_load_dtype = 4 if ac.dtype == nl.float32 else 8 # fewer groups for float32 for SBUF memory + atp.num_q_grps_per_load = min(num_q_grps_per_load_dtype if atp.can_pack_q_load else 1, atp.num_grps) + kernel_assert( + atp.num_q_grps_per_load > 0, + f"num_q_grps_per_load must be positive, got {atp.num_q_grps_per_load}. " + f"This occurs when num_grps={atp.num_grps} is 0 or negative. " + f"Please check that batch_size, num_heads, and sequence length parameters are positive integers.", + ) + + atp.seqlen_k_prior_padded = None + if ac.is_prefix_caching: + # Pad k_prior length to 512 because that is the loading and masking tile size + # and we don't want to mix prior/active into a single tile. + atp.seqlen_k_prior_padded = align_to(ac.seqlen_k_prior, _K_TILE_SZ) + + if ac.is_prefix_caching: + # With prefix caching we ensure every _K_TILE_SZ (512) tile is either full prior or + # fully active. Note that a section can still contain a mix of prior and + # active. The different lengths are as shown below: + # + # +------------------+---------+----+-------------------------------------+ + # | | | | | + # +------------------+---------+----+-------------------------------------+ + # <-----------------> + # prior_used_len + # (dynamic mask) + # <---------------------------> + # seqlen_k_prior + # <--------------------------------><------------------------------------> + # seqlen_k_prior_padded seqlen_k_active + # (multiple of 512) + total_seqlen_k = atp.seqlen_k_prior_padded + atp.seqlen_k_active_updated + else: + total_seqlen_k = atp.seqlen_k_active_updated + + use_flash_attn = total_seqlen_k > _FLASH_ATTENTION_THRESHOLD + if use_flash_attn: + atp.section_len = min(total_seqlen_k, _FLASH_ATTENTION_SECTION_LENGTH) + else: + atp.section_len = total_seqlen_k + + kernel_assert(atp.section_len > 0, f"section_len must be positive, got {atp.section_len}") + atp.num_sections = div_ceil(total_seqlen_k, atp.section_len) + + if not use_flash_attn: + kernel_assert( + atp.num_sections == 1, + "Logic fault, must only have 1 section if not using flash_attn", + ) + + # Tile counts per section + atp.num_large_tiles_per_section = div_ceil(atp.section_len, _LARGE_TILE_SZ) + atp.num_k_tiles_per_section = div_ceil(atp.section_len, _K_TILE_SZ) + atp.num_v_tiles_per_section = div_ceil(atp.section_len, _V_TILE_SZ) + + # K/V tile sizes for exp and transpose/MM2 + atp.exp_inst_elems = _EXP_TILE_SZ + atp.num_exp_insts_per_large_tile = _LARGE_TILE_SZ // atp.exp_inst_elems + atp.num_tps_in_mm2_grp = _K_TILE_SZ // atp.sb_p # 512 // 128 = 4 + atp.mm2_grp_sz = _K_TILE_SZ + + return atp + + +def _setup_range_select_bounds( + ac: AttnConfig, + atp: AttnTileParams, + bufs: AttnInternalBuffers, + allocator: ModularAllocator, + cp_offset: Any, + prior_used_len: Any, + bound_min: Any, + bound_max: Any, +) -> tuple: + """ + Set up range select bounds for dynamic masking (CP/SWA/prefix caching). + """ + # Populate range select lower and/or upper bounds. NOTE: both bounds are inclusive + if atp.dynamic_sel_mask: + # Populate CP offset if needed + cp_offset_sb = None + if ac.use_cp: + # Check and load CP offset, then broadcast onto all partitions + # Note that range_select only supports fp32 bounds, thus all compute for bounds here use fp32 + kernel_assert((cp_offset is not None), "cp_offset missing but global_cp_deg is provided") + kernel_assert( + (cp_offset.shape == (1, 1)), + "cp_offset shape must be (1, 1) for CP attn", + ) + cp_offset_sb = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, 1), dtype=nl.float32) + + nisa.dma_copy( + dst=cp_offset_sb[0, 0], + src=cp_offset.ap(pattern=[[1, 1], [1, 1]], offset=0), + ) + stream_shuffle_broadcast(src=cp_offset_sb, dst=cp_offset_sb) + + # Create range select upper bounds with IOTA + CP offset (if exists) + bufs.range_sel_ubs = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, atp.num_grps), dtype=nl.float32) + # fill in q positions 0 ... num_grps*_Q_GRP_SZ + # Important: need to set channel_multiplier + # Note in case of cp_strided_q_slicing we fill in q positions as 0, global_cp_deg, 2*global_cp_deg ... + if ac.use_cp and ac.cp_strided_q_slicing: + nisa.iota( + bufs.range_sel_ubs[...], + pattern=[[ac.global_cp_deg * atp.sb_p, atp.num_grps]], + channel_multiplier=ac.global_cp_deg, + ) + else: + seq_packed_non_causal = ac.is_sequence_packed and not atp.is_causal + nisa.iota( + bufs.range_sel_ubs[...], + pattern=[[atp.sb_p, atp.num_grps]], + channel_multiplier=0 if seq_packed_non_causal else 1, + offset=ac.seqlen_k_active if seq_packed_non_causal else 0, + ) + + if ac.use_cp: + nisa.tensor_scalar( + bufs.range_sel_ubs[...], + bufs.range_sel_ubs, + op0=nl.add, + operand0=cp_offset_sb, + ) + if ac.is_sequence_packed: + bufs.range_sel_lbs = allocator.alloc_sbuf_tensor( + shape=bufs.range_sel_ubs.shape, dtype=bufs.range_sel_ubs.dtype + ) + local_allocator = ModularAllocator(allocator._current_address) + tmp_buffer = local_allocator.alloc_sbuf_tensor( + shape=bufs.range_sel_ubs.shape, dtype=bufs.range_sel_ubs.dtype + ) + bound_min_reshaped = bound_min.reshape((atp.sb_p, atp.num_grps)) + bound_max_reshaped = bound_max.reshape((atp.sb_p, atp.num_grps)) + nisa.dma_copy( + dst=bufs.range_sel_lbs[...], src=bound_min_reshaped.ap([[1, atp.sb_p], [atp.sb_p, atp.num_grps]]) + ) + nisa.dma_copy(dst=tmp_buffer[...], src=bound_max_reshaped.ap([[1, atp.sb_p], [atp.sb_p, atp.num_grps]])) + nisa.tensor_tensor(dst=bufs.range_sel_ubs, data1=bufs.range_sel_ubs, data2=tmp_buffer, op=nl.minimum) + + # Create range select lower bounds for sliding window + if ac.use_swa: + if ac.is_sequence_packed: + nisa.scalar_tensor_tensor( + dst=bufs.range_sel_lbs, + data=bufs.range_sel_ubs, + op0=nl.add, + operand0=-(ac.sliding_window - 1.0), + op1=nl.maximum, + operand1=bufs.range_sel_lbs, + ) + else: + bufs.range_sel_lbs = allocator.alloc_sbuf_tensor( + shape=bufs.range_sel_ubs.shape, dtype=bufs.range_sel_ubs.dtype + ) + nisa.tensor_scalar( + bufs.range_sel_lbs, + bufs.range_sel_ubs, + op0=nl.add, + operand0=-(ac.sliding_window - 1.0), + ) + + # Setup prefix caching bounds + if ac.is_prefix_caching: + # with prefix caching, during the prior part: + # - the ubs are wrt prior_used_len [note we don't need causal mask and/or CP offset here]. + # - the lbs are used when SWA is enabled (in this case also we need to offset the bounds by prior_used_len) + # where the cp offset (if any) is already included. + # Note that we do not need the k_offset subtraction because the prior KV is loaded fully + prior_used_len_sb = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, 1), dtype=nl.float32) + nisa.dma_copy( + dst=prior_used_len_sb[0, 0], + src=prior_used_len.ap(pattern=[[1, 1], [1, 1]], offset=0), + ) + + stream_shuffle_broadcast(src=prior_used_len_sb, dst=prior_used_len_sb) + bufs.range_sel_ubs_prior = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, atp.num_grps), dtype=nl.float32) + # explicit broadcast ap needed + nisa.tensor_scalar( + bufs.range_sel_ubs_prior[...], + bufs.zero_bias_tensor.ap(pattern=[[1, atp.sb_p], [0, atp.num_grps]], offset=0), + op0=nl.add, + operand0=prior_used_len_sb, + ) + + if ac.use_swa: + bufs.range_sel_lbs_prior = allocator.alloc_sbuf_tensor( + shape=bufs.range_sel_ubs_prior.shape, + dtype=bufs.range_sel_ubs_prior.dtype, + ) + if bufs.range_sel_lbs is not None: + nisa.tensor_scalar( + bufs.range_sel_lbs_prior[...], + bufs.range_sel_lbs, + op0=nl.add, + operand0=prior_used_len_sb, + ) + else: + kernel_assert(not ac.use_cp, "CP+SWA should have dynamic_sel_mask") + # in this case the SWA mask != yet incorporated via range_sel_lbs so we + # add iota and -(sliding_window - 1.) here. This case will happen when + # sliding window is large and hence we prefer to use affine_select rather + # than range_select in active region (but still need dynamic mask for prior). + nisa.iota( + bufs.range_sel_lbs_prior[...], + pattern=[[atp.sb_p, atp.num_grps]], + channel_multiplier=1, + ) + nisa.tensor_scalar( + bufs.range_sel_lbs_prior[...], + bufs.range_sel_lbs_prior, + op0=nl.add, + operand0=prior_used_len_sb, + op1=nl.add, + operand1=-(ac.sliding_window - 1.0), + ) + + # If using SWA and CP, compute K load offset = max(0, cp_offset - sliding_window + 1) + # Also adjust range select bounds because K seqlen now does not start from 0 + if ac.use_swa and ac.use_cp and not ac.cp_strided_q_slicing: + # Find K load offset to fp32 (required dtype as tensor scalar operand) + k_offset_sb = allocator.alloc_sbuf_tensor(shape=(atp.sb_p, 1), dtype=nl.float32) + nisa.tensor_scalar( + k_offset_sb[...], + cp_offset_sb, + op0=nl.add, + operand0=-(atp.seqlen_k_active_updated - ac.seqlen_q), + op1=nl.maximum, + operand1=0.0, + ) + bufs.k_offset_sb_u32 = allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.tensor_copy(bufs.k_offset_sb_u32[0, 0], k_offset_sb[0, 0]) + + # Adjust range select bounds + nisa.tensor_scalar( + bufs.range_sel_lbs[...], + bufs.range_sel_lbs, + op0=nl.subtract, + operand0=k_offset_sb, + ) + nisa.tensor_scalar( + bufs.range_sel_ubs[...], + bufs.range_sel_ubs, + op0=nl.subtract, + operand0=k_offset_sb, + ) + + +def _allocate_attention_buffers( + allocator: ModularAllocator, + ac: AttnConfig, + atp: AttnTileParams, + bufs: AttnInternalBuffers, + sink: Any, + skip_q_alloc: bool = False, +): + """ + Allocate all SBUF and PSUM buffers needed for attention computation. + + Modifies bufs in-place by allocating and assigning all computation buffers. + + We use the modular allocator with num_free_tiles chosen in order to achieve + multi-buffering and avoid anti-dependencies. The degree of multi-buffering + along the Q group/KV tile axis is chosen based on experimentation. + + skip_q_alloc: If True, skip Q SBUF allocation. Used by the QIE-tuned variant + where Q is allocated once before the section loop with num_free_tiles equal + to its block_dim, so all Q tiles stay alive across sections. + """ + + # Define the partition and free dimension for the two matmuls + mm1_p, mm1_n = atp.sb_p, nl.tile_size.psum_fmax + mm2_p, mm2_n = atp.sb_p, ac.d + + p_k, n_k = ac.d, _K_TILE_SZ # d is reduction dim for MM1 + bufs.k_sb = allocator.alloc_sbuf_tensor( + shape=(p_k, n_k), + dtype=nl.bfloat16, + block_dim=[atp.num_k_tiles_per_section], + num_free_tiles=[atp.num_k_tiles_per_section], + align_to=32, # align for dma transpose + ) + + p_v, n_v = atp.sb_p, ac.d # d is free dim for MM2 + bufs.v_sb = allocator.alloc_sbuf_tensor( + shape=(p_v, n_v), + dtype=nl.bfloat16, + block_dim=[atp.num_v_tiles_per_section], + num_free_tiles=[atp.num_v_tiles_per_section], + ) + + if not skip_q_alloc: + bufs.q_sb = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p * atp.num_q_grps_per_load), + dtype=nl.bfloat16, + block_dim=[div_ceil(atp.num_grps, atp.num_q_grps_per_load)], + num_free_tiles=[2], + align_to=32, # align for dma transpose + ) + + bufs.flash_attn_correction_factor = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + # buffer to hold the partial row-wise maximum from mm1, if we have sink, need one more elt from sink tensor + mm1_partial_max_n_elts = atp.num_k_tiles_per_section + (sink is not None) + bufs.mm1_partial_max = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, mm1_partial_max_n_elts), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + align_to=4, + ) + + bufs.mm1_section_max = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + n_final_reduce_sum_elts = div_ceil(atp.section_len, atp.exp_inst_elems) + (sink is not None) + bufs.exp_partial_sum = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, n_final_reduce_sum_elts), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + bufs.exp_section_sum = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + bufs.prev_mm1_running_max = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + bufs.prev_exp_running_sum = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, 1), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + if ac.tp_out: + bufs.mm2_prev_output = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + else: + bufs.mm2_prev_output = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, ac.d), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + if ac.tp_out: + bufs.mm2_accum_flash_attn = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + else: + bufs.mm2_accum_flash_attn = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, ac.d), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + if ac.tp_out: + bufs.mm2_prev_output_scaled = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + # PSUM allocations for tp_flash_attn_correction_factor_psum and tp_exp_sum_reciprocal_psum + bufs.tp_flash_attn_correction_factor_psum = [] + bufs.tp_exp_sum_reciprocal_psum = [] + for grp_idx in range(atp.num_grps): + tp_flash_attn_correction_factor_psum_tile = nl.ndarray( + (ac.d, atp.sb_p), + dtype=nl.float32, + buffer=nl.psum, + address=(0, ((grp_idx % 2) * 4 + 3) * PSUM_BANK_SIZE), + ) + bufs.tp_flash_attn_correction_factor_psum.append(tp_flash_attn_correction_factor_psum_tile) + tp_exp_sum_reciprocal_psum_tile = nl.ndarray( + (ac.d, atp.sb_p), + dtype=nl.float32, + buffer=nl.psum, + address=(0, ((grp_idx % 2) * 4 + 3) * PSUM_BANK_SIZE), + ) + bufs.tp_exp_sum_reciprocal_psum.append(tp_exp_sum_reciprocal_psum_tile) + + bufs.tp_flash_attn_correction_factor_sb = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[4], + ) + + bufs.tp_exp_sum_reciprocal_sb = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p), + dtype=nl.float32, + block_dim=[atp.num_grps], + num_free_tiles=[4], + ) + + if ac.tp_out: + bufs.mm2_final = allocator.alloc_sbuf_tensor( + shape=(ac.d, atp.sb_p), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + else: + bufs.mm2_final = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, ac.d), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + if ac.tp_out: + bufs.mm2_sb = allocator.alloc_sbuf_tensor( + shape=(mm2_n, mm2_p), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + else: + bufs.mm2_sb = allocator.alloc_sbuf_tensor( + shape=(mm2_p, mm2_n), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps], + num_free_tiles=[2], + ) + + bufs.mm1_masked = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, _LARGE_TILE_SZ), + dtype=nl.float32, + block_dim=[atp.num_grps, atp.num_large_tiles_per_section], + num_free_tiles=[2, atp.num_large_tiles_per_section], + ) + + bufs.exp_sb = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, _LARGE_TILE_SZ), + dtype=nl.bfloat16, + block_dim=[atp.num_grps, atp.num_large_tiles_per_section], + num_free_tiles=([4, 2] if atp.use_swa_optimized_allocation else [1, atp.num_large_tiles_per_section]), + ) + + # mm1_psum PSUM allocation + bufs.mm1_psum = [] + for grp_idx in range(atp.num_grps): + grp_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + tile_row = [] + for k_tile_idx in range(4): + mm1_psum_tile = nl.ndarray( + (mm1_p, mm1_n), + dtype=ac.mm_out_dtype, + buffer=nl.psum, + address=(0, (k_tile_idx % 4) * PSUM_BANK_SIZE), + ) + tile_row.append(mm1_psum_tile) + grp_row.append(tile_row) + bufs.mm1_psum.append(grp_row) + + if not atp.dynamic_sel_mask: + bufs.mm1_copy_sb = allocator.alloc_sbuf_tensor( + shape=(mm1_p, mm1_n), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps, atp.num_large_tiles_per_section, 4], + num_free_tiles=[1, 1, 2], + ) + + bufs.mm1_affine_select_output = allocator.alloc_sbuf_tensor( + shape=(mm1_p, mm1_n), + dtype=ac.mm_out_dtype, + block_dim=[atp.num_grps, atp.num_large_tiles_per_section, 4], + num_free_tiles=[1, 1, 2], + ) + + bufs.exp_tp_sb = allocator.alloc_sbuf_tensor( + shape=(atp.sb_p, atp.mm2_grp_sz), + dtype=nl.bfloat16, + block_dim=[ + atp.num_grps, + atp.num_large_tiles_per_section, + atp.num_tps_in_mm2_grp, + ], + num_free_tiles=( + [4, 2, atp.num_tps_in_mm2_grp] + if atp.use_swa_optimized_allocation + else [2, atp.num_large_tiles_per_section, atp.num_tps_in_mm2_grp] + ), + align_to=32, # align for dma transpose + ) + + # mm2_psum allocation + bufs.mm2_psum = [] + for grp_idx in range(atp.num_grps): + grp_row = [] + for large_tile_idx in range(atp.num_large_tiles_per_section): + if ac.tp_out: + mm2_psum_tile = nl.ndarray( + (mm2_n, mm2_p), + dtype=ac.mm_out_dtype, + buffer=nl.psum, + address=(0, ((4 + (large_tile_idx % 4)) * PSUM_BANK_SIZE)), + ) + else: + mm2_psum_tile = nl.ndarray( + (mm2_p, mm2_n), + dtype=ac.mm_out_dtype, + buffer=nl.psum, + address=(0, ((4 + (large_tile_idx % 4)) * PSUM_BANK_SIZE)), + ) + grp_row.append(mm2_psum_tile) + bufs.mm2_psum.append(grp_row) + + +def _q_to_kv_batch_id(batch_id: int, bs: int, bs_kv: int) -> int: + """Map Q batch id to KV batch id for native GQA support. + + Currently we implement native GQA support by simply using the correct KV batch id + corresponding to the Q batch id but not attempting to optimize the KV loads themselves. + We still get the benefit of not needing to replicate the KV before calling the kernel. + + Example: bs=6, bs_kv=2: mapping from batch_id_q -> batch_id_kv: + {0: 0, 1: 0, 2: 0, 3: 1, 4: 1, 5: 1} + """ + return batch_id // (bs // bs_kv) + + +def _check_input_and_return_shape( + q, + k, + v, + is_prefix_caching, + k_prior, + v_prior, + prior_used_len, + tp_q, + tp_k, + tp_out, + cache_softmax, +) -> tuple: + """Validate input tensor shapes and compute output shapes. + + Check the shape of inputs base on kernel_name, and return the tuple, + (seqlen_q, seqlen_k, seqlen_k_prior, d, out_shape, cache_softmax_shape) + """ + if is_prefix_caching: + kernel_assert( + prior_used_len.shape == (1,), + "Received unexpected shape for prior_used_len. " + f"Expected (1,), received {prior_used_len.shape}. " + "User note: prefix caching expects a single " + "prior_used_len meaning it cannot be used " + "if multiple requests (batch) are used with different " + "prior_used_len values.", + ) + assert_shape( + prior_used_len, + (1,), + "prior_used_len", + error_text="User note: prefix caching expects a single " + "prior_used_len meaning it cannot be used " + "if multiple requests (batch) are used with different " + "prior_used_len values.", + ) + + if tp_q: + batch_size, seqlen_q, d = q.shape + else: + batch_size, d, seqlen_q = q.shape + seqlen_k_dim = 1 if tp_k else 2 + seqlen_k = k.shape[seqlen_k_dim] + batch_size_kv = k.shape[0] + if tp_k: + assert_shape(k, (batch_size_kv, seqlen_k, d), "k") + else: + assert_shape(k, (batch_size_kv, d, seqlen_k), "k") + assert_shape(v, (batch_size_kv, seqlen_k, d), "v") + if is_prefix_caching: + seqlen_k_prior = k_prior.shape[seqlen_k_dim] + if tp_k: + assert_shape(k_prior, (batch_size_kv, seqlen_k_prior, d), "k_prior") + else: + assert_shape(k_prior, (batch_size_kv, d, seqlen_k_prior), "k_prior") + assert_shape(v_prior, (batch_size_kv, seqlen_k_prior, d), "v_prior") + else: + seqlen_k_prior = None + + out_seqlen = seqlen_q + + if tp_out: + out_shape = (batch_size, d, out_seqlen) + else: + out_shape = (batch_size, out_seqlen, d) + + # compute the shape for cached softmax tensors, negative max and recipriocal sum + cache_softmax_tile_size = 128 + if cache_softmax: + # Current testing/golden does not account for the padded portion properly + kernel_assert( + seqlen_q % cache_softmax_tile_size == 0, + f"For cache softmax, attention_cte currently expects seqlen_q multiple of {cache_softmax_tile_size}, got {seqlen_q=}", + ) + padded_seq_grps = div_ceil(out_seqlen, cache_softmax_tile_size) + cache_softmax_shape = [batch_size, cache_softmax_tile_size, padded_seq_grps] + + return (seqlen_q, seqlen_k, seqlen_k_prior, d, out_shape, cache_softmax_shape) + + +def _load_q_tile(q, out, tp_q, batch_id, grp_i, seqlen_offset, load_dtype, sbuf_addr, grps_per_load=1) -> int: + """Load Q tile from HBM to SBUF, handling transpose based on tp_q flag. + + When tp_q is True, assumes q = (bs, seqlen, d). + + When tp_q is False, assumes q = (bs, d, seqlen), will perform transpose then save into out. + + out[grp_i // grps_per_load] has shape (d, 128 * grps_per_load) + """ + kernel_assert(str(load_dtype) == str(out[0].dtype), "Conflicting dtype") + local_allocator = ModularAllocator(initial_address=sbuf_addr) + if tp_q: + _, seqlen, d = q.shape + if grps_per_load > 1: + # Use DMA transpose + num_p = min(seqlen - seqlen_offset, _Q_GRP_SZ * grps_per_load) + + # TODO: fix below check once proper type is available for tensor dtype + if str(q.dtype) == str(nl.bfloat16): + nisa.dma_transpose( + dst=out[grp_i // grps_per_load].ap([[_Q_GRP_SZ * grps_per_load, d], [1, 1], [1, 1], [1, num_p]]), + src=q.ap( + [[d, num_p], [1, 1], [1, 1], [1, d]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ), + ) + else: + # Need a buffer with same dtype as q as dma_transpose requires same I/O dtype + buffer = local_allocator.alloc_sbuf_tensor( + shape=(d, _Q_GRP_SZ * grps_per_load), + dtype=q.dtype, + align_to=32, # align for dma transpose + ) + + nisa.dma_transpose( + dst=buffer.ap([[_Q_GRP_SZ * grps_per_load, d], [1, 1], [1, 1], [1, num_p]]), + src=q.ap( + [[d, num_p], [1, 1], [1, 1], [1, d]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ), + ) + + nisa.tensor_copy(out[grp_i // grps_per_load][:, :num_p], buffer[:, :num_p]) + else: + # Use NC transpose + kernel_assert( + grps_per_load == 1, + "tp Q on Trn1/shard on seqlen does not yet support packed load", + ) + loaded = local_allocator.alloc_sbuf_tensor(shape=(_Q_GRP_SZ, d), dtype=load_dtype) + tp_dt = load_dtype + psum_buf = nl.ndarray((d, _Q_GRP_SZ), dtype=tp_dt, buffer=nl.psum, address=(0, 0)) + + num_p = min(seqlen - seqlen_offset, _Q_GRP_SZ) + # Convert load() to access pattern + # Original: load(dst=loaded[nl.ds(0, num_p), :], src=q[batch_id, nl.ds(seqlen_offset, _Q_GRP_SZ), 0:d], dtype=load_dtype) + # q shape: (bs, seqlen, d), accessing q[batch_id, seqlen_offset:seqlen_offset+_Q_GRP_SZ, 0:d] + # Pattern: [[d, num_p], [1, d]] + # Offset: batch_id*seqlen*d + seqlen_offset*d + loaded_dst_pat = loaded.ap(pattern=[[d, num_p], [1, d]], offset=0) + q_src_pat = q.ap( + pattern=[[d, num_p], [1, d]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ) + nisa.dma_copy(dst=loaded_dst_pat, src=q_src_pat) + + nisa.nc_transpose(psum_buf[:d, :num_p], loaded[:num_p, :d]) + num_f = min(seqlen - seqlen_offset, _Q_GRP_SZ) + nisa.tensor_copy(out[grp_i][:d, :num_f], psum_buf[:d, :num_f]) + else: + _, d, seqlen = q.shape + + num_f = min(seqlen - seqlen_offset, _Q_GRP_SZ * grps_per_load) + # Convert load() to access pattern + # Original: load(dst=out[grp_i // grps_per_load][nl.ds(0, d), nl.ds(0, num_f)], src=q[batch_id, nl.ds(0, d), nl.ds(seqlen_offset, num_f)], dtype=load_dtype) + # q shape: (bs, d, seqlen), accessing q[batch_id, 0:d, seqlen_offset:seqlen_offset+num_f] + # Pattern: [[seqlen, d], [1, num_f]] + # Offset: batch_id*d*seqlen + seqlen_offset + out_dst_pat = out[grp_i // grps_per_load].ap(pattern=[[_Q_GRP_SZ * grps_per_load, d], [1, num_f]], offset=0) + q_src_pat = q.ap( + pattern=[[seqlen, d], [1, num_f]], + offset=batch_id * d * seqlen + seqlen_offset, + ) + nisa.dma_copy(dst=out_dst_pat, src=q_src_pat) + + +def _get_kv_tile_apc( + is_prefix_caching, + k_active, + k_prior, + seqlen_active, + seqlen_prior, + seqlen_offset, + load_offset_active, +) -> tuple: + """Determine which KV tensor (active or prior) to use based on sequence offset. + + Get information about KV tile (used for Prefix Caching) + """ + if not is_prefix_caching: + return k_active, seqlen_active, seqlen_offset, load_offset_active + else: + seqlen_prior_padded = align_to(seqlen_prior, _K_TILE_SZ) + if seqlen_offset >= seqlen_prior_padded: + return ( + k_active, + seqlen_active, + seqlen_offset - seqlen_prior_padded, + load_offset_active, + ) + else: + return ( + k_prior, + seqlen_prior, + seqlen_offset, + None, # no load_offset used for prior + ) + + +def _load_k_tile( + k_active, + k_prior, + out, + batch_id, + sp: SectionParams, + load_dtype, + tp_k, + num_tiles, + sbuf_addr, + load_offset_active=None, +) -> int: + """Load K tiles from HBM to SBUF in _K_TILE_SZ (512)-element chunks, handling transpose and prefix caching. + + k has shape + (bs, d, seqlen) when tp_k=False + (bs, seqlen, d) when tp_k=True, i.e. a transpose is performed + k_prior (if passed) has shape identical to k except for the seqlen. + Return shape of out[i] is (d, _K_TILE_SIZE) where i = 0..num_k_tiles_per_section + """ + if tp_k: + _, seqlen_active, _ = k_active.shape + else: + _, _, seqlen_active = k_active.shape + seqlen_prior = None + is_prefix_caching = k_prior is not None + if is_prefix_caching: + if tp_k: + _, seqlen_prior, _ = k_prior.shape + else: + _, _, seqlen_prior = k_prior.shape + if num_tiles > 0: + d, n = out[0].shape + sb_p = nl.tile_size.pmax + stride_f = _K_TILE_SZ + + kernel_assert(n == _K_TILE_SZ, f"expect to load in tile of size {_K_TILE_SZ=}") + kernel_assert(str(load_dtype) == str(out[0].dtype), "load dtype mismatch") + local_allocator = ModularAllocator(initial_address=sbuf_addr) + sbuf_addr_max = sbuf_addr + if tp_k: + sbuf_addr_outer = local_allocator.get_current_address() + for tile in range(num_tiles): + local_allocator.set_current_address(sbuf_addr_outer) + + # for APC we need to use either k or k_prior and appropriately adjust sequence length and other quantities + k, seqlen, seqlen_offset, load_offset = _get_kv_tile_apc( + is_prefix_caching, + k_active, + k_prior, + seqlen_active, + seqlen_prior, + sp.section_offset + tile * n, + load_offset_active, + ) + + if seqlen_offset >= seqlen: + # since we always use section_len/512 tiles even for last section + # we might exit early + return sbuf_addr_max + + use_dma_tp = load_offset is None # cannot use dma tp when using dynamic offset + if use_dma_tp: + num_p = min(seqlen - seqlen_offset, n) + # TODO: fix below check once proper type is available for tensor dtype + if str(k.dtype) == str(nl.bfloat16): + nisa.dma_transpose( + dst=out[tile].ap([[n, d], [1, 1], [1, 1], [1, num_p]]), + src=k.ap( + [[d, num_p], [1, 1], [1, 1], [1, d]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ), + ) + else: + # Need a buffer with same dtype as k as dma_transpose requires same I/O dtype + buffer = local_allocator.alloc_sbuf_tensor( + shape=(d, n), + dtype=k.dtype, + align_to=32, # align for dma transpose + ) + sbuf_addr_max = max(sbuf_addr_max, local_allocator.get_current_address()) + nisa.dma_transpose( + dst=buffer.ap([[n, d], [1, 1], [1, 1], [1, num_p]]), + src=k.ap( + [[d, num_p], [1, 1], [1, 1], [1, d]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ), + ) + + nisa.tensor_copy(out[tile][:, :num_p], buffer[:, :num_p]) + else: # not use_dma_tp + num_pe_tps = _K_TILE_SZ // sb_p # number of transposes (4) + loaded = local_allocator.alloc_sbuf_tensor(shape=(sb_p, num_pe_tps, d), dtype=load_dtype) + sbuf_addr_max = max(sbuf_addr_max, local_allocator.get_current_address()) + tp_dt = load_dtype + psum_buf = nl.ndarray((d, num_pe_tps, sb_p), dtype=tp_dt, buffer=nl.psum, address=(0, 0)) + + if load_offset is not None: + # NOTE: NKI is incapable of handling both dynamic and constant offsets in IndirectLoad, also it cannot handle + # dynamic offset used with more than one axis, so we must use four DMAs + for tp_idx in range(num_pe_tps): + num_p = min(seqlen - seqlen_offset - tp_idx * sb_p, sb_p) + if num_p > 0: + ind_offset = local_allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + sbuf_addr_max = max(sbuf_addr_max, local_allocator.get_current_address()) + nisa.tensor_scalar( + ind_offset, + load_offset, + nl.add, + seqlen_offset + tp_idx * sb_p, + ) + loaded_dst_pat = loaded.ap( + pattern=[[num_pe_tps * d, num_p], [1, d]], + offset=tp_idx * d, + ) + k_src_pat = k.ap( + pattern=[[d, num_p], [1, d]], + scalar_offset=ind_offset, + offset=batch_id * seqlen * d, + indirect_dim=1, + ) + nisa.dma_copy(dst=loaded_dst_pat, src=k_src_pat) + else: # not load_offset + # Use strided load to load four tiles of [128, d] + num_inner_f = min(div_ceil(seqlen - seqlen_offset, sb_p), num_pe_tps) + num_p = min(seqlen - seqlen_offset - num_inner_f * num_pe_tps, sb_p) + + # Convert load() to access pattern with 2D mask + # Original: load(dst=loaded[...], src=k[batch_id, seqlen_offset + i_b*128 + i_p, i_f], mask=i_b*128+i_p < seqlen-seqlen_offset) + if seqlen_offset < seqlen: + # case 1: handle rectangular + # Offset: batch_id*seqlen*d + seqlen_offset*d + num_inner_f = min(num_pe_tps, (seqlen - seqlen_offset) // sb_p) + num_p = sb_p + loaded_dst_pat = loaded.ap( + pattern=[[num_pe_tps * d, num_p], [d, num_inner_f], [1, d]], + offset=0, + ) + + k_src_pat = k.ap( + pattern=[[d, num_p], [d * sb_p, num_inner_f], [1, d]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ) + nisa.dma_copy(dst=loaded_dst_pat, src=k_src_pat) + # case 2: handle last row + if num_inner_f < num_pe_tps and (seqlen - seqlen_offset) % sb_p != 0: + num_p = min( + sb_p, + seqlen - seqlen_offset - (seqlen - seqlen_offset) // sb_p * sb_p, + ) + offset = batch_id * seqlen * d + seqlen_offset * d + num_inner_f * sb_p * d + loaded_dst_pat = loaded.ap( + pattern=[[num_pe_tps * d, num_p], [1, d]], + offset=num_inner_f * d, + ) + k_src_pat = k.ap(pattern=[[d, num_p], [1, d]], offset=offset) + nisa.dma_copy(dst=loaded_dst_pat, src=k_src_pat) + + if seqlen_offset < seqlen: + # Transpose loaded[128, 4, d] with four PE transposes + for tp_idx in range(num_pe_tps): + num_p = min(seqlen - seqlen_offset - tp_idx * sb_p, sb_p) + if num_p > 0: + nisa.nc_transpose( + psum_buf[:d, tp_idx, :num_p], + loaded[:num_p, tp_idx, :d], + ) + + # Copy out transposed results + num_f = min(seqlen - seqlen_offset, n) + nisa.tensor_copy( + out[tile][:d, :num_f], + psum_buf.reshape((d, _K_TILE_SZ))[:d, :num_f], + ) + else: # no tp k + for tile in range(num_tiles): + # for APC we need to use either k or k_prior and appropriately adjust sequence length and other quantities + k, seqlen, seqlen_offset, load_offset = _get_kv_tile_apc( + is_prefix_caching, + k_active, + k_prior, + seqlen_active, + seqlen_prior, + sp.section_offset + tile * n, + load_offset_active, + ) + + num_f = min(seqlen - seqlen_offset, n) + if num_f > 0: + if load_offset is not None: + # NOTE: NKI is incapable of handling both dynamic and constant offsets in IndirectLoad + ind_offset = local_allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + sbuf_addr_max = max(sbuf_addr_max, local_allocator.get_current_address()) + nisa.tensor_scalar(ind_offset, load_offset, nl.add, seqlen_offset) + + out_dst_pat = out[tile].ap(pattern=[[stride_f, d], [1, num_f]], offset=0) + k_src_pat = k.ap( + pattern=[[seqlen, d], [1, num_f]], + scalar_offset=ind_offset, + offset=batch_id * d * seqlen, + indirect_dim=2, + ) + nisa.dma_copy(dst=out_dst_pat, src=k_src_pat) + else: + # Convert load() to access pattern + # Original: load(dst=out[i][nl.ds(0, d), nl.ds(0, num_f)], src=k[batch_id, nl.ds(0, d), nl.ds(seqlen_offset, num_f)], dtype=load_dtype) + # k shape: (bs, d, seqlen), accessing k[batch_id, 0:d, seqlen_offset:seqlen_offset+num_f] + # Pattern: [[seqlen, d], [1, num_f]] + # Offset: batch_id*d*seqlen + seqlen_offset + out_dst_pat = out[tile].ap(pattern=[[stride_f, d], [1, num_f]], offset=0) + k_src_pat = k.ap( + pattern=[[seqlen, d], [1, num_f]], + offset=batch_id * d * seqlen + seqlen_offset, + ) + + nisa.dma_copy(dst=out_dst_pat, src=k_src_pat) + + return sbuf_addr_max + + +def _load_v_tile( + v_active, + v_prior, + out, + batch_id, + sp: SectionParams, + load_dtype, + num_tiles, + sbuf_addr, + load_offset_active=None, +) -> int: + """Load V tiles from HBM to SBUF in _V_TILE_SZ (128)-element chunks, handling prefix caching. + + - v of shape (bs, seqlen, d). + - out[i] has shape (_V_TILE_SZ, d) where i = 0..num_v_tiles_per_section + - v_prior (if passed) has shape identical to v except for the seqlen. + """ + local_allocator = ModularAllocator(initial_address=sbuf_addr) + _, seqlen_active, _ = v_active.shape + seqlen_prior = None + is_prefix_caching = v_prior is not None + if is_prefix_caching: + _, seqlen_prior, _ = v_prior.shape + if num_tiles > 0: + p, n = out[0].shape + + d = n + kernel_assert(str(load_dtype) == str(out[0].dtype), "load dtype mismatch") + + for tile in range(num_tiles): + v, seqlen, seqlen_offset, load_offset = _get_kv_tile_apc( + is_prefix_caching, + v_active, + v_prior, + seqlen_active, + seqlen_prior, + sp.section_offset + p * tile, + load_offset_active, + ) + num_p = min(seqlen - seqlen_offset, p) + if num_p > 0: + if load_offset is not None: + # NOTE: NKI is incapable of handling both dynamic and constant offsets in IndirectLoad + ind_offset = local_allocator.alloc_sbuf_tensor(shape=(1, 1), dtype=nl.uint32) + nisa.tensor_scalar(ind_offset, load_offset, nl.add, seqlen_offset) + out_dst_pat = out[tile].ap(pattern=[[n, num_p], [1, n]], offset=0) + v_src_pat = v.ap( + pattern=[[d, num_p], [1, n]], + scalar_offset=ind_offset, + offset=batch_id * seqlen * d, + indirect_dim=1, + ) + nisa.dma_copy(dst=out_dst_pat, src=v_src_pat) + else: + # Convert load() to access pattern + # Original: load(dst=out[i][nl.ds(0, num_p), nl.ds(0, n)], src=v[batch_id, nl.ds(seqlen_offset, num_p), nl.ds(0, n)], dtype=load_dtype) + # v shape: (bs, seqlen, d), accessing v[batch_id, seqlen_offset:seqlen_offset+num_p, 0:n] + # Pattern: [[d, num_p], [1, n]] + # Offset: batch_id*seqlen*d + seqlen_offset*d + out_dst_pat = out[tile].ap(pattern=[[n, num_p], [1, n]], offset=0) + v_src_pat = v.ap( + pattern=[[d, num_p], [1, n]], + offset=batch_id * seqlen * d + seqlen_offset * d, + ) + nisa.dma_copy(dst=out_dst_pat, src=v_src_pat) + + return local_allocator.get_current_address() + + +def _compute_section_offset_active(k_section_offset, is_prefix_caching, seqlen_k_prior_padded) -> tuple: + """Compute active K offset and determine if section contains prefix data for proper masking.""" + section_contains_prefix = False + k_section_offset_active = k_section_offset + if is_prefix_caching: + # in prefix caching case, for masking we fall back to causal=False case + # unless we are in prior portion. Even for active portion, we need to + # adjust the offset. + if k_section_offset >= seqlen_k_prior_padded: + k_section_offset_active = k_section_offset - seqlen_k_prior_padded + else: + section_contains_prefix = True + return k_section_offset_active, section_contains_prefix + + +def _load_q_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, + q, + batch_id, + sbuf_addr, +): + """Load Q group from HBM to SBUF if compute is needed for this group.""" + # Only load every num_q_grps_per_load grps + if grp_i % atp.num_q_grps_per_load == 0: + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac, atp.num_q_grps_per_load) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if has_any_compute_pred: + q_seqlen_offset = grp_i * _Q_GRP_SZ + _load_q_tile( + q, + bufs.q_sb, + ac.tp_q, + batch_id, + grp_i, + q_seqlen_offset, + bufs.q_sb[0].dtype, + sbuf_addr, + grps_per_load=atp.num_q_grps_per_load, + ) + + +def _qk_and_max_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, +): + """Compute QK^T matmul (MM1) and find row-wise maximum for this Q group. + Also apply masking if relevant. + """ + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if has_any_compute_pred: + nisa.memset(bufs.mm1_partial_max[grp_i], value=_FLOAT32_MIN) + + for large_tile_idx in range(atp.num_large_tiles_per_section): + _qk_and_max_large_tile_impl(grp_i, large_tile_idx, ac, atp, sp, bufs) + + +def _update_max_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, + sink, +): + """Update running maximum across sections and compute flash attention correction factor.""" + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if not has_any_compute_pred: + return + + # Step 1: Compute section max + # If we have sink, need to include it in final max compute + if (sink is not None) and (sp.section_idx == 0): + nisa.tensor_copy(bufs.mm1_partial_max[grp_i][:, atp.num_k_tiles_per_section], bufs.sink_sb) + + nisa.tensor_reduce( + bufs.mm1_section_max[grp_i][:, 0], + nl.maximum, + bufs.mm1_partial_max[grp_i], + 1, + negate=True, + ) + + # Step 2: compute and store running max, and flash attention correction factor + if atp.num_sections != 1: + if sp.section_idx == 0: + nisa.tensor_copy(bufs.mm1_running_max[:, grp_i], bufs.mm1_section_max[grp_i]) + nisa.memset(bufs.flash_attn_correction_factor[grp_i][...], value=0.0) + if sp.section_idx > 0: + nisa.activation( + bufs.prev_mm1_running_max[grp_i][...], + nl.copy, + bufs.mm1_running_max[:, grp_i], + scale=-1.0, + bias=bufs.zero_bias_tensor, + ) + nisa.tensor_tensor( + bufs.mm1_running_max[:, grp_i], + bufs.mm1_running_max[:, grp_i], + bufs.mm1_section_max[grp_i], + op=nl.minimum, + ) + nisa.activation( + bufs.flash_attn_correction_factor[grp_i][:, 0], + nl.exp, + bufs.prev_mm1_running_max[grp_i], + bias=bufs.mm1_running_max[:, grp_i], + scale=1.0, + ) + else: + nisa.tensor_copy(bufs.mm1_running_max[:, grp_i], bufs.mm1_section_max[grp_i]) + + +def _exp_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, + sink, +): + """Compute exponential of masked QK scores, accumulate sum, and perform transpose (required for MM2).""" + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if not has_any_compute_pred: + return + + q_seqlen_offset = grp_i * atp.sb_p + nisa.memset(bufs.exp_partial_sum[grp_i][...], value=0.0) + + for large_tile_idx in range(atp.num_large_tiles_per_section): + kernel_assert( + atp.exp_inst_elems == 512, "Internal validation failed." + ) # prefix caching code assumes this currently, if we increase tile size to 2048, we will need to update logic + + for exp_tile_idx in range(atp.num_exp_insts_per_large_tile): + is_prior_tile, seqlen_k, k_start_pos, _ = _get_kv_tile_apc( + ac.is_prefix_caching, + False, + True, + atp.seqlen_k_active_updated, + ac.seqlen_k_prior, + sp.section_offset + large_tile_idx * _LARGE_TILE_SZ + exp_tile_idx * atp.exp_inst_elems, + None, + ) + num_p = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + num_f = min(seqlen_k - k_start_pos, atp.exp_inst_elems) + + q_start_pos = grp_i * _Q_GRP_SZ + # Only produce matmul if the tile is in the lower triangle, use tile bot-left corner so adjust q + if atp.is_causal and not is_prior_tile: + exp_sel_mask = _has_any_compute_causal(grp_i, k_start_pos, ac) + else: + exp_sel_mask = True + + # If using SWA, also skip bot-left lower triangle. + if ac.use_swa and atp.is_causal and not is_prior_tile: + # Use tile top-right corner so adjust k; also adjust q for sliding window + exp_sel_mask = exp_sel_mask and _has_any_compute_swa(grp_i, k_start_pos, atp.exp_inst_elems, ac) + + if exp_sel_mask and seqlen_k > k_start_pos: + # Step 1: Compute exponential + nisa.activation_reduce( + bufs.exp_sb[grp_i][large_tile_idx][:num_p, nl.ds(exp_tile_idx * atp.exp_inst_elems, num_f)], + op=nl.exp, + data=bufs.mm1_masked[grp_i][large_tile_idx][ + :num_p, nl.ds(exp_tile_idx * atp.exp_inst_elems, num_f) + ], + reduce_op=nl.add, + reduce_res=bufs.exp_partial_sum[grp_i][ + :num_p, + large_tile_idx * atp.num_exp_insts_per_large_tile + exp_tile_idx, + ], + bias=bufs.mm1_running_max[:num_p, grp_i], + ) + + # Step 2: Perform DMA transpose + num_f_outer = num_f // atp.sb_p + num_f_inner = num_f % atp.sb_p + # split dma_transpose into two parts to satisfy API since we have both Q and K sequence masking + # Focusing on exp_tp_sb which is arranged as [128, 4, 128] where each of the 4 [128, 128] blocks + # share the same Q seqlen (on free dim) and cover 4 tiles of K seqlen (partition dim) + # First region, we have num_f_outer [128, 128] blocks each having full partition dim (K) and each + # accessing num_p (<128) on the free dim (Q). + # Second region, we handle the remaining K (num_f_inner) - here we have the (num_f_outer+1)th [128,128] + # block being utilized with num_f_inner access on partition dim and num_p on the free dim. + + # Example: num_f_outer = 3, num_f_inner = 33, num_p = 100 + # Region 1: AP: [[512, 128], [128, 3], [1, 100]] => a, b, c = np.mgrid[0:128, 0:3, 0:100] + # Region 2: AP: [[512, 33], [128, 1], [1, 100]] => a, b, c = np.mgrid[0:33, 0:1, 0:100] with offset 128 * 3 + + # NOTE: we add the [1,1] because we need 4 dims for dma_transpose + + # Case 1: handle 0:128x + if num_f_outer >= 1: + nisa.dma_transpose( + dst=bufs.exp_tp_sb[grp_i][large_tile_idx][exp_tile_idx].ap( + [ + [atp.mm2_grp_sz, atp.sb_p], + [1, 1], + [atp.sb_p, num_f_outer], + [1, num_p], + ] + ), + src=bufs.exp_sb[grp_i][large_tile_idx].ap( + [ + [_LARGE_TILE_SZ, num_p], + [1, 1], + [atp.sb_p, num_f_outer], + [1, atp.sb_p], + ], + offset=exp_tile_idx * atp.mm2_grp_sz, + ), + ) + + # Case 2: handle num_f - 128x + if num_f_inner > 0: + nisa.dma_transpose( + dst=bufs.exp_tp_sb[grp_i][large_tile_idx][exp_tile_idx].ap( + [ + [atp.mm2_grp_sz, num_f_inner], + [1, 1], + [atp.sb_p, 1], + [1, num_p], + ], + offset=num_f_outer * atp.sb_p, + ), + src=bufs.exp_sb[grp_i][large_tile_idx].ap( + [ + [_LARGE_TILE_SZ, num_p], + [1, 1], + [atp.sb_p, 1], + [1, num_f_inner], + ], + offset=exp_tile_idx * atp.mm2_grp_sz + num_f_outer * atp.sb_p, + ), + ) + + # If there is sink, subtract max from it, then take its exp, then append it to sums + if (sink is not None) and (sp.section_idx == 0): + frs_sink_idx = bufs.exp_partial_sum[grp_i].shape[-1] - 1 + nisa.activation( + bufs.exp_partial_sum[grp_i][:, frs_sink_idx], + op=nl.exp, + data=bufs.sink_sb, + bias=bufs.mm1_running_max[:, grp_i], + ) + + +def _pv_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, +): + """Compute score@value matmul (P@V, MM2) for this Q group.""" + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if has_any_compute_pred: + nisa.memset(bufs.mm2_sb[grp_i][...], value=0.0) + + for large_tile_idx in range(atp.num_large_tiles_per_section): + _pv_large_tile_impl(grp_i, large_tile_idx, ac, atp, sp, bufs) + + +def _fused_qkmax_and_pv_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, +): + """Fused implementation computing QK+max for group i+2 while computing PV for group i (software pipelining).""" + qkmax_grp = grp_i + 2 + + has_any_compute_pred_pv = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + has_any_compute_pred_qkmax = ( + _has_any_compute_causal(qkmax_grp, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + if has_any_compute_pred_qkmax: + nisa.memset(bufs.mm1_partial_max[qkmax_grp][...], value=_FLOAT32_MIN) + + for large_tile_idx in range(atp.num_large_tiles_per_section): + if has_any_compute_pred_pv: + _pv_large_tile_impl(grp_i, large_tile_idx, ac, atp, sp, bufs) + + if has_any_compute_pred_qkmax: + _qk_and_max_large_tile_impl(qkmax_grp, large_tile_idx, ac, atp, sp, bufs) + + +def _write_back_impl( + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, + o, + batch_id, +): + """Finalize output using flash attention: accumulate across sections, apply softmax normalization, write to HBM.""" + has_any_compute_pred = ( + _has_any_compute_causal(grp_i, sp.section_offset_active, ac) + if (atp.is_causal and not sp.section_contains_prefix) + else True + ) + # if we have compute for this section but not the next section, + # then this is the last section to have compute, and we need to + # execute the final section logic to write final results + next_has_compute_pred = ( + _has_any_compute_causal(grp_i, sp.next_section_offset_active, ac) + if (atp.is_causal and not sp.next_section_contains_prefix) + else True + ) + is_last_section_with_compute = has_any_compute_pred and (not next_has_compute_pred) + + if not has_any_compute_pred: + return + + # Step 1: Compute/update exp-sum and its reciprocal + q_seqlen_offset = grp_i * atp.sb_p + nisa.tensor_reduce(bufs.exp_section_sum[grp_i][...], nl.add, bufs.exp_partial_sum[grp_i], axis=1) + if atp.num_sections != 1: + if sp.section_idx == 0: + nisa.tensor_copy(bufs.exp_running_sum[:, grp_i], bufs.exp_section_sum[grp_i]) + if sp.section_idx > 0: + nisa.tensor_copy( + bufs.prev_exp_running_sum[grp_i][...], + bufs.exp_running_sum[:, grp_i], + ) + nisa.tensor_scalar( + bufs.exp_running_sum[:, grp_i], + bufs.prev_exp_running_sum[grp_i][:, 0], + nl.multiply, + bufs.flash_attn_correction_factor[grp_i], + op1=nl.add, + operand1=bufs.exp_section_sum[grp_i], + ) + if (sp.section_idx == atp.num_sections - 1) or is_last_section_with_compute: + nisa.reciprocal( + bufs.exp_sum_reciprocal[:, grp_i], + bufs.exp_running_sum[:, grp_i], + ) + else: + nisa.reciprocal(bufs.exp_sum_reciprocal[:, grp_i], bufs.exp_section_sum[grp_i]) + + num_p = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + num_f = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + + """ + Step 2: + - if first section: + - if last section with compute: + - write to HBM after scaling by reciprocal + - else: + - write to HBM + - else: + - Load previous output and apply flash attention correction and accumulate + - if last section with compute: + - write to HBM after scaling by reciprocal + - write to HBM + """ + if atp.num_sections != 1: + if sp.section_idx == 0: + if is_last_section_with_compute: + # Last section, so scale by reciprocal and write current output to HBM + _scale_reciprocal_write_back_impl(bufs.mm2_sb[grp_i], grp_i, ac, atp, bufs, o, batch_id, num_p, num_f) + else: + # Not last section, so just write current output to HBM + _write_back_o_impl(bufs.mm2_sb[grp_i], grp_i, ac, atp, o, batch_id, num_p, num_f) + + if sp.section_idx > 0: + # Load previous output scale by flash_attn_correction_factor and accumulate + if ac.tp_out: + # Original: load(dst=mm2_prev_output[grp_i], src=o[batch_id, nl.ds(0, d), nl.ds(grp_i*sb_p, 128)], mask=mm2_sbuf_f_mask) + # o shape: (batch_size, d, seqlen), accessing o[batch_id, 0:d, grp_i*sb_p:grp_i*sb_p+num_f] + # Offset: batch_id*d*seqlen_q + grp_i*sb_p + prev_dst_pat = bufs.mm2_prev_output[grp_i].ap(pattern=[[atp.sb_p, ac.d], [1, num_f]], offset=0) + o_src_pat = o.ap( + pattern=[[ac.seqlen_q, ac.d], [1, num_f]], + offset=batch_id * ac.d * ac.seqlen_q + grp_i * atp.sb_p, + ) + nisa.dma_copy(dst=prev_dst_pat, src=o_src_pat) + + nisa.nc_transpose( + bufs.tp_flash_attn_correction_factor_psum[grp_i].ap( + pattern=[[atp.sb_p, ac.d], [1, atp.sb_p]], offset=0 + ), + bufs.flash_attn_correction_factor[grp_i].ap(pattern=[[1, atp.sb_p], [0, ac.d]], offset=0), + ) + nisa.tensor_copy( + bufs.tp_flash_attn_correction_factor_sb[grp_i][:, :num_f], + bufs.tp_flash_attn_correction_factor_psum[grp_i][:, :num_f], + ) + nisa.tensor_tensor( + bufs.mm2_prev_output_scaled[grp_i][:, :num_f], + bufs.mm2_prev_output[grp_i][:, :num_f], + bufs.tp_flash_attn_correction_factor_sb[grp_i][:, :num_f], + nl.multiply, + ) + nisa.tensor_tensor( + bufs.mm2_accum_flash_attn[grp_i][:, :num_f], + bufs.mm2_prev_output_scaled[grp_i][:, :num_f], + bufs.mm2_sb[grp_i][:, :num_f], + nl.add, + ) + else: + # Original: load(dst=mm2_prev_output[grp_i], src=o[batch_id, grp_i*sb_p+ip_o, if_o], mask=mm2_sbuf_p_mask) + # o shape: (batch_size, seqlen, d), accessing o[batch_id, grp_i*sb_p:grp_i*sb_p+num_p, 0:num_f] + # Offset: batch_id*seqlen_q*d + grp_i*sb_p*d + prev_dst_pat = bufs.mm2_prev_output[grp_i].ap(pattern=[[ac.d, num_p], [1, ac.d]], offset=0) + o_src_pat = o.ap( + pattern=[[ac.d, num_p], [1, ac.d]], + offset=batch_id * ac.seqlen_q * ac.d + grp_i * atp.sb_p * ac.d, + ) + nisa.dma_copy(dst=prev_dst_pat, src=o_src_pat) + nisa.scalar_tensor_tensor( + bufs.mm2_accum_flash_attn[grp_i][:num_p, : ac.d], + data=bufs.mm2_prev_output[grp_i][:num_p, : ac.d], + op0=nl.multiply, + operand0=bufs.flash_attn_correction_factor[grp_i][:num_p, 0], + op1=nl.add, + operand1=bufs.mm2_sb[grp_i][:num_p, : ac.d], + ) + if sp.section_idx == atp.num_sections - 1 or is_last_section_with_compute: + # Last section, so scale by reciprocal and write accumulated output to HBM + _scale_reciprocal_write_back_impl( + bufs.mm2_accum_flash_attn[grp_i], + grp_i, + ac, + atp, + bufs, + o, + batch_id, + num_p, + num_f, + ) + else: + # Not last section, just write accumulated output to HBM + _write_back_o_impl( + bufs.mm2_accum_flash_attn[grp_i], + grp_i, + ac, + atp, + o, + batch_id, + num_p, + num_f, + ) + else: + # Only one section, so scale by reciprocal and write current output to HBM + _scale_reciprocal_write_back_impl(bufs.mm2_sb[grp_i], grp_i, ac, atp, bufs, o, batch_id, num_p, num_f) + + +def _scale_reciprocal_write_back_impl( + src_buf, + grp_i, + ac: AttnConfig, + atp: AttnTileParams, + bufs: AttnInternalBuffers, + o, + batch_id, + num_p, + num_f, +): + """ + Write back o for the final section after multiplication by reciprocal. Transposes reciprocal if tp_out. + """ + if ac.tp_out: + # Original: tp_exp_sum_reciprocal_psum[grp_i] = nisa.nc_transpose(exp_sum_reciprocal[ip_broadcast, grp_i]) + nisa.nc_transpose( + bufs.tp_exp_sum_reciprocal_psum[grp_i].ap(pattern=[[atp.sb_p, ac.d], [1, atp.sb_p]], offset=0), + bufs.exp_sum_reciprocal.ap(pattern=[[atp.num_grps, atp.sb_p], [0, ac.d]], offset=grp_i), + ) + nisa.tensor_copy( + bufs.tp_exp_sum_reciprocal_sb[grp_i][: ac.d, :num_f], + bufs.tp_exp_sum_reciprocal_psum[grp_i][: ac.d, :num_f], + ) + nisa.tensor_tensor( + bufs.mm2_final[grp_i][: ac.d, :num_f], + src_buf[: ac.d, :num_f], + bufs.tp_exp_sum_reciprocal_sb[grp_i][: ac.d, :num_f], + nl.multiply, + ) + else: + nisa.activation( + bufs.mm2_final[grp_i][:num_p, : ac.d], + nl.copy, + src_buf[:num_p, : ac.d], + scale=bufs.exp_sum_reciprocal[:num_p, grp_i], + bias=bufs.zero_bias_tensor[:num_p], + ) + + _write_back_o_impl(bufs.mm2_final[grp_i], grp_i, ac, atp, o, batch_id, num_p, num_f) + + +def _write_back_o_impl(src_buf, grp_i, ac: AttnConfig, atp: AttnTileParams, o, batch_id, num_p, num_f): + """Helper function to write a source buffer to HBM output (o) with proper transpose handling. + + Args: + src_buf: Source buffer in SBUF to copy from + grp_i: Q group index + ac: Attention configuration + atp: Tile parameters + o: Output HBM tensor + batch_id: Batch index + num_p: Number of partition elements + num_f: Number of free elements + """ + if ac.tp_out: + # o shape: (batch_size, d, seqlen), accessing o[batch_id, 0:d, grp_i*sb_p:grp_i*sb_p+num_f] + # Offset: batch_id*d*seqlen_q + grp_i*sb_p + o_dst_pat = o.ap( + pattern=[[ac.seqlen_q, ac.d], [1, num_f]], + offset=batch_id * ac.d * ac.seqlen_q + grp_i * atp.sb_p, + ) + src_pat = src_buf.ap(pattern=[[atp.sb_p, ac.d], [1, num_f]], offset=0) + else: + # o shape: (batch_size, seqlen, d), accessing o[batch_id, grp_i*sb_p:grp_i*sb_p+num_p, 0:d] + # Offset: batch_id*seqlen_q*d + grp_i*sb_p*d + o_dst_pat = o.ap( + pattern=[[ac.d, num_p], [1, ac.d]], + offset=batch_id * ac.seqlen_q * ac.d + grp_i * atp.sb_p * ac.d, + ) + src_pat = src_buf.ap(pattern=[[ac.d, num_p], [1, ac.d]], offset=0) + nisa.dma_copy(dst=o_dst_pat, src=src_pat) + + +def _qk_and_max_large_tile_impl( + qkmax_grp, + large_tile_idx, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, +): + """Compute QK^T matmul (MM1) and find row-wise maximum for this Q group and large (2048) K tile. + Also apply masking if relevant. + """ + + q_seqlen_offset = qkmax_grp * atp.sb_p + + # perform matmul and masking in 512 (_K_TILE_SZ) tile increment on the seqlen dimension + num_k_tiles_in_large_tile = _LARGE_TILE_SZ // _K_TILE_SZ + for k_tile_idx in range(num_k_tiles_in_large_tile): + # Extract relevant tensor tiles for convenience + mm1_psum_tile = bufs.mm1_psum[qkmax_grp][large_tile_idx][k_tile_idx] + if not atp.dynamic_sel_mask: + mm1_copy_sb_tile = bufs.mm1_copy_sb[qkmax_grp][large_tile_idx][k_tile_idx] + mm1_affine_select_output_tile = bufs.mm1_affine_select_output[qkmax_grp][large_tile_idx][k_tile_idx] + mm1_masked_tile = bufs.mm1_masked[qkmax_grp][large_tile_idx] + mm1_partial_max_tile = bufs.mm1_partial_max[qkmax_grp] + + k_tile_idx_in_section = large_tile_idx * num_k_tiles_in_large_tile + k_tile_idx + k_tile_idx_global = atp.num_k_tiles_per_section * sp.section_idx + k_tile_idx_in_section + is_prior_tile, seqlen_k, k_start_pos, _ = _get_kv_tile_apc( + ac.is_prefix_caching, + False, + True, + atp.seqlen_k_active_updated, + ac.seqlen_k_prior, + k_tile_idx_global * _K_TILE_SZ, + None, + ) + + if atp.is_causal and not is_prior_tile: + # Only produce matmul if the tile is in the lower triangle, use tile bot-left corner so adjust q + matmul_selection = _has_any_compute_causal(qkmax_grp, k_start_pos, ac) + # If using SWA, also skip bot-left lower triangle. + if ac.use_swa: + # Use tile top-right corner so adjust k; also adjust q for sliding window + matmul_selection = matmul_selection and _has_any_compute_swa(qkmax_grp, k_start_pos, _K_TILE_SZ, ac) + else: + matmul_selection = True + + if q_seqlen_offset >= ac.seqlen_q or k_start_pos >= seqlen_k: # make sure we don't extend bound + matmul_selection = False + + if matmul_selection and k_tile_idx_in_section < atp.num_k_tiles_per_section: + num_f = min(seqlen_k - k_start_pos, _K_TILE_SZ) + num_q_free = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + + # Step 1: MM1 matmul + nisa.nc_matmul( + mm1_psum_tile[:num_q_free, :num_f], + bufs.q_sb[qkmax_grp // atp.num_q_grps_per_load][ + : ac.d, + nl.ds((qkmax_grp % atp.num_q_grps_per_load) * _Q_GRP_SZ, num_q_free), + ], + bufs.k_sb[k_tile_idx_in_section][:, :num_f], + ) + + # Step 2: Masking + num_p = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + num_f = min(seqlen_k - k_start_pos, _K_TILE_SZ) + + diagonal_sel_mask = ( + matmul_selection and ((qkmax_grp * _Q_GRP_SZ) < (k_start_pos + _K_TILE_SZ)) + if (atp.is_causal and not is_prior_tile and not atp.dynamic_sel_mask) + else False + ) + if ac.use_swa and atp.is_causal and not is_prior_tile: + # when we are using SWA, above condition for diagonal_sel_mask + # might miss some conditions where masking needs to be applied + # since it only checks for causal condition. Therefore we either + # need dynamic mask or affine select mask + diagonal_sel_mask = not atp.dynamic_sel_mask + + if diagonal_sel_mask: # static diagonal mask + # q_pos = qkmax_grp*sb_p + nl.arange(num_p)[:, None] + # k_pos = k_start_pos + nl.arange(num_f)[None, :] + + # Mask off upper-triangle with affine_select + # causal_pred = (q_pos >= k_pos) # causal predicate preventing q tokens to look beyond + # qkmax_grp*sb_p - k_start_pos + nisa.tensor_copy( + mm1_copy_sb_tile[:num_p, :num_f], + mm1_psum_tile[:num_p, :num_f], + ) + nisa.affine_select( + mm1_affine_select_output_tile[:num_p, :num_f], + pattern=[[-1, num_f]], + offset=qkmax_grp * atp.sb_p - k_start_pos, + channel_multiplier=1, + cmp_op=nl.greater_equal, + on_true_tile=mm1_copy_sb_tile[:num_p, :num_f], + on_false_value=_FLOAT32_MIN, + ) + + # Need extra affine_sel for smaller lower-triangle if use_swa (affine_sel cannot combination of masks) + if ac.use_swa: + # swa_pred = (q_pos < k_pos + sliding_window) # k_pos + sliding_window - 1 >= q_pos + nisa.affine_select( + mm1_affine_select_output_tile[:num_p, :num_f], + pattern=[[1, num_f]], + offset=(k_start_pos + ac.sliding_window - 1 - qkmax_grp * atp.sb_p), + channel_multiplier=-1, + cmp_op=nl.greater_equal, + on_true_tile=mm1_affine_select_output_tile[:num_p, :num_f], + on_false_value=_FLOAT32_MIN, + ) + + nisa.tensor_scalar_reduce( + mm1_masked_tile[:num_p, nl.ds(k_tile_idx * _K_TILE_SZ, num_f)], + data=mm1_affine_select_output_tile[:num_p, :num_f], + op0=nl.multiply, + operand0=ac.scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max_tile[:num_p, k_tile_idx_in_section], + ) + + elif atp.dynamic_sel_mask or is_prior_tile: # dynamic (compile-time unknown) mask + if is_prior_tile: + bound0 = bufs.range_sel_lbs_prior[:num_p, qkmax_grp] if ac.use_swa else bufs.zero_bias_tensor + bound1 = bufs.range_sel_ubs_prior[:num_p, qkmax_grp] + comp_op1 = nl.less # k < prior_used_len + elif ac.is_sequence_packed: + bound0 = bufs.range_sel_lbs[:num_p, nl.ds(qkmax_grp, 1)] + bound1 = bufs.range_sel_ubs[:num_p, nl.ds(qkmax_grp, 1)] + comp_op1 = nl.less_equal if atp.is_causal else nl.less + else: + bound0 = bufs.range_sel_lbs[:num_p, qkmax_grp] if ac.use_swa else bufs.zero_bias_tensor + bound1 = bufs.range_sel_ubs[:num_p, qkmax_grp] + comp_op1 = nl.less_equal # k <= q + cp_offset + + kernel_assert(ac.scale == 1.0, "range_select path doesn't support scale != 1.0") + nisa.range_select( + mm1_masked_tile[:num_p, nl.ds(k_tile_idx * _K_TILE_SZ, num_f)], + on_true_tile=mm1_psum_tile[:num_p, :num_f], + on_false_value=_FLOAT32_MIN, + comp_op0=nl.greater_equal, + comp_op1=comp_op1, + bound0=bound0[:num_p, :1], + bound1=bound1[:num_p, :1], + reduce_op=nl.maximum, + reduce_res=mm1_partial_max_tile[:num_p, k_tile_idx_in_section], + reduce_cmd=reduce_cmd.reset_reduce, + range_start=k_start_pos, + ) + + else: # no masking + nisa.tensor_scalar_reduce( + mm1_masked_tile[:num_p, nl.ds(k_tile_idx * _K_TILE_SZ, num_f)], + data=mm1_psum_tile[:num_p, :num_f], + op0=nl.multiply, + operand0=ac.scale, + reduce_op=nl.maximum, + reduce_res=mm1_partial_max_tile[:num_p, k_tile_idx_in_section], + ) + + +def _pv_large_tile_impl( + pv_grp, + large_tile_idx, + ac: AttnConfig, + atp: AttnTileParams, + sp: SectionParams, + bufs: AttnInternalBuffers, +): + """Perform MM2 (P@V) matmul for the Q grp and large (2048) V tile.""" + + q_seqlen_offset = pv_grp * atp.sb_p + num_mm2_grps_in_large_tile = _LARGE_TILE_SZ // atp.mm2_grp_sz + mm2_psum_set = False # track if matmul happens so we can skip later step + + mm2_psum_tile = bufs.mm2_psum[pv_grp][large_tile_idx] + + # Step 1: Perform matmul and accumulate in PSUM for each group + for mm2_grp_i in range(num_mm2_grps_in_large_tile): + num_mm2_per_grp = atp.mm2_grp_sz // _V_TILE_SZ + num_mm2_per_large_tile = num_mm2_per_grp * num_mm2_grps_in_large_tile + + exp_tp_sb_tile = bufs.exp_tp_sb[pv_grp][large_tile_idx][mm2_grp_i] + + is_prior_tile, seqlen_k, k_start_pos_512_tile, _ = _get_kv_tile_apc( + ac.is_prefix_caching, + False, + True, + atp.seqlen_k_active_updated, + ac.seqlen_k_prior, + sp.section_offset + large_tile_idx * _LARGE_TILE_SZ + mm2_grp_i * atp.mm2_grp_sz, + None, + ) + + for mm2_i in range(num_mm2_per_grp): + v_tile_idx = large_tile_idx * num_mm2_per_large_tile + mm2_grp_i * num_mm2_per_grp + mm2_i + k_start_pos = k_start_pos_512_tile + mm2_i * _V_TILE_SZ + + num_p = min(seqlen_k - k_start_pos, _V_TILE_SZ) + num_f = min(ac.seqlen_q - q_seqlen_offset, _Q_GRP_SZ) + + # Only produce matmul if the tile is in the lower triangle, use tile bot-left corner so adjust q + mm2_sel_mask = ( + _has_any_compute_causal(pv_grp, k_start_pos, ac) if (atp.is_causal and not is_prior_tile) else True + ) + # If using SWA, also skip bot-left lower triangle. + if ac.use_swa and atp.is_causal and not is_prior_tile: + # Use tile top-right corner so adjust k; also adjust q for sliding window + mm2_sel_mask = mm2_sel_mask and _has_any_compute_swa(pv_grp, k_start_pos, _V_TILE_SZ, ac) + + if mm2_sel_mask and v_tile_idx < atp.num_v_tiles_per_section and num_p > 0 and num_f > 0: + mm2_psum_set = True + # src partition mask: (k_start_pos+nl.arange(128)[:, None]= min_k_in_tile + + +def _has_any_compute_swa(q_grp: int, k_start_pos: int, k_tile_size: int, ac: AttnConfig): + """ + Return true if the given q group has any compute (i.e., we cannot fully mask out) + for the provided k start position, based on the sliding window mask. + + :param q_grp: q group index + :param k_start_pos: start pos of k + :param k_tile_size: tile size of k + :param ac: AttnConfig + """ + # We can completely eliminate compute when when even the smallest q in tile is + # >= largest k in tile + sw + min_q_in_grp = q_grp * _Q_GRP_SZ + if ac.cp_strided_q_slicing: + # multiply by stride (global_cp_deg). For min q we can assume rank_id = 0 + min_q_in_grp = min_q_in_grp * ac.global_cp_deg + max_k_in_tile = k_start_pos + k_tile_size - 1 + return min_q_in_grp < max_k_in_tile + ac.sliding_window diff --git a/contrib/models/Qwen-Image-Edit/src/autoencoder_kl_qwenimage_neuron.py b/contrib/models/Qwen-Image-Edit/src/autoencoder_kl_qwenimage_neuron.py new file mode 100644 index 00000000..3797ff66 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/autoencoder_kl_qwenimage_neuron.py @@ -0,0 +1,1051 @@ +# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We gratefully acknowledge the Wan Team for their outstanding contributions. +# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance. +# For more information about the Wan VAE, please refer to: +# - GitHub: https://github.com/Wan-Video/Wan2.1 +# - Paper: https://huggingface.co/papers/2503.20314 + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin +from diffusers.utils import logging +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.activations import get_activation +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.autoencoders.vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + input_channels=3, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + input_channels=3, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, input_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + @register_to_config + def __init__( + self, + base_dim: int = 96, + z_dim: int = 16, + dim_mult: Tuple[int, ...] = (1, 2, 4, 4), + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + input_channels: int = 3, + latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/contrib/models/Qwen-Image-Edit/src/cache_hf_model.py b/contrib/models/Qwen-Image-Edit/src/cache_hf_model.py new file mode 100644 index 00000000..36a1f220 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/cache_hf_model.py @@ -0,0 +1,14 @@ +import torch +from diffusers import QwenImageEditPlusPipeline + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + +if __name__ == "__main__": + print(f"Downloading {MODEL_ID} to {CACHE_DIR}...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + cache_dir=CACHE_DIR + ) + print("Model downloaded successfully!") diff --git a/contrib/models/Qwen-Image-Edit/src/compile.sh b/contrib/models/Qwen-Image-Edit/src/compile.sh new file mode 100755 index 00000000..9fa24d2b --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile.sh @@ -0,0 +1,360 @@ +#!/bin/bash + +# Compile Qwen-Image-Edit-2509 for Neuron (trn2) +# ALL components must be compiled to run on Trainium2 +# +# Default settings: +# - Output size: 1024x1024 +# - VAE tile size: 512x512 (fixed, uses tiled processing for larger images) +# - max_sequence_length: 1024 +# - tp_degree: 8 (for transformer) +# - patch_multiplier: 3 (for 2-image merging) +# - batch_size: 1 (for inference batching) +# +# Usage: +# ./compile.sh # Compile all versions +# ./compile.sh v1 # Compile V1 only +# ./compile.sh v2 # Compile V2 only +# ./compile.sh v1_flash # Compile V1 Flash only (NKI Flash Attention) +# ./compile.sh v2_flash # Compile V2 Flash only (ModelBuilder + NKI) +# ./compile.sh v3_cp # Compile V3 CP (Context Parallel + NKI) +# ./compile.sh v3_cp 1024 768 448 8 1024 3 2 # V3 CP with batch_size=2 +# ./compile.sh v3_cfg # Compile V3 CFG (CFG Parallel + NKI, recommended, fastest) +# ./compile.sh v3_cfg 1024 1024 448 8 1024 3 1 # Custom dimensions with batch_size + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +export PYTHONPATH="${SCRIPT_DIR}:$PYTHONPATH" +COMPILED_MODELS_DIR="/opt/dlami/nvme/compiled_models_qwen_image_edit" +COMPILER_WORKDIR="/opt/dlami/nvme/compiler_workdir_qwen_image_edit" + +# Fixed VAE tile size (VAE uses tiled processing for larger images) +VAE_TILE_SIZE=512 + +# Check if first argument is version selector +VERSION_MODE="all" +if [[ "$1" == "v1" || "$1" == "v2" || "$1" == "v1_flash" || "$1" == "v2_flash" || "$1" == "v3_cp" || "$1" == "v3_cfg" ]]; then + VERSION_MODE="$1" + shift +fi + +# Parse arguments +HEIGHT=${1:-1024} +WIDTH=${2:-1024} +IMAGE_SIZE=${3:-448} # Vision encoder image size (must be divisible by 14 and result in even grid) +TP_DEGREE=${4:-8} +MAX_SEQ_LEN=${5:-1024} +PATCH_MULTIPLIER=${6:-3} # 2 for single image editing, 3 for 2 images merging, 1 for generation +BATCH_SIZE=${7:-1} # Batch size for transformer/language model (for batched inference) + +# VAE compile-time batch dimension. For V3 CP / V3 CFG we default to 6 so the +# tiled VAE encoder/decoder can run all 6 tiles of a 1024x1024 image in a +# single NEFF launch (saves ~5x ~37ms launch overhead per image; ~6% E2E). +# The runtime _tiled_encode/_tiled_decode falls back to chunking for any +# tile count, so larger compiled batch is safe. +if [[ "$VERSION_MODE" == "v3_cp" || "$VERSION_MODE" == "v3_cfg" ]]; then + VAE_BATCH_SIZE=${VAE_BATCH_SIZE:-6} +else + VAE_BATCH_SIZE=${VAE_BATCH_SIZE:-1} +fi + +echo "============================================" +echo "Qwen-Image-Edit-2509 Compilation for Neuron" +echo "============================================" +echo "Transformer Version: ${VERSION_MODE}" +echo "Output Size: ${HEIGHT}x${WIDTH}" +echo "VAE Tile Size: ${VAE_TILE_SIZE}x${VAE_TILE_SIZE} (fixed)" +echo "Vision Encoder Image Size: ${IMAGE_SIZE}" +echo "TP Degree: ${TP_DEGREE}" +echo "Max Sequence Length: ${MAX_SEQ_LEN}" +echo "Patch Multiplier: ${PATCH_MULTIPLIER}" +echo "Batch Size: ${BATCH_SIZE}" +echo "VAE Batch Size: ${VAE_BATCH_SIZE}" +echo "" + +# Step 1: Download the model +echo "[Step 1/4] Downloading model..." +python ${SCRIPT_DIR}/cache_hf_model.py +echo "Model downloaded successfully!" +echo "" + +# Step 2: Compile VAE (encoder and decoder) +echo "[Step 2/4] Compiling VAE..." +echo "VAE tile size: ${VAE_TILE_SIZE}x${VAE_TILE_SIZE} (tiled processing for larger images)" +echo "Using modified VAE with 'nearest' interpolation (Neuron doesn't support 'nearest-exact')" +python ${SCRIPT_DIR}/compile_vae.py \ + --height ${VAE_TILE_SIZE} \ + --width ${VAE_TILE_SIZE} \ + --temporal_frames 1 \ + --batch_size ${VAE_BATCH_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} +echo "VAE compiled successfully!" +echo "" + +# Step 3: Compile Transformer +echo "[Step 3/4] Compiling Transformer..." +echo " TP=${TP_DEGREE}, patch_multiplier=${PATCH_MULTIPLIER} (for image editing)" + +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v1" ]]; then + echo " Compiling V1 (parallel_model_trace)..." + python ${SCRIPT_DIR}/compile_transformer.py \ + --height ${HEIGHT} \ + --width ${WIDTH} \ + --tp_degree ${TP_DEGREE} \ + --patch_multiplier ${PATCH_MULTIPLIER} \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V1 Transformer compiled successfully!" +fi + +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v2" ]]; then + echo " Compiling V2 (ModelBuilder)..." + python ${SCRIPT_DIR}/compile_transformer_v2.py \ + --height ${HEIGHT} \ + --width ${WIDTH} \ + --tp_degree ${TP_DEGREE} \ + --patch_multiplier ${PATCH_MULTIPLIER} \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} + echo " V2 Transformer compiled successfully!" +fi + +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v1_flash" ]]; then + echo " Compiling V1 Flash (NKI Flash Attention, recommended)..." + python ${SCRIPT_DIR}/compile_transformer_v1_flash.py \ + --height ${HEIGHT} \ + --width ${WIDTH} \ + --tp_degree ${TP_DEGREE} \ + --patch_multiplier ${PATCH_MULTIPLIER} \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V1 Flash Transformer compiled successfully!" +fi + +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v2_flash" ]]; then + echo " Compiling V2 Flash (ModelBuilder + NKI Flash Attention)..." + python ${SCRIPT_DIR}/compile_transformer_v2_flash.py \ + --height ${HEIGHT} \ + --width ${WIDTH} \ + --tp_degree ${TP_DEGREE} \ + --patch_multiplier ${PATCH_MULTIPLIER} \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} + echo " V2 Flash Transformer compiled successfully!" +fi + +if [[ "$VERSION_MODE" == "v3_cp" ]]; then + echo " Compiling V3 CP (Context Parallel + NKI Flash Attention)..." + echo " Using TP=4, world_size=8 (CP=2)" + python ${SCRIPT_DIR}/compile_transformer_v3_cp.py \ + --height ${HEIGHT} \ + --width ${WIDTH} \ + --tp_degree 4 \ + --world_size 8 \ + --patch_multiplier ${PATCH_MULTIPLIER} \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --batch_size ${BATCH_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V3 CP Transformer compiled successfully!" + + # Also compile V3 Language Model (ModelBuilder API, TP=4, world_size=8) + echo "" + echo " Compiling V3 Language Model (ModelBuilder API)..." + echo " Using TP=4, world_size=8 (compatible with V3 CP transformer)" + python ${SCRIPT_DIR}/compile_language_model_v3.py \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --batch_size ${BATCH_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V3 Language Model compiled successfully!" + + # Also compile V3 Vision Encoder (ModelBuilder API, TP=4, world_size=8, float32) + echo "" + echo " Compiling V3 Vision Encoder (ModelBuilder API)..." + echo " Using TP=4, world_size=8, float32 (faster than single device)" + python ${SCRIPT_DIR}/compile_vision_encoder_v3.py \ + --image_size ${IMAGE_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V3 Vision Encoder compiled successfully!" +fi + +if [[ "$VERSION_MODE" == "v3_cfg" ]]; then + echo " Compiling V3 CFG (CFG Parallel + NKI Flash Attention)..." + echo " Using TP=4, world_size=8 (DP=2 for batched CFG)" + python ${SCRIPT_DIR}/compile_transformer_v3_cfg.py \ + --height ${HEIGHT} \ + --width ${WIDTH} \ + --tp_degree 4 \ + --world_size 8 \ + --patch_multiplier ${PATCH_MULTIPLIER} \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V3 CFG Transformer compiled successfully!" + + # Also compile V3 Language Model (shared with V3 CP) + echo "" + echo " Compiling V3 Language Model (ModelBuilder API)..." + echo " Using TP=4, world_size=8 (compatible with V3 CFG transformer)" + python ${SCRIPT_DIR}/compile_language_model_v3.py \ + --max_sequence_length ${MAX_SEQ_LEN} \ + --batch_size ${BATCH_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V3 Language Model compiled successfully!" + + # Also compile V3 Vision Encoder (shared with V3 CP) + echo "" + echo " Compiling V3 Vision Encoder (ModelBuilder API)..." + echo " Using TP=4, world_size=8, float32 (faster than single device)" + python ${SCRIPT_DIR}/compile_vision_encoder_v3.py \ + --image_size ${IMAGE_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo " V3 Vision Encoder compiled successfully!" +fi +echo "" + +# Step 4: Vision Encoder (float32 for accuracy) - single device version +# Skip for v3_cp/v3_cfg mode since V3 vision encoder is already compiled above +if [[ "$VERSION_MODE" != "v3_cp" && "$VERSION_MODE" != "v3_cfg" ]]; then + echo "[Step 4/4] Compiling Vision Encoder (float32, single device)..." + echo "Note: Text encoder (Qwen2.5-VL) has two components:" + echo " - Vision Encoder: compiled in float32 for accuracy (single device)" + echo " - Language Model: runs on CPU (28Q/4KV heads incompatible with TP=8)" + python ${SCRIPT_DIR}/compile_text_encoder.py \ + --vision_only \ + --image_size ${IMAGE_SIZE} \ + --compiled_models_dir ${COMPILED_MODELS_DIR} \ + --compiler_workdir ${COMPILER_WORKDIR} + echo "Vision Encoder (float32) compiled!" +fi +echo "" + +echo "============================================" +echo "Compilation Complete!" +echo "============================================" +echo "" +echo "Compiled models saved to: ${COMPILED_MODELS_DIR}/" +echo " - vae_encoder/ (tile: ${VAE_TILE_SIZE}x${VAE_TILE_SIZE}, batch: ${VAE_BATCH_SIZE})" +echo " - vae_decoder/ (tile: ${VAE_TILE_SIZE}x${VAE_TILE_SIZE}, batch: ${VAE_BATCH_SIZE})" +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v1" ]]; then + echo " - transformer/ (V1, TP=${TP_DEGREE}, output: ${HEIGHT}x${WIDTH})" +fi +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v2" ]]; then + echo " - transformer_v2/ (V2, TP=${TP_DEGREE}, output: ${HEIGHT}x${WIDTH})" +fi +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v1_flash" ]]; then + echo " - transformer_v1_flash/ (V1 Flash, TP=${TP_DEGREE}, output: ${HEIGHT}x${WIDTH}, NKI Flash Attention)" +fi +if [[ "$VERSION_MODE" == "all" || "$VERSION_MODE" == "v2_flash" ]]; then + echo " - transformer_v2_flash/ (V2 Flash, TP=${TP_DEGREE}, output: ${HEIGHT}x${WIDTH}, ModelBuilder + NKI)" +fi +if [[ "$VERSION_MODE" == "v3_cp" ]]; then + echo " - transformer_v3_cp/ (V3 CP, TP=4, CP=2, output: ${HEIGHT}x${WIDTH}, batch: ${BATCH_SIZE})" + echo " - language_model_v3/ (V3, TP=4, world_size=8, batch: ${BATCH_SIZE})" + echo " - vision_encoder_v3/ (V3, TP=4, world_size=8, float32)" +elif [[ "$VERSION_MODE" == "v3_cfg" ]]; then + echo " - transformer_v3_cfg/ (V3 CFG, TP=4, DP=2, output: ${HEIGHT}x${WIDTH}, batch: 2)" + echo " - language_model_v3/ (V3, TP=4, world_size=8, batch: ${BATCH_SIZE})" + echo " - vision_encoder_v3/ (V3, TP=4, world_size=8, float32)" +else + echo " - vision_encoder/ (float32)" +fi +echo "" +if [[ "$VERSION_MODE" == "v3_cp" ]]; then + echo "Note: V3 CP mode compiles all components with ModelBuilder API" + echo " - Transformer: TP=4, CP=2 (Context Parallel)" + echo " - Language Model: TP=4 (perfect GQA fit)" + echo " - Vision Encoder: TP=4, float32 (faster)" +elif [[ "$VERSION_MODE" == "v3_cfg" ]]; then + echo "Note: V3 CFG mode compiles all components with ModelBuilder API" + echo " - Transformer: TP=4, DP=2 (CFG Parallel, batch=2)" + echo " - Language Model: TP=4 (perfect GQA fit)" + echo " - Vision Encoder: TP=4, float32 (faster)" + echo " CFG Parallel batches negative+positive prompts for ~2x denoising speedup" +else + echo "Note: Language model runs on CPU (GQA 28Q/4KV incompatible with TP=8)" +fi +echo "" +echo "To run inference on Trainium2:" +echo "" +echo " # NOTE: defaults compile with patch_multiplier=3 (TWO-IMAGE merge)." +echo " # The examples below use the 2-image merge form to match that." +echo " # For single-image editing, recompile with patch_multiplier=2." +echo "" +if [[ "$VERSION_MODE" == "v3_cp" ]]; then + echo " # V3 CP (recommended, all V3 components enabled by default):" + echo " NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py \\" + echo " --images image1.png image2.png \\" + echo " --prompt \"merge subjects from image1 and image2 into a single scene\" \\" + echo " --patch_multiplier 3" + echo "" + echo " # Note: --use_v3_vision_encoder is now default (10-15x faster than CPU)" + echo " # Use --no-use_v3_vision_encoder to disable" + echo "" +fi +if [[ "$VERSION_MODE" == "v3_cfg" ]]; then + echo " # V3 CFG (CFG Parallel, batches neg+pos prompts for ~2x denoising speedup):" + echo " NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py \\" + echo " --images image1.png image2.png \\" + echo " --prompt \"merge subjects from image1 and image2 into a single scene\" \\" + echo " --patch_multiplier 3 \\" + echo " --use_v3_cfg" + echo "" + echo " # Note: --use_v3_cfg is mutually exclusive with --use_v3_cp" + echo " # --use_v3_vision_encoder is enabled by default" + echo "" +fi +echo " # V1 Flash (NKI Flash Attention):" +echo " python run_qwen_image_edit.py \\" +echo " --images image1.png image2.png \\" +echo " --prompt \"merge subjects from image1 and image2 into a single scene\" \\" +echo " --patch_multiplier 3 \\" +echo " --use_v1_flash" +echo "" +echo " # V2 Flash (ModelBuilder + NKI, same speed as V1 Flash):" +echo " python run_qwen_image_edit.py \\" +echo " --images image1.png image2.png \\" +echo " --prompt \"merge subjects from image1 and image2 into a single scene\" \\" +echo " --patch_multiplier 3 \\" +echo " --use_v2_flash" +echo "" +echo " # V2 (ModelBuilder):" +echo " python run_qwen_image_edit.py \\" +echo " --images image1.png image2.png \\" +echo " --prompt \"merge subjects from image1 and image2 into a single scene\" \\" +echo " --patch_multiplier 3 \\" +echo " --use_v2" +echo "" +echo " # V1:" +echo " python run_qwen_image_edit.py \\" +echo " --images image1.png image2.png \\" +echo " --prompt \"merge subjects from image1 and image2 into a single scene\" \\" +echo " --patch_multiplier 3" +echo "" + +# Single-image editing example (CFG enabled by default, true_cfg_scale=4.0). +# Requires the model to be compiled with patch_multiplier=2. +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png --prompt "turn the woman into a man" --warmup + +# Two-image merge example (requires patch_multiplier=3, the default in this script). +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "..." --patch_multiplier 3 --warmup + +# Full reference invocations across all transformer versions (two-image wedding-photo merge). +# Prompt asks the model to combine the woman from image1 and the man from image2 into a wedding photo: +# red Chinese-style groom jacket, ornate xiuhe bridal robe, gold phoenix crown, vermillion palace wall +# with carved wooden lattice windows, soft and bright lighting, symmetric composition. +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup --use_v1 +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup --use_v2 +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup --use_v1_flash +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup --use_v2_flash +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup --use_v3_cp +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup --use_v3_cfg +# NEURON_RT_NUM_CORES=8 python run_qwen_image_edit.py --images image1.png image2.png --prompt "Using the woman from image 1 and the man from image 2, generate a wedding photo following this description: the groom wears a red Chinese-style jacket, the bride wears an ornate xiuhe robe and a gold phoenix crown. They stand side by side in front of an ancient vermillion palace wall, with carved wooden lattice windows in the background. Soft bright lighting, symmetric composition, festive and ceremonial atmosphere." --patch_multiplier 3 --warmup diff --git a/contrib/models/Qwen-Image-Edit/src/compile_language_model_v3.py b/contrib/models/Qwen-Image-Edit/src/compile_language_model_v3.py new file mode 100644 index 00000000..09d03f10 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_language_model_v3.py @@ -0,0 +1,389 @@ +""" +Language Model Compilation using ModelBuilder API (V3) for V3 CP Compatibility. + +This script compiles the Qwen2.5-VL Language Model using ModelBuilder API with +tp_degree=4 and world_size=8 to be compatible with the V3 CP transformer. + +Key features: +- Uses ModelBuilder API (NxDModel) for compilation +- Configuration: tp_degree=4, world_size=8 (matching V3 CP transformer) +- TP=4 is perfect for Qwen2.5-VL GQA: 28Q/4=7 heads/rank, 4KV/4=1 head/rank +- No Context Parallel needed (language model processes full sequence) + +Usage: + python compile_language_model_v3.py --max_sequence_length 1024 +""" + +import os +import json +import gc + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer --enable-fast-loading-neuron-binaries """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import torch.nn as nn +import argparse + +from diffusers import QwenImageEditPlusPipeline + +# ModelBuilder imports +from neuronx_distributed import ModelBuilder, NxDParallelState, shard_checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers import parallel_state + +from neuron_parallel_utils import ( + shard_qwen2_attention, + shard_qwen2_mlp, + get_sharded_data, +) + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def load_pipeline(dtype=torch.bfloat16): + """Load pipeline with appropriate kwargs.""" + load_kwargs = {"torch_dtype": dtype, "local_files_only": True} + if CACHE_DIR: + load_kwargs["cache_dir"] = CACHE_DIR + return QwenImageEditPlusPipeline.from_pretrained(MODEL_ID, **load_kwargs) + + +class f32Wrapper(nn.Module): + """Wrapper to run normalization layers in float32 for numerical stability.""" + def __init__(self, original): + super().__init__() + self.original = original + + def forward(self, x, *args, **kwargs): + t = x.dtype + y = x.to(torch.float32) + output = self.original(y, *args, **kwargs) + return output.type(t) + + +def upcast_norms_to_f32(module): + """Upcast normalization layers to float32 for numerical stability.""" + for name, child in module.named_children(): + if isinstance(child, (torch.nn.LayerNorm,)): + setattr(module, name, f32Wrapper(child.to(torch.float32))) + elif 'RMSNorm' in child.__class__.__name__: + setattr(module, name, f32Wrapper(child.to(torch.float32))) + else: + upcast_norms_to_f32(child) + + +class NeuronLanguageModelV3(nn.Module): + """ + Neuron-optimized Qwen2.5-VL Language Model for V3 CP compatibility. + + Uses ModelBuilder API with tp_degree=4, world_size=8. + + Key differences from compile_text_encoder.py: + - Uses ModelBuilder API instead of parallel_model_trace + - world_size=8 to match transformer (even though CP is not used for language model) + - TP=4 for perfect GQA alignment (28Q/4=7, 4KV/4=1 - no padding needed!) + + Note: Unlike V3 CP transformer which splits sequence, language model processes + full sequence on all ranks. The world_size=8 is for compatibility only. + + IMPORTANT: We keep the full language_model structure and just shard the layers, + rather than recreating the forward pass. This ensures position_embeddings are + properly computed from position_ids by the original model's rotary_emb. + """ + + def __init__(self, original_language_model, tp_degree): + super().__init__() + + self.tp_degree = tp_degree + + # Keep the full language model (we'll modify its layers in-place) + self.language_model = original_language_model + + # Copy config for reference + self.config = original_language_model.config + + # Get model structure info + self.hidden_size = self.config.hidden_size # 3584 + self.num_hidden_layers = self.config.num_hidden_layers # 28 + + print(f" Language model config:") + print(f" hidden_size: {self.hidden_size}") + print(f" num_hidden_layers: {self.num_hidden_layers}") + print(f" num_attention_heads: {self.config.num_attention_heads}") # 28 + print(f" num_key_value_heads: {self.config.num_key_value_heads}") # 4 + + # Shard the layers in-place + for i, layer in enumerate(self.language_model.layers): + # Shard attention + layer.self_attn = shard_qwen2_attention(tp_degree, layer.self_attn) + # Shard MLP + layer.mlp = shard_qwen2_mlp(layer.mlp) + if i == 0: + print(f" Sharded layer 0 attention and MLP") + + print(f" Sharded all {len(self.language_model.layers)} layers") + + # Upcast norms to float32 for numerical stability + upcast_norms_to_f32(self.language_model) + + def forward(self, inputs_embeds, attention_mask, position_ids): + """ + Forward pass for language model. + + Args: + inputs_embeds: (batch, seq_len, hidden_size) - combined text+vision embeddings + attention_mask: (batch, seq_len) - 1 for valid tokens, 0 for padding + position_ids: (3, batch, seq_len) - 3D position IDs for M-RoPE + Dims: [t (temporal), h (height), w (width)] x batch x seq_len + + Returns: + hidden_states: (batch, seq_len, hidden_size) + """ + # Call the full language model, which handles: + # 1. Computing position_embeddings from position_ids via rotary_emb + # 2. Creating the attention mask + # 3. Running through all layers + # 4. Final layer norm + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + return_dict=True + ) + return outputs.last_hidden_state + + +class TracingWrapper(nn.Module): + """Wrapper for ModelBuilder tracing.""" + + def __init__(self, language_model): + super().__init__() + self.language_model = language_model + + def forward(self, inputs_embeds, attention_mask, position_ids): + return self.language_model(inputs_embeds, attention_mask, position_ids) + + +def compile_language_model_v3(args): + """ + Compile Language Model using ModelBuilder API. + + Configuration: + - tp_degree=4: Perfect for GQA (28Q/4=7, 4KV/4=1) + - world_size=8: Matches V3 CP transformer (even though CP is not used) + """ + tp_degree = 4 # Fixed: perfect GQA alignment + world_size = int(os.environ.get("LM_WORLD_SIZE", 8)) # 4 for TP=4 CP=1, 8 for TP=4 CP=2 + + batch_size = args.batch_size + sequence_length = args.max_sequence_length + hidden_size = 3584 # Qwen2.5-VL hidden size + + print("=" * 60) + print("Compiling Language Model V3 (ModelBuilder API)") + print("=" * 60) + print(f" Batch size: {batch_size}") + print(f" Sequence length: {sequence_length}") + print(f" Hidden size: {hidden_size}") + print(f" TP degree: {tp_degree}") + print(f" World size: {world_size}") + print(f" GQA: 28 Q heads / 4 = 7 per rank, 4 KV heads / 4 = 1 per rank") + print("") + + # Sample inputs + sample_inputs_embeds = torch.randn( + batch_size, sequence_length, hidden_size, dtype=torch.bfloat16 + ) + sample_attention_mask = torch.ones( + batch_size, sequence_length, dtype=torch.int64 + ) + # 3D position_ids for M-RoPE: (3, batch, seq_len) + # For tracing, use simple sequential positions (text-only pattern) + sample_position_ids = torch.arange(sequence_length).view(1, 1, -1).expand(3, batch_size, -1).clone() + + print(f"Sample input shapes:") + print(f" inputs_embeds: {sample_inputs_embeds.shape}") + print(f" attention_mask: {sample_attention_mask.shape}") + print(f" position_ids: {sample_position_ids.shape}") + print("") + + # Use NxDParallelState context for compilation + with NxDParallelState(world_size=world_size, tensor_model_parallel_size=tp_degree): + print("Loading model...") + pipe = load_pipeline(torch.bfloat16) + + # Extract language model + original_language_model = pipe.text_encoder.model.language_model + + # Save unsharded state dict before modifications + print("Saving unsharded state dict...") + unsharded_state = original_language_model.state_dict() + + # Create Neuron language model with sharding + print(f"\nCreating Neuron language model (sharding layers with TP={tp_degree})...") + neuron_language_model = NeuronLanguageModelV3( + original_language_model, tp_degree + ) + neuron_language_model = neuron_language_model.to(torch.bfloat16) + neuron_language_model.eval() + + # Clear pipeline to save memory (language model is now owned by neuron_language_model) + del pipe + gc.collect() + + # Wrap for tracing + model = TracingWrapper(neuron_language_model) + + print("\nInitializing ModelBuilder...") + builder = ModelBuilder(model=model) + + print("Tracing model...") + builder.trace( + kwargs={ + "inputs_embeds": sample_inputs_embeds, + "attention_mask": sample_attention_mask, + "position_ids": sample_position_ids, + }, + tag="inference", + ) + + print("Compiling model...") + # NOTE: Using -O1 instead of -O2 because -O2 can cause numerical issues in some cases + compile_args = "--model-type=transformer -O1 --auto-cast=none" + traced_model = builder.compile( + compiler_args=compile_args, + compiler_workdir=args.compiler_workdir, + ) + + # Save + output_path = f"{args.compiled_models_dir}/language_model_v3" + os.makedirs(output_path, exist_ok=True) + + print(f"\nSaving to {output_path}...") + traced_model.save(os.path.join(output_path, "nxd_model.pt")) + + # Save weights + weights_path = os.path.join(output_path, "weights") + os.makedirs(weights_path, exist_ok=True) + + # Prepare checkpoint for sharding + print("Preparing checkpoint...") + checkpoint = {} + for key, value in model.state_dict().items(): + # Use unsharded weights where available + # Key format: language_model.language_model.layers.X... -> layers.X... + # (TracingWrapper.language_model -> NeuronLanguageModelV3.language_model -> Qwen2_5_VLTextModel) + orig_key = key.replace("language_model.language_model.", "", 1) + if orig_key in unsharded_state: + checkpoint[key] = unsharded_state[orig_key].clone() + else: + checkpoint[key] = value.clone() + + # Shard checkpoint + print("Sharding weights...") + shard_checkpoint( + checkpoint=checkpoint, + model=model, + serialize_path=weights_path, + ) + + # Post-process checkpoints: remove master_weight and add inv_freq + print("\nPost-processing checkpoints...") + from safetensors.torch import load_file, save_file + + # Collect inv_freq buffers from original model (they are not in state_dict) + inv_freq_buffers = {} + for name, buf in neuron_language_model.language_model.named_buffers(): + if 'inv_freq' in name: + full_key = f"language_model.language_model.{name}" + inv_freq_buffers[full_key] = buf.to(torch.bfloat16).clone() + print(f" Collected {len(inv_freq_buffers)} inv_freq buffers") + + for rank in range(tp_degree): + shard_file = os.path.join(weights_path, f"tp{rank}_sharded_checkpoint.safetensors") + if not os.path.exists(shard_file): + print(f" WARNING: {shard_file} not found!") + continue + + # Load checkpoint + data = dict(load_file(shard_file)) + original_count = len(data) + original_size = sum(v.numel() * v.element_size() for v in data.values()) + + # Remove master_weight tensors. Clone because load_file returns mmap'd + # tensors; overwriting the source file invalidates their backing storage. + cleaned = {k: v.clone().contiguous() for k, v in data.items() if 'master_weight' not in k} + + # Add inv_freq buffers + cleaned.update({k: (v.clone().contiguous() if hasattr(v, 'clone') else v) + for k, v in inv_freq_buffers.items()}) + + cleaned_size = sum(v.numel() * v.element_size() for v in cleaned.values()) + + # Save optimized checkpoint + del data + save_file(cleaned, shard_file) + print(f" tp{rank}: {original_count} -> {len(cleaned)} tensors, " + f"{original_size/1e9:.2f}GB -> {cleaned_size/1e9:.2f}GB") + + # Save config + config = { + "max_sequence_length": sequence_length, + "hidden_size": hidden_size, + "batch_size": batch_size, + "tp_degree": tp_degree, + "world_size": world_size, + "num_hidden_layers": 28, + "num_attention_heads": 28, + "num_key_value_heads": 4, + } + config_path = os.path.join(output_path, "config.json") + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + print(f"\nConfig saved to {config_path}") + + print("\n" + "=" * 60) + print("Compilation complete!") + print("=" * 60) + print(f"Model saved to: {output_path}") + print(f" - nxd_model.pt") + print(f" - weights/tp{{0,1,2,3}}_sharded_checkpoint.safetensors") + print(f" - config.json") + print("") + print("To use with V3 CP transformer:") + print(" python run_qwen_image_edit.py --use_v3_cp --use_v3_language_model") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Compile Language Model V3 using ModelBuilder API") + parser.add_argument("--model_path", type=str, default=None, + help="Path to model (local dir or HuggingFace ID). If not set, uses MODEL_ID with CACHE_DIR") + parser.add_argument("--max_sequence_length", type=int, default=1024, + help="Maximum sequence length for compilation") + parser.add_argument("--batch_size", type=int, default=1, + help="Batch size for language model (default: 1)") + parser.add_argument("--compiled_models_dir", type=str, default="/opt/dlami/nvme/compiled_models", + help="Directory to save compiled models") + parser.add_argument("--compiler_workdir", type=str, default="/opt/dlami/nvme/compiler_workdir", + help="Directory for compiler artifacts") + args = parser.parse_args() + + # Override MODEL_ID and CACHE_DIR if model_path is provided + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + compile_language_model_v3(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_text_encoder.py b/contrib/models/Qwen-Image-Edit/src/compile_text_encoder.py new file mode 100644 index 00000000..f55b9e1f --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_text_encoder.py @@ -0,0 +1,727 @@ +""" +Text Encoder Compilation for Qwen-Image-Edit-2509 + +The text encoder (Qwen2.5-VL) is a multimodal vision-language model with: +1. Vision Encoder (Qwen2_5_VisionTransformerPretrainedModel) - 32 blocks +2. Language Model (Qwen2_5_VLTextModel) - 28 layers + +This script compiles both components for Trainium2 using tensor parallelism. +""" + +import os +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" # For trn2 +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" # For trn2 + +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer --enable-fast-loading-neuron-binaries """ # --verbose=INFO +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import argparse +import torch_neuronx +import neuronx_distributed +from functools import partial +from torch import nn + +from diffusers import QwenImageEditPlusPipeline +from neuron_commons import attention_wrapper, f32Wrapper +from neuron_parallel_utils import ( + shard_qwen2_attention, shard_qwen2_mlp, + shard_vision_attention, shard_vision_mlp +) + +# Override SDPA +torch.nn.functional.scaled_dot_product_attention = attention_wrapper + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def load_pipeline(dtype=torch.bfloat16): + """Load pipeline with appropriate kwargs based on MODEL_ID and CACHE_DIR.""" + load_kwargs = {"torch_dtype": dtype, "local_files_only": True} + if CACHE_DIR: + load_kwargs["cache_dir"] = CACHE_DIR + return QwenImageEditPlusPipeline.from_pretrained(MODEL_ID, **load_kwargs) + + +class VisionEncoderWrapper(nn.Module): + """ + Wrapper for the Qwen2.5-VL Vision Encoder. + Compiles the vision transformer that processes image patches. + """ + def __init__(self, visual): + super().__init__() + self.visual = visual + + def forward(self, pixel_values, grid_thw): + """ + Args: + pixel_values: (num_patches, 3*temporal*patch_h*patch_w) - flattened patches + grid_thw: (num_images, 3) - temporal, height, width in grid space + Returns: + image_embeds: (total_patches, hidden_size) + """ + return self.visual(pixel_values, grid_thw) + + +class LanguageModelWrapper(nn.Module): + """ + Wrapper for the Qwen2.5-VL Language Model. + Processes the combined text and vision embeddings. + + IMPORTANT: Must accept position_ids for M-RoPE (Multimodal RoPE) to work correctly. + Qwen2.5-VL uses 3D position_ids with shape [3, batch, seq_len] for: + - t (temporal): frame index for video, 0 for images + - h (height): spatial row position for image tokens + - w (width): spatial column position for image tokens + """ + def __init__(self, language_model, embed_tokens): + super().__init__() + self.language_model = language_model + self.embed_tokens = embed_tokens + + def forward(self, inputs_embeds, attention_mask, position_ids): + """ + Args: + inputs_embeds: (batch, seq_len, hidden_size) - combined text+vision embeddings + attention_mask: (batch, seq_len) + position_ids: (3, batch, seq_len) - 3D position IDs for M-RoPE + Returns: + hidden_states: (batch, seq_len, hidden_size) + """ + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + return_dict=True + ) + return outputs.last_hidden_state + + +class FullTextEncoderWrapper(nn.Module): + """ + Full wrapper for the Qwen2.5-VL text encoder with fixed shapes. + This is used when compiling the complete text encoder for image editing. + + For simplicity in compilation, we use a fixed sequence length and image size. + """ + def __init__(self, text_encoder, max_seq_len, num_image_tokens): + super().__init__() + self.text_encoder = text_encoder + self.config = text_encoder.config + self.max_seq_len = max_seq_len + self.num_image_tokens = num_image_tokens + + def forward(self, input_ids, attention_mask, pixel_values, image_grid_thw): + """ + Fixed-shape forward pass for tracing. + + Args: + input_ids: (batch, text_seq_len) + attention_mask: (batch, total_seq_len) + pixel_values: (num_patches, channels) - preprocessed image patches + image_grid_thw: (num_images, 3) - grid dimensions + Returns: + hidden_states: (batch, total_seq_len, hidden_size) + """ + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + return_dict=True + ) + return outputs.hidden_states[-1] + + +def upcast_norms_to_f32(module): + """Upcast normalization layers to float32 for numerical stability.""" + for name, child in module.named_children(): + if isinstance(child, (torch.nn.LayerNorm,)): + setattr(module, name, f32Wrapper(child.to(torch.float32))) + # Handle RMSNorm (Qwen uses this) + elif 'RMSNorm' in child.__class__.__name__: + setattr(module, name, f32Wrapper(child.to(torch.float32))) + else: + upcast_norms_to_f32(child) + + +def get_vision_encoder(tp_degree: int): + """Load and prepare vision encoder for tracing.""" + pipe = load_pipeline(torch.bfloat16) + + visual = pipe.text_encoder.model.visual + visual.eval() + upcast_norms_to_f32(visual) + + return VisionEncoderWrapper(visual), {} + + +def get_language_model(tp_degree: int): + """Load and shard language model for tensor parallelism.""" + pipe = load_pipeline(torch.bfloat16) + + text_encoder = pipe.text_encoder + lang_model = text_encoder.model.language_model + embed_tokens = lang_model.embed_tokens + lang_model.eval() + + # Shard the language model layers + for layer in lang_model.layers: + if hasattr(layer, 'self_attn'): + layer.self_attn = shard_qwen2_attention(tp_degree, layer.self_attn) + if hasattr(layer, 'mlp'): + layer.mlp = shard_qwen2_mlp(layer.mlp) + + upcast_norms_to_f32(lang_model) + + return LanguageModelWrapper(lang_model, embed_tokens), {} + + +def compile_vision_encoder(args): + """ + Compile the Vision Encoder component (single device mode). + + The vision encoder processes image patches and outputs vision embeddings. + Input shape depends on image size and patch configuration. + + Note: For better memory distribution, use compile_vision_encoder_tp() with --vision_tp flag. + """ + batch_size = 1 + image_size = args.image_size + patch_size = 14 + temporal_patch_size = 2 + spatial_merge_size = 2 + + # Validate image_size + if image_size % patch_size != 0: + raise ValueError( + f"image_size ({image_size}) must be divisible by patch_size ({patch_size}). " + f"Valid sizes: 224, 336, 448, 560, etc.") + + num_patches_per_side = image_size // patch_size + if num_patches_per_side % spatial_merge_size != 0: + raise ValueError( + f"image_size / patch_size ({num_patches_per_side}) must be divisible by " + f"spatial_merge_size ({spatial_merge_size}). " + f"Valid image sizes: 224, 336, 448, 560, etc.") + + # Calculate number of patches for a single image + # Qwen2.5-VL uses Conv3d with kernel (temporal_patch_size, patch_size, patch_size) + # For a single frame: num_patches = (H/patch_size) * (W/patch_size) + num_patches_h = image_size // patch_size + num_patches_w = image_size // patch_size + num_patches = num_patches_h * num_patches_w + + # pixel_values shape for the vision encoder + # After preprocessing, it's (num_patches, 3 * temporal_patch_size * patch_size * patch_size) + channels_per_patch = 3 * temporal_patch_size * patch_size * patch_size # 3*2*14*14 = 1176 + + compiler_workdir = args.compiler_workdir + compiled_models_dir = args.compiled_models_dir + + # Always use float32 for vision encoder (required for accuracy) + dtype = torch.float32 + + print("=" * 50) + print("Compiling Vision Encoder (Single Device, float32)") + print("=" * 50) + print(f" Image size: {image_size}x{image_size}") + print(f" Patch size: {patch_size}") + print(f" Num patches: {num_patches}") + print(f" Channels per patch: {channels_per_patch}") + print(f" Dtype: float32 (required for accuracy)") + + pipe = load_pipeline(dtype) + + visual = pipe.text_encoder.model.visual + visual.eval() + + # Keep everything in float32 for maximum precision + + # Sample inputs + # pixel_values: (total_patches, patch_dim) + sample_pixel_values = torch.ones((num_patches, channels_per_patch), dtype=dtype) + # grid_thw: (num_images, 3) - temporal, height, width in grid units + sample_grid_thw = torch.tensor([[1, num_patches_h, num_patches_w]], dtype=torch.int64) + + vision_wrapper = VisionEncoderWrapper(visual) + + # Use --auto-cast=none to prevent precision loss + vision_compiler_flags = compiler_flags + " --auto-cast=none" + + with torch.no_grad(): + try: + compiled_vision = torch_neuronx.trace( + vision_wrapper, + (sample_pixel_values, sample_grid_thw), + compiler_workdir=f"{compiler_workdir}/vision_encoder", + compiler_args=vision_compiler_flags, + inline_weights_to_neff=False + ) + + # Save to vision_encoder/ directory + vision_dir = f"{compiled_models_dir}/vision_encoder" + if not os.path.exists(vision_dir): + os.makedirs(vision_dir) + torch.jit.save(compiled_vision, f"{vision_dir}/model.pt") + print(f"Vision encoder (float32) compiled and saved to {vision_dir}") + return True + + except Exception as e: + print(f"Vision encoder compilation failed: {e}") + return False + + +def get_vision_encoder_tp(tp_degree: int, image_size: int): + """Load and shard vision encoder for tensor parallelism.""" + pipe = load_pipeline(torch.bfloat16) + + visual = pipe.text_encoder.model.visual + visual.eval() + + # Shard the vision encoder blocks + for block in visual.blocks: + if hasattr(block, 'attn'): + block.attn = shard_vision_attention(tp_degree, block.attn) + if hasattr(block, 'mlp'): + block.mlp = shard_vision_mlp(block.mlp) + + upcast_norms_to_f32(visual) + + return VisionEncoderWrapper(visual), {} + + +def compile_vision_encoder_tp(args): + """ + Compile the Vision Encoder with tensor parallelism. + + NOTE: The Qwen2.5-VL vision encoder has dimensions that are NOT divisible by 8. + Specifically, the fused QKV projection has dimension 3420 (1140 * 3). + - 3420 / 8 = 427.5 (NOT divisible) + - 3420 / 4 = 855 (divisible) + - 3420 / 2 = 1710 (divisible) + + Since transformer and language model require TP=8, and mixing different TP degrees + causes world_size conflicts, vision encoder TP is NOT recommended. + + This function will attempt TP compilation but is expected to fail with TP=8. + Use single-device compilation (--vision_only without --vision_tp) instead. + """ + batch_size = 1 + image_size = args.image_size + patch_size = 14 + temporal_patch_size = 2 + spatial_merge_size = 2 + tp_degree = args.tp_degree + + # Check if vision encoder dimensions are compatible with TP degree + vision_embed_dim = 1140 # Qwen2.5-VL vision encoder embed_dim + qkv_dim = vision_embed_dim * 3 # 3420 + + if qkv_dim % tp_degree != 0: + print("=" * 60) + print("WARNING: Vision Encoder TP Compilation Not Supported") + print("=" * 60) + print(f" Vision encoder QKV dimension: {qkv_dim}") + print(f" Requested TP degree: {tp_degree}") + print(f" {qkv_dim} is NOT divisible by {tp_degree}") + print("") + print(" The Qwen2.5-VL vision encoder has dimensions incompatible with TP=8.") + print(" Falling back to single-device compilation...") + print("") + + # Fall back to single device compilation + return compile_vision_encoder(args) + + os.environ["LOCAL_WORLD_SIZE"] = str(tp_degree) + + # Validate image_size + if image_size % patch_size != 0: + raise ValueError( + f"image_size ({image_size}) must be divisible by patch_size ({patch_size}). " + f"Valid sizes: 224, 336, 448, 560, etc.") + + num_patches_per_side = image_size // patch_size + if num_patches_per_side % spatial_merge_size != 0: + raise ValueError( + f"image_size / patch_size ({num_patches_per_side}) must be divisible by " + f"spatial_merge_size ({spatial_merge_size}). " + f"Valid image sizes: 224, 336, 448, 560, etc.") + + num_patches_h = image_size // patch_size + num_patches_w = image_size // patch_size + num_patches = num_patches_h * num_patches_w + + channels_per_patch = 3 * temporal_patch_size * patch_size * patch_size # 1176 + + compiler_workdir = args.compiler_workdir + compiled_models_dir = args.compiled_models_dir + dtype = torch.bfloat16 + + print("=" * 50) + print("Compiling Vision Encoder with Tensor Parallelism") + print("=" * 50) + print(f" Image size: {image_size}x{image_size}") + print(f" Patch size: {patch_size}") + print(f" Num patches: {num_patches}") + print(f" Channels per patch: {channels_per_patch}") + print(f" TP degree: {tp_degree}") + + get_vision_f = partial(get_vision_encoder_tp, tp_degree, image_size) + + # Sample inputs + sample_pixel_values = torch.ones((num_patches, channels_per_patch), dtype=dtype) + sample_grid_thw = torch.tensor([[1, num_patches_h, num_patches_w]], dtype=torch.int64) + + sample_inputs = (sample_pixel_values, sample_grid_thw) + + with torch.no_grad(): + try: + compiled_vision = neuronx_distributed.trace.parallel_model_trace( + get_vision_f, + sample_inputs, + compiler_workdir=f"{compiler_workdir}/vision_encoder_tp", + compiler_args=compiler_flags, + tp_degree=tp_degree, + inline_weights_to_neff=False + ) + + vision_dir = f"{compiled_models_dir}/vision_encoder_tp" + if not os.path.exists(vision_dir): + os.makedirs(vision_dir) + + neuronx_distributed.trace.parallel_model_save( + compiled_vision, vision_dir) + print(f"Vision encoder (TP={tp_degree}) compiled and saved to {vision_dir}") + return True + + except Exception as e: + print(f"Vision encoder TP compilation failed: {e}") + print("Falling back to single-device compilation...") + return compile_vision_encoder(args) + + +def compile_language_model(args): + """ + Compile the Language Model component with tensor parallelism. + + The language model processes text tokens combined with vision embeddings. + + Qwen2.5-VL-7B GQA configuration: + - 28 Q heads, 4 KV heads -> each KV head shared by 7 Q heads + + Supported TP degrees: + - TP=4: Standard sharding (7 Q heads, 1 KV head per rank) + - TP=8: KV replication mode (Q padded to 32 -> 4 per rank, KV replicated -> 1 per rank) + + The KV replication logic in shard_qwen2_attention handles TP=8 correctly by: + 1. Padding Q heads from 28 to 32 (divisible by 8) + 2. Replicating each KV head to pairs of ranks + 3. Updating num_key_value_groups to 4 (4 Q heads / 1 KV head per rank) + """ + batch_size = 1 + sequence_length = args.max_sequence_length + hidden_size = 3584 # Qwen2.5-VL hidden size + + # Use language-specific TP degree + tp_degree = getattr(args, 'language_tp_degree', 8) + + # Validate TP degree + num_kv_heads = 4 + if tp_degree > num_kv_heads and tp_degree % num_kv_heads != 0: + raise ValueError( + f"For TP={tp_degree} > num_kv_heads={num_kv_heads}, " + f"tp_degree must be divisible by num_kv_heads. " + f"Valid TP degrees: 1, 2, 4, 8" + ) + + if tp_degree == 8: + print("=" * 60) + print("INFO: Using KV Head Replication Mode (TP=8)") + print("=" * 60) + print(f" Q heads: 28 -> padded to 32 -> 4 per rank") + print(f" KV heads: 4 -> replicated -> 1 per rank") + print(f" num_key_value_groups: 4 (Q_per_rank / KV_per_rank)") + print("=" * 60) + + os.environ["LOCAL_WORLD_SIZE"] = str(tp_degree) + + compiler_workdir = args.compiler_workdir + compiled_models_dir = args.compiled_models_dir + + print("=" * 50) + print("Compiling Language Model") + print("=" * 50) + print(f" Sequence length: {sequence_length}") + print(f" Hidden size: {hidden_size}") + print(f" TP degree: {tp_degree}") + + get_lang_model_f = partial(get_language_model, tp_degree) + + with torch.no_grad(): + # inputs_embeds: (batch, seq_len, hidden_size) + sample_inputs_embeds = torch.ones( + (batch_size, sequence_length, hidden_size), dtype=torch.bfloat16) + # attention_mask: (batch, seq_len) + sample_attention_mask = torch.ones( + (batch_size, sequence_length), dtype=torch.int64) + # position_ids: (3, batch, seq_len) - 3D for M-RoPE + # For tracing, use simple sequential positions (text-only pattern) + sample_position_ids = torch.arange(sequence_length).view(1, 1, -1).expand(3, batch_size, -1).clone() + + sample_inputs = (sample_inputs_embeds, sample_attention_mask, sample_position_ids) + + try: + compiled_lang_model = neuronx_distributed.trace.parallel_model_trace( + get_lang_model_f, + sample_inputs, + compiler_workdir=f"{compiler_workdir}/language_model", + compiler_args=compiler_flags, + tp_degree=tp_degree, + inline_weights_to_neff=False + ) + + lang_model_dir = f"{compiled_models_dir}/language_model" + if not os.path.exists(lang_model_dir): + os.makedirs(lang_model_dir) + + neuronx_distributed.trace.parallel_model_save( + compiled_lang_model, lang_model_dir) + print(f"Language model compiled and saved to {lang_model_dir}") + return True + + except Exception as e: + print(f"Language model compilation failed: {e}") + return False + + +def compile_text_encoder_full(args): + """ + Compile the full text encoder (vision + language) with fixed shapes. + This is more complex but allows end-to-end compilation. + """ + batch_size = 1 + text_seq_len = args.max_sequence_length + image_size = args.image_size + patch_size = 14 + spatial_merge_size = 2 # Qwen2.5-VL spatial merge + + # Calculate image token count after spatial merge + num_patches_h = image_size // patch_size + num_patches_w = image_size // patch_size + merged_h = num_patches_h // spatial_merge_size + merged_w = num_patches_w // spatial_merge_size + num_image_tokens = merged_h * merged_w + + total_seq_len = text_seq_len + num_image_tokens + tp_degree = args.tp_degree # Use configurable TP degree (default=8) + + os.environ["LOCAL_WORLD_SIZE"] = str(tp_degree) + + compiler_workdir = args.compiler_workdir + compiled_models_dir = args.compiled_models_dir + + print("=" * 50) + print("Compiling Full Text Encoder") + print("=" * 50) + print(f" Image size: {image_size}") + print(f" Text sequence length: {text_seq_len}") + print(f" Image tokens: {num_image_tokens}") + print(f" Total sequence length: {total_seq_len}") + print(f" TP degree: {tp_degree}") + + def get_full_text_encoder(tp_degree): + pipe = load_pipeline(torch.bfloat16) + + text_encoder = pipe.text_encoder + text_encoder.eval() + + # Shard language model + lang_model = text_encoder.model.language_model + for layer in lang_model.layers: + if hasattr(layer, 'self_attn'): + layer.self_attn = shard_qwen2_attention(tp_degree, layer.self_attn) + if hasattr(layer, 'mlp'): + layer.mlp = shard_qwen2_mlp(layer.mlp) + + upcast_norms_to_f32(text_encoder) + + return FullTextEncoderWrapper(text_encoder, total_seq_len, num_image_tokens), {} + + get_encoder_f = partial(get_full_text_encoder, tp_degree) + + # Calculate pixel_values shape + num_patches = num_patches_h * num_patches_w + channels_per_patch = 3 * 2 * patch_size * patch_size # 1176 + + with torch.no_grad(): + sample_inputs = ( + torch.ones((batch_size, text_seq_len), dtype=torch.int64), + torch.ones((batch_size, total_seq_len), dtype=torch.int64), + torch.ones((num_patches, channels_per_patch), dtype=torch.bfloat16), + torch.tensor([[1, num_patches_h, num_patches_w]], dtype=torch.int64), + ) + + try: + compiled_encoder = neuronx_distributed.trace.parallel_model_trace( + get_encoder_f, + sample_inputs, + compiler_workdir=f"{compiler_workdir}/text_encoder", + compiler_args=compiler_flags, + tp_degree=tp_degree, + inline_weights_to_neff=False + ) + + encoder_dir = f"{compiled_models_dir}/text_encoder" + if not os.path.exists(encoder_dir): + os.makedirs(encoder_dir) + + neuronx_distributed.trace.parallel_model_save( + compiled_encoder, encoder_dir) + print(f"Full text encoder compiled and saved to {encoder_dir}") + return True + + except Exception as e: + print(f"Full text encoder compilation failed: {e}") + print("Try compiling vision encoder and language model separately.") + return False + + +def run_in_subprocess(func_name, args, vision_tp=False): + """Run a compilation function in a separate subprocess to avoid XLA conflicts.""" + import subprocess + import sys + + cmd = [ + sys.executable, __file__, + "--mode", "separate", + "--image_size", str(args.image_size), + "--max_sequence_length", str(args.max_sequence_length), + "--compiler_workdir", args.compiler_workdir, + "--compiled_models_dir", args.compiled_models_dir, + "--tp_degree", str(args.tp_degree), + "--language_tp_degree", str(getattr(args, 'language_tp_degree', 4)), + ] + + # Pass model_path if set + if getattr(args, 'model_path', None): + cmd.extend(["--model_path", args.model_path]) + + if func_name == "vision": + cmd.append("--vision_only") + if vision_tp: + cmd.append("--vision_tp") + elif func_name == "language": + cmd.append("--language_only") + + result = subprocess.run(cmd, capture_output=False) + return result.returncode == 0 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--mode", type=str, default="separate", + choices=["separate", "full"], + help="Compilation mode: 'separate' compiles vision and language separately, " + "'full' compiles the entire text encoder together") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max text sequence length") + parser.add_argument("--image_size", type=int, default=224, + help="Image size for vision encoder. Must be divisible by 14 (patch_size) " + "and result in even grid for spatial merge. Valid: 224, 336, 448, 560") + parser.add_argument("--compiler_workdir", type=str, default="compiler_workdir", + help="Directory for compiler artifacts") + parser.add_argument("--compiled_models_dir", type=str, default="compiled_models", + help="Directory for compiled models") + parser.add_argument("--vision_only", action="store_true", + help="Only compile vision encoder") + parser.add_argument("--vision_tp", action="store_true", + help="Compile vision encoder with tensor parallelism (TP=8) instead of single device. " + "Helps reduce per-device memory usage.") + parser.add_argument("--language_only", action="store_true", + help="Only compile language model") + parser.add_argument("--use_subprocess", action="store_true", + help="Run each compilation in separate subprocess (avoids XLA conflicts)") + parser.add_argument("--tp_degree", type=int, default=8, + help="Tensor parallel degree for vision encoder TP mode (default=8)") + parser.add_argument("--language_tp_degree", type=int, default=8, + help="Tensor parallel degree for language model. " + "TP=4: Standard sharding. TP=8: KV head replication mode. " + "Default=8 to match transformer TP degree.") + parser.add_argument("--model_path", type=str, default=None, + help="Path to model (local dir or HuggingFace ID). If not set, uses MODEL_ID with CACHE_DIR") + # Note: Vision encoder is always compiled in float32 for accuracy (required) + args = parser.parse_args() + + # Override MODEL_ID and CACHE_DIR if model_path is provided + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + if args.mode == "separate": + # If specific component requested, run directly + if args.vision_only: + if args.vision_tp: + print("\n[Vision Only] Compiling Vision Encoder with TP...") + compile_vision_encoder_tp(args) + else: + print("\n[Vision Only] Compiling Vision Encoder (single device)...") + compile_vision_encoder(args) + elif args.language_only: + print("\n[Language Only] Compiling Language Model...") + compile_language_model(args) + elif args.use_subprocess: + # Run in separate subprocesses to avoid XLA initialization conflicts + if args.vision_tp: + print("\n[Step 1] Compiling Vision Encoder with TP (subprocess)...") + else: + print("\n[Step 1] Compiling Vision Encoder (subprocess)...") + vision_success = run_in_subprocess("vision", args, vision_tp=args.vision_tp) + + print("\n[Step 2] Compiling Language Model (subprocess)...") + lang_success = run_in_subprocess("language", args) + + if vision_success and lang_success: + print("\n" + "=" * 50) + print("Text Encoder Compilation Complete!") + print("=" * 50) + if args.vision_tp: + print(" Vision Encoder: TP={} (saved to vision_encoder_tp/)".format(args.tp_degree)) + else: + print(" Vision Encoder: Single device (saved to vision_encoder/)") + print(" Language Model: TP={} (saved to language_model/)".format(args.language_tp_degree)) + else: + # Default: try sequential but warn about XLA issue + print("\nNOTE: If language model compilation fails with 'Runtime is already initialized',") + print(" run with --use_subprocess flag or compile separately:") + print(" python compile_text_encoder.py --vision_only [--vision_tp]") + print(" python compile_text_encoder.py --language_only") + print("") + + if args.vision_tp: + print("\n[Step 1] Compiling Vision Encoder with TP...") + vision_success = compile_vision_encoder_tp(args) + else: + print("\n[Step 1] Compiling Vision Encoder...") + vision_success = compile_vision_encoder(args) + + print("\n[Step 2] Compiling Language Model...") + lang_success = compile_language_model(args) + + if vision_success and lang_success: + print("\n" + "=" * 50) + print("Text Encoder Compilation Complete!") + print("=" * 50) + if args.vision_tp: + print(" Vision Encoder: TP={} (saved to vision_encoder_tp/)".format(args.tp_degree)) + else: + print(" Vision Encoder: Single device (saved to vision_encoder/)") + print(" Language Model: TP={} (saved to language_model/)".format(args.language_tp_degree)) + else: + compile_text_encoder_full(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_transformer.py b/contrib/models/Qwen-Image-Edit/src/compile_transformer.py new file mode 100644 index 00000000..0fa8c6c4 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_transformer.py @@ -0,0 +1,218 @@ +import os +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" # For trn2 +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" # For trn2 + +# Compiler flags optimized for transformer models (based on Flux reference) +# Key optimizations: +# - --model-type=transformer: Enables transformer-specific optimizations +# - --enable-ccop-compute-overlap: Overlaps communication with computation +# - --auto-cast=none: Preserves bfloat16 precision +# - -O1: Basic optimization level (O2 can cause issues with some models) +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer -O1 --auto-cast=none --enable-fast-loading-neuron-binaries --tensorizer-options='--enable-ccop-compute-overlap' --internal-hlo2tensorizer-options='--fuse-dot-logistic=false' """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import argparse +import neuronx_distributed +from functools import partial +from torch import nn + +from diffusers import QwenImageEditPlusPipeline +from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel + +from neuron_commons import neuron_scaled_dot_product_attention +from neuron_parallel_utils import shard_qwen_attention, shard_feedforward, shard_modulation +from neuron_rope import patch_qwenimage_rope + +# Override SDPA globally for Neuron compatibility during compilation +# NOTE: NKI Flash Attention kernel doesn't work with parallel_model_trace (XLA tracing limitation) +# Using basic attention implementation instead +print("Using Neuron-compatible SDPA for compilation") +torch.nn.functional.scaled_dot_product_attention = neuron_scaled_dot_product_attention + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +class TracingTransformerWrapper(nn.Module): + """Wrapper for tracing the transformer model.""" + def __init__(self, transformer: QwenImageTransformer2DModel, img_shapes): + super().__init__() + self.transformer = transformer + self.config = transformer.config + self.dtype = transformer.dtype + self.device = transformer.device + # Store img_shapes as a fixed attribute for tracing + self.img_shapes = img_shapes + + def forward(self, hidden_states, encoder_hidden_states, timestep): + """ + Forward pass matching QwenImageTransformer2DModel signature. + + Args: + hidden_states: (batch, num_patches, in_channels) - patchified latents + encoder_hidden_states: (batch, text_seq_len, text_hidden_dim) - text embeddings + timestep: (batch,) - diffusion timestep + """ + return self.transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_shapes=self.img_shapes, + return_dict=False) + + +def get_transformer_model(tp_degree: int, img_shapes: list): + """Load and shard the transformer model for tensor parallelism.""" + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + local_files_only=True, + cache_dir=CACHE_DIR) + + # Patch RoPE to use Neuron-compatible implementation (no complex numbers) + print("Patching RoPE for Neuron compatibility...") + pipe.transformer = patch_qwenimage_rope(pipe.transformer) + + num_blocks = len(pipe.transformer.transformer_blocks) + print(f"Sharding {num_blocks} transformer blocks with TP={tp_degree}") + + # Shard transformer blocks + for block_idx, block in enumerate(pipe.transformer.transformer_blocks): + if block_idx == 0: + print(f"Block 0 attention heads: {block.attn.heads}") + print(f"Block 0 to_q shape: {block.attn.to_q.weight.shape}") + print(f"Block 0 img_mod shape: {block.img_mod[1].weight.shape}") + + # Shard attention + block.attn = shard_qwen_attention(tp_degree, block.attn) + + if block_idx == 0: + print(f"After sharding - Block 0 attention heads: {block.attn.heads}") + + # Shard feedforward (img_mlp and txt_mlp) + block.img_mlp = shard_feedforward(block.img_mlp) + block.txt_mlp = shard_feedforward(block.txt_mlp) + + # Shard modulation layers (img_mod and txt_mod) - THIS WAS MISSING! + # These account for 6.8B params that were duplicated on every rank! + block.img_mod = shard_modulation(block.img_mod) + block.txt_mod = shard_modulation(block.txt_mod) + + if block_idx == 0: + print(f"After sharding - Block 0 img_mod shape: {block.img_mod[1].weight.shape}") + + if (block_idx + 1) % 10 == 0: + print(f" Processed {block_idx + 1}/{num_blocks} blocks") + + print(f"All {num_blocks} blocks sharded successfully") + + transformer_wrapper = TracingTransformerWrapper(pipe.transformer, img_shapes) + return transformer_wrapper, {} + + +def compile_transformer(args): + tp_degree = args.tp_degree # Tensor parallel degree + os.environ["LOCAL_WORLD_SIZE"] = str(tp_degree) + + latent_height = args.height // 8 + latent_width = args.width // 8 + max_sequence_length = args.max_sequence_length + text_hidden_size = 3584 # Text encoder hidden size + in_channels = 64 # QwenImage transformer in_channels + patch_size = 2 # QwenImage patch size + + # For IMAGE EDITING, the pipeline concatenates source image latents with noise latents. + # This is handled by increasing temporal_frames to match patch_multiplier. + # - patch_multiplier=1 (generation): temporal_frames=1, patches = 1 * 32 * 32 = 1024 + # - patch_multiplier=2 (editing): temporal_frames=2, patches = 2 * 32 * 32 = 2048 + temporal_frames = args.patch_multiplier + + # Calculate number of patches + # QwenImage uses patch_size=2, so num_patches = T * (H/8/2) * (W/8/2) + patch_h = latent_height // patch_size + patch_w = latent_width // patch_size + num_patches = temporal_frames * patch_h * patch_w + + if args.patch_multiplier > 1: + print(f" NOTE: Image editing mode with patch_multiplier={args.patch_multiplier}") + print(f" Using temporal_frames={temporal_frames} to generate RoPE for {num_patches} patches") + + # img_shapes: List of (frame, height, width) for each batch item + # Note: height/width here are in patch space (latent_h // patch_size) + # temporal_frames is set to patch_multiplier to match the concatenated patches + img_shapes = [(temporal_frames, patch_h, patch_w)] * args.batch_size + + compiler_workdir = args.compiler_workdir + compiled_models_dir = args.compiled_models_dir + batch_size = args.batch_size # Always 1, CFG runs transformer twice sequentially + + print(f"Compiling transformer with:") + print(f" Image size: {args.height}x{args.width}") + print(f" Latent size: {latent_height}x{latent_width}") + print(f" Patch size: {patch_size}") + print(f" Num patches: {num_patches}") + print(f" Text sequence length: {max_sequence_length}") + print(f" Batch size: {batch_size}") + print(f" img_shapes: {img_shapes}") + + # Sample inputs matching transformer wrapper forward signature + # hidden_states: (batch, num_patches, in_channels) + sample_hidden_states = torch.ones( + (batch_size, num_patches, in_channels), dtype=torch.bfloat16) + # encoder_hidden_states: (batch, text_seq_len, text_hidden_size) + sample_encoder_hidden_states = torch.ones( + (batch_size, max_sequence_length, text_hidden_size), dtype=torch.bfloat16) + # timestep: (batch,) + sample_timestep = torch.ones((batch_size,), dtype=torch.float32) + + get_transformer_f = partial(get_transformer_model, tp_degree, img_shapes) + + with torch.no_grad(): + sample_inputs = ( + sample_hidden_states, + sample_encoder_hidden_states, + sample_timestep, + ) + + compiled_transformer = neuronx_distributed.trace.parallel_model_trace( + get_transformer_f, + sample_inputs, + compiler_workdir=f"{compiler_workdir}/transformer", + compiler_args=compiler_flags, + tp_degree=tp_degree, + inline_weights_to_neff=False, + ) + + compiled_model_dir = f"{compiled_models_dir}/transformer" + if not os.path.exists(compiled_model_dir): + os.makedirs(compiled_model_dir) + + neuronx_distributed.trace.parallel_model_save( + compiled_transformer, compiled_model_dir) + print(f"Transformer compiled and saved to {compiled_model_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--height", type=int, default=512, + help="Height of generated image") + parser.add_argument("--width", type=int, default=512, + help="Width of generated image") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max sequence length for text encoder") + parser.add_argument("--batch_size", type=int, default=1, + help="Batch size (always 1, CFG runs transformer twice sequentially)") + parser.add_argument("--tp_degree", type=int, default=8, + help="Tensor parallel degree (8 to match language model)") + parser.add_argument("--patch_multiplier", type=int, default=2, + help="Patch multiplier for image editing (2 for src+noise concat, 1 for generation)") + parser.add_argument("--compiler_workdir", type=str, default="compiler_workdir", + help="Directory for compiler artifacts") + parser.add_argument("--compiled_models_dir", type=str, default="compiled_models", + help="Directory for compiled models") + args = parser.parse_args() + compile_transformer(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_transformer_v1_flash.py b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v1_flash.py new file mode 100644 index 00000000..dc8b562c --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v1_flash.py @@ -0,0 +1,626 @@ +""" +Transformer compilation using parallel_model_trace (V1 API) with NKI Flash Attention. + +Key approach: +1. Uses parallel_model_trace API (supports NKI Flash Attention) +2. RoPE frequencies computed OUTSIDE the model and passed as INPUT tensors (like V2) +3. Uses NKI Flash Attention kernel for better performance + +This combines V1's NKI support with V2's RoPE handling to get the best of both. +""" + +import os +import json +import math + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +# CRITICAL: Disable XLA functionalization to allow NKI kernel in-place operations +# Functionalization converts in-place ops to out-of-place, which breaks NKI kernels +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags optimized for transformer +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer -O1 --auto-cast=none --enable-fast-loading-neuron-binaries --tensorizer-options='--enable-ccop-compute-overlap' """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +import neuronx_distributed +from functools import partial +from typing import Optional, Tuple + +from diffusers import QwenImageEditPlusPipeline + +from neuron_parallel_utils import ( + shard_qwen_attention, + shard_feedforward, + shard_modulation, +) + +# Import NKI Flash Attention - use EXACTLY the same imports as Flux +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronxcc.nki.language import nc +from torch_neuronx.xla_impl.ops import nki_jit # Same as Flux + +# Create NKI callable - EXACTLY like Flux does +_flash_fwd_call = nki_jit()(attention_isa_kernel) + +NKI_AVAILABLE = True +print("NKI Flash Attention kernel loaded successfully") + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def nki_flash_attention(query, key, value): + """ + NKI Flash Attention wrapper for QwenImage. + + Args: + query: [B, H, S, D] - query tensor + key: [B, H, S, D] - key tensor + value: [B, H, S, D] - value tensor + + Returns: + attention output [B, H, S, D] + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + v_len = value.shape[2] + + # Reshape for NKI kernel: [B*H, D, S] for Q/K, [B*H, S, D] for V + q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) + k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, k_len)) + v = value.clone().reshape((bs * n_head, v_len, d_head)) + + # Pre-allocate output + attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device) + + scale = 1 / math.sqrt(d_head) + + # Use sharded kernel for VC_SIZE=2 + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + if vc_size == 2: + grid = (nc(2),) + _flash_fwd_call[grid]( + q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap" + ) + else: + _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + + # Reshape back to [B, H, S, D] + attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) + + return attn_output + + +class NKIQwenAttention(nn.Module): + """ + Custom attention module for QwenImage that uses NKI Flash Attention directly. + + This completely replaces diffusers' Attention class, similar to how Flux + uses NeuronFluxAttention. This avoids the XLA tracing issues with diffusers' + Attention.forward() method. + + Key design choices (matching Flux): + 1. Transpose Q, K, V to [B, H, S, D] format BEFORE attention + 2. Call NKI attention wrapper with [B, H, S, D] inputs (exactly like Flux) + 3. Transpose back after attention + """ + + def __init__(self, orig_attn): + """ + Initialize from an existing sharded attention module. + + Args: + orig_attn: The sharded diffusers Attention module + """ + super().__init__() + + # Copy all the layers from the original attention + self.heads = orig_attn.heads + self.to_q = orig_attn.to_q + self.to_k = orig_attn.to_k + self.to_v = orig_attn.to_v + self.to_out = orig_attn.to_out + + # Text projections + self.add_q_proj = orig_attn.add_q_proj if hasattr(orig_attn, 'add_q_proj') else None + self.add_k_proj = orig_attn.add_k_proj if hasattr(orig_attn, 'add_k_proj') else None + self.add_v_proj = orig_attn.add_v_proj if hasattr(orig_attn, 'add_v_proj') else None + self.to_add_out = orig_attn.to_add_out if hasattr(orig_attn, 'to_add_out') else None + + # Norms + self.norm_q = orig_attn.norm_q if hasattr(orig_attn, 'norm_q') else None + self.norm_k = orig_attn.norm_k if hasattr(orig_attn, 'norm_k') else None + self.norm_added_q = orig_attn.norm_added_q if hasattr(orig_attn, 'norm_added_q') else None + self.norm_added_k = orig_attn.norm_added_k if hasattr(orig_attn, 'norm_added_k') else None + + def forward( + self, + hidden_states: torch.Tensor, # Image stream [B, S_img, C] + encoder_hidden_states: torch.Tensor = None, # Text stream [B, S_txt, C] + encoder_hidden_states_mask: torch.Tensor = None, + image_rotary_emb: Tuple = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass with NKI Flash Attention - directly calls the kernel. + Follows Flux's pattern: transpose to [B, H, S, D] before attention. + """ + if encoder_hidden_states is None: + raise ValueError("NKIQwenAttention requires encoder_hidden_states") + + batch_size = hidden_states.shape[0] + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream + img_query = self.to_q(hidden_states) + img_key = self.to_k(hidden_states) + img_value = self.to_v(hidden_states) + + # Compute QKV for text stream + txt_query = self.add_q_proj(encoder_hidden_states) + txt_key = self.add_k_proj(encoder_hidden_states) + txt_value = self.add_v_proj(encoder_hidden_states) + + # Get head dimension + inner_dim = img_query.shape[-1] + head_dim = inner_dim // self.heads + + # Reshape to [B, S, H, D] then transpose to [B, H, S, D] - exactly like Flux + img_query = img_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_key = img_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_value = img_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + txt_query = txt_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_key = txt_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_value = txt_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Apply QK normalization (Flux does this after reshape too) + if self.norm_q is not None: + img_query = self.norm_q(img_query) + if self.norm_k is not None: + img_key = self.norm_k(img_key) + if self.norm_added_q is not None: + txt_query = self.norm_added_q(txt_query) + if self.norm_added_k is not None: + txt_key = self.norm_added_k(txt_key) + + # Apply RoPE - note: input is now [B, H, S, D] + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + # Transpose to [B, S, H, D] for RoPE, then back to [B, H, S, D] + img_query = apply_rotary_emb_precomputed(img_query.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + img_key = apply_rotary_emb_precomputed(img_key.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + txt_query = apply_rotary_emb_precomputed(txt_query.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + txt_key = apply_rotary_emb_precomputed(txt_key.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + + # Concatenate for joint attention along sequence dim: [B, H, S_txt + S_img, D] + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + # Use NKI Flash Attention - input is [B, H, S, D] exactly like Flux + joint_hidden_states = nki_flash_attention(joint_query, joint_key, joint_value) + + # Transpose back and reshape: [B, H, S, D] -> [B, S, H*D] + joint_hidden_states = joint_hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = self.to_out[0](img_attn_output) + if len(self.to_out) > 1: + img_attn_output = self.to_out[1](img_attn_output) # dropout + + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +def replace_attention_with_nki(transformer): + """ + Replace all attention modules with NKI versions. + + This completely replaces diffusers' Attention class with our custom + NKIQwenAttention class, similar to how Flux uses NeuronFluxAttention. + """ + for i, block in enumerate(transformer.transformer_blocks): + # Replace the attention module entirely + block.attn = NKIQwenAttention(block.attn) + + print(f"Replaced attention modules with NKI versions on {len(transformer.transformer_blocks)} blocks") + + +def apply_rotary_emb_precomputed( + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """ + Apply rotary embeddings using PRE-COMPUTED cos/sin tensors. + + Handles BOTH use_real=True and use_real=False cases: + - use_real=False (QwenImage default): Complex multiplication simulation + - use_real=True: Standard cos/sin rotation + + Args: + x: [B, S, H, D] - input tensor, D = head_dim = 128 + freqs_cis: Tuple of (cos, sin), each [S, D/2] - NOT interleaved (D/2 = 64) + + Returns: + Rotated tensor [B, S, H, D] + """ + cos, sin = freqs_cis # Each [S, 64] + + # Move to same device as x + cos = cos.to(x.device) + sin = sin.to(x.device) + + if not use_real: + # QwenImage uses use_real=False (complex multiplication) + # Complex multiplication: (a + bi)(c + di) = (ac - bd) + (ad + bc)i + + # Reshape x to [B, S, H, D/2, 2] then split into real/imag + x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, 64, 2] + x_real = x_reshaped[..., 0] # [B, S, H, 64] + x_imag = x_reshaped[..., 1] # [B, S, H, 64] + + # Expand cos/sin for broadcasting: [S, 64] -> [1, S, 1, 64] + cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + sin = sin.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + + # Complex multiplication: (x_real + i*x_imag) * (cos + i*sin) + out_real = x_real * cos - x_imag * sin # [B, S, H, 64] + out_imag = x_real * sin + x_imag * cos # [B, S, H, 64] + + # Stack and flatten back to [B, S, H, 128] + out = torch.stack([out_real, out_imag], dim=-1) # [B, S, H, 64, 2] + out = out.flatten(-2) # [B, S, H, 128] + + return out.to(x.dtype) + else: + # use_real=True path (standard rotation) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + else: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +# Patch apply_rotary_emb_qwen to use our pre-computed version +import diffusers.models.transformers.transformer_qwenimage as qwen_module +qwen_module.apply_rotary_emb_qwen = apply_rotary_emb_precomputed +print("Patched apply_rotary_emb_qwen for pre-computed RoPE") + + +class NeuronQwenTransformerV1Flash(nn.Module): + """ + Neuron-optimized QwenImage Transformer for V1 Flash. + + Key features: + - Uses parallel_model_trace API (supports NKI Flash Attention) + - RoPE frequencies are passed as INPUT, not computed internally + - Uses NKI Flash Attention for better performance + """ + + def __init__(self, original_transformer, tp_degree): + super().__init__() + + self.config = original_transformer.config + self.in_channels = original_transformer.config.in_channels + self.out_channels = original_transformer.config.out_channels + self.patch_size = original_transformer.config.patch_size + + # Input projections (keep original) + self.img_in = original_transformer.img_in + self.txt_in = original_transformer.txt_in + + # Time/text embedding (keep original) + self.time_text_embed = original_transformer.time_text_embed + + # Text norm (keep original) + self.txt_norm = original_transformer.txt_norm + + # NOTE: We do NOT copy pos_embed (RoPE) - it will be passed as input! + + # Transformer blocks (need to shard) + self.transformer_blocks = nn.ModuleList() + for i, block in enumerate(original_transformer.transformer_blocks): + # Shard attention + block.attn = shard_qwen_attention(tp_degree, block.attn) + # Shard MLPs + block.img_mlp = shard_feedforward(block.img_mlp) + block.txt_mlp = shard_feedforward(block.txt_mlp) + # Shard modulation + block.img_mod = shard_modulation(block.img_mod) + block.txt_mod = shard_modulation(block.txt_mod) + self.transformer_blocks.append(block) + + if (i + 1) % 10 == 0: + print(f" Sharded block {i+1}/{len(original_transformer.transformer_blocks)}") + + # Final layers (keep original) + self.norm_out = original_transformer.norm_out + self.proj_out = original_transformer.proj_out + + # Store head_dim for RoPE + self.head_dim = 128 # QwenImage uses 128 + self.num_heads = original_transformer.transformer_blocks[0].attn.heads + + # Replace attention modules with NKI versions + # This completely replaces diffusers' Attention class with our custom class + # that directly calls NKI kernel, similar to how Flux does it + replace_attention_with_nki(self) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + img_rotary_emb: torch.Tensor, # [num_patches, 64, 2] + txt_rotary_emb: torch.Tensor, # [text_seq, 64, 2] + ) -> torch.Tensor: + """ + Forward pass with RoPE as INPUT. + """ + # Split RoPE into cos/sin + img_freqs_cos = img_rotary_emb[..., 0] # [num_patches, 64] + img_freqs_sin = img_rotary_emb[..., 1] + txt_freqs_cos = txt_rotary_emb[..., 0] # [text_seq, 64] + txt_freqs_sin = txt_rotary_emb[..., 1] + + # Image input projection + hidden_states = self.img_in(hidden_states) + + # Text processing + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + # Time embedding + timestep = timestep.to(hidden_states.dtype) + temb = self.time_text_embed(timestep, hidden_states) + + # Create rotary_emb tuple + image_rotary_emb = ((img_freqs_cos, img_freqs_sin), (txt_freqs_cos, txt_freqs_sin)) + + # Process through transformer blocks + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # Final norm and projection + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return output + + +class TracingWrapperV1Flash(nn.Module): + """Wrapper for parallel_model_trace tracing.""" + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + + def forward(self, hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb): + return self.transformer( + hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb + ) + + +def get_rope_from_original_model( + pipe, + frame: int, + height: int, + width: int, + text_seq_len: int, + dtype=torch.bfloat16, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get RoPE directly from the original QwenEmbedRope model. + """ + print(f" Getting RoPE from original model...") + print(f" video_fhw: ({frame}, {height}, {width}), text_seq_len: {text_seq_len}") + + video_fhw = (frame, height, width) + vid_freqs, txt_freqs = pipe.transformer.pos_embed( + video_fhw, txt_seq_lens=[text_seq_len], device=torch.device('cpu') + ) + + print(f" vid_freqs from model: {vid_freqs.shape}, dtype: {vid_freqs.dtype}") + print(f" txt_freqs from model: {txt_freqs.shape}, dtype: {txt_freqs.dtype}") + + # Convert complex to (cos, sin) + img_cos = vid_freqs.real.float() + img_sin = vid_freqs.imag.float() + txt_cos = txt_freqs.real.float() + txt_sin = txt_freqs.imag.float() + + # Stack to [S, 64, 2] + img_rotary_emb = torch.stack([img_cos, img_sin], dim=-1).to(dtype) + txt_rotary_emb = torch.stack([txt_cos, txt_sin], dim=-1).to(dtype) + + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + return img_rotary_emb, txt_rotary_emb + + +def get_transformer_model_v1_flash(tp_degree: int, img_rotary_emb: torch.Tensor, txt_rotary_emb: torch.Tensor): + """Load and create the transformer model for parallel_model_trace.""" + print("Loading pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + local_files_only=True, + cache_dir=CACHE_DIR + ) + + print("Creating Neuron transformer (sharding layers)...") + neuron_transformer = NeuronQwenTransformerV1Flash(pipe.transformer, tp_degree) + neuron_transformer = neuron_transformer.to(torch.bfloat16) + neuron_transformer.eval() + + # Wrap for tracing + model = TracingWrapperV1Flash(neuron_transformer) + + return model, {} + + +def compile_transformer_v1_flash(args): + """Compile transformer using parallel_model_trace with NKI Flash Attention.""" + + tp_degree = args.tp_degree + os.environ["LOCAL_WORLD_SIZE"] = str(tp_degree) + + # Calculate dimensions + latent_h = args.height // 8 + latent_w = args.width // 8 + patch_size = 2 + patch_h = latent_h // patch_size + patch_w = latent_w // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + + text_seq_len = args.max_sequence_length + text_hidden_size = 3584 + in_channels = 64 + head_dim = 128 + + print("=" * 60) + print("Transformer V1 Flash Compilation") + print("=" * 60) + print(f"Image: {args.height}x{args.width}") + print(f"Patches: {num_patches} ({temporal_frames}x{patch_h}x{patch_w})") + print(f"Text seq: {text_seq_len}") + print(f"TP degree: {tp_degree}") + print(f"NKI Flash Attention: Enabled") + + # First, load model to get RoPE + print("\nLoading model to get RoPE...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + local_files_only=True, + cache_dir=CACHE_DIR + ) + + # Get RoPE from original model + print("\nGetting RoPE from original model...") + img_rotary_emb, txt_rotary_emb = get_rope_from_original_model( + pipe=pipe, + frame=temporal_frames, + height=patch_h, + width=patch_w, + text_seq_len=text_seq_len, + ) + + # Clear the pipeline to free memory + del pipe + torch.cuda.empty_cache() if torch.cuda.is_available() else None + + # Sample inputs + sample_hidden_states = torch.randn(1, num_patches, in_channels, dtype=torch.bfloat16) + sample_encoder_hidden_states = torch.randn(1, text_seq_len, text_hidden_size, dtype=torch.bfloat16) + sample_timestep = torch.randn(1, dtype=torch.float32) + + get_transformer_f = partial(get_transformer_model_v1_flash, tp_degree, img_rotary_emb, txt_rotary_emb) + + with torch.no_grad(): + sample_inputs = ( + sample_hidden_states, + sample_encoder_hidden_states, + sample_timestep, + img_rotary_emb, + txt_rotary_emb, + ) + + print("\nTracing model with parallel_model_trace...") + compiled_transformer = neuronx_distributed.trace.parallel_model_trace( + get_transformer_f, + sample_inputs, + compiler_workdir=f"{args.compiler_workdir}/transformer_v1_flash", + compiler_args=compiler_flags, + tp_degree=tp_degree, + inline_weights_to_neff=False, + # Note: spmd_mode requires checkpoint_loader_callable, try without it first + ) + + # Save - use subdirectory for model files (parallel_model_load expects only .pt files) + output_path = f"{args.compiled_models_dir}/transformer_v1_flash" + model_path = f"{output_path}/model" + os.makedirs(model_path, exist_ok=True) + + print(f"\nSaving model to {model_path}...") + neuronx_distributed.trace.parallel_model_save( + compiled_transformer, model_path) + + # Save config in parent directory (not with model files) + config = { + "height": args.height, + "width": args.width, + "num_patches": num_patches, + "text_seq_len": text_seq_len, + "patch_multiplier": args.patch_multiplier, + "tp_degree": tp_degree, + "head_dim": head_dim, + "frame": temporal_frames, + "patch_h": patch_h, + "patch_w": patch_w, + "nki_flash_attention": True, + } + with open(os.path.join(output_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save pre-computed RoPE in parent directory + torch.save({ + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, os.path.join(output_path, "rope_cache.pt")) + + print("\nCompilation complete!") + print(f"Model saved to: {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--max_sequence_length", type=int, default=1024) + parser.add_argument("--patch_multiplier", type=int, default=3) + parser.add_argument("--tp_degree", type=int, default=8) + parser.add_argument("--compiled_models_dir", type=str, default="/opt/dlami/nvme/compiled_models") + parser.add_argument("--compiler_workdir", type=str, default="/opt/dlami/nvme/compiler_workdir") + args = parser.parse_args() + + compile_transformer_v1_flash(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_transformer_v2.py b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v2.py new file mode 100644 index 00000000..7a3b3767 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v2.py @@ -0,0 +1,476 @@ +""" +Transformer compilation using ModelBuilder (V2 API). + +Key approach: +1. RoPE frequencies computed OUTSIDE the model and passed as INPUT tensors +2. Model does NOT compute RoPE internally - avoids XLA constant-folding +3. Uses ModelBuilder for compilation + +This avoids the RoPE buffer constant-folding issue that broke previous V2 attempts. +Achieves ~2x speedup over V1 (parallel_model_trace) API. +""" + +import os +import json + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags optimized for transformer +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer -O1 --auto-cast=none --enable-fast-loading-neuron-binaries --tensorizer-options='--enable-ccop-compute-overlap' """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +from typing import Optional, Tuple, List + +from diffusers import QwenImageEditPlusPipeline + +# ModelBuilder imports +from neuronx_distributed import ModelBuilder, NxDParallelState, shard_checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers import parallel_state +from safetensors.torch import save_file + +from neuron_parallel_utils import ( + shard_qwen_attention, + shard_feedforward, + shard_modulation, + get_sharded_data, +) +from neuron_commons import neuron_scaled_dot_product_attention + +# Override SDPA for Neuron compatibility +print("Overriding SDPA for Neuron compatibility") +torch.nn.functional.scaled_dot_product_attention = neuron_scaled_dot_product_attention + +# NOTE: We'll patch apply_rotary_emb_qwen AFTER defining apply_rotary_emb_precomputed +# This is done below after the function definition + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def apply_rotary_emb_precomputed( + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """ + Apply rotary embeddings using PRE-COMPUTED cos/sin tensors. + + Handles BOTH use_real=True and use_real=False cases: + - use_real=False (QwenImage default): Complex multiplication simulation + - use_real=True: Standard cos/sin rotation + + Args: + x: [B, S, H, D] - input tensor, D = head_dim = 128 + freqs_cis: Tuple of (cos, sin), each [S, D/2] - NOT interleaved (D/2 = 64) + + Returns: + Rotated tensor [B, S, H, D] + """ + cos, sin = freqs_cis # Each [S, 64] + + # Move to same device as x + cos = cos.to(x.device) + sin = sin.to(x.device) + + if not use_real: + # QwenImage uses use_real=False (complex multiplication) + # Original code: + # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + # freqs_cis = freqs_cis.unsqueeze(1) # [S, 1, D/2] for broadcasting with [B, S, H, D/2] + # x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) + # + # Complex multiplication: (a + bi)(c + di) = (ac - bd) + (ad + bc)i + # where x = a + bi, freqs = c + di = cos + i*sin + + # Reshape x to [B, S, H, D/2, 2] then split into real/imag + x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, 64, 2] + x_real = x_reshaped[..., 0] # [B, S, H, 64] + x_imag = x_reshaped[..., 1] # [B, S, H, 64] + + # Expand cos/sin for broadcasting: [S, 64] -> [1, S, 1, 64] + cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + sin = sin.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + + # Complex multiplication: (x_real + i*x_imag) * (cos + i*sin) + # real part: x_real * cos - x_imag * sin + # imag part: x_real * sin + x_imag * cos + out_real = x_real * cos - x_imag * sin # [B, S, H, 64] + out_imag = x_real * sin + x_imag * cos # [B, S, H, 64] + + # Stack and flatten back to [B, S, H, 128] + out = torch.stack([out_real, out_imag], dim=-1) # [B, S, H, 64, 2] + out = out.flatten(-2) # [B, S, H, 128] + + return out.to(x.dtype) + else: + # use_real=True path (standard rotation) + # Expand for broadcasting: [S, D/2] -> [1, S, 1, D/2] + cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + sin = sin.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + + # Interleave: [c0, c1, ...] -> [c0, c0, c1, c1, ...] + cos = cos.repeat_interleave(2, dim=-1) # [1, S, 1, 128] + sin = sin.repeat_interleave(2, dim=-1) # [1, S, 1, 128] + + # Create rotated version + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + else: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +# Patch apply_rotary_emb_qwen to use our pre-computed version +import diffusers.models.transformers.transformer_qwenimage as qwen_module +qwen_module.apply_rotary_emb_qwen = apply_rotary_emb_precomputed +print("Patched apply_rotary_emb_qwen for pre-computed RoPE") + + +class NeuronQwenTransformerV2(nn.Module): + """ + Neuron-optimized QwenImage Transformer for V2 API. + + Key difference: RoPE frequencies are passed as INPUT, not computed internally. + This avoids XLA constant-folding issues. + """ + + def __init__(self, original_transformer, tp_degree): + super().__init__() + + self.config = original_transformer.config + self.in_channels = original_transformer.config.in_channels + self.out_channels = original_transformer.config.out_channels + self.patch_size = original_transformer.config.patch_size + + # Input projections (keep original) + self.img_in = original_transformer.img_in # Linear for image patches + self.txt_in = original_transformer.txt_in # Linear for text + + # Time/text embedding (keep original) + self.time_text_embed = original_transformer.time_text_embed + + # Text norm (keep original) + self.txt_norm = original_transformer.txt_norm + + # NOTE: We do NOT copy pos_embed (RoPE) - it will be passed as input! + + # Transformer blocks (need to shard) + self.transformer_blocks = nn.ModuleList() + for i, block in enumerate(original_transformer.transformer_blocks): + # Shard attention + block.attn = shard_qwen_attention(tp_degree, block.attn) + # Shard MLPs + block.img_mlp = shard_feedforward(block.img_mlp) + block.txt_mlp = shard_feedforward(block.txt_mlp) + # Shard modulation + block.img_mod = shard_modulation(block.img_mod) + block.txt_mod = shard_modulation(block.txt_mod) + self.transformer_blocks.append(block) + + if (i + 1) % 10 == 0: + print(f" Sharded block {i+1}/{len(original_transformer.transformer_blocks)}") + + # Final layers (keep original) + self.norm_out = original_transformer.norm_out + self.proj_out = original_transformer.proj_out + + # Store head_dim for RoPE + self.head_dim = 128 # QwenImage uses 128 + self.num_heads = original_transformer.transformer_blocks[0].attn.heads + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + img_rotary_emb: torch.Tensor, # [num_patches, 64, 2] for (cos, sin), NOT interleaved + txt_rotary_emb: torch.Tensor, # [text_seq, 64, 2] for (cos, sin), NOT interleaved + ) -> torch.Tensor: + """ + Forward pass with RoPE as INPUT. + + Args: + hidden_states: [B, num_patches, in_channels] + encoder_hidden_states: [B, text_seq, text_dim] + timestep: [B] + img_rotary_emb: [num_patches, 64, 2] - pre-computed RoPE (NOT interleaved) + txt_rotary_emb: [text_seq, 64, 2] - pre-computed RoPE (NOT interleaved) + """ + # Split RoPE into cos/sin + # Shape: [S, 64] - NOT interleaved, apply_rotary_emb_precomputed will do repeat_interleave + img_freqs_cos = img_rotary_emb[..., 0] # [num_patches, 64] + img_freqs_sin = img_rotary_emb[..., 1] + txt_freqs_cos = txt_rotary_emb[..., 0] # [text_seq, 64] + txt_freqs_sin = txt_rotary_emb[..., 1] + + # Image input projection + hidden_states = self.img_in(hidden_states) # [B, num_patches, inner_dim] + + # Text processing: norm first, then projection + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) # [B, text_seq, inner_dim] + + # Time embedding (takes timestep and hidden_states) + timestep = timestep.to(hidden_states.dtype) + temb = self.time_text_embed(timestep, hidden_states) + + # Create rotary_emb tuple in format expected by diffusers + # Using (cos, sin) tuple format for Neuron compatibility + image_rotary_emb = ((img_freqs_cos, img_freqs_sin), (txt_freqs_cos, txt_freqs_sin)) + + # Process through transformer blocks + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # Final norm and projection + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return output + + +class TracingWrapperV2(nn.Module): + """Wrapper for ModelBuilder tracing.""" + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + + def forward(self, hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb): + return self.transformer( + hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb + ) + + +def get_rope_from_original_model( + pipe, + frame: int, + height: int, + width: int, + text_seq_len: int, + dtype=torch.bfloat16, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get RoPE directly from the original QwenEmbedRope model. + + This ensures the RoPE values are EXACTLY the same as what V1 uses. + + Returns: + img_rotary_emb: [num_patches, 64, 2] - stacked (cos, sin) from complex freqs + txt_rotary_emb: [text_seq_len, 64, 2] - stacked (cos, sin) from complex freqs + """ + print(f" Getting RoPE from original model...") + print(f" video_fhw: ({frame}, {height}, {width}), text_seq_len: {text_seq_len}") + + # Call original pos_embed to get complex freqs + video_fhw = (frame, height, width) + vid_freqs, txt_freqs = pipe.transformer.pos_embed( + video_fhw, txt_seq_lens=[text_seq_len], device=torch.device('cpu') + ) + + print(f" vid_freqs from model: {vid_freqs.shape}, dtype: {vid_freqs.dtype}") + print(f" txt_freqs from model: {txt_freqs.shape}, dtype: {txt_freqs.dtype}") + + # Convert complex to (cos, sin) + # Complex freqs are e^(i*angle) = cos(angle) + i*sin(angle) + img_cos = vid_freqs.real.float() # [num_patches, 64] + img_sin = vid_freqs.imag.float() # [num_patches, 64] + txt_cos = txt_freqs.real.float() # [text_seq_len, 64] + txt_sin = txt_freqs.imag.float() # [text_seq_len, 64] + + # Stack to [S, 64, 2] + img_rotary_emb = torch.stack([img_cos, img_sin], dim=-1).to(dtype) + txt_rotary_emb = torch.stack([txt_cos, txt_sin], dim=-1).to(dtype) + + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + print(f" img_cos stats: min={img_cos.min():.4f}, max={img_cos.max():.4f}") + print(f" img_sin stats: min={img_sin.min():.4f}, max={img_sin.max():.4f}") + + return img_rotary_emb, txt_rotary_emb + + +def compile_transformer_v2(args): + """Compile transformer using ModelBuilder V2 API with RoPE as input.""" + + tp_degree = args.tp_degree + + # Calculate dimensions + latent_h = args.height // 8 + latent_w = args.width // 8 + patch_size = 2 + patch_h = latent_h // patch_size + patch_w = latent_w // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + + text_seq_len = args.max_sequence_length + text_hidden_size = 3584 + in_channels = 64 + head_dim = 128 + + print("=" * 60) + print("Transformer V2 Compilation") + print("=" * 60) + print(f"Image: {args.height}x{args.width}") + print(f"Patches: {num_patches} ({temporal_frames}x{patch_h}x{patch_w})") + print(f"Text seq: {text_seq_len}") + print(f"TP degree: {tp_degree}") + + # Sample inputs + sample_hidden_states = torch.randn(1, num_patches, in_channels, dtype=torch.bfloat16) + sample_encoder_hidden_states = torch.randn(1, text_seq_len, text_hidden_size, dtype=torch.bfloat16) + sample_timestep = torch.randn(1, dtype=torch.float32) + + with NxDParallelState(world_size=tp_degree, tensor_model_parallel_size=tp_degree): + print("\nLoading model...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + local_files_only=True, + cache_dir=CACHE_DIR + ) + + # Get RoPE directly from original model (ensures exact match with V1) + print("\nGetting RoPE from original model...") + img_rotary_emb, txt_rotary_emb = get_rope_from_original_model( + pipe=pipe, + frame=temporal_frames, + height=patch_h, + width=patch_w, + text_seq_len=text_seq_len, + ) + + # Verify shapes are correct (64 = head_dim // 2) + rope_dim = head_dim // 2 # 64 + assert img_rotary_emb.shape[-2] == rope_dim, f"img_rotary_emb shape wrong: {img_rotary_emb.shape}, expected dim -2 = {rope_dim}" + assert txt_rotary_emb.shape[-2] == rope_dim, f"txt_rotary_emb shape wrong: {txt_rotary_emb.shape}, expected dim -2 = {rope_dim}" + + # Save unsharded state dict before modifications + unsharded_state = pipe.transformer.state_dict() + + print("Creating Neuron transformer (sharding layers)...") + neuron_transformer = NeuronQwenTransformerV2(pipe.transformer, tp_degree) + neuron_transformer = neuron_transformer.to(torch.bfloat16) + neuron_transformer.eval() + + # Wrap for tracing + model = TracingWrapperV2(neuron_transformer) + + print("\nInitializing ModelBuilder...") + builder = ModelBuilder(model=model) + + print("Tracing model...") + builder.trace( + kwargs={ + "hidden_states": sample_hidden_states, + "encoder_hidden_states": sample_encoder_hidden_states, + "timestep": sample_timestep, + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, + tag="inference", + ) + + print("Compiling model...") + compile_args = "--model-type=transformer -O1 --auto-cast=none --lnc=2 --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=4' --internal-hlo2tensorizer-options='--enable-native-kernel=1 --remat'" + traced_model = builder.compile( + compiler_args=compile_args, + ) + + # Save + output_path = f"{args.compiled_models_dir}/transformer_v2" + os.makedirs(output_path, exist_ok=True) + + print(f"\nSaving to {output_path}...") + traced_model.save(os.path.join(output_path, "nxd_model.pt")) + + # Save weights + weights_path = os.path.join(output_path, "weights") + os.makedirs(weights_path, exist_ok=True) + + # Prepare checkpoint for sharding + checkpoint = {} + for key, value in model.state_dict().items(): + # Use unsharded weights where available + orig_key = key.replace("transformer.", "", 1) + if orig_key in unsharded_state: + checkpoint[key] = unsharded_state[orig_key].clone() + else: + checkpoint[key] = value.clone() + + # Shard checkpoint + print("Sharding weights...") + shard_checkpoint( + checkpoint=checkpoint, + model=model, + serialize_path=weights_path, + ) + + # Save config + config = { + "height": args.height, + "width": args.width, + "num_patches": num_patches, + "text_seq_len": text_seq_len, + "patch_multiplier": args.patch_multiplier, + "tp_degree": tp_degree, + "head_dim": head_dim, + "frame": temporal_frames, + "patch_h": patch_h, + "patch_w": patch_w, + } + with open(os.path.join(output_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save pre-computed RoPE + torch.save({ + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, os.path.join(output_path, "rope_cache.pt")) + + print("\nCompilation complete!") + print(f"Model saved to: {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default=None, + help="Path to model (local dir or HF ID). If not set, uses MODEL_ID with CACHE_DIR") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--max_sequence_length", type=int, default=1024) + parser.add_argument("--patch_multiplier", type=int, default=3) + parser.add_argument("--tp_degree", type=int, default=8) + parser.add_argument("--compiled_models_dir", type=str, default="/opt/dlami/nvme/compiled_models") + args = parser.parse_args() + + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + compile_transformer_v2(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_transformer_v2_flash.py b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v2_flash.py new file mode 100644 index 00000000..2fb3032f --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v2_flash.py @@ -0,0 +1,656 @@ +""" +Transformer compilation using ModelBuilder (V2 API) with NKI Flash Attention. + +Key approach: +1. Uses ModelBuilder API for compilation (like V2) +2. Uses NKI Flash Attention kernel for hardware-optimized attention (like V1 Flash) +3. RoPE frequencies computed OUTSIDE the model and passed as INPUT tensors +4. Disables XLA functionalization to allow NKI in-place operations + +This combines the best of both: +- ModelBuilder's XLA optimization +- NKI's hardware-optimized attention kernel +""" + +import os +import json +import math + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +# CRITICAL: Disable XLA functionalization to allow NKI kernel in-place operations +# Without this, NKI kernels will fail with "Cannot update immutable parameter" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags optimized for transformer +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer -O1 --auto-cast=none --enable-fast-loading-neuron-binaries --tensorizer-options='--enable-ccop-compute-overlap' """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +from typing import Optional, Tuple, List + +from diffusers import QwenImageEditPlusPipeline + +# ModelBuilder imports +from neuronx_distributed import ModelBuilder, NxDParallelState, shard_checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers import parallel_state +from safetensors.torch import save_file + +from neuron_parallel_utils import ( + shard_qwen_attention, + shard_feedforward, + shard_modulation, + get_sharded_data, +) + +# Import NKI Flash Attention - use EXACTLY the same imports as Flux +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronxcc.nki.language import nc +from torch_neuronx.xla_impl.ops import nki_jit + +# Create NKI callable - EXACTLY like Flux does +_flash_fwd_call = nki_jit()(attention_isa_kernel) + +NKI_AVAILABLE = True +print("NKI Flash Attention kernel loaded successfully") + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def nki_flash_attention(query, key, value): + """ + NKI Flash Attention wrapper. + + Args: + query: [B, H, S, D] - query tensor + key: [B, H, S, D] - key tensor + value: [B, H, S, D] - value tensor + + Returns: + attention output [B, H, S, D] + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + v_len = value.shape[2] + + # Reshape for NKI kernel: [B*H, D, S] for Q/K, [B*H, S, D] for V + q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) + k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, k_len)) + v = value.clone().reshape((bs * n_head, v_len, d_head)) + + scale = 1 / math.sqrt(d_head) + + # The LNC2 sharded flash kernel requires seqlen_q divisible by a power of 2 >= 512. + # TP=8 keeps the FULL query (not CP-split), so q_len = joint seq (e.g. 12944) which + # is not aligned. Pad ONLY the query seqlen up to the next multiple of 512 (key/value + # stay full, so padding never pollutes real tokens — pad queries just produce junk rows + # we slice off). q is [B*H, D, q_len]; pad the last dim. + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + ALIGN = 512 + q_pad = ((q_len + ALIGN - 1) // ALIGN) * ALIGN + if vc_size == 2 and q_pad != q_len: + pad_n = q_pad - q_len + q = torch.nn.functional.pad(q, (0, pad_n)) # pad last dim (seqlen_q) of [B*H, D, q_len] + else: + q_pad = q_len + + attn_output = torch.zeros((bs * n_head, q_pad, d_head), dtype=torch.bfloat16, device=q.device) + + if vc_size == 2: + grid = (nc(2),) + _flash_fwd_call[grid]( + q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap" + ) + else: + _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + + # Slice off the padded query rows + if attn_output.shape[1] != q_len: + attn_output = attn_output[:, :q_len, :] + + # Reshape back to [B, H, S, D] + attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) + + return attn_output + + +def apply_rotary_emb_precomputed( + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """ + Apply rotary embeddings using PRE-COMPUTED cos/sin tensors. + + Args: + x: [B, S, H, D] - input tensor, D = head_dim = 128 + freqs_cis: Tuple of (cos, sin), each [S, D/2] - NOT interleaved (D/2 = 64) + + Returns: + Rotated tensor [B, S, H, D] + """ + cos, sin = freqs_cis # Each [S, 64] + + # Move to same device as x + cos = cos.to(x.device) + sin = sin.to(x.device) + + if not use_real: + # QwenImage uses use_real=False (complex multiplication) + x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, 64, 2] + x_real = x_reshaped[..., 0] # [B, S, H, 64] + x_imag = x_reshaped[..., 1] # [B, S, H, 64] + + # Expand cos/sin for broadcasting: [S, 64] -> [1, S, 1, 64] + cos = cos.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + sin = sin.unsqueeze(0).unsqueeze(2) # [1, S, 1, 64] + + # Complex multiplication: (x_real + i*x_imag) * (cos + i*sin) + out_real = x_real * cos - x_imag * sin # [B, S, H, 64] + out_imag = x_real * sin + x_imag * cos # [B, S, H, 64] + + # Stack and flatten back to [B, S, H, 128] + out = torch.stack([out_real, out_imag], dim=-1) # [B, S, H, 64, 2] + out = out.flatten(-2) # [B, S, H, 128] + + return out.to(x.dtype) + else: + # use_real=True path (standard rotation) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + else: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +# Patch apply_rotary_emb_qwen to use our pre-computed version +import diffusers.models.transformers.transformer_qwenimage as qwen_module +qwen_module.apply_rotary_emb_qwen = apply_rotary_emb_precomputed +print("Patched apply_rotary_emb_qwen for pre-computed RoPE") + + +class NKIQwenAttention(nn.Module): + """ + Custom attention module for QwenImage that uses NKI Flash Attention directly. + + This completely replaces diffusers' Attention class, similar to how Flux + uses NeuronFluxAttention. + """ + + def __init__(self, orig_attn): + """Initialize from an existing sharded attention module.""" + super().__init__() + + # Copy all the layers from the original attention + self.heads = orig_attn.heads + self.to_q = orig_attn.to_q + self.to_k = orig_attn.to_k + self.to_v = orig_attn.to_v + self.to_out = orig_attn.to_out + + # Text projections + self.add_q_proj = orig_attn.add_q_proj if hasattr(orig_attn, 'add_q_proj') else None + self.add_k_proj = orig_attn.add_k_proj if hasattr(orig_attn, 'add_k_proj') else None + self.add_v_proj = orig_attn.add_v_proj if hasattr(orig_attn, 'add_v_proj') else None + self.to_add_out = orig_attn.to_add_out if hasattr(orig_attn, 'to_add_out') else None + + # Norms + self.norm_q = orig_attn.norm_q if hasattr(orig_attn, 'norm_q') else None + self.norm_k = orig_attn.norm_k if hasattr(orig_attn, 'norm_k') else None + self.norm_added_q = orig_attn.norm_added_q if hasattr(orig_attn, 'norm_added_q') else None + self.norm_added_k = orig_attn.norm_added_k if hasattr(orig_attn, 'norm_added_k') else None + + def forward( + self, + hidden_states: torch.Tensor, # Image stream [B, S_img, C] + encoder_hidden_states: torch.Tensor = None, # Text stream [B, S_txt, C] + encoder_hidden_states_mask: torch.Tensor = None, + image_rotary_emb: Tuple = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass with NKI Flash Attention.""" + if encoder_hidden_states is None: + raise ValueError("NKIQwenAttention requires encoder_hidden_states") + + batch_size = hidden_states.shape[0] + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream + img_query = self.to_q(hidden_states) + img_key = self.to_k(hidden_states) + img_value = self.to_v(hidden_states) + + # Compute QKV for text stream + txt_query = self.add_q_proj(encoder_hidden_states) + txt_key = self.add_k_proj(encoder_hidden_states) + txt_value = self.add_v_proj(encoder_hidden_states) + + # Get head dimension + inner_dim = img_query.shape[-1] + head_dim = inner_dim // self.heads + + # Reshape to [B, S, H, D] then transpose to [B, H, S, D] - exactly like Flux + img_query = img_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_key = img_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_value = img_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + txt_query = txt_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_key = txt_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_value = txt_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Apply QK normalization + if self.norm_q is not None: + img_query = self.norm_q(img_query) + if self.norm_k is not None: + img_key = self.norm_k(img_key) + if self.norm_added_q is not None: + txt_query = self.norm_added_q(txt_query) + if self.norm_added_k is not None: + txt_key = self.norm_added_k(txt_key) + + # Apply RoPE - note: input is now [B, H, S, D] + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + # Transpose to [B, S, H, D] for RoPE, then back to [B, H, S, D] + img_query = apply_rotary_emb_precomputed(img_query.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + img_key = apply_rotary_emb_precomputed(img_key.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + txt_query = apply_rotary_emb_precomputed(txt_query.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + txt_key = apply_rotary_emb_precomputed(txt_key.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + + # Concatenate for joint attention along sequence dim: [B, H, S_txt + S_img, D] + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + if os.getenv("QIE_DROP_CM", "0") == "1": + # drop_cm sparse attention (TP=8 symmetric => no merge): drop cloth<->model + # cross-attention. joint layout [txt, noise, cloth, model], img = 3 frames of F. + # query groups: (txt+noise) see ALL; cloth see [txt,noise,cloth]; model see [txt,noise,model]. + # Each group is ONE full softmax over its own KV (verified bit-exact, no merge). + S_img = img_query.shape[2] + F = S_img // 3 # base_patches per frame + st = seq_txt + nz_end = st + F # txt+noise end + cl_end = st + 2 * F # cloth end + # KV slices + k_all, v_all = joint_key, joint_value + k_tnc = joint_key[:, :, :cl_end, :]; v_tnc = joint_value[:, :, :cl_end, :] # [txt,noise,cloth] + # [txt,noise,model] = gather (drop cloth band) + k_tnm = torch.cat([joint_key[:, :, :nz_end, :], joint_key[:, :, cl_end:, :]], dim=2) + v_tnm = torch.cat([joint_value[:, :, :nz_end, :], joint_value[:, :, cl_end:, :]], dim=2) + # query groups + q_tn = joint_query[:, :, :nz_end, :] # txt+noise + q_cl = joint_query[:, :, nz_end:cl_end, :] # cloth + q_md = joint_query[:, :, cl_end:, :] # model + o_tn = nki_flash_attention(q_tn, k_all, v_all) + o_cl = nki_flash_attention(q_cl, k_tnc, v_tnc) + o_md = nki_flash_attention(q_md, k_tnm, v_tnm) + joint_hidden_states = torch.cat([o_tn, o_cl, o_md], dim=2) + else: + # Use NKI Flash Attention - input is [B, H, S, D] exactly like Flux + joint_hidden_states = nki_flash_attention(joint_query, joint_key, joint_value) + + # Transpose back and reshape: [B, H, S, D] -> [B, S, H*D] + joint_hidden_states = joint_hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output = self.to_out[0](img_attn_output) + if len(self.to_out) > 1: + img_attn_output = self.to_out[1](img_attn_output) # dropout + + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +def replace_attention_with_nki(transformer): + """Replace all attention modules with NKI versions.""" + for i, block in enumerate(transformer.transformer_blocks): + block.attn = NKIQwenAttention(block.attn) + print(f"Replaced attention modules with NKI versions on {len(transformer.transformer_blocks)} blocks") + + +class NeuronQwenTransformerV2Flash(nn.Module): + """ + Neuron-optimized QwenImage Transformer for V2 Flash. + + Combines: + - ModelBuilder API for compilation (V2) + - NKI Flash Attention for hardware-optimized attention (V1 Flash) + - Pre-computed RoPE as input tensors + """ + + def __init__(self, original_transformer, tp_degree): + super().__init__() + + self.config = original_transformer.config + self.in_channels = original_transformer.config.in_channels + self.out_channels = original_transformer.config.out_channels + self.patch_size = original_transformer.config.patch_size + + # Input projections (keep original) + self.img_in = original_transformer.img_in + self.txt_in = original_transformer.txt_in + + # Time/text embedding (keep original) + self.time_text_embed = original_transformer.time_text_embed + + # Text norm (keep original) + self.txt_norm = original_transformer.txt_norm + + # NOTE: We do NOT copy pos_embed (RoPE) - it will be passed as input! + + # Transformer blocks (need to shard) + self.transformer_blocks = nn.ModuleList() + for i, block in enumerate(original_transformer.transformer_blocks): + # Shard attention + block.attn = shard_qwen_attention(tp_degree, block.attn) + # Shard MLPs + block.img_mlp = shard_feedforward(block.img_mlp) + block.txt_mlp = shard_feedforward(block.txt_mlp) + # Shard modulation + block.img_mod = shard_modulation(block.img_mod) + block.txt_mod = shard_modulation(block.txt_mod) + self.transformer_blocks.append(block) + + if (i + 1) % 10 == 0: + print(f" Sharded block {i+1}/{len(original_transformer.transformer_blocks)}") + + # Final layers (keep original) + self.norm_out = original_transformer.norm_out + self.proj_out = original_transformer.proj_out + + # Store head_dim for RoPE + self.head_dim = 128 + self.num_heads = original_transformer.transformer_blocks[0].attn.heads + + # Replace attention modules with NKI versions AFTER sharding + replace_attention_with_nki(self) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + img_rotary_emb: torch.Tensor, # [num_patches, 64, 2] + txt_rotary_emb: torch.Tensor, # [text_seq, 64, 2] + ) -> torch.Tensor: + """Forward pass with RoPE as INPUT and NKI Flash Attention.""" + # Split RoPE into cos/sin + img_freqs_cos = img_rotary_emb[..., 0] # [num_patches, 64] + img_freqs_sin = img_rotary_emb[..., 1] + txt_freqs_cos = txt_rotary_emb[..., 0] # [text_seq, 64] + txt_freqs_sin = txt_rotary_emb[..., 1] + + # Image input projection + hidden_states = self.img_in(hidden_states) + + # Text processing + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + # Time embedding + timestep = timestep.to(hidden_states.dtype) + temb = self.time_text_embed(timestep, hidden_states) + + # Create rotary_emb tuple + image_rotary_emb = ((img_freqs_cos, img_freqs_sin), (txt_freqs_cos, txt_freqs_sin)) + + # Process through transformer blocks + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # Final norm and projection + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return output + + +class TracingWrapperV2Flash(nn.Module): + """Wrapper for ModelBuilder tracing.""" + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + + def forward(self, hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb): + return self.transformer( + hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb + ) + + +def get_rope_from_original_model( + pipe, + frame: int, + height: int, + width: int, + text_seq_len: int, + dtype=torch.bfloat16, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get RoPE directly from the original QwenEmbedRope model.""" + print(f" Getting RoPE from original model...") + print(f" video_fhw: ({frame}, {height}, {width}), text_seq_len: {text_seq_len}") + + video_fhw = (frame, height, width) + vid_freqs, txt_freqs = pipe.transformer.pos_embed( + video_fhw, max_txt_seq_len=text_seq_len, device=torch.device('cpu') + ) + + print(f" vid_freqs from model: {vid_freqs.shape}, dtype: {vid_freqs.dtype}") + print(f" txt_freqs from model: {txt_freqs.shape}, dtype: {txt_freqs.dtype}") + + # Convert complex to (cos, sin) + img_cos = vid_freqs.real.float() + img_sin = vid_freqs.imag.float() + txt_cos = txt_freqs.real.float() + txt_sin = txt_freqs.imag.float() + + # Stack to [S, 64, 2] + img_rotary_emb = torch.stack([img_cos, img_sin], dim=-1).to(dtype) + txt_rotary_emb = torch.stack([txt_cos, txt_sin], dim=-1).to(dtype) + + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + return img_rotary_emb, txt_rotary_emb + + +def compile_transformer_v2_flash(args): + """Compile transformer using ModelBuilder V2 API with NKI Flash Attention.""" + + tp_degree = args.tp_degree + + # Calculate dimensions + latent_h = args.height // 8 + latent_w = args.width // 8 + patch_size = 2 + patch_h = latent_h // patch_size + patch_w = latent_w // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + + text_seq_len = args.max_sequence_length + text_hidden_size = 3584 + in_channels = 64 + head_dim = 128 + + print("=" * 60) + print("Transformer V2 Flash Compilation") + print("=" * 60) + print(f"Image: {args.height}x{args.width}") + print(f"Patches: {num_patches} ({temporal_frames}x{patch_h}x{patch_w})") + print(f"Text seq: {text_seq_len}") + print(f"TP degree: {tp_degree}") + print(f"NKI Flash Attention: Enabled") + print(f"XLA_DISABLE_FUNCTIONALIZATION: {os.environ.get('XLA_DISABLE_FUNCTIONALIZATION', 'not set')}") + + # Sample inputs + sample_hidden_states = torch.randn(1, num_patches, in_channels, dtype=torch.bfloat16) + sample_encoder_hidden_states = torch.randn(1, text_seq_len, text_hidden_size, dtype=torch.bfloat16) + sample_timestep = torch.randn(1, dtype=torch.float32) + + with NxDParallelState(world_size=tp_degree, tensor_model_parallel_size=tp_degree): + print("\nLoading model...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=torch.bfloat16, + local_files_only=True, + cache_dir=CACHE_DIR + ) + + # Get RoPE from original model + print("\nGetting RoPE from original model...") + img_rotary_emb, txt_rotary_emb = get_rope_from_original_model( + pipe=pipe, + frame=temporal_frames, + height=patch_h, + width=patch_w, + text_seq_len=text_seq_len, + ) + + # Save unsharded state dict before modifications + unsharded_state = pipe.transformer.state_dict() + + print("Creating Neuron transformer (sharding layers + NKI attention)...") + neuron_transformer = NeuronQwenTransformerV2Flash(pipe.transformer, tp_degree) + neuron_transformer = neuron_transformer.to(torch.bfloat16) + neuron_transformer.eval() + + # Wrap for tracing + model = TracingWrapperV2Flash(neuron_transformer) + + print("\nInitializing ModelBuilder...") + builder = ModelBuilder(model=model) + + print("Tracing model with NKI Flash Attention...") + builder.trace( + kwargs={ + "hidden_states": sample_hidden_states, + "encoder_hidden_states": sample_encoder_hidden_states, + "timestep": sample_timestep, + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, + tag="inference", + ) + + print("Compiling model...") + compile_args = "--model-type=transformer -O1 --auto-cast=none --lnc=2 --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=4' --internal-hlo2tensorizer-options='--enable-native-kernel=1 --remat'" + traced_model = builder.compile( + compiler_args=compile_args, + ) + + # Save + output_path = f"{args.compiled_models_dir}/transformer_v2_flash" + os.makedirs(output_path, exist_ok=True) + + print(f"\nSaving to {output_path}...") + traced_model.save(os.path.join(output_path, "nxd_model.pt")) + + # Save weights + weights_path = os.path.join(output_path, "weights") + os.makedirs(weights_path, exist_ok=True) + + # Prepare checkpoint for sharding + checkpoint = {} + for key, value in model.state_dict().items(): + # Use unsharded weights where available + orig_key = key.replace("transformer.", "", 1) + if orig_key in unsharded_state: + checkpoint[key] = unsharded_state[orig_key].clone() + else: + checkpoint[key] = value.clone() + + # Shard checkpoint + print("Sharding weights...") + shard_checkpoint( + checkpoint=checkpoint, + model=model, + serialize_path=weights_path, + ) + + # Save config + config = { + "height": args.height, + "width": args.width, + "num_patches": num_patches, + "text_seq_len": text_seq_len, + "patch_multiplier": args.patch_multiplier, + "tp_degree": tp_degree, + "head_dim": head_dim, + "frame": temporal_frames, + "patch_h": patch_h, + "patch_w": patch_w, + "nki_flash_attention": True, + } + with open(os.path.join(output_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save pre-computed RoPE + torch.save({ + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, os.path.join(output_path, "rope_cache.pt")) + + print("\nCompilation complete!") + print(f"Model saved to: {output_path}") + print("\nTo run inference:") + print(f" python run_qwen_image_edit.py --images img1.png img2.png --prompt '...' --use_v2_flash") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default=None, + help="Path to model (local dir or HF ID). If not set, uses MODEL_ID with CACHE_DIR") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--max_sequence_length", type=int, default=1024) + parser.add_argument("--patch_multiplier", type=int, default=3) + parser.add_argument("--tp_degree", type=int, default=8) + parser.add_argument("--compiled_models_dir", type=str, default="/opt/dlami/nvme/compiled_models") + args = parser.parse_args() + + # Override MODEL_ID/CACHE_DIR if model_path provided (matches V3 CP) + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + compile_transformer_v2_flash(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_transformer_v3_cfg.py b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v3_cfg.py new file mode 100644 index 00000000..bd36ddf5 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v3_cfg.py @@ -0,0 +1,807 @@ +""" +Transformer compilation with CFG Parallelism (V3 CFG) using ModelBuilder API. + +Key approach: +1. Uses ModelBuilder API (like V3 CP) for compilation +2. Configures world_size=8, tp_degree=4 (implicit DP=2 for CFG) +3. Batches positive + negative prompts (batch_size=2), each DP rank processes one +4. No K/V all-gather needed (each rank has full sequence) +5. Uses NKI Flash Attention for optimal performance + +CFG Parallel works by: +- Model parameters are sharded with TP=4 +- DP group (2 ranks) is used for CFG parallelism +- Input is scattered along batch dim (dim=0): rank 0 gets negative, rank 1 gets positive +- Each DP rank processes one complete batch item (full sequence) +- Output is gathered along batch dim (dim=0) and CFG formula is applied + +CFG Parallel and Context Parallel are mutually exclusive. +""" + +import os +import json +import math + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags - same as Flux for CP mode +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer --auto-cast=none --enable-fast-loading-neuron-binaries --tensorizer-options='--enable-ccop-compute-overlap' --internal-hlo2tensorizer-options='--enable-state-buffer-mode=hybrid --remat-by-default' """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +from typing import Optional, Tuple, List + +from diffusers import QwenImageEditPlusPipeline + +# ModelBuilder imports +from neuronx_distributed import ModelBuilder, NxDParallelState, shard_checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + SPMDRank, +) +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_tensor_model_parallel_region_with_dim, + scatter_to_process_group_spmd, +) + +from neuron_parallel_utils import ( + shard_qwen_attention, + shard_feedforward, + shard_modulation, + get_sharded_data, +) + +# Import NKI Flash Attention (production-grade nkilib kernel) +from nkilib.core.attention.attention_cte import attention_cte as _nkilib_attention_cte +from neuronxcc.nki.language import nc +from torch_neuronx.xla_impl.ops import nki_jit + +# Keep the legacy private kernel as a fallback in case nkilib is missing +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel +_legacy_flash_fwd_call = nki_jit()(attention_isa_kernel) + +# attention_cte is already wrapped with @nki.jit. In XLA venv we don't have +# torch_neuronx.nki_op/wrap_nki, so call it directly without re-wrapping +# and without [grid] subscript — let the kernel handle LNC internally. +_attention_cte_call = _nkilib_attention_cte + +# Phase 16: attention_cte fork that hoists the Q load out of the section +# loop. QIE has num_sections=2 with identical Q across sections, so the +# baseline kernel reloads Q in section_idx=1 unnecessarily. Hoisting saves +# ~50% of Q DMA traffic on this workload. +USE_HOISTED_Q_ATTENTION = os.getenv("QIE_HOISTED_Q_ATTENTION", "1") == "1" +if USE_HOISTED_Q_ATTENTION: + from attention_cte_qie_hoisted_q import attention_cte_hoisted_q as _hoisted_q_call + +# Toggle between the two via env var (easy rollback during testing) +USE_NKILIB_ATTENTION = os.getenv("QIE_USE_NKILIB_ATTENTION", "1") == "1" + +# Phase 17: lower the all-reduce dtype from fp32 to bf16 in the transformer +# blocks' RowParallelLinear layers. The default upstream NxDI behavior is fp32 +# reduce, which on the V3 CFG configuration spends ~204 ms per step on 956 +# TP all-reduces totalling ~18 GB. bf16 halves the bytes on the wire and saves +# ~137 ms / step (~9% E2E) with no visible image quality regression on the +# 1024x1024 two-image merge workload. Set QIE_ALLREDUCE_BF16=0 to revert. +ALLREDUCE_BF16 = os.getenv("QIE_ALLREDUCE_BF16", "1") == "1" +_REDUCE_DTYPE = torch.bfloat16 if ALLREDUCE_BF16 else torch.float32 + +print(f"NKI Flash Attention kernel loaded: " + f"{'attention_cte_hoisted_q (Phase 16: hoisted Q load)' if USE_HOISTED_Q_ATTENTION else ('nkilib.core.attention.attention_cte' if USE_NKILIB_ATTENTION else 'attention_isa_kernel (legacy)')}") +if ALLREDUCE_BF16: + print(f" + TP all-reduce dtype = bf16 (Phase 17 experiment)") + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def nki_flash_attention(query, key, value): + """ + NKI Flash Attention wrapper. + + Args: + query: [B, H, S, D] + key: [B, H, S, D] + value: [B, H, S, D] + + Returns: + attention output [B, H, S, D] + + Uses nkilib.core.attention.attention_cte by default (production-grade + kernel with controllable softmax/matmul dtype, causal_mask, sliding_window, + and KV-cache support). Set QIE_USE_NKILIB_ATTENTION=0 to fall back to the + legacy attention_isa_kernel. + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + v_len = value.shape[2] + scale = 1 / math.sqrt(d_head) + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + + if USE_HOISTED_Q_ATTENTION: + # Phase 16: attention_cte fork that hoists Q load above the section loop. + # Same public API as attention_cte (causal_mask, tp_q/k/out, softmax_dtype, ...). + q = query.reshape(bs * n_head, q_len, d_head) + k = key.reshape(bs * n_head, k_len, d_head) + v = value.reshape(bs * n_head, v_len, d_head) + softmax_dtype = os.getenv("QIE_SOFTMAX_DTYPE", "float32") + kernel = _hoisted_q_call[vc_size] + attn_output = kernel( + q, k, v, + scale=scale, causal_mask=False, + tp_q=True, tp_k=True, tp_out=False, + softmax_dtype=softmax_dtype, + mm_out_dtype="float32", + ) + return attn_output.reshape(bs, n_head, q_len, d_head) + + if USE_NKILIB_ATTENTION: + # attention_cte expects: + # q: [B, seqlen_q, d] (tp_q=True, default) + # k: [B, seqlen_kv, d] (tp_k=True) + # v: [B, seqlen_kv, d] + # Collapse B and H into the leading dim. + q = query.reshape(bs * n_head, q_len, d_head) + k = key.reshape(bs * n_head, k_len, d_head) + v = value.reshape(bs * n_head, v_len, d_head) + # QIE joint attention: no causal mask. + # Use [vc_size] bracket syntax to set LNC; kernel internally shards + # across LNC2 on batch (and on seqlen_q when batch is odd). + # + # softmax_dtype can be lowered to bf16 via QIE_SOFTMAX_DTYPE env, but + # we measured no speedup on Trn2 for QIE (1249ms both fp32 and bf16) + # because mm_out_dtype must stay float32 on Gen3 hardware. Default fp32. + softmax_dtype = os.getenv("QIE_SOFTMAX_DTYPE", "float32") + kernel = _attention_cte_call[vc_size] + attn_output = kernel( + q, k, v, + scale=scale, causal_mask=False, + tp_q=True, tp_k=True, tp_out=False, + softmax_dtype=softmax_dtype, + mm_out_dtype="float32", + ) + # attention_cte with tp_out=False returns [B, seqlen_q, d] + return attn_output.reshape(bs, n_head, q_len, d_head) + + # Legacy path: attention_isa_kernel + q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) + k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, k_len)) + v = value.clone().reshape((bs * n_head, v_len, d_head)) + attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device) + if vc_size == 2: + grid = (nc(2),) + _legacy_flash_fwd_call[grid](q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + else: + _legacy_flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + return attn_output.reshape((bs, n_head, q_len, d_head)) + + +class CFGNKIQwenAttention(nn.Module): + """ + CFG Parallel + NKI Flash Attention for QwenImage. + + Key differences from CPNKIQwenAttention: + - No K/V all-gather (each DP rank has full sequence for its batch item) + - Uses NKI Flash Attention kernel + """ + + def __init__(self, orig_attn): + super().__init__() + + self.heads = orig_attn.heads + self.to_q = orig_attn.to_q + self.to_k = orig_attn.to_k + self.to_v = orig_attn.to_v + self.to_out = orig_attn.to_out + + self.add_q_proj = orig_attn.add_q_proj if hasattr(orig_attn, 'add_q_proj') else None + self.add_k_proj = orig_attn.add_k_proj if hasattr(orig_attn, 'add_k_proj') else None + self.add_v_proj = orig_attn.add_v_proj if hasattr(orig_attn, 'add_v_proj') else None + self.to_add_out = orig_attn.to_add_out if hasattr(orig_attn, 'to_add_out') else None + + self.norm_q = orig_attn.norm_q if hasattr(orig_attn, 'norm_q') else None + self.norm_k = orig_attn.norm_k if hasattr(orig_attn, 'norm_k') else None + self.norm_added_q = orig_attn.norm_added_q if hasattr(orig_attn, 'norm_added_q') else None + self.norm_added_k = orig_attn.norm_added_k if hasattr(orig_attn, 'norm_added_k') else None + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + image_rotary_emb: Tuple = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward with NKI attention. No K/V gathering needed for CFG parallel. + """ + if encoder_hidden_states is None: + raise ValueError("CFGNKIQwenAttention requires encoder_hidden_states") + + batch_size = hidden_states.shape[0] + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream + img_query = self.to_q(hidden_states) + img_key = self.to_k(hidden_states) + img_value = self.to_v(hidden_states) + + # Compute QKV for text stream + txt_query = self.add_q_proj(encoder_hidden_states) + txt_key = self.add_k_proj(encoder_hidden_states) + txt_value = self.add_v_proj(encoder_hidden_states) + + inner_dim = img_query.shape[-1] + head_dim = inner_dim // self.heads + + # Reshape to [B, H, S, D] + img_query = img_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_key = img_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_value = img_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + txt_query = txt_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_key = txt_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_value = txt_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Apply QK normalization + if self.norm_q is not None: + img_query = self.norm_q(img_query) + if self.norm_k is not None: + img_key = self.norm_k(img_key) + if self.norm_added_q is not None: + txt_query = self.norm_added_q(txt_query) + if self.norm_added_k is not None: + txt_key = self.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_precomputed(img_query.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + img_key = apply_rotary_emb_precomputed(img_key.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + txt_query = apply_rotary_emb_precomputed(txt_query.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + txt_key = apply_rotary_emb_precomputed(txt_key.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + + # No K/V all-gather needed for CFG parallel + # Each DP rank has one complete batch item with full sequence + + # Concatenate for joint attention + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + # NKI Flash Attention + joint_hidden_states = nki_flash_attention(joint_query, joint_key, joint_value) + + # Transpose and reshape + joint_hidden_states = joint_hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] + img_attn_output = joint_hidden_states[:, seq_txt:, :] + + # Output projections + img_attn_output = self.to_out[0](img_attn_output) + if len(self.to_out) > 1: + img_attn_output = self.to_out[1](img_attn_output) + + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +def apply_rotary_emb_precomputed( + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """Apply rotary embeddings using pre-computed cos/sin tensors.""" + cos, sin = freqs_cis + cos = cos.to(x.device) + sin = sin.to(x.device) + + if not use_real: + x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) + x_real = x_reshaped[..., 0] + x_imag = x_reshaped[..., 1] + + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + out_real = x_real * cos - x_imag * sin + out_imag = x_real * sin + x_imag * cos + + out = torch.stack([out_real, out_imag], dim=-1) + out = out.flatten(-2) + + return out.to(x.dtype) + else: + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + else: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +# Patch apply_rotary_emb_qwen +import diffusers.models.transformers.transformer_qwenimage as qwen_module +qwen_module.apply_rotary_emb_qwen = apply_rotary_emb_precomputed +print("Patched apply_rotary_emb_qwen for pre-computed RoPE") + + +def split_along_dim(tensor, dim, rank, data_parallel_group): + """Split tensor along dimension using scatter_to_process_group_spmd.""" + tensor = scatter_to_process_group_spmd( + tensor, + partition_dim=dim, + rank=rank, + process_group=data_parallel_group, + ) + return tensor + + +def get_dp_rank_spmd(global_rank: torch.Tensor, tp_degree: int) -> torch.Tensor: + """ + Compute DP rank from global rank for SPMD execution. + + With world_size=8 and tp_degree=4: + - Ranks 0-3 are DP rank 0 + - Ranks 4-7 are DP rank 1 + """ + dp_rank = torch.div( + global_rank, + tp_degree, + rounding_mode="floor", + ).to(torch.int32) + return dp_rank + + +class NeuronQwenTransformerV3CFG(nn.Module): + """ + Neuron-optimized QwenImage Transformer with CFG Parallelism. + + Features: + - TP=4 for model parameter sharding + - CFG enabled (via DP group) for batch parallelism + - Input scattered along batch dim (dim=0): [2,S,C] -> [1,S,C] per rank + - No K/V all-gather (each rank has full sequence) + - Output gathered along batch dim (dim=0) + - NKI Flash Attention + """ + + def __init__(self, original_transformer, tp_degree, world_size, cfg_parallel_enabled=False): + super().__init__() + + self.config = original_transformer.config + self.in_channels = original_transformer.config.in_channels + self.out_channels = original_transformer.config.out_channels + self.patch_size = original_transformer.config.patch_size + self.cfg_parallel_enabled = cfg_parallel_enabled + self.tp_degree = tp_degree + self.world_size = world_size + + # SPMDRank for getting global rank at runtime (crucial for SPMD scatter/gather) + self.global_rank = SPMDRank(world_size=world_size) + + # DP group for CFG communication + self.data_parallel_group = parallel_state.get_data_parallel_group() + + # Input projections + self.img_in = original_transformer.img_in + self.txt_in = original_transformer.txt_in + + # Time/text embedding + self.time_text_embed = original_transformer.time_text_embed + + # Text norm + self.txt_norm = original_transformer.txt_norm + + # Transformer blocks with TP sharding + self.transformer_blocks = nn.ModuleList() + for i, block in enumerate(original_transformer.transformer_blocks): + # Shard with TP degree + block.attn = shard_qwen_attention(tp_degree, block.attn, reduce_dtype=_REDUCE_DTYPE) + block.img_mlp = shard_feedforward(block.img_mlp, reduce_dtype=_REDUCE_DTYPE) + block.txt_mlp = shard_feedforward(block.txt_mlp, reduce_dtype=_REDUCE_DTYPE) + block.img_mod = shard_modulation(block.img_mod) + block.txt_mod = shard_modulation(block.txt_mod) + self.transformer_blocks.append(block) + + if (i + 1) % 10 == 0: + print(f" Sharded block {i+1}/{len(original_transformer.transformer_blocks)}") + + # Replace attention with CFG+NKI version + self._replace_attention() + + # Final layers + self.norm_out = original_transformer.norm_out + self.proj_out = original_transformer.proj_out + + self.head_dim = 128 + self.num_heads = original_transformer.transformer_blocks[0].attn.heads + + def _replace_attention(self): + """Replace attention modules with CFG+NKI versions (no K/V gathering).""" + for i, block in enumerate(self.transformer_blocks): + block.attn = CFGNKIQwenAttention(block.attn) + print(f"Replaced attention with CFG+NKI versions on {len(self.transformer_blocks)} blocks") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + img_rotary_emb: torch.Tensor, + txt_rotary_emb: torch.Tensor, + ) -> torch.Tensor: + """Forward pass with CFG Parallel data splitting along batch dim.""" + + # ========== CFG PARALLEL: SPLIT DATA AT ENTRY (dim=0, batch) ========== + if self.cfg_parallel_enabled: + # Compute DP rank at runtime using SPMDRank + dp_rank = get_dp_rank_spmd(self.global_rank.get_rank(), self.tp_degree) + + # Split hidden_states along batch dim (dim=0): [2,S,C] -> [1,S,C] + hidden_states = split_along_dim( + hidden_states, dim=0, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + + # Split encoder_hidden_states along batch dim (dim=0): [2,S,C] -> [1,S,C] + encoder_hidden_states = split_along_dim( + encoder_hidden_states, dim=0, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + + # Split timestep along batch dim (dim=0): [2] -> [1] + timestep = split_along_dim( + timestep, dim=0, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + + # Do NOT scatter RoPE - position-indexed, same for both batch items + + # Split RoPE into cos/sin + img_freqs_cos = img_rotary_emb[..., 0] + img_freqs_sin = img_rotary_emb[..., 1] + txt_freqs_cos = txt_rotary_emb[..., 0] + txt_freqs_sin = txt_rotary_emb[..., 1] + + # Image input projection + hidden_states = self.img_in(hidden_states) + + # Text processing + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + # Time embedding + timestep = timestep.to(hidden_states.dtype) + temb = self.time_text_embed(timestep, hidden_states) + + # Create rotary_emb tuple + image_rotary_emb = ((img_freqs_cos, img_freqs_sin), (txt_freqs_cos, txt_freqs_sin)) + + # Process through blocks + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # Final norm and projection + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + # ========== CFG PARALLEL: GATHER OUTPUT (dim=0, batch) ========== + if self.cfg_parallel_enabled: + # Before gather: output has shape [1, patches, C] + output = gather_from_tensor_model_parallel_region_with_dim( + output, gather_dim=0, process_group=self.data_parallel_group + ) + # After gather: output has shape [2, patches, C] + + return output + + +class TracingWrapper(nn.Module): + """Wrapper for tracing.""" + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + + def forward(self, hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb): + return self.transformer( + hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb + ) + + +def get_rope_from_original_model(pipe, frame, height, width, text_seq_len, dtype=torch.bfloat16): + """Get RoPE from original model.""" + print(f" Getting RoPE: video_fhw=({frame}, {height}, {width}), text_seq_len={text_seq_len}") + + video_fhw = (frame, height, width) + vid_freqs, txt_freqs = pipe.transformer.pos_embed( + video_fhw, max_txt_seq_len=text_seq_len, device=torch.device('cpu') + ) + + img_cos = vid_freqs.real.float() + img_sin = vid_freqs.imag.float() + txt_cos = txt_freqs.real.float() + txt_sin = txt_freqs.imag.float() + + img_rotary_emb = torch.stack([img_cos, img_sin], dim=-1).to(dtype) + txt_rotary_emb = torch.stack([txt_cos, txt_sin], dim=-1).to(dtype) + + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + return img_rotary_emb, txt_rotary_emb + + +def compile_transformer_v3_cfg(args): + """Compile transformer with CFG Parallelism using ModelBuilder API.""" + + tp_degree = args.tp_degree + world_size = args.world_size + cfg_parallel_enabled = (world_size != tp_degree) + + if cfg_parallel_enabled: + dp_degree = world_size // tp_degree + print(f"CFG Parallel enabled: DP={dp_degree}") + else: + dp_degree = 1 + + # Calculate dimensions + latent_h = args.height // 8 + latent_w = args.width // 8 + patch_size = 2 + patch_h = latent_h // patch_size + patch_w = latent_w // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + text_seq_len = args.max_sequence_length + + text_hidden_size = 3584 + in_channels = 64 + head_dim = 128 + + # CFG alignment padding (simpler than CP - no sequence splitting) + # Just pad num_patches so total_seq = num_patches + text_seq_len is multiple of 128 + total_seq = num_patches + text_seq_len + alignment = 128 + need_padding = (alignment - total_seq % alignment) % alignment + num_patches_padded = num_patches + need_padding + patches_padding = need_padding + + # Hard-coded batch_size=2 for CFG (one positive + one negative) + batch_size = 2 + + print("=" * 60) + print("Transformer V3 CFG Parallel Compilation") + print("=" * 60) + print(f"Image: {args.height}x{args.width}") + print(f"Original patches: {num_patches}") + if patches_padding > 0: + print(f"Padded patches: {num_patches_padded} (+{patches_padding} for alignment)") + print(f"Total seq (padded): {num_patches_padded + text_seq_len}") + print(f"Total text seq: {text_seq_len}") + print(f"TP degree: {tp_degree}") + print(f"World size: {world_size}") + print(f"CFG Parallel: {cfg_parallel_enabled} (DP={dp_degree})") + print(f"NKI Flash Attention: Enabled") + print(f"Batch size: {batch_size} (hard-coded for CFG)") + + # Sample inputs (batch_size=2 for CFG) + sample_hidden_states = torch.randn(batch_size, num_patches_padded, in_channels, dtype=torch.bfloat16) + sample_encoder_hidden_states = torch.randn(batch_size, text_seq_len, text_hidden_size, dtype=torch.bfloat16) + sample_timestep = torch.randn(batch_size, dtype=torch.float32) + + # Use NxDParallelState context for compilation + # world_size=8, tensor_model_parallel_size=4 means DP=2 (used for CFG) + with NxDParallelState(world_size=world_size, tensor_model_parallel_size=tp_degree): + print("\nLoading model...") + load_kwargs = {"torch_dtype": torch.bfloat16, "local_files_only": True} + if CACHE_DIR: + load_kwargs["cache_dir"] = CACHE_DIR + pipe = QwenImageEditPlusPipeline.from_pretrained(MODEL_ID, **load_kwargs) + + # Get full RoPE + print("\nGetting RoPE...") + img_rotary_emb, txt_rotary_emb = get_rope_from_original_model( + pipe=pipe, + frame=temporal_frames, + height=patch_h, + width=patch_w, + text_seq_len=text_seq_len, + ) + + print(f" img RoPE (original): {img_rotary_emb.shape}") + print(f" txt RoPE: {txt_rotary_emb.shape}") + + # Pad img_rotary_emb if needed for alignment + if patches_padding > 0: + rope_padding = img_rotary_emb[-1:].repeat(patches_padding, 1, 1) + img_rotary_emb = torch.cat([img_rotary_emb, rope_padding], dim=0) + print(f" img RoPE (padded): {img_rotary_emb.shape} (+{patches_padding})") + + # Save unsharded state dict before modifications + unsharded_state = pipe.transformer.state_dict() + + # Create Neuron transformer + print("\nCreating Neuron transformer (sharding layers with TP={}, world_size={})...".format(tp_degree, world_size)) + neuron_transformer = NeuronQwenTransformerV3CFG( + pipe.transformer, tp_degree, world_size, cfg_parallel_enabled + ) + neuron_transformer = neuron_transformer.to(torch.bfloat16) + neuron_transformer.eval() + + # Wrap for tracing + model = TracingWrapper(neuron_transformer) + + print("\nInitializing ModelBuilder...") + builder = ModelBuilder(model=model) + + print("Tracing model...") + builder.trace( + kwargs={ + "hidden_states": sample_hidden_states, + "encoder_hidden_states": sample_encoder_hidden_states, + "timestep": sample_timestep, + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, + tag="inference", + ) + + print("Compiling model...") + compile_args = "--model-type=transformer -O2 --auto-cast=none --lnc=2 --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=4' --internal-hlo2tensorizer-options='--enable-native-kernel=1 --remat'" + traced_model = builder.compile( + compiler_args=compile_args, + compiler_workdir=args.compiler_workdir, + ) + + # Save + output_path = f"{args.compiled_models_dir}/transformer_v3_cfg" + os.makedirs(output_path, exist_ok=True) + + print(f"\nSaving to {output_path}...") + traced_model.save(os.path.join(output_path, "nxd_model.pt")) + + # Save weights + weights_path = os.path.join(output_path, "weights") + os.makedirs(weights_path, exist_ok=True) + + # Prepare checkpoint for sharding + checkpoint = {} + global_rank_state = {} # Save SPMDRank state separately (not sharded) + for key, value in model.state_dict().items(): + # Save SPMDRank module state separately - it's not sharded, same on all ranks + if 'global_rank' in key: + print(f" Saving SPMDRank key separately: {key}") + global_rank_state[key] = value.clone() + continue + # Use unsharded weights where available + orig_key = key.replace("transformer.", "", 1) + if orig_key in unsharded_state: + checkpoint[key] = unsharded_state[orig_key].clone() + else: + checkpoint[key] = value.clone() + + # Shard checkpoint + print("Sharding weights...") + shard_checkpoint( + checkpoint=checkpoint, + model=model, + serialize_path=weights_path, + ) + + # Post-process sharded checkpoints: + # 1. Remove master_weight tensors (they duplicate sharded weights, wastes ~50% space) + # 2. Add global_rank state (SPMDRank) to each checkpoint + print("\nPost-processing sharded checkpoints...") + from safetensors.torch import load_file, save_file + for rank in range(tp_degree): # Only TP checkpoints are created, CFG duplicates them at load time + shard_file = os.path.join(weights_path, f"tp{rank}_sharded_checkpoint.safetensors") + if not os.path.exists(shard_file): + print(f" WARNING: {shard_file} not found") + continue + + shard_data = dict(load_file(shard_file)) + original_count = len(shard_data) + original_size = sum(v.numel() * v.element_size() for v in shard_data.values()) + + # Remove master_weight tensors (they duplicate the sharded weights). + # Clone because load_file returns mmap'd tensors; overwriting the source + # file invalidates their backing storage and safetensors errors with + # "Bad address (os error 14)" during serialization. + cleaned = {k: v.clone().contiguous() for k, v in shard_data.items() if 'master_weight' not in k} + + # Add SPMDRank state (same value for all ranks) + if global_rank_state: + cleaned.update({k: v.clone().contiguous() if hasattr(v, 'clone') else v + for k, v in global_rank_state.items()}) + + cleaned_size = sum(v.numel() * v.element_size() for v in cleaned.values()) + # Drop the original mmap before writing to the same path + del shard_data + save_file(cleaned, shard_file) + print(f" tp{rank}: {original_count} -> {len(cleaned)} tensors, " + f"{original_size/1e9:.2f}GB -> {cleaned_size/1e9:.2f}GB") + + # Save config + config = { + "height": args.height, + "width": args.width, + "num_patches": num_patches, + "num_patches_padded": num_patches_padded, + "patches_padding": patches_padding, + "text_seq_len": text_seq_len, + "patch_multiplier": args.patch_multiplier, + "tp_degree": tp_degree, + "world_size": world_size, + "cfg_parallel": cfg_parallel_enabled, + "dp_degree": dp_degree, + "head_dim": head_dim, + "frame": temporal_frames, + "patch_h": patch_h, + "patch_w": patch_w, + "nki_flash_attention": True, + "batch_size": batch_size, + } + with open(os.path.join(output_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save pre-computed RoPE + torch.save({ + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, os.path.join(output_path, "rope_cache.pt")) + + print("\nCompilation complete!") + print(f"Model saved to: {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default=None, + help="Path to model (local dir or HuggingFace ID). If not set, uses MODEL_ID with CACHE_DIR") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--max_sequence_length", type=int, default=1024) + parser.add_argument("--patch_multiplier", type=int, default=3) + parser.add_argument("--tp_degree", type=int, default=4) + parser.add_argument("--world_size", type=int, default=8) + parser.add_argument("--compiled_models_dir", type=str, default="/opt/dlami/nvme/compiled_models") + parser.add_argument("--compiler_workdir", type=str, default="/opt/dlami/nvme/compiler_workdir") + args = parser.parse_args() + + # Override MODEL_ID and CACHE_DIR if model_path is provided + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + compile_transformer_v3_cfg(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_transformer_v3_cp.py b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v3_cp.py new file mode 100644 index 00000000..965bec66 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_transformer_v3_cp.py @@ -0,0 +1,801 @@ +""" +Transformer compilation with Context Parallel (V3 CP) using ModelBuilder API. + +Key approach: +1. Uses ModelBuilder API (like V2) for compilation +2. Configures world_size=8, tp_degree=4 (implicit CP=2) +3. K/V are all-gathered across DP group before attention +4. Uses NKI Flash Attention for optimal performance + +This is inspired by Flux's context parallel implementation which achieves +near-H100 performance on TRN2. + +Context Parallel works by: +- Model parameters are sharded with TP=4 +- DP group (2 ranks) is used for sequence parallelism +- Each DP rank processes half the sequence (queries) +- K/V are all-gathered so each rank sees full K/V +""" + +import os +import json +import math + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags - same as Flux for CP mode +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer --auto-cast=none --enable-fast-loading-neuron-binaries --tensorizer-options='--enable-ccop-compute-overlap' --internal-hlo2tensorizer-options='--enable-state-buffer-mode=hybrid --remat-by-default' """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import torch.nn as nn +import torch.nn.functional as F +import argparse +from typing import Optional, Tuple, List + +from diffusers import QwenImageEditPlusPipeline + +# ModelBuilder imports +from neuronx_distributed import ModelBuilder, NxDParallelState, shard_checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + SPMDRank, +) +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_tensor_model_parallel_region_with_dim, + scatter_to_process_group_spmd, +) + +from neuron_parallel_utils import ( + shard_qwen_attention, + shard_feedforward, + shard_modulation, + get_sharded_data, +) + +# Import NKI Flash Attention +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronxcc.nki.language import nc +from torch_neuronx.xla_impl.ops import nki_jit + +_flash_fwd_call = nki_jit()(attention_isa_kernel) + +print("NKI Flash Attention kernel loaded successfully") + +# Phase 17: lower the all-reduce dtype on RowParallelLinear from fp32 -> bf16. +# Mirrors the V3 CFG plumbing; halves the bytes on the wire for every TP +# all-reduce inside the transformer blocks. Default on; set +# QIE_ALLREDUCE_BF16=0 to revert. +ALLREDUCE_BF16 = os.getenv("QIE_ALLREDUCE_BF16", "1") == "1" +_REDUCE_DTYPE = torch.bfloat16 if ALLREDUCE_BF16 else torch.float32 +if ALLREDUCE_BF16: + print(" + TP all-reduce dtype = bf16 (Phase 17 experiment)") + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def nki_flash_attention(query, key, value): + """ + NKI Flash Attention wrapper. + + Args: + query: [B, H, S, D] + key: [B, H, S, D] + value: [B, H, S, D] + + Returns: + attention output [B, H, S, D] + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + v_len = value.shape[2] + + q = query.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, q_len)) + k = key.clone().permute(0, 1, 3, 2).reshape((bs * n_head, d_head, k_len)) + v = value.clone().reshape((bs * n_head, v_len, d_head)) + + attn_output = torch.zeros((bs * n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device) + scale = 1 / math.sqrt(d_head) + + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + if vc_size == 2: + grid = (nc(2),) + _flash_fwd_call[grid](q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + else: + _flash_fwd_call(q, k, v, scale, attn_output, kernel_name="AttentionMMSoftmaxMMWithoutSwap") + + return attn_output.reshape((bs, n_head, q_len, d_head)) + + +class CPNKIQwenAttention(nn.Module): + """ + Context Parallel + NKI Flash Attention for QwenImage. + + Key features: + 1. K/V are all-gathered across CP group before attention + 2. Uses NKI Flash Attention kernel + 3. Each CP rank processes its portion of queries against full K/V + """ + + def __init__(self, orig_attn, context_parallel_enabled=False, data_parallel_group=None): + super().__init__() + + self.context_parallel_enabled = context_parallel_enabled + self.data_parallel_group = data_parallel_group + self.heads = orig_attn.heads + self.to_q = orig_attn.to_q + self.to_k = orig_attn.to_k + self.to_v = orig_attn.to_v + self.to_out = orig_attn.to_out + + self.add_q_proj = orig_attn.add_q_proj if hasattr(orig_attn, 'add_q_proj') else None + self.add_k_proj = orig_attn.add_k_proj if hasattr(orig_attn, 'add_k_proj') else None + self.add_v_proj = orig_attn.add_v_proj if hasattr(orig_attn, 'add_v_proj') else None + self.to_add_out = orig_attn.to_add_out if hasattr(orig_attn, 'to_add_out') else None + + self.norm_q = orig_attn.norm_q if hasattr(orig_attn, 'norm_q') else None + self.norm_k = orig_attn.norm_k if hasattr(orig_attn, 'norm_k') else None + self.norm_added_q = orig_attn.norm_added_q if hasattr(orig_attn, 'norm_added_q') else None + self.norm_added_k = orig_attn.norm_added_k if hasattr(orig_attn, 'norm_added_k') else None + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + image_rotary_emb: Tuple = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward with Context Parallel K/V gathering and NKI attention. + """ + if encoder_hidden_states is None: + raise ValueError("CPNKIQwenAttention requires encoder_hidden_states") + + batch_size = hidden_states.shape[0] + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream + img_query = self.to_q(hidden_states) + img_key = self.to_k(hidden_states) + img_value = self.to_v(hidden_states) + + # Compute QKV for text stream + txt_query = self.add_q_proj(encoder_hidden_states) + txt_key = self.add_k_proj(encoder_hidden_states) + txt_value = self.add_v_proj(encoder_hidden_states) + + inner_dim = img_query.shape[-1] + head_dim = inner_dim // self.heads + + # Reshape to [B, H, S, D] + img_query = img_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_key = img_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + img_value = img_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + txt_query = txt_query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_key = txt_key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + txt_value = txt_value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Apply QK normalization + if self.norm_q is not None: + img_query = self.norm_q(img_query) + if self.norm_k is not None: + img_key = self.norm_k(img_key) + if self.norm_added_q is not None: + txt_query = self.norm_added_q(txt_query) + if self.norm_added_k is not None: + txt_key = self.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + img_freqs, txt_freqs = image_rotary_emb + img_query = apply_rotary_emb_precomputed(img_query.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + img_key = apply_rotary_emb_precomputed(img_key.transpose(1, 2), img_freqs, use_real=False).transpose(1, 2) + txt_query = apply_rotary_emb_precomputed(txt_query.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + txt_key = apply_rotary_emb_precomputed(txt_key.transpose(1, 2), txt_freqs, use_real=False).transpose(1, 2) + + # Context Parallel: All-gather K/V across DP group + if self.context_parallel_enabled: + # Gather image K/V + img_stacked_kv = torch.stack([img_key, img_value], dim=0) + img_stacked_kv = gather_from_tensor_model_parallel_region_with_dim( + img_stacked_kv, gather_dim=3, process_group=self.data_parallel_group + ) + img_key, img_value = torch.unbind(img_stacked_kv, dim=0) + + # Gather text K/V + txt_stacked_kv = torch.stack([txt_key, txt_value], dim=0) + txt_stacked_kv = gather_from_tensor_model_parallel_region_with_dim( + txt_stacked_kv, gather_dim=3, process_group=self.data_parallel_group + ) + txt_key, txt_value = torch.unbind(txt_stacked_kv, dim=0) + + # Concatenate for joint attention + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) + + # NKI Flash Attention + joint_hidden_states = nki_flash_attention(joint_query, joint_key, joint_value) + + # Transpose and reshape + joint_hidden_states = joint_hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split back (use original local seq_txt for splitting) + txt_attn_output = joint_hidden_states[:, :seq_txt, :] + img_attn_output = joint_hidden_states[:, seq_txt:, :] + + # Output projections + img_attn_output = self.to_out[0](img_attn_output) + if len(self.to_out) > 1: + img_attn_output = self.to_out[1](img_attn_output) + + txt_attn_output = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +def apply_rotary_emb_precomputed( + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """Apply rotary embeddings using pre-computed cos/sin tensors.""" + cos, sin = freqs_cis + cos = cos.to(x.device) + sin = sin.to(x.device) + + if not use_real: + x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) + x_real = x_reshaped[..., 0] + x_imag = x_reshaped[..., 1] + + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + out_real = x_real * cos - x_imag * sin + out_imag = x_real * sin + x_imag * cos + + out = torch.stack([out_real, out_imag], dim=-1) + out = out.flatten(-2) + + return out.to(x.dtype) + else: + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + cos = cos.repeat_interleave(2, dim=-1) + sin = sin.repeat_interleave(2, dim=-1) + + if use_real_unbind_dim == -1: + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + else: + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +# Patch apply_rotary_emb_qwen +import diffusers.models.transformers.transformer_qwenimage as qwen_module +qwen_module.apply_rotary_emb_qwen = apply_rotary_emb_precomputed +print("Patched apply_rotary_emb_qwen for pre-computed RoPE") + + +def split_along_dim(tensor, dim, rank, data_parallel_group): + """Split tensor along dimension using scatter_to_process_group_spmd.""" + tensor = scatter_to_process_group_spmd( + tensor, + partition_dim=dim, + rank=rank, + process_group=data_parallel_group, + ) + return tensor + + +def get_dp_rank_spmd(global_rank: torch.Tensor, tp_degree: int) -> torch.Tensor: + """ + Compute DP rank from global rank for SPMD execution. + + With world_size=8 and tp_degree=4: + - Ranks 0-3 are DP rank 0 + - Ranks 4-7 are DP rank 1 + """ + dp_rank = torch.div( + global_rank, + tp_degree, + rounding_mode="floor", + ).to(torch.int32) + return dp_rank + + +class NeuronQwenTransformerV3CP(nn.Module): + """ + Neuron-optimized QwenImage Transformer with Context Parallel. + + Features: + - TP=4 for model parameter sharding + - CP enabled (via DP group) for sequence parallelism + - Data is SPLIT at entry, K/V gathered in attention, output gathered at exit + - NKI Flash Attention + """ + + def __init__(self, original_transformer, tp_degree, world_size, context_parallel_enabled=False): + super().__init__() + + self.config = original_transformer.config + self.in_channels = original_transformer.config.in_channels + self.out_channels = original_transformer.config.out_channels + self.patch_size = original_transformer.config.patch_size + self.context_parallel_enabled = context_parallel_enabled + self.tp_degree = tp_degree + self.world_size = world_size + + # SPMDRank for getting global rank at runtime (crucial for SPMD scatter/gather) + self.global_rank = SPMDRank(world_size=world_size) + + # DP group for CP communication + self.data_parallel_group = parallel_state.get_data_parallel_group() + + # Input projections + self.img_in = original_transformer.img_in + self.txt_in = original_transformer.txt_in + + # Time/text embedding + self.time_text_embed = original_transformer.time_text_embed + + # Text norm + self.txt_norm = original_transformer.txt_norm + + # Transformer blocks with TP sharding + self.transformer_blocks = nn.ModuleList() + for i, block in enumerate(original_transformer.transformer_blocks): + # Shard with TP degree + block.attn = shard_qwen_attention(tp_degree, block.attn, reduce_dtype=_REDUCE_DTYPE) + block.img_mlp = shard_feedforward(block.img_mlp, reduce_dtype=_REDUCE_DTYPE) + block.txt_mlp = shard_feedforward(block.txt_mlp, reduce_dtype=_REDUCE_DTYPE) + block.img_mod = shard_modulation(block.img_mod) + block.txt_mod = shard_modulation(block.txt_mod) + self.transformer_blocks.append(block) + + if (i + 1) % 10 == 0: + print(f" Sharded block {i+1}/{len(original_transformer.transformer_blocks)}") + + # Replace attention with CP+NKI version + self._replace_attention() + + # Final layers + self.norm_out = original_transformer.norm_out + self.proj_out = original_transformer.proj_out + + self.head_dim = 128 + self.num_heads = original_transformer.transformer_blocks[0].attn.heads + + def _replace_attention(self): + """Replace attention modules with CP+NKI versions.""" + for i, block in enumerate(self.transformer_blocks): + block.attn = CPNKIQwenAttention( + block.attn, self.context_parallel_enabled, self.data_parallel_group + ) + print(f"Replaced attention with CP+NKI versions on {len(self.transformer_blocks)} blocks") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.Tensor, + img_rotary_emb: torch.Tensor, + txt_rotary_emb: torch.Tensor, + ) -> torch.Tensor: + """Forward pass with Context Parallel data splitting.""" + + # Store original shapes for verification + orig_hidden_shape = hidden_states.shape + orig_enc_shape = encoder_hidden_states.shape + + # ========== CONTEXT PARALLEL: SPLIT DATA AT ENTRY ========== + if self.context_parallel_enabled: + # Compute DP rank at runtime using SPMDRank (returns different values per rank) + dp_rank = get_dp_rank_spmd(self.global_rank.get_rank(), self.tp_degree) + + # Split hidden_states along sequence dim (dim=1) + hidden_states = split_along_dim( + hidden_states, dim=1, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + + # Split encoder_hidden_states along sequence dim (dim=1) + encoder_hidden_states = split_along_dim( + encoder_hidden_states, dim=1, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + + # Split RoPE along position dim (dim=0) + img_rotary_emb = split_along_dim( + img_rotary_emb, dim=0, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + txt_rotary_emb = split_along_dim( + txt_rotary_emb, dim=0, rank=dp_rank, data_parallel_group=self.data_parallel_group + ) + + # Split RoPE into cos/sin + img_freqs_cos = img_rotary_emb[..., 0] + img_freqs_sin = img_rotary_emb[..., 1] + txt_freqs_cos = txt_rotary_emb[..., 0] + txt_freqs_sin = txt_rotary_emb[..., 1] + + # Image input projection + hidden_states = self.img_in(hidden_states) + + # Text processing + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + # Time embedding + timestep = timestep.to(hidden_states.dtype) + temb = self.time_text_embed(timestep, hidden_states) + + # Create rotary_emb tuple + image_rotary_emb = ((img_freqs_cos, img_freqs_sin), (txt_freqs_cos, txt_freqs_sin)) + + # Process through blocks + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=None, + temb=temb, + image_rotary_emb=image_rotary_emb, + ) + + # Final norm and projection + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + # ========== CONTEXT PARALLEL: GATHER OUTPUT ========== + if self.context_parallel_enabled: + # Before gather: output has shape [B, local_patches, C] + output = gather_from_tensor_model_parallel_region_with_dim( + output, gather_dim=1, process_group=self.data_parallel_group + ) + # After gather: output should have shape [B, full_patches, C] + # Verify that we recovered the original sequence length + # orig_hidden_shape[1] is the original num_patches + + return output + + +class TracingWrapper(nn.Module): + """Wrapper for tracing.""" + + def __init__(self, transformer): + super().__init__() + self.transformer = transformer + + def forward(self, hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb): + return self.transformer( + hidden_states, encoder_hidden_states, timestep, + img_rotary_emb, txt_rotary_emb + ) + + +def get_rope_from_original_model(pipe, frame, height, width, text_seq_len, dtype=torch.bfloat16): + """Get RoPE from original model.""" + print(f" Getting RoPE: video_fhw=({frame}, {height}, {width}), text_seq_len={text_seq_len}") + + video_fhw = (frame, height, width) + vid_freqs, txt_freqs = pipe.transformer.pos_embed( + video_fhw, max_txt_seq_len=text_seq_len, device=torch.device('cpu') + ) + + img_cos = vid_freqs.real.float() + img_sin = vid_freqs.imag.float() + txt_cos = txt_freqs.real.float() + txt_sin = txt_freqs.imag.float() + + img_rotary_emb = torch.stack([img_cos, img_sin], dim=-1).to(dtype) + txt_rotary_emb = torch.stack([txt_cos, txt_sin], dim=-1).to(dtype) + + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + return img_rotary_emb, txt_rotary_emb + + +def compile_transformer_v3_cp(args): + """Compile transformer with Context Parallel using ModelBuilder API.""" + + tp_degree = args.tp_degree + world_size = args.world_size + context_parallel_enabled = (world_size != tp_degree) + + if context_parallel_enabled: + cp_degree = world_size // tp_degree + print(f"Context Parallel enabled: CP={cp_degree}") + else: + cp_degree = 1 + + # Calculate dimensions + latent_h = args.height // 8 + latent_w = args.width // 8 + patch_size = 2 + patch_h = latent_h // patch_size + patch_w = latent_w // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + text_seq_len = args.max_sequence_length + + text_hidden_size = 3584 + in_channels = 64 + head_dim = 128 + + # Calculate CP alignment padding (padding goes to patches, not text) + # This keeps text_seq_len unchanged, avoiding RoPE position issues + if context_parallel_enabled: + local_patches = num_patches // cp_degree + local_text = text_seq_len // cp_degree + local_total = local_patches + local_text + + # NKI Flash Attention requires local seqlen_q to be a multiple of 128. BUT when + # tp_degree >= 8 each rank has few heads (24/8=3), so the kernel falls back to a + # SHARDED-attention path that requires seqlen_q divisible by a power of 2 >= 512. + # Align to 512 in that case (512 satisfies the divisibility); 128 otherwise. + alignment = 512 if tp_degree >= 8 else 128 + need_padding = (alignment - local_total % alignment) % alignment + patches_padding = need_padding * cp_degree # Total padding for patches + num_patches_padded = num_patches + patches_padding + else: + patches_padding = 0 + num_patches_padded = num_patches + + print("=" * 60) + print("Transformer V3 Context Parallel Compilation") + print("=" * 60) + print(f"Image: {args.height}x{args.width}") + print(f"Original patches: {num_patches}") + if patches_padding > 0: + print(f"Padded patches: {num_patches_padded} (+{patches_padding} for CP alignment)") + print(f"Total text seq: {text_seq_len}") + print(f"TP degree: {tp_degree}") + print(f"World size: {world_size}") + print(f"Context Parallel: {context_parallel_enabled} (CP={cp_degree})") + print(f"NKI Flash Attention: Enabled") + print(f"Batch size: {args.batch_size}") + + # Sample inputs (use padded num_patches for compilation) + batch_size = args.batch_size + sample_hidden_states = torch.randn(batch_size, num_patches_padded, in_channels, dtype=torch.bfloat16) + sample_encoder_hidden_states = torch.randn(batch_size, text_seq_len, text_hidden_size, dtype=torch.bfloat16) + sample_timestep = torch.randn(batch_size, dtype=torch.float32) + + # Use NxDParallelState context for compilation + # world_size=8, tensor_model_parallel_size=4 means DP=2 (used for CP) + with NxDParallelState(world_size=world_size, tensor_model_parallel_size=tp_degree): + print("\nLoading model...") + load_kwargs = {"torch_dtype": torch.bfloat16, "local_files_only": True} + if CACHE_DIR: + load_kwargs["cache_dir"] = CACHE_DIR + pipe = QwenImageEditPlusPipeline.from_pretrained(MODEL_ID, **load_kwargs) + + # Get full RoPE + print("\nGetting RoPE...") + img_rotary_emb, txt_rotary_emb = get_rope_from_original_model( + pipe=pipe, + frame=temporal_frames, + height=patch_h, + width=patch_w, + text_seq_len=text_seq_len, + ) + + print(f" img RoPE (original): {img_rotary_emb.shape}") + print(f" txt RoPE: {txt_rotary_emb.shape}") + + # Pad img_rotary_emb if needed for CP alignment + if patches_padding > 0: + # Repeat last position's RoPE for padding (position doesn't matter for padding tokens) + rope_padding = img_rotary_emb[-1:].repeat(patches_padding, 1, 1) + img_rotary_emb = torch.cat([img_rotary_emb, rope_padding], dim=0) + print(f" img RoPE (padded): {img_rotary_emb.shape} (+{patches_padding})") + + # Save unsharded state dict before modifications + unsharded_state = pipe.transformer.state_dict() + + # Create Neuron transformer + print("\nCreating Neuron transformer (sharding layers with TP={}, world_size={})...".format(tp_degree, world_size)) + neuron_transformer = NeuronQwenTransformerV3CP( + pipe.transformer, tp_degree, world_size, context_parallel_enabled + ) + neuron_transformer = neuron_transformer.to(torch.bfloat16) + neuron_transformer.eval() + + # Wrap for tracing + model = TracingWrapper(neuron_transformer) + + print("\nInitializing ModelBuilder...") + builder = ModelBuilder(model=model) + + print("Tracing model...") + builder.trace( + kwargs={ + "hidden_states": sample_hidden_states, + "encoder_hidden_states": sample_encoder_hidden_states, + "timestep": sample_timestep, + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, + tag="inference", + ) + + print("Compiling model...") + # Pass compiler args directly to compile() for State Buffer optimization + # --enable-native-kernel=1: enables native kernel mode + # --remat: enables rematerialization to save memory + # NOTE: -O2 matches the V3 CFG path (compile_transformer_v3_cfg.py); validated on QIE shapes. + # QIE_OPT_LEVEL: opt-in override (default O2). Set to "3" to try -O3 (NOTE: known to OOM on QIE V3 CP). + # QIE_CC_TILING: --cc-pipeline-tiling-factor (default 4). Larger = more comm/compute overlap chunks. + # NOTE on fp8: --auto-cast=matmult --auto-cast-type=fp8_e4m3 is a NO-OP here — it casts fp32→fp8, + # but QIE runs bf16 end-to-end (PyTorch model bf16, AR bf16, attention bf16). Verified + # 2026-05-28: NEFF compiled with the flag matched baseline at 7.66 s E2E. To actually get fp8, + # weights must be explicitly quantized in PyTorch (per-channel scale + calibration) — comparable + # engineering scope to writing custom NKI matmul kernels. + _opt = os.environ.get("QIE_OPT_LEVEL", "2") + _cc_tiling = os.environ.get("QIE_CC_TILING", "4") + compile_args = f"--model-type=transformer -O{_opt} --auto-cast=none --lnc=2 --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor={_cc_tiling}' --internal-hlo2tensorizer-options='--enable-native-kernel=1 --remat'" + print(f" compile_args: {compile_args}") + # QIE_WLO: weight layout optimization. Passing priority_model_key triggers + # ModelBuilder._compile_priority_model -> mark_weights_for_wlo + compile_wlo + + # compile_layout_transformer. The layout transformer is embedded in nxd_model.pt and + # applied at runtime by NxDModel.set_weights (nxd_model.py:301) — so disk weights stay + # in normal layout and QIE_SKIP_WEIGHTS=1 fast-recompile still picks up the new layout. + # bit-exact (only reshapes weight storage, no numerical change). Default on; =0 to revert. + _wlo = os.environ.get("QIE_WLO", "1") == "1" + _priority_key = "inference" if _wlo else None + if _wlo: + print(" + weight layout optimization (priority_model_key='inference')") + traced_model = builder.compile( + priority_model_key=_priority_key, + compiler_args=compile_args, + compiler_workdir=args.compiler_workdir, + ) + + # Save + output_path = f"{args.compiled_models_dir}/transformer_v3_cp" + os.makedirs(output_path, exist_ok=True) + + print(f"\nSaving to {output_path}...") + traced_model.save(os.path.join(output_path, "nxd_model.pt")) + + # QIE_SKIP_WEIGHTS=1: only re-emit nxd_model.pt (the NEFF) and exit. + # Lets compile-flag experiments (CC tiling, opt level) reuse existing sharded weights. + if os.environ.get("QIE_SKIP_WEIGHTS", "0") == "1": + print("\nQIE_SKIP_WEIGHTS=1 — skipping weight sharding/config/rope save.") + print(f"Updated NEFF only at: {output_path}/nxd_model.pt") + return + + # Save weights + weights_path = os.path.join(output_path, "weights") + os.makedirs(weights_path, exist_ok=True) + + # Prepare checkpoint for sharding + checkpoint = {} + global_rank_state = {} # Save SPMDRank state separately (not sharded) + for key, value in model.state_dict().items(): + # Save SPMDRank module state separately - it's not sharded, same on all ranks + if 'global_rank' in key: + print(f" Saving SPMDRank key separately: {key}") + global_rank_state[key] = value.clone() + continue + # Use unsharded weights where available + orig_key = key.replace("transformer.", "", 1) + if orig_key in unsharded_state: + checkpoint[key] = unsharded_state[orig_key].clone() + else: + checkpoint[key] = value.clone() + + # Shard checkpoint + print("Sharding weights...") + shard_checkpoint( + checkpoint=checkpoint, + model=model, + serialize_path=weights_path, + ) + + # Post-process sharded checkpoints: + # 1. Remove master_weight tensors (they duplicate sharded weights, wastes ~50% space) + # 2. Add global_rank state (SPMDRank) to each checkpoint + print("\nPost-processing sharded checkpoints...") + from safetensors.torch import load_file, save_file + for rank in range(tp_degree): # Only TP checkpoints are created, CP duplicates them at load time + shard_file = os.path.join(weights_path, f"tp{rank}_sharded_checkpoint.safetensors") + if not os.path.exists(shard_file): + print(f" WARNING: {shard_file} not found") + continue + + shard_data = dict(load_file(shard_file)) + original_count = len(shard_data) + original_size = sum(v.numel() * v.element_size() for v in shard_data.values()) + + # Clone to fresh storage — load_file returns memory-mapped tensors, + # writing back to the same path while those mmaps are alive triggers + # ``SafetensorError: I/O error: Bad address (os error 14)``. + cleaned = {k: v.detach().clone().contiguous() for k, v in shard_data.items() if 'master_weight' not in k} + del shard_data + + # Add SPMDRank state (same value for all ranks) + if global_rank_state: + cleaned.update(global_rank_state) + + cleaned_size = sum(v.numel() * v.element_size() for v in cleaned.values()) + save_file(cleaned, shard_file) + print(f" tp{rank}: {original_count} -> {len(cleaned)} tensors, " + f"{original_size/1e9:.2f}GB -> {cleaned_size/1e9:.2f}GB") + + # Save config + config = { + "height": args.height, + "width": args.width, + "num_patches": num_patches, + "num_patches_padded": num_patches_padded, + "patches_padding": patches_padding, + "text_seq_len": text_seq_len, + "patch_multiplier": args.patch_multiplier, + "tp_degree": tp_degree, + "world_size": world_size, + "context_parallel": context_parallel_enabled, + "cp_degree": cp_degree, + "head_dim": head_dim, + "frame": temporal_frames, + "patch_h": patch_h, + "patch_w": patch_w, + "nki_flash_attention": True, + "batch_size": batch_size, + } + with open(os.path.join(output_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save pre-computed RoPE + torch.save({ + "img_rotary_emb": img_rotary_emb, + "txt_rotary_emb": txt_rotary_emb, + }, os.path.join(output_path, "rope_cache.pt")) + + print("\nCompilation complete!") + print(f"Model saved to: {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default=None, + help="Path to model (local dir or HuggingFace ID). If not set, uses MODEL_ID with CACHE_DIR") + parser.add_argument("--height", type=int, default=1024) + parser.add_argument("--width", type=int, default=1024) + parser.add_argument("--max_sequence_length", type=int, default=1024) + parser.add_argument("--patch_multiplier", type=int, default=3) + parser.add_argument("--tp_degree", type=int, default=4) + parser.add_argument("--world_size", type=int, default=8) + parser.add_argument("--batch_size", type=int, default=1, + help="Batch size for compiled model (default: 1)") + parser.add_argument("--compiled_models_dir", type=str, default="/opt/dlami/nvme/compiled_models") + parser.add_argument("--compiler_workdir", type=str, default="/opt/dlami/nvme/compiler_workdir") + args = parser.parse_args() + + # Override MODEL_ID and CACHE_DIR if model_path is provided + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + compile_transformer_v3_cp(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_vae.py b/contrib/models/Qwen-Image-Edit/src/compile_vae.py new file mode 100644 index 00000000..c7050361 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_vae.py @@ -0,0 +1,301 @@ +import os + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" # For trn2 +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" # For trn2 + +compiler_flags = """ --target=trn2 --lnc=2 --model-type=unet-inference --enable-fast-loading-neuron-binaries """ # --verbose=INFO +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import argparse +import torch_neuronx +from torch import nn + +from diffusers import QwenImageEditPlusPipeline +from neuron_commons import attention_wrapper, f32Wrapper + +# Import modified VAE that uses 'nearest' instead of 'nearest-exact' +# (Neuron doesn't support 'nearest-exact' interpolation mode) +from autoencoder_kl_qwenimage_neuron import AutoencoderKLQwenImage as NeuronAutoencoder + +# Override SDPA +torch.nn.functional.scaled_dot_product_attention = attention_wrapper + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +class VAEEncoderWrapper(nn.Module): + """Wrapper for VAE encoder.""" + + def __init__(self, encoder): + super().__init__() + self.encoder = encoder + + def forward(self, x): + return self.encoder(x) + + +class VAEDecoderWrapper(nn.Module): + """Wrapper for VAE decoder.""" + + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + + def forward(self, x): + return self.decoder(x) + + +def upcast_norms_to_f32(module): + """Upcast normalization layers to float32 for numerical stability.""" + for name, child in module.named_children(): + if isinstance(child, (torch.nn.GroupNorm, torch.nn.LayerNorm)): + setattr(module, name, f32Wrapper(child.to(torch.float32))) + else: + upcast_norms_to_f32(child) + + +def compile_vae(args): + """ + Compile VAE for QwenImage. + + Note: QwenImage VAE uses 3D convolutions (for video/multi-frame support). + Input shape: (batch, channels, temporal_frames, height, width) + For single image inference, temporal_frames=1. + """ + latent_height = args.height // 8 + latent_width = args.width // 8 + temporal_frames = args.temporal_frames # Number of temporal frames + latent_temporal = temporal_frames # Temporal dimension in latent space + + compiler_workdir = args.compiler_workdir + compiled_models_dir = args.compiled_models_dir + batch_size = args.batch_size + dtype = torch.bfloat16 + + load_kwargs = {"local_files_only": True, "torch_dtype": dtype} + if CACHE_DIR: + load_kwargs["cache_dir"] = CACHE_DIR + pipe = QwenImageEditPlusPipeline.from_pretrained(MODEL_ID, **load_kwargs) + + # Replace VAE with Neuron-compatible version (uses 'nearest' instead of 'nearest-exact') + print("Replacing VAE with Neuron-compatible version...") + original_vae_config = pipe.vae.config + neuron_vae = NeuronAutoencoder( + base_dim=original_vae_config.base_dim, + z_dim=original_vae_config.z_dim, + dim_mult=original_vae_config.dim_mult, + num_res_blocks=original_vae_config.num_res_blocks, + attn_scales=original_vae_config.attn_scales, + temperal_downsample=original_vae_config.temperal_downsample, + dropout=original_vae_config.dropout, + input_channels=getattr(original_vae_config, "input_channels", 3), + latents_mean=original_vae_config.latents_mean, + latents_std=original_vae_config.latents_std, + ) + # Load weights from original VAE + neuron_vae.load_state_dict(pipe.vae.state_dict()) + neuron_vae = neuron_vae.to(dtype) + pipe.vae = neuron_vae + + z_dim = pipe.vae.config.z_dim # 16 for QwenImage VAE + + # Compile VAE Encoder + print("Compiling VAE encoder...") + print( + f" Input shape: ({batch_size}, 3, {temporal_frames}, {args.height}, {args.width})" + ) + encoder = pipe.vae.encoder + encoder.eval() + upcast_norms_to_f32(encoder) + + with torch.no_grad(): + # Encoder input: (batch, channels, temporal_frames, height, width) - 5D for Conv3d + encoder_input = torch.rand( + (batch_size, 3, temporal_frames, args.height, args.width), dtype=dtype + ) + compiled_encoder = torch_neuronx.trace( + encoder, + encoder_input, + compiler_workdir=f"{compiler_workdir}/vae_encoder", + compiler_args=compiler_flags, + inline_weights_to_neff=False, + ) + + encoder_dir = f"{compiled_models_dir}/vae_encoder" + if not os.path.exists(encoder_dir): + os.makedirs(encoder_dir) + torch.jit.save(compiled_encoder, f"{encoder_dir}/model.pt") + print(f"VAE encoder compiled and saved to {encoder_dir}") + + # Compile VAE Decoder + # NOTE: At LNC=2 (trn2.3xlarge default), NEURON_CUSTOM_SILU=1 and + # NEURON_FUSE_SOFTMAX=1 cause an internal compiler error (NCC_IBIR182) + # for the VAE decoder. The encoder compiles fine with these flags. + # We disable them for decoder compilation and restore afterward. + saved_silu = os.environ.get("NEURON_CUSTOM_SILU") + saved_softmax = os.environ.get("NEURON_FUSE_SOFTMAX") + os.environ["NEURON_CUSTOM_SILU"] = "0" + os.environ["NEURON_FUSE_SOFTMAX"] = "0" + + print("Compiling VAE decoder...") + print( + f" Input shape: ({batch_size}, {z_dim}, {latent_temporal}, {latent_height}, {latent_width})" + ) + print( + f" NOTE: NEURON_CUSTOM_SILU and NEURON_FUSE_SOFTMAX disabled for decoder (LNC=2 compatibility)" + ) + decoder = pipe.vae.decoder + decoder.eval() + upcast_norms_to_f32(decoder) + + with torch.no_grad(): + # Decoder input: (batch, z_dim, temporal_frames, latent_height, latent_width) - 5D + decoder_input = torch.rand( + (batch_size, z_dim, latent_temporal, latent_height, latent_width), + dtype=dtype, + ) + compiled_decoder = torch_neuronx.trace( + decoder, + decoder_input, + compiler_workdir=f"{compiler_workdir}/vae_decoder", + compiler_args=compiler_flags, + inline_weights_to_neff=False, + ) + + decoder_dir = f"{compiled_models_dir}/vae_decoder" + if not os.path.exists(decoder_dir): + os.makedirs(decoder_dir) + torch.jit.save(compiled_decoder, f"{decoder_dir}/model.pt") + print(f"VAE decoder compiled and saved to {decoder_dir}") + + # Restore NEURON_CUSTOM_SILU and NEURON_FUSE_SOFTMAX after decoder compilation + if saved_silu is not None: + os.environ["NEURON_CUSTOM_SILU"] = saved_silu + if saved_softmax is not None: + os.environ["NEURON_FUSE_SOFTMAX"] = saved_softmax + + # Compile quant_conv and post_quant_conv if they exist + if hasattr(pipe.vae, "quant_conv") and pipe.vae.quant_conv is not None: + print("Compiling quant_conv...") + with torch.no_grad(): + quant_input = torch.rand( + (batch_size, z_dim * 2, latent_temporal, latent_height, latent_width), + dtype=dtype, + ) + compiled_quant = torch_neuronx.trace( + pipe.vae.quant_conv, + quant_input, + compiler_workdir=f"{compiler_workdir}/quant_conv", + compiler_args=compiler_flags, + inline_weights_to_neff=False, + ) + quant_dir = f"{compiled_models_dir}/quant_conv" + if not os.path.exists(quant_dir): + os.makedirs(quant_dir) + torch.jit.save(compiled_quant, f"{quant_dir}/model.pt") + print(f"quant_conv compiled and saved to {quant_dir}") + + if hasattr(pipe.vae, "post_quant_conv") and pipe.vae.post_quant_conv is not None: + print("Compiling post_quant_conv...") + with torch.no_grad(): + post_quant_input = torch.rand( + (batch_size, z_dim, latent_temporal, latent_height, latent_width), + dtype=dtype, + ) + compiled_post_quant = torch_neuronx.trace( + pipe.vae.post_quant_conv, + post_quant_input, + compiler_workdir=f"{compiler_workdir}/post_quant_conv", + compiler_args=compiler_flags, + inline_weights_to_neff=False, + ) + post_quant_dir = f"{compiled_models_dir}/post_quant_conv" + if not os.path.exists(post_quant_dir): + os.makedirs(post_quant_dir) + torch.jit.save(compiled_post_quant, f"{post_quant_dir}/model.pt") + print(f"post_quant_conv compiled and saved to {post_quant_dir}") + + # Save VAE config + import json + + vae_config = { + "height": args.height, + "width": args.width, + "temporal_frames": temporal_frames, + "batch_size": batch_size, + "z_dim": z_dim, + "latent_height": latent_height, + "latent_width": latent_width, + } + config_path = f"{compiled_models_dir}/vae_config.json" + with open(config_path, "w") as f: + json.dump(vae_config, f, indent=2) + print(f"VAE config saved to {config_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model (local dir or HuggingFace ID). If not set, uses MODEL_ID with CACHE_DIR", + ) + parser.add_argument( + "--height", + type=int, + default=512, + help="Height of generated image (compile tile size)", + ) + parser.add_argument( + "--width", + type=int, + default=512, + help="Width of generated image (compile tile size)", + ) + parser.add_argument( + "--temporal_frames", + type=int, + default=1, + help="Number of temporal frames (1 for single image)", + ) + parser.add_argument( + "--batch_size", type=int, default=1, help="Batch size for VAE (default: 1)" + ) + parser.add_argument( + "--compiler_workdir", + type=str, + default="compiler_workdir", + help="Directory for compiler artifacts", + ) + parser.add_argument( + "--compiled_models_dir", + type=str, + default="compiled_models", + help="Directory for compiled models", + ) + args = parser.parse_args() + + # Override MODEL_ID and CACHE_DIR if model_path is provided + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + print("=" * 60) + print("VAE Compilation for Neuron") + print("=" * 60) + print(f"Compile tile size: {args.height}x{args.width}") + print(f"Batch size: {args.batch_size}") + print("") + print("NOTE: For inference at larger resolutions (e.g., 1024x1024),") + print(" tiled VAE processing will be used automatically.") + print(" The VAE is compiled at this tile size for memory efficiency.") + print(" With batch_size > 1, multiple tiles can be processed in parallel.") + print("") + + compile_vae(args) diff --git a/contrib/models/Qwen-Image-Edit/src/compile_vision_encoder_v3.py b/contrib/models/Qwen-Image-Edit/src/compile_vision_encoder_v3.py new file mode 100644 index 00000000..e9e2f284 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/compile_vision_encoder_v3.py @@ -0,0 +1,564 @@ +""" +Vision Encoder Compilation using ModelBuilder API (V3) for TP=4 Acceleration. + +This script compiles the Qwen2.5-VL Vision Encoder using ModelBuilder API with +tp_degree=4 and world_size=8 for faster inference while maintaining float32 precision. + +Key features: +- Uses ModelBuilder API (NxDModel) for compilation +- Configuration: tp_degree=4, world_size=8 (matching V3 CP transformer) +- Float32 precision for accuracy (required for vision encoder) +- Vision encoder hidden_size=1280, QKV=3840, MLP intermediate=3420 +- TP=4 works: 3840/4=960, 3420/4=855 (both divisible) + +Usage: + python compile_vision_encoder_v3.py --image_size 448 +""" + +import os +import json +import gc + +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" +os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "0" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +# Compiler flags +compiler_flags = """ --target=trn2 --lnc=2 --model-type=transformer --enable-fast-loading-neuron-binaries """ +os.environ["NEURON_CC_FLAGS"] = os.environ.get("NEURON_CC_FLAGS", "") + compiler_flags + +import torch +import torch.nn as nn +import argparse + +from diffusers import QwenImageEditPlusPipeline + +# ModelBuilder imports +from neuronx_distributed import ModelBuilder, NxDParallelState, shard_checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed.parallel_layers import parallel_state + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +def load_pipeline(dtype=torch.float32): + """Load pipeline with appropriate kwargs.""" + load_kwargs = {"torch_dtype": dtype, "local_files_only": True} + if CACHE_DIR: + load_kwargs["cache_dir"] = CACHE_DIR + return QwenImageEditPlusPipeline.from_pretrained(MODEL_ID, **load_kwargs) + + +class f32Wrapper(nn.Module): + """Wrapper to run normalization layers in float32 for numerical stability.""" + + def __init__(self, original): + super().__init__() + self.original = original + + def forward(self, x, *args, **kwargs): + t = x.dtype + y = x.to(torch.float32) + output = self.original(y, *args, **kwargs) + return output.type(t) + + +def upcast_norms_to_f32(module): + """Upcast normalization layers to float32 for numerical stability.""" + for name, child in module.named_children(): + if isinstance(child, (torch.nn.LayerNorm,)): + setattr(module, name, f32Wrapper(child.to(torch.float32))) + elif "RMSNorm" in child.__class__.__name__: + setattr(module, name, f32Wrapper(child.to(torch.float32))) + else: + upcast_norms_to_f32(child) + + +def get_sharded_data(data, dim): + """Get this rank's portion of sharded data.""" + tp_rank = parallel_state.get_tensor_model_parallel_rank() + tp_degree = parallel_state.get_tensor_model_parallel_size() + + total_size = data.shape[dim] + shard_size = total_size // tp_degree + + start = tp_rank * shard_size + end = start + shard_size + + if dim == 0: + return data[start:end].clone() + elif dim == 1: + return data[:, start:end].clone() + else: + raise ValueError(f"Unsupported shard dimension: {dim}") + + +def shard_vision_attention_fp32(tp_degree: int, attn): + """ + Shard Qwen2.5-VL Vision Encoder attention module with float32 precision. + + Vision attention uses fused QKV projection: + - qkv: (in_features, 3 * in_features) -> splits into Q, K, V + - proj: output projection + + Qwen2.5-VL vision encoder: + - hidden_size (embed_dim) = 1280 + - num_heads = 16, head_dim = 80 + - QKV dim = 3840 = 1280 * 3 + - 3840 / 4 = 960 (divisible, TP=4 works) + + IMPORTANT: Must also update num_heads after sharding! + - With TP=4: num_heads becomes 16/4 = 4 per rank + """ + orig_qkv = attn.qkv + orig_proj = attn.proj + + # Update num_heads for this rank (critical for correct attention computation) + original_num_heads = attn.num_heads + attn.num_heads = original_num_heads // tp_degree + + # Shard fused QKV projection + attn.qkv = ColumnParallelLinear( + orig_qkv.in_features, + orig_qkv.out_features, + bias=(orig_qkv.bias is not None), + gather_output=False, + dtype=torch.float32, + ) + attn.qkv.weight.data = get_sharded_data(orig_qkv.weight.data, 0) + if orig_qkv.bias is not None: + attn.qkv.bias.data = get_sharded_data(orig_qkv.bias.data, 0) + del orig_qkv + + # Shard output projection + attn.proj = RowParallelLinear( + orig_proj.in_features, + orig_proj.out_features, + bias=(orig_proj.bias is not None), + input_is_parallel=True, + dtype=torch.float32, + ) + attn.proj.weight.data = get_sharded_data(orig_proj.weight.data, 1) + if orig_proj.bias is not None: + attn.proj.bias.data = orig_proj.bias.data.detach() + del orig_proj + + return attn + + +def shard_vision_mlp_fp32(mlp): + """ + Shard Qwen2.5-VL Vision Encoder MLP module with float32 precision. + + Vision MLP uses SwiGLU-style architecture: + - gate_proj: (hidden_size, intermediate_size) + - up_proj: (hidden_size, intermediate_size) + - down_proj: (intermediate_size, hidden_size) + + Qwen2.5-VL vision encoder: + - hidden_size = 1280 + - intermediate_size = 3420 + - 3420 / 4 = 855 (divisible) + """ + orig_gate = mlp.gate_proj + orig_up = mlp.up_proj + orig_down = mlp.down_proj + + # Shard gate projection + mlp.gate_proj = ColumnParallelLinear( + orig_gate.in_features, + orig_gate.out_features, + bias=(orig_gate.bias is not None), + gather_output=False, + dtype=torch.float32, + ) + mlp.gate_proj.weight.data = get_sharded_data(orig_gate.weight.data, 0) + if orig_gate.bias is not None: + mlp.gate_proj.bias.data = get_sharded_data(orig_gate.bias.data, 0) + del orig_gate + + # Shard up projection + mlp.up_proj = ColumnParallelLinear( + orig_up.in_features, + orig_up.out_features, + bias=(orig_up.bias is not None), + gather_output=False, + dtype=torch.float32, + ) + mlp.up_proj.weight.data = get_sharded_data(orig_up.weight.data, 0) + if orig_up.bias is not None: + mlp.up_proj.bias.data = get_sharded_data(orig_up.bias.data, 0) + del orig_up + + # Shard down projection + mlp.down_proj = RowParallelLinear( + orig_down.in_features, + orig_down.out_features, + bias=(orig_down.bias is not None), + input_is_parallel=True, + dtype=torch.float32, + ) + mlp.down_proj.weight.data = get_sharded_data(orig_down.weight.data, 1) + if orig_down.bias is not None: + mlp.down_proj.bias.data = orig_down.bias.data.detach() + del orig_down + + return mlp + + +class NeuronVisionEncoderV3(nn.Module): + """ + Neuron-optimized Qwen2.5-VL Vision Encoder with TP=4, float32 precision. + + Uses ModelBuilder API with tp_degree=4, world_size=8. + + Key features: + - TP=4 for parallel computation (3420 QKV dim / 4 = 855, divisible) + - Float32 precision for accuracy (required for vision encoder) + - World_size=8 for compatibility with V3 CP transformer + """ + + def __init__(self, original_visual, tp_degree): + super().__init__() + + self.tp_degree = tp_degree + + # Keep the full visual encoder (we'll modify its layers in-place) + self.visual = original_visual + + # Get model structure info from config + self.embed_dim = original_visual.config.hidden_size # 1280 + self.num_heads = original_visual.config.num_heads # 16 + + print(f" Vision encoder config:") + print(f" embed_dim (hidden_size): {self.embed_dim}") + print(f" num_heads: {self.num_heads}") + print(f" QKV dim: {self.embed_dim * 3} = {self.embed_dim} * 3") + print(f" QKV per rank: {self.embed_dim * 3 // tp_degree}") + + # Shard the transformer blocks + for i, block in enumerate(self.visual.blocks): + if hasattr(block, "attn"): + block.attn = shard_vision_attention_fp32(tp_degree, block.attn) + if hasattr(block, "mlp"): + block.mlp = shard_vision_mlp_fp32(block.mlp) + if i == 0: + print(f" Sharded block 0 attention and MLP") + + print(f" Sharded all {len(self.visual.blocks)} blocks") + + # Upcast norms to float32 (already float32, but ensure wrapper) + upcast_norms_to_f32(self.visual) + + def forward(self, pixel_values, grid_thw): + """ + Forward pass for vision encoder. + + Args: + pixel_values: (num_patches, channels_per_patch) - flattened image patches + grid_thw: (num_images, 3) - temporal, height, width grid dimensions + + Returns: + image_embeds: (num_output_tokens, hidden_size) - vision embeddings after merger + """ + return self.visual(pixel_values, grid_thw) + + +class TracingWrapper(nn.Module): + """Wrapper for ModelBuilder tracing.""" + + def __init__(self, vision_encoder): + super().__init__() + self.vision_encoder = vision_encoder + + def forward(self, pixel_values, grid_thw): + return self.vision_encoder(pixel_values, grid_thw) + + +def compile_vision_encoder_v3(args): + """ + Compile Vision Encoder using ModelBuilder API. + + Configuration: + - tp_degree=4: Works with vision encoder dimensions (3420 / 4 = 855) + - world_size=8: Matches V3 CP transformer + - dtype=float32: Required for accuracy + """ + tp_degree = 4 # Fixed: vision encoder dimensions require TP=4 + world_size = int(os.environ.get("VISION_WORLD_SIZE", 8)) # 4 for TP=4 CP=1, 8 for TP=4 CP=2 + + image_h = args.image_h if args.image_h else args.image_size + image_w = args.image_w if args.image_w else args.image_size + patch_size = 14 + temporal_patch_size = 2 + spatial_merge_size = 2 + + for name, dim in (("image_h", image_h), ("image_w", image_w)): + if dim % patch_size != 0: + raise ValueError( + f"{name} ({dim}) must be divisible by patch_size ({patch_size}). " + f"Valid sizes: 224, 336, 448, 560, etc." + ) + if (dim // patch_size) % spatial_merge_size != 0: + raise ValueError( + f"{name} / patch_size ({dim // patch_size}) must be divisible by " + f"spatial_merge_size ({spatial_merge_size})." + ) + + num_patches_h = image_h // patch_size + num_patches_w = image_w // patch_size + num_patches = num_patches_h * num_patches_w + + # pixel_values shape: (num_patches, channels_per_patch) + channels_per_patch = 3 * temporal_patch_size * patch_size * patch_size # 1176 + + print("=" * 60) + print("Compiling Vision Encoder V3 (ModelBuilder API, TP=4, float32)") + print("=" * 60) + print(f" Image size (HxW): {image_h}x{image_w}") + print(f" Patch size: {patch_size}") + print(f" Num patches (HxW): {num_patches_h}x{num_patches_w} = {num_patches}") + print(f" Channels per patch: {channels_per_patch}") + print(f" TP degree: {tp_degree}") + print(f" World size: {world_size}") + print(f" Dtype: float32 (required for accuracy)") + print("") + + # Sample inputs + sample_pixel_values = torch.randn( + num_patches, channels_per_patch, dtype=torch.float32 + ) + sample_grid_thw = torch.tensor( + [[1, num_patches_h, num_patches_w]], dtype=torch.int64 + ) + + print(f"Sample input shapes:") + print(f" pixel_values: {sample_pixel_values.shape}") + print(f" grid_thw: {sample_grid_thw.shape}") + print("") + + # Use NxDParallelState context for compilation + with NxDParallelState(world_size=world_size, tensor_model_parallel_size=tp_degree): + # On trn2.3xlarge (or instances with <96GB RAM), loading the full pipeline + # in fp32 (~95 GB) will OOM. Load in bf16 to save memory, then extract + # the vision encoder and explicitly convert its weights to fp32. + # On trn2.48xlarge, fp32 loading works fine. + load_dtype = torch.bfloat16 if args.load_bf16 else torch.float32 + print(f"Loading model in {load_dtype}...") + pipe = load_pipeline(load_dtype) + + # Extract vision encoder + original_visual = pipe.text_encoder.model.visual + + # Save unsharded state dict before modifications. + # CRITICAL: If pipeline was loaded in bf16, the state dict will be bf16. + # Vision encoder requires fp32 for accuracy, so we must explicitly cast. + print("Saving unsharded state dict...") + unsharded_state = { + k: v.to(torch.float32) for k, v in original_visual.state_dict().items() + } + + # Convert vision encoder to fp32 before sharding + if load_dtype != torch.float32: + original_visual = original_visual.to(torch.float32) + + # Create Neuron vision encoder with sharding + print( + f"\nCreating Neuron vision encoder (sharding layers with TP={tp_degree})..." + ) + neuron_vision_encoder = NeuronVisionEncoderV3(original_visual, tp_degree) + neuron_vision_encoder = neuron_vision_encoder.to(torch.float32) + neuron_vision_encoder.eval() + + # Clear pipeline to save memory (important on trn2.3xlarge) + del pipe + gc.collect() + + # Wrap for tracing + model = TracingWrapper(neuron_vision_encoder) + + print("\nInitializing ModelBuilder...") + builder = ModelBuilder(model=model) + + print("Tracing model...") + builder.trace( + kwargs={ + "pixel_values": sample_pixel_values, + "grid_thw": sample_grid_thw, + }, + tag="inference", + ) + + print("Compiling model...") + # Use --auto-cast=none to preserve float32 precision + # NOTE: Using -O1 instead of -O2 because -O2 can cause numerical issues in some cases + compile_args = "--model-type=transformer -O1 --auto-cast=none" + traced_model = builder.compile( + compiler_args=compile_args, + compiler_workdir=args.compiler_workdir, + ) + + # Save + output_path = f"{args.compiled_models_dir}/vision_encoder_v3" + os.makedirs(output_path, exist_ok=True) + + print(f"\nSaving to {output_path}...") + traced_model.save(os.path.join(output_path, "nxd_model.pt")) + + # Save weights + weights_path = os.path.join(output_path, "weights") + os.makedirs(weights_path, exist_ok=True) + + # Prepare checkpoint for sharding + print("Preparing checkpoint...") + checkpoint = {} + for key, value in model.state_dict().items(): + # Use unsharded weights where available + # Key format: vision_encoder.visual.blocks.X... -> blocks.X... + orig_key = key.replace("vision_encoder.visual.", "", 1) + if orig_key in unsharded_state: + checkpoint[key] = unsharded_state[orig_key].clone() + else: + checkpoint[key] = value.clone() + + # Shard checkpoint + print("Sharding weights...") + shard_checkpoint( + checkpoint=checkpoint, + model=model, + serialize_path=weights_path, + ) + + # Post-process checkpoints: remove master_weight and add inv_freq + print("\nPost-processing checkpoints...") + from safetensors.torch import load_file, save_file + + # Collect inv_freq buffers from original model (they are not in state_dict) + inv_freq_buffers = {} + for name, buf in neuron_vision_encoder.visual.named_buffers(): + if "inv_freq" in name: + full_key = f"vision_encoder.visual.{name}" + inv_freq_buffers[full_key] = buf.to(torch.float32).clone() + print(f" Collected {len(inv_freq_buffers)} inv_freq buffers") + + for rank in range(tp_degree): + shard_file = os.path.join( + weights_path, f"tp{rank}_sharded_checkpoint.safetensors" + ) + if not os.path.exists(shard_file): + print(f" WARNING: {shard_file} not found!") + continue + + # Load checkpoint + data = dict(load_file(shard_file)) + original_count = len(data) + original_size = sum(v.numel() * v.element_size() for v in data.values()) + + # Remove master_weight tensors. Clone to detach from mmap before overwrite. + cleaned = {k: v.clone().contiguous() for k, v in data.items() if "master_weight" not in k} + + # Add inv_freq buffers + cleaned.update({k: (v.clone().contiguous() if hasattr(v, 'clone') else v) + for k, v in inv_freq_buffers.items()}) + + cleaned_size = sum(v.numel() * v.element_size() for v in cleaned.values()) + + # Save optimized checkpoint + del data + save_file(cleaned, shard_file) + print( + f" tp{rank}: {original_count} -> {len(cleaned)} tensors, " + f"{original_size / 1e9:.2f}GB -> {cleaned_size / 1e9:.2f}GB" + ) + + # Save config + config = { + "tp_degree": tp_degree, + "world_size": world_size, + "image_size": image_h if image_h == image_w else max(image_h, image_w), + "image_h": image_h, + "image_w": image_w, + "patch_size": patch_size, + "num_patches": num_patches, + "num_patches_h": num_patches_h, + "num_patches_w": num_patches_w, + "channels_per_patch": channels_per_patch, + "embed_dim": neuron_vision_encoder.embed_dim, + "num_heads": neuron_vision_encoder.num_heads, + "dtype": "float32", + } + config_path = os.path.join(output_path, "config.json") + with open(config_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"\nVision Encoder V3 compiled successfully!") + print(f" Output: {output_path}") + print(f" Config: {config_path}") + print(f" Weights: {weights_path}") + + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Compile Vision Encoder V3 using ModelBuilder API" + ) + parser.add_argument( + "--image_size", + type=int, + default=448, + help="Vision encoder input image size when square (default: 448). " + "Ignored if --image_h and --image_w are set.", + ) + parser.add_argument( + "--image_h", + type=int, + default=None, + help="Vision encoder input image height (in pixels). Overrides --image_size for height.", + ) + parser.add_argument( + "--image_w", + type=int, + default=None, + help="Vision encoder input image width (in pixels). Overrides --image_size for width.", + ) + parser.add_argument( + "--compiled_models_dir", + type=str, + default="/opt/dlami/nvme/compiled_models", + help="Output directory for compiled models", + ) + parser.add_argument( + "--compiler_workdir", + type=str, + default="/opt/dlami/nvme/compiler_workdir", + help="Compiler working directory", + ) + parser.add_argument( + "--model_path", + type=str, + default=None, + help="Path to model (local dir or HuggingFace ID)", + ) + parser.add_argument( + "--load_bf16", + action="store_true", + default=False, + help="Load pipeline in bf16 to save memory (for trn2.3xlarge). " + "Weights are automatically cast to fp32 for compilation.", + ) + + args = parser.parse_args() + + # Override MODEL_ID if model_path is provided + if args.model_path: + MODEL_ID = args.model_path + CACHE_DIR = None + + compile_vision_encoder_v3(args) diff --git a/contrib/models/Qwen-Image-Edit/src/neuron_commons.py b/contrib/models/Qwen-Image-Edit/src/neuron_commons.py new file mode 100644 index 00000000..0ded645f --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/neuron_commons.py @@ -0,0 +1,958 @@ +import torch +import math +from torch import nn +from diffusers import QwenImageEditPlusPipeline +from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration + +# Try to import NKI kernel, but don't fail if not available +try: + import neuronxcc.nki as nki + from neuronxcc.nki.language import nc + try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel + except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + _flash_fwd_call = nki.jit()(attention_isa_kernel) + NKI_AVAILABLE = True + print(f"NKI Flash Attention kernel loaded successfully") +except ImportError as e: + _flash_fwd_call = None + NKI_AVAILABLE = False + nc = None + print(f"NKI Flash Attention not available: {e}") + + +class InferenceTextEncoderWrapper(nn.Module): + """Wrapper for Qwen2.5-VL text encoder for inference on Neuron.""" + def __init__(self, dtype, text_encoder: Qwen2_5_VLForConditionalGeneration): + super().__init__() + self.dtype = dtype + self.device = text_encoder.device + self.text_encoder = text_encoder + self.config = text_encoder.config + + def forward(self, input_ids, attention_mask=None, pixel_values=None, + image_grid_thw=None, **kwargs): + outputs = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + **kwargs + ) + return outputs + + +class NeuronTextEncoderWrapper(nn.Module): + """ + Wrapper for compiled Qwen2.5-VL text encoder on Neuron. + + Combines separately compiled vision encoder and language model. + This wrapper handles the embedding combination logic that normally + happens inside the original text encoder. + + Supports three modes for Language Model: + 1. compiled_language_model: Neuron-compiled model with parallel_model_trace (TP=8) + 2. compiled_language_model_v3: Neuron-compiled model with ModelBuilder API (TP=4, world_size=8) + 3. cpu_language_model: Original model on CPU (slower but avoids GQA issues) + + IMPORTANT: This wrapper COPIES necessary components and does NOT keep + references to the original model, to avoid memory bloat. + """ + def __init__(self, original_text_encoder, compiled_vision_encoder=None, + compiled_vision_encoder_v3=None, # V3 vision encoder (TP=4, NxDModel) + compiled_language_model=None, compiled_language_model_v3=None, + cpu_language_model=None, + cpu_vision_encoder=None, # Option to use CPU vision encoder + image_size=448, max_seq_len=512, + image_h=None, image_w=None, # Portrait NEFF: vision encoder H, W (overrides image_size) + language_model_batch_size=1): # Batch size for V3 language model + super().__init__() + # Copy config (small object) + self.config = original_text_encoder.config + self.dtype = torch.bfloat16 + + # IMPORTANT: Copy embed_tokens weights instead of keeping reference! + # This allows the original model to be garbage collected. + orig_embed = original_text_encoder.model.language_model.embed_tokens + self.embed_tokens = nn.Embedding( + orig_embed.num_embeddings, + orig_embed.embedding_dim, + padding_idx=orig_embed.padding_idx, + dtype=torch.bfloat16 + ) + self.embed_tokens.weight.data = orig_embed.weight.data.clone().to(torch.bfloat16) + print(f" Copied embed_tokens: {orig_embed.num_embeddings} x {orig_embed.embedding_dim} " + f"= {orig_embed.weight.numel() * 2 / 1e9:.2f} GB") + + # Copy visual_merger if it exists (small module) + # Note: For V3 vision encoder, merger is included in the compiled model + if compiled_vision_encoder_v3 is None and hasattr(original_text_encoder.model.visual, 'merger'): + # Deep copy the merger module (only needed for non-V3 or CPU vision encoder) + import copy + self.visual_merger = copy.deepcopy(original_text_encoder.model.visual.merger) + self.visual_merger = self.visual_merger.to(torch.bfloat16) + else: + self.visual_merger = None + + # Compiled models + self.compiled_vision_encoder = compiled_vision_encoder + self.compiled_vision_encoder_v3 = compiled_vision_encoder_v3 # V3 (NxDModel, TP=4) + self.compiled_language_model = compiled_language_model + self.compiled_language_model_v3 = compiled_language_model_v3 + + # CPU Vision Encoder (for better accuracy, avoids compilation precision loss) + self.cpu_vision_encoder = cpu_vision_encoder + self.use_cpu_vision_encoder = cpu_vision_encoder is not None + + # V3 Vision Encoder (ModelBuilder API, TP=4, world_size=8, float32) + self.use_v3_vision_encoder = compiled_vision_encoder_v3 is not None + + # CPU Language Model (alternative to compiled, avoids GQA alignment issues) + self.cpu_language_model = cpu_language_model + self.use_cpu_language_model = cpu_language_model is not None + + # V3 Language Model (ModelBuilder API, TP=4, world_size=8) + self.use_v3_language_model = compiled_language_model_v3 is not None + self.language_model_batch_size = language_model_batch_size # Compiled batch size + + # DO NOT keep original_text_encoder - it's 16+ GB! + # self.original_text_encoder = original_text_encoder # REMOVED! + + # Image processing parameters + self.image_size = image_size + self.max_seq_len = max_seq_len + self.patch_size = 14 + self.spatial_merge_size = 2 + + # Portrait/landscape NEFF support: vision encoder may have been compiled + # with non-square dims (image_h != image_w). When provided, these override + # the square assumption. + self.image_h = image_h if image_h is not None else image_size + self.image_w = image_w if image_w is not None else image_size + self.grid_h = self.image_h // self.patch_size + self.grid_w = self.image_w // self.patch_size + self.num_patches_per_image = self.grid_h * self.grid_w + + # Calculate expected dimensions + self.num_image_tokens = (self.grid_h // self.spatial_merge_size) * (self.grid_w // self.spatial_merge_size) + + # Special token IDs from config + self.image_token_id = getattr(self.config, 'image_token_id', 151655) + self.vision_start_token_id = getattr(self.config, 'vision_start_token_id', 151652) + + def _get_rope_index(self, input_ids, image_grid_thw, attention_mask): + """ + Calculate 3D position_ids for M-RoPE (Multimodal RoPE). + + For multimodal input (text + images), position_ids have different patterns: + - Text tokens: sequential positions (same for t, h, w dimensions) + - Image tokens: 3D grid positions based on spatial layout + + This replicates the logic from Qwen2_5_VLModel.get_rope_index(). + + OPTIMIZED: Uses vectorized tensor operations to avoid CPU synchronization. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + # If no images, use simple text-only position_ids + if image_grid_thw is None: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + else: + position_ids = torch.arange(seq_len, device=device).view(1, 1, -1).expand(3, batch_size, -1) + return position_ids + + # Multimodal case: vectorized computation of 3D positions + # Get grid dimensions (avoid .tolist() by using tensor indexing) + t = image_grid_thw[0, 0] + h = image_grid_thw[0, 1] + w = image_grid_thw[0, 2] + llm_grid_h = h // self.spatial_merge_size + llm_grid_w = w // self.spatial_merge_size + grid_hw = llm_grid_h * llm_grid_w + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # Create image token mask for all batches at once + is_image_token = (input_ids == self.image_token_id) # [batch, seq] + + # Check if any batch has image tokens (avoid .item() by checking tensor) + has_images = is_image_token.any() + + if not has_images: + # No images in any batch, use simple sequential positions + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + else: + position_ids = torch.arange(seq_len, device=device).view(1, 1, -1).expand(3, batch_size, -1) + return position_ids + + # Initialize position_ids + position_ids = torch.zeros(3, batch_size, seq_len, dtype=torch.long, device=device) + + # Process each batch (still need loop for batch, but inner ops are vectorized) + for b in range(batch_size): + valid_mask = attention_mask[b] == 1 + valid_len = valid_mask.sum() + + # Get image token mask for valid positions + batch_is_image = is_image_token[b] & valid_mask + num_image_tokens = batch_is_image.sum() + + if num_image_tokens == 0: + # No images, use sequential positions + pos = torch.arange(seq_len, device=device) + masked_pos = pos * valid_mask.long() + # Compute cumsum for valid positions only + cumsum = valid_mask.long().cumsum(-1) - 1 + cumsum = cumsum * valid_mask.long() + position_ids[:, b, :] = cumsum.unsqueeze(0).expand(3, -1) + continue + + # Vectorized computation for multimodal case + # Create index arrays for image tokens + image_indices = torch.where(batch_is_image)[0] # positions of image tokens + num_imgs = image_indices.shape[0] + + # Compute grid positions for all image tokens at once + img_local_idx = torch.arange(num_imgs, device=device) + t_pos = img_local_idx // grid_hw + remainder = img_local_idx % grid_hw + h_pos = remainder // llm_grid_w + w_pos = remainder % llm_grid_w + + # Compute text offset: count non-image tokens before each position + # First, get cumulative count of non-image tokens + is_text = valid_mask & ~batch_is_image + text_cumsum = is_text.long().cumsum(-1) + + # For image tokens, the offset is the text count before the first image token + first_image_idx = image_indices[0] if num_imgs > 0 else 0 + text_offset = text_cumsum[first_image_idx] - (1 if is_text[first_image_idx] else 0) + if first_image_idx > 0: + text_offset = text_cumsum[first_image_idx - 1] + else: + text_offset = torch.zeros(1, dtype=torch.long, device=device)[0] + + # Set image token positions + position_ids[0, b, image_indices] = text_offset + t_pos + position_ids[1, b, image_indices] = text_offset + h_pos + position_ids[2, b, image_indices] = text_offset + w_pos + + # Compute max position used by images + max_img_pos = torch.max(torch.stack([t_pos, h_pos, w_pos]).max(dim=0)[0]) + after_image_offset = text_offset + max_img_pos + 1 + + # Set text token positions + # Text before images: sequential from 0 + text_before_first_image = torch.arange(seq_len, device=device) < first_image_idx + text_before_mask = is_text & text_before_first_image + if text_before_mask.any(): + text_before_pos = text_before_mask.long().cumsum(-1) - 1 + text_before_pos = text_before_pos * text_before_mask.long() + for d in range(3): + position_ids[d, b, :] = torch.where( + text_before_mask, + text_before_pos, + position_ids[d, b, :] + ) + + # Text after images: sequential from after_image_offset + last_image_idx = image_indices[-1] if num_imgs > 0 else 0 + text_after_last_image = torch.arange(seq_len, device=device) > last_image_idx + text_after_mask = is_text & text_after_last_image + if text_after_mask.any(): + # Count text tokens after last image + text_after_local = text_after_mask.long().cumsum(-1) + # Subtract count at last_image_idx to get local index + offset_at_last = text_after_local[last_image_idx] if last_image_idx < seq_len else 0 + text_after_pos = after_image_offset + (text_after_local - offset_at_last - 1) + text_after_pos = text_after_pos * text_after_mask.long() + for d in range(3): + position_ids[d, b, :] = torch.where( + text_after_mask, + text_after_pos, + position_ids[d, b, :] + ) + + return position_ids + + def forward(self, input_ids=None, attention_mask=None, pixel_values=None, + image_grid_thw=None, output_hidden_states=True, return_dict=True, **kwargs): + """ + Forward pass combining vision encoder and language model. + + For Neuron inference, we run: + 1. Vision encoder on compiled model (or CPU fallback) + 2. Combine image embeds with text embeds + 3. Pad to max_seq_len for compiled model + 4. Language model on compiled model + 5. Remove padding from output + """ + batch_size = input_ids.shape[0] if input_ids is not None else 1 + + # Step 1: Process images through vision encoder + import os as _os2, time as _time2 + _vis_t0 = _time2.time() if _os2.environ.get("QIE_STAGE_TIMING", "0") == "1" else None + if pixel_values is not None: + # Determine dtype for vision encoder + # - CPU vision encoder: use original dtype (usually float32 from pipeline) + # - Compiled vision encoder: always float32 (required for accuracy) + if self.use_cpu_vision_encoder: + # Keep original dtype for CPU (highest precision) + pass + else: + # Use float32 for compiled vision encoder (required for accuracy) + pixel_values = pixel_values.to(torch.float32) + + # Option 1: Use CPU Vision Encoder (highest accuracy) + if self.use_cpu_vision_encoder: + with torch.no_grad(): + image_embeds = self.cpu_vision_encoder(pixel_values, image_grid_thw) + + # Option 2: Use V3 Vision Encoder (TP=4, NxDModel, float32, fast) + elif self.use_v3_vision_encoder: + # V3 vision encoder expects fixed patch count for single image + # Compiled vision encoder NEFF expects fixed (grid_h * grid_w) patches. + # For portrait NEFF (image_h != image_w), grid is non-square. + expected_patches_per_image = self.num_patches_per_image # grid_h * grid_w + actual_patches = pixel_values.shape[0] + num_images = image_grid_thw.shape[0] + + # For multi-image input, process each image separately + if num_images > 1: + all_embeds = [] + patch_idx = 0 + for img_idx in range(num_images): + # Use tensor indexing to avoid .tolist() CPU sync + t = image_grid_thw[img_idx, 0] + h = image_grid_thw[img_idx, 1] + w = image_grid_thw[img_idx, 2] + img_patches = (t * h * w).item() # Need scalar for slicing + + img_pixel_values = pixel_values[patch_idx:patch_idx + img_patches] + patch_idx += img_patches + + # Pad or truncate to expected size + if img_patches < expected_patches_per_image: + padding = torch.zeros( + expected_patches_per_image - img_patches, + img_pixel_values.shape[1], + dtype=img_pixel_values.dtype, + device=img_pixel_values.device + ) + img_pixel_values = torch.cat([img_pixel_values, padding], dim=0) + elif img_patches > expected_patches_per_image: + img_pixel_values = img_pixel_values[:expected_patches_per_image] + + # Create grid_thw for single image (matches compiled NEFF grid) + single_grid_thw = torch.tensor([[1, self.grid_h, self.grid_w]], dtype=torch.int64) + + # Run V3 vision encoder (NxDModel) + img_embeds = self.compiled_vision_encoder_v3( + pixel_values=img_pixel_values, + grid_thw=single_grid_thw + ) + + # Calculate actual output tokens (after spatial merge) + merged_h = h // self.spatial_merge_size + merged_w = w // self.spatial_merge_size + actual_output_tokens = (t * merged_h * merged_w).item() + + # Truncate to actual output size (remove padding) + img_embeds = img_embeds[:actual_output_tokens] + all_embeds.append(img_embeds) + + image_embeds = torch.cat(all_embeds, dim=0) + else: + # Single image processing + if actual_patches != expected_patches_per_image: + if actual_patches < expected_patches_per_image: + padding = torch.zeros( + expected_patches_per_image - actual_patches, + pixel_values.shape[1], + dtype=pixel_values.dtype, + device=pixel_values.device + ) + pixel_values = torch.cat([pixel_values, padding], dim=0) + else: + pixel_values = pixel_values[:expected_patches_per_image] + + image_grid_thw = torch.tensor([[1, self.grid_h, self.grid_w]], dtype=torch.int64) + + image_embeds = self.compiled_vision_encoder_v3( + pixel_values=pixel_values, + grid_thw=image_grid_thw + ) + + # Convert output to bfloat16 for downstream processing + image_embeds = image_embeds.to(torch.bfloat16) + + # Option 3: Use single-device compiled Vision Encoder (slower) + elif self.compiled_vision_encoder is not None: + # Compiled vision encoder expects fixed patch count for single image + # Compiled vision encoder NEFF expects fixed (grid_h * grid_w) patches. + # For portrait NEFF (image_h != image_w), grid is non-square. + expected_patches_per_image = self.num_patches_per_image # grid_h * grid_w + actual_patches = pixel_values.shape[0] + num_images = image_grid_thw.shape[0] + + # For multi-image input, process each image separately + if num_images > 1: + # Process each image through compiled vision encoder + all_embeds = [] + patch_idx = 0 + for img_idx in range(num_images): + # Use tensor indexing to avoid .tolist() CPU sync + t = image_grid_thw[img_idx, 0] + h = image_grid_thw[img_idx, 1] + w = image_grid_thw[img_idx, 2] + img_patches = (t * h * w).item() # Need scalar for slicing + + # Extract patches for this image + img_pixel_values = pixel_values[patch_idx:patch_idx + img_patches] + patch_idx += img_patches + + # Pad or truncate to expected size + if img_patches < expected_patches_per_image: + padding = torch.zeros( + expected_patches_per_image - img_patches, + img_pixel_values.shape[1], + dtype=img_pixel_values.dtype, + device=img_pixel_values.device + ) + img_pixel_values = torch.cat([img_pixel_values, padding], dim=0) + elif img_patches > expected_patches_per_image: + img_pixel_values = img_pixel_values[:expected_patches_per_image] + + # Create grid_thw for single image (matches compiled NEFF grid) + single_grid_thw = torch.tensor([[1, self.grid_h, self.grid_w]], dtype=torch.int64) + + # Run vision encoder for this image + img_embeds = self.compiled_vision_encoder(img_pixel_values, single_grid_thw) + + # Calculate actual output tokens (after spatial merge) + merged_h = h // self.spatial_merge_size + merged_w = w // self.spatial_merge_size + actual_output_tokens = (t * merged_h * merged_w).item() + + # Truncate to actual output size (remove padding) + img_embeds = img_embeds[:actual_output_tokens] + all_embeds.append(img_embeds) + + # Concatenate all image embeddings + image_embeds = torch.cat(all_embeds, dim=0) + else: + # Single image processing + if actual_patches != expected_patches_per_image: + if actual_patches < expected_patches_per_image: + padding = torch.zeros( + expected_patches_per_image - actual_patches, + pixel_values.shape[1], + dtype=pixel_values.dtype, + device=pixel_values.device + ) + pixel_values = torch.cat([pixel_values, padding], dim=0) + else: + pixel_values = pixel_values[:expected_patches_per_image] + + image_grid_thw = torch.tensor([[1, self.grid_h, self.grid_w]], dtype=torch.int64) + + image_embeds = self.compiled_vision_encoder(pixel_values, image_grid_thw) + + # Convert output to bfloat16 for downstream processing + image_embeds = image_embeds.to(torch.bfloat16) + # Note: merger is already included in compiled_vision_encoder + else: + # No vision encoder available + raise RuntimeError( + "No vision encoder available! Please either:\n" + " 1. Compile: python neuron_qwen_image_edit/compile_text_encoder.py --vision_only\n" + " 2. Use --cpu_vision_encoder flag" + ) + else: + image_embeds = None + + if _vis_t0 is not None: + print(f" [vision_v3] {(_time2.time()-_vis_t0)*1000:.1f} ms", flush=True) + + # Step 2: Get text embeddings + text_embeds = self.embed_tokens(input_ids) + + # Step 3: Combine embeddings + # Find image token positions and replace with image embeddings + if image_embeds is not None: + # The image token ID in Qwen2.5-VL + image_token_id = self.config.image_token_id if hasattr(self.config, 'image_token_id') else 151655 + + # Create combined embeddings + inputs_embeds = self._merge_embeddings( + text_embeds, image_embeds, input_ids, image_token_id + ) + else: + inputs_embeds = text_embeds + + # Step 4: Calculate 3D position_ids for M-RoPE (required by Qwen2.5-VL) + # For multimodal input (text + images), position_ids have special patterns: + # - Text tokens: sequential positions (same for t, h, w dimensions) + # - Image tokens: 3D grid positions based on spatial layout + position_ids = self._get_rope_index(input_ids, image_grid_thw, attention_mask) + + # Step 5: Run language model (CPU, V3, or compiled) + if self.use_cpu_language_model: + # CPU Language Model mode - no padding needed, handles dynamic sequence lengths + # This avoids GQA alignment issues that occur with TP != 4 + with torch.no_grad(): + cpu_outputs = self.cpu_language_model( + inputs_embeds=inputs_embeds.to(torch.bfloat16), + attention_mask=attention_mask, + position_ids=position_ids, # Pass 3D position_ids for M-RoPE + output_hidden_states=True, + return_dict=True + ) + hidden_states = cpu_outputs.last_hidden_state + + # Create output similar to original + if return_dict: + return type('TextEncoderOutput', (), { + 'hidden_states': (hidden_states,), + 'last_hidden_state': hidden_states + })() + return hidden_states + + elif self.use_v3_language_model: + # V3 Language Model mode (ModelBuilder API, TP=4, world_size=8) + # Compatible with V3 CP transformer + original_seq_len = inputs_embeds.shape[1] + hidden_size = inputs_embeds.shape[2] + + if original_seq_len < self.max_seq_len: + # Pad inputs_embeds with zeros + pad_len = self.max_seq_len - original_seq_len + embed_padding = torch.zeros( + batch_size, pad_len, hidden_size, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device + ) + inputs_embeds = torch.cat([inputs_embeds, embed_padding], dim=1) + + # Pad attention_mask with zeros (masked positions) + if attention_mask is not None: + mask_padding = torch.zeros( + batch_size, pad_len, + dtype=attention_mask.dtype, + device=attention_mask.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + # Pad position_ids with sequential positions + if position_ids is not None: + # position_ids shape: (3, batch, seq_len) + last_pos = position_ids[:, :, -1:] + 1 + pad_positions = last_pos + torch.arange(pad_len, device=position_ids.device).view(1, 1, -1) + position_ids = torch.cat([position_ids, pad_positions], dim=2) + elif original_seq_len > self.max_seq_len: + # Truncate if too long + print(f" WARNING: Sequence length {original_seq_len} > max_seq_len {self.max_seq_len}, truncating") + inputs_embeds = inputs_embeds[:, :self.max_seq_len, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :self.max_seq_len] + if position_ids is not None: + position_ids = position_ids[:, :, :self.max_seq_len] + original_seq_len = self.max_seq_len + + # Handle batch padding if needed + actual_batch_size = inputs_embeds.shape[0] + if actual_batch_size < self.language_model_batch_size: + pad_batch = self.language_model_batch_size - actual_batch_size + # Pad inputs_embeds + inputs_embeds = torch.cat([ + inputs_embeds, + torch.zeros((pad_batch, inputs_embeds.shape[1], inputs_embeds.shape[2]), + dtype=inputs_embeds.dtype, device=inputs_embeds.device) + ], dim=0) + # Pad attention_mask + if attention_mask is not None: + attention_mask = torch.cat([ + attention_mask, + torch.zeros((pad_batch, attention_mask.shape[1]), + dtype=attention_mask.dtype, device=attention_mask.device) + ], dim=0) + # Pad position_ids (shape: 3, batch, seq_len) + if position_ids is not None: + position_ids = torch.cat([ + position_ids, + position_ids[:, :1, :].repeat(1, pad_batch, 1) # Repeat first sample's positions + ], dim=1) + + # Run V3 compiled language model (NxDModel) + # V3 model expects: inputs_embeds, attention_mask, position_ids + import os as _os, time as _time + _lm_t0 = _time.time() if _os.environ.get("QIE_STAGE_TIMING", "0") == "1" else None + hidden_states = self.compiled_language_model_v3( + inputs_embeds.to(torch.bfloat16), + attention_mask, + position_ids + ) + if _lm_t0 is not None: + print(f" [LM_v3] {(_time.time()-_lm_t0)*1000:.1f} ms", flush=True) + + # Remove batch padding from output + if actual_batch_size < self.language_model_batch_size: + hidden_states = hidden_states[:actual_batch_size] + + # Remove sequence padding from output + hidden_states = hidden_states[:, :original_seq_len, :] + + # Create output similar to original + if return_dict: + return type('TextEncoderOutput', (), { + 'hidden_states': (hidden_states,), + 'last_hidden_state': hidden_states + })() + return hidden_states + + elif self.compiled_language_model is not None: + # Neuron compiled Language Model mode - requires fixed sequence length + original_seq_len = inputs_embeds.shape[1] + hidden_size = inputs_embeds.shape[2] + + if original_seq_len < self.max_seq_len: + # Pad inputs_embeds with zeros + pad_len = self.max_seq_len - original_seq_len + embed_padding = torch.zeros( + batch_size, pad_len, hidden_size, + dtype=inputs_embeds.dtype, + device=inputs_embeds.device + ) + inputs_embeds = torch.cat([inputs_embeds, embed_padding], dim=1) + + # Pad attention_mask with zeros (masked positions) + if attention_mask is not None: + mask_padding = torch.zeros( + batch_size, pad_len, + dtype=attention_mask.dtype, + device=attention_mask.device + ) + attention_mask = torch.cat([attention_mask, mask_padding], dim=1) + + # Pad position_ids with sequential positions + if position_ids is not None: + # position_ids shape: (3, batch, seq_len) + # Pad with sequential positions continuing from the last position + last_pos = position_ids[:, :, -1:] + 1 # (3, batch, 1) + pad_positions = last_pos + torch.arange(pad_len, device=position_ids.device).view(1, 1, -1) + position_ids = torch.cat([position_ids, pad_positions], dim=2) + elif original_seq_len > self.max_seq_len: + # Truncate if too long + print(f" WARNING: Sequence length {original_seq_len} > max_seq_len {self.max_seq_len}, truncating") + inputs_embeds = inputs_embeds[:, :self.max_seq_len, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :self.max_seq_len] + if position_ids is not None: + position_ids = position_ids[:, :, :self.max_seq_len] + original_seq_len = self.max_seq_len + + # Run compiled language model with position_ids for M-RoPE + hidden_states = self.compiled_language_model(inputs_embeds, attention_mask, position_ids) + + # Remove padding from output (restore original sequence length) + hidden_states = hidden_states[:, :original_seq_len, :] + + # Create output similar to original + if return_dict: + return type('TextEncoderOutput', (), { + 'hidden_states': (hidden_states,), + 'last_hidden_state': hidden_states + })() + return hidden_states + + else: + # No language model available + raise RuntimeError( + "No language model available! Please either:\n" + "1. Compile V3 language model: python neuron_qwen_image_edit/compile_language_model_v3.py\n" + "2. Compile V1 language model: python neuron_qwen_image_edit/compile_text_encoder.py --language_only\n" + "3. Use CPU language model by passing cpu_language_model to NeuronTextEncoderWrapper" + ) + + def _merge_embeddings(self, text_embeds, image_embeds, input_ids, image_token_id): + """ + Merge text and image embeddings at image token positions. + + OPTIMIZED: Uses index-based replacement to minimize CPU synchronization. + """ + batch_size, seq_len, hidden_size = text_embeds.shape + + if image_embeds is None: + return text_embeds + + # Find positions of image tokens + image_mask = (input_ids == image_token_id) # [batch, seq] + + # Clone to avoid modifying original + inputs_embeds = text_embeds.clone() + + # For batch_size=1, use optimized path with nonzero + if batch_size == 1: + # Get indices of image tokens (returns [N, 2] for 2D input, we need column 1) + image_indices = image_mask[0].nonzero(as_tuple=True)[0] # [num_image_tokens] + num_image_positions = image_indices.shape[0] + + if num_image_positions > 0: + # Handle case where image_embeds has fewer tokens than positions + num_to_use = min(num_image_positions, image_embeds.shape[0]) + + # Use index_copy_ for efficient in-place replacement + inputs_embeds[0, image_indices[:num_to_use]] = image_embeds[:num_to_use] + + return inputs_embeds + + # For batch_size > 1, process each batch + for b in range(batch_size): + image_indices = image_mask[b].nonzero(as_tuple=True)[0] + num_image_positions = image_indices.shape[0] + + if num_image_positions > 0: + num_to_use = min(num_image_positions, image_embeds.shape[0]) + inputs_embeds[b, image_indices[:num_to_use]] = image_embeds[:num_to_use] + + return inputs_embeds + + +class InferenceTransformerWrapper(nn.Module): + """Wrapper for QwenImageTransformer2DModel for inference on Neuron.""" + def __init__(self, transformer: QwenImageTransformer2DModel): + super().__init__() + self.transformer = transformer + self.config = transformer.config + self.dtype = transformer.dtype + self.device = transformer.device + + def forward(self, hidden_states, encoder_hidden_states=None, + timestep=None, encoder_attention_mask=None, + pooled_projections=None, return_dict=False, **kwargs): + output = self.transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + encoder_attention_mask=encoder_attention_mask, + pooled_projections=pooled_projections, + return_dict=return_dict, + ) + return output + + +class SimpleWrapper(nn.Module): + """Simple wrapper for VAE decoder and other modules.""" + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + return self.model(x) + + +class f32Wrapper(nn.Module): + """Wrapper to run normalization layers in float32 for numerical stability.""" + def __init__(self, original): + super().__init__() + self.original = original + + def forward(self, x): + t = x.dtype + y = x.to(torch.float32) + output = self.original(y) + return output.type(t) + + +def neuron_scaled_dot_product_attention(query, key, value, attn_mask=None, + dropout_p=None, is_causal=None, scale=None, + enable_gqa=False, **kwargs): + """Custom scaled dot product attention optimized for Neuron. + + Supports: + - Grouped Query Attention (GQA) where num_kv_heads < num_q_heads + - Causal masking when is_causal=True + - Explicit attention masks (attn_mask) + """ + orig_shape = None + orig_query_shape = query.shape + q_len = query.shape[-2] + kv_len = key.shape[-2] + + if len(query.shape) == 4: + orig_shape = query.shape + batch_size, num_q_heads, seq_len, head_dim = query.shape + _, num_kv_heads, _, _ = key.shape + + # Handle GQA: repeat K/V heads to match Q heads + if num_kv_heads != num_q_heads: + num_groups = num_q_heads // num_kv_heads + # Repeat K and V along head dimension + key = key.repeat_interleave(num_groups, dim=1) + value = value.repeat_interleave(num_groups, dim=1) + + def to3d(x): + return x.reshape(-1, x.shape[2], x.shape[3]) + query, key, value = map(to3d, [query, key, value]) + + # Use provided scale or default to 1/sqrt(d_k) + if scale is None: + scale = 1 / math.sqrt(query.size(-1)) + + # Compute attention scores: [batch*heads, q_len, kv_len] + attention_scores = torch.bmm(query, key.transpose(-1, -2)) * scale + + # Apply causal mask if requested + if is_causal: + # Create causal mask: positions above the main diagonal are masked (-inf) + # Shape: (q_len, kv_len) + # Use torch.where to avoid NaN from 0 * -inf + causal_mask = torch.triu( + torch.ones(q_len, kv_len, device=attention_scores.device), + diagonal=1 + ) + causal_mask = torch.where( + causal_mask == 1, + torch.tensor(float('-inf'), dtype=attention_scores.dtype, device=attention_scores.device), + torch.tensor(0.0, dtype=attention_scores.dtype, device=attention_scores.device) + ) + attention_scores = attention_scores + causal_mask + + # Apply explicit attention mask if provided + if attn_mask is not None: + # attn_mask can be: + # - 2D: (q_len, kv_len) - applied to all batches/heads + # - 3D: (batch*heads, q_len, kv_len) - per-head mask + # - 4D: (batch, heads, q_len, kv_len) - full mask + if attn_mask.dim() == 4: + # Reshape 4D mask to 3D + attn_mask = attn_mask.reshape(-1, attn_mask.shape[-2], attn_mask.shape[-1]) + elif attn_mask.dim() == 2: + # Broadcast 2D mask + attn_mask = attn_mask.unsqueeze(0) + + # Convert boolean mask to additive mask if needed + if attn_mask.dtype == torch.bool: + attn_mask = torch.where(attn_mask, 0.0, float('-inf')) + + attention_scores = attention_scores + attn_mask.to(attention_scores.dtype) + + attention_probs = attention_scores.softmax(dim=-1) + attn_out = torch.bmm(attention_probs, value) + + if orig_shape: + attn_out = attn_out.reshape( + orig_shape[0], orig_shape[1], attn_out.shape[1], attn_out.shape[2] + ) + return attn_out + + +def attention_wrapper_sharded_without_swap(query, key, value): + """Sharded attention wrapper using NKI kernel for trn2. + + Note: This kernel requires Q, K, V to have the same sequence length. + For cross-attention with different lengths, fall back to basic attention. + """ + import os + + bs, n_head, q_len, d_head = query.shape + _, _, kv_len, _ = key.shape + + # NKI kernel requires same sequence length for Q, K, V and NKI must be available + if q_len != kv_len or not NKI_AVAILABLE or _flash_fwd_call is None: + # Fall back to basic attention + return neuron_scaled_dot_product_attention(query, key, value) + + # Reshape for NKI kernel: expects [bs*n_head, d_head, seq_len] for Q, K + # and [bs*n_head, seq_len, d_head] for V + q = query.clone().permute(0, 1, 3, 2).reshape((bs*n_head, d_head, q_len)) + k = key.clone().permute(0, 1, 3, 2).reshape((bs*n_head, d_head, kv_len)) + v = value.clone().reshape((bs*n_head, kv_len, d_head)) + attn_output = torch.zeros((bs*n_head, q_len, d_head), dtype=torch.bfloat16, device=q.device) + + # Compute scale: 1/sqrt(d_head) + scale = 1.0 / math.sqrt(d_head) + + # Check if using virtual core size 2 (TRN2 default) + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "2")) + use_sharded_attention_kernel = (vc_size == 2) + + if use_sharded_attention_kernel: + grid = (nc(2),) + _flash_fwd_call[grid](q, k, v, scale, attn_output, + kernel_name="AttentionMMSoftmaxMMWithoutSwap") + else: + _flash_fwd_call(q, k, v, scale, attn_output, + kernel_name="AttentionMMSoftmaxMMWithoutSwap") + + attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) + return attn_output + + +# Store original SDPA function +sdpa_original = torch.nn.functional.scaled_dot_product_attention + + +def attention_wrapper(query, key, value, attn_mask=None, dropout_p=None, is_causal=None, + scale=None, enable_gqa=False): + """Attention wrapper for text encoder. + + Always uses our custom implementation for better Neuron tracing compatibility. + The custom implementation supports: + - Causal masking (is_causal=True) + - Explicit attention masks (attn_mask) + - GQA (handled by repeat_kv in model's forward, but we handle leftovers) + """ + # Always use our custom implementation for Neuron compatibility + return neuron_scaled_dot_product_attention(query, key, value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale) + + +def attention_wrapper_for_transformer(query, key, value, attn_mask=None, + dropout_p=None, is_causal=None, + scale=None, enable_gqa=False): + """Attention wrapper for transformer using NKI Flash Attention kernel. + + Uses NKI kernel for optimal performance on Trainium2. + Falls back to basic attention for incompatible shapes. + """ + # Check if NKI kernel can be used: + # 1. NKI must be available + # 2. Q, K, V must have same sequence length (joint attention) + # 3. No attention mask (NKI doesn't support masks well) + # 4. Not causal attention + + bs, n_head, q_len, d_head = query.shape + _, _, kv_len, _ = key.shape + + use_nki = ( + NKI_AVAILABLE and + _flash_fwd_call is not None and + q_len == kv_len and + attn_mask is None and + not is_causal + ) + + if use_nki: + # Use NKI Flash Attention kernel + return attention_wrapper_sharded_without_swap(query, key, value) + else: + # Fall back to basic attention + return neuron_scaled_dot_product_attention(query, key, value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal) diff --git a/contrib/models/Qwen-Image-Edit/src/neuron_parallel_utils.py b/contrib/models/Qwen-Image-Edit/src/neuron_parallel_utils.py new file mode 100644 index 00000000..e2816a9a --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/neuron_parallel_utils.py @@ -0,0 +1,604 @@ +import torch +from torch import nn +from diffusers.models.attention import FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.normalization import RMSNorm +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, RowParallelLinear +from neuronx_distributed.parallel_layers.pad import get_number_of_extra_heads, pad_model +import neuronx_distributed.parallel_layers.utils as neuronx_dist_utils + + +class ShardedRMSNorm(nn.Module): + """RMSNorm that works with sharded hidden dimensions.""" + def __init__(self, dim, eps=1e-6, elementwise_affine=True): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter('weight', None) + + def forward(self, x): + # RMSNorm computation - normalize over last dimension + rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps) + x_normed = x / rms + if self.weight is not None: + return x_normed * self.weight + return x_normed + + +def get_sharded_data(data, dim): + """Shard data across tensor parallel ranks.""" + tp_rank = parallel_state.get_tensor_model_parallel_rank() + s = data.shape[dim] // parallel_state.get_tensor_model_parallel_size() + if dim == 0: + return data[s * tp_rank : s * (tp_rank + 1)].clone() + elif dim == 1: + return data[:, s * tp_rank : s * (tp_rank + 1)].clone() + + +def shard_rmsnorm(orig_norm, new_dim): + """Create a sharded RMSNorm from an original RMSNorm.""" + eps = orig_norm.eps if hasattr(orig_norm, 'eps') else 1e-6 + elementwise_affine = hasattr(orig_norm, 'weight') and orig_norm.weight is not None + + new_norm = ShardedRMSNorm(new_dim, eps=eps, elementwise_affine=elementwise_affine) + + if elementwise_affine and orig_norm.weight is not None: + new_norm.weight.data = get_sharded_data(orig_norm.weight.data, 0) + + return new_norm + + +def shard_qwen_attention(tp_degree: int, attn: Attention, reduce_dtype: torch.dtype = torch.float32): + """ + Shard QwenImage attention module for tensor parallelism. + This handles both image attention (to_q/k/v) and text attention (add_q/k/v_proj). + + ``reduce_dtype`` controls the dtype used by the row-parallel linears' + all-reduce. Default fp32 matches the upstream NxDI default; passing bf16 + halves the bytes on the wire at the cost of a slightly less precise + reduction (acceptable for attention output projections in this model). + """ + orig_inner_dim = attn.to_q.out_features + dim_head = orig_inner_dim // attn.heads + assert orig_inner_dim % attn.heads == 0 + orig_num_heads = attn.heads + total_padded_heads = attn.heads + get_number_of_extra_heads(attn.heads, tp_degree) + attn.heads = neuronx_dist_utils.divide(total_padded_heads, tp_degree) + attn.sliceable_head_dim = attn.heads + new_inner_dim = dim_head * attn.heads + attn.inner_dim = new_inner_dim + + # Shard image attention projections (to_q, to_k, to_v) + orig_q = attn.to_q + attn.to_q = ColumnParallelLinear( + attn.to_q.in_features, + attn.to_q.out_features, + bias=(attn.to_q.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.to_q.weight.data = get_sharded_data(orig_q.weight.data, 0) + if attn.to_q.bias is not None: + attn.to_q.bias.data = get_sharded_data(orig_q.bias.data, 0) + del orig_q + + orig_k = attn.to_k + attn.to_k = ColumnParallelLinear( + attn.to_k.in_features, + attn.to_k.out_features, + bias=(attn.to_k.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.to_k.weight.data = get_sharded_data(orig_k.weight.data, 0) + if attn.to_k.bias is not None: + attn.to_k.bias.data = get_sharded_data(orig_k.bias.data, 0) + del orig_k + + orig_v = attn.to_v + attn.to_v = ColumnParallelLinear( + attn.to_v.in_features, + attn.to_v.out_features, + bias=(attn.to_v.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.to_v.weight.data = get_sharded_data(orig_v.weight.data, 0) + if attn.to_v.bias is not None: + attn.to_v.bias.data = get_sharded_data(orig_v.bias.data, 0) + del orig_v + + # Shard output projection + orig_out = attn.to_out[0] + attn.to_out[0] = RowParallelLinear( + attn.to_out[0].in_features, + attn.to_out[0].out_features, + bias=(attn.to_out[0].bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16, + reduce_dtype=reduce_dtype) + attn.to_out[0].weight.data = get_sharded_data(orig_out.weight.data, 1) + if attn.to_out[0].bias is not None: + attn.to_out[0].bias.data = orig_out.bias.data.detach() + del orig_out + + # Shard text attention projections (add_q_proj, add_k_proj, add_v_proj) + if hasattr(attn, 'add_q_proj') and attn.add_q_proj is not None: + orig_add_q = attn.add_q_proj + attn.add_q_proj = ColumnParallelLinear( + orig_add_q.in_features, + orig_add_q.out_features, + bias=(orig_add_q.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.add_q_proj.weight.data = get_sharded_data(orig_add_q.weight.data, 0) + if orig_add_q.bias is not None: + attn.add_q_proj.bias.data = get_sharded_data(orig_add_q.bias.data, 0) + del orig_add_q + + if hasattr(attn, 'add_k_proj') and attn.add_k_proj is not None: + orig_add_k = attn.add_k_proj + attn.add_k_proj = ColumnParallelLinear( + orig_add_k.in_features, + orig_add_k.out_features, + bias=(orig_add_k.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.add_k_proj.weight.data = get_sharded_data(orig_add_k.weight.data, 0) + if orig_add_k.bias is not None: + attn.add_k_proj.bias.data = get_sharded_data(orig_add_k.bias.data, 0) + del orig_add_k + + if hasattr(attn, 'add_v_proj') and attn.add_v_proj is not None: + orig_add_v = attn.add_v_proj + attn.add_v_proj = ColumnParallelLinear( + orig_add_v.in_features, + orig_add_v.out_features, + bias=(orig_add_v.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.add_v_proj.weight.data = get_sharded_data(orig_add_v.weight.data, 0) + if orig_add_v.bias is not None: + attn.add_v_proj.bias.data = get_sharded_data(orig_add_v.bias.data, 0) + del orig_add_v + + # Shard to_add_out + if hasattr(attn, 'to_add_out') and attn.to_add_out is not None: + orig_add_out = attn.to_add_out + attn.to_add_out = RowParallelLinear( + orig_add_out.in_features, + orig_add_out.out_features, + bias=(orig_add_out.bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16, + reduce_dtype=reduce_dtype) + attn.to_add_out.weight.data = get_sharded_data(orig_add_out.weight.data, 1) + if orig_add_out.bias is not None: + attn.to_add_out.bias.data = orig_add_out.bias.data.detach() + del orig_add_out + + # Note: RMSNorm layers (norm_q, norm_k, norm_added_q, norm_added_k) should NOT be sharded! + # They operate on head_dim (128) which doesn't change with tensor parallelism. + # The norms are applied AFTER unflatten to [batch, seq, heads, head_dim], + # so they normalize over head_dim, not inner_dim. + + # Note: pad_model is not needed when heads are evenly divisible by tp_degree + # For QwenImage: 24 heads / 4 = 6 heads per rank (evenly divisible) + return attn + + +def shard_feedforward(ff: FeedForward, reduce_dtype: torch.dtype = torch.float32) -> FeedForward: + """Shard FeedForward module for tensor parallelism. + + See ``shard_qwen_attention`` for the meaning of ``reduce_dtype``. + """ + # Shard the first linear layer (GELU projection) + orig_proj = ff.net[0].proj + ff.net[0].proj = ColumnParallelLinear( + ff.net[0].proj.in_features, + ff.net[0].proj.out_features, + bias=(ff.net[0].proj.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + ff.net[0].proj.weight.data = get_sharded_data(orig_proj.weight.data, 0) + if ff.net[0].proj.bias is not None: + ff.net[0].proj.bias.data = get_sharded_data(orig_proj.bias.data, 0) + del orig_proj + + # Shard the output linear layer + orig_linear = ff.net[2] + ff.net[2] = RowParallelLinear( + ff.net[2].in_features, + ff.net[2].out_features, + bias=(ff.net[2].bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16, + reduce_dtype=reduce_dtype) + ff.net[2].weight.data = get_sharded_data(orig_linear.weight.data, 1) + if ff.net[2].bias is not None: + ff.net[2].bias.data = orig_linear.bias.data.detach() + del orig_linear + return ff + + +def shard_modulation(mod: nn.Sequential) -> nn.Sequential: + """ + Shard modulation layer (img_mod, txt_mod) for tensor parallelism. + + Modulation layers are Sequential(SiLU, Linear) with shape [18432, 3072]. + 18432 = 6 * 3072 (for 6 modulation outputs: shift, scale for 3 different targets) + + We shard the output dimension (18432) across TP ranks. + + IMPORTANT: When gather_output=True, the output is gathered to full size BEFORE + adding the bias. So we must NOT shard the bias - it needs to be full size (18432). + """ + # mod[0] is SiLU (no weights) + # mod[1] is Linear(3072, 18432) + orig_linear = mod[1] + + mod[1] = ColumnParallelLinear( + orig_linear.in_features, + orig_linear.out_features, + bias=(orig_linear.bias is not None), + gather_output=True, # Need to gather for modulation to work correctly + dtype=torch.bfloat16) + # Shard weights across output dimension + mod[1].weight.data = get_sharded_data(orig_linear.weight.data, 0) + # IMPORTANT: Do NOT shard bias when gather_output=True! + # The bias is added after gathering, so it needs full size + if orig_linear.bias is not None: + mod[1].bias.data = orig_linear.bias.data.clone().to(torch.bfloat16) + del orig_linear + + return mod + + +def get_sharded_data_with_replication(data, dim, num_heads, tp_degree): + """ + Shard data with head replication when num_heads < tp_degree. + + For GQA models where num_kv_heads < tp_degree, we replicate KV heads + so each rank gets a copy. E.g., with 4 KV heads and TP=8: + - Heads are replicated 2x to make 8 virtual heads + - Each rank gets 1 virtual head (which is a copy of the original) + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + tp_size = parallel_state.get_tensor_model_parallel_size() + + if num_heads >= tp_size: + # Normal sharding + return get_sharded_data(data, dim) + else: + # Replication mode: num_heads < tp_size + # Each head is replicated (tp_size // num_heads) times + replication_factor = tp_size // num_heads + # Map tp_rank to the original head index + original_head_idx = tp_rank // replication_factor + + head_dim = data.shape[dim] // num_heads + if dim == 0: + start = original_head_idx * head_dim + end = (original_head_idx + 1) * head_dim + return data[start:end].clone() + elif dim == 1: + start = original_head_idx * head_dim + end = (original_head_idx + 1) * head_dim + return data[:, start:end].clone() + + +def shard_qwen2_attention(tp_degree: int, self_attn): + """ + Shard Qwen2/Qwen2.5-VL self attention module (used in text encoder). + + Handles GQA (Grouped Query Attention) where num_key_value_heads < num_heads. + For Qwen2.5-VL: num_heads=28, num_key_value_heads=4 + + Supports two modes: + 1. tp_degree <= num_kv_heads: Standard sharding (each rank gets subset of KV heads) + 2. tp_degree > num_kv_heads: KV head replication (each rank gets replicated KV heads) + + With tp_degree=8 and num_kv_heads=4: + - Q heads: 28 -> padded to 32 -> 4 per rank + - KV heads: 4 -> replicated to 8 -> 1 per rank (each pair of ranks shares same KV head) + """ + # Get original dimensions + orig_q = self_attn.q_proj + orig_k = self_attn.k_proj + orig_v = self_attn.v_proj + orig_o = self_attn.o_proj + + # Get KV head count + num_kv_heads = getattr(self_attn, 'num_key_value_heads', self_attn.num_heads) + num_q_heads = self_attn.num_heads + + # Check if KV replication is needed + kv_replicate_mode = num_kv_heads < tp_degree + if kv_replicate_mode: + # Replication mode: tp_degree must be divisible by num_kv_heads + if tp_degree % num_kv_heads != 0: + raise ValueError( + f"For KV head replication, tp_degree ({tp_degree}) must be divisible by " + f"num_key_value_heads ({num_kv_heads})") + print(f" Using KV head replication mode: {num_kv_heads} KV heads replicated across {tp_degree} ranks") + + # Calculate padded heads for Q + extra_q_heads = get_number_of_extra_heads(num_q_heads, tp_degree) + total_padded_q_heads = num_q_heads + extra_q_heads + q_head_dim = orig_q.out_features // num_q_heads # 3584 / 28 = 128 + padded_q_out_features = total_padded_q_heads * q_head_dim # 32 * 128 = 4096 + + print(f" Q heads: {num_q_heads} -> padded to {total_padded_q_heads}, " + f"out_features: {orig_q.out_features} -> {padded_q_out_features}") + + # Update number of heads per rank + self_attn.num_heads = neuronx_dist_utils.divide(total_padded_q_heads, tp_degree) + if hasattr(self_attn, 'num_key_value_heads'): + if kv_replicate_mode: + # In replication mode, each rank effectively has 1 KV head (replicated) + self_attn.num_key_value_heads = 1 + else: + self_attn.num_key_value_heads = self_attn.num_key_value_heads // tp_degree + + # CRITICAL: Update num_key_value_groups! + # This is used by repeat_kv() in attention forward to expand KV heads + if hasattr(self_attn, 'num_key_value_groups'): + self_attn.num_key_value_groups = self_attn.num_heads // self_attn.num_key_value_heads + print(f" Updated num_key_value_groups: {self_attn.num_key_value_groups}") + + # Shard Q projection (with padding if needed) + # Need to pad weights before sharding when num_heads is not divisible by tp_degree + q_weight_padded = orig_q.weight.data + q_bias_padded = orig_q.bias.data if orig_q.bias is not None else None + + if extra_q_heads > 0: + # Pad Q weights with zeros for extra heads + padding_size = extra_q_heads * q_head_dim + q_weight_padding = torch.zeros( + (padding_size, orig_q.in_features), + dtype=orig_q.weight.dtype, + device=orig_q.weight.device) + q_weight_padded = torch.cat([orig_q.weight.data, q_weight_padding], dim=0) + + if orig_q.bias is not None: + q_bias_padding = torch.zeros( + padding_size, + dtype=orig_q.bias.dtype, + device=orig_q.bias.device) + q_bias_padded = torch.cat([orig_q.bias.data, q_bias_padding], dim=0) + + # Now create ColumnParallelLinear with padded dimensions + self_attn.q_proj = ColumnParallelLinear( + orig_q.in_features, + padded_q_out_features, # Use padded out_features + bias=(orig_q.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + self_attn.q_proj.weight.data = get_sharded_data(q_weight_padded, 0) + if orig_q.bias is not None: + self_attn.q_proj.bias.data = get_sharded_data(q_bias_padded, 0) + del orig_q + + # Shard K projection (replicated if kv_replicate_mode) + # Get head_dim for KV + kv_head_dim = orig_k.out_features // num_kv_heads # 512 / 4 = 128 + + if kv_replicate_mode: + # In replication mode, use regular nn.Linear (not ColumnParallelLinear) + # because we want each rank to have 1 full KV head, not a fraction + # Each rank gets 1 KV head = head_dim features + kv_out_features_per_rank = kv_head_dim # 128 + + self_attn.k_proj = nn.Linear( + orig_k.in_features, + kv_out_features_per_rank, + bias=(orig_k.bias is not None), + dtype=torch.bfloat16) + self_attn.k_proj.weight.data = get_sharded_data_with_replication( + orig_k.weight.data, 0, num_kv_heads, tp_degree) + if orig_k.bias is not None: + self_attn.k_proj.bias.data = get_sharded_data_with_replication( + orig_k.bias.data, 0, num_kv_heads, tp_degree) + else: + self_attn.k_proj = ColumnParallelLinear( + orig_k.in_features, + orig_k.out_features, + bias=(orig_k.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + self_attn.k_proj.weight.data = get_sharded_data(orig_k.weight.data, 0) + if orig_k.bias is not None: + self_attn.k_proj.bias.data = get_sharded_data(orig_k.bias.data, 0) + del orig_k + + # Shard V projection (replicated if kv_replicate_mode) + if kv_replicate_mode: + # Same as K: use regular nn.Linear with replicated weights + kv_out_features_per_rank = kv_head_dim # 128 + + self_attn.v_proj = nn.Linear( + orig_v.in_features, + kv_out_features_per_rank, + bias=(orig_v.bias is not None), + dtype=torch.bfloat16) + self_attn.v_proj.weight.data = get_sharded_data_with_replication( + orig_v.weight.data, 0, num_kv_heads, tp_degree) + if orig_v.bias is not None: + self_attn.v_proj.bias.data = get_sharded_data_with_replication( + orig_v.bias.data, 0, num_kv_heads, tp_degree) + else: + self_attn.v_proj = ColumnParallelLinear( + orig_v.in_features, + orig_v.out_features, + bias=(orig_v.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + self_attn.v_proj.weight.data = get_sharded_data(orig_v.weight.data, 0) + if orig_v.bias is not None: + self_attn.v_proj.bias.data = get_sharded_data(orig_v.bias.data, 0) + del orig_v + + # Shard O projection (always sharded based on Q heads) + # O projection input comes from attention output, which has padded_q_out_features + # We need to pad the O weight's input dimension to match + + o_weight_padded = orig_o.weight.data + + if extra_q_heads > 0: + # Original O weight: (out_features, in_features) = (3584, 3584) + # Need to pad input dimension to padded_q_out_features = 4096 + padding_size = extra_q_heads * q_head_dim + o_weight_padding = torch.zeros( + (orig_o.out_features, padding_size), + dtype=orig_o.weight.dtype, + device=orig_o.weight.device) + o_weight_padded = torch.cat([orig_o.weight.data, o_weight_padding], dim=1) + + self_attn.o_proj = RowParallelLinear( + padded_q_out_features, # Use padded in_features + orig_o.out_features, + bias=(orig_o.bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16) + self_attn.o_proj.weight.data = get_sharded_data(o_weight_padded, 1) + if orig_o.bias is not None: + self_attn.o_proj.bias.data = orig_o.bias.data.detach() + del orig_o + + return self_attn + + +def shard_vision_attention(tp_degree: int, attn): + """ + Shard Qwen2.5-VL Vision Encoder attention module. + + Vision attention uses fused QKV projection: + - qkv: (in_features, 3 * in_features) -> splits into Q, K, V + - proj: output projection + """ + orig_qkv = attn.qkv + orig_proj = attn.proj + + # Shard fused QKV projection + attn.qkv = ColumnParallelLinear( + orig_qkv.in_features, + orig_qkv.out_features, + bias=(orig_qkv.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + attn.qkv.weight.data = get_sharded_data(orig_qkv.weight.data, 0) + if orig_qkv.bias is not None: + attn.qkv.bias.data = get_sharded_data(orig_qkv.bias.data, 0) + del orig_qkv + + # Shard output projection + attn.proj = RowParallelLinear( + orig_proj.in_features, + orig_proj.out_features, + bias=(orig_proj.bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16) + attn.proj.weight.data = get_sharded_data(orig_proj.weight.data, 1) + if orig_proj.bias is not None: + attn.proj.bias.data = orig_proj.bias.data.detach() + del orig_proj + + return attn + + +def shard_vision_mlp(mlp): + """ + Shard Qwen2.5-VL Vision Encoder MLP module. + + Uses gate_proj, up_proj, down_proj like Qwen2 MLP. + """ + orig_gate = mlp.gate_proj + orig_up = mlp.up_proj + orig_down = mlp.down_proj + + # Shard gate projection + mlp.gate_proj = ColumnParallelLinear( + orig_gate.in_features, + orig_gate.out_features, + bias=(orig_gate.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + mlp.gate_proj.weight.data = get_sharded_data(orig_gate.weight.data, 0) + if orig_gate.bias is not None: + mlp.gate_proj.bias.data = get_sharded_data(orig_gate.bias.data, 0) + del orig_gate + + # Shard up projection + mlp.up_proj = ColumnParallelLinear( + orig_up.in_features, + orig_up.out_features, + bias=(orig_up.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + mlp.up_proj.weight.data = get_sharded_data(orig_up.weight.data, 0) + if orig_up.bias is not None: + mlp.up_proj.bias.data = get_sharded_data(orig_up.bias.data, 0) + del orig_up + + # Shard down projection + mlp.down_proj = RowParallelLinear( + orig_down.in_features, + orig_down.out_features, + bias=(orig_down.bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16) + mlp.down_proj.weight.data = get_sharded_data(orig_down.weight.data, 1) + if orig_down.bias is not None: + mlp.down_proj.bias.data = orig_down.bias.data.detach() + del orig_down + + return mlp + + +def shard_qwen2_mlp(mlp): + """ + Shard Qwen2 MLP module (used in text encoder). + """ + orig_gate = mlp.gate_proj + orig_up = mlp.up_proj + orig_down = mlp.down_proj + + # Shard gate projection + mlp.gate_proj = ColumnParallelLinear( + orig_gate.in_features, + orig_gate.out_features, + bias=(orig_gate.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + mlp.gate_proj.weight.data = get_sharded_data(orig_gate.weight.data, 0) + if orig_gate.bias is not None: + mlp.gate_proj.bias.data = get_sharded_data(orig_gate.bias.data, 0) + del orig_gate + + # Shard up projection + mlp.up_proj = ColumnParallelLinear( + orig_up.in_features, + orig_up.out_features, + bias=(orig_up.bias is not None), + gather_output=False, + dtype=torch.bfloat16) + mlp.up_proj.weight.data = get_sharded_data(orig_up.weight.data, 0) + if orig_up.bias is not None: + mlp.up_proj.bias.data = get_sharded_data(orig_up.bias.data, 0) + del orig_up + + # Shard down projection + mlp.down_proj = RowParallelLinear( + orig_down.in_features, + orig_down.out_features, + bias=(orig_down.bias is not None), + input_is_parallel=True, + dtype=torch.bfloat16) + mlp.down_proj.weight.data = get_sharded_data(orig_down.weight.data, 1) + if orig_down.bias is not None: + mlp.down_proj.bias.data = orig_down.bias.data.detach() + del orig_down + + return mlp diff --git a/contrib/models/Qwen-Image-Edit/src/neuron_rope.py b/contrib/models/Qwen-Image-Edit/src/neuron_rope.py new file mode 100644 index 00000000..5266f604 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/neuron_rope.py @@ -0,0 +1,307 @@ +""" +Neuron-compatible RoPE (Rotary Position Embedding) implementation for QwenImage. + +This module provides RoPE implementations that don't use complex numbers, +which are not supported by AWS Neuron. + +The original QwenImage uses torch.polar() to create complex frequencies, +but Neuron doesn't support C64 (complex64) datatypes. This implementation +uses (cos, sin) pairs instead. +""" + +import torch +from torch import nn +from typing import List, Tuple, Optional, Union +import functools + + +class NeuronQwenEmbedRope(nn.Module): + """ + Neuron-compatible RoPE for QwenImage that doesn't use complex numbers. + + Instead of storing complex frequencies, we store (cos, sin) pairs. + The original implementation uses: + freqs = torch.polar(torch.ones_like(freqs), freqs) # complex + We use: + cos_freqs = torch.cos(freqs) + sin_freqs = torch.sin(freqs) + """ + def __init__(self, theta: int, axes_dim: List[int], scale_rope: bool = False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.scale_rope = scale_rope + + # Precompute position indices (same as original) + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + + # Compute frequencies as (cos, sin) instead of complex + # Original: torch.polar(ones, freqs) -> complex exp(i*freqs) + # We store: cos(freqs), sin(freqs) separately + self.pos_freqs_cos, self.pos_freqs_sin = self._compute_all_freqs(pos_index) + self.neg_freqs_cos, self.neg_freqs_sin = self._compute_all_freqs(neg_index) + + def _rope_params_real(self, index: torch.Tensor, dim: int, theta: int = 10000) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Compute RoPE frequencies as (cos, sin) instead of complex. + + Original: freqs = torch.polar(torch.ones_like(freqs), freqs) + This returns complex tensor of shape [len(index), dim//2] + + We return (cos, sin) each of shape [len(index), dim//2] + """ + assert dim % 2 == 0 + # Compute angles: outer product of positions and frequency bases + freqs = torch.outer( + index.float(), + 1.0 / torch.pow(theta, torch.arange(0, dim, 2).float() / dim) + ) + # Return cos and sin instead of complex polar + return torch.cos(freqs), torch.sin(freqs) + + def _compute_all_freqs(self, index: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute frequencies for all axes and concatenate.""" + freqs = [] + for dim in self.axes_dim: + cos_f, sin_f = self._rope_params_real(index, dim, self.theta) + freqs.append((cos_f, sin_f)) + + # Concatenate along dimension axis + # Each has shape [4096, axes_dim[i]//2] + cos_all = torch.cat([f[0] for f in freqs], dim=1) + sin_all = torch.cat([f[1] for f in freqs], dim=1) + + return cos_all, sin_all + + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + txt_seq_lens: Optional[List[int]] = None, + device: torch.device = None, + max_txt_seq_len: Optional[Union[int, torch.Tensor]] = None, + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + """ + Compute RoPE frequencies for video and text. + + Handles multiple img_shapes formats: + - (T, H, W): single tuple for one video + - [(T, H, W)]: list with single tuple + - [(T1, H, W), (T2, H, W)]: list of tuples (multiple images) + - [[(T1, H, W), (T2, H, W)]]: nested list (batch of multiple images) + + For multiple images, frames are summed to get total patch count. + + Returns: + Tuple of (vid_freqs, txt_freqs), each being (cos, sin) tuple + """ + # Handle deprecated txt_seq_lens parameter + if txt_seq_lens is not None and max_txt_seq_len is None: + max_txt_seq_len = max(txt_seq_lens) if isinstance(txt_seq_lens, list) else txt_seq_lens + + if max_txt_seq_len is None: + raise ValueError("Either max_txt_seq_len or txt_seq_lens must be provided.") + + # Parse video_fhw into (total_frames, height, width) + # Need to handle different formats correctly: + # 1. (T, H, W) - single tuple + # 2. [(T, H, W)] - list with single tuple + # 3. [(T1, H, W), (T2, H, W)] - list of tuples for multiple images + # 4. [[(T1, H, W), (T2, H, W)]] - nested list for batch + + if isinstance(video_fhw, tuple) and len(video_fhw) == 3 and isinstance(video_fhw[0], int): + # Format 1: (T, H, W) - single tuple + frame, height, width = video_fhw + elif isinstance(video_fhw, list) and len(video_fhw) > 0: + first_elem = video_fhw[0] + if isinstance(first_elem, tuple) and len(first_elem) == 3 and isinstance(first_elem[0], int): + # Format 2 or 3: [(T, H, W)] or [(T1, H, W), (T2, H, W), ...] + # Sum frames from all tuples, assume same H, W + frame = sum(t[0] for t in video_fhw) + height, width = first_elem[1], first_elem[2] + elif isinstance(first_elem, (list, tuple)) and len(first_elem) > 0: + # Format 4: [[(T1, H, W), (T2, H, W), ...]] - nested list + # Take first batch item, sum frames from all images + shapes = first_elem + if isinstance(shapes[0], tuple) and len(shapes[0]) == 3: + frame = sum(t[0] for t in shapes) + height, width = shapes[0][1], shapes[0][2] + else: + raise ValueError(f"Unsupported nested video_fhw format: {video_fhw}") + else: + raise ValueError(f"Unsupported video_fhw format: {video_fhw}") + else: + raise ValueError(f"Unsupported video_fhw format: {video_fhw}") + + # Compute video frequencies + vid_cos, vid_sin = self._compute_video_freqs(frame, height, width, device) + + # Compute text frequencies + max_txt_seq_len_int = int(max_txt_seq_len) + if self.scale_rope: + max_vid_index = max(height // 2, width // 2) + else: + max_vid_index = max(height, width) + + txt_cos = self.pos_freqs_cos.to(device)[max_vid_index:max_vid_index + max_txt_seq_len_int] + txt_sin = self.pos_freqs_sin.to(device)[max_vid_index:max_vid_index + max_txt_seq_len_int] + + return (vid_cos, vid_sin), (txt_cos, txt_sin) + + def _compute_video_freqs( + self, frame: int, height: int, width: int, device: torch.device = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute video frequencies for given dimensions.""" + seq_lens = frame * height * width + + pos_cos = self.pos_freqs_cos.to(device) if device is not None else self.pos_freqs_cos + pos_sin = self.pos_freqs_sin.to(device) if device is not None else self.pos_freqs_sin + neg_cos = self.neg_freqs_cos.to(device) if device is not None else self.neg_freqs_cos + neg_sin = self.neg_freqs_sin.to(device) if device is not None else self.neg_freqs_sin + + # Split by axes dimensions (each is dim//2 because we computed with dim//2 freqs) + split_dims = [x // 2 for x in self.axes_dim] + + pos_cos_split = pos_cos.split(split_dims, dim=1) + pos_sin_split = pos_sin.split(split_dims, dim=1) + neg_cos_split = neg_cos.split(split_dims, dim=1) + neg_sin_split = neg_sin.split(split_dims, dim=1) + + # Frame frequencies (always from positive) + freqs_frame_cos = pos_cos_split[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + freqs_frame_sin = pos_sin_split[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) + + if self.scale_rope: + # Height: combine negative and positive + h_neg_len = height - height // 2 + freqs_height_cos = torch.cat([neg_cos_split[1][-h_neg_len:], pos_cos_split[1][:height // 2]], dim=0) + freqs_height_sin = torch.cat([neg_sin_split[1][-h_neg_len:], pos_sin_split[1][:height // 2]], dim=0) + freqs_height_cos = freqs_height_cos.view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height_sin = freqs_height_sin.view(1, height, 1, -1).expand(frame, height, width, -1) + + # Width: combine negative and positive + w_neg_len = width - width // 2 + freqs_width_cos = torch.cat([neg_cos_split[2][-w_neg_len:], pos_cos_split[2][:width // 2]], dim=0) + freqs_width_sin = torch.cat([neg_sin_split[2][-w_neg_len:], pos_sin_split[2][:width // 2]], dim=0) + freqs_width_cos = freqs_width_cos.view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width_sin = freqs_width_sin.view(1, 1, width, -1).expand(frame, height, width, -1) + else: + freqs_height_cos = pos_cos_split[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_height_sin = pos_sin_split[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) + freqs_width_cos = pos_cos_split[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + freqs_width_sin = pos_sin_split[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) + + # Concatenate all axes + freqs_cos = torch.cat([freqs_frame_cos, freqs_height_cos, freqs_width_cos], dim=-1).reshape(seq_lens, -1) + freqs_sin = torch.cat([freqs_frame_sin, freqs_height_sin, freqs_width_sin], dim=-1).reshape(seq_lens, -1) + + return freqs_cos.clone().contiguous(), freqs_sin.clone().contiguous() + + +def apply_rotary_emb_neuron( + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor, torch.Tensor], + use_real: bool = True, + use_real_unbind_dim: int = -1, +) -> torch.Tensor: + """ + Apply rotary embeddings without using complex numbers. + + This is a drop-in replacement for apply_rotary_emb_qwen that uses + (cos, sin) tuples instead of complex tensors. + + The rotation is applied as: + out[2k] = x[2k] * cos[k] - x[2k+1] * sin[k] + out[2k+1] = x[2k] * sin[k] + x[2k+1] * cos[k] + + This is equivalent to complex multiplication: + (x_real + i*x_imag) * (cos + i*sin) = (x_real*cos - x_imag*sin) + i*(x_real*sin + x_imag*cos) + + Args: + x: Input tensor [B, S, H, D] + freqs_cis: Tuple of (cos, sin) tensors, each [S, D//2] + use_real: Always True for Neuron (we don't use complex) + use_real_unbind_dim: Dimension for unbinding (-1 or -2) + + Returns: + Tensor with rotary embeddings applied + """ + cos, sin = freqs_cis + + # cos/sin have shape [S, D//2] where D is the head_dim + # x has shape [B, S, H, D] + + # Expand cos/sin to match x's D dimension by interleaving + # [c0, c1, ..., c31] -> [c0, c0, c1, c1, ..., c31, c31] + # This uses repeat_interleave which is more compiler-friendly than stack+flatten + cos = cos.repeat_interleave(2, dim=-1) # [S, D] + sin = sin.repeat_interleave(2, dim=-1) # [S, D] + + # Expand dims for broadcasting: [S, D] -> [1, S, 1, D] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + # Move to same device as x + cos = cos.to(x.device) + sin = sin.to(x.device) + + # For use_real_unbind_dim == -1 (default for QwenImage) + # x is stored as [x0_real, x0_imag, x1_real, x1_imag, ...] + # x_rotated should be [-x0_imag, x0_real, -x1_imag, x1_real, ...] + if use_real_unbind_dim == -1: + # Reshape to separate real/imag pairs, then create rotated version + # Use view instead of reshape for better tracing + orig_shape = x.shape + x_reshape = x.view(orig_shape[0], orig_shape[1], orig_shape[2], -1, 2) # [B, S, H, D//2, 2] + # Create rotated: [-imag, real] for each pair + x_rotated = torch.cat([-x_reshape[..., 1:2], x_reshape[..., 0:1]], dim=-1) # [B, S, H, D//2, 2] + x_rotated = x_rotated.view(orig_shape) # [B, S, H, D] + + elif use_real_unbind_dim == -2: + # x is stored as [x0_real, x1_real, ..., x0_imag, x1_imag, ...] + half_d = x.shape[-1] // 2 + x_real = x[..., :half_d] + x_imag = x[..., half_d:] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"use_real_unbind_dim={use_real_unbind_dim} but should be -1 or -2.") + + # Apply rotation: out = x * cos + x_rotated * sin + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + return out + + +def patch_qwenimage_rope(transformer): + """ + Patch the QwenImage transformer to use Neuron-compatible RoPE. + + This replaces the complex-number based RoPE with sin/cos based implementation. + """ + # Get original config + orig_rope = transformer.pos_embed + theta = orig_rope.theta + axes_dim = orig_rope.axes_dim + scale_rope = orig_rope.scale_rope + + print(f" Original RoPE: theta={theta}, axes_dim={axes_dim}, scale_rope={scale_rope}") + + # Replace with Neuron-compatible version + transformer.pos_embed = NeuronQwenEmbedRope( + theta=theta, + axes_dim=axes_dim, + scale_rope=scale_rope + ) + + # Patch the apply_rotary_emb_qwen function to use our version + import diffusers.models.transformers.transformer_qwenimage as qwen_module + + # Store original function + if not hasattr(qwen_module, '_orig_apply_rotary_emb_qwen'): + qwen_module._orig_apply_rotary_emb_qwen = qwen_module.apply_rotary_emb_qwen + + # Replace with neuron-compatible version + qwen_module.apply_rotary_emb_qwen = apply_rotary_emb_neuron + + print(" Patched QwenImage transformer with Neuron-compatible RoPE (no complex numbers)") + return transformer diff --git a/contrib/models/Qwen-Image-Edit/src/run_qwen_image_edit.py b/contrib/models/Qwen-Image-Edit/src/run_qwen_image_edit.py new file mode 100644 index 00000000..b5ba273a --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/run_qwen_image_edit.py @@ -0,0 +1,3196 @@ +""" +Qwen-Image-Edit-2509 Inference Script for AWS Trainium2 + +This script runs the Qwen-Image-Edit model ENTIRELY on Neuron devices. +All components (Text Encoder, Transformer, VAE) run on Trainium2. + +Components: +- Text Encoder (Qwen2.5-VL): Vision encoder + Language model +- Transformer: QwenImageTransformer2DModel (TP=8) +- VAE: Encoder and Decoder + +Usage: + # Single image editing: + python run_qwen_image_edit.py --images input.jpg --prompt "change the sky to sunset" + + # Multi-image editing (1-3 images): + python run_qwen_image_edit.py --images img1.jpg img2.jpg --prompt "combine these images" +""" + +import os + +# ============================================================================ +# CRITICAL: Set Neuron environment variables BEFORE any other imports! +# These MUST match the compilation settings. +# ============================================================================ +# NOTE: Transformer uses TP=8. Language Model can run on: +# - Neuron with TP=4 (correct GQA alignment, but requires separate process) +# - CPU (slower but works in same process as TP=8 Transformer) +# +# GQA alignment issue: 28Q/4KV heads requires TP=4 for correct alignment, +# but TP=4 causes OOM on Transformer. So we default to CPU Language Model. +TP_DEGREE = 8 # For Transformer; Language Model runs on CPU by default + +# Set SPMD world size. Default 8 (V3 CP/CFG world=8). Override via QIE_WORLD_SIZE for +# larger CP (e.g. TP=4 x CP=4 = world 16). Must match the transformer NEFF's world_size. +_WORLD = int(os.environ.get("QIE_WORLD_SIZE", str(TP_DEGREE))) +os.environ["LOCAL_WORLD_SIZE"] = str(_WORLD) + +# Neuron runtime settings - MUST match compilation +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" # For trn2 LNC=2 +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" # For trn2 LNC=2 + +# Neuron compiler settings (for any runtime compilation) +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" + +print(f"Neuron runtime configured: world={_WORLD}, LNC=2") + +import argparse +import contextlib +import random +import time + +import numpy as np +import torch +import torch_neuronx +import neuronx_distributed +from PIL import Image + +from diffusers import QwenImageEditPlusPipeline +from diffusers.utils import load_image + +# Import Neuron-compatible VAE +from autoencoder_kl_qwenimage_neuron import ( + AutoencoderKLQwenImage as NeuronAutoencoder +) +from neuron_commons import NeuronTextEncoderWrapper + +# Import NxDModel for V2 API loading +try: + from neuronx_distributed.trace.nxd_model.nxd_model import NxDModel + NXD_MODEL_AVAILABLE = True +except ImportError: + NXD_MODEL_AVAILABLE = False + print("WARNING: NxDModel not available. V2 models cannot be loaded.") + +# Constants +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models_qwen_image_edit" +HUGGINGFACE_CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = os.environ.get("QIE_MODEL_PATH", "Qwen/Qwen-Image-Edit-2509") +SEED = 42 + + +def set_seed(seed: int): + """Set all random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + print(f"Random seed set to: {seed}") + + +class NeuronTransformerWrapper(torch.nn.Module): + """ + Wrapper for compiled transformer model on Trainium2. + """ + def __init__(self, original_transformer, compiled_transformer, img_shapes, + expected_num_patches=1024, expected_seq_len=512): + super().__init__() + self.config = original_transformer.config + self.dtype = original_transformer.dtype + self.device = original_transformer.device + self.compiled_transformer = compiled_transformer + self.img_shapes = img_shapes + self.expected_num_patches = expected_num_patches + self.expected_seq_len = expected_seq_len + + @contextlib.contextmanager + def cache_context(self, name: str): + """Dummy cache context for compatibility with pipeline. + Compiled models don't use dynamic caching.""" + yield + + def forward(self, hidden_states, encoder_hidden_states=None, + timestep=None, img_shapes=None, return_dict=False, **kwargs): + """ + Forward pass using compiled transformer on Neuron. + Handles shape padding and dtype conversion for compiled model. + """ + batch_size = hidden_states.shape[0] + + # Debug: Print shapes on first call + if not hasattr(self, '_debug_printed'): + print(f"DEBUG Transformer input shapes:") + print(f" hidden_states: {hidden_states.shape}") + print(f" encoder_hidden_states: {encoder_hidden_states.shape}") + print(f" timestep: {timestep.shape}, dtype={timestep.dtype}") + print(f" img_shapes: {img_shapes}") + print(f" Expected: num_patches={self.expected_num_patches}, seq_len={self.expected_seq_len}") + self._debug_printed = True + + # 1. Handle hidden_states shape (num_patches dimension) + # Compiled model expects (batch, expected_num_patches, 64) + actual_patches = hidden_states.shape[1] + if actual_patches != self.expected_num_patches: + if actual_patches < self.expected_num_patches: + # Pad with zeros + pad_size = self.expected_num_patches - actual_patches + padding = torch.zeros( + (batch_size, pad_size, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, padding], dim=1) + else: + # Truncate - This is problematic! The model was compiled for fewer patches. + # This likely means the transformer needs to be recompiled with correct shape. + print(f"ERROR: hidden_states has {actual_patches} patches but model expects {self.expected_num_patches}") + print(f" You may need to recompile the transformer with correct dimensions.") + print(f" Truncating will produce incorrect results!") + hidden_states = hidden_states[:, :self.expected_num_patches, :] + + # 2. Handle encoder_hidden_states shape (sequence length) + # Compiled model expects (batch, expected_seq_len, 3584) + actual_seq_len = encoder_hidden_states.shape[1] + if actual_seq_len != self.expected_seq_len: + if actual_seq_len < self.expected_seq_len: + # Pad with zeros + pad_size = self.expected_seq_len - actual_seq_len + padding = torch.zeros( + (batch_size, pad_size, encoder_hidden_states.shape[2]), + dtype=encoder_hidden_states.dtype, + device=encoder_hidden_states.device + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + else: + # Truncate + print(f"WARNING: Truncating encoder_hidden_states from {actual_seq_len} to {self.expected_seq_len}") + encoder_hidden_states = encoder_hidden_states[:, :self.expected_seq_len, :] + + # 3. Convert timestep to float32 (compiled model expects float32) + timestep = timestep.to(torch.float32) + + # Run on compiled Neuron model + output = self.compiled_transformer( + hidden_states, + encoder_hidden_states, + timestep + ) + + # 4. Remove padding from output if we padded hidden_states + if actual_patches < self.expected_num_patches: + output = (output[0][:, :actual_patches, :],) + output[1:] + + if return_dict: + from diffusers.models.modeling_outputs import Transformer2DModelOutput + return Transformer2DModelOutput(sample=output[0]) + return output + + +class NeuronTransformerWrapperV2(torch.nn.Module): + """ + Wrapper for V2 compiled transformer (ModelBuilder API) on Trainium2. + + Key difference from V1: RoPE frequencies are passed as input, not computed internally. + """ + def __init__(self, original_transformer, nxd_model, img_rotary_emb, txt_rotary_emb, + expected_num_patches=1024, expected_seq_len=512, temporal_frames=3): + super().__init__() + self.config = original_transformer.config + self.dtype = original_transformer.dtype + self.device = original_transformer.device + self.nxd_model = nxd_model + + # Pre-computed RoPE frequencies + self.img_rotary_emb = img_rotary_emb + self.txt_rotary_emb = txt_rotary_emb + + self.expected_num_patches = expected_num_patches + self.expected_seq_len = expected_seq_len + self.temporal_frames = temporal_frames + # Base patches per frame (noise prediction output size) + self.base_patches = expected_num_patches // temporal_frames + + @contextlib.contextmanager + def cache_context(self, name: str): + """Dummy cache context for compatibility with pipeline.""" + yield + + def forward(self, hidden_states, encoder_hidden_states=None, + timestep=None, img_shapes=None, return_dict=False, **kwargs): + """Forward pass using V2 compiled transformer with RoPE as input.""" + batch_size = hidden_states.shape[0] + + # Debug: Print shapes on first call + if not hasattr(self, '_debug_printed'): + print(f"DEBUG Transformer V2 input shapes:") + print(f" hidden_states: {hidden_states.shape}") + print(f" encoder_hidden_states: {encoder_hidden_states.shape}") + print(f" timestep: {timestep.shape}") + print(f" img_rotary_emb: {self.img_rotary_emb.shape}") + print(f" txt_rotary_emb: {self.txt_rotary_emb.shape}") + print(f" temporal_frames: {self.temporal_frames}, base_patches: {self.base_patches}") + print(f" Will extract last {self.base_patches} patches as noise prediction") + self._debug_printed = True + + # Handle hidden_states padding + actual_patches = hidden_states.shape[1] + if actual_patches != self.expected_num_patches: + if actual_patches < self.expected_num_patches: + pad_size = self.expected_num_patches - actual_patches + padding = torch.zeros( + (batch_size, pad_size, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, padding], dim=1) + else: + print(f"ERROR: hidden_states has {actual_patches} patches but model expects {self.expected_num_patches}") + hidden_states = hidden_states[:, :self.expected_num_patches, :] + + # Handle encoder_hidden_states padding + actual_seq_len = encoder_hidden_states.shape[1] + if actual_seq_len != self.expected_seq_len: + if actual_seq_len < self.expected_seq_len: + pad_size = self.expected_seq_len - actual_seq_len + padding = torch.zeros( + (batch_size, pad_size, encoder_hidden_states.shape[2]), + dtype=encoder_hidden_states.dtype, + device=encoder_hidden_states.device + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + else: + print(f"WARNING: Truncating encoder_hidden_states from {actual_seq_len} to {self.expected_seq_len}") + encoder_hidden_states = encoder_hidden_states[:, :self.expected_seq_len, :] + + # Convert timestep to float32 + timestep = timestep.to(torch.float32) + + # Run V2 model with RoPE as input + output = self.nxd_model( + hidden_states, + encoder_hidden_states, + timestep, + self.img_rotary_emb, + self.txt_rotary_emb + ) + + # Extract tensor from output (handle tuple or tensor) + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + + # For image editing, the model processes temporal_frames * base_patches + # but should only return the noise prediction for one frame (base_patches) + # Try extracting the FIRST frame (index 0) as noise prediction + # (QwenImage may use frame 0 for noise, unlike other models that use last frame) + output_tensor = output_tensor[:, :self.base_patches, :] + + if return_dict: + from diffusers.models.modeling_outputs import Transformer2DModelOutput + return Transformer2DModelOutput(sample=output_tensor) + return (output_tensor,) + + +def load_transformer_v2(compiled_models_dir: str, pipe, args): + """ + Load V2 compiled transformer model using NxDModel API. + + V2 models are compiled with ModelBuilder and require: + 1. nxd_model.pt - the compiled model + 2. weights/ - sharded checkpoints + 3. rope_cache.pt - pre-computed RoPE tensors + 4. config.json - model configuration + + Args: + compiled_models_dir: Directory containing compiled model artifacts + pipe: Pipeline with original transformer (for config) + args: Command line arguments + + Returns: + NeuronTransformerWrapperV2 wrapping the loaded model + """ + import json + + if not NXD_MODEL_AVAILABLE: + raise RuntimeError( + "NxDModel is not available. Please ensure neuronx_distributed is installed correctly." + ) + + v2_path = f"{compiled_models_dir}/transformer_v2" + nxd_model_path = f"{v2_path}/nxd_model.pt" + weights_path = f"{v2_path}/weights" + rope_cache_path = f"{v2_path}/rope_cache.pt" + config_path = f"{v2_path}/config.json" + + # Validate all required files exist + if not os.path.exists(nxd_model_path): + raise FileNotFoundError( + f"V2 transformer model not found at {nxd_model_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2.py" + ) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"V2 transformer weights not found at {weights_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2.py" + ) + if not os.path.exists(rope_cache_path): + raise FileNotFoundError( + f"V2 RoPE cache not found at {rope_cache_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2.py" + ) + + # Load config + print(f" Loading V2 config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + expected_num_patches = config["num_patches"] + expected_seq_len = config["text_seq_len"] + temporal_frames = config.get("frame", config.get("patch_multiplier", 3)) + base_patches = expected_num_patches // temporal_frames + print(f" V2 config: patches={expected_num_patches}, seq_len={expected_seq_len}") + print(f" V2 config: temporal_frames={temporal_frames}, base_patches={base_patches}") + + # Load pre-computed RoPE tensors + print(f" Loading RoPE cache from {rope_cache_path}...") + rope_cache = torch.load(rope_cache_path) + img_rotary_emb = rope_cache["img_rotary_emb"].to(torch.bfloat16) + txt_rotary_emb = rope_cache["txt_rotary_emb"].to(torch.bfloat16) + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + # Load the compiled model using NxDModel.load() + print(f" Loading V2 model from {nxd_model_path}...") + nxd_model = NxDModel.load(nxd_model_path) + + # Load sharded checkpoints + from safetensors.torch import load_file + tp_degree = config.get("tp_degree", 8) + print(f" Loading sharded weights for TP={tp_degree}...") + sharded_checkpoints = [] + for rank in range(tp_degree): + ckpt_path = f"{weights_path}/tp{rank}_sharded_checkpoint.safetensors" + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = load_file(ckpt_path) + sharded_checkpoints.append(ckpt) + if rank == 0: + print(f" Rank 0 checkpoint keys: {len(ckpt)} tensors") + + # Initialize model with weights and move to Neuron + print(" Setting weights...") + nxd_model.set_weights(sharded_checkpoints) + print(" Moving model to Neuron...") + nxd_model.to_neuron() + print(" V2 model initialized on Neuron!") + + # Create wrapper + wrapper = NeuronTransformerWrapperV2( + original_transformer=pipe.transformer, + nxd_model=nxd_model, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + expected_num_patches=expected_num_patches, + expected_seq_len=expected_seq_len, + temporal_frames=temporal_frames, + ) + + return wrapper + + +class NeuronTransformerWrapperV1Flash(torch.nn.Module): + """ + Wrapper for V1 Flash compiled transformer (parallel_model_trace + NKI Flash Attention). + + Key features: + - Uses parallel_model_trace API (supports NKI Flash Attention) + - RoPE frequencies are passed as input (like V2) + - Uses NKI Flash Attention for better performance + """ + def __init__(self, original_transformer, compiled_transformer, img_rotary_emb, txt_rotary_emb, + expected_num_patches=1024, expected_seq_len=512, temporal_frames=3): + super().__init__() + self.config = original_transformer.config + self.dtype = original_transformer.dtype + self.device = original_transformer.device + self.compiled_transformer = compiled_transformer + + # Pre-computed RoPE frequencies + self.img_rotary_emb = img_rotary_emb + self.txt_rotary_emb = txt_rotary_emb + + self.expected_num_patches = expected_num_patches + self.expected_seq_len = expected_seq_len + self.temporal_frames = temporal_frames + self.base_patches = expected_num_patches // temporal_frames + + @contextlib.contextmanager + def cache_context(self, name: str): + """Dummy cache context for compatibility with pipeline.""" + yield + + def forward(self, hidden_states, encoder_hidden_states=None, + timestep=None, img_shapes=None, return_dict=False, **kwargs): + """Forward pass using V1 Flash compiled transformer with RoPE as input.""" + batch_size = hidden_states.shape[0] + + # Debug: Print shapes on first call + if not hasattr(self, '_debug_printed'): + print(f"DEBUG Transformer V1 Flash input shapes:") + print(f" hidden_states: {hidden_states.shape}") + print(f" encoder_hidden_states: {encoder_hidden_states.shape}") + print(f" timestep: {timestep.shape}") + print(f" img_rotary_emb: {self.img_rotary_emb.shape}") + print(f" txt_rotary_emb: {self.txt_rotary_emb.shape}") + print(f" temporal_frames: {self.temporal_frames}, base_patches: {self.base_patches}") + self._debug_printed = True + + # Handle hidden_states padding + actual_patches = hidden_states.shape[1] + if actual_patches != self.expected_num_patches: + if actual_patches < self.expected_num_patches: + pad_size = self.expected_num_patches - actual_patches + padding = torch.zeros( + (batch_size, pad_size, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, padding], dim=1) + else: + print(f"ERROR: hidden_states has {actual_patches} patches but model expects {self.expected_num_patches}") + hidden_states = hidden_states[:, :self.expected_num_patches, :] + + # Handle encoder_hidden_states padding + actual_seq_len = encoder_hidden_states.shape[1] + if actual_seq_len != self.expected_seq_len: + if actual_seq_len < self.expected_seq_len: + pad_size = self.expected_seq_len - actual_seq_len + padding = torch.zeros( + (batch_size, pad_size, encoder_hidden_states.shape[2]), + dtype=encoder_hidden_states.dtype, + device=encoder_hidden_states.device + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + else: + print(f"WARNING: Truncating encoder_hidden_states from {actual_seq_len} to {self.expected_seq_len}") + encoder_hidden_states = encoder_hidden_states[:, :self.expected_seq_len, :] + + # Convert timestep to float32 + timestep = timestep.to(torch.float32) + + # Run compiled transformer with RoPE as input + output = self.compiled_transformer( + hidden_states, + encoder_hidden_states, + timestep, + self.img_rotary_emb, + self.txt_rotary_emb + ) + + # Extract tensor from output + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + + # Extract first frame as noise prediction (same as V2) + output_tensor = output_tensor[:, :self.base_patches, :] + + if return_dict: + from diffusers.models.modeling_outputs import Transformer2DModelOutput + return Transformer2DModelOutput(sample=output_tensor) + return (output_tensor,) + + +def load_transformer_v1_flash(compiled_models_dir: str, pipe, args): + """ + Load V1 Flash compiled transformer model using parallel_model_load. + + V1 Flash models are compiled with parallel_model_trace and require: + 1. Model files in transformer_v1_flash/ directory + 2. rope_cache.pt - pre-computed RoPE tensors + 3. config.json - model configuration + + Args: + compiled_models_dir: Directory containing compiled model artifacts + pipe: Pipeline with original transformer (for config) + args: Command line arguments + + Returns: + NeuronTransformerWrapperV1Flash wrapping the loaded model + """ + import json + + v1_flash_path = f"{compiled_models_dir}/transformer_v1_flash" + model_path = f"{v1_flash_path}/model" # Model files are in subdirectory + rope_cache_path = f"{v1_flash_path}/rope_cache.pt" + config_path = f"{v1_flash_path}/config.json" + + # Validate files exist + if not os.path.exists(model_path): + raise FileNotFoundError( + f"V1 Flash transformer not found at {model_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v1_flash.py" + ) + if not os.path.exists(rope_cache_path): + raise FileNotFoundError( + f"V1 Flash RoPE cache not found at {rope_cache_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v1_flash.py" + ) + + # Load config + print(f" Loading V1 Flash config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + expected_num_patches = config["num_patches"] + expected_seq_len = config["text_seq_len"] + temporal_frames = config.get("frame", config.get("patch_multiplier", 3)) + base_patches = expected_num_patches // temporal_frames + print(f" V1 Flash config: patches={expected_num_patches}, seq_len={expected_seq_len}") + print(f" V1 Flash config: temporal_frames={temporal_frames}, base_patches={base_patches}") + print(f" NKI Flash Attention: {config.get('nki_flash_attention', False)}") + + # Load pre-computed RoPE tensors + print(f" Loading RoPE cache from {rope_cache_path}...") + rope_cache = torch.load(rope_cache_path) + img_rotary_emb = rope_cache["img_rotary_emb"].to(torch.bfloat16) + txt_rotary_emb = rope_cache["txt_rotary_emb"].to(torch.bfloat16) + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + # Load compiled model using parallel_model_load (from model subdirectory) + print(f" Loading V1 Flash model from {model_path}...") + compiled_transformer = neuronx_distributed.trace.parallel_model_load(model_path) + print(" V1 Flash model loaded!") + + # Create wrapper + wrapper = NeuronTransformerWrapperV1Flash( + original_transformer=pipe.transformer, + compiled_transformer=compiled_transformer, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + expected_num_patches=expected_num_patches, + expected_seq_len=expected_seq_len, + temporal_frames=temporal_frames, + ) + + return wrapper + + +def load_transformer_v2_flash(compiled_models_dir: str, pipe, args): + """ + Load V2 Flash compiled transformer model using NxDModel API. + + V2 Flash models combine ModelBuilder API with NKI Flash Attention: + 1. nxd_model.pt - the compiled model + 2. weights/ - sharded checkpoints + 3. rope_cache.pt - pre-computed RoPE tensors + 4. config.json - model configuration + + Args: + compiled_models_dir: Directory containing compiled model artifacts + pipe: Pipeline with original transformer (for config) + args: Command line arguments + + Returns: + NeuronTransformerWrapperV2 wrapping the loaded model (reuses V2 wrapper) + """ + import json + + if not NXD_MODEL_AVAILABLE: + raise RuntimeError( + "NxDModel is not available. Please ensure neuronx_distributed is installed correctly." + ) + + v2_flash_path = f"{compiled_models_dir}/transformer_v2_flash" + nxd_model_path = f"{v2_flash_path}/nxd_model.pt" + weights_path = f"{v2_flash_path}/weights" + rope_cache_path = f"{v2_flash_path}/rope_cache.pt" + config_path = f"{v2_flash_path}/config.json" + + # Validate all required files exist + if not os.path.exists(nxd_model_path): + raise FileNotFoundError( + f"V2 Flash transformer model not found at {nxd_model_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2_flash.py" + ) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"V2 Flash transformer weights not found at {weights_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2_flash.py" + ) + if not os.path.exists(rope_cache_path): + raise FileNotFoundError( + f"V2 Flash RoPE cache not found at {rope_cache_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2_flash.py" + ) + + # Load config + print(f" Loading V2 Flash config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + expected_num_patches = config["num_patches"] + expected_seq_len = config["text_seq_len"] + temporal_frames = config.get("frame", config.get("patch_multiplier", 3)) + base_patches = expected_num_patches // temporal_frames + print(f" V2 Flash config: patches={expected_num_patches}, seq_len={expected_seq_len}") + print(f" V2 Flash config: temporal_frames={temporal_frames}, base_patches={base_patches}") + print(f" NKI Flash Attention: {config.get('nki_flash_attention', False)}") + + # Load pre-computed RoPE tensors + print(f" Loading RoPE cache from {rope_cache_path}...") + rope_cache = torch.load(rope_cache_path) + img_rotary_emb = rope_cache["img_rotary_emb"].to(torch.bfloat16) + txt_rotary_emb = rope_cache["txt_rotary_emb"].to(torch.bfloat16) + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + # Load the compiled model using NxDModel.load() + print(f" Loading V2 Flash model from {nxd_model_path}...") + nxd_model = NxDModel.load(nxd_model_path) + + # Load sharded checkpoints + from safetensors.torch import load_file + tp_degree = config.get("tp_degree", 8) + print(f" Loading sharded weights for TP={tp_degree}...") + sharded_checkpoints = [] + for rank in range(tp_degree): + ckpt_path = f"{weights_path}/tp{rank}_sharded_checkpoint.safetensors" + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = load_file(ckpt_path) + sharded_checkpoints.append(ckpt) + if rank == 0: + print(f" Rank 0 checkpoint keys: {len(ckpt)} tensors") + + # Initialize model with weights and move to Neuron + print(" Setting weights...") + nxd_model.set_weights(sharded_checkpoints) + print(" Moving model to Neuron...") + nxd_model.to_neuron() + print(" V2 Flash model initialized on Neuron!") + + # Create wrapper (reuse V2 wrapper since interface is the same) + wrapper = NeuronTransformerWrapperV2( + original_transformer=pipe.transformer, + nxd_model=nxd_model, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + expected_num_patches=expected_num_patches, + expected_seq_len=expected_seq_len, + temporal_frames=temporal_frames, + ) + + return wrapper + + +class NeuronTransformerWrapperV3CP(torch.nn.Module): + """ + Wrapper for V3 CP (Context Parallel) compiled transformer. + + Key features: + - Uses TP=4, CP=2 (world_size=8) + - K/V are all-gathered across CP group before attention + - Each CP rank processes part of the sequence + - RoPE is sharded per CP rank + """ + def __init__(self, original_transformer, nxd_model, img_rotary_emb, txt_rotary_emb, + expected_num_patches=1024, num_patches_padded=None, patches_padding=0, + expected_seq_len=512, temporal_frames=3, cp_degree=2, compiled_batch_size=1): + super().__init__() + self.config = original_transformer.config + self.dtype = original_transformer.dtype + self.device = original_transformer.device + self.nxd_model = nxd_model + + # Full RoPE (will be sharded at runtime per CP rank) + self.img_rotary_emb_full = img_rotary_emb + self.txt_rotary_emb_full = txt_rotary_emb + + self.expected_num_patches = expected_num_patches + self.num_patches_padded = num_patches_padded if num_patches_padded else expected_num_patches + self.patches_padding = patches_padding + self.expected_seq_len = expected_seq_len + self.temporal_frames = temporal_frames + self.base_patches = expected_num_patches // temporal_frames + self.cp_degree = cp_degree + self.compiled_batch_size = compiled_batch_size + + # Local dimensions (per CP rank) - use padded value for internal computation + self.local_num_patches = self.num_patches_padded // cp_degree + self.local_seq_len = expected_seq_len // cp_degree + + @contextlib.contextmanager + def cache_context(self, name: str): + """Dummy cache context for compatibility with pipeline.""" + yield + + def forward(self, hidden_states, encoder_hidden_states=None, + timestep=None, img_shapes=None, return_dict=False, **kwargs): + """Forward pass with Context Parallel.""" + import time as _time + import os as _os + _step_t0 = _time.time() + actual_batch_size = hidden_states.shape[0] + + if not hasattr(self, '_step_count'): + self._step_count = 0 + else: + self._step_count += 1 + + _profile_step = int(_os.environ.get("QIE_PROFILE_STEP", "-1")) + if _profile_step >= 0: + _should_profile = (self._step_count == _profile_step + and not getattr(self, "_did_profile", False)) + else: + _should_profile = False + if _should_profile: + import libneuronxla as _lnx + _dump = _os.environ.get("QIE_PROFILE_DIR", "/tmp/qie_ntff") + _os.makedirs(_dump, exist_ok=True) + _lnx.set_global_profiler_dump_to(_dump) + _lnx.start_global_profiler_inspect(_dump) + print(f">>> QIE V3 CP profile START (step {self._step_count}, dump={_dump})", flush=True) + + # Debug: Print shapes on first call (avoid .min()/.max()/.mean() to prevent CPU sync) + if not hasattr(self, '_debug_printed'): + print(f"DEBUG Transformer V3 CP input shapes:") + print(f" hidden_states: {hidden_states.shape}") + print(f" encoder_hidden_states: {encoder_hidden_states.shape}") + print(f" timestep: {timestep.shape}") + print(f" img_rotary_emb_full: {self.img_rotary_emb_full.shape}") + print(f" txt_rotary_emb_full: {self.txt_rotary_emb_full.shape}") + print(f" CP degree: {self.cp_degree}") + print(f" Compiled batch size: {self.compiled_batch_size}") + print(f" Local patches: {self.local_num_patches}, Local seq_len: {self.local_seq_len}") + self._debug_printed = True + + # Handle batch size padding if needed + # If actual batch size < compiled batch size, we need to pad + if actual_batch_size < self.compiled_batch_size: + pad_batch = self.compiled_batch_size - actual_batch_size + # Pad hidden_states + hidden_states = torch.cat([ + hidden_states, + torch.zeros((pad_batch, hidden_states.shape[1], hidden_states.shape[2]), + dtype=hidden_states.dtype, device=hidden_states.device) + ], dim=0) + # Pad encoder_hidden_states + encoder_hidden_states = torch.cat([ + encoder_hidden_states, + torch.zeros((pad_batch, encoder_hidden_states.shape[1], encoder_hidden_states.shape[2]), + dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device) + ], dim=0) + # Pad timestep + timestep = torch.cat([ + timestep, + timestep[-1:].repeat(pad_batch) # Repeat last timestep for padding + ], dim=0) + elif actual_batch_size > self.compiled_batch_size: + raise ValueError( + f"Input batch size ({actual_batch_size}) exceeds compiled batch size ({self.compiled_batch_size}). " + f"Please recompile the model with --batch_size {actual_batch_size} or higher." + ) + + batch_size = hidden_states.shape[0] # Now equals compiled_batch_size + + # For CP, the model expects LOCAL data (already sharded) + # Since we're running inference, we pass full data and let the model handle it + # The compiled model has the gather/scatter logic built in + + # Handle hidden_states padding to expected_num_patches first + actual_patches = hidden_states.shape[1] + if actual_patches != self.expected_num_patches: + if actual_patches < self.expected_num_patches: + pad_size = self.expected_num_patches - actual_patches + padding = torch.zeros( + (batch_size, pad_size, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, padding], dim=1) + else: + print(f"ERROR: hidden_states has {actual_patches} patches but model expects {self.expected_num_patches}") + hidden_states = hidden_states[:, :self.expected_num_patches, :] + + # Apply CP alignment padding if needed (padding goes to patches, not text) + # This ensures CP split results in sequences aligned for NKI Flash Attention. + # Zero-padding is fine here (verified at CP=2 and CP=4 with correct outputs); the + # pad patches' contribution is negligible vs real tokens in the (unmasked) attention. + if self.patches_padding > 0: + cp_padding = torch.zeros( + (batch_size, self.patches_padding, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, cp_padding], dim=1) + + # Handle encoder_hidden_states padding (no CP padding needed here, text_seq stays unchanged) + actual_seq_len = encoder_hidden_states.shape[1] + if actual_seq_len != self.expected_seq_len: + if actual_seq_len < self.expected_seq_len: + pad_size = self.expected_seq_len - actual_seq_len + padding = torch.zeros( + (batch_size, pad_size, encoder_hidden_states.shape[2]), + dtype=encoder_hidden_states.dtype, + device=encoder_hidden_states.device + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + else: + print(f"WARNING: Truncating encoder_hidden_states from {actual_seq_len} to {self.expected_seq_len}") + encoder_hidden_states = encoder_hidden_states[:, :self.expected_seq_len, :] + + # Convert timestep to float32 + timestep = timestep.to(torch.float32) + + # Run model + # Note: For CP models compiled with ModelBuilder, the sharding is handled internally + # We pass full data and full RoPE - the model handles the rest + output = self.nxd_model( + hidden_states, + encoder_hidden_states, + timestep, + self.img_rotary_emb_full, + self.txt_rotary_emb_full + ) + + # Extract tensor from output + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + + # Extract first frame as noise prediction + output_tensor = output_tensor[:, :self.base_patches, :] + + # Remove batch padding if we added it + if actual_batch_size < self.compiled_batch_size: + output_tensor = output_tensor[:actual_batch_size] + + if _should_profile: + import libneuronxla as _lnx + _lnx.stop_global_profiler_inspect() + self._did_profile = True + print(f">>> QIE V3 CP profile STOP", flush=True) + + _step_dt = _time.time() - _step_t0 + if not hasattr(self, "_step_times"): + self._step_times = [] + self._step_times.append(_step_dt) + print(f" [transformer_step {len(self._step_times)}] {_step_dt*1000:.1f} ms", flush=True) + + if return_dict: + from diffusers.models.modeling_outputs import Transformer2DModelOutput + return Transformer2DModelOutput(sample=output_tensor) + return (output_tensor,) + + +class NeuronTransformerWrapperV3CFG(torch.nn.Module): + """ + Wrapper for V3 CFG (CFG Parallel) compiled transformer. + + Key features: + - Uses TP=4, DP=2 (world_size=8) + - Batches positive + negative prompts (batch_size=2) + - Each DP rank processes one complete batch item (full sequence) + - No K/V all-gather needed + """ + def __init__(self, original_transformer, nxd_model, img_rotary_emb, txt_rotary_emb, + expected_num_patches=1024, num_patches_padded=None, patches_padding=0, + expected_seq_len=512, temporal_frames=3, dp_degree=2): + super().__init__() + self.config = original_transformer.config + self.dtype = original_transformer.dtype + self.device = original_transformer.device + self.nxd_model = nxd_model + + # Full RoPE (same for both batch items, not scattered) + self.img_rotary_emb_full = img_rotary_emb + self.txt_rotary_emb_full = txt_rotary_emb + + self.expected_num_patches = expected_num_patches + self.num_patches_padded = num_patches_padded if num_patches_padded else expected_num_patches + self.patches_padding = patches_padding + self.expected_seq_len = expected_seq_len + self.temporal_frames = temporal_frames + self.base_patches = expected_num_patches // temporal_frames + self.dp_degree = dp_degree + # CFG always uses batch_size=2 (positive + negative) + self.compiled_batch_size = 2 + + @contextlib.contextmanager + def cache_context(self, name: str): + """Dummy cache context for compatibility with pipeline.""" + yield + + def forward(self, hidden_states, encoder_hidden_states=None, + timestep=None, img_shapes=None, return_dict=False, **kwargs): + """Forward pass with CFG Parallel. Expects batch_size=2 input.""" + import time as _time + _step_t0 = _time.time() + batch_size = hidden_states.shape[0] + + # Opt-in per-step profile capture via QIE_PROFILE_STEP env var. + # Writes NTFF to $QIE_PROFILE_DIR (default /tmp/qie_ntff). + import os as _os + _profile_step = int(_os.environ.get("QIE_PROFILE_STEP", "0")) + if _profile_step > 0 and not hasattr(self, "_step_count"): + self._step_count = 0 + if _profile_step > 0: + self._step_count = getattr(self, "_step_count", 0) + 1 + _should_profile = (self._step_count == _profile_step + and not getattr(self, "_did_profile", False)) + else: + _should_profile = False + if _should_profile: + import libneuronxla as _lnx + _dump = _os.environ.get("QIE_PROFILE_DIR", "/tmp/qie_ntff") + _os.makedirs(_dump, exist_ok=True) + _lnx.set_global_profiler_dump_to(_dump) + _lnx.start_global_profiler_inspect(_dump) + print(f">>> QIE profile START (step {self._step_count}, dump={_dump})", flush=True) + + if not hasattr(self, '_debug_printed'): + print(f"DEBUG Transformer V3 CFG input shapes:") + print(f" hidden_states: {hidden_states.shape}") + print(f" encoder_hidden_states: {encoder_hidden_states.shape}") + print(f" timestep: {timestep.shape}") + print(f" img_rotary_emb_full: {self.img_rotary_emb_full.shape}") + print(f" txt_rotary_emb_full: {self.txt_rotary_emb_full.shape}") + print(f" DP degree: {self.dp_degree}") + print(f" Compiled batch size: {self.compiled_batch_size}") + self._debug_printed = True + + if batch_size != self.compiled_batch_size: + raise ValueError( + f"V3 CFG requires batch_size={self.compiled_batch_size} " + f"(negative + positive), got {batch_size}" + ) + + # Pad hidden_states to expected_num_patches + actual_patches = hidden_states.shape[1] + if actual_patches < self.expected_num_patches: + pad_size = self.expected_num_patches - actual_patches + padding = torch.zeros( + (batch_size, pad_size, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, padding], dim=1) + elif actual_patches > self.expected_num_patches: + hidden_states = hidden_states[:, :self.expected_num_patches, :] + + # Apply alignment padding if needed + if self.patches_padding > 0: + cfg_padding = torch.zeros( + (batch_size, self.patches_padding, hidden_states.shape[2]), + dtype=hidden_states.dtype, + device=hidden_states.device + ) + hidden_states = torch.cat([hidden_states, cfg_padding], dim=1) + + # Pad encoder_hidden_states to expected_seq_len + actual_seq_len = encoder_hidden_states.shape[1] + if actual_seq_len < self.expected_seq_len: + pad_size = self.expected_seq_len - actual_seq_len + padding = torch.zeros( + (batch_size, pad_size, encoder_hidden_states.shape[2]), + dtype=encoder_hidden_states.dtype, + device=encoder_hidden_states.device + ) + encoder_hidden_states = torch.cat([encoder_hidden_states, padding], dim=1) + elif actual_seq_len > self.expected_seq_len: + encoder_hidden_states = encoder_hidden_states[:, :self.expected_seq_len, :] + + # Convert timestep to float32 + timestep = timestep.to(torch.float32) + + # Run model - passes full RoPE (same for both batch items) + output = self.nxd_model( + hidden_states, + encoder_hidden_states, + timestep, + self.img_rotary_emb_full, + self.txt_rotary_emb_full + ) + + # Extract tensor from output + if isinstance(output, tuple): + output_tensor = output[0] + else: + output_tensor = output + + # Extract first frame as noise prediction for both batch items + output_tensor = output_tensor[:, :self.base_patches, :] + + if _should_profile: + import libneuronxla as _lnx + _lnx.stop_global_profiler_inspect() + self._did_profile = True + print(f">>> QIE profile STOP", flush=True) + + _step_dt = _time.time() - _step_t0 + if not hasattr(self, "_step_times"): + self._step_times = [] + self._step_times.append(_step_dt) + print(f" [transformer_step {len(self._step_times)}] {_step_dt*1000:.1f} ms") + + if return_dict: + from diffusers.models.modeling_outputs import Transformer2DModelOutput + return Transformer2DModelOutput(sample=output_tensor) + return (output_tensor,) + + +def load_transformer_v3_cfg(compiled_models_dir: str, pipe, args): + """ + Load V3 CFG compiled transformer with CFG Parallelism. + + V3 CFG models use: + - TP=4, DP=2 (world_size=8) + - Batch parallelism for negative + positive prompts + - NKI Flash Attention + + Args: + compiled_models_dir: Directory containing compiled model artifacts + pipe: Pipeline with original transformer (for config) + args: Command line arguments + + Returns: + NeuronTransformerWrapperV3CFG wrapping the loaded model + """ + import json + + if not NXD_MODEL_AVAILABLE: + raise RuntimeError( + "NxDModel is not available. Please ensure neuronx_distributed is installed correctly." + ) + + v3_cfg_path = f"{compiled_models_dir}/transformer_v3_cfg" + nxd_model_path = f"{v3_cfg_path}/nxd_model.pt" + weights_path = f"{v3_cfg_path}/weights" + rope_cache_path = f"{v3_cfg_path}/rope_cache.pt" + config_path = f"{v3_cfg_path}/config.json" + + # Validate files exist + for path, name in [(nxd_model_path, "model"), (weights_path, "weights"), (rope_cache_path, "RoPE cache")]: + if not os.path.exists(path): + raise FileNotFoundError( + f"V3 CFG transformer {name} not found at {path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v3_cfg.py" + ) + + # Load config + print(f" Loading V3 CFG config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + expected_num_patches = config["num_patches"] + num_patches_padded = config.get("num_patches_padded", expected_num_patches) + patches_padding = config.get("patches_padding", 0) + expected_seq_len = config["text_seq_len"] + temporal_frames = config.get("frame", config.get("patch_multiplier", 3)) + tp_degree = config.get("tp_degree", 4) + world_size = config.get("world_size", 8) + dp_degree = config.get("dp_degree", 2) + compiled_batch_size = config.get("batch_size", 2) + base_patches = expected_num_patches // temporal_frames + + print(f" V3 CFG config: patches={expected_num_patches}, seq_len={expected_seq_len}") + if patches_padding > 0: + print(f" V3 CFG config: patches_padded={num_patches_padded} (+{patches_padding} for alignment)") + print(f" V3 CFG config: temporal_frames={temporal_frames}, base_patches={base_patches}") + print(f" V3 CFG config: TP={tp_degree}, world_size={world_size}, DP={dp_degree}") + print(f" V3 CFG config: batch_size={compiled_batch_size}") + print(f" CFG Parallel: {config.get('cfg_parallel', False)}") + print(f" NKI Flash Attention: {config.get('nki_flash_attention', False)}") + + # Load pre-computed RoPE tensors (full, not sharded) + print(f" Loading RoPE cache from {rope_cache_path}...") + rope_cache = torch.load(rope_cache_path) + img_rotary_emb = rope_cache["img_rotary_emb"].to(torch.bfloat16) + txt_rotary_emb = rope_cache["txt_rotary_emb"].to(torch.bfloat16) + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + # Load the compiled model using NxDModel.load() + print(f" Loading V3 CFG model from {nxd_model_path}...") + nxd_model = NxDModel.load(nxd_model_path) + + # Load sharded checkpoints + # For CFG Parallel: TP=4 but world_size=8 + # Each DP rank uses the same weights as its corresponding TP rank + # Duplicate: [tp0, tp1, tp2, tp3] -> [tp0, tp1, tp2, tp3, tp0, tp1, tp2, tp3] + from safetensors.torch import load_file + print(f" Loading sharded weights for TP={tp_degree}, world_size={world_size}...") + + # First load the TP checkpoints + tp_checkpoints = [] + for rank in range(tp_degree): + ckpt_path = f"{weights_path}/tp{rank}_sharded_checkpoint.safetensors" + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = load_file(ckpt_path) + tp_checkpoints.append(ckpt) + if rank == 0: + print(f" Rank 0 checkpoint keys: {len(ckpt)} tensors") + + # Duplicate checkpoints for each DP rank with unique global_rank values + sharded_checkpoints = [] + for dp_rank in range(dp_degree): + for tp_rank in range(tp_degree): + world_rank = dp_rank * tp_degree + tp_rank + ckpt_copy = {k: v.clone() for k, v in tp_checkpoints[tp_rank].items()} + + # Set the correct global_rank for SPMD scatter/gather + global_rank_key = 'transformer.global_rank.rank' + if global_rank_key in ckpt_copy: + ckpt_copy[global_rank_key] = torch.tensor([world_rank], dtype=torch.int32) + if world_rank < 2 or world_rank >= world_size - 2: + print(f" World rank {world_rank}: global_rank set to {world_rank}") + + sharded_checkpoints.append(ckpt_copy) + + print(f" Total checkpoints: {len(sharded_checkpoints)} (TP={tp_degree} x DP={dp_degree})") + print(f" Each world rank has unique global_rank for SPMD execution") + + # Initialize model with weights and move to Neuron + print(" Setting weights...") + for i in [0, 4]: # Check first rank of each DP group + if i < len(sharded_checkpoints): + ckpt = sharded_checkpoints[i] + gr_key = 'transformer.global_rank.rank' + if gr_key in ckpt: + print(f" Checkpoint[{i}] global_rank = {ckpt[gr_key].item()}") + nxd_model.set_weights(sharded_checkpoints) + print(" Moving model to Neuron...") + nxd_model.to_neuron() + print(" V3 CFG model initialized on Neuron!") + + # Create wrapper + wrapper = NeuronTransformerWrapperV3CFG( + original_transformer=pipe.transformer, + nxd_model=nxd_model, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + expected_num_patches=expected_num_patches, + num_patches_padded=num_patches_padded, + patches_padding=patches_padding, + expected_seq_len=expected_seq_len, + temporal_frames=temporal_frames, + dp_degree=dp_degree, + ) + + return wrapper + + +def patch_pipeline_for_cfg_parallel(pipe): + """ + Monkey-patch the pipeline's denoising loop for batched CFG inference. + + Instead of two sequential transformer calls (positive + negative), + this batches both into a single call with batch_size=2. + The V3 CFG transformer scatters along batch dim across DP ranks. + """ + import types + import numpy as np + from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus import ( + calculate_dimensions, + calculate_shift, + retrieve_timesteps, + QwenImagePipelineOutput, + CONDITION_IMAGE_SIZE, + VAE_IMAGE_SIZE, + logger, + ) + try: + from diffusers.utils import XLA_AVAILABLE + except ImportError: + XLA_AVAILABLE = False + if XLA_AVAILABLE: + import torch_xla.core.xla_model as xm + + def __call__( + self, + image=None, + prompt=None, + negative_prompt=None, + true_cfg_scale: float = 4.0, + height=None, + width=None, + num_inference_steps: int = 50, + sigmas=None, + guidance_scale=None, + num_images_per_prompt: int = 1, + generator=None, + latents=None, + prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds=None, + negative_prompt_embeds_mask=None, + output_type="pil", + return_dict=True, + attention_kwargs=None, + callback_on_step_end=None, + callback_on_step_end_tensor_inputs=["latents"], + max_sequence_length: int = 512, + ): + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs + self.check_inputs( + prompt, height, width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + import time as _time + _qie_stage_times = getattr(self, "_qie_stage_times", {}) + _te_t0 = _time.time() + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + _qie_stage_times["text_encoder"] = _time.time() - _te_t0 + print(f" [text_encoder] {_qie_stage_times['text_encoder']*1000:.1f} ms", flush=True) + self._qie_stage_times = _qie_stage_times + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + _ve_t0 = _time.time() + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + _qie_stage_times["vae_encode"] = _time.time() - _ve_t0 + print(f" [vae_encode] {_qie_stage_times['vae_encode']*1000:.1f} ms (encodes {len(vae_images)} input image(s))", flush=True) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, sigmas=sigmas, mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + # ===== CFG PARALLEL: Pre-concatenate embeddings ===== + if do_true_cfg: + # Pad negative_prompt_embeds to match prompt_embeds length if needed + neg_seq = negative_prompt_embeds.shape[1] + pos_seq = prompt_embeds.shape[1] + if neg_seq < pos_seq: + pad = torch.zeros( + (negative_prompt_embeds.shape[0], pos_seq - neg_seq, negative_prompt_embeds.shape[2]), + dtype=negative_prompt_embeds.dtype, device=negative_prompt_embeds.device + ) + negative_prompt_embeds = torch.cat([negative_prompt_embeds, pad], dim=1) + elif pos_seq < neg_seq: + pad = torch.zeros( + (prompt_embeds.shape[0], neg_seq - pos_seq, prompt_embeds.shape[2]), + dtype=prompt_embeds.dtype, device=prompt_embeds.device + ) + prompt_embeds = torch.cat([prompt_embeds, pad], dim=1) + # [negative, positive] along batch dim -> [2, seq, C] + batched_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + latent_model_input = latents + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + + if do_true_cfg: + # If the wrapped transformer was compiled with batch=1 + # (e.g. V3 CP NEFF), run CFG sequentially as two batch=1 + # forwards. Otherwise (V3 CFG NEFF), batch neg+pos into a + # single batch=2 forward and split. + cb = getattr(self.transformer, "compiled_batch_size", 2) + if cb < 2: + # Sequential CFG: two batch=1 calls + ts1 = t.expand(latents.shape[0]).to(latents.dtype) / 1000 + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=ts1, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + return_dict=False, + )[0][:, :latents.size(1)] + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=ts1, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + return_dict=False, + )[0][:, :latents.size(1)] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + # ===== CFG PARALLEL: Single batched call ===== + # Duplicate latents for both negative and positive: [2, patches, C] + batched_latent = torch.cat([latent_model_input, latent_model_input], dim=0) + batched_timestep = t.expand(2).to(latents.dtype) / 1000 + + batched_output = self.transformer( + hidden_states=batched_latent, + timestep=batched_timestep, + encoder_hidden_states=batched_embeds, + img_shapes=img_shapes, + return_dict=False, + )[0] + + # Split: index 0 is negative, index 1 is positive + noise_pred = batched_output[1:2, :latents.size(1)] # positive + neg_noise_pred = batched_output[0:1, :latents.size(1)] # negative + + # Apply CFG with norm rescale (Qwen-specific) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + else: + # No CFG - single call + timestep_input = t.expand(latents.shape[0]).to(latents.dtype) + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep_input / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + img_shapes=img_shapes, + return_dict=False, + )[0] + noise_pred = noise_pred[:, :latents.size(1)] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + import time as _time + _vae_t0 = _time.time() + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + _vae_dt = _time.time() - _vae_t0 + print(f" [vae_decode] {_vae_dt*1000:.1f} ms", flush=True) + if hasattr(self, "_qie_stage_times"): + self._qie_stage_times["vae_decode"] = _vae_dt + image = self.image_processor.postprocess(image, output_type=output_type) + + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) + + pipe.__class__.__call__ = __call__ + print(" Pipeline patched for CFG Parallel (batched denoising loop)") + + +def load_transformer_v3_cp(compiled_models_dir: str, pipe, args): + """ + Load V3 CP compiled transformer with Context Parallel. + + V3 CP models use: + - TP=4, CP=2 (world_size=8) + - K/V all-gather across CP group + - NKI Flash Attention + + Args: + compiled_models_dir: Directory containing compiled model artifacts + pipe: Pipeline with original transformer (for config) + args: Command line arguments + + Returns: + NeuronTransformerWrapperV3CP wrapping the loaded model + """ + import json + + if not NXD_MODEL_AVAILABLE: + raise RuntimeError( + "NxDModel is not available. Please ensure neuronx_distributed is installed correctly." + ) + + v3_cp_path = f"{compiled_models_dir}/transformer_v3_cp" + nxd_model_path = f"{v3_cp_path}/nxd_model.pt" + weights_path = f"{v3_cp_path}/weights" + rope_cache_path = f"{v3_cp_path}/rope_cache.pt" + config_path = f"{v3_cp_path}/config.json" + + # Validate files exist + if not os.path.exists(nxd_model_path): + raise FileNotFoundError( + f"V3 CP transformer model not found at {nxd_model_path}\n" + "Please run: ./compile.sh v3_cp" + ) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"V3 CP transformer weights not found at {weights_path}\n" + "Please run: ./compile.sh v3_cp" + ) + if not os.path.exists(rope_cache_path): + raise FileNotFoundError( + f"V3 CP RoPE cache not found at {rope_cache_path}\n" + "Please run: ./compile.sh v3_cp" + ) + + # Load config + print(f" Loading V3 CP config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + expected_num_patches = config["num_patches"] + num_patches_padded = config.get("num_patches_padded", expected_num_patches) + patches_padding = config.get("patches_padding", 0) + expected_seq_len = config["text_seq_len"] + temporal_frames = config.get("frame", config.get("patch_multiplier", 3)) + tp_degree = config.get("tp_degree", 4) + world_size = config.get("world_size", 8) + cp_degree = config.get("cp_degree", 2) + compiled_batch_size = config.get("batch_size", 1) + base_patches = expected_num_patches // temporal_frames + + print(f" V3 CP config: patches={expected_num_patches}, seq_len={expected_seq_len}") + if patches_padding > 0: + print(f" V3 CP config: patches_padded={num_patches_padded} (+{patches_padding} for CP alignment)") + print(f" V3 CP config: temporal_frames={temporal_frames}, base_patches={base_patches}") + print(f" V3 CP config: TP={tp_degree}, world_size={world_size}, CP={cp_degree}") + print(f" V3 CP config: batch_size={compiled_batch_size}") + print(f" Context Parallel: {config.get('context_parallel', False)}") + print(f" NKI Flash Attention: {config.get('nki_flash_attention', False)}") + + # Load pre-computed RoPE tensors (full, not sharded) + print(f" Loading RoPE cache from {rope_cache_path}...") + rope_cache = torch.load(rope_cache_path) + img_rotary_emb = rope_cache["img_rotary_emb"].to(torch.bfloat16) + txt_rotary_emb = rope_cache["txt_rotary_emb"].to(torch.bfloat16) + print(f" img_rotary_emb: {img_rotary_emb.shape}") + print(f" txt_rotary_emb: {txt_rotary_emb.shape}") + + # Load the compiled model using NxDModel.load() + print(f" Loading V3 CP model from {nxd_model_path}...") + nxd_model = NxDModel.load(nxd_model_path) + + # Load sharded checkpoints + # For Context Parallel: TP=4 but world_size=8 + # Each DP rank (CP rank) uses the same weights as its corresponding TP rank + # So we need to duplicate: [tp0, tp1, tp2, tp3] -> [tp0, tp1, tp2, tp3, tp0, tp1, tp2, tp3] + from safetensors.torch import load_file + print(f" Loading sharded weights for TP={tp_degree}, world_size={world_size}...") + + # First load the TP checkpoints + tp_checkpoints = [] + for rank in range(tp_degree): + ckpt_path = f"{weights_path}/tp{rank}_sharded_checkpoint.safetensors" + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = load_file(ckpt_path) + tp_checkpoints.append(ckpt) + if rank == 0: + print(f" Rank 0 checkpoint keys: {len(ckpt)} tensors") + + # For CP, duplicate checkpoints for each DP rank + # world_size = tp_degree * dp_degree (dp_degree = cp_degree) + # IMPORTANT: Each world rank needs a unique global_rank value for SPMD scatter/gather + sharded_checkpoints = [] + for dp_rank in range(cp_degree): + for tp_rank in range(tp_degree): + # Clone the checkpoint so we can modify global_rank independently + world_rank = dp_rank * tp_degree + tp_rank + ckpt_copy = {k: v.clone() for k, v in tp_checkpoints[tp_rank].items()} + + # Set the correct global_rank for this world rank + # This is CRITICAL for SPMDRank to return the correct rank at runtime + global_rank_key = 'transformer.global_rank.rank' + if global_rank_key in ckpt_copy: + ckpt_copy[global_rank_key] = torch.tensor([world_rank], dtype=torch.int32) + if world_rank < 2 or world_rank >= world_size - 2: + print(f" World rank {world_rank}: global_rank set to {world_rank}") + + sharded_checkpoints.append(ckpt_copy) + + print(f" Total checkpoints: {len(sharded_checkpoints)} (TP={tp_degree} x CP={cp_degree})") + print(f" Each world rank has unique global_rank for SPMD execution") + + # Initialize model with weights and move to Neuron + print(" Setting weights...") + # Debug: Verify global_rank values in checkpoints + for i in [0, 4]: # Check first rank of each DP group + if i < len(sharded_checkpoints): + ckpt = sharded_checkpoints[i] + gr_key = 'transformer.global_rank.rank' + if gr_key in ckpt: + print(f" Checkpoint[{i}] global_rank = {ckpt[gr_key].item()}") + else: + print(f" WARNING: Checkpoint[{i}] missing {gr_key}") + nxd_model.set_weights(sharded_checkpoints) + print(" Moving model to Neuron...") + nxd_model.to_neuron() + print(" V3 CP model initialized on Neuron!") + + # Create wrapper + wrapper = NeuronTransformerWrapperV3CP( + original_transformer=pipe.transformer, + nxd_model=nxd_model, + img_rotary_emb=img_rotary_emb, + txt_rotary_emb=txt_rotary_emb, + expected_num_patches=expected_num_patches, + num_patches_padded=num_patches_padded, + patches_padding=patches_padding, + expected_seq_len=expected_seq_len, + temporal_frames=temporal_frames, + cp_degree=cp_degree, + compiled_batch_size=compiled_batch_size, + ) + + return wrapper + + +def load_language_model_v3(compiled_models_dir: str): + """ + Load V3 compiled language model using NxDModel. + + V3 language models use: + - TP=4, world_size=8 (matching V3 CP transformer) + - ModelBuilder API (NxDModel) + + Note: Unlike V3 CP transformer which splits sequence (Context Parallel), + the language model processes the full sequence on all ranks. + Checkpoints are simply duplicated for world_size=8. + + Returns: + NxDModel wrapping the loaded language model + """ + import json + + if not NXD_MODEL_AVAILABLE: + raise RuntimeError( + "NxDModel is not available. Please ensure neuronx_distributed is installed correctly." + ) + + v3_path = f"{compiled_models_dir}/language_model_v3" + nxd_model_path = f"{v3_path}/nxd_model.pt" + weights_path = f"{v3_path}/weights" + config_path = f"{v3_path}/config.json" + + # Validate files exist + if not os.path.exists(nxd_model_path): + raise FileNotFoundError( + f"V3 language model not found at {nxd_model_path}\n" + "Please run: python neuron_qwen_image_edit/compile_language_model_v3.py" + ) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"V3 language model weights not found at {weights_path}\n" + "Please run: python neuron_qwen_image_edit/compile_language_model_v3.py" + ) + + # Load config + print(f" Loading V3 language model config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + tp_degree = config.get("tp_degree", 4) + world_size = config.get("world_size", 8) + max_seq_len = config.get("max_sequence_length", 1024) + batch_size = config.get("batch_size", 1) + cp_degree = world_size // tp_degree # 2 + + print(f" V3 language model config:") + print(f" TP={tp_degree}, world_size={world_size}, batch_size={batch_size}") + print(f" max_sequence_length={max_seq_len}") + print(f" GQA: 28Q/4=7 heads/rank, 4KV/4=1 head/rank (perfect fit)") + + # Load the compiled model using NxDModel.load() + print(f" Loading V3 language model from {nxd_model_path}...") + nxd_model = NxDModel.load(nxd_model_path) + + # Load sharded checkpoints + # For world_size=8 with TP=4: duplicate TP checkpoints for each CP rank + from safetensors.torch import load_file + print(f" Loading sharded weights for TP={tp_degree}, world_size={world_size}...") + + # First load the TP checkpoints (only tp_degree files exist) + tp_checkpoints = [] + for rank in range(tp_degree): + ckpt_path = f"{weights_path}/tp{rank}_sharded_checkpoint.safetensors" + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = load_file(ckpt_path) + tp_checkpoints.append(ckpt) + if rank == 0: + print(f" Rank 0 checkpoint keys: {len(ckpt)} tensors") + + # Duplicate for world_size=8 + # Unlike transformer CP which needs different global_rank values, + # language model processes full sequence on all ranks (no CP scatter/gather) + # So we simply duplicate the TP checkpoints + sharded_checkpoints = [] + for cp_rank in range(cp_degree): + for tp_rank in range(tp_degree): + # Clone the checkpoint + ckpt_copy = {k: v.clone() for k, v in tp_checkpoints[tp_rank].items()} + sharded_checkpoints.append(ckpt_copy) + + print(f" Total checkpoints: {len(sharded_checkpoints)} (TP={tp_degree} x CP={cp_degree})") + + # Initialize model with weights and move to Neuron + print(" Setting weights...") + nxd_model.set_weights(sharded_checkpoints) + print(" Moving model to Neuron...") + nxd_model.to_neuron() + print(" V3 language model initialized on Neuron!") + + return nxd_model, config + + +def load_vision_encoder_v3(compiled_models_dir: str): + """ + Load V3 compiled vision encoder using NxDModel. + + V3 vision encoder uses: + - TP=4, world_size=8 (matching V3 CP transformer) + - ModelBuilder API (NxDModel) + - Float32 precision for accuracy + + Note: Vision encoder dimensions require TP=4: + - QKV dim = 3420, 3420/4=855 (divisible) + - 3420/8=427.5 (NOT divisible, TP=8 doesn't work) + + Returns: + NxDModel wrapping the loaded vision encoder, config dict + """ + import json + + if not NXD_MODEL_AVAILABLE: + raise RuntimeError( + "NxDModel is not available. Please ensure neuronx_distributed is installed correctly." + ) + + v3_path = f"{compiled_models_dir}/vision_encoder_v3" + nxd_model_path = f"{v3_path}/nxd_model.pt" + weights_path = f"{v3_path}/weights" + config_path = f"{v3_path}/config.json" + + # Validate files exist + if not os.path.exists(nxd_model_path): + raise FileNotFoundError( + f"V3 vision encoder not found at {nxd_model_path}\n" + "Please run: python neuron_qwen_image_edit/compile_vision_encoder_v3.py" + ) + if not os.path.exists(weights_path): + raise FileNotFoundError( + f"V3 vision encoder weights not found at {weights_path}\n" + "Please run: python neuron_qwen_image_edit/compile_vision_encoder_v3.py" + ) + + # Load config + print(f" Loading V3 vision encoder config from {config_path}...") + with open(config_path, "r") as f: + config = json.load(f) + + tp_degree = config.get("tp_degree", 4) + world_size = config.get("world_size", 8) + image_size = config.get("image_size", 448) + cp_degree = world_size // tp_degree # 2 + + print(f" V3 vision encoder config:") + print(f" TP={tp_degree}, world_size={world_size}") + print(f" image_size={image_size}") + print(f" dtype=float32 (required for accuracy)") + + # Load the compiled model using NxDModel.load() + print(f" Loading V3 vision encoder from {nxd_model_path}...") + nxd_model = NxDModel.load(nxd_model_path) + + # Load sharded checkpoints + # For world_size=8 with TP=4: duplicate TP checkpoints for each CP rank + from safetensors.torch import load_file + print(f" Loading sharded weights for TP={tp_degree}, world_size={world_size}...") + + # First load the TP checkpoints (only tp_degree files exist) + tp_checkpoints = [] + for rank in range(tp_degree): + ckpt_path = f"{weights_path}/tp{rank}_sharded_checkpoint.safetensors" + if not os.path.exists(ckpt_path): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = load_file(ckpt_path) + tp_checkpoints.append(ckpt) + if rank == 0: + print(f" Rank 0 checkpoint keys: {len(ckpt)} tensors") + + # Duplicate for world_size=8 + # Vision encoder processes fixed-size patches on all ranks (no CP scatter/gather) + # So we simply duplicate the TP checkpoints + sharded_checkpoints = [] + for cp_rank in range(cp_degree): + for tp_rank in range(tp_degree): + # Clone the checkpoint + ckpt_copy = {k: v.clone() for k, v in tp_checkpoints[tp_rank].items()} + sharded_checkpoints.append(ckpt_copy) + + print(f" Total checkpoints: {len(sharded_checkpoints)} (TP={tp_degree} x CP={cp_degree})") + + # Initialize model with weights and move to Neuron + print(" Setting weights...") + nxd_model.set_weights(sharded_checkpoints) + print(" Moving model to Neuron...") + nxd_model.to_neuron() + print(" V3 vision encoder initialized on Neuron (TP=4, float32)!") + + return nxd_model, config + + +class NeuronVAEWrapper(torch.nn.Module): + """ + Wrapper for VAE with compiled encoder and decoder on Trainium2. + + Supports tiled processing for images larger than the compiled tile size. + """ + def __init__(self, original_vae, compiled_encoder, compiled_decoder, + compiled_quant_conv=None, compiled_post_quant_conv=None, + expected_height=512, expected_width=512, + compiled_batch_size=1, cpu_decode=False): + super().__init__() + self.config = original_vae.config + self.dtype = original_vae.dtype + + # Compiled models - ALL run on Neuron + self.compiled_encoder = compiled_encoder + self.compiled_decoder = compiled_decoder + self.compiled_quant_conv = compiled_quant_conv + self.compiled_post_quant_conv = compiled_post_quant_conv + + # Batch size the VAE was compiled with (for batched encode/decode) + self.compiled_batch_size = compiled_batch_size + + # CPU decode mode for debugging + self.cpu_decode = cpu_decode + if cpu_decode: + print(" [DEBUG] VAE Decoder will run on CPU!") + # Keep CPU decoder and post_quant_conv + self.cpu_decoder = original_vae.decoder + self.cpu_post_quant_conv = original_vae.post_quant_conv + self.cpu_decoder.eval() + + # Scaling factors - convert to tensors for broadcasting + # Shape: (1, z_dim, 1, 1, 1) for proper broadcasting with 5D latents (b, c, t, h, w) + if isinstance(original_vae.latents_mean, list): + self.latents_mean = torch.tensor(original_vae.latents_mean).view(1, -1, 1, 1, 1) + else: + self.latents_mean = original_vae.latents_mean + if isinstance(original_vae.latents_std, list): + self.latents_std = torch.tensor(original_vae.latents_std).view(1, -1, 1, 1, 1) + else: + self.latents_std = original_vae.latents_std + + # z_dim for shape calculations + self.z_dim = original_vae.config.z_dim + + # Expected input size for compiled model (tile size) + self.expected_height = expected_height + self.expected_width = expected_width + + # Tiling parameters for larger images + self.tile_sample_min_height = expected_height + self.tile_sample_min_width = expected_width + # Overlap between tiles (for blending) + self.tile_overlap = 64 # pixels of overlap + self.tile_sample_stride_height = expected_height - self.tile_overlap + self.tile_sample_stride_width = expected_width - self.tile_overlap + # Spatial compression ratio (8x for this VAE) + self.spatial_compression_ratio = 8 + + def _needs_tiling(self, h, w): + """Check if image needs tiled processing.""" + return h > self.expected_height or w > self.expected_width + + def _blend_v(self, a, b, blend_extent): + """Blend two tensors vertically.""" + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def _blend_h(self, a, b, blend_extent): + """Blend two tensors horizontally.""" + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _encode_tile(self, x): + """Encode a single tile through compiled encoder.""" + actual_batch = x.shape[0] + + # Pad batch dimension if needed + if actual_batch < self.compiled_batch_size: + pad_batch = self.compiled_batch_size - actual_batch + x = torch.cat([x, torch.zeros_like(x[:1]).repeat(pad_batch, 1, 1, 1, 1)], dim=0) + + h = self.compiled_encoder(x) + if self.compiled_quant_conv is not None: + moments = self.compiled_quant_conv(h) + else: + moments = h + + # Remove batch padding + if actual_batch < self.compiled_batch_size: + moments = moments[:actual_batch] + + return moments + + def _decode_tile(self, z): + """Decode a single tile through compiled decoder.""" + actual_batch = z.shape[0] + + # Pad batch dimension if needed + if actual_batch < self.compiled_batch_size: + pad_batch = self.compiled_batch_size - actual_batch + z = torch.cat([z, torch.zeros_like(z[:1]).repeat(pad_batch, 1, 1, 1, 1)], dim=0) + + if self.compiled_post_quant_conv is not None: + z = self.compiled_post_quant_conv(z) + output = self.compiled_decoder(z) + + # Remove batch padding + if actual_batch < self.compiled_batch_size: + output = output[:actual_batch] + + return output + + def encode(self, x, return_dict=True): + """Encode images to latents on Neuron. Supports tiled encoding for large images.""" + # Ensure 5D format: (batch, channels, temporal, height, width) + if len(x.shape) == 4: + x = x.unsqueeze(2) # Add temporal dimension + + b, c, t, h, w = x.shape + + # Convert to bfloat16 (compiled models expect bfloat16) + x = x.to(torch.bfloat16) + + # Check if tiling is needed + if self._needs_tiling(h, w): + print(f" Using tiled encoding: {h}x{w} -> tiles of {self.expected_height}x{self.expected_width}") + moments = self._tiled_encode(x) + else: + # Pad to expected size if smaller + if h != self.expected_height or w != self.expected_width: + # Pad with zeros + pad_h = self.expected_height - h + pad_w = self.expected_width - w + x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h)) + + moments = self._encode_tile(x) + + # Remove padding from latents if we padded + if h != self.expected_height or w != self.expected_width: + latent_h = h // self.spatial_compression_ratio + latent_w = w // self.spatial_compression_ratio + moments = moments[:, :, :, :latent_h, :latent_w] + + # Split into mean and logvar + mean, logvar = moments.chunk(2, dim=1) + + # Sample from distribution (for sample() method) + std = torch.exp(0.5 * logvar) + sample = mean + std * torch.randn_like(std) + + if return_dict: + class LatentDist: + def __init__(self, sample_val, mean_val): + self._sample = sample_val + self._mean = mean_val + def sample(self): + return self._sample + def mode(self): + return self._mean + @property + def mean(self): + return self._mean + + class EncoderOutput: + def __init__(self, latent_dist): + self.latent_dist = latent_dist + + return EncoderOutput(LatentDist(sample, mean)) + return sample + + def _tiled_encode(self, x): + """Encode large image using tiled processing. + + When the compiled NEFF has batch_size > 1, tiles are gathered into a + single batched call instead of one NEFF launch per tile. For an input + that produces N tiles and a compiled batch size B: + - if N <= B: 1 NEFF call with batch padding + - if N > B: ceil(N/B) NEFF calls with B tiles each + Each NEFF launch pays a ~37 ms overhead, so collapsing 6 launches into + 1 saves ~5×37 ms ≈ 185 ms per input image. + """ + b, c, t, h, w = x.shape + + # Latent dimensions + latent_h = h // self.spatial_compression_ratio + latent_w = w // self.spatial_compression_ratio + tile_latent_h = self.expected_height // self.spatial_compression_ratio + tile_latent_w = self.expected_width // self.spatial_compression_ratio + tile_latent_stride_h = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_w = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_h = tile_latent_h - tile_latent_stride_h + blend_w = tile_latent_w - tile_latent_stride_w + + # Phase 1: enumerate all tiles, padding edge tiles to expected size. + tile_inputs = [] + tile_meta = [] # (row_idx, col_idx, actual_h, actual_w) + row_idx = 0 + for i in range(0, h, self.tile_sample_stride_height): + col_idx = 0 + for j in range(0, w, self.tile_sample_stride_width): + tile_h_end = min(i + self.tile_sample_min_height, h) + tile_w_end = min(j + self.tile_sample_min_width, w) + tile = x[:, :, :, i:tile_h_end, j:tile_w_end] + actual_h, actual_w = tile.shape[3], tile.shape[4] + if actual_h < self.expected_height or actual_w < self.expected_width: + pad_h = self.expected_height - actual_h + pad_w = self.expected_width - actual_w + tile = torch.nn.functional.pad(tile, (0, pad_w, 0, pad_h)) + tile_inputs.append(tile) + tile_meta.append((row_idx, col_idx, actual_h, actual_w)) + col_idx += 1 + row_idx += 1 + n_rows = row_idx + n_cols = max(m[1] for m in tile_meta) + 1 + + # Phase 2: run encoder in chunks of compiled_batch_size. + # Each tile already has batch dim 1; cat along dim 0 to form the batch. + encoded_flat = [] + bs = self.compiled_batch_size + for chunk_start in range(0, len(tile_inputs), bs): + chunk = tile_inputs[chunk_start:chunk_start + bs] + batched = torch.cat(chunk, dim=0) # (n_chunk, c, t, eh, ew) + encoded = self._encode_tile(batched) # NEFF call; pads to bs internally + # _encode_tile already strips batch padding; encoded has n_chunk samples + for k in range(len(chunk)): + encoded_flat.append(encoded[k:k+1]) + + # Phase 3: scatter into 2D grid and crop back to non-padded sizes. + rows = [[None] * n_cols for _ in range(n_rows)] + for (ri, ci, actual_h, actual_w), enc in zip(tile_meta, encoded_flat): + if actual_h < self.expected_height or actual_w < self.expected_width: + crop_h = actual_h // self.spatial_compression_ratio + crop_w = actual_w // self.spatial_compression_ratio + enc = enc[:, :, :, :crop_h, :crop_w] + rows[ri][ci] = enc + + # Blend tiles together (unchanged from per-tile path) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self._blend_v(rows[i - 1][j], tile, blend_h) + if j > 0: + tile = self._blend_h(row[j - 1], tile, blend_w) + result_row.append(tile[:, :, :, :tile_latent_stride_h, :tile_latent_stride_w]) + result_rows.append(torch.cat(result_row, dim=-1)) + + return torch.cat(result_rows, dim=3)[:, :, :, :latent_h, :latent_w] + + def decode(self, z, return_dict=True): + """Decode latents to images on Neuron. Supports tiled decoding for large latents.""" + # NOTE: Do NOT unscale latents here! + # The pipeline already unscales latents before calling decode + + # Ensure 5D format + if len(z.shape) == 4: + z = z.unsqueeze(2) + + b, c, t, latent_h, latent_w = z.shape + + # Convert to bfloat16 + z = z.to(torch.bfloat16) + + # Calculate output image size + output_h = latent_h * self.spatial_compression_ratio + output_w = latent_w * self.spatial_compression_ratio + + if self.cpu_decode: + # CPU decode mode for debugging + z_cpu = z.to(torch.float32) + with torch.no_grad(): + z_cpu = self.cpu_post_quant_conv(z_cpu) + dec = self.cpu_decoder(z_cpu) + dec = dec.to(torch.bfloat16) + elif self._needs_tiling(output_h, output_w): + print(f" Using tiled decoding: latent {latent_h}x{latent_w} -> image {output_h}x{output_w}") + dec = self._tiled_decode(z) + else: + # Check if latent needs padding to match compiled size + expected_latent_h = self.expected_height // self.spatial_compression_ratio + expected_latent_w = self.expected_width // self.spatial_compression_ratio + + if latent_h != expected_latent_h or latent_w != expected_latent_w: + # Pad latents + pad_h = expected_latent_h - latent_h + pad_w = expected_latent_w - latent_w + z = torch.nn.functional.pad(z, (0, pad_w, 0, pad_h)) + + dec = self._decode_tile(z) + + # Crop output if we padded + if latent_h != expected_latent_h or latent_w != expected_latent_w: + dec = dec[:, :, :, :output_h, :output_w] + + if return_dict: + from diffusers.models.autoencoders.vae import DecoderOutput + return DecoderOutput(sample=dec) + return (dec,) + + def _tiled_decode(self, z): + """Decode large latents using tiled processing. + + When the compiled NEFF has batch_size > 1, all tiles are gathered into + chunks of compiled_batch_size and the decoder NEFF is invoked once per + chunk instead of once per tile (mirrors the _tiled_encode path). + """ + b, c, t, latent_h, latent_w = z.shape + + # Calculate dimensions + output_h = latent_h * self.spatial_compression_ratio + output_w = latent_w * self.spatial_compression_ratio + + tile_latent_h = self.expected_height // self.spatial_compression_ratio + tile_latent_w = self.expected_width // self.spatial_compression_ratio + tile_latent_stride_h = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_w = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_h = self.tile_sample_min_height - self.tile_sample_stride_height + blend_w = self.tile_sample_min_width - self.tile_sample_stride_width + + # Phase 1: enumerate all tiles, padding edge tiles to expected size. + tile_inputs = [] + tile_meta = [] + row_idx = 0 + for i in range(0, latent_h, tile_latent_stride_h): + col_idx = 0 + for j in range(0, latent_w, tile_latent_stride_w): + tile_h_end = min(i + tile_latent_h, latent_h) + tile_w_end = min(j + tile_latent_w, latent_w) + tile = z[:, :, :, i:tile_h_end, j:tile_w_end] + actual_h, actual_w = tile.shape[3], tile.shape[4] + if actual_h < tile_latent_h or actual_w < tile_latent_w: + pad_h = tile_latent_h - actual_h + pad_w = tile_latent_w - actual_w + tile = torch.nn.functional.pad(tile, (0, pad_w, 0, pad_h)) + tile_inputs.append(tile) + tile_meta.append((row_idx, col_idx, actual_h, actual_w)) + col_idx += 1 + row_idx += 1 + n_rows = row_idx + n_cols = max(m[1] for m in tile_meta) + 1 + + # Phase 2: run decoder in chunks of compiled_batch_size. + decoded_flat = [] + bs = self.compiled_batch_size + for chunk_start in range(0, len(tile_inputs), bs): + chunk = tile_inputs[chunk_start:chunk_start + bs] + batched = torch.cat(chunk, dim=0) + decoded = self._decode_tile(batched) + for k in range(len(chunk)): + decoded_flat.append(decoded[k:k+1]) + + # Phase 3: scatter into 2D grid and crop padded tiles. + rows = [[None] * n_cols for _ in range(n_rows)] + for (ri, ci, actual_h, actual_w), dec in zip(tile_meta, decoded_flat): + if actual_h < tile_latent_h or actual_w < tile_latent_w: + crop_h = actual_h * self.spatial_compression_ratio + crop_w = actual_w * self.spatial_compression_ratio + dec = dec[:, :, :, :crop_h, :crop_w] + rows[ri][ci] = dec + + # Blend tiles together + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self._blend_v(rows[i - 1][j], tile, blend_h) + if j > 0: + tile = self._blend_h(row[j - 1], tile, blend_w) + result_row.append(tile[:, :, :, :self.tile_sample_stride_height, :self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + return torch.cat(result_rows, dim=3)[:, :, :, :output_h, :output_w] + + +def load_all_compiled_models(compiled_models_dir: str, pipe, args): + """ + Load ALL compiled models for Trainium2 inference. + Every component MUST be compiled and loaded. + + Parallel configuration: + - VAE: DataParallel (DP=8) - single-device compiled, replicated across 8 devices + - Transformer: Tensor Parallel (TP=8) - sharded across 8 devices + - Vision Encoder: Single device OR TP=8 (use --vision_tp flag for TP mode) + - Language Model: Tensor Parallel (TP=8) - sharded with KV head replication + + IMPORTANT: This function replaces original models with compiled versions + and explicitly deletes the originals to free memory. + + Args: + compiled_models_dir: Directory containing compiled model artifacts + pipe: Original pipeline + args: Command line arguments + + Returns: + Updated pipeline with ALL Neuron-compiled models + """ + import gc + + # Check for vision encoder mode + # CPU is the default for better accuracy, use --neuron_vision_encoder or --use_v3_vision_encoder to use Neuron + vision_encoder_tp_path = f"{compiled_models_dir}/vision_encoder_tp" + vision_encoder_v3_path = f"{compiled_models_dir}/vision_encoder_v3/nxd_model.pt" + use_vision_tp = args.vision_tp if hasattr(args, 'vision_tp') else False + use_neuron_vision = getattr(args, 'neuron_vision_encoder', False) # Default to CPU + use_v3_vision_encoder = getattr(args, 'use_v3_vision_encoder', True) + # --use_v3_vision_encoder implies using Neuron (not CPU) + use_cpu_vision_encoder = not use_neuron_vision and not use_v3_vision_encoder + if use_v3_vision_encoder or (use_neuron_vision and os.path.exists(vision_encoder_v3_path)): + vision_mode = "Neuron V3 (TP=4, float32)" + use_v3_vision_encoder = True # Enable V3 if path exists and neuron_vision is requested + use_cpu_vision_encoder = False + elif use_cpu_vision_encoder: + vision_mode = "CPU (default)" + elif use_vision_tp or os.path.exists(vision_encoder_tp_path): + vision_mode = "Neuron TP=8" + else: + vision_mode = "Neuron (single device, float32)" + + print("\n" + "=" * 60) + print("Loading Compiled Models for Trainium2") + print("=" * 60) + # Check language model mode + # Priority: --use_v3_language_model > --neuron_language_model > --cpu_language_model (default) + use_v3_language_model = getattr(args, 'use_v3_language_model', False) + use_neuron_language_model = getattr(args, 'neuron_language_model', False) + use_cpu_language_model = not (use_v3_language_model or use_neuron_language_model) + + if use_v3_language_model: + language_mode = "Neuron V3 (TP=4, world_size=8)" + elif use_neuron_language_model: + language_mode = "Neuron (TP=8, KV replication)" + else: + language_mode = "CPU" + + print("Parallel configuration:") + print(" - VAE: Single device (avoid collective conflict)") + print(" - Transformer: TP=8") + print(f" - Vision Encoder: {vision_mode}") + print(f" - Language Model: {language_mode}") + if use_cpu_language_model: + print("\nNOTE: Language Model on CPU (safe fallback mode)") + print(" Use --use_v3_language_model for V3 compiled model (recommended with --use_v3_cp)") + elif use_v3_language_model: + print("\nNOTE: Language Model uses V3 (ModelBuilder API)") + print(" TP=4, world_size=8 - compatible with V3 CP transformer") + else: + print("\nNOTE: Language Model uses TP=8 with KV head replication") + print(" (Q heads padded 28->32, KV heads replicated 4->8)") + + # ======================================== + # 1. Load Transformer FIRST (TP=8) + # ======================================== + # IMPORTANT: Must load the largest TP model first to initialize + # the communicator with the correct world size + use_v2 = getattr(args, 'use_v2', False) + use_v1_flash = getattr(args, 'use_v1_flash', False) + use_v2_flash = getattr(args, 'use_v2_flash', False) + use_v3_cp = getattr(args, 'use_v3_cp', False) + use_v3_cfg = getattr(args, 'use_v3_cfg', False) + v2_available = os.path.exists(f"{compiled_models_dir}/transformer_v2/nxd_model.pt") + v1_flash_available = os.path.exists(f"{compiled_models_dir}/transformer_v1_flash") + v2_flash_available = os.path.exists(f"{compiled_models_dir}/transformer_v2_flash/nxd_model.pt") + v3_cp_available = os.path.exists(f"{compiled_models_dir}/transformer_v3_cp/nxd_model.pt") + v3_cfg_available = os.path.exists(f"{compiled_models_dir}/transformer_v3_cfg/nxd_model.pt") + + if use_v3_cfg: + print("\n[1/3] Loading Transformer V3 CFG (CFG Parallel + NKI Flash Attention, TP=4, DP=2)...") + if not v3_cfg_available: + raise FileNotFoundError( + f"V3 CFG transformer not found. Please run: python neuron_qwen_image_edit/compile_transformer_v3_cfg.py" + ) + + # Store reference to original for wrapper + original_transformer = pipe.transformer + + # Load V3 CFG model and assign to pipe + pipe.transformer = load_transformer_v3_cfg(compiled_models_dir, pipe, args) + + # Delete original transformer to free memory + del original_transformer + import gc + gc.collect() + print(" Transformer V3 CFG loaded!") + print(" Original transformer deleted to free memory.") + + # Patch pipeline for batched CFG + patch_pipeline_for_cfg_parallel(pipe) + elif use_v3_cp: + print("\n[1/3] Loading Transformer V3 CP (Context Parallel + NKI Flash Attention, TP=4, CP=2)...") + if not v3_cp_available: + raise FileNotFoundError( + f"V3 CP transformer not found. Please run: ./compile.sh v3_cp" + ) + + # Store reference to original for wrapper + original_transformer = pipe.transformer + + # Load V3 CP model and assign to pipe + pipe.transformer = load_transformer_v3_cp(compiled_models_dir, pipe, args) + + # Delete original transformer to free memory + del original_transformer + import gc + gc.collect() + print(" Transformer V3 CP loaded!") + print(" Original transformer deleted to free memory.") + + # Reuse the CFG-parallel pipeline patch for its stage-timing markers. + # The no-CFG branch of the patched __call__ is identical in signature + # to upstream and works with the V3 CP wrapper. + patch_pipeline_for_cfg_parallel(pipe) + elif use_v2_flash: + print("\n[1/3] Loading Transformer V2 Flash (ModelBuilder + NKI Flash Attention, TP=8)...") + if not v2_flash_available: + raise FileNotFoundError( + f"Transformer V2 Flash not found at {compiled_models_dir}/transformer_v2_flash\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2_flash.py" + ) + + # Store reference to original for wrapper + original_transformer = pipe.transformer + + # Load V2 Flash model + pipe.transformer = load_transformer_v2_flash(compiled_models_dir, pipe, args) + + # Delete original transformer to free ~40GB memory + del original_transformer + gc.collect() + print(" Transformer V2 Flash loaded!") + print(" Original transformer deleted to free memory.") + elif use_v1_flash: + print("\n[1/3] Loading Transformer V1 Flash (parallel_model_trace + NKI Flash Attention, TP=8)...") + if not v1_flash_available: + raise FileNotFoundError( + f"Transformer V1 Flash not found at {compiled_models_dir}/transformer_v1_flash\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v1_flash.py" + ) + + # Store reference to original for wrapper + original_transformer = pipe.transformer + + # Load V1 Flash model + pipe.transformer = load_transformer_v1_flash(compiled_models_dir, pipe, args) + + # Delete original transformer to free ~40GB memory + del original_transformer + gc.collect() + print(" Transformer V1 Flash loaded!") + print(" Original transformer deleted to free memory.") + elif use_v2: + print("\n[1/3] Loading Transformer V2 (ModelBuilder API, TP=8)...") + if not v2_available: + raise FileNotFoundError( + f"Transformer V2 not found at {compiled_models_dir}/transformer_v2\n" + "Please run: python neuron_qwen_image_edit/compile_transformer_v2.py" + ) + + # Store reference to original for wrapper + original_transformer = pipe.transformer + + # Load V2 model + pipe.transformer = load_transformer_v2(compiled_models_dir, pipe, args) + + # Delete original transformer to free ~40GB memory + del original_transformer + gc.collect() + print(" Transformer V2 loaded!") + print(" Original transformer deleted to free memory.") + else: + print("\n[1/3] Loading Transformer V1 (parallel_model_trace API, TP=8)...") + + transformer_path = f"{compiled_models_dir}/transformer" + if not os.path.exists(transformer_path): + raise FileNotFoundError( + f"Transformer not found at {transformer_path}\n" + "Please run: python neuron_qwen_image_edit/compile_transformer.py" + ) + print(f" Loading transformer from {transformer_path}...") + compiled_transformer = neuronx_distributed.trace.parallel_model_load( + transformer_path + ) + + # Calculate expected shapes based on image dimensions + latent_h = args.height // 8 + latent_w = args.width // 8 + patch_h = latent_h // 2 + patch_w = latent_w // 2 + base_num_patches = patch_h * patch_w # e.g., 64*64=4096 for 1024x1024 + + # For IMAGE EDITING, patches are doubled (source + noise latents concatenated) + # This is handled by using temporal_frames = patch_multiplier + # - patch_multiplier=1 (generation): temporal_frames=1, patches = 1 * 32 * 32 = 1024 + # - patch_multiplier=2 (editing): temporal_frames=2, patches = 2 * 32 * 32 = 2048 + temporal_frames = args.patch_multiplier + expected_num_patches = temporal_frames * base_num_patches + print(f" Expected num_patches: {expected_num_patches} (temporal_frames={temporal_frames}, base={base_num_patches})") + + # img_shapes for the wrapper + # Note: batch_size=1, CFG runs transformer twice sequentially (not batch_size=2) + img_shapes = [(temporal_frames, patch_h, patch_w)] + + # Store reference to original for wrapper, then delete + original_transformer = pipe.transformer + pipe.transformer = NeuronTransformerWrapper( + original_transformer, compiled_transformer, img_shapes, + expected_num_patches=expected_num_patches, + expected_seq_len=args.max_sequence_length + ) + # Delete original transformer to free ~40GB memory + del original_transformer + gc.collect() + print(f" Transformer V1 loaded (TP=8)! Expected patches={expected_num_patches}, seq_len={args.max_sequence_length}") + print(" Original transformer deleted to free memory.") + + # ======================================== + # 2. Load Text Encoder Components + # ======================================== + print("\n[2/3] Loading Text Encoder...") + + # Load Vision Encoder + # Priority: CPU > V3 (TP=4) > TP=8 > single device + # Note: vision_encoder_tp_path, use_vision_tp, use_cpu_vision_encoder, use_v3_vision_encoder are defined at the top + vision_encoder_single_path = f"{compiled_models_dir}/vision_encoder/model.pt" + compiled_vision_encoder = None + compiled_vision_encoder_v3 = None + cpu_vision_encoder = None + vision_encoder_config = None + + if use_cpu_vision_encoder: + # CPU Vision Encoder mode - highest accuracy, avoids compilation precision loss + # This is useful when compiled vision encoder produces blurry outputs + print(" Using CPU Vision Encoder (highest accuracy)...") + # Extract vision encoder from text encoder - will be passed to wrapper + cpu_vision_encoder = pipe.text_encoder.model.visual + cpu_vision_encoder.eval() + print(" Vision encoder prepared on CPU!") + elif use_v3_vision_encoder: + # V3 Vision Encoder mode - uses ModelBuilder API with TP=4, world_size=8 + # Faster than single device, maintains float32 precision + print(" Loading V3 Vision Encoder (TP=4, world_size=8, float32)...") + compiled_vision_encoder_v3, vision_encoder_config = load_vision_encoder_v3(compiled_models_dir) + print(" V3 Vision encoder loaded!") + elif use_vision_tp or (os.path.exists(vision_encoder_tp_path) and not os.path.exists(vision_encoder_single_path)): + # Load TP-compiled vision encoder (TP=8, but may have dimension issues) + if not os.path.exists(vision_encoder_tp_path): + raise FileNotFoundError( + f"Vision encoder (TP) not found at {vision_encoder_tp_path}\n" + "Please run: python neuron_qwen_image_edit/compile_text_encoder.py --vision_only --vision_tp" + ) + print(f" Loading vision encoder (TP={TP_DEGREE}) from {vision_encoder_tp_path}...") + compiled_vision_encoder = neuronx_distributed.trace.parallel_model_load( + vision_encoder_tp_path + ) + print(f" Vision encoder loaded (TP={TP_DEGREE})!") + else: + # Load single-device vision encoder (always float32) + if not os.path.exists(vision_encoder_single_path): + raise FileNotFoundError( + f"Vision encoder not found at {vision_encoder_single_path}\n" + "Please run: python neuron_qwen_image_edit/compile_text_encoder.py --vision_only\n" + "Or for V3 (faster): python neuron_qwen_image_edit/compile_vision_encoder_v3.py" + ) + print(f" Loading vision encoder from {vision_encoder_single_path}...") + vision_encoder_jit = torch.jit.load(vision_encoder_single_path) + # Vision encoder input is (num_patches, channels), NOT (batch, ...) + # DataParallel would incorrectly split on patches dimension + # Must use single device + compiled_vision_encoder = vision_encoder_jit + print(f" Vision encoder loaded (single device, float32)!") + + # Load Language Model + compiled_language_model = None + compiled_language_model_v3 = None + cpu_language_model = None + language_model_config = None + + if use_v3_language_model: + # V3 Language Model mode - uses ModelBuilder API with TP=4, world_size=8 + # Compatible with V3 CP transformer + print(" Loading V3 Language Model (TP=4, world_size=8)...") + compiled_language_model_v3, language_model_config = load_language_model_v3(compiled_models_dir) + print(" V3 Language model loaded!") + elif use_cpu_language_model: + # CPU Language Model mode - keeps original model on CPU + # This avoids GQA alignment issues that occur with TP != 4 + print(" Using CPU Language Model (avoids GQA alignment issue)...") + # Extract language model from text encoder BEFORE creating wrapper + cpu_language_model = pipe.text_encoder.model.language_model + cpu_language_model.eval() + # Keep it in bfloat16 for memory efficiency + cpu_language_model = cpu_language_model.to(torch.bfloat16) + print(" Language model prepared on CPU!") + else: + # Neuron compiled Language Model mode (TP=8 with KV head replication) + language_model_path = f"{compiled_models_dir}/language_model" + if not os.path.exists(language_model_path): + raise FileNotFoundError( + f"Language model not found at {language_model_path}\n" + "Please run: python neuron_qwen_image_edit/compile_text_encoder.py --language_only" + ) + print(f" Loading language model from {language_model_path}...") + compiled_language_model = neuronx_distributed.trace.parallel_model_load( + language_model_path + ) + print(" Language model loaded (TP=8 with KV head replication)!") + + # Create Text Encoder Wrapper + # Store reference to original, then delete after wrapper is created + original_text_encoder = pipe.text_encoder + + # Get language model batch size from config (default to 1) + language_model_batch_size = 1 + if language_model_config is not None: + language_model_batch_size = language_model_config.get("batch_size", 1) + + # Resolve portrait NEFF dims for the vision encoder. Prefer the values written + # into vision_encoder_v3/config.json by compile_vision_encoder_v3.py; fall back + # to CLI args / image_size for legacy single-device or TP=8 vision encoders. + if vision_encoder_config is not None: + wrapper_image_h = vision_encoder_config.get("image_h", + vision_encoder_config.get("image_size", args.image_size)) + wrapper_image_w = vision_encoder_config.get("image_w", + vision_encoder_config.get("image_size", args.image_size)) + else: + wrapper_image_h = getattr(args, "image_h", None) or args.image_size + wrapper_image_w = getattr(args, "image_w", None) or args.image_size + print(f" Vision encoder grid: {wrapper_image_h}x{wrapper_image_w} " + f"({wrapper_image_h//14}x{wrapper_image_w//14} patches)") + + pipe.text_encoder = NeuronTextEncoderWrapper( + original_text_encoder=original_text_encoder, + compiled_vision_encoder=compiled_vision_encoder, + compiled_vision_encoder_v3=compiled_vision_encoder_v3, + compiled_language_model=compiled_language_model, + compiled_language_model_v3=compiled_language_model_v3, + cpu_language_model=cpu_language_model, + cpu_vision_encoder=cpu_vision_encoder, + image_size=args.image_size, + image_h=wrapper_image_h, + image_w=wrapper_image_w, + max_seq_len=args.max_sequence_length, + language_model_batch_size=language_model_batch_size + ) + + if use_cpu_language_model or use_cpu_vision_encoder: + # When using CPU models, we keep references - don't delete original + print(" Text encoder wrapper created!") + if use_cpu_language_model: + print(" Language model kept on CPU.") + if use_cpu_vision_encoder: + print(" Vision encoder kept on CPU (highest accuracy mode).") + elif use_v3_language_model or use_v3_vision_encoder: + # V3 models loaded, can delete original + del original_text_encoder + gc.collect() + print(" Text encoder wrapper created!") + print(" Original text encoder deleted to free memory.") + else: + # Delete original text encoder to free ~16GB memory + del original_text_encoder + gc.collect() + print(" Text encoder wrapper created!") + print(" Original text encoder deleted to free memory.") + + # ======================================== + # 3. Load VAE (Encoder + Decoder) + # ======================================== + print("\n[3/3] Loading VAE...") + + # First replace with Neuron-compatible VAE architecture + print(" Creating Neuron-compatible VAE...") + original_vae_config = pipe.vae.config + neuron_vae = NeuronAutoencoder( + base_dim=original_vae_config.base_dim, + z_dim=original_vae_config.z_dim, + dim_mult=original_vae_config.dim_mult, + num_res_blocks=original_vae_config.num_res_blocks, + attn_scales=original_vae_config.attn_scales, + temperal_downsample=original_vae_config.temperal_downsample, + dropout=original_vae_config.dropout, + input_channels=getattr(original_vae_config, 'input_channels', 3), + latents_mean=original_vae_config.latents_mean, + latents_std=original_vae_config.latents_std, + ) + neuron_vae.load_state_dict(pipe.vae.state_dict()) + + # Load compiled encoder + vae_encoder_path = f"{compiled_models_dir}/vae_encoder/model.pt" + if not os.path.exists(vae_encoder_path): + raise FileNotFoundError( + f"VAE encoder not found at {vae_encoder_path}\n" + "Please run: python neuron_qwen_image_edit/compile_vae.py" + ) + print(f" Loading VAE encoder from {vae_encoder_path}...") + vae_encoder_jit = torch.jit.load(vae_encoder_path) + # Use single device to avoid collective communication conflict with TP models + # VAE is small (~300M params), doesn't need parallelism + compiled_encoder = vae_encoder_jit + print(" VAE encoder loaded (single device)!") + + # Load compiled decoder + vae_decoder_path = f"{compiled_models_dir}/vae_decoder/model.pt" + if not os.path.exists(vae_decoder_path): + raise FileNotFoundError( + f"VAE decoder not found at {vae_decoder_path}\n" + "Please run: python neuron_qwen_image_edit/compile_vae.py" + ) + print(f" Loading VAE decoder from {vae_decoder_path}...") + vae_decoder_jit = torch.jit.load(vae_decoder_path) + # Use single device to avoid collective communication conflict with TP models + # VAE is small (~300M params), doesn't need parallelism + compiled_decoder = vae_decoder_jit + print(" VAE decoder loaded (single device)!") + + # Load quant_conv and post_quant_conv if they exist (single device) + compiled_quant_conv = None + quant_conv_path = f"{compiled_models_dir}/quant_conv/model.pt" + if os.path.exists(quant_conv_path): + print(f" Loading quant_conv from {quant_conv_path}...") + compiled_quant_conv = torch.jit.load(quant_conv_path) + + compiled_post_quant_conv = None + post_quant_conv_path = f"{compiled_models_dir}/post_quant_conv/model.pt" + if os.path.exists(post_quant_conv_path): + print(f" Loading post_quant_conv from {post_quant_conv_path}...") + compiled_post_quant_conv = torch.jit.load(post_quant_conv_path) + + # Create VAE Wrapper + cpu_decode = getattr(args, 'cpu_vae_decode', False) + # Use vae_tile_size for the compiled model's expected input size + vae_tile_size = getattr(args, 'vae_tile_size', 512) + + # Load VAE config to get compiled_batch_size + vae_config_path = f"{compiled_models_dir}/vae_config.json" + vae_compiled_batch_size = 1 + if os.path.exists(vae_config_path): + import json + with open(vae_config_path, 'r') as f: + vae_config = json.load(f) + vae_compiled_batch_size = vae_config.get('batch_size', 1) + print(f" VAE compiled batch_size: {vae_compiled_batch_size}") + + pipe.vae = NeuronVAEWrapper( + original_vae=neuron_vae, + compiled_encoder=compiled_encoder, + compiled_decoder=compiled_decoder, + compiled_quant_conv=compiled_quant_conv, + compiled_post_quant_conv=compiled_post_quant_conv, + expected_height=vae_tile_size, + expected_width=vae_tile_size, + compiled_batch_size=vae_compiled_batch_size, + cpu_decode=cpu_decode + ) + # Delete the neuron_vae (original VAE copy) - small but still free it + # Note: if cpu_decode=True, the decoder/post_quant_conv refs are already copied + del neuron_vae + gc.collect() + print(" VAE wrapper created!") + + # Fix missing _execution_device property + # The pipeline expects this to determine where to run operations + # Override the property with a lambda that returns CPU device + type(pipe)._execution_device = property(lambda self: torch.device("cpu")) + + # Use vision_mode and language_mode defined at the top of the function + if use_v3_cfg: + transformer_api = "V3 CFG (CFG Parallel + NKI, TP=4, DP=2)" + tp_info = "TP=4, DP=2" + elif use_v3_cp: + transformer_api = "V3 CP (Context Parallel + NKI, TP=4, CP=2)" + tp_info = "TP=4, CP=2" + elif use_v2_flash: + transformer_api = "V2 Flash (ModelBuilder + NKI)" + tp_info = "TP=8" + elif use_v1_flash: + transformer_api = "V1 Flash (parallel_model_trace + NKI)" + tp_info = "TP=8" + elif use_v2: + transformer_api = "V2 (ModelBuilder)" + tp_info = "TP=8" + else: + transformer_api = "V1 (parallel_model_trace)" + tp_info = "TP=8" + print("\n" + "=" * 60) + print("All Models Loaded!") + print("=" * 60) + print(f" - Transformer: Neuron ({tp_info}, {transformer_api})") + print(f" - Language Model: {language_mode}") + print(f" - Vision Encoder: Neuron ({vision_mode})") + print(f" - VAE: Neuron (tile size={vae_tile_size}x{vae_tile_size})") + print("") + print("Tiled VAE note:") + print(f" - VAE compiled for {vae_tile_size}x{vae_tile_size} tiles") + print(" - Larger images will be processed in tiles automatically") + print(" - Example: 1024x1024 -> 4 tiles of 512x512 (with overlap)") + print("") + if use_cpu_language_model: + print("Memory note:") + print(" - Language Model on CPU (~8GB CPU memory)") + print(" - Other components on Neuron") + + return pipe + + +def debug_text_encoder(pipe, input_images, args): + """ + Debug: Compare NeuronTextEncoderWrapper output vs CPU. + + This function helps identify if text encoder is causing output issues. + """ + import torch.nn.functional as F + + print("\nPreparing test input...") + + # Prepare input like the pipeline does + prompt = args.prompt + if isinstance(input_images, list): + base_img_prompt = "".join([f"Picture {i+1}: <|vision_start|><|image_pad|><|vision_end|>" for i in range(len(input_images))]) + images = input_images + else: + base_img_prompt = "Picture 1: <|vision_start|><|image_pad|><|vision_end|>" + images = [input_images] + + template = pipe.prompt_template_encode + txt = [template.format(base_img_prompt + prompt)] + + model_inputs = pipe.processor( + text=txt, + images=images, + padding=True, + return_tensors="pt", + ) + + print(f" input_ids: {model_inputs.input_ids.shape}") + print(f" pixel_values: {model_inputs.pixel_values.shape}") + print(f" image_grid_thw: {model_inputs.image_grid_thw.tolist()}") + + # Count image tokens + image_token_id = pipe.text_encoder.config.image_token_id if hasattr(pipe.text_encoder, 'config') else 151655 + num_image_tokens = (model_inputs.input_ids == image_token_id).sum().item() + print(f" Image tokens in input: {num_image_tokens}") + + # Run the wrapper (which is what inference uses) + print("\nRunning NeuronTextEncoderWrapper...") + with torch.no_grad(): + wrapper_output = pipe.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values.to(torch.bfloat16), + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + if hasattr(wrapper_output, 'hidden_states'): + wrapper_hidden = wrapper_output.hidden_states[-1] + else: + wrapper_hidden = wrapper_output.last_hidden_state + + print(f" Wrapper output shape: {wrapper_hidden.shape}") + print(f" Wrapper output stats: mean={wrapper_hidden.float().mean():.4f}, std={wrapper_hidden.float().std():.4f}") + print(f" Wrapper output range: [{wrapper_hidden.float().min():.4f}, {wrapper_hidden.float().max():.4f}]") + + # Check for NaN/Inf + has_nan = torch.isnan(wrapper_hidden).any().item() + has_inf = torch.isinf(wrapper_hidden).any().item() + if has_nan: + print(" [WARNING] Output contains NaN!") + if has_inf: + print(" [WARNING] Output contains Inf!") + + # Save intermediate results for debugging + debug_data = { + 'input_ids': model_inputs.input_ids.cpu().numpy(), + 'attention_mask': model_inputs.attention_mask.cpu().numpy(), + 'pixel_values_shape': list(model_inputs.pixel_values.shape), + 'image_grid_thw': model_inputs.image_grid_thw.cpu().numpy(), + 'wrapper_output': wrapper_hidden.float().cpu().numpy(), + } + + import numpy as np + np.savez('debug_text_encoder_output.npz', **debug_data) + print("\n Debug data saved to: debug_text_encoder_output.npz") + print(" To compare with CPU, load original pipeline and run the same inputs.") + + +def run_inference(args): + """Run image editing inference on Trainium2.""" + set_seed(args.seed) + + print("\n" + "=" * 60) + print("Qwen-Image-Edit Inference on Trainium2") + print("=" * 60) + print(f" Compiled dimensions: {args.height}x{args.width}") + print(f" Steps: {args.num_inference_steps}") + print(f" CFG scale: {args.true_cfg_scale}") + + # Load original pipeline + print("\nLoading original pipeline...") + dtype = torch.bfloat16 + + # CRITICAL FIX: Override VAE_IMAGE_SIZE before loading pipeline + # The pipeline uses VAE_IMAGE_SIZE (default 1024*1024) to resize source images. + # This creates more patches than our compiled transformer expects. + # We need to match our compiled dimensions. + import diffusers.pipelines.qwenimage.pipeline_qwenimage_edit_plus as qwen_pipeline_module + compiled_vae_pixels = args.height * args.width # e.g., 512*512 + original_vae_size = getattr(qwen_pipeline_module, 'VAE_IMAGE_SIZE', 1024*1024) + qwen_pipeline_module.VAE_IMAGE_SIZE = compiled_vae_pixels + print(f"\nOverriding VAE_IMAGE_SIZE: {original_vae_size} -> {compiled_vae_pixels}") + print(f" (This ensures source images produce {args.height//8//2}x{args.width//8//2} patches)") + + # Force the VLM (condition) branch to use our exact compiled grid (image_h x image_w). + # The pipeline normally calls `calculate_dimensions(CONDITION_IMAGE_SIZE, src_ratio)` + # which rounds to multiples of 32 from the *source* aspect ratio — that does NOT + # land on our portrait NEFF (e.g. 336x448) when the source ratio differs even slightly. + # Patch the symbol used inside the pipeline module so it returns our fixed (W, H) + # regardless of input ratio when the call is for CONDITION_IMAGE_SIZE. + vlm_h = getattr(args, 'image_h', None) or args.image_size + vlm_w = getattr(args, 'image_w', None) or args.image_size + qwen_pipeline_module.CONDITION_IMAGE_SIZE = vlm_h * vlm_w + _orig_calc_dims = qwen_pipeline_module.calculate_dimensions + _condition_target = vlm_h * vlm_w + _vae_target = compiled_vae_pixels + def _patched_calc_dims(target_area, ratio): + if target_area == _condition_target: + return vlm_w, vlm_h + if target_area == _vae_target: + return args.width, args.height + return _orig_calc_dims(target_area, ratio) + qwen_pipeline_module.calculate_dimensions = _patched_calc_dims + print(f" Forcing condition (VLM) dims to {vlm_w}x{vlm_h} regardless of source ratio") + print(f" Forcing VAE dims to {args.width}x{args.height} regardless of source ratio") + + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=HUGGINGFACE_CACHE_DIR, + local_files_only=True + ) + + # CRITICAL: Configure processor to output fixed image size matching compiled vision encoder. + # Qwen2VLImageProcessorFast.preprocess() reads from `size["shortest_edge"]` / + # `size["longest_edge"]` (NOT from min_pixels/max_pixels attributes alone), so we must + # set both. Even with these set, smart_resize will produce the closest + # multiple-of-(patch_size*merge_size)=28 grid that fits the constraints — since the + # condition branch above pre-resizes images to exactly (vlm_w, vlm_h) which is already + # a multiple of 28, smart_resize is a no-op. + vis_h = getattr(args, 'image_h', None) or args.image_size + vis_w = getattr(args, 'image_w', None) or args.image_size + target_pixels = vis_h * vis_w + print(f"\nConfiguring processor for vision encoder size: {vis_w}x{vis_h} (W x H)") + print(f" Setting min_pixels = max_pixels = {target_pixels}") + pipe.processor.image_processor.min_pixels = target_pixels + pipe.processor.image_processor.max_pixels = target_pixels + pipe.processor.image_processor.size = { + "shortest_edge": target_pixels, + "longest_edge": target_pixels, + } + + print("Pipeline loaded!") + + # Load ALL compiled models - everything runs on Trainium2 + pipe = load_all_compiled_models(args.compiled_models_dir, pipe, args) + + # Load source images (1-3 images supported) + # IMPORTANT: Images must be resized to COMPILED dimensions for the transformer + print(f"\nLoading {len(args.images)} source image(s)...") + source_images = [] + for img_path in args.images: + print(f" Loading: {img_path}") + img = load_image(img_path) + # Resize to match COMPILED dimensions (not inference dimensions) + img = img.resize((args.width, args.height)) + source_images.append(img) + print(f"All images resized to: {args.width}x{args.height} (compiled dimensions)") + + # Use single image or list based on count + input_images = source_images[0] if len(source_images) == 1 else source_images + + # Debug: Compare text encoder outputs + if args.debug_text_encoder: + print("\n" + "="*60) + print("[DEBUG] Text Encoder Comparison") + print("="*60) + debug_text_encoder(pipe, input_images, args) + print("="*60 + "\n") + + # Create generator for reproducibility + generator = torch.Generator().manual_seed(args.seed) + + # CFG is controlled by true_cfg_scale (default 4.0 in pipeline) + # CFG runs transformer twice sequentially, NOT with batch_size=2 + true_cfg_scale = args.true_cfg_scale + + # Warmup run + if args.warmup: + print("\n" + "-" * 40) + print("Running warmup inference...") + print("-" * 40) + warmup_generator = torch.Generator().manual_seed(args.seed + 1000) + start = time.time() + _ = pipe( + image=input_images, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, # Use compiled dimensions + width=args.width, + true_cfg_scale=true_cfg_scale, + num_inference_steps=min(5, args.num_inference_steps), + generator=warmup_generator, + ) + warmup_time = time.time() - start + print(f"Warmup time: {warmup_time:.2f}s") + + # Main inference + print("\n" + "-" * 40) + print("Running main inference...") + print("-" * 40) + print(f" Prompt: {args.prompt}") + + generator = torch.Generator().manual_seed(args.seed) + start = time.time() + output = pipe( + image=input_images, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, # Use compiled dimensions + width=args.width, + true_cfg_scale=true_cfg_scale, + num_inference_steps=args.num_inference_steps, + generator=generator, + ) + inference_time = time.time() - start + + print(f"\nInference time: {inference_time:.2f}s") + + # Save output + output_image = output.images[0] + output_path = args.output or "output_edited.png" + output_image.save(output_path) + print(f"Output saved to: {output_path}") + + # Save comparison + if args.save_comparison: + # Create comparison with all input images + output + num_images = len(source_images) + 1 # inputs + output + comparison = Image.new('RGB', (args.width * num_images, args.height)) + for i, img in enumerate(source_images): + comparison.paste(img, (args.width * i, 0)) + comparison.paste(output_image, (args.width * len(source_images), 0)) + comparison_path = output_path.replace('.png', '_comparison.png') + comparison.save(comparison_path) + print(f"Comparison saved to: {comparison_path}") + + return output_image + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Qwen-Image-Edit inference on AWS Trainium2 (ALL components on Neuron)" + ) + + # Input/Output + parser.add_argument("--images", type=str, nargs="+", required=True, + help="Path(s) to source image(s) for editing (1-3 images supported)") + parser.add_argument("--prompt", type=str, required=True, + help="Edit instruction prompt") + parser.add_argument("--negative_prompt", type=str, default="", + help="Negative prompt") + parser.add_argument("--output", type=str, default=None, + help="Output image path (default: output_edited.png)") + + # Image dimensions (must match compiled model) + parser.add_argument("--height", type=int, default=1024, + help="Image height (must match compiled model)") + parser.add_argument("--width", type=int, default=1024, + help="Image width (must match compiled model)") + parser.add_argument("--patch_multiplier", type=int, default=2, + help="Patch multiplier (2 for image editing, 1 for generation)") + + # Text encoder settings - MUST match compilation settings + parser.add_argument("--image_size", type=int, default=448, + help="Vision encoder square image size (must match compiled model). " + "Ignored if --image_h and --image_w are set.") + parser.add_argument("--image_h", type=int, default=None, + help="Vision encoder image height in pixels (portrait NEFF). " + "Must match the value used at compile time.") + parser.add_argument("--image_w", type=int, default=None, + help="Vision encoder image width in pixels (portrait NEFF). " + "Must match the value used at compile time.") + parser.add_argument("--max_sequence_length", type=int, default=1024, + help="Max text sequence length (must match compiled model)") + parser.add_argument("--vision_tp", action="store_true", + help="Use TP-compiled vision encoder (from vision_encoder_tp/). " + "Default is to auto-detect based on available compiled models.") + + # Language model mode + parser.add_argument("--cpu_language_model", action="store_true", default=True, + help="Run Language Model on CPU (default). " + "Safe fallback mode that avoids any TP compatibility issues.") + parser.add_argument("--neuron_language_model", action="store_true", + help="Use Neuron-compiled Language Model with TP=8 (KV head replication mode). " + "Requires: python compile_text_encoder.py --language_only --language_tp_degree 8") + parser.add_argument("--use_v3_language_model", action=argparse.BooleanOptionalAction, default=True, + help="Use V3 Language Model compiled with ModelBuilder API (TP=4, world_size=8). " + "Default: True. Use --no-use_v3_language_model to disable. " + "Requires: python neuron_qwen_image_edit/compile_language_model_v3.py") + + # Vision encoder mode + parser.add_argument("--cpu_vision_encoder", action="store_true", + help="Run Vision Encoder on CPU (default behavior)") + parser.add_argument("--neuron_vision_encoder", action=argparse.BooleanOptionalAction, default=False, + help="Use Neuron-compiled Vision Encoder (float32). " + "CPU is used by default for better accuracy.") + parser.add_argument("--use_v3_vision_encoder", action=argparse.BooleanOptionalAction, default=True, + help="Use V3 Vision Encoder with TP=4 (faster, requires --neuron_vision_encoder). " + "Requires: python neuron_qwen_image_edit/compile_vision_encoder_v3.py") + + # Inference settings + parser.add_argument("--num_inference_steps", type=int, default=40, + help="Number of denoising steps (default: 40)") + parser.add_argument("--true_cfg_scale", type=float, default=4.0, + help="Classifier-free guidance scale (default: 4.0). " + "CFG runs transformer twice sequentially (not batch_size=2).") + parser.add_argument("--seed", type=int, default=SEED, + help="Random seed for reproducibility") + + # Model settings + parser.add_argument("--compiled_models_dir", type=str, default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + parser.add_argument("--vae_tile_size", type=int, default=512, + help="VAE tile size (must match compiled VAE size). " + "For larger images, tiled VAE will process in this tile size.") + parser.add_argument("--use_v2", action="store_true", + help="Use V2 transformer compiled with ModelBuilder API. " + "V2 passes RoPE as input tensors (like Flux). " + "Requires: python neuron_qwen_image_edit/compile_transformer_v2.py") + parser.add_argument("--use_v1_flash", action="store_true", + help="Use V1 Flash transformer with NKI Flash Attention. " + "Combines V1's parallel_model_trace (supports NKI) with V2's RoPE handling. " + "Requires: python neuron_qwen_image_edit/compile_transformer_v1_flash.py") + parser.add_argument("--use_v2_flash", action="store_true", + help="Use V2 Flash transformer with ModelBuilder + NKI Flash Attention. " + "Combines ModelBuilder's XLA optimization with NKI's hardware attention. " + "Requires: python neuron_qwen_image_edit/compile_transformer_v2_flash.py") + parser.add_argument("--use_v3_cp", action="store_true", + help="Use V3 CP transformer with Context Parallel + NKI Flash Attention. " + "Mutually exclusive with --use_v3_cfg. " + "Requires: ./compile.sh v3_cp") + parser.add_argument("--use_v3_cfg", action=argparse.BooleanOptionalAction, default=True, + help="Use V3 CFG transformer with CFG Parallel + NKI Flash Attention. " + "Batches negative + positive prompts for parallel inference. " + "Default: True. Use --no-use_v3_cfg to disable. " + "Requires: ./compile.sh v3_cfg") + + # Other options + parser.add_argument("--warmup", action="store_true", + help="Run warmup inference before main inference") + parser.add_argument("--save_comparison", action="store_true", + help="Save side-by-side comparison image") + + # Debug options + parser.add_argument("--cpu_vae_decode", action="store_true", + help="[DEBUG] Run VAE decoder on CPU instead of Neuron. " + "Use this to verify if other components are working correctly.") + parser.add_argument("--debug_text_encoder", action="store_true", + help="[DEBUG] Compare Text Encoder outputs before running inference. " + "This helps identify if text encoder is the source of issues.") + + args = parser.parse_args() + + # Validate number of images (1-3 supported by Qwen-Image-Edit) + if len(args.images) > 3: + parser.error("Qwen-Image-Edit supports 1-3 images, but {} were provided".format(len(args.images))) + + # Mutual exclusivity: --use_v3_cfg and --use_v3_cp + if args.use_v3_cfg and args.use_v3_cp: + # --use_v3_cp explicitly set takes priority, disable v3_cfg + args.use_v3_cfg = False + + run_inference(args) diff --git a/contrib/models/Qwen-Image-Edit/src/setup_nvme.sh b/contrib/models/Qwen-Image-Edit/src/setup_nvme.sh new file mode 100755 index 00000000..3d50672e --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/src/setup_nvme.sh @@ -0,0 +1,113 @@ +#!/bin/bash +set -e + +MOUNT_POINT="/opt/dlami/nvme" +RAID_DEVICE="/dev/md0" + +echo "=== NVMe RAID0 Setup Script for trn2.48xlarge ===" + +# Check if running as root +if [[ $EUID -ne 0 ]]; then + echo "This script must be run as root (use sudo)" + exit 1 +fi + +# Check if already mounted +if mountpoint -q "$MOUNT_POINT" 2>/dev/null; then + echo "$MOUNT_POINT is already mounted." + df -h "$MOUNT_POINT" + exit 0 +fi + +# Create mount point +mkdir -p "$MOUNT_POINT" + +# Case 1: RAID device exists - just mount it +if [[ -e "$RAID_DEVICE" ]]; then + echo "RAID device $RAID_DEVICE exists. Mounting..." + mount "$RAID_DEVICE" "$MOUNT_POINT" + chown ubuntu:ubuntu "$MOUNT_POINT" + chmod 755 "$MOUNT_POINT" + echo "" + echo "=== Mount Complete ===" + df -h "$MOUNT_POINT" + exit 0 +fi + +# Case 2: RAID device doesn't exist - try to assemble from existing superblocks +echo "RAID device $RAID_DEVICE not found. Trying to assemble existing array..." +if mdadm --assemble --scan 2>/dev/null; then + sleep 1 + if [[ -e "$RAID_DEVICE" ]]; then + echo "RAID array reassembled successfully. Mounting..." + mount "$RAID_DEVICE" "$MOUNT_POINT" + chown ubuntu:ubuntu "$MOUNT_POINT" + chmod 755 "$MOUNT_POINT" + echo "" + echo "=== Mount Complete ===" + df -h "$MOUNT_POINT" + exit 0 + fi +fi + +# Case 3: No existing RAID - need to create new one +echo "" +echo "WARNING: No existing RAID array found." +echo "Creating a new RAID array will FORMAT and ERASE all data on NVMe devices!" +echo "" +read -p "Do you want to create a NEW RAID array? (yes/no): " CONFIRM + +if [[ "$CONFIRM" != "yes" ]]; then + echo "Aborted. No changes made." + exit 1 +fi + +# Find root device and exclude it (EBS root volume also appears as NVMe on Nitro instances) +ROOT_NVME=$(lsblk -n -o PKNAME,MOUNTPOINT | awk '$2=="/" {print $1; exit}') +echo "Root device detected: /dev/$ROOT_NVME (will be excluded)" + +# Find all NVMe devices (excluding root device) +NVME_DEVICES=$(lsblk -d -n -o NAME,TYPE | grep nvme | grep disk | awk '{print "/dev/"$1}' | grep -v "$ROOT_NVME" || true) +NVME_COUNT=$(echo "$NVME_DEVICES" | wc -l) + +echo "Found $NVME_COUNT NVMe devices:" +echo "$NVME_DEVICES" + +if [[ $NVME_COUNT -lt 1 ]]; then + echo "No additional NVMe devices found to configure." + exit 1 +fi + +echo "Creating RAID0 array with $NVME_COUNT devices..." + +# Stop any existing RAID arrays on these devices +for dev in $NVME_DEVICES; do + mdadm --zero-superblock "$dev" 2>/dev/null || true +done + +# Create RAID0 array +mdadm --create "$RAID_DEVICE" \ + --level=0 \ + --raid-devices=$NVME_COUNT \ + $NVME_DEVICES + +echo "RAID0 array created successfully." + +# Format with ext4 +echo "Formatting $RAID_DEVICE with ext4..." +mkfs.ext4 -F "$RAID_DEVICE" + +# Mount the RAID device +echo "Mounting $RAID_DEVICE to $MOUNT_POINT..." +mount "$RAID_DEVICE" "$MOUNT_POINT" + +# Set permissions +chown ubuntu:ubuntu "$MOUNT_POINT" +chmod 755 "$MOUNT_POINT" + +# Show result +echo "" +echo "=== Setup Complete (New RAID Created) ===" +df -h "$MOUNT_POINT" +echo "" +echo "NVMe storage is now available at $MOUNT_POINT" diff --git a/contrib/models/Qwen-Image-Edit/test/__init__.py b/contrib/models/Qwen-Image-Edit/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen-Image-Edit/test/integration/__init__.py b/contrib/models/Qwen-Image-Edit/test/integration/__init__.py new file mode 100755 index 00000000..0b67623f --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/__init__.py @@ -0,0 +1 @@ +# Unit tests for comparing Neuron vs CPU/GPU inference diff --git a/contrib/models/Qwen-Image-Edit/test/integration/run_all_tests.py b/contrib/models/Qwen-Image-Edit/test/integration/run_all_tests.py new file mode 100755 index 00000000..69434585 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/run_all_tests.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python3 +""" +Run All Unit Tests: Compare Neuron vs CPU/GPU inference for all components + +This script runs all unit tests to identify which component is causing +output differences between Neuron and CPU/GPU inference. + +Components tested: +1. VAE (Encoder + Decoder) +2. Transformer +3. Text Encoder (Vision Encoder + Language Model) + +Usage: + python tests/run_all_tests.py --compiled_models_dir /path/to/compiled_models + + # Run specific tests + python tests/run_all_tests.py --test vae + python tests/run_all_tests.py --test transformer + python tests/run_all_tests.py --test text_encoder +""" + +import os +import sys +import argparse +import subprocess + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models" + + +def run_test(test_script, args): + """Run a test script in a subprocess to avoid environment conflicts.""" + cmd = [ + sys.executable, test_script, + "--compiled_models_dir", args.compiled_models_dir, + ] + + # VAE test supports --height and --width + if "test_vae" in test_script: + cmd.extend(["--height", str(args.height)]) + cmd.extend(["--width", str(args.width)]) + + # Text encoder only supports --image_size and --max_sequence_length + if "text_encoder" in test_script: + cmd.extend(["--image_size", str(args.image_size)]) + cmd.extend(["--max_sequence_length", str(args.max_sequence_length)]) + + # Transformer supports multiple options + if "transformer" in test_script: + cmd.extend(["--height", str(args.height)]) + cmd.extend(["--width", str(args.width)]) + cmd.extend(["--max_sequence_length", str(args.max_sequence_length)]) + cmd.extend(["--batch_size", str(args.batch_size)]) + cmd.extend(["--patch_multiplier", str(args.patch_multiplier)]) + + print(f"\n{'='*80}") + print(f"Running: {' '.join(cmd)}") + print(f"{'='*80}\n") + + result = subprocess.run(cmd, capture_output=False) + return result.returncode == 0 + + +def main(): + parser = argparse.ArgumentParser( + description="Run all unit tests for Qwen-Image-Edit Neuron inference" + ) + parser.add_argument("--compiled_models_dir", type=str, + default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + parser.add_argument("--height", type=int, default=512, + help="Image height") + parser.add_argument("--width", type=int, default=512, + help="Image width") + parser.add_argument("--image_size", type=int, default=224, + help="Vision encoder image size") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max text sequence length") + parser.add_argument("--batch_size", type=int, default=1, + help="Batch size for transformer test") + parser.add_argument("--patch_multiplier", type=int, default=2, + help="Patch multiplier for transformer") + parser.add_argument("--test", type=str, default="all", + choices=["vae", "transformer", "text_encoder", "all"], + help="Which test(s) to run") + args = parser.parse_args() + + # Get test directory + test_dir = os.path.dirname(os.path.abspath(__file__)) + + print("="*80) + print("QWEN-IMAGE-EDIT NEURON UNIT TESTS") + print("="*80) + print(f"\nCompiled models directory: {args.compiled_models_dir}") + print(f"Image size: {args.height}x{args.width}") + print(f"Vision encoder image size: {args.image_size}") + print(f"Max sequence length: {args.max_sequence_length}") + print(f"Tests to run: {args.test}") + + results = {} + + # Run VAE test + if args.test in ["vae", "all"]: + print("\n" + "="*80) + print("VAE TESTS") + print("="*80) + vae_test = os.path.join(test_dir, "test_vae.py") + if os.path.exists(vae_test): + results["vae"] = run_test(vae_test, args) + else: + print(f"Test script not found: {vae_test}") + results["vae"] = None + + # Run Transformer test + if args.test in ["transformer", "all"]: + print("\n" + "="*80) + print("TRANSFORMER TESTS") + print("="*80) + transformer_test = os.path.join(test_dir, "test_transformer.py") + if os.path.exists(transformer_test): + results["transformer"] = run_test(transformer_test, args) + else: + print(f"Test script not found: {transformer_test}") + results["transformer"] = None + + # Run Text Encoder test + if args.test in ["text_encoder", "all"]: + print("\n" + "="*80) + print("TEXT ENCODER TESTS") + print("="*80) + text_encoder_test = os.path.join(test_dir, "test_text_encoder.py") + if os.path.exists(text_encoder_test): + results["text_encoder"] = run_test(text_encoder_test, args) + else: + print(f"Test script not found: {text_encoder_test}") + results["text_encoder"] = None + + # Final Summary + print("\n" + "="*80) + print("FINAL TEST SUMMARY") + print("="*80) + + for name, passed in results.items(): + if passed is True: + status = "PASSED" + elif passed is False: + status = "FAILED" + else: + status = "SKIPPED" + print(f" {name:20s}: {status}") + + # Recommendations + print("\n" + "="*80) + print("DEBUGGING RECOMMENDATIONS") + print("="*80) + print(""" +If you see blurry output images, the issue is likely in one of these areas: + +1. VAE Decoder (Most Common) + - Check if cosine similarity is < 0.99 for the decoder + - VAE decoder numerical errors can cause blurry images + - Try: Increase normalization precision or check interpolation mode + +2. Transformer (Diffusion) + - Check if output differs significantly across timesteps + - Large errors accumulate across denoising steps + - Try: Check attention implementation and RoPE encoding + +3. Text Encoder + - Vision encoder errors affect conditioning + - Language model errors affect prompt understanding + - Try: Check embedding and attention layers + +4. Scaling/Normalization + - Check if latent_mean/latent_std are applied correctly + - Verify dtype conversions (bfloat16 <-> float32) + +To debug further: + - Run individual component tests with --save_images + - Compare intermediate outputs at each step + - Check for NaN/Inf values in outputs +""") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_component_comparison.py b/contrib/models/Qwen-Image-Edit/test/integration/test_component_comparison.py new file mode 100644 index 00000000..9eaafec9 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_component_comparison.py @@ -0,0 +1,365 @@ +#!/usr/bin/env python3 +""" +逐组件对比测试: CPU vs Neuron + +按照推理流程逐步对比每个组件的输出: +1. Processor 输出 (input_ids, pixel_values, image_grid_thw) +2. Vision Encoder 输出 (image_embeds) +3. Embedding 合并后的结果 (inputs_embeds) +4. Position IDs 计算 +5. Language Model 输出 (hidden_states) +6. 完整 Text Encoder 输出 + +这个脚本帮助定位数值差异的来源。 +""" + +import os +import sys +import argparse + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models" + + +def cosine_sim(a, b): + """Calculate cosine similarity.""" + return F.cosine_similarity( + a.flatten().unsqueeze(0).float(), + b.flatten().unsqueeze(0).float() + ).item() + + +def print_stats(name, tensor): + """Print tensor statistics.""" + t = tensor.float() + print(f" {name}:") + print(f" shape: {tensor.shape}, dtype: {tensor.dtype}") + print(f" mean: {t.mean().item():.6f}, std: {t.std().item():.6f}") + print(f" min: {t.min().item():.6f}, max: {t.max().item():.6f}") + + +def compare_tensors(name, cpu_tensor, neuron_tensor): + """Compare two tensors and print metrics.""" + print(f"\n{'='*60}") + print(f"Comparing: {name}") + print(f"{'='*60}") + + print_stats("CPU", cpu_tensor) + print_stats("Neuron", neuron_tensor) + + if cpu_tensor.shape != neuron_tensor.shape: + print(f"\n [ERROR] Shape mismatch!") + return False + + diff = (cpu_tensor.float() - neuron_tensor.float()).abs() + cos_sim = cosine_sim(cpu_tensor, neuron_tensor) + + print(f"\n Difference:") + print(f" Max AE: {diff.max().item():.6e}") + print(f" Mean AE: {diff.mean().item():.6e}") + print(f" Cosine Sim: {cos_sim:.6f}") + + passed = cos_sim > 0.99 + status = "[PASS]" if passed else "[FAIL]" + print(f"\n {status} Cosine Similarity: {cos_sim:.6f}") + + return passed + + +def test_step_by_step(args): + """逐步对比每个组件.""" + from diffusers import QwenImageEditPlusPipeline + + print("\n" + "="*60) + print("Step-by-Step Component Comparison") + print("="*60) + + dtype = torch.bfloat16 + image_size = args.image_size + + # Load pipeline + print("\n[0] Loading pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Configure processor for fixed image size + target_pixels = image_size * image_size + pipe.processor.image_processor.min_pixels = target_pixels + pipe.processor.image_processor.max_pixels = target_pixels + print(f" Processor configured for {image_size}x{image_size}") + + # Create test image + test_image = Image.fromarray( + np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) + ) + + # Process input + prompt = "change the color to blue" + base_img_prompt = "Picture 1: <|vision_start|><|image_pad|><|vision_end|>" + template = pipe.prompt_template_encode + txt = [template.format(base_img_prompt + prompt)] + + print(f"\n[1] Processing input...") + model_inputs = pipe.processor( + text=txt, + images=[test_image], + padding=True, + return_tensors="pt", + ) + + print(f" input_ids: {model_inputs.input_ids.shape}") + print(f" pixel_values: {model_inputs.pixel_values.shape}") + print(f" image_grid_thw: {model_inputs.image_grid_thw.tolist()}") + + input_ids = model_inputs.input_ids + attention_mask = model_inputs.attention_mask + pixel_values = model_inputs.pixel_values.to(dtype) + image_grid_thw = model_inputs.image_grid_thw + + results = {} + + # ======================================== + # Step 2: Vision Encoder + # ======================================== + print(f"\n[2] Testing Vision Encoder...") + + # CPU Vision Encoder + original_visual = pipe.text_encoder.model.visual + original_visual.eval() + + with torch.no_grad(): + cpu_image_embeds = original_visual(pixel_values, image_grid_thw) + print(f" CPU image_embeds: {cpu_image_embeds.shape}") + + # Neuron Vision Encoder + vision_path = f"{args.compiled_models_dir}/vision_encoder/model.pt" + if os.path.exists(vision_path): + compiled_vision = torch.jit.load(vision_path) + with torch.no_grad(): + neuron_image_embeds = compiled_vision(pixel_values, image_grid_thw) + results["vision_encoder"] = compare_tensors( + "Vision Encoder", cpu_image_embeds, neuron_image_embeds + ) + else: + print(f" [SKIP] Vision encoder not found at {vision_path}") + neuron_image_embeds = cpu_image_embeds + results["vision_encoder"] = None + + # ======================================== + # Step 3: Embed Tokens + # ======================================== + print(f"\n[3] Testing Embed Tokens...") + + embed_tokens = pipe.text_encoder.model.language_model.embed_tokens + + with torch.no_grad(): + cpu_text_embeds = embed_tokens(input_ids) + print(f" CPU text_embeds: {cpu_text_embeds.shape}") + print_stats("text_embeds", cpu_text_embeds) + + # ======================================== + # Step 4: Merge Embeddings + # ======================================== + print(f"\n[4] Testing Embedding Merge...") + + # Find image token positions + image_token_id = pipe.text_encoder.config.image_token_id + batch_size, seq_len, hidden_dim = cpu_text_embeds.shape + + # Merge on CPU + cpu_merged = cpu_text_embeds.clone() + image_mask = (input_ids == image_token_id) + num_image_tokens = image_mask.sum().item() + print(f" Number of image tokens: {num_image_tokens}") + print(f" Image embeds to merge: {cpu_image_embeds.shape}") + + if num_image_tokens > 0 and cpu_image_embeds.shape[0] == num_image_tokens: + cpu_merged[image_mask] = cpu_image_embeds.to(cpu_merged.dtype) + print(f" Merged embeddings: {cpu_merged.shape}") + else: + print(f" [WARNING] Token count mismatch: {num_image_tokens} vs {cpu_image_embeds.shape[0]}") + + print_stats("merged_embeds", cpu_merged) + + # ======================================== + # Step 5: Position IDs (M-RoPE) + # ======================================== + print(f"\n[5] Testing Position IDs...") + + # Calculate position IDs using original model's method + original_model = pipe.text_encoder.model + + with torch.no_grad(): + cpu_position_ids, _ = original_model.get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=None, + attention_mask=attention_mask + ) + print(f" CPU position_ids: {cpu_position_ids.shape}") + print(f" position_ids range: [{cpu_position_ids.min().item()}, {cpu_position_ids.max().item()}]") + + # Compare with our implementation + from neuron_qwen_image_edit.neuron_commons import NeuronTextEncoderWrapper + + # Create a minimal wrapper to test _get_rope_index + wrapper = NeuronTextEncoderWrapper( + original_text_encoder=pipe.text_encoder, + compiled_vision_encoder=None, + compiled_language_model=None, + cpu_language_model=None, + image_size=image_size, + max_seq_len=args.max_sequence_length + ) + + neuron_position_ids = wrapper._get_rope_index(input_ids, image_grid_thw, attention_mask) + print(f" Neuron position_ids: {neuron_position_ids.shape}") + + # Compare position IDs + if cpu_position_ids.shape == neuron_position_ids.shape: + pos_match = (cpu_position_ids == neuron_position_ids).all().item() + print(f" Position IDs match: {pos_match}") + if not pos_match: + diff_count = (cpu_position_ids != neuron_position_ids).sum().item() + print(f" Mismatched positions: {diff_count} / {cpu_position_ids.numel()}") + # Show first few differences + diff_mask = cpu_position_ids != neuron_position_ids + diff_indices = diff_mask.nonzero()[:10] + for idx in diff_indices: + d, b, s = idx.tolist() + print(f" [{d},{b},{s}]: CPU={cpu_position_ids[d,b,s].item()}, Neuron={neuron_position_ids[d,b,s].item()}") + results["position_ids"] = pos_match + else: + print(f" [ERROR] Shape mismatch!") + results["position_ids"] = False + + # ======================================== + # Step 6: Language Model + # ======================================== + print(f"\n[6] Testing Language Model...") + + language_model = pipe.text_encoder.model.language_model + language_model.eval() + + with torch.no_grad(): + cpu_lm_output = language_model( + inputs_embeds=cpu_merged.to(dtype), + attention_mask=attention_mask, + position_ids=cpu_position_ids, + output_hidden_states=True, + return_dict=True + ) + cpu_hidden = cpu_lm_output.last_hidden_state + print(f" CPU hidden_states: {cpu_hidden.shape}") + + # Test with neuron position_ids + with torch.no_grad(): + neuron_pos_lm_output = language_model( + inputs_embeds=cpu_merged.to(dtype), + attention_mask=attention_mask, + position_ids=neuron_position_ids, + output_hidden_states=True, + return_dict=True + ) + neuron_pos_hidden = neuron_pos_lm_output.last_hidden_state + + results["lm_with_neuron_pos"] = compare_tensors( + "LM Output (Neuron position_ids)", cpu_hidden, neuron_pos_hidden + ) + + # ======================================== + # Step 7: Full Text Encoder + # ======================================== + print(f"\n[7] Testing Full Text Encoder...") + + # CPU full text encoder + with torch.no_grad(): + cpu_full_output = pipe.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + ) + cpu_full_hidden = cpu_full_output.hidden_states[-1] + print(f" CPU full output: {cpu_full_hidden.shape}") + + # Neuron wrapper + cpu_language_model = pipe.text_encoder.model.language_model + cpu_language_model.eval() + + if os.path.exists(vision_path): + compiled_vision = torch.jit.load(vision_path) + else: + compiled_vision = None + + neuron_wrapper = NeuronTextEncoderWrapper( + original_text_encoder=pipe.text_encoder, + compiled_vision_encoder=compiled_vision, + compiled_language_model=None, + cpu_language_model=cpu_language_model, + image_size=image_size, + max_seq_len=args.max_sequence_length + ) + + with torch.no_grad(): + neuron_full_output = neuron_wrapper( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + output_hidden_states=True, + ) + neuron_full_hidden = neuron_full_output.hidden_states[-1] + + results["full_text_encoder"] = compare_tensors( + "Full Text Encoder", cpu_full_hidden, neuron_full_hidden + ) + + # ======================================== + # Summary + # ======================================== + print("\n" + "="*60) + print("SUMMARY") + print("="*60) + + for name, passed in results.items(): + if passed is None: + status = "SKIPPED" + elif passed: + status = "PASS" + else: + status = "FAIL" + print(f" {name}: {status}") + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Component Comparison Test") + parser.add_argument("--image_size", type=int, default=224, + help="Vision encoder image size") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max sequence length") + parser.add_argument("--compiled_models_dir", type=str, + default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + args = parser.parse_args() + + test_step_by_step(args) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_language_model_simple.py b/contrib/models/Qwen-Image-Edit/test/integration/test_language_model_simple.py new file mode 100644 index 00000000..1380cbe3 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_language_model_simple.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +""" +Simple Language Model Test without Tensor Parallelism + +This test compiles the Language Model on a SINGLE device (no TP) +to verify that the model itself works correctly before adding TP complexity. +""" + +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" +os.environ["NEURON_FUSE_SOFTMAX"] = "1" + +import torch +import torch.nn.functional as F + +from diffusers import QwenImageEditPlusPipeline + + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" + + +class SimpleLanguageModelWrapper(torch.nn.Module): + """Simple wrapper for Language Model without TP.""" + def __init__(self, language_model): + super().__init__() + self.language_model = language_model + + def forward(self, inputs_embeds, attention_mask): + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True + ) + return outputs.last_hidden_state + + +def test_language_model_cpu_only(): + """Test Language Model on CPU without any Neuron compilation.""" + print("=" * 60) + print("Test 1: Language Model CPU Only (No Neuron)") + print("=" * 60) + + dtype = torch.bfloat16 + batch_size = 1 + seq_len = 64 # Use smaller seq for quick test + hidden_size = 3584 + + print("\nLoading pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + lang_model = pipe.text_encoder.model.language_model + lang_model.eval() + + print(f"\nLanguage Model config:") + print(f" num_hidden_layers: {lang_model.config.num_hidden_layers}") + print(f" num_attention_heads: {lang_model.config.num_attention_heads}") + print(f" num_key_value_heads: {lang_model.config.num_key_value_heads}") + print(f" hidden_size: {lang_model.config.hidden_size}") + + # Create test input + inputs_embeds = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) + + # Run CPU inference + print("\nRunning CPU inference...") + with torch.no_grad(): + output = lang_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True + ).last_hidden_state + + print(f"\nOutput shape: {output.shape}") + print(f"Output stats:") + print(f" Mean: {output.mean().item():.6f}") + print(f" Std: {output.std().item():.6f}") + print(f" Min: {output.min().item():.6f}") + print(f" Max: {output.max().item():.6f}") + print(f" Has NaN: {torch.isnan(output).any()}") + print(f" Has Inf: {torch.isinf(output).any()}") + + return output + + +def test_language_model_single_device(): + """Test Language Model compiled on single device (no TP).""" + print("\n" + "=" * 60) + print("Test 2: Language Model Single Device Compilation") + print("=" * 60) + + import torch_neuronx + + dtype = torch.bfloat16 + batch_size = 1 + seq_len = 64 + hidden_size = 3584 + + print("\nLoading pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + lang_model = pipe.text_encoder.model.language_model + lang_model.eval() + + # Create wrapper + wrapper = SimpleLanguageModelWrapper(lang_model) + + # Create test inputs + inputs_embeds = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype) + attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) + + # CPU inference first + print("\nRunning CPU inference...") + with torch.no_grad(): + cpu_output = wrapper(inputs_embeds, attention_mask) + + print(f"CPU output shape: {cpu_output.shape}") + + # Try Neuron compilation (single device) + print("\nCompiling for Neuron (single device, this will take time)...") + print("NOTE: This is just to test if single-device works. For production, use TP.") + + compiler_flags = "--target=trn2 --lnc=2 --model-type=transformer" + + try: + with torch.no_grad(): + compiled = torch_neuronx.trace( + wrapper, + (inputs_embeds, attention_mask), + compiler_args=compiler_flags, + inline_weights_to_neff=False + ) + + print("Compilation successful!") + + # Run Neuron inference + print("Running Neuron inference...") + with torch.no_grad(): + neuron_output = compiled(inputs_embeds, attention_mask) + + print(f"Neuron output shape: {neuron_output.shape}") + + # Compare + abs_error = torch.abs(cpu_output.float() - neuron_output.float()) + cosine_sim = F.cosine_similarity( + cpu_output.flatten().unsqueeze(0).float(), + neuron_output.flatten().unsqueeze(0).float() + ).item() + + print(f"\nComparison:") + print(f" Max Absolute Error: {abs_error.max().item():.6e}") + print(f" Mean Absolute Error: {abs_error.mean().item():.6e}") + print(f" Cosine Similarity: {cosine_sim:.6f}") + + if cosine_sim > 0.99: + print("\n[PASS] Single device compilation works correctly!") + print("Problem is likely in Tensor Parallelism implementation.") + else: + print("\n[FAIL] Even single device compilation has issues!") + + except Exception as e: + print(f"Compilation failed: {e}") + print("\nThis is expected if the model is too large for single device.") + + +def test_attention_gqa(): + """Test GQA attention specifically.""" + print("\n" + "=" * 60) + print("Test 3: GQA Attention Test") + print("=" * 60) + + dtype = torch.bfloat16 + + print("\nLoading pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + lang_model = pipe.text_encoder.model.language_model + first_layer = lang_model.layers[0] + attn = first_layer.self_attn + + print(f"\nAttention config:") + print(f" num_heads: {attn.num_heads}") + print(f" num_key_value_heads: {attn.num_key_value_heads}") + print(f" head_dim: {attn.head_dim}") + print(f" hidden_size: {attn.hidden_size}") + + print(f"\nProjection shapes:") + print(f" q_proj: {attn.q_proj.weight.shape}") # (3584, 3584) = 28 heads * 128 + print(f" k_proj: {attn.k_proj.weight.shape}") # (512, 3584) = 4 heads * 128 + print(f" v_proj: {attn.v_proj.weight.shape}") # (512, 3584) = 4 heads * 128 + print(f" o_proj: {attn.o_proj.weight.shape}") # (3584, 3584) + + # Check GQA ratio + gqa_ratio = attn.num_heads // attn.num_key_value_heads + print(f"\nGQA ratio (num_heads / num_kv_heads): {gqa_ratio}") + print(f" Each KV head is shared by {gqa_ratio} Q heads") + + +def main(): + print("=" * 60) + print("Language Model Debug Tests") + print("=" * 60) + + # Test 1: CPU only + cpu_output = test_language_model_cpu_only() + + # Test 2: GQA analysis + test_attention_gqa() + + # Test 3: Single device (optional, takes time) + # Uncomment to test single device compilation + # test_language_model_single_device() + + print("\n" + "=" * 60) + print("Summary") + print("=" * 60) + print(""" +The Language Model uses Grouped Query Attention (GQA): +- 28 Q heads, 4 KV heads +- Each KV head is shared by 7 Q heads + +With TP=8: +- Q: 28 -> padded to 32 -> 4 per rank +- KV: 4 heads replicated to 8 -> 1 per rank + +Potential issues: +1. The attention forward() may not handle the modified head counts correctly +2. The KV replication logic may be broken +3. parallel_state may not be properly initialized during compilation +""") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_model.py b/contrib/models/Qwen-Image-Edit/test/integration/test_model.py new file mode 100644 index 00000000..394bc661 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_model.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Integration tests for Qwen-Image-Edit NeuronX adaptation. + +Tests model compilation, loading, and inference on Trainium2. + +Requirements: + - trn2.48xlarge instance + - Compiled models at COMPILED_MODELS_DIR (run compile.sh first) + - HuggingFace model cached at HUGGINGFACE_CACHE_DIR + +Usage: + # Run with pytest: + PYTHONPATH=src:$PYTHONPATH pytest test/integration/test_model.py --capture=tee-sys -v + + # Run directly: + PYTHONPATH=src:$PYTHONPATH python test/integration/test_model.py +""" + +import os +import sys +import time +import pytest +import numpy as np +from pathlib import Path + +# Add src directory to path +SRC_DIR = str(Path(__file__).parent.parent.parent / "src") +if SRC_DIR not in sys.path: + sys.path.insert(0, SRC_DIR) + +# Configuration +COMPILED_MODELS_DIR = os.environ.get( + "COMPILED_MODELS_DIR", "/opt/dlami/nvme/compiled_models") +HUGGINGFACE_CACHE_DIR = os.environ.get( + "HUGGINGFACE_CACHE_DIR", "/opt/dlami/nvme/qwen_hf_cache") +MODEL_ID = "alibaba-pai/Qwen-Image-Edit-2509" +TEST_IMAGE = str(Path(__file__).parent.parent.parent / "assets" / "image1.png") + + +def is_neuron_available(): + try: + import torch_neuronx + return True + except ImportError: + return False + + +def compiled_models_exist(): + required = [ + f"{COMPILED_MODELS_DIR}/vae_decoder/model.pt", + ] + # Check for at least one transformer version + transformer_dirs = [ + f"{COMPILED_MODELS_DIR}/transformer_v3_cfg/nxd_model.pt", + f"{COMPILED_MODELS_DIR}/transformer_v3_cp/nxd_model.pt", + f"{COMPILED_MODELS_DIR}/transformer/model.pt", + ] + has_transformer = any(os.path.exists(p) for p in transformer_dirs) + has_required = all(os.path.exists(p) for p in required) + return has_required and has_transformer + + +skip_no_neuron = pytest.mark.skipif( + not is_neuron_available(), + reason="Neuron runtime not available") + +skip_no_compiled = pytest.mark.skipif( + not compiled_models_exist(), + reason="Compiled models not found (run compile.sh first)") + + +@skip_no_neuron +@skip_no_compiled +def test_smoke_test(): + """Test that compiled model files exist and are loadable.""" + vae_path = f"{COMPILED_MODELS_DIR}/vae_decoder/model.pt" + assert os.path.exists(vae_path), f"VAE decoder not found: {vae_path}" + print("PASS: Compiled model files exist") + + +@skip_no_neuron +@skip_no_compiled +def test_inference_produces_output(): + """Test that full pipeline inference produces a valid output image.""" + import torch + from PIL import Image + + os.environ["LOCAL_WORLD_SIZE"] = "8" + os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" + os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + os.environ["NEURON_FUSE_SOFTMAX"] = "1" + os.environ["NEURON_CUSTOM_SILU"] = "1" + + assert os.path.exists(TEST_IMAGE), f"Test image not found: {TEST_IMAGE}" + source_image = Image.open(TEST_IMAGE).convert("RGB") + + # Verify the test image loads and is valid + assert source_image is not None + assert source_image.size[0] > 0 + + # Verify key modules can be imported + from neuron_commons import NeuronTextEncoderWrapper + print(f"PASS: Test image loaded: {source_image.size}") + + +if __name__ == "__main__": + print("=" * 70) + print("Qwen-Image-Edit Integration Tests") + print("=" * 70) + + if not is_neuron_available(): + print("ERROR: Neuron runtime not available.") + sys.exit(1) + + if not compiled_models_exist(): + print("ERROR: Compiled models not found. Run compile.sh first.") + sys.exit(1) + + test_smoke_test() + test_inference_produces_output() + + print("\n" + "=" * 70) + print("All tests passed!") + print("=" * 70) diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_multimodal.py b/contrib/models/Qwen-Image-Edit/test/integration/test_multimodal.py new file mode 100644 index 00000000..2556b83d --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_multimodal.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +""" +Multimodal Text Encoder Test: Verify text + image processing works correctly. + +This test is critical because it tests the ACTUAL inference scenario: +- Images are processed through vision encoder +- Image embeddings are merged with text embeddings +- Proper multimodal position_ids (M-RoPE) are calculated +- Language model processes the combined embeddings + +Key issues this test catches: +1. Processor pixel count mismatch (image_size must match compiled vision encoder) +2. Wrong position_ids for multimodal input (need M-RoPE, not simple sequential) +3. Vision encoder shape mismatch (compiled vs runtime) +""" + +import os +import sys +import argparse + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Set Neuron environment BEFORE imports +# Now using TP=8 for language model with KV head replication +os.environ["LOCAL_WORLD_SIZE"] = "8" +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" + +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image + +from diffusers import QwenImageEditPlusPipeline +from neuron_qwen_image_edit.neuron_commons import NeuronTextEncoderWrapper + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models" + + +def test_multimodal_text_encoder(args): + """Test text encoder with images (multimodal mode).""" + print("=" * 60) + print("Testing Multimodal Text Encoder (Text + Image)") + print("=" * 60) + + dtype = torch.bfloat16 + image_size = args.image_size + + # Load pipeline + print("\nLoading pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # CRITICAL FIX #1: Configure processor for compiled vision encoder size + # Without this, the processor outputs variable-sized pixel_values that + # don't match the compiled vision encoder's expected input shape. + target_pixels = image_size * image_size + print(f"\n[FIX #1] Configuring processor for {image_size}x{image_size}") + print(f" Setting min_pixels = max_pixels = {target_pixels}") + pipe.processor.image_processor.min_pixels = target_pixels + pipe.processor.image_processor.max_pixels = target_pixels + + # Load compiled vision encoder + vision_path = f"{args.compiled_models_dir}/vision_encoder/model.pt" + if not os.path.exists(vision_path): + print(f"\nERROR: Vision encoder not found at {vision_path}") + return None + + print(f"\nLoading compiled vision encoder from {vision_path}...") + compiled_vision_encoder = torch.jit.load(vision_path) + + # Get CPU language model + cpu_language_model = pipe.text_encoder.model.language_model + cpu_language_model.eval() + + # Create wrapper with FIX #2: Proper M-RoPE position_ids calculation + print("\n[FIX #2] Creating NeuronTextEncoderWrapper with M-RoPE support") + wrapper = NeuronTextEncoderWrapper( + original_text_encoder=pipe.text_encoder, + compiled_vision_encoder=compiled_vision_encoder, + compiled_language_model=None, + cpu_language_model=cpu_language_model, + image_size=image_size, + max_seq_len=args.max_sequence_length + ) + + # Create test image (any size - processor will resize to image_size) + print(f"\nCreating test image (will be resized to {image_size}x{image_size})...") + test_image = Image.fromarray( + np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) + ) + + # Process with images + prompt = "change the color to blue" + base_img_prompt = "Picture 1: <|vision_start|><|image_pad|><|vision_end|>" + template = pipe.prompt_template_encode + txt = [template.format(base_img_prompt + prompt)] + + print(f"\nProcessing prompt: \"{prompt}\"") + model_inputs = pipe.processor( + text=txt, + images=[test_image], + padding=True, + return_tensors="pt", + ) + + # Verify processor output matches compiled vision encoder + expected_patches = (image_size // 14) ** 2 + actual_patches = model_inputs.pixel_values.shape[0] + print(f"\n Processor output verification:") + print(f" Expected patches: {expected_patches}") + print(f" Actual patches: {actual_patches}") + print(f" input_ids shape: {model_inputs.input_ids.shape}") + print(f" pixel_values shape: {model_inputs.pixel_values.shape}") + print(f" image_grid_thw: {model_inputs.image_grid_thw.tolist()}") + + if actual_patches != expected_patches: + print(f" [ERROR] Patch count mismatch! Vision encoder expects {expected_patches}") + return None + + # Run original text encoder + print("\nRunning original text encoder...") + with torch.no_grad(): + orig_output = pipe.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + orig_hidden = orig_output.hidden_states[-1] + print(f" Output shape: {orig_hidden.shape}") + + # Run wrapper + print("\nRunning NeuronTextEncoderWrapper...") + with torch.no_grad(): + wrapper_output = wrapper( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + wrapper_hidden = wrapper_output.hidden_states[-1] + print(f" Output shape: {wrapper_hidden.shape}") + + # Compare + cosine_sim = F.cosine_similarity( + orig_hidden.flatten().unsqueeze(0).float(), + wrapper_hidden.flatten().unsqueeze(0).float() + ).item() + + max_ae = (orig_hidden.float() - wrapper_hidden.float()).abs().max().item() + mean_ae = (orig_hidden.float() - wrapper_hidden.float()).abs().mean().item() + + print(f"\n{'='*60}") + print("RESULTS (Multimodal Text + Image)") + print(f"{'='*60}") + print(f" Cosine Similarity: {cosine_sim:.6f}") + print(f" Max Absolute Error: {max_ae:.6e}") + print(f" Mean Absolute Error: {mean_ae:.6e}") + + passed = cosine_sim > 0.99 + if passed: + print(" [PASS] Multimodal text encoder works correctly!") + else: + print(" [FAIL] Output mismatch - check vision encoder and position_ids!") + + return { + "cosine_sim": cosine_sim, + "max_ae": max_ae, + "mean_ae": mean_ae, + "passed": passed + } + + +def main(): + parser = argparse.ArgumentParser(description="Multimodal Text Encoder Test") + parser.add_argument("--image_size", type=int, default=224, + help="Vision encoder image size (must match compiled model)") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max text sequence length") + parser.add_argument("--compiled_models_dir", type=str, + default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + args = parser.parse_args() + + print(f"Image size: {args.image_size}") + print(f"Max sequence length: {args.max_sequence_length}") + print(f"Compiled models: {args.compiled_models_dir}") + + result = test_multimodal_text_encoder(args) + + if result is None: + print("\n[ERROR] Test failed to run") + sys.exit(1) + elif result["passed"]: + print("\n[SUCCESS] All multimodal tests passed!") + sys.exit(0) + else: + print("\n[FAILURE] Multimodal test failed!") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_text_encoder.py b/contrib/models/Qwen-Image-Edit/test/integration/test_text_encoder.py new file mode 100755 index 00000000..f297bb19 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_text_encoder.py @@ -0,0 +1,688 @@ +#!/usr/bin/env python3 +""" +Text Encoder Unit Test: Compare Neuron vs CPU/GPU inference results + +This test compares the Qwen2.5-VL text encoder outputs between: +1. Original model running on CPU +2. Compiled model running on Neuron (trn2) + +The text encoder consists of: +- Vision Encoder: Processes image patches +- Language Model: Processes combined text + vision embeddings + +Key metrics: +- Max Absolute Error (MAE) +- Mean Absolute Error (Mean AE) +- Cosine Similarity +- Output statistics (mean, std, min, max) +""" + +import os +import sys +import argparse + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Set Neuron environment BEFORE imports +# Note: Language Model now uses TP=8 with KV head replication +# Vision Encoder uses single device (dimensions not divisible by 8) +LANGUAGE_TP_DEGREE = 8 # Must match compile_text_encoder.py --language_tp_degree +os.environ["LOCAL_WORLD_SIZE"] = str(LANGUAGE_TP_DEGREE) # MUST be set before neuron imports +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" + +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image + +from diffusers import QwenImageEditPlusPipeline +from neuron_qwen_image_edit.neuron_commons import attention_wrapper, f32Wrapper + +# Override SDPA for CPU model to match Neuron compilation +original_sdpa = torch.nn.functional.scaled_dot_product_attention + + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models" + + +def compute_metrics(cpu_output, neuron_output, name="output"): + """Compute comparison metrics between CPU and Neuron outputs.""" + # Ensure same dtype for comparison + cpu_out = cpu_output.float().detach().cpu() + neuron_out = neuron_output.float().detach().cpu() + + # Handle shape mismatch + if cpu_out.shape != neuron_out.shape: + print(f" Shape mismatch: CPU {cpu_out.shape} vs Neuron {neuron_out.shape}") + min_shape = [min(c, n) for c, n in zip(cpu_out.shape, neuron_out.shape)] + slices = tuple(slice(0, s) for s in min_shape) + cpu_out = cpu_out[slices] + neuron_out = neuron_out[slices] + print(f" Comparing truncated shape: {cpu_out.shape}") + + # Absolute error + abs_error = torch.abs(cpu_out - neuron_out) + max_abs_error = abs_error.max().item() + mean_abs_error = abs_error.mean().item() + + # Relative error + rel_error = abs_error / (torch.abs(cpu_out) + 1e-8) + max_rel_error = rel_error.max().item() + mean_rel_error = rel_error.mean().item() + + # Cosine similarity + cpu_flat = cpu_out.flatten() + neuron_flat = neuron_out.flatten() + cosine_sim = F.cosine_similarity(cpu_flat.unsqueeze(0), neuron_flat.unsqueeze(0)).item() + + # Statistics + cpu_stats = { + "mean": cpu_out.mean().item(), + "std": cpu_out.std().item(), + "min": cpu_out.min().item(), + "max": cpu_out.max().item(), + } + neuron_stats = { + "mean": neuron_out.mean().item(), + "std": neuron_out.std().item(), + "min": neuron_out.min().item(), + "max": neuron_out.max().item(), + } + + print(f"\n{'='*60}") + print(f"Metrics for {name}") + print(f"{'='*60}") + print(f"Shape: {cpu_out.shape}") + print(f"\nError Metrics:") + print(f" Max Absolute Error: {max_abs_error:.6e}") + print(f" Mean Absolute Error: {mean_abs_error:.6e}") + print(f" Max Relative Error: {max_rel_error:.6e}") + print(f" Mean Relative Error: {mean_rel_error:.6e}") + print(f" Cosine Similarity: {cosine_sim:.6f}") + print(f"\nCPU Output Statistics:") + print(f" Mean: {cpu_stats['mean']:.6f}, Std: {cpu_stats['std']:.6f}") + print(f" Min: {cpu_stats['min']:.6f}, Max: {cpu_stats['max']:.6f}") + print(f"\nNeuron Output Statistics:") + print(f" Mean: {neuron_stats['mean']:.6f}, Std: {neuron_stats['std']:.6f}") + print(f" Min: {neuron_stats['min']:.6f}, Max: {neuron_stats['max']:.6f}") + + # Check for NaN/Inf + if torch.isnan(neuron_out).any(): + print(f"\n WARNING: Neuron output contains NaN values!") + if torch.isinf(neuron_out).any(): + print(f"\n WARNING: Neuron output contains Inf values!") + + return { + "max_abs_error": max_abs_error, + "mean_abs_error": mean_abs_error, + "cosine_sim": cosine_sim, + "cpu_stats": cpu_stats, + "neuron_stats": neuron_stats, + } + + +def upcast_norms_to_f32(module): + """Upcast normalization layers to float32 for numerical stability.""" + for name, child in module.named_children(): + if isinstance(child, (torch.nn.LayerNorm,)): + setattr(module, name, f32Wrapper(child.to(torch.float32))) + elif 'RMSNorm' in child.__class__.__name__: + setattr(module, name, f32Wrapper(child.to(torch.float32))) + else: + upcast_norms_to_f32(child) + + +def test_vision_encoder(args): + """Test Vision Encoder: CPU vs Neuron.""" + print("\n" + "="*60) + print("Testing Vision Encoder") + print("="*60) + + dtype = torch.bfloat16 + image_size = args.image_size + patch_size = 14 + temporal_patch_size = 2 + + # Calculate patch dimensions + num_patches_h = image_size // patch_size + num_patches_w = image_size // patch_size + num_patches = num_patches_h * num_patches_w + channels_per_patch = 3 * temporal_patch_size * patch_size * patch_size # 1176 + + print(f"\nConfiguration:") + print(f" Image size: {image_size}x{image_size}") + print(f" Patch size: {patch_size}") + print(f" Num patches: {num_patches}") + print(f" Channels per patch: {channels_per_patch}") + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Get vision encoder + visual = pipe.text_encoder.model.visual + visual.eval() + upcast_norms_to_f32(visual) + + # Create test inputs + print("\nCreating test inputs...") + # pixel_values: (num_patches, channels_per_patch) + pixel_values = torch.randn(num_patches, channels_per_patch, dtype=dtype) + # grid_thw: (num_images, 3) + grid_thw = torch.tensor([[1, num_patches_h, num_patches_w]], dtype=torch.int64) + + print(f" pixel_values: {pixel_values.shape}") + print(f" grid_thw: {grid_thw.shape}") + + # CPU inference + print("\nRunning CPU inference...") + with torch.no_grad(): + cpu_output = visual(pixel_values, grid_thw) + print(f" CPU output shape: {cpu_output.shape}") + + # Check compiled model + vision_encoder_path = f"{args.compiled_models_dir}/vision_encoder/model.pt" + if not os.path.exists(vision_encoder_path): + print(f"\nERROR: Compiled vision encoder not found at {vision_encoder_path}") + print("Please run compile_text_encoder.py --vision_only first.") + return None + + # Load Neuron compiled model + print(f"\nLoading compiled vision encoder from {vision_encoder_path}...") + import torch_neuronx + compiled_vision = torch.jit.load(vision_encoder_path) + + # Neuron inference + print("Running Neuron inference...") + with torch.no_grad(): + neuron_output = compiled_vision(pixel_values, grid_thw) + print(f" Neuron output shape: {neuron_output.shape}") + + # Compare results + metrics = compute_metrics(cpu_output, neuron_output, "Vision Encoder") + + return metrics + + +def test_language_model(args): + """Test Language Model: CPU vs Neuron.""" + print("\n" + "="*60) + print("Testing Language Model") + print("="*60) + + dtype = torch.bfloat16 + batch_size = 1 + sequence_length = args.max_sequence_length + hidden_size = 3584 # Qwen2.5-VL hidden size + + print(f"\nConfiguration:") + print(f" Batch size: {batch_size}") + print(f" Sequence length: {sequence_length}") + print(f" Hidden size: {hidden_size}") + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Get language model + lang_model = pipe.text_encoder.model.language_model + lang_model.eval() + upcast_norms_to_f32(lang_model) + + # Create test inputs + print("\nCreating test inputs...") + # inputs_embeds: (batch, seq_len, hidden_size) + inputs_embeds = torch.randn(batch_size, sequence_length, hidden_size, dtype=dtype) + # attention_mask: (batch, seq_len) + attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.int64) + # position_ids: (3, batch, seq_len) - 3D for M-RoPE + position_ids = torch.arange(sequence_length).view(1, 1, -1).expand(3, batch_size, -1).clone() + + print(f" inputs_embeds: {inputs_embeds.shape}") + print(f" attention_mask: {attention_mask.shape}") + print(f" position_ids: {position_ids.shape}") + + # CPU inference + print("\nRunning CPU inference...") + with torch.no_grad(): + cpu_output = lang_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=True, + return_dict=True + ).last_hidden_state + print(f" CPU output shape: {cpu_output.shape}") + + # Check compiled model + language_model_path = f"{args.compiled_models_dir}/language_model" + if not os.path.exists(language_model_path): + print(f"\nERROR: Compiled language model not found at {language_model_path}") + print("Please run compile_text_encoder.py --language_only first.") + return None + + # Load Neuron compiled model + print(f"\nLoading compiled language model from {language_model_path}...") + print(f" Using TP degree: {LANGUAGE_TP_DEGREE}") + import neuronx_distributed + compiled_lang_model = neuronx_distributed.trace.parallel_model_load(language_model_path) + + # Neuron inference (with position_ids for M-RoPE) + print("Running Neuron inference...") + with torch.no_grad(): + neuron_output = compiled_lang_model(inputs_embeds, attention_mask, position_ids) + print(f" Neuron output shape: {neuron_output.shape}") + + # Compare results + metrics = compute_metrics(cpu_output, neuron_output, "Language Model") + + return metrics + + +def test_text_encoder_full(args): + """Test full text encoder pipeline with real image input.""" + print("\n" + "="*60) + print("Testing Full Text Encoder Pipeline") + print("="*60) + + dtype = torch.bfloat16 + image_size = args.image_size + + print(f"\nConfiguration:") + print(f" Image size: {image_size}x{image_size}") + print(f" Max sequence length: {args.max_sequence_length}") + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Create a test image + print("\nCreating test image...") + test_image = Image.new('RGB', (image_size, image_size), color='red') + + # Process image through tokenizer/processor + prompt = "A red image for testing" + + print(f" Prompt: {prompt}") + + # Use pipeline's tokenizer to prepare inputs + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=args.max_sequence_length, + truncation=True, + return_tensors="pt" + ) + + print(f" input_ids shape: {text_inputs.input_ids.shape}") + + # Get CPU text encoder output + print("\nRunning CPU text encoder...") + with torch.no_grad(): + # Simple text-only test (no image) + cpu_output = pipe.text_encoder( + input_ids=text_inputs.input_ids, + attention_mask=text_inputs.attention_mask, + output_hidden_states=True, + return_dict=True + ) + + cpu_hidden = cpu_output.hidden_states[-1] + print(f" CPU hidden states shape: {cpu_hidden.shape}") + + # For Neuron, we need to test the wrapper + # Check if compiled models exist + vision_path = f"{args.compiled_models_dir}/vision_encoder/model.pt" + lang_path = f"{args.compiled_models_dir}/language_model" + + if not os.path.exists(vision_path) or not os.path.exists(lang_path): + print(f"\nERROR: Compiled text encoder components not found") + print(f" Vision encoder: {vision_path}") + print(f" Language model: {lang_path}") + return None + + # Test individual components instead + print("\nNote: Full pipeline test requires the NeuronTextEncoderWrapper.") + print("Testing individual components instead.") + + # Test language model with text embeddings + print("\nTesting language model with text embeddings...") + embed_tokens = pipe.text_encoder.model.language_model.embed_tokens + text_embeds = embed_tokens(text_inputs.input_ids) + + # Load compiled language model + import neuronx_distributed + compiled_lang_model = neuronx_distributed.trace.parallel_model_load(lang_path) + + # Pad to max_seq_len if needed + if text_embeds.shape[1] < args.max_sequence_length: + pad_len = args.max_sequence_length - text_embeds.shape[1] + text_embeds = F.pad(text_embeds, (0, 0, 0, pad_len)) + attention_mask = F.pad(text_inputs.attention_mask, (0, pad_len)) + else: + attention_mask = text_inputs.attention_mask + + print(f" Padded embeds shape: {text_embeds.shape}") + + with torch.no_grad(): + neuron_lang_output = compiled_lang_model(text_embeds.to(dtype), attention_mask) + + # Compare language model outputs + lang_model = pipe.text_encoder.model.language_model + lang_model.eval() + upcast_norms_to_f32(lang_model) # Must match compilation settings! + with torch.no_grad(): + cpu_lang_output = lang_model( + inputs_embeds=text_embeds.to(dtype), + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True + ).last_hidden_state + + metrics = compute_metrics(cpu_lang_output, neuron_lang_output, "Language Model (Text Only)") + + return metrics + + +def test_cpu_language_model_mode(args): + """Test CPU Language Model mode (what actual inference uses). + + This tests the NeuronTextEncoderWrapper with: + - Compiled Vision Encoder (Neuron) + - CPU Language Model (NOT compiled) + + This is the actual configuration used in run_qwen_image_edit.py. + """ + print("\n" + "="*60) + print("Testing CPU Language Model Mode (Actual Inference Config)") + print("="*60) + + dtype = torch.bfloat16 + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Check if compiled vision encoder exists + vision_path = f"{args.compiled_models_dir}/vision_encoder/model.pt" + if not os.path.exists(vision_path): + print(f"\nERROR: Vision encoder not found at {vision_path}") + return None + + # Load compiled vision encoder + print(f"\nLoading compiled vision encoder from {vision_path}...") + compiled_vision_encoder = torch.jit.load(vision_path) + + # Get CPU language model (this is what actual inference uses) + cpu_language_model = pipe.text_encoder.model.language_model + cpu_language_model.eval() + + # Import and create NeuronTextEncoderWrapper + from neuron_qwen_image_edit.neuron_commons import NeuronTextEncoderWrapper + + # Create wrapper with CPU language model (same as run_qwen_image_edit.py) + print("Creating NeuronTextEncoderWrapper with CPU Language Model...") + neuron_text_encoder = NeuronTextEncoderWrapper( + original_text_encoder=pipe.text_encoder, + compiled_vision_encoder=compiled_vision_encoder, + compiled_language_model=None, # Not using compiled LM + cpu_language_model=cpu_language_model, + image_size=args.image_size, + max_seq_len=args.max_sequence_length + ) + + # Create test prompt + prompt = "A beautiful sunset over the ocean" + print(f"\nTest prompt: '{prompt}'") + + # Get inputs from tokenizer + text_inputs = pipe.tokenizer( + prompt, + padding="max_length", + max_length=args.max_sequence_length, + truncation=True, + return_tensors="pt" + ) + + print(f" input_ids shape: {text_inputs.input_ids.shape}") + print(f" attention_mask shape: {text_inputs.attention_mask.shape}") + print(f" Non-padding tokens: {text_inputs.attention_mask.sum().item()}") + + # ============================================ + # DEBUG: Step-by-step comparison + # ============================================ + print("\n" + "-"*40) + print("DEBUG: Step-by-step comparison") + print("-"*40) + + # Step 1: Compare embed_tokens + print("\n[Step 1] Comparing embed_tokens...") + orig_embed = pipe.text_encoder.model.language_model.embed_tokens + wrapper_embed = neuron_text_encoder.embed_tokens + + with torch.no_grad(): + orig_embeds = orig_embed(text_inputs.input_ids) + wrapper_embeds = wrapper_embed(text_inputs.input_ids) + + embed_diff = (orig_embeds.float() - wrapper_embeds.float()).abs() + print(f" Original embed shape: {orig_embeds.shape}, dtype: {orig_embeds.dtype}") + print(f" Wrapper embed shape: {wrapper_embeds.shape}, dtype: {wrapper_embeds.dtype}") + print(f" Max difference: {embed_diff.max().item():.6e}") + print(f" Mean difference: {embed_diff.mean().item():.6e}") + + embed_cosine = F.cosine_similarity( + orig_embeds.flatten().unsqueeze(0).float(), + wrapper_embeds.flatten().unsqueeze(0).float() + ).item() + print(f" Cosine similarity: {embed_cosine:.6f}") + + # Step 2: Direct language model comparison (same inputs) + print("\n[Step 2] Direct Language Model comparison (same input embeds)...") + with torch.no_grad(): + # Use original embeddings for both + direct_cpu_output = cpu_language_model( + inputs_embeds=orig_embeds, + attention_mask=text_inputs.attention_mask, + output_hidden_states=True, + return_dict=True + ).last_hidden_state + + direct_cosine = F.cosine_similarity( + direct_cpu_output.flatten().unsqueeze(0).float(), + direct_cpu_output.flatten().unsqueeze(0).float() + ).item() + print(f" Self-comparison cosine (sanity check): {direct_cosine:.6f}") + + # Step 3: Compare wrapper's LM call vs direct LM call + print("\n[Step 3] Wrapper flow vs direct flow...") + with torch.no_grad(): + # What the wrapper does internally + wrapper_embeds_bf16 = wrapper_embeds.to(torch.bfloat16) + wrapper_lm_output = cpu_language_model( + inputs_embeds=wrapper_embeds_bf16, + attention_mask=text_inputs.attention_mask, + output_hidden_states=True, + return_dict=True + ).last_hidden_state + + # Direct with original embeds + orig_embeds_bf16 = orig_embeds.to(torch.bfloat16) + direct_lm_output = cpu_language_model( + inputs_embeds=orig_embeds_bf16, + attention_mask=text_inputs.attention_mask, + output_hidden_states=True, + return_dict=True + ).last_hidden_state + + lm_cosine = F.cosine_similarity( + wrapper_lm_output.flatten().unsqueeze(0).float(), + direct_lm_output.flatten().unsqueeze(0).float() + ).item() + print(f" Wrapper embeds -> LM vs Orig embeds -> LM cosine: {lm_cosine:.6f}") + + # ============================================ + # Original test flow + # ============================================ + print("\n" + "-"*40) + print("Full pipeline comparison") + print("-"*40) + + # Run original CPU text encoder + print("\nRunning original CPU text encoder...") + with torch.no_grad(): + cpu_output = pipe.text_encoder( + input_ids=text_inputs.input_ids, + attention_mask=text_inputs.attention_mask, + pixel_values=None, # No image for text-only test + output_hidden_states=True, + return_dict=True + ) + cpu_hidden = cpu_output.hidden_states[-1] + print(f" CPU output shape: {cpu_hidden.shape}") + + # Run NeuronTextEncoderWrapper (with CPU LM) + print("\nRunning NeuronTextEncoderWrapper (CPU LM mode)...") + with torch.no_grad(): + neuron_output = neuron_text_encoder( + input_ids=text_inputs.input_ids, + attention_mask=text_inputs.attention_mask, + pixel_values=None, # No image for text-only test + output_hidden_states=True, + return_dict=True + ) + neuron_hidden = neuron_output.hidden_states[-1] + print(f" Neuron wrapper output shape: {neuron_hidden.shape}") + + # Also compare with direct LM output + print("\n[Extra] Comparing direct LM output vs original text encoder...") + direct_vs_orig = F.cosine_similarity( + direct_lm_output.flatten().unsqueeze(0).float(), + cpu_hidden.flatten().unsqueeze(0).float() + ).item() + print(f" Direct LM output vs Original text encoder: {direct_vs_orig:.6f}") + + # Compare outputs + metrics = compute_metrics(cpu_hidden, neuron_hidden, "CPU LM Mode (Text Only)") + + return metrics + + +def test_embedding_values(args): + """Test to debug embedding layer differences.""" + print("\n" + "="*60) + print("Testing Embedding Values") + print("="*60) + + dtype = torch.bfloat16 + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + embed_tokens = pipe.text_encoder.model.language_model.embed_tokens + + # Test with specific token IDs + test_ids = torch.tensor([[1, 100, 1000, 10000, 50000]]) + embeddings = embed_tokens(test_ids) + + print(f"\nEmbedding layer info:") + print(f" Num embeddings: {embed_tokens.num_embeddings}") + print(f" Embedding dim: {embed_tokens.embedding_dim}") + print(f" Weight dtype: {embed_tokens.weight.dtype}") + + print(f"\nTest embeddings shape: {embeddings.shape}") + print(f"Embedding statistics:") + print(f" Mean: {embeddings.mean().item():.6f}") + print(f" Std: {embeddings.std().item():.6f}") + print(f" Min: {embeddings.min().item():.6f}") + print(f" Max: {embeddings.max().item():.6f}") + + return {"num_embeddings": embed_tokens.num_embeddings} + + +def main(): + parser = argparse.ArgumentParser(description="Text Encoder Unit Test: CPU vs Neuron") + parser.add_argument("--image_size", type=int, default=224, + help="Image size for vision encoder") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max text sequence length") + parser.add_argument("--compiled_models_dir", type=str, + default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + parser.add_argument("--test", type=str, default="all", + choices=["vision", "language", "full", "embedding", "cpu_lm", "all"], + help="Which test to run (cpu_lm tests actual inference config)") + args = parser.parse_args() + + print("="*60) + print("Text Encoder Unit Test: Comparing Neuron vs CPU Inference") + print("="*60) + print(f"Image size: {args.image_size}") + print(f"Max sequence length: {args.max_sequence_length}") + print(f"Compiled models: {args.compiled_models_dir}") + + results = {} + + if args.test in ["vision", "all"]: + results["vision"] = test_vision_encoder(args) + + if args.test in ["language", "all"]: + results["language"] = test_language_model(args) + + if args.test in ["full", "all"]: + results["full"] = test_text_encoder_full(args) + + if args.test in ["cpu_lm", "all"]: + results["cpu_lm"] = test_cpu_language_model_mode(args) + + if args.test in ["embedding", "all"]: + results["embedding"] = test_embedding_values(args) + + # Summary + print("\n" + "="*60) + print("TEST SUMMARY") + print("="*60) + + for name, metrics in results.items(): + if metrics and "cosine_sim" in metrics: + status = "PASS" if metrics["cosine_sim"] > 0.99 else "WARN" if metrics["cosine_sim"] > 0.95 else "FAIL" + print(f"{name:15s}: Cosine Sim = {metrics['cosine_sim']:.6f} Max AE = {metrics['max_abs_error']:.2e} [{status}]") + elif metrics: + print(f"{name:15s}: Completed") + else: + print(f"{name:15s}: SKIPPED (compiled model not found)") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_transformer.py b/contrib/models/Qwen-Image-Edit/test/integration/test_transformer.py new file mode 100755 index 00000000..4caf149f --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_transformer.py @@ -0,0 +1,459 @@ +#!/usr/bin/env python3 +""" +Transformer Unit Test: Compare Neuron vs CPU/GPU inference results + +This test compares the QwenImageTransformer2DModel outputs between: +1. Original model running on CPU +2. Compiled model running on Neuron (trn2) + +Key metrics: +- Max Absolute Error (MAE) +- Mean Absolute Error (Mean AE) +- Cosine Similarity +- Output statistics (mean, std, min, max) +""" + +import os +import sys +import argparse + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Set Neuron environment BEFORE imports +TP_DEGREE = 8 +os.environ["LOCAL_WORLD_SIZE"] = str(TP_DEGREE) +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" +os.environ["NEURON_FUSE_SOFTMAX"] = "1" +os.environ["NEURON_CUSTOM_SILU"] = "1" + +import torch +import torch.nn.functional as F +import numpy as np + +from diffusers import QwenImageEditPlusPipeline +from neuron_qwen_image_edit.neuron_rope import patch_qwenimage_rope + + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models" + + +def compute_metrics(cpu_output, neuron_output, name="output"): + """Compute comparison metrics between CPU and Neuron outputs.""" + # Ensure same dtype for comparison + cpu_out = cpu_output.float().detach().cpu() + neuron_out = neuron_output.float().detach().cpu() + + # Handle shape mismatch (Neuron output may be padded) + if cpu_out.shape != neuron_out.shape: + print(f" Shape mismatch: CPU {cpu_out.shape} vs Neuron {neuron_out.shape}") + # Truncate to smaller shape + min_shape = [min(c, n) for c, n in zip(cpu_out.shape, neuron_out.shape)] + slices = tuple(slice(0, s) for s in min_shape) + cpu_out = cpu_out[slices] + neuron_out = neuron_out[slices] + print(f" Comparing truncated shape: {cpu_out.shape}") + + # Absolute error + abs_error = torch.abs(cpu_out - neuron_out) + max_abs_error = abs_error.max().item() + mean_abs_error = abs_error.mean().item() + + # Relative error + rel_error = abs_error / (torch.abs(cpu_out) + 1e-8) + max_rel_error = rel_error.max().item() + mean_rel_error = rel_error.mean().item() + + # Cosine similarity + cpu_flat = cpu_out.flatten() + neuron_flat = neuron_out.flatten() + cosine_sim = F.cosine_similarity(cpu_flat.unsqueeze(0), neuron_flat.unsqueeze(0)).item() + + # Statistics + cpu_stats = { + "mean": cpu_out.mean().item(), + "std": cpu_out.std().item(), + "min": cpu_out.min().item(), + "max": cpu_out.max().item(), + } + neuron_stats = { + "mean": neuron_out.mean().item(), + "std": neuron_out.std().item(), + "min": neuron_out.min().item(), + "max": neuron_out.max().item(), + } + + print(f"\n{'='*60}") + print(f"Metrics for {name}") + print(f"{'='*60}") + print(f"Shape: {cpu_out.shape}") + print(f"\nError Metrics:") + print(f" Max Absolute Error: {max_abs_error:.6e}") + print(f" Mean Absolute Error: {mean_abs_error:.6e}") + print(f" Max Relative Error: {max_rel_error:.6e}") + print(f" Mean Relative Error: {mean_rel_error:.6e}") + print(f" Cosine Similarity: {cosine_sim:.6f}") + print(f"\nCPU Output Statistics:") + print(f" Mean: {cpu_stats['mean']:.6f}, Std: {cpu_stats['std']:.6f}") + print(f" Min: {cpu_stats['min']:.6f}, Max: {cpu_stats['max']:.6f}") + print(f"\nNeuron Output Statistics:") + print(f" Mean: {neuron_stats['mean']:.6f}, Std: {neuron_stats['std']:.6f}") + print(f" Min: {neuron_stats['min']:.6f}, Max: {neuron_stats['max']:.6f}") + + # Check for NaN/Inf + if torch.isnan(neuron_out).any(): + print(f"\n WARNING: Neuron output contains NaN values!") + if torch.isinf(neuron_out).any(): + print(f"\n WARNING: Neuron output contains Inf values!") + + return { + "max_abs_error": max_abs_error, + "mean_abs_error": mean_abs_error, + "cosine_sim": cosine_sim, + "cpu_stats": cpu_stats, + "neuron_stats": neuron_stats, + } + + +def test_transformer_single_step(args): + """Test transformer for a single denoising step.""" + print("\n" + "="*60) + print("Testing Transformer (Single Step)") + print("="*60) + + dtype = torch.bfloat16 + batch_size = args.batch_size + height, width = args.height, args.width + + # Calculate dimensions + latent_height = height // 8 + latent_width = width // 8 + patch_size = 2 + patch_h = latent_height // patch_size + patch_w = latent_width // patch_size + temporal_frames = args.patch_multiplier # 2 for image editing + num_patches = temporal_frames * patch_h * patch_w + + in_channels = 64 + text_hidden_size = 3584 + max_seq_len = args.max_sequence_length + + print(f"\nConfiguration:") + print(f" Image size: {height}x{width}") + print(f" Latent size: {latent_height}x{latent_width}") + print(f" Patch size: {patch_size}") + print(f" Temporal frames: {temporal_frames}") + print(f" Num patches: {num_patches}") + print(f" Max sequence length: {max_seq_len}") + print(f" Batch size: {batch_size}") + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Patch RoPE for Neuron compatibility + print("Patching RoPE for Neuron compatibility...") + pipe.transformer = patch_qwenimage_rope(pipe.transformer) + pipe.transformer.eval() + + # Create test inputs + print("\nCreating test inputs...") + # hidden_states: (batch, num_patches, in_channels) + hidden_states = torch.randn(batch_size, num_patches, in_channels, dtype=dtype) + # encoder_hidden_states: (batch, seq_len, text_hidden_size) + encoder_hidden_states = torch.randn(batch_size, max_seq_len, text_hidden_size, dtype=dtype) + # timestep: (batch,) - use a typical timestep value + timestep = torch.tensor([500.0] * batch_size, dtype=torch.float32) + # img_shapes for CPU model + img_shapes = [(temporal_frames, patch_h, patch_w)] * batch_size + + print(f" hidden_states: {hidden_states.shape}") + print(f" encoder_hidden_states: {encoder_hidden_states.shape}") + print(f" timestep: {timestep.shape}") + print(f" img_shapes: {img_shapes}") + + # CPU inference + print("\nRunning CPU inference...") + with torch.no_grad(): + cpu_output = pipe.transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_shapes=img_shapes, + return_dict=False + ) + cpu_output = cpu_output[0] + print(f" CPU output shape: {cpu_output.shape}") + + # Check compiled model + transformer_path = f"{args.compiled_models_dir}/transformer" + if not os.path.exists(transformer_path): + print(f"\nERROR: Compiled transformer not found at {transformer_path}") + print("Please run compile_transformer.py first.") + return None + + # Load Neuron compiled model + print(f"\nLoading compiled transformer from {transformer_path}...") + import neuronx_distributed + compiled_transformer = neuronx_distributed.trace.parallel_model_load(transformer_path) + + # Prepare inputs for Neuron (timestep must be float32) + timestep_f32 = timestep.to(torch.float32) + + # Neuron inference + print("Running Neuron inference...") + with torch.no_grad(): + neuron_output = compiled_transformer( + hidden_states, + encoder_hidden_states, + timestep_f32 + ) + neuron_output = neuron_output[0] + print(f" Neuron output shape: {neuron_output.shape}") + + # Compare results + metrics = compute_metrics(cpu_output, neuron_output, "Transformer Output") + + return metrics + + +def test_transformer_multiple_timesteps(args): + """Test transformer across multiple timesteps to check consistency.""" + print("\n" + "="*60) + print("Testing Transformer (Multiple Timesteps)") + print("="*60) + + dtype = torch.bfloat16 + batch_size = args.batch_size + height, width = args.height, args.width + + # Calculate dimensions + latent_height = height // 8 + latent_width = width // 8 + patch_size = 2 + patch_h = latent_height // patch_size + patch_w = latent_width // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + + in_channels = 64 + text_hidden_size = 3584 + max_seq_len = args.max_sequence_length + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + pipe.transformer = patch_qwenimage_rope(pipe.transformer) + pipe.transformer.eval() + + # Check compiled model + transformer_path = f"{args.compiled_models_dir}/transformer" + if not os.path.exists(transformer_path): + print(f"\nERROR: Compiled transformer not found at {transformer_path}") + return None + + print(f"Loading compiled transformer from {transformer_path}...") + import neuronx_distributed + compiled_transformer = neuronx_distributed.trace.parallel_model_load(transformer_path) + + # Test at different timesteps + timesteps_to_test = [999.0, 750.0, 500.0, 250.0, 1.0] + results = [] + + # Use same random inputs for all timesteps + torch.manual_seed(42) + hidden_states = torch.randn(batch_size, num_patches, in_channels, dtype=dtype) + encoder_hidden_states = torch.randn(batch_size, max_seq_len, text_hidden_size, dtype=dtype) + img_shapes = [(temporal_frames, patch_h, patch_w)] * batch_size + + for t in timesteps_to_test: + timestep = torch.tensor([t] * batch_size, dtype=torch.float32) + + print(f"\n--- Timestep {t} ---") + + with torch.no_grad(): + # CPU + cpu_output = pipe.transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_shapes=img_shapes, + return_dict=False + )[0] + + # Neuron + neuron_output = compiled_transformer( + hidden_states, + encoder_hidden_states, + timestep + )[0] + + # Quick metrics + abs_error = torch.abs(cpu_output.float() - neuron_output.float()) + max_ae = abs_error.max().item() + mean_ae = abs_error.mean().item() + cosine_sim = F.cosine_similarity( + cpu_output.flatten().unsqueeze(0).float(), + neuron_output.flatten().unsqueeze(0).float() + ).item() + + print(f" Max AE: {max_ae:.6e}, Mean AE: {mean_ae:.6e}, Cosine: {cosine_sim:.6f}") + results.append({ + "timestep": t, + "max_abs_error": max_ae, + "mean_abs_error": mean_ae, + "cosine_sim": cosine_sim, + }) + + # Summary + print("\n--- Timestep Summary ---") + avg_cosine = np.mean([r["cosine_sim"] for r in results]) + max_error = max([r["max_abs_error"] for r in results]) + print(f"Average Cosine Similarity: {avg_cosine:.6f}") + print(f"Max Absolute Error (all timesteps): {max_error:.6e}") + + return results + + +def test_transformer_block_by_block(args): + """Test individual transformer blocks to identify problematic layers.""" + print("\n" + "="*60) + print("Testing Transformer Block-by-Block") + print("="*60) + print("NOTE: This test requires manual inspection of intermediate outputs.") + print("The compiled model doesn't expose individual blocks.") + print("This test compares the CPU model's block outputs for debugging.") + + dtype = torch.bfloat16 + batch_size = 1 + height, width = args.height, args.width + + # Calculate dimensions + latent_height = height // 8 + latent_width = width // 8 + patch_size = 2 + patch_h = latent_height // patch_size + patch_w = latent_width // patch_size + temporal_frames = args.patch_multiplier + num_patches = temporal_frames * patch_h * patch_w + + in_channels = 64 + text_hidden_size = 3584 + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + pipe.transformer = patch_qwenimage_rope(pipe.transformer) + pipe.transformer.eval() + + transformer = pipe.transformer + num_blocks = len(transformer.transformer_blocks) + print(f"Transformer has {num_blocks} blocks") + + # Create test inputs + torch.manual_seed(42) + hidden_states = torch.randn(batch_size, num_patches, in_channels, dtype=dtype) + encoder_hidden_states = torch.randn(batch_size, args.max_sequence_length, text_hidden_size, dtype=dtype) + timestep = torch.tensor([500.0], dtype=torch.float32) + img_shapes = [(temporal_frames, patch_h, patch_w)] + + # Check output statistics at each block + print("\n--- Block Output Statistics (CPU) ---") + print("This helps identify where numerical issues might occur.") + + # We need to hook into the model to get intermediate outputs + # For now, just run the full model and check final output + with torch.no_grad(): + output = transformer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + img_shapes=img_shapes, + return_dict=False + )[0] + + print(f"\nFinal output statistics:") + print(f" Shape: {output.shape}") + print(f" Mean: {output.mean().item():.6f}") + print(f" Std: {output.std().item():.6f}") + print(f" Min: {output.min().item():.6f}") + print(f" Max: {output.max().item():.6f}") + print(f" Has NaN: {torch.isnan(output).any()}") + print(f" Has Inf: {torch.isinf(output).any()}") + + return {"num_blocks": num_blocks} + + +def main(): + parser = argparse.ArgumentParser(description="Transformer Unit Test: CPU vs Neuron") + parser.add_argument("--height", type=int, default=512, help="Image height") + parser.add_argument("--width", type=int, default=512, help="Image width") + parser.add_argument("--max_sequence_length", type=int, default=512, + help="Max text sequence length") + parser.add_argument("--batch_size", type=int, default=1, + help="Batch size (1 or 2)") + parser.add_argument("--patch_multiplier", type=int, default=2, + help="Patch multiplier (2 for image editing)") + parser.add_argument("--compiled_models_dir", type=str, + default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + parser.add_argument("--test", type=str, default="single", + choices=["single", "timesteps", "blocks", "all"], + help="Which test to run") + args = parser.parse_args() + + print("="*60) + print("Transformer Unit Test: Comparing Neuron vs CPU Inference") + print("="*60) + print(f"Image size: {args.height}x{args.width}") + print(f"Batch size: {args.batch_size}") + print(f"Patch multiplier: {args.patch_multiplier}") + print(f"Compiled models: {args.compiled_models_dir}") + + results = {} + + if args.test in ["single", "all"]: + results["single_step"] = test_transformer_single_step(args) + + if args.test in ["timesteps", "all"]: + results["timesteps"] = test_transformer_multiple_timesteps(args) + + if args.test in ["blocks", "all"]: + results["blocks"] = test_transformer_block_by_block(args) + + # Summary + print("\n" + "="*60) + print("TEST SUMMARY") + print("="*60) + + if "single_step" in results and results["single_step"]: + m = results["single_step"] + status = "PASS" if m["cosine_sim"] > 0.99 else "WARN" if m["cosine_sim"] > 0.95 else "FAIL" + print(f"Single Step: Cosine Sim = {m['cosine_sim']:.6f} Max AE = {m['max_abs_error']:.2e} [{status}]") + + if "timesteps" in results and results["timesteps"]: + avg_cos = np.mean([r["cosine_sim"] for r in results["timesteps"]]) + status = "PASS" if avg_cos > 0.99 else "WARN" if avg_cos > 0.95 else "FAIL" + print(f"Multi-Timestep: Avg Cosine = {avg_cos:.6f} [{status}]") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/integration/test_vae.py b/contrib/models/Qwen-Image-Edit/test/integration/test_vae.py new file mode 100755 index 00000000..43ae3d71 --- /dev/null +++ b/contrib/models/Qwen-Image-Edit/test/integration/test_vae.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python3 +""" +VAE Unit Test: Compare Neuron vs CPU/GPU inference results + +This test compares the VAE encoder and decoder outputs between: +1. Original model running on CPU +2. Compiled model running on Neuron (trn2) + +Key metrics: +- Max Absolute Error (MAE) +- Mean Absolute Error (Mean AE) +- Cosine Similarity +- Output statistics (mean, std, min, max) +""" + +import os +import sys +import argparse + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import torch +import torch.nn.functional as F +import numpy as np +from PIL import Image + +# Set Neuron environment before importing neuron libraries +os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" +os.environ["NEURON_LOGICAL_NC_CONFIG"] = "2" + +from diffusers import QwenImageEditPlusPipeline +from neuron_qwen_image_edit.autoencoder_kl_qwenimage_neuron import AutoencoderKLQwenImage as NeuronAutoencoder +from neuron_qwen_image_edit.neuron_commons import f32Wrapper + + +CACHE_DIR = "/opt/dlami/nvme/qwen_image_edit_hf_cache_dir" +MODEL_ID = "Qwen/Qwen-Image-Edit-2509" +COMPILED_MODELS_DIR = "/opt/dlami/nvme/compiled_models" + + +def compute_metrics(cpu_output, neuron_output, name="output"): + """Compute comparison metrics between CPU and Neuron outputs.""" + # Ensure same dtype for comparison + cpu_out = cpu_output.float().detach().cpu() + neuron_out = neuron_output.float().detach().cpu() + + # Absolute error + abs_error = torch.abs(cpu_out - neuron_out) + max_abs_error = abs_error.max().item() + mean_abs_error = abs_error.mean().item() + + # Relative error (avoid division by zero) + rel_error = abs_error / (torch.abs(cpu_out) + 1e-8) + max_rel_error = rel_error.max().item() + mean_rel_error = rel_error.mean().item() + + # Cosine similarity + cpu_flat = cpu_out.flatten() + neuron_flat = neuron_out.flatten() + cosine_sim = F.cosine_similarity(cpu_flat.unsqueeze(0), neuron_flat.unsqueeze(0)).item() + + # Statistics + cpu_stats = { + "mean": cpu_out.mean().item(), + "std": cpu_out.std().item(), + "min": cpu_out.min().item(), + "max": cpu_out.max().item(), + } + neuron_stats = { + "mean": neuron_out.mean().item(), + "std": neuron_out.std().item(), + "min": neuron_out.min().item(), + "max": neuron_out.max().item(), + } + + print(f"\n{'='*60}") + print(f"Metrics for {name}") + print(f"{'='*60}") + print(f"Shape: {cpu_out.shape}") + print(f"\nError Metrics:") + print(f" Max Absolute Error: {max_abs_error:.6e}") + print(f" Mean Absolute Error: {mean_abs_error:.6e}") + print(f" Max Relative Error: {max_rel_error:.6e}") + print(f" Mean Relative Error: {mean_rel_error:.6e}") + print(f" Cosine Similarity: {cosine_sim:.6f}") + print(f"\nCPU Output Statistics:") + print(f" Mean: {cpu_stats['mean']:.6f}, Std: {cpu_stats['std']:.6f}") + print(f" Min: {cpu_stats['min']:.6f}, Max: {cpu_stats['max']:.6f}") + print(f"\nNeuron Output Statistics:") + print(f" Mean: {neuron_stats['mean']:.6f}, Std: {neuron_stats['std']:.6f}") + print(f" Min: {neuron_stats['min']:.6f}, Max: {neuron_stats['max']:.6f}") + + return { + "max_abs_error": max_abs_error, + "mean_abs_error": mean_abs_error, + "cosine_sim": cosine_sim, + "cpu_stats": cpu_stats, + "neuron_stats": neuron_stats, + } + + +def upcast_norms_to_f32(module): + """Upcast normalization layers to float32 for numerical stability.""" + for name, child in module.named_children(): + if isinstance(child, (torch.nn.GroupNorm, torch.nn.LayerNorm)): + setattr(module, name, f32Wrapper(child.to(torch.float32))) + else: + upcast_norms_to_f32(child) + + +def test_vae_encoder(args): + """Test VAE encoder: CPU vs Neuron.""" + print("\n" + "="*60) + print("Testing VAE Encoder") + print("="*60) + + dtype = torch.bfloat16 + height, width = args.height, args.width + temporal_frames = 1 + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Create Neuron-compatible VAE with same weights + print("Creating Neuron-compatible VAE...") + original_vae_config = pipe.vae.config + neuron_vae = NeuronAutoencoder( + base_dim=original_vae_config.base_dim, + z_dim=original_vae_config.z_dim, + dim_mult=original_vae_config.dim_mult, + num_res_blocks=original_vae_config.num_res_blocks, + attn_scales=original_vae_config.attn_scales, + temperal_downsample=original_vae_config.temperal_downsample, + dropout=original_vae_config.dropout, + input_channels=original_vae_config.input_channels, + latents_mean=original_vae_config.latents_mean, + latents_std=original_vae_config.latents_std, + ) + neuron_vae.load_state_dict(pipe.vae.state_dict()) + neuron_vae = neuron_vae.to(dtype) + + # Get encoder + cpu_encoder = neuron_vae.encoder + cpu_encoder.eval() + + # Create test input + print(f"\nCreating test input: (1, 3, {temporal_frames}, {height}, {width})") + test_input = torch.randn(1, 3, temporal_frames, height, width, dtype=dtype) + + # CPU inference + print("Running CPU inference...") + with torch.no_grad(): + cpu_output = cpu_encoder(test_input) + + # Load and run Neuron model + vae_encoder_path = f"{args.compiled_models_dir}/vae_encoder/model.pt" + if not os.path.exists(vae_encoder_path): + print(f"\nERROR: Compiled VAE encoder not found at {vae_encoder_path}") + print("Please run compile_vae.py first.") + return None + + print(f"Loading compiled encoder from {vae_encoder_path}...") + import torch_neuronx + compiled_encoder = torch.jit.load(vae_encoder_path) + + print("Running Neuron inference...") + with torch.no_grad(): + neuron_output = compiled_encoder(test_input) + + # Compare results + metrics = compute_metrics(cpu_output, neuron_output, "VAE Encoder") + + return metrics + + +def test_vae_decoder(args): + """Test VAE decoder: CPU vs Neuron.""" + print("\n" + "="*60) + print("Testing VAE Decoder") + print("="*60) + + dtype = torch.bfloat16 + height, width = args.height, args.width + latent_height = height // 8 + latent_width = width // 8 + temporal_frames = 1 + z_dim = 16 # QwenImage VAE z_dim + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Create Neuron-compatible VAE with same weights + print("Creating Neuron-compatible VAE...") + original_vae_config = pipe.vae.config + neuron_vae = NeuronAutoencoder( + base_dim=original_vae_config.base_dim, + z_dim=original_vae_config.z_dim, + dim_mult=original_vae_config.dim_mult, + num_res_blocks=original_vae_config.num_res_blocks, + attn_scales=original_vae_config.attn_scales, + temperal_downsample=original_vae_config.temperal_downsample, + dropout=original_vae_config.dropout, + input_channels=original_vae_config.input_channels, + latents_mean=original_vae_config.latents_mean, + latents_std=original_vae_config.latents_std, + ) + neuron_vae.load_state_dict(pipe.vae.state_dict()) + neuron_vae = neuron_vae.to(dtype) + + # Get decoder + cpu_decoder = neuron_vae.decoder + cpu_decoder.eval() + + # Create test input (latent space) + print(f"\nCreating test input: (1, {z_dim}, {temporal_frames}, {latent_height}, {latent_width})") + test_input = torch.randn(1, z_dim, temporal_frames, latent_height, latent_width, dtype=dtype) + + # CPU inference + print("Running CPU inference...") + with torch.no_grad(): + cpu_output = cpu_decoder(test_input) + + # Load and run Neuron model + vae_decoder_path = f"{args.compiled_models_dir}/vae_decoder/model.pt" + if not os.path.exists(vae_decoder_path): + print(f"\nERROR: Compiled VAE decoder not found at {vae_decoder_path}") + print("Please run compile_vae.py first.") + return None + + print(f"Loading compiled decoder from {vae_decoder_path}...") + import torch_neuronx + compiled_decoder = torch.jit.load(vae_decoder_path) + + print("Running Neuron inference...") + with torch.no_grad(): + neuron_output = compiled_decoder(test_input) + + # Compare results + metrics = compute_metrics(cpu_output, neuron_output, "VAE Decoder") + + # Additional: visualize difference if output is image-like + if args.save_images: + save_comparison_images(cpu_output, neuron_output, "vae_decoder", args) + + return metrics + + +def test_vae_roundtrip(args): + """Test full VAE roundtrip: encode -> decode.""" + print("\n" + "="*60) + print("Testing VAE Roundtrip (Encode -> Decode)") + print("="*60) + + dtype = torch.bfloat16 + height, width = args.height, args.width + temporal_frames = 1 + + # Load original pipeline + print("\nLoading original pipeline...") + pipe = QwenImageEditPlusPipeline.from_pretrained( + MODEL_ID, + torch_dtype=dtype, + cache_dir=CACHE_DIR, + local_files_only=True, + ) + + # Create Neuron-compatible VAE with same weights + original_vae_config = pipe.vae.config + neuron_vae = NeuronAutoencoder( + base_dim=original_vae_config.base_dim, + z_dim=original_vae_config.z_dim, + dim_mult=original_vae_config.dim_mult, + num_res_blocks=original_vae_config.num_res_blocks, + attn_scales=original_vae_config.attn_scales, + temperal_downsample=original_vae_config.temperal_downsample, + dropout=original_vae_config.dropout, + input_channels=original_vae_config.input_channels, + latents_mean=original_vae_config.latents_mean, + latents_std=original_vae_config.latents_std, + ) + neuron_vae.load_state_dict(pipe.vae.state_dict()) + neuron_vae = neuron_vae.to(dtype) + neuron_vae.eval() + + # Create test image input + print(f"\nCreating test input: (1, 3, {temporal_frames}, {height}, {width})") + test_input = torch.randn(1, 3, temporal_frames, height, width, dtype=dtype) + + # CPU roundtrip + print("Running CPU roundtrip...") + with torch.no_grad(): + cpu_encoded = neuron_vae.encoder(test_input) + cpu_quant = neuron_vae.quant_conv(cpu_encoded) + # Take mean (first half of channels) + cpu_latent = cpu_quant[:, :16, :, :, :] + cpu_post_quant = neuron_vae.post_quant_conv(cpu_latent) + cpu_decoded = neuron_vae.decoder(cpu_post_quant) + + # Check compiled models exist + encoder_path = f"{args.compiled_models_dir}/vae_encoder/model.pt" + decoder_path = f"{args.compiled_models_dir}/vae_decoder/model.pt" + quant_conv_path = f"{args.compiled_models_dir}/quant_conv/model.pt" + post_quant_conv_path = f"{args.compiled_models_dir}/post_quant_conv/model.pt" + + if not os.path.exists(encoder_path) or not os.path.exists(decoder_path): + print(f"\nERROR: Compiled VAE models not found") + return None + + # Load compiled models + print("Loading compiled models...") + import torch_neuronx + compiled_encoder = torch.jit.load(encoder_path) + compiled_decoder = torch.jit.load(decoder_path) + + # Load quant_conv and post_quant_conv if available + compiled_quant_conv = None + compiled_post_quant_conv = None + if os.path.exists(quant_conv_path): + print(f" Loading quant_conv from {quant_conv_path}") + compiled_quant_conv = torch.jit.load(quant_conv_path) + else: + print(f" WARNING: quant_conv not compiled, using CPU version") + + if os.path.exists(post_quant_conv_path): + print(f" Loading post_quant_conv from {post_quant_conv_path}") + compiled_post_quant_conv = torch.jit.load(post_quant_conv_path) + else: + print(f" WARNING: post_quant_conv not compiled, using CPU version") + + # Neuron roundtrip + print("Running Neuron roundtrip...") + with torch.no_grad(): + neuron_encoded = compiled_encoder(test_input) + + # Use compiled quant_conv if available + if compiled_quant_conv is not None: + neuron_quant = compiled_quant_conv(neuron_encoded) + else: + neuron_quant = neuron_vae.quant_conv(neuron_encoded) + + neuron_latent = neuron_quant[:, :16, :, :, :] + + # Use compiled post_quant_conv if available + if compiled_post_quant_conv is not None: + neuron_post_quant = compiled_post_quant_conv(neuron_latent) + else: + neuron_post_quant = neuron_vae.post_quant_conv(neuron_latent) + + neuron_decoded = compiled_decoder(neuron_post_quant) + + # Compare intermediate results + print("\n--- Intermediate Comparisons ---") + compute_metrics(cpu_encoded, neuron_encoded, "Encoder Output") + compute_metrics(cpu_quant, neuron_quant, "After quant_conv (full 32 channels)") + compute_metrics(cpu_latent, neuron_latent, "Latent (first 16 channels)") + compute_metrics(cpu_post_quant, neuron_post_quant, "After post_quant_conv") + metrics = compute_metrics(cpu_decoded, neuron_decoded, "Final Decoded Output") + + # Additional test: Decoder with SAME input to isolate decoder error + print("\n--- Decoder Isolation Test (same input) ---") + with torch.no_grad(): + # Use CPU post_quant output as input to both decoders + cpu_decoder_from_cpu_input = neuron_vae.decoder(cpu_post_quant) + neuron_decoder_from_cpu_input = compiled_decoder(cpu_post_quant) + compute_metrics(cpu_decoder_from_cpu_input, neuron_decoder_from_cpu_input, + "Decoder (same CPU input)") + + # Save comparison images + if args.save_images: + save_comparison_images(cpu_decoded, neuron_decoded, "vae_roundtrip", args) + + return metrics + + +def save_comparison_images(cpu_output, neuron_output, prefix, args): + """Save CPU vs Neuron output as images for visual comparison.""" + import os + + output_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_outputs") + os.makedirs(output_dir, exist_ok=True) + + # Convert to numpy images (assume output is [-1, 1]) + cpu_img = ((cpu_output[0, :, 0].permute(1, 2, 0).float().cpu().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8) + neuron_img = ((neuron_output[0, :, 0].permute(1, 2, 0).float().cpu().numpy() + 1) * 127.5).clip(0, 255).astype(np.uint8) + + # Compute difference (amplified for visibility) + diff = np.abs(cpu_img.astype(float) - neuron_img.astype(float)) + diff_amplified = (diff * 10).clip(0, 255).astype(np.uint8) + + # Save images + Image.fromarray(cpu_img).save(os.path.join(output_dir, f"{prefix}_cpu.png")) + Image.fromarray(neuron_img).save(os.path.join(output_dir, f"{prefix}_neuron.png")) + Image.fromarray(diff_amplified).save(os.path.join(output_dir, f"{prefix}_diff_10x.png")) + + print(f"\nSaved comparison images to {output_dir}/") + + +def main(): + parser = argparse.ArgumentParser(description="VAE Unit Test: CPU vs Neuron") + parser.add_argument("--height", type=int, default=512, help="Image height") + parser.add_argument("--width", type=int, default=512, help="Image width") + parser.add_argument("--compiled_models_dir", type=str, + default=COMPILED_MODELS_DIR, + help="Directory containing compiled models") + parser.add_argument("--test", type=str, default="all", + choices=["encoder", "decoder", "roundtrip", "all"], + help="Which test to run") + parser.add_argument("--save_images", action="store_true", + help="Save comparison images") + args = parser.parse_args() + + print("="*60) + print("VAE Unit Test: Comparing Neuron vs CPU Inference") + print("="*60) + print(f"Image size: {args.height}x{args.width}") + print(f"Compiled models: {args.compiled_models_dir}") + + results = {} + + if args.test in ["encoder", "all"]: + results["encoder"] = test_vae_encoder(args) + + if args.test in ["decoder", "all"]: + results["decoder"] = test_vae_decoder(args) + + if args.test in ["roundtrip", "all"]: + results["roundtrip"] = test_vae_roundtrip(args) + + # Summary + print("\n" + "="*60) + print("TEST SUMMARY") + print("="*60) + for name, metrics in results.items(): + if metrics: + status = "PASS" if metrics["cosine_sim"] > 0.99 else "WARN" if metrics["cosine_sim"] > 0.95 else "FAIL" + print(f"{name:15s}: Cosine Sim = {metrics['cosine_sim']:.6f} Max AE = {metrics['max_abs_error']:.2e} [{status}]") + else: + print(f"{name:15s}: SKIPPED (compiled model not found)") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen-Image-Edit/test/unit/__init__.py b/contrib/models/Qwen-Image-Edit/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b