From db0eba2303f6c4ce58ea863f86736fcd6d40a683 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Tue, 28 Apr 2026 11:39:07 -0400 Subject: [PATCH 1/4] Add S3Diff one-step super-resolution contrib S3Diff (ECCV 2024) performs degradation-guided 4x super-resolution in a single denoising step using SD-Turbo with dynamic LoRA modulation. A DEResNet encoder estimates degradation and produces per-layer modulation matrices injected between LoRA A/B weights. Uses torch_neuronx.trace() (no TP needed, model is ~2 GB). Validated on trn2.3xlarge, SDK 2.29: 0.544s/image, ~21x CPU speedup. LoRA components use --model-type=unet-inference to avoid NaN from --auto-cast=matmult on small einsum operations. --- contrib/models/S3Diff/README.md | 136 ++++ contrib/models/S3Diff/src/__init__.py | 3 + contrib/models/S3Diff/src/generate_s3diff.py | 159 +++++ contrib/models/S3Diff/src/modeling_s3diff.py | 659 ++++++++++++++++++ contrib/models/S3Diff/test/__init__.py | 0 .../S3Diff/test/integration/__init__.py | 0 .../S3Diff/test/integration/test_model.py | 159 +++++ 7 files changed, 1116 insertions(+) create mode 100644 contrib/models/S3Diff/README.md create mode 100644 contrib/models/S3Diff/src/__init__.py create mode 100644 contrib/models/S3Diff/src/generate_s3diff.py create mode 100644 contrib/models/S3Diff/src/modeling_s3diff.py create mode 100644 contrib/models/S3Diff/test/__init__.py create mode 100644 contrib/models/S3Diff/test/integration/__init__.py create mode 100644 contrib/models/S3Diff/test/integration/test_model.py diff --git a/contrib/models/S3Diff/README.md b/contrib/models/S3Diff/README.md new file mode 100644 index 00000000..c0dfcbc0 --- /dev/null +++ b/contrib/models/S3Diff/README.md @@ -0,0 +1,136 @@ +# Contrib Model: S3Diff + +S3Diff one-step 4x super-resolution on AWS Neuron using `torch_neuronx.trace()`. + +## Model Information + +- **HuggingFace ID:** `zhangap/S3Diff` (weights), base model `stabilityai/sd-turbo` +- **Model Type:** One-step diffusion model for image super-resolution +- **Parameters:** ~1.3B total (~2 GB on disk) +- **Architecture:** SD-Turbo UNet with degradation-guided dynamic LoRA modulation (DEResNet encoder, CLIP text encoder, VAE encoder/decoder with per-layer LoRA, UNet with per-layer LoRA) +- **Paper:** "Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors" (ECCV 2024) +- **License:** Check model cards for SD-Turbo and S3Diff + +## Key Architecture Notes + +S3Diff is unusual among diffusion models: + +1. **Single denoising step**: Only one UNet forward pass per image (at t=999), making it extremely fast. +2. **Dynamic LoRA modulation**: A DEResNet encoder estimates input degradation and produces per-layer LoRA scaling matrices. These `[rank, rank]` modulation matrices are injected between `lora_A` and `lora_B` via einsum operations, conditioning the UNet on the specific degradation pattern of each input. +3. **Two LoRA ranks**: VAE uses rank=16 (6 blocks), UNet uses rank=32 (10 blocks). +4. **Small model**: Total size ~2 GB, fits on a single NeuronCore with no tensor parallelism needed. + +This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, which is appropriate for the model's small size and non-autoregressive architecture. + +## Validation Results + +**Validated:** 2026-04-28 +**Instance:** trn2.3xlarge (LNC=2) +**SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9 + +### Benchmark Results (128x128 -> 512x512, single step) + +| Component | Time | +|-----------|------| +| DEResNet | 3.8ms | +| Modulation (CPU) | 0.5ms | +| VAE Encode | 83.2ms | +| UNet x2 (CFG) | 218.8ms | +| VAE Decode | 164.6ms | +| **Total** | **0.471s** | + +| Metric | Value | +|--------|-------| +| Resolution | 128x128 -> 512x512 (4x SR) | +| Inference steps | 1 (one-step model) | +| Warm generation time | 0.544s | +| Throughput | ~1.8 img/s | +| Total compile time | ~21 min | +| CPU baseline | 11.53s | +| Speedup vs CPU | ~21x | + +### Accuracy Validation + +Visual quality validated against CPU reference output. The model produces high-quality 4x upscaled images with correct degradation-aware enhancement. + +## Usage + +```python +from S3Diff.src.modeling_s3diff import S3DiffNeuronPipeline +from PIL import Image + +pipeline = S3DiffNeuronPipeline( + sd_turbo_path="/shared/sd-turbo/", + s3diff_weights_path="/shared/s3diff/s3diff.pkl", + de_net_path="/shared/s3diff/de_net.pth", + compile_dir="/tmp/s3diff/compiled/", + lr_size=128, +) +pipeline.load() +pipeline.compile() + +lr_image = Image.open("input_128x128.png").convert("RGB") +sr_image = pipeline(lr_image) +sr_image.save("output_512x512.png") +``` + +Or use the provided script: + +```bash +python src/generate_s3diff.py \ + --download \ + --input_image input.png \ + --output_image output.png \ + --compile_dir /tmp/s3diff/compiled/ +``` + +## Setup + +```bash +# Activate NxDI environment +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Install dependencies +pip install diffusers transformers peft accelerate torchvision + +# Download weights +python src/generate_s3diff.py --download + +# Or manually: +# SD-Turbo: huggingface-cli download stabilityai/sd-turbo --local-dir /shared/sd-turbo/ +# S3Diff: huggingface-cli download zhangap/S3Diff --local-dir /shared/s3diff/ +# DEResNet: git clone https://github.com/ArcticHare105/S3Diff.git /tmp/s3diff_repo +# cp /tmp/s3diff_repo/assets/mm-realsr/de_net.pth /shared/s3diff/ +``` + +## Compatibility Matrix + +| Instance/Version | SDK 2.29 | SDK 2.28 | +|------------------|----------|----------| +| trn2.3xlarge | VALIDATED | Not tested | + +## Example Checkpoints + +* [zhangap/S3Diff](https://huggingface.co/zhangap/S3Diff) -- S3Diff LoRA weights +* [stabilityai/sd-turbo](https://huggingface.co/stabilityai/sd-turbo) -- Base SD-Turbo model + +## Testing Instructions + +```bash +export SD_TURBO_PATH=/shared/sd-turbo/ +export S3DIFF_WEIGHTS=/shared/s3diff/s3diff.pkl +export DE_NET_WEIGHTS=/shared/s3diff/de_net.pth + +cd contrib/models/S3Diff/ +pytest test/integration/test_model.py -v + +# Or standalone +python test/integration/test_model.py +``` + +## Known Issues + +- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. The VAE encoder, UNet, and VAE decoder all use `--model-type=unet-inference` instead, which avoids this issue. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`. +- **Compilation time**: ~21 minutes total (UNet is the slowest at ~12 min). Compiled models are cached for reuse. +- **CFG is sequential**: Two separate UNet passes (positive + negative prompt), not batched. Batching with batch_size=2 would halve UNet wall time but requires recompilation. +- **Neuron runtime HBM**: Once loaded, compiled models stay in HBM even if the Python object is deleted (within the same process). Plan memory accordingly. diff --git a/contrib/models/S3Diff/src/__init__.py b/contrib/models/S3Diff/src/__init__.py new file mode 100644 index 00000000..3d29296f --- /dev/null +++ b/contrib/models/S3Diff/src/__init__.py @@ -0,0 +1,3 @@ +# S3Diff one-step super-resolution on AWS Neuron +# Uses torch_neuronx.trace() for compilation -- not NxDI TP sharding +# (model is ~2 GB, fits on a single NeuronCore) diff --git a/contrib/models/S3Diff/src/generate_s3diff.py b/contrib/models/S3Diff/src/generate_s3diff.py new file mode 100644 index 00000000..db98bd14 --- /dev/null +++ b/contrib/models/S3Diff/src/generate_s3diff.py @@ -0,0 +1,159 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +S3Diff one-step 4x super-resolution on AWS Neuron. + +Downloads required weights (SD-Turbo, S3Diff LoRA, DEResNet), compiles all +components, and runs super-resolution inference. + +Usage: + python generate_s3diff.py \ + --input_image /path/to/lr_image.png \ + --output_image /path/to/sr_output.png \ + --compile_dir /tmp/s3diff/compiled/ + +Requirements: + pip install diffusers transformers peft accelerate torchvision +""" + +import argparse +import os +import time + +import torch +from PIL import Image + +try: + from .modeling_s3diff import S3DiffNeuronPipeline +except ImportError: + from modeling_s3diff import S3DiffNeuronPipeline + + +DEFAULT_SD_TURBO_PATH = "/shared/sd-turbo/" +DEFAULT_S3DIFF_WEIGHTS = "/shared/s3diff/s3diff.pkl" +DEFAULT_DE_NET_WEIGHTS = "/shared/s3diff/de_net.pth" +DEFAULT_COMPILE_DIR = "/tmp/s3diff/compiled/" + + +def download_weights(sd_turbo_path, s3diff_weights_path, de_net_path): + """Download model weights if not already present.""" + from huggingface_hub import hf_hub_download, snapshot_download + + if not os.path.exists(sd_turbo_path): + print("Downloading SD-Turbo...") + snapshot_download("stabilityai/sd-turbo", local_dir=sd_turbo_path) + + if not os.path.exists(s3diff_weights_path): + print("Downloading S3Diff weights...") + os.makedirs(os.path.dirname(s3diff_weights_path), exist_ok=True) + hf_hub_download( + "zhangap/S3Diff", + filename="s3diff.pkl", + local_dir=os.path.dirname(s3diff_weights_path), + ) + + if not os.path.exists(de_net_path): + print("Downloading DEResNet weights...") + os.makedirs(os.path.dirname(de_net_path), exist_ok=True) + # DEResNet weights are in the S3Diff GitHub repo + import subprocess + + repo_dir = "/tmp/s3diff_repo" + if not os.path.exists(repo_dir): + subprocess.run( + [ + "git", + "clone", + "https://github.com/ArcticHare105/S3Diff.git", + repo_dir, + ], + check=True, + ) + import shutil + + shutil.copy2( + os.path.join(repo_dir, "assets", "mm-realsr", "de_net.pth"), + de_net_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="S3Diff 4x SR on AWS Neuron") + parser.add_argument( + "--input_image", + type=str, + default=None, + help="Path to input low-resolution image (128x128 recommended)", + ) + parser.add_argument( + "--output_image", type=str, default="sr_output.png", help="Output path" + ) + parser.add_argument("--sd_turbo_path", type=str, default=DEFAULT_SD_TURBO_PATH) + parser.add_argument("--s3diff_weights", type=str, default=DEFAULT_S3DIFF_WEIGHTS) + parser.add_argument("--de_net_weights", type=str, default=DEFAULT_DE_NET_WEIGHTS) + parser.add_argument("--compile_dir", type=str, default=DEFAULT_COMPILE_DIR) + parser.add_argument("--num_images", type=int, default=3) + parser.add_argument("--warmup_rounds", type=int, default=5) + parser.add_argument("--download", action="store_true", help="Download weights") + args = parser.parse_args() + + if args.download: + download_weights(args.sd_turbo_path, args.s3diff_weights, args.de_net_weights) + + # Create a test image if none provided + if args.input_image is None: + print("No input image provided, creating a 128x128 test pattern...") + import numpy as np + + test_img = np.random.randint(0, 255, (128, 128, 3), dtype=np.uint8) + lr_image = Image.fromarray(test_img) + else: + lr_image = Image.open(args.input_image).convert("RGB") + print(f"Input image: {lr_image.size}") + + # Build pipeline + pipeline = S3DiffNeuronPipeline( + sd_turbo_path=args.sd_turbo_path, + s3diff_weights_path=args.s3diff_weights, + de_net_path=args.de_net_weights, + compile_dir=args.compile_dir, + lr_size=lr_image.size[0], + ) + + print("\nLoading model...") + pipeline.load() + + print("\nCompiling...") + t0 = time.time() + pipeline.compile() + compile_time = time.time() - t0 + print(f"Total compilation: {compile_time:.1f}s") + + # Warmup + print(f"\nWarming up ({args.warmup_rounds} rounds)...") + for _ in range(args.warmup_rounds): + pipeline(lr_image) + + # Benchmark + print(f"\nGenerating {args.num_images} images...") + total_time = 0 + for i in range(args.num_images): + t0 = time.time() + sr_image = pipeline(lr_image) + elapsed = time.time() - t0 + total_time += elapsed + print(f" Image {i + 1}: {elapsed:.3f}s") + + avg_time = total_time / args.num_images + print(f"\nResults:") + print(f" Average time: {avg_time:.3f}s") + print(f" Throughput: {1.0 / avg_time:.2f} img/s") + print(f" Compilation: {compile_time:.1f}s") + + sr_image.save(args.output_image) + print(f" Saved: {args.output_image}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/S3Diff/src/modeling_s3diff.py b/contrib/models/S3Diff/src/modeling_s3diff.py new file mode 100644 index 00000000..5cc9dda4 --- /dev/null +++ b/contrib/models/S3Diff/src/modeling_s3diff.py @@ -0,0 +1,659 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +S3Diff one-step 4x super-resolution on AWS Neuron. + +Model: Yukang/S3Diff (weights from zhangap/S3Diff on HuggingFace) +Paper: "Degradation-Guided One-Step Image Super-Resolution with Diffusion Priors" (ECCV 2024) + +Architecture: SD-Turbo UNet with dynamic LoRA modulation. A DEResNet encoder +estimates input degradation, and per-layer LoRA scaling factors are computed +from these scores to condition the UNet on the specific degradation pattern. + +This module provides the full pipeline: load, compile, and run inference. +Uses torch_neuronx.trace() since the model is small (~2 GB) and does not +benefit from tensor parallelism. +""" + +import math +import os +import time +from typing import Any, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms + +# Scale factor and padding constants +SF = 4 +PAD_H = 512 +PAD_W = 512 + +# --------------------------------------------------------------------------- +# DEResNet -- degradation estimator +# --------------------------------------------------------------------------- + + +class ResidualBlockNoBN(nn.Module): + def __init__(self, num_feat=64): + super().__init__() + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return x + self.conv2(self.relu(self.conv1(x))) + + +class DEResNet(nn.Module): + """Degradation Estimation ResNet. Outputs per-degradation scores in [0, 1].""" + + def __init__( + self, + num_in_ch=3, + num_degradation=2, + num_feats=[64, 64, 64, 128], + num_blocks=[2, 2, 2, 2], + downscales=[1, 1, 2, 1], + ): + super().__init__() + num_stage = len(num_feats) + self.conv_first = nn.ModuleList() + for _ in range(num_degradation): + self.conv_first.append(nn.Conv2d(num_in_ch, num_feats[0], 3, 1, 1)) + self.body = nn.ModuleList() + for _ in range(num_degradation): + body = [] + for stage in range(num_stage): + for _ in range(num_blocks[stage]): + body.append(ResidualBlockNoBN(num_feats[stage])) + if downscales[stage] == 1: + if ( + stage < num_stage - 1 + and num_feats[stage] != num_feats[stage + 1] + ): + body.append( + nn.Conv2d(num_feats[stage], num_feats[stage + 1], 3, 1, 1) + ) + elif downscales[stage] == 2: + body.append( + nn.Conv2d( + num_feats[stage], + num_feats[min(stage + 1, num_stage - 1)], + 3, + 2, + 1, + ) + ) + self.body.append(nn.Sequential(*body)) + self.num_degradation = num_degradation + self.fc_degree = nn.ModuleList() + for _ in range(num_degradation): + self.fc_degree.append( + nn.Sequential( + nn.Linear(num_feats[-1], 512), + nn.ReLU(inplace=True), + nn.Linear(512, 1), + nn.Sigmoid(), + ) + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, x): + degrees = [] + for i in range(self.num_degradation): + x_out = self.conv_first[i](x) + feat = self.body[i](x_out) + feat = self.avg_pool(feat).squeeze(-1).squeeze(-1) + degrees.append(self.fc_degree[i](feat).squeeze(-1)) + return torch.stack(degrees, dim=1) + + +# --------------------------------------------------------------------------- +# Custom LoRA forward with degradation modulation +# --------------------------------------------------------------------------- + + +def my_lora_fwd(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: + """LoRA forward that injects degradation modulation between lora_A and lora_B. + + For Conv2d LoRA layers: einsum('...khw,...kr->...rhw', lora_A(x), de_mod) + For Linear LoRA layers: einsum('...lk,...kr->...lr', lora_A(x), de_mod) + + The de_mod tensor is a [B, rank, rank] matrix set per-layer by the wrapper + modules before the forward pass. + """ + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + if self.disable_adapters: + if self.merged: + self.unmerge() + result = self.base_layer(x, *args, **kwargs) + elif adapter_names is not None: + result = self._mixed_batch_forward( + x, *args, adapter_names=adapter_names, **kwargs + ) + elif self.merged: + result = self.base_layer(x, *args, **kwargs) + else: + result = self.base_layer(x, *args, **kwargs) + torch_result_dtype = result.dtype + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(lora_A.weight.dtype) + if not self.use_dora[active_adapter]: + _tmp = lora_A(dropout(x)) + if isinstance(lora_A, torch.nn.Conv2d): + _tmp = torch.einsum("...khw,...kr->...rhw", _tmp, self.de_mod) + elif isinstance(lora_A, torch.nn.Linear): + _tmp = torch.einsum("...lk,...kr->...lr", _tmp, self.de_mod) + else: + raise NotImplementedError + result = result + lora_B(_tmp) * scaling + else: + x = dropout(x) + result = result + self._apply_dora( + x, lora_A, lora_B, scaling, active_adapter + ) + result = result.to(torch_result_dtype) + return result + + +# --------------------------------------------------------------------------- +# Neuron tracing wrappers +# --------------------------------------------------------------------------- + + +class TextEncoderWrapper(nn.Module): + def __init__(self, text_encoder): + super().__init__() + self.text_encoder = text_encoder + + def forward(self, input_ids): + return self.text_encoder(input_ids)[0] + + +class VAEEncoderWrapper(nn.Module): + """Wraps VAE encoder to accept de_mod_all [B, 6, rank, rank] as explicit input.""" + + def __init__(self, vae, vae_lora_layers, lora_rank_vae): + super().__init__() + self.vae = vae + self.vae_lora_layers = vae_lora_layers + self.lora_rank_vae = lora_rank_vae + self.layer_block_map = {} + for layer_name in vae_lora_layers: + split_name = layer_name.split(".") + if split_name[1] == "down_blocks": + self.layer_block_map[layer_name] = int(split_name[2]) + elif split_name[1] == "mid_block": + self.layer_block_map[layer_name] = 4 + else: + self.layer_block_map[layer_name] = 5 + + def forward(self, pixel_values, de_mod_all): + for layer_name, module in self.vae.named_modules(): + if layer_name in self.vae_lora_layers: + block_idx = self.layer_block_map[layer_name] + module.de_mod = de_mod_all[:, block_idx] + latent = ( + self.vae.encode(pixel_values).latent_dist.sample() + * self.vae.config.scaling_factor + ) + return latent + + +class UNetWrapper(nn.Module): + """Wraps UNet to accept de_mod_all [B, 10, rank, rank] as explicit input.""" + + def __init__(self, unet, unet_lora_layers, lora_rank_unet): + super().__init__() + self.unet = unet + self.unet_lora_layers = unet_lora_layers + self.lora_rank_unet = lora_rank_unet + self.layer_block_map = {} + for layer_name in unet_lora_layers: + split_name = layer_name.split(".") + if split_name[0] == "down_blocks": + self.layer_block_map[layer_name] = int(split_name[1]) + elif split_name[0] == "mid_block": + self.layer_block_map[layer_name] = 4 + elif split_name[0] == "up_blocks": + self.layer_block_map[layer_name] = int(split_name[1]) + 5 + else: + self.layer_block_map[layer_name] = 9 + + def forward(self, latent, timestep, encoder_hidden_states, de_mod_all): + for layer_name, module in self.unet.named_modules(): + if layer_name in self.unet_lora_layers: + block_idx = self.layer_block_map[layer_name] + module.de_mod = de_mod_all[:, block_idx] + return self.unet( + latent, timestep, encoder_hidden_states=encoder_hidden_states + ).sample + + +class VAEDecoderWrapper(nn.Module): + """Simple wrapper for VAE decoder (no LoRA).""" + + def __init__(self, vae): + super().__init__() + self.vae = vae + + def forward(self, latent): + return self.vae.decode(latent / self.vae.config.scaling_factor).sample + + +# --------------------------------------------------------------------------- +# S3Diff Neuron Pipeline +# --------------------------------------------------------------------------- + + +class S3DiffNeuronPipeline: + """End-to-end S3Diff super-resolution pipeline on Neuron. + + Handles model loading, compilation, and inference. + + Args: + sd_turbo_path: Path to SD-Turbo checkpoint (stabilityai/sd-turbo) + s3diff_weights_path: Path to s3diff.pkl weights + de_net_path: Path to de_net.pth weights + compile_dir: Directory for compiled model cache + lr_size: Input low-resolution size (default: 128) + """ + + def __init__( + self, + sd_turbo_path: str, + s3diff_weights_path: str, + de_net_path: str, + compile_dir: str = "/tmp/s3diff/compiled/", + lr_size: int = 128, + ): + self.sd_turbo_path = sd_turbo_path + self.s3diff_weights_path = s3diff_weights_path + self.de_net_path = de_net_path + self.compile_dir = compile_dir + self.lr_size = lr_size + self.hr_size = lr_size * SF + + # Will be set during load() + self.de_net_neuron = None + self.text_enc_neuron = None + self.vae_enc_neuron = None + self.unet_neuron = None + self.vae_dec_neuron = None + self.tokenizer = None + self.sched = None + self.compute_modulation = None + + def load(self): + """Load all model components and build the modulation network.""" + from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel + from peft import LoraConfig + from transformers import AutoTokenizer, CLIPTextModel + + print("Loading SD-Turbo + S3Diff LoRA weights...") + + self.tokenizer = AutoTokenizer.from_pretrained( + self.sd_turbo_path, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + self.sd_turbo_path, subfolder="text_encoder" + ).eval() + vae = AutoencoderKL.from_pretrained(self.sd_turbo_path, subfolder="vae") + unet = UNet2DConditionModel.from_pretrained( + self.sd_turbo_path, subfolder="unet" + ) + + sd = torch.load(self.s3diff_weights_path, map_location="cpu") + self.lora_rank_unet = sd["rank_unet"] + self.lora_rank_vae = sd["rank_vae"] + print(f"LoRA ranks: unet={self.lora_rank_unet}, vae={self.lora_rank_vae}") + + # Add LoRA adapters and load trained weights + vae_lora_config = LoraConfig( + r=self.lora_rank_vae, + init_lora_weights="gaussian", + target_modules=sd["vae_lora_target_modules"], + ) + vae.add_adapter(vae_lora_config, adapter_name="vae_skip") + _sd_vae = vae.state_dict() + for k in sd["state_dict_vae"]: + _sd_vae[k] = sd["state_dict_vae"][k] + vae.load_state_dict(_sd_vae) + + unet_lora_config = LoraConfig( + r=self.lora_rank_unet, + init_lora_weights="gaussian", + target_modules=sd["unet_lora_target_modules"], + ) + unet.add_adapter(unet_lora_config) + _sd_unet = unet.state_dict() + for k in sd["state_dict_unet"]: + _sd_unet[k] = sd["state_dict_unet"][k] + unet.load_state_dict(_sd_unet) + + # Monkey-patch LoRA forward methods + vae_lora_layers = [] + for name, module in vae.named_modules(): + if "base_layer" in name: + vae_lora_layers.append(name[: -len(".base_layer")]) + for name, module in vae.named_modules(): + if name in vae_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + unet_lora_layers = [] + for name, module in unet.named_modules(): + if "base_layer" in name: + unet_lora_layers.append(name[: -len(".base_layer")]) + for name, module in unet.named_modules(): + if name in unet_lora_layers: + module.forward = my_lora_fwd.__get__(module, module.__class__) + + vae.eval() + unet.eval() + + # Modulation MLPs (tiny, run on CPU) + num_embeddings = 64 + block_embedding_dim = 64 + W = nn.Parameter(sd["w"], requires_grad=False) + + vae_de_mlp = nn.Sequential(nn.Linear(num_embeddings * 4, 256), nn.ReLU(True)) + unet_de_mlp = nn.Sequential(nn.Linear(num_embeddings * 4, 256), nn.ReLU(True)) + vae_block_mlp = nn.Sequential(nn.Linear(block_embedding_dim, 64), nn.ReLU(True)) + unet_block_mlp = nn.Sequential( + nn.Linear(block_embedding_dim, 64), nn.ReLU(True) + ) + vae_fuse_mlp = nn.Linear(256 + 64, self.lora_rank_vae**2) + unet_fuse_mlp = nn.Linear(256 + 64, self.lora_rank_unet**2) + vae_block_embeddings = nn.Embedding(6, block_embedding_dim) + unet_block_embeddings = nn.Embedding(10, block_embedding_dim) + + for name, module in [ + ("vae_de_mlp", vae_de_mlp), + ("unet_de_mlp", unet_de_mlp), + ("vae_block_mlp", vae_block_mlp), + ("unet_block_mlp", unet_block_mlp), + ("vae_fuse_mlp", vae_fuse_mlp), + ("unet_fuse_mlp", unet_fuse_mlp), + ]: + _ssd = module.state_dict() + for k in sd[f"state_dict_{name}"]: + _ssd[k] = sd[f"state_dict_{name}"][k] + module.load_state_dict(_ssd) + vae_block_embeddings.load_state_dict( + sd["state_embeddings"]["state_dict_vae_block"] + ) + unet_block_embeddings.load_state_dict( + sd["state_embeddings"]["state_dict_unet_block"] + ) + + for m in [ + vae_de_mlp, + unet_de_mlp, + vae_block_mlp, + unet_block_mlp, + vae_fuse_mlp, + unet_fuse_mlp, + ]: + m.eval() + + # DEResNet + de_net = DEResNet(num_in_ch=3, num_degradation=2) + de_net.load_state_dict(torch.load(self.de_net_path, map_location="cpu")) + de_net.eval() + + # Scheduler + self.sched = DDPMScheduler.from_pretrained( + self.sd_turbo_path, subfolder="scheduler" + ) + self.sched.set_timesteps(1, device="cpu") + + # Build wrappers + self._text_enc_wrapper = TextEncoderWrapper(text_encoder) + self._vae_enc_wrapper = VAEEncoderWrapper( + vae, vae_lora_layers, self.lora_rank_vae + ) + self._unet_wrapper = UNetWrapper(unet, unet_lora_layers, self.lora_rank_unet) + self._vae_dec_wrapper = VAEDecoderWrapper(vae) + self._de_net = de_net + + # Build modulation closure + lora_rank_vae = self.lora_rank_vae + lora_rank_unet = self.lora_rank_unet + + def compute_modulation(deg_score): + deg_proj = deg_score[..., None] * W[None, None, :] * 2 * np.pi + deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1) + deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1) + + vae_de_c_embed = vae_de_mlp(deg_proj) + unet_de_c_embed = unet_de_mlp(deg_proj) + + vae_block_c_embeds = vae_block_mlp(vae_block_embeddings.weight) + unet_block_c_embeds = unet_block_mlp(unet_block_embeddings.weight) + + B = deg_score.shape[0] + vae_embeds = vae_fuse_mlp( + torch.cat( + [ + vae_de_c_embed.unsqueeze(1).repeat( + 1, vae_block_c_embeds.shape[0], 1 + ), + vae_block_c_embeds.unsqueeze(0).repeat(B, 1, 1), + ], + -1, + ) + ) + unet_embeds = unet_fuse_mlp( + torch.cat( + [ + unet_de_c_embed.unsqueeze(1).repeat( + 1, unet_block_c_embeds.shape[0], 1 + ), + unet_block_c_embeds.unsqueeze(0).repeat(B, 1, 1), + ], + -1, + ) + ) + + return ( + vae_embeds.reshape(B, 6, lora_rank_vae, lora_rank_vae), + unet_embeds.reshape(B, 10, lora_rank_unet, lora_rank_unet), + ) + + self.compute_modulation = compute_modulation + + print(f"VAE LoRA layers: {len(vae_lora_layers)}") + print(f"UNet LoRA layers: {len(unet_lora_layers)}") + print("Model loaded successfully.") + + def compile(self): + """Compile all components with torch_neuronx.trace().""" + import torch_neuronx + + os.makedirs(self.compile_dir, exist_ok=True) + lr_h, lr_w = self.lr_size, self.lr_size + hr_h, hr_w = self.hr_size, self.hr_size + lat_h, lat_w = hr_h // 8, hr_w // 8 + + # DEResNet + path = os.path.join(self.compile_dir, "de_net.pt") + if os.path.exists(path): + print("DEResNet: loading cached...") + self.de_net_neuron = torch.jit.load(path) + else: + print("DEResNet: compiling...") + t0 = time.time() + self.de_net_neuron = torch_neuronx.trace( + self._de_net, + torch.randn(1, 3, lr_h, lr_w), + compiler_args=["--auto-cast", "matmult", "-O1"], + ) + torch.jit.save(self.de_net_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # Text encoder + path = os.path.join(self.compile_dir, "text_encoder.pt") + if os.path.exists(path): + print("Text encoder: loading cached...") + self.text_enc_neuron = torch.jit.load(path) + else: + print("Text encoder: compiling...") + t0 = time.time() + self.text_enc_neuron = torch_neuronx.trace( + self._text_enc_wrapper, + torch.zeros(1, 77, dtype=torch.long), + compiler_args=["--auto-cast", "matmult", "-O1"], + ) + torch.jit.save(self.text_enc_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # VAE encoder (with LoRA) + path = os.path.join(self.compile_dir, "vae_encoder.pt") + if os.path.exists(path): + print("VAE encoder: loading cached...") + self.vae_enc_neuron = torch.jit.load(path) + else: + print("VAE encoder: compiling...") + t0 = time.time() + self.vae_enc_neuron = torch_neuronx.trace( + self._vae_enc_wrapper, + ( + torch.randn(1, 3, hr_h, hr_w), + torch.randn(1, 6, self.lora_rank_vae, self.lora_rank_vae), + ), + compiler_args=["--model-type=unet-inference", "-O1"], + ) + torch.jit.save(self.vae_enc_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # UNet (with LoRA) + path = os.path.join(self.compile_dir, "unet.pt") + if os.path.exists(path): + print("UNet: loading cached...") + self.unet_neuron = torch.jit.load(path) + else: + print("UNet: compiling...") + t0 = time.time() + self.unet_neuron = torch_neuronx.trace( + self._unet_wrapper, + ( + torch.randn(1, 4, lat_h, lat_w), + torch.tensor([999], dtype=torch.long), + torch.randn(1, 77, 1024), + torch.randn(1, 10, self.lora_rank_unet, self.lora_rank_unet), + ), + compiler_args=["--model-type=unet-inference", "-O1"], + ) + torch.jit.save(self.unet_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + # VAE decoder (no LoRA) + path = os.path.join(self.compile_dir, "vae_decoder.pt") + if os.path.exists(path): + print("VAE decoder: loading cached...") + self.vae_dec_neuron = torch.jit.load(path) + else: + print("VAE decoder: compiling...") + t0 = time.time() + self.vae_dec_neuron = torch_neuronx.trace( + self._vae_dec_wrapper, + torch.randn(1, 4, lat_h, lat_w), + compiler_args=["--model-type=unet-inference", "-O1"], + ) + torch.jit.save(self.vae_dec_neuron, path) + print(f" Done in {time.time() - t0:.1f}s") + + print("All components compiled.") + + @torch.no_grad() + def __call__( + self, + lr_image: Image.Image, + pos_prompt: str = "high quality, highly detailed, clean", + neg_prompt: str = "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", + cfg_scale: float = 1.07, + ) -> Image.Image: + """Run 4x super-resolution on a low-resolution PIL image. + + Args: + lr_image: Input PIL image (should be lr_size x lr_size) + pos_prompt: Positive text prompt + neg_prompt: Negative text prompt + cfg_scale: Classifier-free guidance scale + + Returns: + Super-resolved PIL image (4x the input resolution) + """ + to_tensor = transforms.ToTensor() + im_lr = to_tensor(lr_image).unsqueeze(0) + + # Preprocess: resize 4x + normalize + pad + ori_h, ori_w = im_lr.shape[2:] + im_lr_resize = F.interpolate( + im_lr, + size=(ori_h * SF, ori_w * SF), + mode="bilinear", + align_corners=False, + ) + im_lr_resize_norm = (im_lr_resize * 2 - 1.0).clamp(-1, 1) + resize_h, resize_w = im_lr_resize_norm.shape[2:] + im_lr_resize_norm = F.pad( + im_lr_resize_norm, + pad=(0, PAD_W - resize_w, 0, PAD_H - resize_h), + mode="reflect", + ) + + # 1. DEResNet -> degradation scores + deg_score = self.de_net_neuron(im_lr) + + # 2. Compute modulation on CPU + vae_de_mod_all, unet_de_mod_all = self.compute_modulation(deg_score) + + # 3. Text encoding + pos_tokens = self.tokenizer( + pos_prompt, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + neg_tokens = self.tokenizer( + neg_prompt, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ).input_ids + pos_enc = self.text_enc_neuron(pos_tokens) + neg_enc = self.text_enc_neuron(neg_tokens) + + # 4. VAE Encode (with de_mod) + latent = self.vae_enc_neuron(im_lr_resize_norm, vae_de_mod_all) + + # 5. UNet x2 for CFG (with de_mod) + timestep = torch.tensor([999], dtype=torch.long) + pos_pred = self.unet_neuron(latent, timestep, pos_enc, unet_de_mod_all) + neg_pred = self.unet_neuron(latent, timestep, neg_enc, unet_de_mod_all) + model_pred = neg_pred + cfg_scale * (pos_pred - neg_pred) + + # 6. Scheduler step (CPU) + x_denoised = self.sched.step( + model_pred.cpu(), torch.tensor([999]), latent.cpu(), return_dict=True + ).prev_sample + + # 7. VAE Decode + output = self.vae_dec_neuron(x_denoised).clamp(-1, 1) + output = output[:, :, :resize_h, :resize_w] + return transforms.ToPILImage()((output[0] * 0.5 + 0.5).cpu().clamp(0, 1)) diff --git a/contrib/models/S3Diff/test/__init__.py b/contrib/models/S3Diff/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/S3Diff/test/integration/__init__.py b/contrib/models/S3Diff/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/S3Diff/test/integration/test_model.py b/contrib/models/S3Diff/test/integration/test_model.py new file mode 100644 index 00000000..c00ad580 --- /dev/null +++ b/contrib/models/S3Diff/test/integration/test_model.py @@ -0,0 +1,159 @@ +""" +Integration tests for S3Diff one-step super-resolution on Neuron. + +Tests: +1. test_smoke_pipeline_loads: Pipeline loads without errors +2. test_sr_produces_correct_size: 128x128 -> 512x512 +3. test_warm_generation_time: < 2s per image + +Requirements: + - trn2.3xlarge + - Neuron SDK 2.29+ + - diffusers, transformers, peft, torchvision + - Weights downloaded (see generate_s3diff.py --download) + +Run: + pytest test_model.py -v + # Or standalone: + python test_model.py +""" + +import gc +import os +import sys +import time + +import numpy as np +import pytest +import torch +from PIL import Image + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +SRC_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, "..", "..", "src")) +sys.path.insert(0, SRC_DIR) + +from modeling_s3diff import S3DiffNeuronPipeline + +SD_TURBO_PATH = os.environ.get("SD_TURBO_PATH", "/shared/sd-turbo/") +S3DIFF_WEIGHTS = os.environ.get("S3DIFF_WEIGHTS", "/shared/s3diff/s3diff.pkl") +DE_NET_WEIGHTS = os.environ.get("DE_NET_WEIGHTS", "/shared/s3diff/de_net.pth") +COMPILE_DIR = os.environ.get("S3DIFF_COMPILE_DIR", "/tmp/s3diff_test/compiled/") +LR_SIZE = 128 +HR_SIZE = 512 + + +def make_test_image(size=128): + """Create a deterministic test image.""" + np.random.seed(42) + data = np.random.randint(0, 255, (size, size, 3), dtype=np.uint8) + return Image.fromarray(data) + + +@pytest.fixture(scope="module") +def pipeline(): + """Create, load, compile, and warm up the S3Diff pipeline.""" + pipe = S3DiffNeuronPipeline( + sd_turbo_path=SD_TURBO_PATH, + s3diff_weights_path=S3DIFF_WEIGHTS, + de_net_path=DE_NET_WEIGHTS, + compile_dir=COMPILE_DIR, + lr_size=LR_SIZE, + ) + pipe.load() + pipe.compile() + + # Warmup + test_img = make_test_image() + pipe(test_img) + + yield pipe + + del pipe + gc.collect() + + +def test_smoke_pipeline_loads(pipeline): + """Pipeline loads without errors and has required components.""" + assert pipeline is not None + assert pipeline.de_net_neuron is not None + assert pipeline.text_enc_neuron is not None + assert pipeline.vae_enc_neuron is not None + assert pipeline.unet_neuron is not None + assert pipeline.vae_dec_neuron is not None + assert pipeline.tokenizer is not None + assert pipeline.sched is not None + assert pipeline.compute_modulation is not None + + +def test_sr_produces_correct_size(pipeline): + """128x128 input produces 512x512 output.""" + test_img = make_test_image(LR_SIZE) + sr_img = pipeline(test_img) + + assert isinstance(sr_img, Image.Image) + assert sr_img.size == (HR_SIZE, HR_SIZE), ( + f"Expected ({HR_SIZE}, {HR_SIZE}), got {sr_img.size}" + ) + + # Verify reasonable pixel distribution (not blank) + sr_array = np.array(sr_img) + assert sr_array.shape == (HR_SIZE, HR_SIZE, 3) + assert sr_array.std() > 10, "Output appears blank or uniform" + + # Save for inspection + os.makedirs(os.path.join(COMPILE_DIR, "test_outputs"), exist_ok=True) + sr_img.save(os.path.join(COMPILE_DIR, "test_outputs", "test_sr.png")) + print(f"SR output saved, pixel std={sr_array.std():.1f}") + + +def test_warm_generation_time(pipeline): + """Warm generation should complete in < 2s.""" + test_img = make_test_image(LR_SIZE) + + t0 = time.time() + pipeline(test_img) + elapsed = time.time() - t0 + print(f"Warm generation time: {elapsed:.3f}s") + + assert elapsed < 2.0, f"Generation took {elapsed:.3f}s, expected < 2s" + + +# Standalone runner +if __name__ == "__main__": + print("=" * 60) + print("S3Diff Integration Tests") + print("=" * 60) + + pipe = S3DiffNeuronPipeline( + sd_turbo_path=SD_TURBO_PATH, + s3diff_weights_path=S3DIFF_WEIGHTS, + de_net_path=DE_NET_WEIGHTS, + compile_dir=COMPILE_DIR, + lr_size=LR_SIZE, + ) + + print("\n[1/5] Loading...") + pipe.load() + + print("\n[2/5] Compiling...") + t0 = time.time() + pipe.compile() + print(f" Compilation: {time.time() - t0:.1f}s") + + print("\n[3/5] Warmup...") + test_img = make_test_image() + pipe(test_img) + + print("\n[4/5] test_smoke_pipeline_loads") + test_smoke_pipeline_loads(pipe) + print(" PASSED") + + print("\n[5/5] test_sr_produces_correct_size") + test_sr_produces_correct_size(pipe) + print(" PASSED") + + print("\n[6/6] test_warm_generation_time") + test_warm_generation_time(pipe) + print(" PASSED") + + print("\nAll tests passed!") From 4758bab6dbbf82e0dbf465292ec37d21d22462f8 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 6 May 2026 08:26:18 -0400 Subject: [PATCH 2/4] Document S3Diff resolution limitation and torch.compile alternative trace() approach is validated at 128x128->512x512 only. Higher resolutions produce NaN or degraded output due to BF16 accumulation in LoRA einsum at larger spatial dims. Reference torch.compile alternative for multi-resolution (1K/2K/4K) use cases. --- contrib/models/S3Diff/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/contrib/models/S3Diff/README.md b/contrib/models/S3Diff/README.md index c0dfcbc0..f42952a8 100644 --- a/contrib/models/S3Diff/README.md +++ b/contrib/models/S3Diff/README.md @@ -130,7 +130,9 @@ python test/integration/test_model.py ## Known Issues -- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. The VAE encoder, UNet, and VAE decoder all use `--model-type=unet-inference` instead, which avoids this issue. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`. +- **Fixed resolution only (128x128 -> 512x512)**: This implementation uses `torch_neuronx.trace()` which compiles static tensor shapes. Each input resolution requires a separate compilation. The pipeline is validated at 128x128 input only. +- **Higher resolutions (256->1024, 512->2048, etc.) produce degraded or NaN output** when using trace(). The LoRA modulation einsum operations accumulate BF16 rounding errors at larger spatial dimensions. For multi-resolution or high-resolution (1K/2K/4K) super-resolution, use `torch.compile(backend="neuron")` on the UNet's `Transformer2DModel` blocks instead, with latent tiling for resolutions above 1K. See [xniwangaws/NeuronStuff/s3diff-benchmark](https://github.com/xniwangaws/NeuronStuff/tree/main/s3diff-benchmark) for a validated multi-resolution implementation using PyTorch Native (6.14s @ 1K, 60s @ 2K, 303s @ 4K). +- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. The VAE encoder, UNet, and VAE decoder all use `--model-type=unet-inference` instead, which avoids this issue at 128x128. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`. - **Compilation time**: ~21 minutes total (UNet is the slowest at ~12 min). Compiled models are cached for reuse. - **CFG is sequential**: Two separate UNet passes (positive + negative prompt), not batched. Batching with batch_size=2 would halve UNet wall time but requires recompilation. - **Neuron runtime HBM**: Once loaded, compiled models stay in HBM even if the Python object is deleted (within the same process). Plan memory accordingly. From 94c7731b414b70f70dcea5818a2d951f634315a8 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Wed, 6 May 2026 09:19:29 -0400 Subject: [PATCH 3/4] Add tiling support for arbitrary input resolutions Implements overlapping tile processing with Gaussian blending so that images larger than 512x512 HR (e.g., 256->1024, 512->2048) are handled without recompilation. All components remain compiled at fixed 512x512 tile size; larger images are split, processed per-tile, and blended. Benchmarks on trn2.3xlarge: - 128->512: 0.545s (single tile) - 256->1024: 4.8s (9 tiles) - 512->2048: 13.3s (25 tiles) Tests: 5/5 pass including new tiling tests. --- contrib/models/S3Diff/README.md | 65 +++-- contrib/models/S3Diff/src/generate_s3diff.py | 38 ++- contrib/models/S3Diff/src/modeling_s3diff.py | 249 +++++++++++++++--- .../S3Diff/test/integration/test_model.py | 53 +++- 4 files changed, 343 insertions(+), 62 deletions(-) diff --git a/contrib/models/S3Diff/README.md b/contrib/models/S3Diff/README.md index f42952a8..1d25e176 100644 --- a/contrib/models/S3Diff/README.md +++ b/contrib/models/S3Diff/README.md @@ -1,6 +1,7 @@ # Contrib Model: S3Diff S3Diff one-step 4x super-resolution on AWS Neuron using `torch_neuronx.trace()`. +Supports arbitrary input resolutions via tiling with Gaussian blending. ## Model Information @@ -19,16 +20,25 @@ S3Diff is unusual among diffusion models: 2. **Dynamic LoRA modulation**: A DEResNet encoder estimates input degradation and produces per-layer LoRA scaling matrices. These `[rank, rank]` modulation matrices are injected between `lora_A` and `lora_B` via einsum operations, conditioning the UNet on the specific degradation pattern of each input. 3. **Two LoRA ranks**: VAE uses rank=16 (6 blocks), UNet uses rank=32 (10 blocks). 4. **Small model**: Total size ~2 GB, fits on a single NeuronCore with no tensor parallelism needed. +5. **Arbitrary resolution via tiling**: All components are compiled at a fixed tile size (512x512 pixels). Images larger than this are split into overlapping tiles, processed independently, and blended with Gaussian weights for seamless output. This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, which is appropriate for the model's small size and non-autoregressive architecture. ## Validation Results -**Validated:** 2026-04-28 +**Validated:** 2026-05-06 **Instance:** trn2.3xlarge (LNC=2) **SDK:** Neuron SDK 2.29 (DLAMI 20260410), PyTorch 2.9 -### Benchmark Results (128x128 -> 512x512, single step) +### Benchmark Results (multi-resolution, single step) + +| Input Size | Output Size | Tiles | Time | Throughput | +|-----------|-------------|-------|------|------------| +| 128x128 | 512x512 | 1 | 0.545s | 1.8 img/s | +| 256x256 | 1024x1024 | 9 | 4.809s | 0.21 img/s | +| 512x512 | 2048x2048 | 25 | 13.346s | 0.075 img/s | + +### Component Timing (single tile, 512x512 output) | Component | Time | |-----------|------| @@ -37,21 +47,18 @@ This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, w | VAE Encode | 83.2ms | | UNet x2 (CFG) | 218.8ms | | VAE Decode | 164.6ms | -| **Total** | **0.471s** | +| **Total (single tile)** | **0.471s** | | Metric | Value | |--------|-------| -| Resolution | 128x128 -> 512x512 (4x SR) | | Inference steps | 1 (one-step model) | -| Warm generation time | 0.544s | -| Throughput | ~1.8 img/s | | Total compile time | ~21 min | -| CPU baseline | 11.53s | +| CPU baseline (128->512) | 11.53s | | Speedup vs CPU | ~21x | ### Accuracy Validation -Visual quality validated against CPU reference output. The model produces high-quality 4x upscaled images with correct degradation-aware enhancement. +Visual quality validated against CPU reference output. The model produces high-quality 4x upscaled images with correct degradation-aware enhancement. Tiled outputs show pixel std > 20 (indicating good visual detail) with seamless blending at tile boundaries. ## Usage @@ -64,24 +71,38 @@ pipeline = S3DiffNeuronPipeline( s3diff_weights_path="/shared/s3diff/s3diff.pkl", de_net_path="/shared/s3diff/de_net.pth", compile_dir="/tmp/s3diff/compiled/", - lr_size=128, + lr_size=128, # DEResNet fixed input (always 128) + tile_size=512, # HR tile size (default) + tile_overlap=128, # Overlap for blending (default) ) pipeline.load() pipeline.compile() -lr_image = Image.open("input_128x128.png").convert("RGB") -sr_image = pipeline(lr_image) -sr_image.save("output_512x512.png") +# Works with any input size -- tiling is automatic +lr_image = Image.open("input.png").convert("RGB") +sr_image = pipeline(lr_image) # 4x upscaled output +sr_image.save("output.png") ``` Or use the provided script: ```bash +# 128x128 -> 512x512 (single tile, fast) +python src/generate_s3diff.py \ + --input_image input_128.png \ + --output_image output_512.png + +# 256x256 -> 1024x1024 (tiled) +python src/generate_s3diff.py \ + --input_image input_256.png \ + --output_image output_1024.png + +# Custom tile settings python src/generate_s3diff.py \ - --download \ --input_image input.png \ --output_image output.png \ - --compile_dir /tmp/s3diff/compiled/ + --tile_size 512 \ + --tile_overlap 128 ``` ## Setup @@ -103,6 +124,17 @@ python src/generate_s3diff.py --download # cp /tmp/s3diff_repo/assets/mm-realsr/de_net.pth /shared/s3diff/ ``` +## Tiling Design + +For images whose HR output (input x 4) exceeds 512x512 pixels, the pipeline automatically: + +1. **Upscales** the input image to HR resolution via bicubic interpolation +2. **Splits** the HR image into overlapping 512x512 tiles (128px overlap by default) +3. **Processes** each tile independently through VAE encode -> UNet -> VAE decode +4. **Blends** tile outputs using Gaussian weights (smooth center-to-edge falloff) + +This approach avoids recompilation for different resolutions and produces seamless outputs. The degradation estimation (DEResNet) runs once on the full image resized to 128x128, producing global modulation parameters shared across all tiles. + ## Compatibility Matrix | Instance/Version | SDK 2.29 | SDK 2.28 | @@ -130,9 +162,8 @@ python test/integration/test_model.py ## Known Issues -- **Fixed resolution only (128x128 -> 512x512)**: This implementation uses `torch_neuronx.trace()` which compiles static tensor shapes. Each input resolution requires a separate compilation. The pipeline is validated at 128x128 input only. -- **Higher resolutions (256->1024, 512->2048, etc.) produce degraded or NaN output** when using trace(). The LoRA modulation einsum operations accumulate BF16 rounding errors at larger spatial dimensions. For multi-resolution or high-resolution (1K/2K/4K) super-resolution, use `torch.compile(backend="neuron")` on the UNet's `Transformer2DModel` blocks instead, with latent tiling for resolutions above 1K. See [xniwangaws/NeuronStuff/s3diff-benchmark](https://github.com/xniwangaws/NeuronStuff/tree/main/s3diff-benchmark) for a validated multi-resolution implementation using PyTorch Native (6.14s @ 1K, 60s @ 2K, 303s @ 4K). -- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. The VAE encoder, UNet, and VAE decoder all use `--model-type=unet-inference` instead, which avoids this issue at 128x128. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`. +- **LoRA + `--auto-cast=matmult` produces NaN**: The LoRA modulation einsum operations are numerically unstable when `--auto-cast=matmult` casts them to BF16. All components with LoRA use `--model-type=unet-inference` instead. Only DEResNet and text encoder (no LoRA) use `--auto-cast=matmult`. - **Compilation time**: ~21 minutes total (UNet is the slowest at ~12 min). Compiled models are cached for reuse. - **CFG is sequential**: Two separate UNet passes (positive + negative prompt), not batched. Batching with batch_size=2 would halve UNet wall time but requires recompilation. - **Neuron runtime HBM**: Once loaded, compiled models stay in HBM even if the Python object is deleted (within the same process). Plan memory accordingly. +- **Tiling artifacts at very high resolution**: At 4K+ output, very minor blending seams may be visible in uniform regions. Increasing `--tile_overlap` to 192 or 256 reduces this at the cost of more tiles. diff --git a/contrib/models/S3Diff/src/generate_s3diff.py b/contrib/models/S3Diff/src/generate_s3diff.py index db98bd14..eed8f6a7 100644 --- a/contrib/models/S3Diff/src/generate_s3diff.py +++ b/contrib/models/S3Diff/src/generate_s3diff.py @@ -5,7 +5,9 @@ S3Diff one-step 4x super-resolution on AWS Neuron. Downloads required weights (SD-Turbo, S3Diff LoRA, DEResNet), compiles all -components, and runs super-resolution inference. +components, and runs super-resolution inference. Supports arbitrary input +resolutions via tiling (images whose 4x upscaled size exceeds 512x512 are +automatically processed with overlapping tiles). Usage: python generate_s3diff.py \ @@ -13,6 +15,11 @@ --output_image /path/to/sr_output.png \ --compile_dir /tmp/s3diff/compiled/ + # Multi-resolution examples: + # 128x128 input -> 512x512 output (single tile, ~0.5s) + # 256x256 input -> 1024x1024 output (9 tiles, ~4.8s) + # 512x512 input -> 2048x2048 output (25 tiles, ~13.3s) + Requirements: pip install diffusers transformers peft accelerate torchvision """ @@ -84,7 +91,7 @@ def main(): "--input_image", type=str, default=None, - help="Path to input low-resolution image (128x128 recommended)", + help="Path to input low-resolution image (any size; will be 4x upscaled)", ) parser.add_argument( "--output_image", type=str, default="sr_output.png", help="Output path" @@ -95,6 +102,18 @@ def main(): parser.add_argument("--compile_dir", type=str, default=DEFAULT_COMPILE_DIR) parser.add_argument("--num_images", type=int, default=3) parser.add_argument("--warmup_rounds", type=int, default=5) + parser.add_argument( + "--tile_size", + type=int, + default=512, + help="Pixel-space tile size for VAE/UNet (default: 512). Must be divisible by 8.", + ) + parser.add_argument( + "--tile_overlap", + type=int, + default=128, + help="Pixel-space overlap between tiles (default: 128). Must be divisible by 8.", + ) parser.add_argument("--download", action="store_true", help="Download weights") args = parser.parse_args() @@ -110,15 +129,24 @@ def main(): lr_image = Image.fromarray(test_img) else: lr_image = Image.open(args.input_image).convert("RGB") - print(f"Input image: {lr_image.size}") - # Build pipeline + lr_w, lr_h = lr_image.size + hr_w, hr_h = lr_w * 4, lr_h * 4 + print(f"Input image: {lr_w}x{lr_h} -> Output: {hr_w}x{hr_h}") + if hr_h > args.tile_size or hr_w > args.tile_size: + print( + f"Tiling enabled (tile_size={args.tile_size}, overlap={args.tile_overlap})" + ) + + # Build pipeline (lr_size is always 128 for DEResNet) pipeline = S3DiffNeuronPipeline( sd_turbo_path=args.sd_turbo_path, s3diff_weights_path=args.s3diff_weights, de_net_path=args.de_net_weights, compile_dir=args.compile_dir, - lr_size=lr_image.size[0], + lr_size=128, + tile_size=args.tile_size, + tile_overlap=args.tile_overlap, ) print("\nLoading model...") diff --git a/contrib/models/S3Diff/src/modeling_s3diff.py b/contrib/models/S3Diff/src/modeling_s3diff.py index 5cc9dda4..d6eabdd1 100644 --- a/contrib/models/S3Diff/src/modeling_s3diff.py +++ b/contrib/models/S3Diff/src/modeling_s3diff.py @@ -28,10 +28,11 @@ from PIL import Image from torchvision import transforms -# Scale factor and padding constants +# Scale factor and tiling constants SF = 4 -PAD_H = 512 -PAD_W = 512 +TILE_SIZE = 512 # Pixel-space tile size (validated with trace()) +TILE_OVERLAP = 128 # Pixel-space overlap between tiles (must be divisible by 8) +LATENT_SCALE = 8 # VAE spatial downscale factor # --------------------------------------------------------------------------- # DEResNet -- degradation estimator @@ -253,6 +254,51 @@ def forward(self, latent): return self.vae.decode(latent / self.vae.config.scaling_factor).sample +# --------------------------------------------------------------------------- +# Tiling utilities +# --------------------------------------------------------------------------- + + +def _make_gaussian_weight(h: int, w: int) -> torch.Tensor: + """Create a 2D Gaussian blending weight mask for tile overlap regions. + + The weight is 1.0 at center and falls off toward edges, ensuring smooth + blending where tiles overlap. + """ + y = torch.linspace(-1, 1, h) + x = torch.linspace(-1, 1, w) + yy, xx = torch.meshgrid(y, x, indexing="ij") + # Gaussian with sigma that gives ~0.1 weight at edges + d = (xx**2 + yy**2) / 2.0 + weight = torch.exp(-d * 3.0) # sigma ~0.58, edge weight ~0.05 + return weight.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] + + +def _compute_tile_positions(total_size: int, tile_size: int, overlap: int): + """Compute tile start positions along one dimension. + + Returns list of (start, end) tuples covering the full dimension + with the specified overlap between adjacent tiles. + """ + if total_size <= tile_size: + return [(0, total_size)] + + stride = tile_size - overlap + positions = [] + start = 0 + while start < total_size: + end = min(start + tile_size, total_size) + # If the last tile is too small, shift it back + if end - start < tile_size and start > 0: + start = total_size - tile_size + end = total_size + positions.append((start, end)) + if end >= total_size: + break + start += stride + return positions + + # --------------------------------------------------------------------------- # S3Diff Neuron Pipeline # --------------------------------------------------------------------------- @@ -261,14 +307,24 @@ def forward(self, latent): class S3DiffNeuronPipeline: """End-to-end S3Diff super-resolution pipeline on Neuron. - Handles model loading, compilation, and inference. + Handles model loading, compilation, and inference. Supports arbitrary + input resolutions via tiling: images larger than the tile size (512x512 + pixel HR) are split into overlapping tiles, processed independently, + and blended with Gaussian weights. Args: sd_turbo_path: Path to SD-Turbo checkpoint (stabilityai/sd-turbo) s3diff_weights_path: Path to s3diff.pkl weights de_net_path: Path to de_net.pth weights compile_dir: Directory for compiled model cache - lr_size: Input low-resolution size (default: 128) + lr_size: DEResNet input size (default: 128). The DEResNet is compiled + at this fixed size. Input images are resized to this before + degradation estimation. + tile_size: Pixel-space tile size for VAE/UNet (default: 512). + Must be divisible by 8. + tile_overlap: Pixel-space overlap between tiles (default: 128). + Must be divisible by 8. Larger overlap = smoother blending + but slower processing. """ def __init__( @@ -278,13 +334,19 @@ def __init__( de_net_path: str, compile_dir: str = "/tmp/s3diff/compiled/", lr_size: int = 128, + tile_size: int = TILE_SIZE, + tile_overlap: int = TILE_OVERLAP, ): self.sd_turbo_path = sd_turbo_path self.s3diff_weights_path = s3diff_weights_path self.de_net_path = de_net_path self.compile_dir = compile_dir self.lr_size = lr_size - self.hr_size = lr_size * SF + self.tile_size = tile_size + self.tile_overlap = tile_overlap + + assert tile_size % LATENT_SCALE == 0, "tile_size must be divisible by 8" + assert tile_overlap % LATENT_SCALE == 0, "tile_overlap must be divisible by 8" # Will be set during load() self.de_net_neuron = None @@ -479,13 +541,17 @@ def compute_modulation(deg_score): print("Model loaded successfully.") def compile(self): - """Compile all components with torch_neuronx.trace().""" + """Compile all components with torch_neuronx.trace(). + + Components are compiled at fixed tile_size (default 512x512 pixels). + Larger images are processed via tiling at inference time. + """ import torch_neuronx os.makedirs(self.compile_dir, exist_ok=True) lr_h, lr_w = self.lr_size, self.lr_size - hr_h, hr_w = self.hr_size, self.hr_size - lat_h, lat_w = hr_h // 8, hr_w // 8 + tile_h, tile_w = self.tile_size, self.tile_size + lat_h, lat_w = tile_h // LATENT_SCALE, tile_w // LATENT_SCALE # DEResNet path = os.path.join(self.compile_dir, "de_net.pt") @@ -530,7 +596,7 @@ def compile(self): self.vae_enc_neuron = torch_neuronx.trace( self._vae_enc_wrapper, ( - torch.randn(1, 3, hr_h, hr_w), + torch.randn(1, 3, tile_h, tile_w), torch.randn(1, 6, self.lora_rank_vae, self.lora_rank_vae), ), compiler_args=["--model-type=unet-inference", "-O1"], @@ -587,8 +653,12 @@ def __call__( ) -> Image.Image: """Run 4x super-resolution on a low-resolution PIL image. + Supports arbitrary input sizes via tiling. Images whose HR size + (input * 4) exceeds tile_size are automatically split into overlapping + tiles, processed independently, and blended with Gaussian weights. + Args: - lr_image: Input PIL image (should be lr_size x lr_size) + lr_image: Input PIL image (any size; will be 4x upscaled) pos_prompt: Positive text prompt neg_prompt: Negative text prompt cfg_scale: Classifier-free guidance scale @@ -599,29 +669,22 @@ def __call__( to_tensor = transforms.ToTensor() im_lr = to_tensor(lr_image).unsqueeze(0) - # Preprocess: resize 4x + normalize + pad + # Resize LR image for DEResNet (fixed lr_size) ori_h, ori_w = im_lr.shape[2:] - im_lr_resize = F.interpolate( + im_lr_for_de = F.interpolate( im_lr, - size=(ori_h * SF, ori_w * SF), + size=(self.lr_size, self.lr_size), mode="bilinear", align_corners=False, ) - im_lr_resize_norm = (im_lr_resize * 2 - 1.0).clamp(-1, 1) - resize_h, resize_w = im_lr_resize_norm.shape[2:] - im_lr_resize_norm = F.pad( - im_lr_resize_norm, - pad=(0, PAD_W - resize_w, 0, PAD_H - resize_h), - mode="reflect", - ) - # 1. DEResNet -> degradation scores - deg_score = self.de_net_neuron(im_lr) + # 1. DEResNet -> degradation scores (on fixed lr_size input) + deg_score = self.de_net_neuron(im_lr_for_de) - # 2. Compute modulation on CPU + # 2. Compute modulation on CPU (same for all tiles) vae_de_mod_all, unet_de_mod_all = self.compute_modulation(deg_score) - # 3. Text encoding + # 3. Text encoding (same for all tiles) pos_tokens = self.tokenizer( pos_prompt, max_length=self.tokenizer.model_max_length, @@ -639,21 +702,139 @@ def __call__( pos_enc = self.text_enc_neuron(pos_tokens) neg_enc = self.text_enc_neuron(neg_tokens) - # 4. VAE Encode (with de_mod) - latent = self.vae_enc_neuron(im_lr_resize_norm, vae_de_mod_all) + # Prepare HR image (4x bicubic upscale + normalize) + hr_h, hr_w = ori_h * SF, ori_w * SF + im_hr = F.interpolate( + im_lr, + size=(hr_h, hr_w), + mode="bilinear", + align_corners=False, + ) + im_hr_norm = (im_hr * 2 - 1.0).clamp(-1, 1) + + # Determine if tiling is needed + if hr_h <= self.tile_size and hr_w <= self.tile_size: + # Single tile path (pad to tile_size if needed) + pad_h = self.tile_size - hr_h + pad_w = self.tile_size - hr_w + if pad_h > 0 or pad_w > 0: + im_hr_norm = F.pad( + im_hr_norm, + pad=(0, pad_w, 0, pad_h), + mode="reflect", + ) + output = self._process_tile( + im_hr_norm, + vae_de_mod_all, + unet_de_mod_all, + pos_enc, + neg_enc, + cfg_scale, + ) + output = output[:, :, :hr_h, :hr_w] + else: + # Tiled path for large images + output = self._process_tiled( + im_hr_norm, + hr_h, + hr_w, + vae_de_mod_all, + unet_de_mod_all, + pos_enc, + neg_enc, + cfg_scale, + ) - # 5. UNet x2 for CFG (with de_mod) + return transforms.ToPILImage()((output[0] * 0.5 + 0.5).cpu().clamp(0, 1)) + + def _process_tile( + self, tile_pixels, vae_de_mod, unet_de_mod, pos_enc, neg_enc, cfg_scale + ): + """Process a single tile_size x tile_size pixel tile through the full pipeline.""" + # VAE Encode + latent = self.vae_enc_neuron(tile_pixels, vae_de_mod) + + # UNet x2 for CFG timestep = torch.tensor([999], dtype=torch.long) - pos_pred = self.unet_neuron(latent, timestep, pos_enc, unet_de_mod_all) - neg_pred = self.unet_neuron(latent, timestep, neg_enc, unet_de_mod_all) + pos_pred = self.unet_neuron(latent, timestep, pos_enc, unet_de_mod) + neg_pred = self.unet_neuron(latent, timestep, neg_enc, unet_de_mod) model_pred = neg_pred + cfg_scale * (pos_pred - neg_pred) - # 6. Scheduler step (CPU) + # Scheduler step (CPU) x_denoised = self.sched.step( model_pred.cpu(), torch.tensor([999]), latent.cpu(), return_dict=True ).prev_sample - # 7. VAE Decode + # VAE Decode output = self.vae_dec_neuron(x_denoised).clamp(-1, 1) - output = output[:, :, :resize_h, :resize_w] - return transforms.ToPILImage()((output[0] * 0.5 + 0.5).cpu().clamp(0, 1)) + return output + + def _process_tiled( + self, + im_hr_norm, + hr_h, + hr_w, + vae_de_mod, + unet_de_mod, + pos_enc, + neg_enc, + cfg_scale, + ): + """Process a large image via overlapping tiles with Gaussian blending.""" + tile_size = self.tile_size + overlap = self.tile_overlap + + # Compute tile positions + row_positions = _compute_tile_positions(hr_h, tile_size, overlap) + col_positions = _compute_tile_positions(hr_w, tile_size, overlap) + + # Prepare output accumulator and weight map + output_acc = torch.zeros(1, 3, hr_h, hr_w) + weight_acc = torch.zeros(1, 1, hr_h, hr_w) + + # Gaussian weight for blending + gauss_weight = _make_gaussian_weight(tile_size, tile_size) + + n_tiles = len(row_positions) * len(col_positions) + tile_idx = 0 + + for y_start, y_end in row_positions: + for x_start, x_end in col_positions: + tile_idx += 1 + th = y_end - y_start + tw = x_end - x_start + + # Extract tile (pad if at edge and smaller than tile_size) + tile = im_hr_norm[:, :, y_start:y_end, x_start:x_end] + pad_h = tile_size - th + pad_w = tile_size - tw + if pad_h > 0 or pad_w > 0: + tile = F.pad(tile, pad=(0, pad_w, 0, pad_h), mode="reflect") + + # Process tile + tile_output = self._process_tile( + tile, + vae_de_mod, + unet_de_mod, + pos_enc, + neg_enc, + cfg_scale, + ) + + # Crop back to actual tile dimensions + tile_output = tile_output[:, :, :th, :tw] + + # Blend with Gaussian weight + w = gauss_weight[:, :, :th, :tw] + output_acc[:, :, y_start:y_end, x_start:x_end] += tile_output.cpu() * w + weight_acc[:, :, y_start:y_end, x_start:x_end] += w + + if n_tiles > 1: + print( + f" Tile {tile_idx}/{n_tiles} " + f"[{y_start}:{y_end}, {x_start}:{x_end}]" + ) + + # Normalize by accumulated weights + output = output_acc / weight_acc.clamp(min=1e-8) + return output.clamp(-1, 1) diff --git a/contrib/models/S3Diff/test/integration/test_model.py b/contrib/models/S3Diff/test/integration/test_model.py index c00ad580..51274a6a 100644 --- a/contrib/models/S3Diff/test/integration/test_model.py +++ b/contrib/models/S3Diff/test/integration/test_model.py @@ -4,7 +4,8 @@ Tests: 1. test_smoke_pipeline_loads: Pipeline loads without errors 2. test_sr_produces_correct_size: 128x128 -> 512x512 -3. test_warm_generation_time: < 2s per image +3. test_warm_generation_time: < 2s per image (single tile) +4. test_tiled_sr_256_to_1024: 256x256 -> 1024x1024 via tiling Requirements: - trn2.3xlarge @@ -118,6 +119,38 @@ def test_warm_generation_time(pipeline): assert elapsed < 2.0, f"Generation took {elapsed:.3f}s, expected < 2s" +def test_tiled_sr_256_to_1024(pipeline): + """256x256 input produces 1024x1024 output via tiling.""" + test_img = make_test_image(256) + sr_img = pipeline(test_img) + + assert isinstance(sr_img, Image.Image) + assert sr_img.size == (1024, 1024), f"Expected (1024, 1024), got {sr_img.size}" + + # Verify reasonable pixel distribution (not blank/NaN) + sr_array = np.array(sr_img) + assert sr_array.shape == (1024, 1024, 3) + assert sr_array.std() > 10, "Output appears blank or uniform" + assert not np.any(np.isnan(sr_array.astype(np.float32))), "Output contains NaN" + + # Save for inspection + os.makedirs(os.path.join(COMPILE_DIR, "test_outputs"), exist_ok=True) + sr_img.save(os.path.join(COMPILE_DIR, "test_outputs", "test_sr_1024.png")) + print(f"Tiled SR output (1024x1024) saved, pixel std={sr_array.std():.1f}") + + +def test_tiled_sr_timing(pipeline): + """Tiled 256->1024 generation should complete in < 10s.""" + test_img = make_test_image(256) + + t0 = time.time() + pipeline(test_img) + elapsed = time.time() - t0 + print(f"Tiled generation time (256->1024): {elapsed:.3f}s") + + assert elapsed < 10.0, f"Tiled generation took {elapsed:.3f}s, expected < 10s" + + # Standalone runner if __name__ == "__main__": print("=" * 60) @@ -132,23 +165,23 @@ def test_warm_generation_time(pipeline): lr_size=LR_SIZE, ) - print("\n[1/5] Loading...") + print("\n[1/6] Loading...") pipe.load() - print("\n[2/5] Compiling...") + print("\n[2/6] Compiling...") t0 = time.time() pipe.compile() print(f" Compilation: {time.time() - t0:.1f}s") - print("\n[3/5] Warmup...") + print("\n[3/6] Warmup...") test_img = make_test_image() pipe(test_img) - print("\n[4/5] test_smoke_pipeline_loads") + print("\n[4/6] test_smoke_pipeline_loads") test_smoke_pipeline_loads(pipe) print(" PASSED") - print("\n[5/5] test_sr_produces_correct_size") + print("\n[5/6] test_sr_produces_correct_size") test_sr_produces_correct_size(pipe) print(" PASSED") @@ -156,4 +189,12 @@ def test_warm_generation_time(pipeline): test_warm_generation_time(pipe) print(" PASSED") + print("\n[7/7] test_tiled_sr_256_to_1024") + test_tiled_sr_256_to_1024(pipe) + print(" PASSED") + + print("\n[8/8] test_tiled_sr_timing") + test_tiled_sr_timing(pipe) + print(" PASSED") + print("\nAll tests passed!") From f8910f0f3ddf17916e2ae26fdb5d45eca5a90de0 Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Thu, 28 May 2026 14:58:28 -0400 Subject: [PATCH 4/4] Add 4K benchmark results and SDK 2.30 validation for S3Diff Add 1024x1024 -> 4096x4096 benchmark (121 tiles, 64.09s) to results table. Note that the original model was trained at 128->512 only; higher resolutions use tiled processing with unchanged compiled NEFFs and scale linearly. Also update compatibility matrix with SDK 2.30 validation. --- contrib/models/S3Diff/README.md | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/contrib/models/S3Diff/README.md b/contrib/models/S3Diff/README.md index 1d25e176..80d8779c 100644 --- a/contrib/models/S3Diff/README.md +++ b/contrib/models/S3Diff/README.md @@ -37,6 +37,12 @@ This contrib uses `torch_neuronx.trace()` rather than NxDI tensor parallelism, w | 128x128 | 512x512 | 1 | 0.545s | 1.8 img/s | | 256x256 | 1024x1024 | 9 | 4.809s | 0.21 img/s | | 512x512 | 2048x2048 | 25 | 13.346s | 0.075 img/s | +| 1024x1024 | 4096x4096 | 121 | 64.09s | 0.016 img/s | + +> **Note:** The original S3Diff model was trained on 128x128 → 512x512 crops only. Higher +> resolution outputs (2K, 4K) use tiled processing with the same compiled 512x512 NEFFs — +> no recompilation is needed. Per-tile latency is constant at ~0.530s regardless of total +> image size. Scaling is linear: latency = tiles x 0.530s + overhead. ### Component Timing (single tile, 512x512 output) @@ -97,12 +103,17 @@ python src/generate_s3diff.py \ --input_image input_256.png \ --output_image output_1024.png -# Custom tile settings +# 1024x1024 -> 4096x4096 (121 tiles, ~64s) +python src/generate_s3diff.py \ + --input_image input_1024.png \ + --output_image output_4096.png + +# Custom tile settings (increase overlap for fewer seams at 4K) python src/generate_s3diff.py \ --input_image input.png \ --output_image output.png \ --tile_size 512 \ - --tile_overlap 128 + --tile_overlap 192 ``` ## Setup @@ -137,9 +148,9 @@ This approach avoids recompilation for different resolutions and produces seamle ## Compatibility Matrix -| Instance/Version | SDK 2.29 | SDK 2.28 | +| Instance/Version | SDK 2.29 | SDK 2.30 | |------------------|----------|----------| -| trn2.3xlarge | VALIDATED | Not tested | +| trn2.3xlarge | VALIDATED | VALIDATED | ## Example Checkpoints