diff --git a/contrib/models/Cosmos3-Text2Image/README.md b/contrib/models/Cosmos3-Text2Image/README.md new file mode 100644 index 00000000..fb1a473f --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/README.md @@ -0,0 +1,362 @@ +# Contrib Model: Cosmos3-Text2Image + +NeuronX Distributed Inference implementation of NVIDIA Cosmos3 omnimodal +Mixture-of-Transformers (MoT) for text-to-image generation. + +## Model Information + +- **Models:** Cosmos3-Nano (16B), Cosmos3-Super-Text2Image (65B) +- **HuggingFace ID:** `nvidia/Cosmos3-Nano`, `nvidia/Cosmos3-Super-Text2Image` +- **Model Type:** Diffusion Transformer (MoT architecture) +- **Task:** Text-to-Image/Video Generation (512x512, 1024x1024, video up to 61 frames) +- **License:** Check HuggingFace model card + +## Architecture Details + +Cosmos3 uses a **Mixture-of-Transformers (MoT)** architecture: +- Dual-stream processing: text (understanding) and vision (generation) pathways +- Joint MMDiT-style attention: text uses causal self-attention, vision attends bidirectionally to all tokens +- Separate SwiGLU MLPs per stream (text MLP + generation MLP in each layer) +- M-RoPE (Multimodal Rotary Position Embedding) with 3 axes (T, H, W) +- QK normalization (per-head RMSNorm) +- GQA (Grouped Query Attention) +- VAE: AutoencoderKLWan with 48 latent channels, patch_size=2, spatial_compression=16 +- Scheduler: UniPCMultistepScheduler (35 steps, flow matching) +- CFG scale: 6.0 + +| | Cosmos3-Nano | Cosmos3-Super | +|--|--|--| +| Parameters | 16B | 65B | +| hidden_size | 4096 | 5120 | +| intermediate_size | 12288 | 25600 | +| Layers | 36 | 64 | +| Q Heads | 32 | 64 | +| KV Heads | 8 | 8 | +| Instance | trn2.3xlarge | trn2.48xlarge | +| TP Degree | 4 | 8 | + +## Validation Results + +**Validated:** 2026-06-12 +**SDK:** 2.30 (torch-neuronx 2.9.0.2.14.27725) + +### Cosmos3-Nano (trn2.3xlarge, TP=4) + +| Metric | 512x512 (35 steps) | 512x512 CFG-parallel | 1024x1024 (50 steps) | 1024x1024 CFG-parallel | +|--------|-------|-------|-------|-------| +| Backbone latency | 33.4 ms/call | 49.7 ms (batch=2) | 167.9 ms/call | 131.1 ms (batch=2) | +| E2E generation | **2.79s** | **2.23s** | **8.63s** | **6.79s** | +| Per-step latency | 78.2 ms | 62.2 ms | 167.9 ms | 131.1 ms | +| VAE decode | 50 ms | 50 ms | 231 ms | 231 ms | +| Speedup vs sequential | - | 20% | - | 21% | + +### Cosmos3-Super-Text2Image (trn2.48xlarge, TP=8) + +| Metric | 512x512 (35 steps) | +|--------|-------| +| Backbone latency | 79.5 ms/call | +| E2E generation | **5.81s** | +| Per-step latency | 164.6 ms | +| VAE decode | 50 ms | +| Image quality | High fidelity | + +**Status:** VALIDATED (both variants) + +## Setup + +### Prerequisites + +- AWS Neuron SDK 2.30+ (DLAMI `Deep Learning AMI Neuron (Ubuntu 24.04) 20260522`) +- trn2.3xlarge (Nano) or trn2.48xlarge (Super) +- `diffusers >= 0.39.0.dev0` (for Cosmos3 VAE support) + +### Environment + +```bash +# Activate pre-installed environment +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + +# Install diffusers from source (needed for Cosmos3 VAE) +pip install git+https://github.com/huggingface/diffusers.git +``` + +### Download Model + +```bash +# Nano (33 GB) +huggingface-cli download nvidia/Cosmos3-Nano --local-dir /home/ubuntu/Cosmos3-Nano + +# Super (124 GB) +huggingface-cli download nvidia/Cosmos3-Super-Text2Image --local-dir /home/ubuntu/Cosmos3-Super-Text2Image +``` + +## Usage + +### 1. Compile Backbone + +```bash +# Nano at 512x512 (TP=4, ~2 min compile) +python examples/compile.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --tp 4 \ + --output /home/ubuntu/compiled_cosmos3_nano + +# Nano at 1024x1024 (TP=4, ~2 min compile) +python examples/compile.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --tp 4 \ + --height 1024 --width 1024 \ + --output /home/ubuntu/compiled_cosmos3_nano_1024 + +# Super at 512x512 (TP=8, ~5 min compile) +python examples/compile.py \ + --model-path /home/ubuntu/Cosmos3-Super-Text2Image \ + --tp 8 \ + --output /home/ubuntu/compiled_cosmos3_super + +# Super at 1024x1024 (TP=8, ~5 min compile) +python examples/compile.py \ + --model-path /home/ubuntu/Cosmos3-Super-Text2Image \ + --tp 8 \ + --height 1024 --width 1024 \ + --output /home/ubuntu/compiled_cosmos3_super_1024 + +# CFG-parallel (batch=2, ~20% faster generation): +python examples/compile.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --tp 4 --cfg-parallel \ + --output /home/ubuntu/compiled_cosmos3_nano_cfgp + +python examples/compile.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --tp 4 --height 1024 --width 1024 --cfg-parallel \ + --output /home/ubuntu/compiled_cosmos3_nano_1024_cfgp +``` + +### 2. Compile VAE + +```bash +# For 512x512 output +python examples/compile_vae.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --output /home/ubuntu/compiled_vae/vae_512.pt + +# For 1024x1024 output (~10 min compile) +python examples/compile_vae.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --height 1024 --width 1024 \ + --output /home/ubuntu/compiled_vae/vae_1024.pt +``` + +Note: The same compiled VAE works for both Nano and Super at the same resolution (same architecture). + +### 3. Generate Images + +```bash +# Nano at 512x512 +python examples/generate.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --compiled-path /home/ubuntu/compiled_cosmos3_nano \ + --vae-path /home/ubuntu/compiled_vae/vae_512.pt \ + --tp 4 \ + --prompt "A cat sitting on a windowsill watching birds" \ + --output cat_512.png + +# Nano at 1024x1024 +python examples/generate.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --compiled-path /home/ubuntu/compiled_cosmos3_nano_1024 \ + --vae-path /home/ubuntu/compiled_vae/vae_1024.pt \ + --tp 4 --height 1024 --width 1024 --steps 50 \ + --prompt "A majestic snow-covered mountain at sunrise" \ + --output mountain_1024.png + +# Super at 512x512 +python examples/generate.py \ + --model-path /home/ubuntu/Cosmos3-Super-Text2Image \ + --compiled-path /home/ubuntu/compiled_cosmos3_super \ + --vae-path /home/ubuntu/compiled_vae/vae_512.pt \ + --tp 8 \ + --prompt "A majestic snow-covered mountain at sunrise with golden light" \ + --output mountain.png + +# CFG-parallel mode (requires backbone compiled with --cfg-parallel): +python examples/generate.py \ + --model-path /home/ubuntu/Cosmos3-Nano \ + --compiled-path /home/ubuntu/compiled_cosmos3_nano_cfgp \ + --vae-path /home/ubuntu/compiled_vae/vae_512.pt \ + --tp 4 --cfg-parallel \ + --prompt "A golden retriever in autumn leaves" \ + --output dog_cfgp.png +``` + +### Python API + +```python +import torch +import torch_neuronx +from src.modeling_cosmos3 import ( + Cosmos3BackboneInferenceConfig, + NeuronCosmos3BackboneApplication, +) +from src.pipeline import ( + build_position_ids, denoise, denormalize_latents, tokenize_prompt, +) +from neuronx_distributed_inference.models.config import NeuronConfig +from transformers import AutoTokenizer +from diffusers import UniPCMultistepScheduler + +# Configure (Nano example) +neuron_config = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16) +config = Cosmos3BackboneInferenceConfig(neuron_config=neuron_config) +config.max_text_len = 256 +config.num_vision_patches = 256 + +# Load +app = NeuronCosmos3BackboneApplication( + model_path="/home/ubuntu/Cosmos3-Nano/transformer", config=config +) +app.load("/home/ubuntu/compiled_cosmos3_nano") + +# Tokenize +tokenizer = AutoTokenizer.from_pretrained("/home/ubuntu/Cosmos3-Nano", trust_remote_code=True) +cond_ids, cond_len = tokenize_prompt(tokenizer, "A sunset over the ocean") +uncond_ids, uncond_len = tokenize_prompt(tokenizer, "", negative=True) + +# Position IDs +cond_pos = build_position_ids(256, cond_len, T=1, pH=16, pW=16) +uncond_pos = build_position_ids(256, uncond_len, T=1, pH=16, pW=16) + +# Generate +latents = torch.randn(1, 48, 1, 32, 32, dtype=torch.float32) +scheduler = UniPCMultistepScheduler.from_pretrained("/home/ubuntu/Cosmos3-Nano", subfolder="scheduler") + +latents = denoise(app, cond_ids, uncond_ids, cond_pos, uncond_pos, scheduler, latents) +latents = denormalize_latents(latents, "/home/ubuntu/Cosmos3-Nano/vae/config.json") + +# Decode with VAE +vae = torch.jit.load("/home/ubuntu/compiled_vae/vae_decoder.pt") +pixels = vae(latents.float()) # [1, 3, 1, 512, 512] +``` + +## Testing + +```bash +# Run integration tests +export COSMOS3_MODEL_PATH=/home/ubuntu/Cosmos3-Nano +export COSMOS3_COMPILED_PATH=/home/ubuntu/compiled_cosmos3_nano +export COSMOS3_VAE_PATH=/home/ubuntu/compiled_vae/vae_decoder.pt +export COSMOS3_TP_DEGREE=4 + +pytest test/integration/test_model.py --capture=tee-sys -v + +# Or run manually: +python test/integration/test_model.py +``` + +## Key Implementation Notes + +1. **Channel ordering in patchify/unpatchify**: Uses spatial-first, channels-last + `(p_h, p_w, C)` matching the reference einsum `"cthpwq->thwpqc"`. Getting this + wrong produces a 16x16 repeating tile pattern. + +2. **Temporal margin**: Vision position IDs use `actual_text_len + 15000` as temporal + offset, matching `unified_3d_mrope_temporal_modality_margin=15000` in the reference. + +3. **Tokenization**: Must use the full chat template format with system prompt + + resolution template + special tokens (eos + vision_start). + +4. **Warmup both CFG paths**: The backbone must be warmed up with both conditional and + unconditional inputs before timing. Without this, first-call overhead adds ~2.6s. + +5. **Super model compilation**: Models with > 36 layers require `--verify-hlo=false` + and `--modular-flow-mac-threshold=10` to avoid pre-partition HBM verification failure. + +## Compatibility Matrix + +| Instance | Nano (TP=4) | Super (TP=8) | +|----------|-------------|--------------| +| trn2.3xlarge (LNC=2) | **Working** (512, 1024) | N/A (HBM limit) | +| trn2.48xlarge (LNC=2) | Working | **Working** (512, 1024) | + +## Video Generation (Experimental) + +The Cosmos3 backbone is **modality-agnostic** — the same compiled model that generates +images can generate video by providing temporal position IDs (T > 1 in the M-RoPE +encoding). No recompilation is needed if the total patch count matches an existing +compiled model. + +### How It Works + +The backbone processes a flat sequence of text + vision patches. For images, vision +patches come from a 2D spatial grid. For video, patches span a 3D grid (T × H × W): + +| Modality | T_lat | pH × pW | Total Patches | Use Compiled Model | +|----------|-------|---------|---------------|--------------------| +| Image 512×512 | 1 | 16×16 | 256 | compile at `--height 512 --width 512` | +| Image 1024×1024 | 1 | 32×32 | 1024 | compile at `--height 1024 --width 1024` | +| Video 13f@512 | 4 | 16×16 | 1024 | **Reuse 1024p image model!** | +| Video 29f@512 | 8 | 16×16 | 2048 | compile at 2048 patches | +| Video 61f@512 | 16 | 16×16 | 4096 | compile at 4096 patches | + +The temporal latent count is: `T_lat = (raw_frames - 1) // 4 + 1` (VAE temporal +compression factor = 4). + +### Example: Generate Video with Existing 1024p Model + +```python +from src.pipeline import build_position_ids, patchify, unpatchify, denoise + +# Use T_lat=4 (13 raw frames) at 512x512 → 4×16×16 = 1024 patches +# Same compiled model as 1024x1024 image generation! +T_lat = 4 +pH, pW = 16, 16 + +# Build video position IDs (temporal axis spans T_lat values) +cond_pos = build_position_ids(256, actual_text_len, T=T_lat, pH=pH, pW=pW) +uncond_pos = build_position_ids(256, uncond_text_len, T=T_lat, pH=pH, pW=pW) + +# Initial noise with temporal dimension +latents = torch.randn(1, 48, T_lat, 32, 32, dtype=torch.float32) + +# Denoise (pipeline handles patchify/unpatchify with T>1 automatically) +latents = denoise(backbone, cond_ids, uncond_ids, cond_pos, uncond_pos, + scheduler, latents, num_steps=35) +``` + +### Measured Video Performance (Nano, TP=4, trn2.3xlarge) + +| Video Config | Raw Frames | T_lat | Patches | Per-call Latency | Total (35 steps) | +|-------------|-----------|-------|---------|-----------------|-------------------| +| 13f @ 512×512 | 13 | 4 | 1024 | ~83 ms | 5.83s | +| 29f @ 512×512 | 29 | 8 | 2048 | ~121 ms | 8.45s | +| 61f @ 512×512 | 61 | 16 | 4096 | ~239 ms | 16.73s | + +### Limitations + +- **VAE decode**: The compiled image VAE only handles T=1. Per-frame decoding works as + an approximation. A proper 3D video VAE compilation is needed for production quality. +- **Maximum sequence length**: Tested up to 8192 patches (seq_len=8448) on trn2.3xlarge. + Longer videos (189 frames = 41k patches) require context parallelism. +- **Video quality**: Without the proper 3D VAE, temporal consistency between decoded + frames depends on the per-frame decode approximation. + +## Supported Resolutions + +The backbone can be compiled at any resolution divisible by 32. Compile time and +latency scale with sequence length (number of vision patches). + +| Resolution | Vision Patches | Total Seq Len | Compile Time (Nano) | Latency/Step (Nano) | +|-----------|---------------|---------------|--------------------|--------------------| +| 512x512 | 256 | 512 | ~2 min | 33.4 ms | +| 768x768 | 576 | 832 | ~2 min | ~50 ms | +| 1024x1024 | 1024 | 1280 | ~2 min | 80.6 ms | + +The VAE must be compiled separately for each target resolution. + + +## Maintainer + +Annapurna Labs + +**Last Updated:** 2026-06-12 diff --git a/contrib/models/Cosmos3-Text2Image/examples/compile.py b/contrib/models/Cosmos3-Text2Image/examples/compile.py new file mode 100644 index 00000000..1de96633 --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/examples/compile.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 +""" +Compile Cosmos3 backbone for Neuron. + +Supports both Cosmos3-Nano (16B, TP=4) and Cosmos3-Super-Text2Image (65B, TP=8). + +Usage: + # Nano at 512x512 on trn2.3xlarge (TP=4): + python compile.py --model-path /path/to/Cosmos3-Nano --tp 4 --output /path/to/compiled + + # Nano at 1024x1024: + python compile.py --model-path /path/to/Cosmos3-Nano --tp 4 --height 1024 --width 1024 --output /path/to/compiled_1024p + + # Super at 512x512 on trn2.48xlarge (TP=8): + python compile.py --model-path /path/to/Cosmos3-Super-Text2Image --tp 8 --output /path/to/compiled + + # Super at 1024x1024: + python compile.py --model-path /path/to/Cosmos3-Super-Text2Image --tp 8 --height 1024 --width 1024 --output /path/to/compiled_1024p + +Environment: + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +""" + +import argparse +import json +import os +import sys +import time + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +os.environ.setdefault("NEURON_COMPILE_CACHE_URL", "/tmp/neuron_cache") + +import torch +import torch_neuronx +from modeling_cosmos3 import ( + Cosmos3BackboneInferenceConfig, + NeuronCosmos3BackboneApplication, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +# Model configurations +MODEL_CONFIGS = { + "Cosmos3-Nano": { + "hidden_size": 4096, + "intermediate_size": 12288, + "num_hidden_layers": 36, + "num_attention_heads": 32, + "num_key_value_heads": 8, + }, + "Cosmos3-Super-Text2Image": { + "hidden_size": 5120, + "intermediate_size": 25600, + "num_hidden_layers": 64, + "num_attention_heads": 64, + "num_key_value_heads": 8, + }, +} + +# Resolution presets (height, width) -> num_vision_patches +# Formula: num_patches = (height // 32) * (width // 32) +# Where 32 = scale_factor_spatial(16) * patch_size(2) +RESOLUTION_PRESETS = { + "512x512": (512, 512, 256), # 16×16 = 256 patches + "768x768": (768, 768, 576), # 24×24 = 576 patches + "1024x1024": (1024, 1024, 1024), # 32×32 = 1024 patches +} + + +def detect_model_variant(model_path: str) -> str: + """Detect model variant from config.json.""" + config_path = os.path.join(model_path, "transformer", "config.json") + if not os.path.exists(config_path): + config_path = os.path.join(model_path, "config.json") + with open(config_path) as f: + cfg = json.load(f) + hidden_size = cfg.get("hidden_size", cfg.get("d_model", 4096)) + if hidden_size >= 5120: + return "Cosmos3-Super-Text2Image" + return "Cosmos3-Nano" + + +def main(): + parser = argparse.ArgumentParser(description="Compile Cosmos3 backbone for Neuron") + parser.add_argument("--model-path", required=True, help="Path to HF model weights") + parser.add_argument( + "--output", required=True, help="Output path for compiled model" + ) + parser.add_argument( + "--tp", type=int, default=None, help="TP degree (auto-detected if not set)" + ) + parser.add_argument( + "--max-text-len", type=int, default=256, help="Max text token length" + ) + parser.add_argument( + "--height", type=int, default=512, help="Target image height in pixels" + ) + parser.add_argument( + "--width", type=int, default=512, help="Target image width in pixels" + ) + parser.add_argument( + "--num-vision-patches", + type=int, + default=None, + help="Override: number of vision patches (auto-calculated from height/width if not set)", + ) + parser.add_argument( + "--cfg-parallel", + action="store_true", + help="Enable CFG-parallel (batch=2): pack cond+uncond in a single call for ~20%% speedup", + ) + args = parser.parse_args() + + # Calculate vision patches from resolution + if args.num_vision_patches is not None: + num_vision_patches = args.num_vision_patches + else: + # Formula: (height / scale_factor_spatial / patch_size) * (width / scale_factor_spatial / patch_size) + # = (height / 32) * (width / 32) + pH = args.height // 32 + pW = args.width // 32 + num_vision_patches = pH * pW + if args.height % 32 != 0 or args.width % 32 != 0: + raise ValueError( + f"Height ({args.height}) and width ({args.width}) must be divisible by 32 " + f"(scale_factor_spatial=16 × patch_size=2)" + ) + + total_seq = args.max_text_len + num_vision_patches + + # Detect model variant + variant = detect_model_variant(args.model_path) + model_cfg = MODEL_CONFIGS[variant] + print(f"Detected model: {variant}") + print( + f" hidden_size={model_cfg['hidden_size']}, layers={model_cfg['num_hidden_layers']}" + ) + + # Auto-select TP + tp = args.tp + if tp is None: + tp = 4 if variant == "Cosmos3-Nano" else 8 + print(f" TP degree: {tp}") + print(f" Resolution: {args.height}x{args.width}") + print( + f" Vision patches: {num_vision_patches} (patch grid: {args.height // 32}x{args.width // 32})" + ) + print( + f" Total sequence length: {total_seq} (text={args.max_text_len} + vision={num_vision_patches})" + ) + if args.cfg_parallel: + print(f" CFG-parallel: ENABLED (batch=2, cond+uncond in single call)") + + # Create config + neuron_config = NeuronConfig( + tp_degree=tp, world_size=tp, torch_dtype=torch.bfloat16 + ) + config = Cosmos3BackboneInferenceConfig( + neuron_config=neuron_config, + cfg_parallel_enabled=args.cfg_parallel, + head_dim=128, + vocab_size=151936, + patch_channels=192, + latent_channels=48, + rope_theta=5000000.0, + mrope_section=[24, 20, 20], + **model_cfg, + ) + print( + f" Total sequence length: {total_seq} (text={args.max_text_len} + vision={num_vision_patches})" + ) + + # Create config + neuron_config = NeuronConfig( + tp_degree=tp, world_size=tp, torch_dtype=torch.bfloat16 + ) + config = Cosmos3BackboneInferenceConfig( + neuron_config=neuron_config, + head_dim=128, + vocab_size=151936, + patch_channels=192, + latent_channels=48, + rope_theta=5000000.0, + mrope_section=[24, 20, 20], + **model_cfg, + ) + config.max_text_len = args.max_text_len + config.num_vision_patches = num_vision_patches + + # Compile + transformer_path = os.path.join(args.model_path, "transformer") + print(f"\nCompiling {variant} backbone...") + print(f" Weights: {transformer_path}") + print(f" Output: {args.output}") + + t0 = time.time() + app = NeuronCosmos3BackboneApplication(model_path=transformer_path, config=config) + app.compile(args.output) + elapsed = time.time() - t0 + + print(f"\nCompilation complete in {elapsed:.1f}s") + print(f"Compiled model saved to: {args.output}") + print(f"\nTo generate images at {args.height}x{args.width}, run:") + print(f" python generate.py --model-path {args.model_path} \\") + print(f" --compiled-path {args.output} \\") + print(f" --vae-path /vae_decoder.pt \\") + print(f" --height {args.height} --width {args.width} \\") + print(f' --prompt "your prompt"') + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Cosmos3-Text2Image/examples/compile_vae.py b/contrib/models/Cosmos3-Text2Image/examples/compile_vae.py new file mode 100644 index 00000000..70485dfb --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/examples/compile_vae.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +""" +Compile the Cosmos3 VAE decoder for Neuron using torch_neuronx.trace(). + +The VAE decoder is a standard AutoencoderKLWan (48 latent channels). +It runs on a single NeuronCore (no TP needed). + +Key challenge: F.interpolate (nearest-exact) is not supported in XLA tracing. +Solution: Monkey-patch all nn.Upsample modules with repeat_interleave equivalents. + +Usage: + # For 512x512 images: + python compile_vae.py --model-path /path/to/Cosmos3-Nano --output /path/to/vae_512.pt + + # For 1024x1024 images: + python compile_vae.py --model-path /path/to/Cosmos3-Nano --height 1024 --width 1024 \ + --output /path/to/vae_1024.pt + +Environment: + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate + + IMPORTANT: Requires diffusers >= 0.39.0.dev0 for Cosmos3 VAE support. + Install with: pip install git+https://github.com/huggingface/diffusers.git +""" + +import argparse +import os +import time + +import torch +import torch.nn as nn +import torch_neuronx +from diffusers import AutoencoderKLWan + + +def patch_upsample_modules(model: nn.Module) -> int: + """Replace all nn.Upsample modules with XLA-compatible repeat_interleave versions. + + nn.Upsample uses F.interpolate which isn't supported in XLA tracing. + We replace with repeat_interleave which achieves the same nearest-neighbor upsampling. + + Returns: + Number of modules patched. + """ + + class NeuronUpsample(nn.Module): + """XLA-compatible nearest-neighbor upsampling via repeat_interleave.""" + + def __init__(self, scale_factor): + super().__init__() + if isinstance(scale_factor, (tuple, list)): + self.scale_h = int(scale_factor[0]) + self.scale_w = int(scale_factor[1]) + else: + self.scale_h = int(scale_factor) + self.scale_w = int(scale_factor) + + def forward(self, x): + if x.dim() == 5: + # [B, C, T, H, W] — 3D video latents + b, c, t, h, w = x.shape + x = x.view(b * t, c, h, w) + x = x.repeat_interleave(self.scale_h, dim=2) + x = x.repeat_interleave(self.scale_w, dim=3) + x = x.view(b, c, t, h * self.scale_h, w * self.scale_w) + else: + # [B, C, H, W] — 2D + x = x.repeat_interleave(self.scale_h, dim=2) + x = x.repeat_interleave(self.scale_w, dim=3) + return x + + count = 0 + for name, module in model.named_modules(): + if isinstance(module, nn.Upsample): + parts = name.split(".") + parent = model + for p in parts[:-1]: + parent = parent[int(p)] if p.isdigit() else getattr(parent, p) + attr_name = parts[-1] + replacement = NeuronUpsample(module.scale_factor) + if attr_name.isdigit(): + parent[int(attr_name)] = replacement + else: + setattr(parent, attr_name, replacement) + count += 1 + return count + + +class VAEDecodeWrapper(nn.Module): + """Wrapper that bypasses temporal caching for single-frame decode. + + Pipeline: post_quant_conv -> decoder(no cache, first_chunk=True) -> unpatchify -> clamp + """ + + def __init__(self, vae): + super().__init__() + self.post_quant_conv = vae.post_quant_conv + self.decoder = vae.decoder + self.patch_size = vae.config.patch_size + + def forward(self, z): + # z: [1, 48, 1, H_latent, W_latent] + x = self.post_quant_conv(z) + out = self.decoder(x, feat_cache=None, feat_idx=[0], first_chunk=True) + out = self._unpatchify(out) + out = torch.clamp(out, min=-1.0, max=1.0) + return out + + def _unpatchify(self, x): + """Inline unpatchify matching diffusers implementation.""" + p = self.patch_size + b, c_pp, t, h, w = x.shape + c = c_pp // (p * p) + x = x.view(b, c, p, p, t, h, w) + x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous() + x = x.view(b, c, t, h * p, w * p) + return x + + +def main(): + parser = argparse.ArgumentParser(description="Compile Cosmos3 VAE decoder") + parser.add_argument( + "--model-path", + required=True, + help="Path to HF model (parent dir with vae/ subfolder)", + ) + parser.add_argument( + "--output", required=True, help="Output path for compiled VAE (.pt)" + ) + parser.add_argument("--height", type=int, default=512, help="Target image height") + parser.add_argument("--width", type=int, default=512, help="Target image width") + args = parser.parse_args() + + vae_path = os.path.join(args.model_path, "vae") + print(f"Loading VAE from: {vae_path}") + + vae = AutoencoderKLWan.from_pretrained(vae_path, torch_dtype=torch.float32) + vae.eval() + print(f" Parameters: {sum(p.numel() for p in vae.parameters()) / 1e6:.1f}M") + + # Patch upsample modules for XLA compatibility + num_patched = patch_upsample_modules(vae) + print(f" Patched {num_patched} nn.Upsample modules for Neuron compatibility") + + # Latent dimensions for target resolution + # spatial_compression = 16, latent_channels = 48, temporal = 1 for images + H_latent = args.height // 16 + W_latent = args.width // 16 + T = 1 + + print(f"\nTarget resolution: {args.height}x{args.width}") + print(f"Latent shape: [1, 48, {T}, {H_latent}, {W_latent}]") + print(f"Output shape: [1, 3, {T}, {args.height}, {args.width}]") + + # Create wrapper and example input + wrapper = VAEDecodeWrapper(vae) + wrapper.eval() + example_input = torch.randn(1, 48, T, H_latent, W_latent, dtype=torch.float32) + + # Verify CPU output + print("\nVerifying CPU decode...") + with torch.no_grad(): + cpu_output = wrapper(example_input) + print(f" CPU output shape: {cpu_output.shape}") + print( + f" CPU output range: [{cpu_output.min().item():.3f}, {cpu_output.max().item():.3f}]" + ) + + # Compile for Neuron + print(f"\nCompiling VAE decoder for {args.height}x{args.width}...") + print(f" (3D convolutions, ~700M params, may take 5-15 minutes)") + t0 = time.time() + + compiled = torch_neuronx.trace( + wrapper, + example_input, + compiler_args=[ + "--auto-cast", + "matmult", + "--model-type=unet-inference", + ], + ) + elapsed = time.time() - t0 + print(f" Compilation complete in {elapsed:.1f}s") + + # Save + os.makedirs(os.path.dirname(os.path.abspath(args.output)) or ".", exist_ok=True) + torch.jit.save(compiled, args.output) + neff_size = os.path.getsize(args.output) / (1024**2) + print(f" Saved to: {args.output} ({neff_size:.1f} MB)") + + # Quick benchmark + print("\nBenchmarking...") + with torch.no_grad(): + neuron_output = compiled(example_input) + diff = (cpu_output - neuron_output).abs().max().item() + print(f" Max diff vs CPU: {diff:.6f}") + + # Warmup + measure + for _ in range(3): + with torch.no_grad(): + _ = compiled(example_input) + times = [] + for _ in range(5): + t0 = time.time() + with torch.no_grad(): + _ = compiled(example_input) + times.append((time.time() - t0) * 1000) + avg_ms = sum(times) / len(times) + print(f" Average latency: {avg_ms:.1f}ms") + print( + f"\nDone! VAE decoder for {args.height}x{args.width}: {avg_ms:.1f}ms on 1 NeuronCore" + ) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Cosmos3-Text2Image/examples/generate.py b/contrib/models/Cosmos3-Text2Image/examples/generate.py new file mode 100644 index 00000000..a7cc03bc --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/examples/generate.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +""" +Generate images with Cosmos3 on Neuron. + +Supports both Cosmos3-Nano (16B) and Cosmos3-Super-Text2Image (65B). + +Usage: + # Generate with Nano: + python generate.py \ + --model-path /path/to/Cosmos3-Nano \ + --compiled-path /path/to/compiled \ + --vae-path /path/to/vae_decoder.pt \ + --prompt "A cat sitting on a windowsill" \ + --output generated.png + + # Generate with Super: + python generate.py \ + --model-path /path/to/Cosmos3-Super-Text2Image \ + --compiled-path /path/to/compiled_super \ + --vae-path /path/to/vae_decoder.pt \ + --tp 8 \ + --prompt "A majestic mountain at sunrise" \ + --output generated.png + +Environment: + source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +""" + +import argparse +import json +import os +import sys +import time + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src")) + +os.environ.setdefault("NEURON_COMPILE_CACHE_URL", "/tmp/neuron_cache") + +import torch +import torch_neuronx +from modeling_cosmos3 import ( + Cosmos3BackboneInferenceConfig, + NeuronCosmos3BackboneApplication, +) +from pipeline import ( + build_position_ids, + denoise, + denormalize_latents, + patchify, + tokenize_prompt, +) +from neuronx_distributed_inference.models.config import NeuronConfig +from transformers import AutoTokenizer +from diffusers import UniPCMultistepScheduler +from PIL import Image + + +def main(): + parser = argparse.ArgumentParser( + description="Generate images with Cosmos3 on Neuron" + ) + parser.add_argument("--model-path", required=True, help="Path to HF model weights") + parser.add_argument( + "--compiled-path", required=True, help="Path to compiled backbone" + ) + parser.add_argument( + "--vae-path", required=True, help="Path to compiled VAE decoder (.pt)" + ) + parser.add_argument( + "--tp", type=int, default=4, help="TP degree (4 for Nano, 8 for Super)" + ) + parser.add_argument( + "--prompt", + type=str, + default="A beautiful sunset over the ocean", + help="Text prompt", + ) + parser.add_argument( + "--negative-prompt", type=str, default="", help="Negative prompt" + ) + parser.add_argument("--steps", type=int, default=35, help="Denoising steps") + parser.add_argument("--cfg-scale", type=float, default=6.0, help="CFG scale") + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument( + "--output", type=str, default="generated.png", help="Output image path" + ) + parser.add_argument("--height", type=int, default=512, help="Output height") + parser.add_argument("--width", type=int, default=512, help="Output width") + parser.add_argument("--max-text-len", type=int, default=256, help="Max text tokens") + parser.add_argument( + "--cfg-parallel", + action="store_true", + help="Use CFG-parallel mode (backbone compiled with --cfg-parallel)", + ) + args = parser.parse_args() + + MAX_TEXT = args.max_text_len + NUM_VIS = (args.height // 32) * (args.width // 32) # patches for 512x512 = 256 + + # Detect model variant from hidden_size + config_path = os.path.join(args.model_path, "transformer", "config.json") + with open(config_path) as f: + model_cfg = json.load(f) + + hidden_size = model_cfg.get("hidden_size", 4096) + is_super = hidden_size >= 5120 + + print("=" * 60) + print( + f"Cosmos3 {'Super' if is_super else 'Nano'} Image Generation on Neuron (TP={args.tp})" + ) + print("=" * 60) + + # --- Load backbone --- + print( + f"\n[1/4] Loading backbone ({model_cfg.get('num_hidden_layers', 36)} layers, TP={args.tp})..." + ) + t0 = time.time() + neuron_config = NeuronConfig( + tp_degree=args.tp, world_size=args.tp, torch_dtype=torch.bfloat16 + ) + config = Cosmos3BackboneInferenceConfig( + neuron_config=neuron_config, + cfg_parallel_enabled=args.cfg_parallel, + hidden_size=hidden_size, + intermediate_size=model_cfg.get("intermediate_size", 12288), + num_hidden_layers=model_cfg.get("num_hidden_layers", 36), + num_attention_heads=model_cfg.get("num_attention_heads", 32), + num_key_value_heads=model_cfg.get("num_key_value_heads", 8), + head_dim=128, + vocab_size=model_cfg.get("vocab_size", 151936), + patch_channels=192, + latent_channels=48, + rope_theta=model_cfg.get("rope_theta", 5000000.0), + mrope_section=model_cfg.get("mrope_section", [24, 20, 20]), + ) + config.max_text_len = MAX_TEXT + config.num_vision_patches = NUM_VIS + + transformer_path = os.path.join(args.model_path, "transformer") + app = NeuronCosmos3BackboneApplication(model_path=transformer_path, config=config) + app.load(args.compiled_path) + print(f" Loaded in {time.time() - t0:.1f}s") + + # --- Load VAE --- + print("\n[2/4] Loading VAE...") + t0 = time.time() + vae = torch.jit.load(args.vae_path) + print(f" Loaded in {time.time() - t0:.1f}s") + + # --- Load tokenizer --- + print("\n[3/4] Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + + # --- Tokenize --- + print(f"\nPrompt: '{args.prompt}'") + cond_ids, cond_len = tokenize_prompt( + tokenizer, args.prompt, height=args.height, width=args.width, max_len=MAX_TEXT + ) + uncond_ids, uncond_len = tokenize_prompt( + tokenizer, + "", + height=args.height, + width=args.width, + max_len=MAX_TEXT, + negative=True, + ) + print(f" Cond tokens: {cond_len}, Uncond tokens: {uncond_len}") + + # --- Position IDs --- + T = 1 # single frame for images + pH = args.height // 32 # 512 -> 16 + pW = args.width // 32 + + cond_pos = build_position_ids(MAX_TEXT, cond_len, T, pH, pW) + uncond_pos = build_position_ids(MAX_TEXT, uncond_len, T, pH, pW) + + # --- Warmup --- + print("\n[4/4] Warming up backbone (both CFG paths)...") + t0 = time.time() + if args.cfg_parallel: + dummy_patches = torch.randn(2, NUM_VIS, 192, dtype=torch.bfloat16) + dummy_ts = torch.tensor([0.5, 0.5], dtype=torch.bfloat16) + for _ in range(2): + _ = app( + torch.cat([cond_ids, uncond_ids], dim=0), + dummy_patches, + dummy_ts, + cond_pos, + ) + else: + dummy_patches = torch.randn(1, NUM_VIS, 192, dtype=torch.bfloat16) + dummy_ts = torch.tensor([0.5], dtype=torch.bfloat16) + for _ in range(2): + _ = app(cond_ids, dummy_patches, dummy_ts, cond_pos) + _ = app(uncond_ids, dummy_patches, dummy_ts, uncond_pos) + print(f" Warmup done in {time.time() - t0:.1f}s") + + # --- Generate --- + H_latent = args.height // 16 + W_latent = args.width // 16 + + gen = torch.manual_seed(args.seed) + latents = torch.randn( + 1, 48, T, H_latent, W_latent, generator=gen, dtype=torch.float32 + ) + + scheduler = UniPCMultistepScheduler.from_pretrained( + args.model_path, subfolder="scheduler" + ) + + print(f"\n Denoising: {args.steps} steps, CFG={args.cfg_scale}") + import logging + + logging.basicConfig(level=logging.INFO) + + latents = denoise( + backbone=app, + cond_ids=cond_ids, + uncond_ids=uncond_ids, + cond_pos=cond_pos, + uncond_pos=uncond_pos, + scheduler=scheduler, + latents=latents, + num_steps=args.steps, + cfg_scale=args.cfg_scale, + cfg_parallel=args.cfg_parallel, + ) + + # --- Denormalize + VAE decode --- + vae_config_path = os.path.join(args.model_path, "vae", "config.json") + latents = denormalize_latents(latents, vae_config_path) + + print(" VAE decoding...") + t0 = time.time() + with torch.no_grad(): + pixels = vae(latents.float()) + vae_time = time.time() - t0 + print(f" VAE: {vae_time * 1000:.0f}ms") + + # --- Save --- + pixels = pixels.squeeze(2).squeeze(0) + pixels = ((pixels + 1.0) / 2.0).clamp(0, 1) + pixels = (pixels * 255).to(torch.uint8).permute(1, 2, 0).numpy() + img = Image.fromarray(pixels, mode="RGB") + img.save(args.output) + + print(f"\nSaved: {args.output}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Cosmos3-Text2Image/src/__init__.py b/contrib/models/Cosmos3-Text2Image/src/__init__.py new file mode 100644 index 00000000..9d482d0f --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/src/__init__.py @@ -0,0 +1,6 @@ +from .modeling_cosmos3 import ( + Cosmos3BackboneInferenceConfig, + NeuronCosmos3BackboneApplication, + NeuronCosmos3Transformer, +) +from .pipeline import patchify, unpatchify, generate_position_ids diff --git a/contrib/models/Cosmos3-Text2Image/src/modeling_cosmos3.py b/contrib/models/Cosmos3-Text2Image/src/modeling_cosmos3.py new file mode 100644 index 00000000..adad5739 --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/src/modeling_cosmos3.py @@ -0,0 +1,1139 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Cosmos3-Nano MoT (Mixture-of-Transformers) generation backbone for NxDI. +# Adapted from NxDI Flux implementation with Cosmos3-specific MoT dual-stream +# attention, separate MLPs, additive timestep conditioning, and GQA. + +import logging +import math +import os +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + SPMDRank, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_tensor_model_parallel_region_with_dim, + reduce_from_tensor_model_parallel_region, + scatter_to_process_group_spmd, +) +from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_group, + get_tensor_model_parallel_size, + get_world_group, +) + +from neuronx_distributed_inference.models.diffusers.embeddings import ( + NeuronTimestepEmbedding, + Timesteps, + get_1d_rotary_pos_embed, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.utils.distributed import get_dp_rank_spmd + +from nkilib.core.attention.attention_cte import attention_cte + +from neuronx_distributed.utils.utils import hardware +from torch_neuronx.utils import get_platform_target + +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) +from neuronx_distributed_inference.models.model_wrapper import ( + BaseModelInstance, + ModelWrapper, +) + +_HARDWARE = hardware(get_platform_target()) + +if not os.environ.get("NEURON_PLATFORM_TARGET_OVERRIDE"): + os.environ["NEURON_PLATFORM_TARGET_OVERRIDE"] = get_platform_target() + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Attention Kernel Wrapper (bidirectional, from Flux) +# ============================================================================= + + +def attention_wrapper_bidirectional(query, key, value): + """ + Bidirectional attention using NKI attention_cte kernel. + + Input shapes: query, key, value all have shape [bs, n_head, seq_len, d_head] + Output shape: [bs, n_head, q_len, d_head] + + Uses tp_q=True, tp_k=True to let the kernel handle transposes internally. + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + + q = query.reshape((bs * n_head, q_len, d_head)) + k = key.reshape((bs * n_head, k_len, d_head)) + v = value.reshape((bs * n_head, k_len, d_head)) + + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + use_sharded_attention_kernel = vc_size == 2 + scale = 1 / math.sqrt(d_head) + + if use_sharded_attention_kernel: + attn_output = attention_cte[2]( + q, + k, + v, + scale, + causal_mask=False, + tp_q=True, + tp_k=True, + tp_out=False, + ) + else: + attn_output = attention_cte( + q, + k, + v, + scale, + causal_mask=False, + tp_q=True, + tp_k=True, + tp_out=False, + ) + + attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) + return attn_output + + +def attention_wrapper_causal(query, key, value): + """ + Causal attention using NKI attention_cte kernel. + + Same as bidirectional but with causal_mask=True. + Used for the text (understanding) pathway. + """ + bs, n_head, q_len, d_head = query.shape + k_len = key.shape[2] + + q = query.reshape((bs * n_head, q_len, d_head)) + k = key.reshape((bs * n_head, k_len, d_head)) + v = value.reshape((bs * n_head, k_len, d_head)) + + vc_size = int(os.getenv("NEURON_RT_VIRTUAL_CORE_SIZE", "1")) + use_sharded_attention_kernel = vc_size == 2 + scale = 1 / math.sqrt(d_head) + + if use_sharded_attention_kernel: + attn_output = attention_cte[2]( + q, + k, + v, + scale, + causal_mask=True, + tp_q=True, + tp_k=True, + tp_out=False, + ) + else: + attn_output = attention_cte( + q, + k, + v, + scale, + causal_mask=True, + tp_q=True, + tp_k=True, + tp_out=False, + ) + + attn_output = attn_output.reshape((bs, n_head, q_len, d_head)) + return attn_output + + +# ============================================================================= +# M-RoPE (Multimodal Rotary Position Embedding) +# Matches Cosmos3VLTextRotaryEmbedding from diffusers reference implementation. +# ============================================================================= + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Standard half-rotation for RoPE: split last dim in half, negate+swap.""" + half = x.shape[-1] // 2 + return torch.cat((-x[..., half:], x[..., :half]), dim=-1) + + +class Cosmos3MRoPE(nn.Module): + """ + M-RoPE for Cosmos3 matching Cosmos3VLTextRotaryEmbedding. + + Key differences from the previous (incorrect) implementation: + 1. Uses a SINGLE inv_freq basis computed over full head_dim (not per-axis) + 2. Applies interleaved M-RoPE mixing (T/H/W frequencies interleaved) + 3. Returns (cos, sin) each of shape [seq_len, head_dim] + + Position IDs: [seq_len, 3] -> (t, h, w) per token. + """ + + def __init__( + self, head_dim: int, mrope_section: List[int], rope_theta: float = 5000000.0 + ): + super().__init__() + self.head_dim = head_dim + self.mrope_section = mrope_section # [24, 20, 20] for Cosmos3 + self.rope_theta = rope_theta + # Single inv_freq basis over full head_dim + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def _apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor: + """Reorganize chunked [TTT...HHH...WWW] frequency layout into interleaved + [THTHWHTHW...TT], preserving frequency continuity across the 3 grids. + + Uses scatter-based approach for XLA compatibility (no in-place mutation). + + Args: + freqs: [3, seq_len, head_dim//2] - per-axis frequencies + + Returns: + [seq_len, head_dim//2] - interleaved frequencies + """ + # Build output by selecting from T, H, W based on interleaving pattern. + # For mrope_section=[24, 20, 20], head_dim//2=64: + # positions 0,3,6,...,69 (step 3, 24 positions) -> T (freqs[0]) + # positions 1,4,7,...,58 (step 3, 20 positions) -> H (freqs[1]) + # positions 2,5,8,...,59 (step 3, 20 positions) -> W (freqs[2]) + # Remaining positions use T (freqs[0]) + half_dim = freqs.shape[-1] + seq_len = freqs.shape[1] + + # Build source index on CPU (constant, will be baked into XLA graph) + # This determines which axis (0=T, 1=H, 2=W) each position draws from + source_cpu = torch.zeros(half_dim, dtype=torch.long) + for dim in range(1, 3): # H=1, W=2 + length = self.mrope_section[dim] * 3 + for i in range(dim, length, 3): + source_cpu[i] = dim + + # Move to device and expand for gather + source = source_cpu.to(device=freqs.device) + source_expanded = source.unsqueeze(0).expand(seq_len, -1) # [seq_len, hd//2] + + # freqs: [3, seq_len, half_dim] -> permute to [seq_len, half_dim, 3] + freqs_perm = freqs.permute(1, 2, 0) # [seq_len, half_dim, 3] + + # Gather along last dim using source indices + result = freqs_perm.gather(2, source_expanded.unsqueeze(-1)).squeeze( + -1 + ) # [seq_len, half_dim] + + return result + + def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + position_ids: [seq_len, 3] - (t, h, w) positions per token + + Returns: + (cos, sin): each [seq_len, head_dim] + """ + # position_ids: [seq_len, 3] -> transpose to [3, seq_len] + pos = position_ids.T.float() # [3, seq_len] + + # Move inv_freq to input device (critical for XLA tracing) + inv_freq = self.inv_freq.to(device=position_ids.device) + + # inv_freq: [head_dim//2] -> expand to [3, head_dim//2, 1] + inv_freq_expanded = inv_freq[None, :, None].expand( + 3, -1, 1 + ) # [3, head_dim//2, 1] + pos_expanded = pos[:, None, :] # [3, 1, seq_len] + + # freqs: [3, seq_len, head_dim//2] + freqs = (inv_freq_expanded @ pos_expanded).transpose( + 1, 2 + ) # [3, seq_len, hd//2] + + # Apply interleaved M-RoPE mixing + freqs_mixed = self._apply_interleaved_mrope(freqs) # [seq_len, head_dim//2] + + # Expand to full head_dim by doubling (for rotate_half compatibility) + emb = torch.cat((freqs_mixed, freqs_mixed), dim=-1) # [seq_len, head_dim] + + return emb.cos(), emb.sin() + + +def apply_rotary_emb_cosmos3( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> torch.Tensor: + """ + Apply M-RoPE to query or key tensor using rotate_half convention. + + Args: + x: [B, H, S, D] - query or key + cos: [S, D] - cosine component + sin: [S, D] - sine component + + Returns: + [B, H, S, D] - rotated tensor + """ + # Broadcast cos/sin: [S, D] -> [1, 1, S, D] + cos = cos[None, None] + sin = sin[None, None] + + # Standard rotate_half application (matches reference _rotate_half) + out = (x.float() * cos + _rotate_half(x.float()) * sin).to(x.dtype) + return out + + +# ============================================================================= +# NeuronCosmos3Attention (Joint MMDiT with GQA) +# ============================================================================= + + +class NeuronCosmos3Attention(nn.Module): + """ + Joint attention for Cosmos3 MoT with: + - Separate Q/K/V projections per stream (text: to_q/to_k/to_v, gen: add_q/k/v_proj) + - GQA (32 Q heads, 8 KV heads) + - QK normalization (per-head RMSNorm) + - Separate output projections (to_out, to_add_out) + - Bidirectional attention (no causal mask) + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + reduce_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.head_dim = head_dim + self.num_kv_groups = num_attention_heads // num_key_value_heads + + tp_degree = get_tensor_model_parallel_size() + + # Pad heads for TP alignment + self.padded_num_heads = math.ceil(num_attention_heads / tp_degree) * tp_degree + self.padded_num_kv_heads = ( + math.ceil(num_key_value_heads / tp_degree) * tp_degree + ) + self.padded_q_dim = self.padded_num_heads * head_dim + self.padded_kv_dim = self.padded_num_kv_heads * head_dim + + # Per-TP-rank head counts + self.heads_per_rank = self.padded_num_heads // tp_degree + self.kv_heads_per_rank = self.padded_num_kv_heads // tp_degree + + # --- Text stream projections --- + self.to_q = ColumnParallelLinear( + hidden_size, + self.padded_q_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_k = ColumnParallelLinear( + hidden_size, + self.padded_kv_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_v = ColumnParallelLinear( + hidden_size, + self.padded_kv_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_out = RowParallelLinear( + self.padded_q_dim, + hidden_size, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + + # --- Generation stream projections --- + self.add_q_proj = ColumnParallelLinear( + hidden_size, + self.padded_q_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.add_k_proj = ColumnParallelLinear( + hidden_size, + self.padded_kv_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.add_v_proj = ColumnParallelLinear( + hidden_size, + self.padded_kv_dim, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.to_add_out = RowParallelLinear( + self.padded_q_dim, + hidden_size, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + + # --- QK Normalization (per-head RMSNorm) --- + self.norm_q = CustomRMSNorm(head_dim, eps=1e-6) + self.norm_k = CustomRMSNorm(head_dim, eps=1e-6) + self.norm_added_q = CustomRMSNorm(head_dim, eps=1e-6) + self.norm_added_k = CustomRMSNorm(head_dim, eps=1e-6) + + def _repeat_kv(self, kv: torch.Tensor) -> torch.Tensor: + """Expand KV heads to match Q heads for GQA. [B, kv_heads, S, D] -> [B, q_heads, S, D]""" + if self.num_kv_groups == 1: + return kv + bs, n_kv_heads, seq_len, head_dim = kv.shape + kv = kv[:, :, None, :, :].expand( + bs, n_kv_heads, self.num_kv_groups, seq_len, head_dim + ) + return kv.reshape(bs, n_kv_heads * self.num_kv_groups, seq_len, head_dim) + + def forward( + self, + text_hidden: torch.Tensor, + gen_hidden: torch.Tensor, + rotary_emb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Dual-pathway attention matching reference Cosmos3AttnProcessor: + - Text (understanding): CAUSAL self-attention over text tokens only + - Generation: BIDIRECTIONAL attention over ALL tokens (text + gen) + + Args: + text_hidden: [B, T_text, hidden_size] - normed text hidden states + gen_hidden: [B, T_gen, hidden_size] - normed gen hidden states + rotary_emb: (cos_und, sin_und, cos_gen, sin_gen) + cos/sin_und: [T_text, head_dim] + cos/sin_gen: [T_gen, head_dim] + + Returns: + text_attn_out: [B, T_text, hidden_size] + gen_attn_out: [B, T_gen, hidden_size] + """ + batch_size = text_hidden.shape[0] + text_len = text_hidden.shape[1] + gen_len = gen_hidden.shape[1] + + cos_und, sin_und, cos_gen, sin_gen = rotary_emb + + # --- Project text stream --- + text_q = self.to_q(text_hidden) # [B, T_text, padded_q_dim/tp] + text_k = self.to_k(text_hidden) # [B, T_text, padded_kv_dim/tp] + text_v = self.to_v(text_hidden) # [B, T_text, padded_kv_dim/tp] + + # --- Project gen stream --- + gen_q = self.add_q_proj(gen_hidden) + gen_k = self.add_k_proj(gen_hidden) + gen_v = self.add_v_proj(gen_hidden) + + # --- Reshape to [B, heads, S, head_dim] --- + text_q = text_q.view( + batch_size, text_len, self.heads_per_rank, self.head_dim + ).transpose(1, 2) + text_k = text_k.view( + batch_size, text_len, self.kv_heads_per_rank, self.head_dim + ).transpose(1, 2) + text_v = text_v.view( + batch_size, text_len, self.kv_heads_per_rank, self.head_dim + ).transpose(1, 2) + + gen_q = gen_q.view( + batch_size, gen_len, self.heads_per_rank, self.head_dim + ).transpose(1, 2) + gen_k = gen_k.view( + batch_size, gen_len, self.kv_heads_per_rank, self.head_dim + ).transpose(1, 2) + gen_v = gen_v.view( + batch_size, gen_len, self.kv_heads_per_rank, self.head_dim + ).transpose(1, 2) + + # --- QK Normalization --- + text_q = self.norm_q(text_q) + text_k = self.norm_k(text_k) + gen_q = self.norm_added_q(gen_q) + gen_k = self.norm_added_k(gen_k) + + # --- Apply M-RoPE (separate per pathway) --- + text_q = apply_rotary_emb_cosmos3(text_q, cos_und, sin_und) + text_k = apply_rotary_emb_cosmos3(text_k, cos_und, sin_und) + gen_q = apply_rotary_emb_cosmos3(gen_q, cos_gen, sin_gen) + gen_k = apply_rotary_emb_cosmos3(gen_k, cos_gen, sin_gen) + + # --- GQA: expand KV heads to match Q heads --- + text_k = self._repeat_kv(text_k) + text_v = self._repeat_kv(text_v) + gen_k = self._repeat_kv(gen_k) + gen_v = self._repeat_kv(gen_v) + + # --- Text pathway: CAUSAL self-attention (text only) --- + if _HARDWARE == hardware.TRN1: + text_attn_output = F.scaled_dot_product_attention( + text_q, text_k, text_v, dropout_p=0.0, is_causal=True + ) + else: + text_attn_output = attention_wrapper_causal(text_q, text_k, text_v) + + # --- Generation pathway: BIDIRECTIONAL attention to ALL tokens --- + all_k = torch.cat([text_k, gen_k], dim=2) # [B, heads, T_text+T_gen, D] + all_v = torch.cat([text_v, gen_v], dim=2) + + if _HARDWARE == hardware.TRN1: + gen_attn_output = F.scaled_dot_product_attention( + gen_q, all_k, all_v, dropout_p=0.0, is_causal=False + ) + else: + gen_attn_output = attention_wrapper_bidirectional(gen_q, all_k, all_v) + + # --- Reshape outputs --- + text_attn_output = text_attn_output.transpose(1, 2).reshape( + batch_size, text_len, self.heads_per_rank * self.head_dim + ) + gen_attn_output = gen_attn_output.transpose(1, 2).reshape( + batch_size, gen_len, self.heads_per_rank * self.head_dim + ) + + # --- Separate output projections --- + text_out = self.to_out(text_attn_output) + gen_out = self.to_add_out(gen_attn_output) + + return text_out, gen_out + + +# ============================================================================= +# SwiGLU MLP (used for both text and gen MLPs) +# ============================================================================= + + +class NeuronCosmos3SwiGLU(nn.Module): + """ + SwiGLU MLP: gate_proj (up), up_proj (gate), down_proj. + Same as Llama/Qwen MLP structure. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + reduce_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + self.gate_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=False, + gather_output=False, + reduce_dtype=reduce_dtype, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + reduce_dtype=reduce_dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +# ============================================================================= +# NeuronCosmos3MoTBlock (single MoT layer) +# ============================================================================= + + +class NeuronCosmos3MoTBlock(nn.Module): + """ + Single Cosmos3 MoT layer with: + - Separate pre-attention LayerNorms per stream + - Joint attention (MMDiT-style) + - Separate post-attention LayerNorms per stream + - Separate SwiGLU MLPs per stream (text: mlp, gen: mlp_moe_gen) + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rms_norm_eps: float = 1e-6, + reduce_dtype: torch.dtype = torch.bfloat16, + ): + super().__init__() + + # --- Pre-attention norms (separate per stream) --- + self.input_layernorm = CustomRMSNorm(hidden_size, eps=rms_norm_eps) + self.input_layernorm_moe_gen = CustomRMSNorm(hidden_size, eps=rms_norm_eps) + + # --- Joint attention --- + self.self_attn = NeuronCosmos3Attention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + reduce_dtype=reduce_dtype, + ) + + # --- Post-attention norms (separate per stream) --- + self.post_attention_layernorm = CustomRMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm_moe_gen = CustomRMSNorm( + hidden_size, eps=rms_norm_eps + ) + + # --- Separate MLPs --- + self.mlp = NeuronCosmos3SwiGLU(hidden_size, intermediate_size, reduce_dtype) + self.mlp_moe_gen = NeuronCosmos3SwiGLU( + hidden_size, intermediate_size, reduce_dtype + ) + + def forward( + self, + hidden_states: torch.Tensor, + text_len: int, + rotary_emb: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + hidden_states: [B, T_text + T_gen, hidden_size] - packed sequence + text_len: number of text tokens (split point) + rotary_emb: (cos_und, sin_und, cos_gen, sin_gen) + cos/sin_und: [T_text, head_dim] + cos/sin_gen: [T_gen, head_dim] + + Returns: + [B, T_text + T_gen, hidden_size] - updated packed sequence + """ + # --- Split streams --- + text_hidden = hidden_states[:, :text_len, :] + gen_hidden = hidden_states[:, text_len:, :] + + # --- Pre-norm --- + text_normed = self.input_layernorm(text_hidden) + gen_normed = self.input_layernorm_moe_gen(gen_hidden) + + # --- Joint attention --- + text_attn_out, gen_attn_out = self.self_attn( + text_normed, gen_normed, rotary_emb + ) + + # --- Residual --- + text_hidden = text_hidden + text_attn_out + gen_hidden = gen_hidden + gen_attn_out + + # --- Post-norm + separate MLPs --- + text_hidden = text_hidden + self.mlp(self.post_attention_layernorm(text_hidden)) + gen_hidden = gen_hidden + self.mlp_moe_gen( + self.post_attention_layernorm_moe_gen(gen_hidden) + ) + + # --- Re-pack --- + return torch.cat([text_hidden, gen_hidden], dim=1) + + +# ============================================================================= +# NeuronCosmos3Transformer (full backbone) +# ============================================================================= + + +class NeuronCosmos3Transformer(nn.Module): + """ + Full Cosmos3-Nano MoT transformer backbone for generation. + + Architecture: + - embed_tokens: shared text embedding + - vae2llm (proj_in): project VAE latent patches to hidden_size + - time_embedder: sinusoidal timestep -> MLP -> hidden_size (additive) + - 36 MoT layers (NeuronCosmos3MoTBlock) + - norm: final RMSNorm + - llm2vae (proj_out): project hidden_size -> patch channels (velocity output) + + Forward signature: + text_ids, vision_patches, timestep, position_ids -> velocity_prediction + """ + + def __init__(self, config: "Cosmos3BackboneInferenceConfig"): + super().__init__() + self.config = config + + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + num_layers = config.num_hidden_layers + num_attention_heads = config.num_attention_heads + num_key_value_heads = config.num_key_value_heads + head_dim = config.head_dim + rms_norm_eps = config.rms_norm_eps + vocab_size = config.vocab_size + patch_channels = config.patch_channels # 64 (16 channels * 2x2 spatial patch) + + reduce_dtype = config.neuron_config.torch_dtype + + self.data_parallel_group = get_data_parallel_group() + self.global_rank = SPMDRank(world_size=get_world_group().size()) + self.cfg_parallel_enabled = getattr(config, "cfg_parallel_enabled", False) + + # --- Token embedding (text) --- + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + + # --- Patch projection (VAE latent -> hidden) --- + # proj_in in the HF model: [patch_channels, hidden_size] with bias + self.proj_in = ColumnParallelLinear( + patch_channels, + hidden_size, + bias=True, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + + # --- Timestep embedding (sinusoidal + MLP, additive) --- + # Cosmos3 time_embedder: Timesteps(256) -> MLP(256 -> hidden_size) + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.time_embedder = NeuronTimestepEmbedding( + in_channels=256, + time_embed_dim=hidden_size, + reduce_dtype=reduce_dtype, + ) + + # --- MoT Transformer layers --- + self.layers = nn.ModuleList( + [ + NeuronCosmos3MoTBlock( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + rms_norm_eps=rms_norm_eps, + reduce_dtype=reduce_dtype, + ) + for _ in range(num_layers) + ] + ) + + # --- Output projection --- + # norm_moe_gen: final RMSNorm for generation stream output + # (norm is for text stream / lm_head, not used here) + self.norm_moe_gen = CustomRMSNorm(hidden_size, eps=rms_norm_eps) + # proj_out: [hidden_size, patch_channels] with bias + # Use ColumnParallelLinear with gather to shard across TP + # Weight shape: [patch_channels, hidden_size] = [192, 4096] + # Each rank: [48, 4096] input, gathered to [192] output + self.proj_out = ColumnParallelLinear( + hidden_size, + patch_channels, + bias=True, + gather_output=True, + reduce_dtype=reduce_dtype, + ) + + def forward( + self, + text_ids: torch.Tensor, + vision_patches: torch.Tensor, + timestep: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + text_ids: [B, T_text] - tokenized text (int64) + vision_patches: [B, T_gen, patch_channels] - patchified noisy latents + timestep: [B] - diffusion timestep (float, 0 to 1) + position_ids: [T_text + T_gen, 3] - M-RoPE positions (t, h, w) + + Returns: + velocity: [B, T_gen, patch_channels] - predicted velocity + """ + batch_size = text_ids.shape[0] + text_len = text_ids.shape[1] + gen_len = vision_patches.shape[1] + + # --- CFG Parallel: scatter inputs --- + if self.cfg_parallel_enabled and batch_size == 2: + dp_rank = get_dp_rank_spmd( + global_rank=self.global_rank.get_rank(), + tp_degree=get_tensor_model_parallel_size(), + ) + text_ids = scatter_to_process_group_spmd( + text_ids, + partition_dim=0, + rank=dp_rank, + process_group=self.data_parallel_group, + ) + vision_patches = scatter_to_process_group_spmd( + vision_patches, + partition_dim=0, + rank=dp_rank, + process_group=self.data_parallel_group, + ) + timestep = scatter_to_process_group_spmd( + timestep, + partition_dim=0, + rank=dp_rank, + process_group=self.data_parallel_group, + ) + batch_size = 1 + + # --- 1. Embed text --- + text_embeds = self.embed_tokens(text_ids) # [B, T_text, hidden_size] + + # --- 2. Project vision patches --- + vision_embeds = self.proj_in(vision_patches) # [B, T_gen, hidden_size] + + # --- 3. Additive timestep conditioning on vision tokens --- + timestep_proj = self.time_proj(timestep) # [B, 256] + t_emb = self.time_embedder( + timestep_proj.to(vision_embeds.dtype) + ) # [B, hidden_size] + vision_embeds = vision_embeds + t_emb.unsqueeze(1) # broadcast add + + # --- 4. Pack into single sequence --- + hidden_states = torch.cat( + [text_embeds, vision_embeds], dim=1 + ) # [B, T_text+T_gen, H] + + # --- 5. Compute M-RoPE and split into text/gen portions --- + mrope = Cosmos3MRoPE( + head_dim=self.config.head_dim, + mrope_section=self.config.mrope_section, + rope_theta=self.config.rope_theta, + ) + # Returns (cos, sin) each [T_text + T_gen, head_dim] + cos_full, sin_full = mrope(position_ids) + cos_full = cos_full.to( + dtype=self.config.neuron_config.torch_dtype, device=hidden_states.device + ) + sin_full = sin_full.to( + dtype=self.config.neuron_config.torch_dtype, device=hidden_states.device + ) + # Split into text (understanding) and gen portions + cos_und = cos_full[:text_len] # [T_text, head_dim] + sin_und = sin_full[:text_len] # [T_text, head_dim] + cos_gen = cos_full[text_len:] # [T_gen, head_dim] + sin_gen = sin_full[text_len:] # [T_gen, head_dim] + rotary_emb = (cos_und, sin_und, cos_gen, sin_gen) + + # --- 6. Run through MoT layers --- + hidden_states, _ = ModuleMarkerStartWrapper()(hidden_states, hidden_states) + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, text_len, rotary_emb) + # Layer boundary markers for compiler optimization (every 2 layers) + if i % 2 == 1 and i < len(self.layers) - 1: + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + + # --- 7. Extract vision portion and project to velocity --- + vision_output = hidden_states[:, text_len:, :] # [B, T_gen, hidden_size] + vision_output = self.norm_moe_gen(vision_output) + velocity = self.proj_out(vision_output) # [B, T_gen, patch_channels] + + # --- CFG Parallel: gather outputs --- + if self.cfg_parallel_enabled: + velocity = gather_from_tensor_model_parallel_region_with_dim( + velocity, + gather_dim=0, + process_group=self.data_parallel_group, + ) + + return velocity + + +# ============================================================================= +# Config +# ============================================================================= + + +class Cosmos3BackboneInferenceConfig(InferenceConfig): + """Config for the Cosmos3 generation backbone.""" + + def __init__(self, *args, cfg_parallel_enabled: bool = False, **kwargs): + # Set Cosmos3-Nano defaults BEFORE super().__init__ (which calls validate_config) + self.hidden_size = kwargs.pop("hidden_size", 4096) + self.intermediate_size = kwargs.pop("intermediate_size", 12288) + self.num_hidden_layers = kwargs.pop("num_hidden_layers", 36) + self.num_attention_heads = kwargs.pop("num_attention_heads", 32) + self.num_key_value_heads = kwargs.pop("num_key_value_heads", 8) + self.head_dim = kwargs.pop("head_dim", 128) + self.rms_norm_eps = kwargs.pop("rms_norm_eps", 1e-6) + self.vocab_size = kwargs.pop("vocab_size", 151936) + self.patch_channels = kwargs.pop( + "patch_channels", 192 + ) # 48 latent_ch * 2*2 patch + self.latent_channels = kwargs.pop("latent_channels", 48) + self.rope_theta = kwargs.pop("rope_theta", 5000000.0) + self.mrope_section = kwargs.pop("mrope_section", [24, 20, 20]) + self.cfg_parallel_enabled = cfg_parallel_enabled + + super().__init__(*args, **kwargs) + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", + "intermediate_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "head_dim", + "rms_norm_eps", + "vocab_size", + "patch_channels", + "rope_theta", + "mrope_section", + ] + + +# ============================================================================= +# Model Wrapper +# ============================================================================= + + +class ModelWrapperCosmos3Backbone(ModelWrapper): + """Wrapper for Cosmos3 backbone: handles input generation and forward dispatch.""" + + def __init__( + self, + config: InferenceConfig, + model_cls, + tag="", + compiler_args=None, + priority_model_idx=None, + model_init_kwargs=None, + ): + super().__init__( + config, + model_cls, + tag, + compiler_args, + priority_model_idx, + model_init_kwargs or {}, + ) + # For large models (Super, 64 layers), the NxDI framework appends + # --verify-hlo=true which fails before partitioning for models > 24 GB/rank. + # Replace verify-hlo=true with verify-hlo=false to skip the pre-partition check. + if config.num_hidden_layers > 36: + self.compiler_args = self.compiler_args.replace( + "--verify-hlo=true", "--verify-hlo=false" + ) + logger.info( + f"Large model: disabled verify-hlo (compiler_args: {self.compiler_args})" + ) + + self.mrope = Cosmos3MRoPE( + head_dim=config.head_dim, + mrope_section=config.mrope_section, + rope_theta=config.rope_theta, + ) + + def input_generator(self) -> List[Tuple[torch.Tensor, ...]]: + """Generate example inputs for compilation.""" + dtype = self.config.neuron_config.torch_dtype + text_len = self.config.max_text_len + gen_len = self.config.num_vision_patches + patch_channels = self.config.patch_channels + + batch_size = 2 if self.config.cfg_parallel_enabled else 1 + + model_inputs = ( + # text_ids: [B, T_text] + torch.zeros([batch_size, text_len], dtype=torch.long), + # vision_patches: [B, T_gen, patch_channels] + torch.randn([batch_size, gen_len, patch_channels], dtype=dtype), + # timestep: [B] + torch.randn([batch_size], dtype=dtype), + # position_ids: [T_text + T_gen, 3] + torch.zeros([text_len + gen_len, 3], dtype=torch.long), + ) + return [model_inputs] + + def get_model_instance(self): + def _create_model(): + model = self.model_cls(self.config) + model = model.to(dtype=self.config.neuron_config.torch_dtype) + model.eval() + return model + + model_instance = BaseModelInstance( + module_cls=_create_model, input_output_aliases={} + ) + return model_instance + + def forward(self, text_ids, vision_patches, timestep, position_ids): + """Override ModelWrapper.forward().""" + if self.model is None: + raise RuntimeError("Forward called before load. Run load() first.") + + timestep = timestep.to(self.config.neuron_config.torch_dtype) + + output = self._forward(text_ids, vision_patches, timestep, position_ids) + return output + + +# ============================================================================= +# Application (compile/load infrastructure) +# ============================================================================= + + +class NeuronCosmos3BackboneApplication(NeuronApplicationBase): + """ + Application class for the Cosmos3 MoT backbone. + Handles compilation, weight loading, and forward dispatch. + """ + + _model_cls = NeuronCosmos3Transformer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_wrapper = ModelWrapperCosmos3Backbone + + self.model = self.model_wrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + self.dtype = self.config.neuron_config.torch_dtype + + def forward(self, *model_inputs, **kwargs): + return self.models[0](*model_inputs, **kwargs) + + def get_compiler_args(self): + compiler_args = "--model-type=transformer -O1" + compiler_args += " --tensorizer-options='--enable-ccop-compute-overlap'" + compiler_args += " --auto-cast=none" + # For large models (Super, 64 layers): force low MAC threshold so modular flow + # always partitions the graph. Without this, the compiler may fail HBM verification + # before partitioning kicks in. + if self.config.num_hidden_layers > 36: + compiler_args += ( + " --internal-hlo2tensorizer-options=" + "'--modular-flow-mac-threshold=10 --recursive-layer-det=false'" + ) + + os.environ["LOCAL_WORLD_SIZE"] = str(self.config.neuron_config.world_size) + if _HARDWARE == hardware.TRN2: + os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2" + return compiler_args + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: + """ + Convert HF Diffusers Cosmos3 state_dict to NxDI format. + + HF Diffusers keys (from transformer/ directory): + layers.N.self_attn.to_q.weight + layers.N.self_attn.to_k.weight + layers.N.self_attn.to_v.weight + layers.N.self_attn.to_out.weight + layers.N.self_attn.norm_q.weight + layers.N.self_attn.norm_k.weight + layers.N.self_attn.add_q_proj.weight + layers.N.self_attn.add_k_proj.weight + layers.N.self_attn.add_v_proj.weight + layers.N.self_attn.to_add_out.weight + layers.N.self_attn.norm_added_q.weight + layers.N.self_attn.norm_added_k.weight + layers.N.input_layernorm.weight + layers.N.input_layernorm_moe_gen.weight + layers.N.post_attention_layernorm.weight + layers.N.post_attention_layernorm_moe_gen.weight + layers.N.mlp.gate_proj.weight + layers.N.mlp.up_proj.weight + layers.N.mlp.down_proj.weight + layers.N.mlp_moe_gen.gate_proj.weight + layers.N.mlp_moe_gen.up_proj.weight + layers.N.mlp_moe_gen.down_proj.weight + embed_tokens.weight + proj_in.weight / proj_in.bias + proj_out.weight / proj_out.bias + time_embedder.linear_1.weight / bias + time_embedder.linear_2.weight / bias + + NxDI keys (this model): + Same structure -- we keep the HF naming since our module names match. + Only need to: + 1. Map time_embedder.linear_1/2 -> time_embedder.linear_1/2 + 2. Add global_rank tensor + 3. Ensure contiguous tensors + """ + new_state_dict = {} + + # Key mapping from HF Diffusers to NxDI module names + # Most keys map directly since our module structure mirrors HF + key_mapping = { + # Timestep embedder + "time_embedder.linear_1.weight": "time_embedder.linear_1.weight", + "time_embedder.linear_1.bias": "time_embedder.linear_1.bias", + "time_embedder.linear_2.weight": "time_embedder.linear_2.weight", + "time_embedder.linear_2.bias": "time_embedder.linear_2.bias", + } + + # Keys to skip (not used in generation backbone) + skip_prefixes = [ + "lm_head.", # text generation head (reasoning only) + "audio_", # audio modality + "action_", # action modality + "sound_", # sound tokenizer + "visual.", # vision encoder (separate NEFF if needed) + ] + + for key, value in state_dict.items(): + # Skip non-generation keys + if any(key.startswith(prefix) for prefix in skip_prefixes): + continue + + # Apply key mapping if exists + new_key = key_mapping.get(key, key) + new_state_dict[new_key] = value.clone().detach().contiguous() + + # Add global rank tensor (required by NxDI parallel layers) + new_state_dict["global_rank.rank"] = torch.arange( + 0, config.neuron_config.world_size, dtype=torch.int32 + ) + + return new_state_dict diff --git a/contrib/models/Cosmos3-Text2Image/src/pipeline.py b/contrib/models/Cosmos3-Text2Image/src/pipeline.py new file mode 100644 index 00000000..054fa972 --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/src/pipeline.py @@ -0,0 +1,351 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Cosmos3 diffusion pipeline utilities: patchify/unpatchify, position IDs, +# tokenization helpers, and the denoising loop. + +import json +import logging +import time +from typing import Optional, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Patchify / Unpatchify +# ============================================================================= + + +def patchify(latents: torch.Tensor, patch_size: int = 2) -> torch.Tensor: + """ + Convert VAE latents to patch tokens for the transformer. + + Cosmos3 uses patch_spatial=2: each 2x2 spatial block of the latent is + flattened into a single token with spatial-first, channels-last ordering. + + Matches the reference einsum: "cthpwq->thwpqc" + + Args: + latents: [B, C, T, H, W] - VAE latent space (C=48, T=1 for images) + patch_size: spatial patch size (default 2) + + Returns: + patches: [B, N_patches, patch_channels] + N_patches = T * (H // patch_size) * (W // patch_size) + patch_channels = patch_size * patch_size * C + """ + B, C, T, H, W = latents.shape + assert H % patch_size == 0 and W % patch_size == 0, ( + f"Latent spatial dims ({H}, {W}) must be divisible by patch_size={patch_size}" + ) + + pH = H // patch_size + pW = W // patch_size + + # Reshape: [B, C, T, H, W] -> [B, C, T, pH, ps, pW, ps] + latents = latents.reshape(B, C, T, pH, patch_size, pW, patch_size) + # Permute to match reference "cthpwq->thwpqc": [B, T, pH, pW, ps_h, ps_w, C] + latents = latents.permute(0, 2, 3, 5, 4, 6, 1) + # Flatten patches: [B, T*pH*pW, ps*ps*C] + patches = latents.reshape(B, T * pH * pW, C * patch_size * patch_size) + + return patches + + +def unpatchify( + patches: torch.Tensor, + T: int, + H: int, + W: int, + channels: int = 48, + patch_size: int = 2, +) -> torch.Tensor: + """ + Convert patch tokens back to VAE latent space. + + Matches the reference einsum: "thwpqc->cthpwq" + + Args: + patches: [B, N_patches, patch_channels] - velocity prediction + T: temporal dimension of latent + H: height of latent (before patching) + W: width of latent (before patching) + channels: number of latent channels (48) + patch_size: spatial patch size (2) + + Returns: + latents: [B, C, T, H, W] + """ + B = patches.shape[0] + pH = H // patch_size + pW = W // patch_size + + # Reshape: [B, T*pH*pW, ps*ps*C] -> [B, T, pH, pW, ps_h, ps_w, C] + patches = patches.reshape(B, T, pH, pW, patch_size, patch_size, channels) + # Permute back (inverse of "cthpwq->thwpqc"): [B, C, T, pH, ps_h, pW, ps_w] + latents = patches.permute(0, 6, 1, 2, 4, 3, 5) + # Flatten spatial: [B, C, T, H, W] + latents = latents.reshape(B, channels, T, H, W) + + return latents + + +# ============================================================================= +# Position IDs for M-RoPE +# ============================================================================= + + +def build_position_ids( + max_text_len: int, + actual_text_len: int, + T: int, + pH: int, + pW: int, + temporal_margin: int = 15000, +) -> torch.Tensor: + """ + Build M-RoPE 3D position IDs for text + vision tokens. + + Text tokens: all 3 axes (T, H, W) share the same incrementing IDs [0..max_text_len-1]. + Vision tokens: temporal = actual_text_len + temporal_margin, H=[0..pH-1], W=[0..pW-1]. + + The temporal_margin (15000) separates text and vision position spaces, matching + the reference `unified_3d_mrope_temporal_modality_margin` parameter. + + Args: + max_text_len: padded text sequence length (model input dimension) + actual_text_len: actual number of meaningful text tokens (for temporal offset) + T: temporal patches (1 for images) + pH: spatial height patches + pW: spatial width patches + temporal_margin: gap between text and vision temporal positions (default 15000) + + Returns: + position_ids: [max_text_len + T*pH*pW, 3] - (t, h, w) per token + """ + vision_t_offset = actual_text_len + temporal_margin + + # Text: all 3 axes share same incrementing IDs + text_ids = torch.arange(max_text_len, dtype=torch.long) + text_pos = text_ids.unsqueeze(1).expand(-1, 3) # [max_text_len, 3] + + # Vision: 3D grid via meshgrid + t_coords = torch.arange(T, dtype=torch.long) + vision_t_offset + h_coords = torch.arange(pH, dtype=torch.long) + w_coords = torch.arange(pW, dtype=torch.long) + grid_t, grid_h, grid_w = torch.meshgrid(t_coords, h_coords, w_coords, indexing="ij") + vis_pos = torch.stack([grid_t.flatten(), grid_h.flatten(), grid_w.flatten()], dim=1) + + return torch.cat([text_pos, vis_pos], dim=0) + + +def generate_position_ids(text_len: int, T: int, pH: int, pW: int) -> torch.Tensor: + """Simple position ID generation (text_len used as both max and actual).""" + return build_position_ids(text_len, text_len, T, pH, pW) + + +# ============================================================================= +# Tokenization (matching reference Cosmos3OmniPipeline.tokenize_prompt) +# ============================================================================= + +SYSTEM_PROMPT = ( + "You are a helpful assistant who will generate images from a give prompt." +) +RESOLUTION_TEMPLATE = "This image is of {height}x{width} resolution." +NEGATIVE_RESOLUTION_TEMPLATE = "This image is not of {height}x{width} resolution." + + +def tokenize_prompt( + tokenizer, + prompt: str, + height: int = 512, + width: int = 512, + max_len: int = 256, + negative: bool = False, +) -> Tuple[torch.Tensor, int]: + """ + Tokenize a prompt following the Cosmos3 reference pipeline format. + + The reference pipeline uses: + 1. System prompt: "You are a helpful assistant who will generate images..." + 2. User content: prompt + resolution template (or inverse for negative) + 3. apply_chat_template with add_generation_prompt=True + 4. Append eos_token_id + <|vision_start|> token + + Args: + tokenizer: Qwen2 tokenizer (from the model) + prompt: text prompt (or empty for negative) + height: image height in pixels + width: image width in pixels + max_len: maximum token length (padded) + negative: if True, use inverse resolution template + + Returns: + (padded_ids, actual_len): padded token tensor [1, max_len] and actual length + """ + if negative: + user_text = NEGATIVE_RESOLUTION_TEMPLATE.format(height=height, width=width) + else: + user_text = ( + prompt.rstrip(".") + + ". " + + RESOLUTION_TEMPLATE.format(height=height, width=width) + ) + + conversations = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_text}, + ] + result = tokenizer.apply_chat_template( + conversations, + tokenize=True, + add_generation_prompt=True, + add_vision_id=False, + return_dict=True, + ) + + eos_id = tokenizer.eos_token_id + vision_start_id = tokenizer.convert_tokens_to_ids("<|vision_start|>") + input_ids = list(result["input_ids"]) + [eos_id, vision_start_id] + actual_len = len(input_ids) + + # Pad + pad_id = tokenizer.pad_token_id or 0 + if actual_len > max_len: + input_ids = input_ids[:max_len] + actual_len = max_len + else: + input_ids = input_ids + [pad_id] * (max_len - actual_len) + + return torch.tensor([input_ids], dtype=torch.long), actual_len + + +# ============================================================================= +# Denoising Loop +# ============================================================================= + + +def denoise( + backbone, + cond_ids: torch.Tensor, + uncond_ids: torch.Tensor, + cond_pos: torch.Tensor, + uncond_pos: torch.Tensor, + scheduler, + latents: torch.Tensor, + num_steps: int = 35, + cfg_scale: float = 6.0, + latent_channels: int = 48, + patch_size: int = 2, + cfg_parallel: bool = False, +) -> torch.Tensor: + """ + Run the denoising loop with CFG. + + This is the optimized inner loop that: + - Pre-computes timestep tensors + - Uses return_dict=False for scheduler + - Keeps CFG math in bf16 + - Optionally uses CFG-parallel (batch=2 single call) + + Args: + backbone: compiled NeuronCosmos3BackboneApplication + cond_ids: [1, max_text_len] - conditional token IDs + uncond_ids: [1, max_text_len] - unconditional token IDs + cond_pos: [total_seq, 3] - conditional position IDs + uncond_pos: [total_seq, 3] - unconditional position IDs + scheduler: UniPCMultistepScheduler (already configured) + latents: [1, C, T, H, W] - initial noisy latents (float32) + num_steps: denoising steps + cfg_scale: classifier-free guidance scale + latent_channels: number of VAE latent channels (48) + patch_size: spatial patch size (2) + cfg_parallel: if True, pack cond+uncond into batch=2 single call + + Returns: + latents: [1, C, T, H, W] - denoised latents (float32) + """ + _, _, T, H_latent, W_latent = latents.shape + + scheduler.set_timesteps(num_steps) + timesteps = scheduler.timesteps + latents = latents * scheduler.init_noise_sigma + + # Pre-compute timestep tensors + if cfg_parallel: + ts_tensors = [ + torch.tensor([t.item() * 0.001, t.item() * 0.001], dtype=torch.bfloat16) + for t in timesteps + ] + # Pack text IDs into batch=2 + text_ids_batch = torch.cat([cond_ids, uncond_ids], dim=0) # [2, max_text_len] + else: + ts_tensors = [ + torch.tensor([t.item() * 0.001], dtype=torch.bfloat16) for t in timesteps + ] + + start = time.time() + for i, t_val in enumerate(timesteps): + vis_patches = patchify(latents.to(torch.bfloat16), patch_size=patch_size) + + if cfg_parallel: + # Single call with batch=2: [cond_patches, uncond_patches] + vis_batch = vis_patches.expand(2, -1, -1).contiguous() + output = backbone(text_ids_batch, vis_batch, ts_tensors[i], cond_pos) + v_cond = output[0:1] + v_uncond = output[1:2] + else: + v_cond = backbone(cond_ids, vis_patches, ts_tensors[i], cond_pos) + v_uncond = backbone(uncond_ids, vis_patches, ts_tensors[i], uncond_pos) + + velocity = v_uncond + cfg_scale * (v_cond - v_uncond) + vel_latent = unpatchify( + velocity.float(), + T, + H_latent, + W_latent, + channels=latent_channels, + patch_size=patch_size, + ) + + latents = scheduler.step(vel_latent, t_val, latents, return_dict=False)[0] + + if i < 2 or i == num_steps - 1 or (i + 1) % 10 == 0: + logger.info( + f" Step {i + 1}/{num_steps}: v_norm={vel_latent.norm():.1f}, " + f"lat_norm={latents.norm():.1f}" + ) + + elapsed = time.time() - start + mode_str = "CFG-parallel" if cfg_parallel else "sequential" + logger.info( + f" Denoising ({mode_str}): {elapsed:.2f}s ({elapsed / num_steps * 1000:.1f}ms/step)" + ) + + return latents + + +def denormalize_latents(latents: torch.Tensor, vae_config_path: str) -> torch.Tensor: + """ + Denormalize latents from scheduler space to VAE space. + + The denoising operates in a normalized space (zero mean, unit variance). + The VAE expects native latent space. Transform: latents = latents * std + mean. + + Args: + latents: [1, 48, T, H, W] - denoised latents + vae_config_path: path to vae/config.json + + Returns: + denormalized latents ready for VAE decoding + """ + with open(vae_config_path) as f: + vae_cfg = json.load(f) + + lat_mean = torch.tensor(vae_cfg["latents_mean"]).view(1, 48, 1, 1, 1).float() + lat_std = torch.tensor(vae_cfg["latents_std"]).view(1, 48, 1, 1, 1).float() + + return latents.float() * lat_std + lat_mean diff --git a/contrib/models/Cosmos3-Text2Image/test/__init__.py b/contrib/models/Cosmos3-Text2Image/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Cosmos3-Text2Image/test/integration/__init__.py b/contrib/models/Cosmos3-Text2Image/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Cosmos3-Text2Image/test/integration/test_model.py b/contrib/models/Cosmos3-Text2Image/test/integration/test_model.py new file mode 100644 index 00000000..ad13c4e6 --- /dev/null +++ b/contrib/models/Cosmos3-Text2Image/test/integration/test_model.py @@ -0,0 +1,438 @@ +#!/usr/bin/env python3 +""" +Integration tests for Cosmos3-Text2Image NeuronX implementation. + +Tests model compilation, loading, and image generation quality. + +Unlike LLM contribs that test token accuracy, this tests: +1. Model loads and produces output (smoke test) +2. Output has correct shape and valid pixel range +3. Generated image has expected statistical properties (not noise/blank) +4. Performance meets latency targets + +Usage: + # Run with pytest: + pytest test/integration/test_model.py --capture=tee-sys -v + + # Run manually: + python test/integration/test_model.py + +Configuration: + Set MODEL_PATH, COMPILED_PATH, and VAE_PATH below to match your setup. +""" + +import os +import sys +import time + +import pytest +import torch + +# Add src to path +sys.path.insert(0, str(os.path.join(os.path.dirname(__file__), "..", "..", "src"))) + +from modeling_cosmos3 import ( + Cosmos3BackboneInferenceConfig, + NeuronCosmos3BackboneApplication, +) +from pipeline import ( + build_position_ids, + denoise, + denormalize_latents, + patchify, + tokenize_prompt, + unpatchify, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +# ============================================================================= +# Test Configuration - Update these paths for your environment +# ============================================================================= + +# Nano configuration (trn2.3xlarge, TP=4) +MODEL_PATH = os.environ.get("COSMOS3_MODEL_PATH", "/home/ubuntu/Cosmos3-Nano") +COMPILED_PATH = os.environ.get("COSMOS3_COMPILED_PATH", "/home/ubuntu/compiled_cosmos3") +VAE_PATH = os.environ.get( + "COSMOS3_VAE_PATH", "/home/ubuntu/compiled_vae/vae_decoder.pt" +) +TP_DEGREE = int(os.environ.get("COSMOS3_TP_DEGREE", "4")) + +# Model params (auto-detected from MODEL_PATH if available) +HIDDEN_SIZE = int(os.environ.get("COSMOS3_HIDDEN_SIZE", "4096")) +INTERMEDIATE_SIZE = int(os.environ.get("COSMOS3_INTERMEDIATE_SIZE", "12288")) +NUM_LAYERS = int(os.environ.get("COSMOS3_NUM_LAYERS", "36")) +NUM_HEADS = int(os.environ.get("COSMOS3_NUM_HEADS", "32")) +NUM_KV_HEADS = int(os.environ.get("COSMOS3_NUM_KV_HEADS", "8")) + +MAX_TEXT = 256 +NUM_VIS = 256 # 16x16 patches for 512x512 + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(scope="module") +def backbone(): + """Load compiled backbone model.""" + import torch_neuronx + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, world_size=TP_DEGREE, torch_dtype=torch.bfloat16 + ) + config = Cosmos3BackboneInferenceConfig( + neuron_config=neuron_config, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=128, + vocab_size=151936, + patch_channels=192, + latent_channels=48, + rope_theta=5000000.0, + mrope_section=[24, 20, 20], + ) + config.max_text_len = MAX_TEXT + config.num_vision_patches = NUM_VIS + + transformer_path = os.path.join(MODEL_PATH, "transformer") + app = NeuronCosmos3BackboneApplication(model_path=transformer_path, config=config) + app.load(COMPILED_PATH) + + # Warmup + dummy = torch.randn(1, NUM_VIS, 192, dtype=torch.bfloat16) + dummy_ts = torch.tensor([0.5], dtype=torch.bfloat16) + pos = torch.zeros(MAX_TEXT + NUM_VIS, 3, dtype=torch.long) + ids = torch.zeros(1, MAX_TEXT, dtype=torch.long) + for _ in range(2): + _ = app.forward(ids, dummy, dummy_ts, pos) + + return app + + +@pytest.fixture(scope="module") +def vae(): + """Load compiled VAE decoder.""" + import torch_neuronx + + return torch.jit.load(VAE_PATH) + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + from transformers import AutoTokenizer + + return AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + +# ============================================================================= +# Tests +# ============================================================================= + + +def test_backbone_loads(backbone): + """Smoke test: backbone loads and is callable.""" + assert backbone is not None + print("PASS: Backbone loaded successfully") + + +def test_backbone_output_shape(backbone): + """Test that backbone produces correct output shape.""" + ids = torch.zeros(1, MAX_TEXT, dtype=torch.long) + patches = torch.randn(1, NUM_VIS, 192, dtype=torch.bfloat16) + ts = torch.tensor([0.5], dtype=torch.bfloat16) + pos = torch.zeros(MAX_TEXT + NUM_VIS, 3, dtype=torch.long) + + output = backbone.forward(ids, patches, ts, pos) + + assert output.shape == (1, NUM_VIS, 192), ( + f"Expected (1, {NUM_VIS}, 192), got {output.shape}" + ) + assert output.dtype == torch.bfloat16 + print(f"PASS: Output shape {output.shape}, dtype {output.dtype}") + + +def test_backbone_nonzero_output(backbone): + """Test that backbone produces non-trivial output (not all zeros).""" + ids = torch.ones(1, MAX_TEXT, dtype=torch.long) # non-zero token IDs + patches = torch.randn(1, NUM_VIS, 192, dtype=torch.bfloat16) + ts = torch.tensor([0.5], dtype=torch.bfloat16) + pos = torch.zeros(MAX_TEXT + NUM_VIS, 3, dtype=torch.long) + + output = backbone.forward(ids, patches, ts, pos) + + assert output.abs().max() > 0.01, ( + "Output is near-zero (model may not be loading weights)" + ) + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + print(f"PASS: Output norm={output.norm():.2f}, max={output.abs().max():.4f}") + + +def test_patchify_unpatchify_roundtrip(): + """Test patchify/unpatchify are inverse operations.""" + latents = torch.randn(1, 48, 1, 32, 32, dtype=torch.float32) + + patches = patchify(latents, patch_size=2) + assert patches.shape == (1, 256, 192), f"Patch shape: {patches.shape}" + + reconstructed = unpatchify(patches, T=1, H=32, W=32, channels=48, patch_size=2) + assert reconstructed.shape == latents.shape + + # Should be exact roundtrip + assert torch.allclose(latents, reconstructed, atol=1e-6), ( + "Patchify/unpatchify not invertible" + ) + print("PASS: Patchify/unpatchify roundtrip exact") + + +def test_tokenization(tokenizer): + """Test tokenization produces expected format.""" + cond_ids, cond_len = tokenize_prompt( + tokenizer, + "A cat sitting on a windowsill", + height=512, + width=512, + max_len=MAX_TEXT, + ) + uncond_ids, uncond_len = tokenize_prompt( + tokenizer, "", height=512, width=512, max_len=MAX_TEXT, negative=True + ) + + assert cond_ids.shape == (1, MAX_TEXT) + assert uncond_ids.shape == (1, MAX_TEXT) + assert 30 < cond_len < MAX_TEXT, f"Cond len {cond_len} unexpected" + assert 20 < uncond_len < MAX_TEXT, f"Uncond len {uncond_len} unexpected" + + # Check special tokens at end + eos_id = tokenizer.eos_token_id + vision_start = tokenizer.convert_tokens_to_ids("<|vision_start|>") + assert cond_ids[0, cond_len - 1].item() == vision_start + assert cond_ids[0, cond_len - 2].item() == eos_id + + print(f"PASS: Tokenization - cond={cond_len} tokens, uncond={uncond_len} tokens") + + +def test_position_ids(): + """Test position ID generation.""" + pos = build_position_ids(max_text_len=256, actual_text_len=50, T=1, pH=16, pW=16) + + assert pos.shape == (256 + 256, 3), f"Position shape: {pos.shape}" + + # Text: all 3 axes incrementing + assert pos[0, 0] == 0 and pos[0, 1] == 0 and pos[0, 2] == 0 + assert pos[100, 0] == 100 and pos[100, 1] == 100 + + # Vision: temporal offset = 50 + 15000 = 15050 + vision_start = 256 + assert pos[vision_start, 0] == 15050, f"Vision temporal: {pos[vision_start, 0]}" + assert pos[vision_start, 1] == 0 # H starts at 0 + assert pos[vision_start, 2] == 0 # W starts at 0 + + print("PASS: Position IDs correct") + + +def test_backbone_latency(backbone): + """Test backbone call latency is within expected range.""" + ids = torch.zeros(1, MAX_TEXT, dtype=torch.long) + patches = torch.randn(1, NUM_VIS, 192, dtype=torch.bfloat16) + ts = torch.tensor([0.5], dtype=torch.bfloat16) + pos = torch.zeros(MAX_TEXT + NUM_VIS, 3, dtype=torch.long) + + # Warmup (already done in fixture, but just in case) + for _ in range(3): + _ = backbone.forward(ids, patches, ts, pos) + + # Measure + times = [] + for _ in range(10): + t0 = time.perf_counter() + _ = backbone.forward(ids, patches, ts, pos) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + # Nano: ~33ms, Super: ~80ms. Allow generous threshold. + threshold = 200.0 # ms + assert avg_ms < threshold, f"Latency {avg_ms:.1f}ms exceeds {threshold}ms threshold" + print(f"PASS: Backbone latency {avg_ms:.1f}ms (threshold: {threshold}ms)") + + +def test_full_generation(backbone, vae, tokenizer): + """End-to-end generation test: produces a valid image.""" + from diffusers import UniPCMultistepScheduler + + # Tokenize + prompt = "A red apple on a white table" + cond_ids, cond_len = tokenize_prompt(tokenizer, prompt, max_len=MAX_TEXT) + uncond_ids, uncond_len = tokenize_prompt( + tokenizer, "", max_len=MAX_TEXT, negative=True + ) + + # Position IDs + cond_pos = build_position_ids(MAX_TEXT, cond_len, T=1, pH=16, pW=16) + uncond_pos = build_position_ids(MAX_TEXT, uncond_len, T=1, pH=16, pW=16) + + # Latents + gen = torch.manual_seed(42) + latents = torch.randn(1, 48, 1, 32, 32, generator=gen, dtype=torch.float32) + + # Scheduler + scheduler = UniPCMultistepScheduler.from_pretrained( + MODEL_PATH, subfolder="scheduler" + ) + + # Denoise (use fewer steps for speed) + num_steps = 10 + latents = denoise( + backbone=backbone, + cond_ids=cond_ids, + uncond_ids=uncond_ids, + cond_pos=cond_pos, + uncond_pos=uncond_pos, + scheduler=scheduler, + latents=latents, + num_steps=num_steps, + cfg_scale=6.0, + ) + + assert latents.shape == (1, 48, 1, 32, 32) + assert not torch.isnan(latents).any(), "Latents contain NaN after denoising" + + # Denormalize + VAE + vae_config_path = os.path.join(MODEL_PATH, "vae", "config.json") + latents = denormalize_latents(latents, vae_config_path) + + with torch.no_grad(): + pixels = vae(latents.float()) + + assert pixels.shape == (1, 3, 1, 512, 512), f"Pixel shape: {pixels.shape}" + + # Valid pixel range + pixels_01 = ((pixels.squeeze(2).squeeze(0) + 1.0) / 2.0).clamp(0, 1) + assert pixels_01.min() >= 0 and pixels_01.max() <= 1 + + # Not blank (should have variance across pixels) + pixel_std = pixels_01.std() + assert pixel_std > 0.05, f"Image appears blank (std={pixel_std:.4f})" + + # Not uniform noise (should have spatial structure) + # Check that some channels have local correlation + center_patch = pixels_01[:, 200:300, 200:300] + corner_patch = pixels_01[:, 0:100, 0:100] + # Patches shouldn't be identical (image has spatial structure) + diff = (center_patch.mean() - corner_patch.mean()).abs() + # This is a weak check - just ensure the image isn't perfectly uniform + print( + f"PASS: Full generation - shape={pixels.shape}, std={pixel_std:.4f}, patch_diff={diff:.4f}" + ) + + +# ============================================================================= +# Manual runner +# ============================================================================= + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.INFO) + + print("=" * 80) + print("Cosmos3-Text2Image Integration Tests") + print("=" * 80) + + print(f"\nConfiguration:") + print(f" MODEL_PATH: {MODEL_PATH}") + print(f" COMPILED_PATH: {COMPILED_PATH}") + print(f" VAE_PATH: {VAE_PATH}") + print(f" TP_DEGREE: {TP_DEGREE}") + + # Unit tests (no model needed) + print("\n" + "-" * 40) + print("Unit Tests (no model required)") + print("-" * 40) + + print("\n1. Patchify/Unpatchify roundtrip...") + test_patchify_unpatchify_roundtrip() + + print("\n2. Position IDs...") + test_position_ids() + + # Load model + print("\n" + "-" * 40) + print("Loading model...") + print("-" * 40) + + import torch_neuronx + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, world_size=TP_DEGREE, torch_dtype=torch.bfloat16 + ) + config = Cosmos3BackboneInferenceConfig( + neuron_config=neuron_config, + hidden_size=HIDDEN_SIZE, + intermediate_size=INTERMEDIATE_SIZE, + num_hidden_layers=NUM_LAYERS, + num_attention_heads=NUM_HEADS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=128, + vocab_size=151936, + patch_channels=192, + latent_channels=48, + rope_theta=5000000.0, + mrope_section=[24, 20, 20], + ) + config.max_text_len = MAX_TEXT + config.num_vision_patches = NUM_VIS + + transformer_path = os.path.join(MODEL_PATH, "transformer") + backbone_model = NeuronCosmos3BackboneApplication( + model_path=transformer_path, config=config + ) + backbone_model.load(COMPILED_PATH) + + # Warmup + dummy = torch.randn(1, NUM_VIS, 192, dtype=torch.bfloat16) + dummy_ts = torch.tensor([0.5], dtype=torch.bfloat16) + pos = torch.zeros(MAX_TEXT + NUM_VIS, 3, dtype=torch.long) + ids = torch.zeros(1, MAX_TEXT, dtype=torch.long) + for _ in range(3): + _ = backbone_model.forward(ids, dummy, dummy_ts, pos) + + vae_model = torch.jit.load(VAE_PATH) + + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + # Integration tests + print("\n" + "-" * 40) + print("Integration Tests") + print("-" * 40) + + print("\n3. Backbone loads...") + test_backbone_loads(backbone_model) + + print("\n4. Output shape...") + test_backbone_output_shape(backbone_model) + + print("\n5. Non-zero output...") + test_backbone_nonzero_output(backbone_model) + + print("\n6. Tokenization...") + test_tokenization(tok) + + print("\n7. Backbone latency...") + test_backbone_latency(backbone_model) + + print("\n8. Full generation (10 steps)...") + test_full_generation(backbone_model, vae_model, tok) + + print("\n" + "=" * 80) + print("ALL TESTS PASSED") + print("=" * 80) diff --git a/contrib/models/Cosmos3-Text2Image/test/unit/__init__.py b/contrib/models/Cosmos3-Text2Image/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b