Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions contrib/models/Qwen-Image-Edit/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
output/
output_edited.png
log-neuron-cc.txt
scratch/
global_metric_store.json
103 changes: 103 additions & 0 deletions contrib/models/Qwen-Image-Edit/OPTIMIZATION.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Qwen-Image-Edit on Trainium2 — optimized (TP4×CP4 + WLO)

NeuronX adaptation of `Qwen/Qwen-Image-Edit-2509` for AWS Trainium2 inference, with the
latency optimizations applied. **Production config: TP=4 × CP=4 (world=16) + WLO**, which
runs the 896×1184 / 8-step / CFG=1 virtual try-on at **~4.5 s end-to-end** — faster than the
H100 vLLM-Omni reference (4.99 s), lossless.

> **Scope / applicability.** This optimization round targets a **few-step
> distillation–finetuned checkpoint run with classifier-free guidance disabled
> (CFG=1)** — i.e. ~8 denoising steps and a *single* (positive-only) transformer forward
> per step, no negative-prompt pass. Two consequences shape every number below:
> - With CFG=1 there is no negative branch to batch, so the **CFG-parallel (DP=2) path
> buys nothing** — the right lever is **Context Parallel (CP)**, which shards the single
> forward across more cores. (For CFG>1, V3 CFG's DP=2 batching of neg+pos is the better
> layout; CP scaling here is specifically for the CFG=1 few-step regime.)
> - At only ~8 steps the fixed per-run cost (text encoder + VAE encode/decode, ~1.1 s) is a
> large fraction of E2E, so VAE/text-encoder latency matters far more than it would for a
> 50-step run — which is why the VAE batched-tile win below is material here.
>
> The latency targets (matching/beating H100's 4.99 s) and the CP=4 sweet-spot choice are
> all stated **for this few-step / CFG=1 workload**; a many-step or CFG>1 run has a
> different optimum.

## What's here

```
src/ full NeuronX model/compile/run source (the optimizations live here)
release_v3cp4_wlo/ the production config: compile.sh / run_tryon.sh / test / README + sample outputs
requirements.txt
```

Input try-on images (cloth/, input_img/) are NOT included — supply your own and pass via the
run script flags (see release_v3cp4_wlo/run_tryon.sh).

## Quick start (production config)

```bash
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
cd <repo>/contrib/models/Qwen-Image-Edit

# 1) compile transformer + vision + LM at world=16 (+WLO); VAE symlinked. ~35 min.
bash release_v3cp4_wlo/compile.sh

# 2) run try-on (~4.5 s). Edit the cloth/model paths in the script or pass via QIE_CLOTH / QIE_MODEL_IMG.
bash release_v3cp4_wlo/run_tryon.sh out.png
```

See `release_v3cp4_wlo/README.md` for the full recipe, the CP scaling rationale (why CP=4),
and the quality test.

## The optimizations (vs the original ~7.9 s baseline)

| optimization | effect | how |
|---|---|---|
| **VAE batched-tile** | VAE encode 559 → 298 ms (−47%), decode 422 → 192 ms (−55%), **−6.1% E2E**; numerically identical | compile the tiled VAE encoder/decoder at `batch=N` (N = tiles/image — 6 for 1024², 12 for two-image edit) and run all tiles in **one** NEFF launch instead of N sequential launches. `_tiled_encode`/`_tiled_decode` enumerate tiles → run in chunks of the compiled batch → scatter+crop+blend back into the grid. Collapses ~37 ms/tile launch overhead; matters at few-step because VAE is a fixed cost. Set via `VAE_BATCH_SIZE` in `compile.sh`. **Both `_tiled_encode` and `_tiled_decode` must be updated** (half-applied regresses the decoder). |
| **WLO** (weight layout opt) | −3.2% step, **bit-exact** | pass `priority_model_key="inference"` to `ModelBuilder.compile()` — was never enabled. `QIE_WLO=1` (default). |
| **CP scaling** (CP=2 → CP=4) | **7.51 → 4.50 s** (−40%), step 793 → 411 ms, lossless | QIE transformer is compute-bound but V3 only used 8 of the chip's 32 logical cores; doubling cores (world=8 → 16) halves the step. Only a lever because CFG=1 runs a single forward (no DP=2 batching to exploit). Compile with `--tp_degree 4 --world_size 16`; run with `QIE_WORLD_SIZE=16 NEURON_RT_NUM_CORES=16`. |

Output is visually equivalent across CP degrees (CP=4 vs CP=2 mean |Δ| 0.78%); not bit-exact
because the CP partition changes bf16 accumulation order. VAE batched-tile and WLO are
numerically clean (identical / bit-exact); CP scaling is lossless in the visual sense.

### Per-optimization latency reduction

Each delta is measured against the *then-current* baseline (they are applied sequentially), so
this is a cumulative path, **not** a simple sum. Workload: 896×1184 / 8-step / **CFG=1**
two-image virtual try-on on `trn2.48xlarge`.

| step | optimization | E2E | what moved |
|---|---|---|---|
| 0 | baseline (V3 CP=2, fp32 reduce, per-tile VAE) | ~7.89 s | — |
| 1 | + bf16 TP all-reduce | ~7.5 s | TP all-reduce bytes halved (~−9% step) |
| 2 | + VAE batched-tile | 7.41 s | VAE enc 559→298 ms, dec 422→192 ms (−6.1% E2E) |
| 3 | + WLO | ~7.3 s | weight layout for inference (−3.2% step, bit-exact) |
| 4 | + CP scaling (CP=2 → CP=4) | **4.50 s** | transformer step 793 → 411 ms (2× cores) |

> bf16 TP all-reduce and VAE batched-tile predate this round (part of the base contribution);
> the **new** work here is **WLO** and **CP scaling**, which take the verified-correct config
> from 7.51 s to **4.50 s**. After CP scaling the transformer is ~2.3 s of the 4.5 s, so the
> fixed text-encoder + VAE (~1.1 s) is now the dominant remaining cost — which is exactly why
> the VAE batched-tile win is load-bearing in this few-step regime.

### Why TP=4 × CP=4 (not CP=8 or TP=8)
- **CP=8** (32 cores) reaches 4.10 s but the gain over CP=4 is only ~0.4 s — transformer
marginal return drops to 0.70× (from 0.52×) and the larger world=32 adds ~200 ms of DP
overhead to the TP=4 vision/LM. CP=4 is the sweet spot.
- **TP=8** triggers a flash-kernel seqlen-shard fallback and the 8-rank TP group has no
replica-group mapping in NeuronX for world<64. TP=4 (6 heads/rank) maps cleanly to the torus.
- CP must be a power of 2 (world ∈ {8,16,32}); the runtime rejects other core counts (e.g.
world=12 → "Unsupported topology").

## Key env vars (compile / run)

- `QIE_WLO=1` (default) — weight layout optimization (bit-exact speedup).
- `QIE_WORLD_SIZE=16`, `NEURON_RT_NUM_CORES=16` — required at run time for CP=4.
- `QIE_ALLREDUCE_BF16=1` (default), `QIE_OPT_LEVEL=2`, `QIE_CC_TILING=4` — existing tuned defaults.
- Use the SHORT prompt `让图2的模特换上图1的下装` for try-on (the long prompt mis-edits on this ckpt).

## Environment

- Instance `trn2.48xlarge` (64 NeuronCores = 32 logical at LNC2)
- Venv `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference` (PyTorch 2.9, neuronx-cc 2.22, neuronx-distributed 0.16)
- `PYTHONPATH=src:$PYTHONPATH` for both compile and run
219 changes: 219 additions & 0 deletions contrib/models/Qwen-Image-Edit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Contrib Model: Qwen-Image-Edit

NeuronX adaptation of [alibaba-pai/Qwen-Image-Edit-2509](https://huggingface.co/alibaba-pai/Qwen-Image-Edit-2509) for AWS Trainium2 inference.

> **Latency-optimized config (recommended for few-step / CFG=1):** TP=4 × CP=4 (world=16) +
> WLO runs the 896×1184 / 8-step / **CFG=1** virtual try-on (few-step distillation–finetuned
> checkpoint) at **~4.5 s end-to-end on trn2.48xlarge**, beating the H100 vLLM-Omni reference
> (4.99 s), lossless. This regime is what the optimizations target — with CFG=1 a single
> forward per step makes Context Parallel the right lever; for CFG>1 use the V3 CFG (DP=2) path
> instead. See [`OPTIMIZATION.md`](OPTIMIZATION.md) for the full writeup (VAE batched-tile,
> bf16 TP all-reduce, WLO, CP scaling — each with measured latency deltas) and
> [`release_v3cp4_wlo/`](release_v3cp4_wlo/) for the ready-to-run production config.

## Model Information

- **HuggingFace ID:** `alibaba-pai/Qwen-Image-Edit-2509`
- **Model Type:** Diffusion model for image editing
- **Architecture:** Multi-component (Qwen2.5-VL Vision Encoder + Language Model + QwenImageTransformer2DModel + 3D VAE)
- **License:** Check HuggingFace model card

## Architecture Details

| Component | Model | Parameters | Neuron Parallelism |
|-----------|-------|------------|-------------------|
| Vision Encoder | Qwen2.5-VL ViT (32 blocks) | ~1.4B | TP=4, float32 (or CPU) |
| Language Model | Qwen2.5-VL LM (28 layers) | ~7B | TP=4, world_size=8 (or CPU) |
| Transformer | QwenImageTransformer2DModel | ~20.4B | TP=4-8, various parallelism modes |
| VAE | 3D AutoencoderKL (causal) | ~300M | Single device, tiled processing |

Key parameters:
- **Transformer**: 48 attention heads, head_dim=128, inner_dim=6144
- **Text Hidden Size**: 3584 (Qwen2.5-VL)
- **Dual-stream blocks**: 20 (separate text/image norms+FFN, joint attention)
- **Single-stream blocks**: 40 (concatenated text+image, parallel MLP+attention)

## Performance

6 compilation APIs with different parallelism strategies:

| Version | Parallelism | Attention | Per Step | Total (50 steps) | Notes |
|---------|------------|-----------|----------|-----------------|-------|
| **V3 CFG** | TP=4, DP=2 | NKI Flash | **~0.75s** | **~53s** | Fastest, recommended |
| V3 CP | TP=4, CP=2 | NKI Flash | ~0.77s | ~55s | Context Parallel |
| V1 Flash | TP=8 | NKI Flash | ~1.2s | ~76s | NKI kernel |
| V2 Flash | TP=8 | NKI Flash | ~1.2s | ~76s | ModelBuilder + NKI |
| V2 | TP=8 | Standard SDPA | ~1.2s | ~76s | ModelBuilder |
| V1 | TP=8 | Standard SDPA | ~2.4s | ~136s | Baseline |

Test: 1024x1024 output, guidance_scale=4.0, trn2.48xlarge, single-image
editing (`patch_multiplier=2`). Total time includes VAE encoding/decoding
and text encoding overhead.

> Note: a two-image merge (`patch_multiplier=3`) processes a longer joint
> sequence (`S = 1024 + 12288 = 13312`) and lands at roughly 1.6× the
> per-step time of single-image editing on the same configuration.

## Prerequisites

- **Instance**: trn2.48xlarge (64 NeuronCores, 1.5TB device memory)
- **Virtual env**: `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference`
- PyTorch 2.9, neuronx-cc 2.22, neuronx-distributed 0.16
- **NVMe**: Mount RAID at `/opt/dlami/nvme/` (run `src/setup_nvme.sh`)

## Usage

### 1. Setup

```bash
# Mount NVMe RAID
sudo bash src/setup_nvme.sh

# Activate virtual environment
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate

# Install dependencies
pip install -r requirements.txt
```

### 2. Download Model

```bash
python src/cache_hf_model.py
```

### 3. Compile All Components

```bash
# Compile V3 CFG (recommended, fastest)
bash src/compile.sh v3_cfg

# Compile V3 CP (Context Parallel)
bash src/compile.sh v3_cp

# Compile all versions
bash src/compile.sh

# Custom dimensions:
# bash src/compile.sh <version> <height> <width> <image_size> <tp_degree> <max_seq_len> <patch_mult> <batch_size>
```

Compilation takes ~60-120 minutes total depending on version.

### 4. Run Inference

`compile.sh` defaults to `patch_multiplier=3` (two-image merge), so the
example below uses two input images. For single-image editing, recompile
with `patch_multiplier=2` first.

```bash
# Two-image merge (matches compile.sh default of patch_multiplier=3)
NEURON_RT_NUM_CORES=8 PYTHONPATH=src:$PYTHONPATH python src/run_qwen_image_edit.py \
--compiled_models_dir /opt/dlami/nvme/compiled_models_qwen_image_edit \
--images assets/image1.png assets/image2.png \
--prompt "merge subjects from image1 and image2 into a single scene" \
--patch_multiplier 3 \
--use_v3_cfg \
--output output.png

# Single-image editing (requires recompilation with patch_multiplier=2)
# bash src/compile.sh v3_cfg 1024 1024 448 8 1024 2 1
# NEURON_RT_NUM_CORES=8 PYTHONPATH=src:$PYTHONPATH python src/run_qwen_image_edit.py \
# --compiled_models_dir /opt/dlami/nvme/compiled_models_qwen_image_edit \
# --images assets/image1.png \
# --prompt "change the sky to sunset" \
# --patch_multiplier 2 \
# --use_v3_cfg \
# --output output.png
```

### Runtime toggles

The transformer dispatch in `compile_transformer_v3_cfg.py` reads four
optional environment variables:

| Variable | Default | Effect |
|----------|---------|--------|
| `QIE_HOISTED_Q_ATTENTION` | `1` | Use the Phase 16 hoisted-Q `attention_cte` fork. Set to `0` to fall back to upstream `nkilib.core.attention.attention_cte`. |
| `QIE_ALLREDUCE_BF16` | `1` | Phase 17: use `bfloat16` instead of `float32` as the reduce dtype for every TP `RowParallelLinear` all-reduce. Halves the bytes on the wire and saves ~137 ms / step (~9% E2E) with no visible image quality regression. Set to `0` to keep the upstream `fp32` reduce. |
| `QIE_USE_NKILIB_ATTENTION` | `1` | When the hoisted-Q fork is disabled, choose between `nkilib` `attention_cte` (`1`) and the legacy `attention_isa_kernel` (`0`). |
| `QIE_SOFTMAX_DTYPE` | `float32` | Softmax accumulation dtype inside `attention_cte`. `bfloat16` is supported but measured no speedup on Trn2 because `mm_out_dtype` must stay `float32` on Gen3. |

## Compatibility Matrix

| Instance/Version | 2.22+ (PyTorch 2.9) | 2.21 and earlier |
|------------------|---------------------|------------------|
| Trn2 (trn2.48xlarge) | Tested | Not tested |
| Trn1 | Not tested | Not tested |
| Inf2 | Not supported | Not supported |

## Testing

```bash
# Run component tests
PYTHONPATH=src:$PYTHONPATH pytest test/integration/ --capture=tee-sys -v

# Run all tests manually
PYTHONPATH=src:$PYTHONPATH python test/integration/run_all_tests.py
```

## Key Implementation Notes

1. **Modulation Layer Sharding**: Uses `ColumnParallelLinear(gather_output=True)` to reduce memory from ~17GB to ~5.2GB per shard.
2. **RoPE Without Complex Numbers**: Neuron doesn't support C64; uses (cos, sin) tuples instead.
3. **M-RoPE Position IDs**: 3D position indices (temporal, height, width) for multimodal tokens.
4. **VAE Interpolation**: Replaces `nearest-exact` with `nearest` for Neuron compatibility.
5. **CFG Parallel**: Batches negative + positive prompts into single forward pass for ~6% speedup over CP.
6. **NKI Flash Attention**: Custom NKI kernel for Trainium2, requires `XLA_DISABLE_FUNCTIONALIZATION=1`.
7. **Hoisted-Q `attention_cte` (Phase 16)**: forked kernel (`attention_cte_qie_hoisted_q.py`) that hoists the Q-tile load out of the K/V section loop. For QIE shapes (`num_sections = 2` and Q identical across sections), this removes the redundant Q reload that the upstream `attention_cte` performs in `section_idx=1`. Bit-exact to the baseline (same `hardware_flops`, same output PNG MD5); −14.7 ms / step, +0.52 pp MFU on the V3 CFG configuration. Toggle via `QIE_HOISTED_Q_ATTENTION` (default `1`).
8. **bf16 TP all-reduce (Phase 17)**: drops the all-reduce dtype on every `RowParallelLinear` (attention output, attention text-output, MLP output) from `float32` to `bfloat16`. The V3 CFG configuration emits 956 TP all-reduces per step totalling ~18 GB; halving the bytes on the wire saves **~137 ms / step (~9% E2E)** — by far the largest single win in this contrib's optimization history. Output images are visually equivalent to the `fp32`-reduce baseline on the 1024×1024 two-image merge workload; not bit-exact (small numerical drift accumulates over 60 blocks × 40 steps, but does not change scene content / faces / composition). Toggle via `QIE_ALLREDUCE_BF16` (default `1`).

## File Structure

```
Qwen-Image-Edit/
README.md
requirements.txt
assets/
image1.png, image2.png # Test input images
src/
run_qwen_image_edit.py # Main inference script
neuron_commons.py # NeuronTextEncoderWrapper, SDPA implementations
neuron_parallel_utils.py # TP sharding utilities
neuron_rope.py # Neuron-compatible RoPE
autoencoder_kl_qwenimage_neuron.py # Neuron-compatible 3D VAE
compile_transformer.py # V1 transformer (TP=8)
compile_transformer_v1_flash.py # V1 Flash (NKI)
compile_transformer_v2.py # V2 (ModelBuilder)
compile_transformer_v2_flash.py # V2 Flash (ModelBuilder + NKI)
compile_transformer_v3_cp.py # V3 Context Parallel (TP=4, CP=2)
compile_transformer_v3_cfg.py # V3 CFG Parallel (TP=4, DP=2)
attention_cte_qie_hoisted_q.py # Phase 16: hoisted-Q attention_cte fork
compile_language_model_v3.py # Language Model V3 (TP=4)
compile_vision_encoder_v3.py # Vision Encoder V3 (TP=4)
compile_text_encoder.py # Vision encoder single-device
compile_vae.py # 3D VAE encoder/decoder
cache_hf_model.py # Download model
compile.sh # Master compilation script
setup_nvme.sh # NVMe RAID setup
test/
integration/
run_all_tests.py # Master test runner
test_vae.py # VAE tests
test_transformer.py # Transformer tests
test_text_encoder.py # Text encoder tests
test_component_comparison.py # Neuron vs CPU comparison
test_language_model_simple.py # Language model tests
test_multimodal.py # Multi-image tests
unit/
```

## Example Checkpoints

* [alibaba-pai/Qwen-Image-Edit-2509](https://huggingface.co/alibaba-pai/Qwen-Image-Edit-2509)

## Maintainer

Henan Wan (whn09)

**Last Updated:** 2026-05-14
Loading