diff --git a/configs/td_config.yaml.example b/configs/td_config.yaml.example index 48a620a26..89248c3b8 100644 --- a/configs/td_config.yaml.example +++ b/configs/td_config.yaml.example @@ -78,6 +78,12 @@ engine_dir: "engines/td" # ControlNet configuration (disabled) use_controlnet: false +# cn_cache_interval: reuse CN residuals every N frames instead of recomputing each frame. +# 1 = disabled (default, always recompute). 2+ = skip forward on intermediate frames. +# Safe to change live; invalidated automatically on control-image or scale change. +# Note: cache key does NOT include t_index_list — avoid changing batch config mid-stream +# while caching is active. Low practical risk but noted. +cn_cache_interval: 1 # IPAdapter configuration (disabled) use_ipadapter: false diff --git a/demo/realtime-img2img/controlnet_registry.yaml b/demo/realtime-img2img/controlnet_registry.yaml index 7da63eee6..cd4352bec 100644 --- a/demo/realtime-img2img/controlnet_registry.yaml +++ b/demo/realtime-img2img/controlnet_registry.yaml @@ -55,7 +55,7 @@ available_controlnets: - id: "tile_sd15" name: "Tile/Feedback" model_id: "lllyasviel/control_v11f1e_sd15_tile" - default_preprocessor: "feedback" + default_preprocessor: "passthrough" default_scale: 0.6 description: "Uses image feedback for enhanced details" preprocessor_params: @@ -116,8 +116,27 @@ available_controlnets: - id: "tile_sdxl" name: "Tile/Feedback" model_id: "xinsir/controlnet-tile-sdxl-1.0" - default_preprocessor: "feedback" + default_preprocessor: "passthrough" default_scale: 0.6 description: "Uses image feedback for enhanced details (SDXL)" + preprocessor_params: + image_resolution: 512 + + - id: "depth_xinsir_sdxl" + name: "Depth Detection (xinsir)" + model_id: "xinsir/controlnet-depth-sdxl-1.0" + default_preprocessor: "depth_tensorrt" + default_scale: 0.8 + description: "Estimates depth information from images — xinsir SDXL variant" + preprocessor_params: + detect_resolution: 518 + image_resolution: 512 + + - id: "scribble_sdxl" + name: "Scribble" + model_id: "xinsir/controlnet-scribble-sdxl-1.0" + default_preprocessor: "scribble_tensorrt" + default_scale: 0.8 + description: "Produces sketch-like scribble edge conditioning (SDXL)" preprocessor_params: image_resolution: 512 \ No newline at end of file diff --git a/demo/realtime-img2img/util.py b/demo/realtime-img2img/util.py index 0ef21421b..e64ddc36a 100644 --- a/demo/realtime-img2img/util.py +++ b/demo/realtime-img2img/util.py @@ -29,21 +29,31 @@ def bytes_to_pil(image_bytes: bytes) -> Image.Image: def bytes_to_pt(image_bytes: bytes) -> torch.Tensor: """ - Convert JPEG/PNG bytes directly to PyTorch tensor using torchvision - + Convert JPEG bytes directly to a GPU float32 tensor via torchvision nvJPEG. + + Decodes on CUDA when available (nvJPEG path), eliminating the CPU decode + + host→device DMA transfer that the CPU path incurs. Falls back to CPU decode + on machines without CUDA. + Args: - image_bytes: Raw image bytes (JPEG/PNG format) - + image_bytes: Raw JPEG bytes (PNG bytes fall back to CPU automatically + since nvJPEG only handles JPEG) + + Returns: - torch.Tensor: Image tensor with shape (C, H, W), values in [0, 1], dtype float32 + torch.Tensor: Image tensor with shape (C, H, W), values in [0, 1], + dtype float32, on the same device as the decode. """ - # Convert bytes to tensor for torchvision byte_tensor = torch.frombuffer(image_bytes, dtype=torch.uint8) - - # Decode JPEG/PNG directly to tensor (C, H, W) format, uint8 [0, 255] - image_tensor = decode_jpeg(byte_tensor) - - # Convert to float32 and normalize to [0, 1] + + # Decode directly on GPU when CUDA is available — nvJPEG avoids the + # CPU decode + H2D copy incurred by the plain decode_jpeg(byte_tensor) call. + if torch.cuda.is_available(): + image_tensor = decode_jpeg(byte_tensor, device="cuda") + else: + image_tensor = decode_jpeg(byte_tensor) + + # Normalise to [0, 1] on the decode device (fused kernel on GPU). image_tensor = image_tensor.float() / 255.0 return image_tensor diff --git a/examples/benchmark/ab_bench.py b/examples/benchmark/ab_bench.py new file mode 100644 index 000000000..a59a1ed90 --- /dev/null +++ b/examples/benchmark/ab_bench.py @@ -0,0 +1,452 @@ +""" +ab_bench.py — A/B benchmark harness for StreamDiffusion performance work. + +Runs warmup + timed frames, records per-frame CUDA-event timings and per-region +profiler stats, and writes a JSON keyed by git-SHA + config-hash so before/after +runs can be compared without manual bookkeeping. + +USAGE +----- +# Bare-pipeline (inline defaults, no GPU_PROFILER env needed for frame timing): +python examples/benchmark/ab_bench.py + +# With a full config (includes ControlNet / IPAdapter / ESRGAN): +GPU_PROFILER=1 python examples/benchmark/ab_bench.py --config path/to/config.yaml + +# Selective config override: +GPU_PROFILER=1 python examples/benchmark/ab_bench.py \\ + --config configs/cn_tile.yaml \\ + --iterations 200 \\ + --warmup 20 \\ + --image /path/to/input.jpg \\ + --style-image /path/to/style.jpg \\ + --output-dir examples/benchmark/results + +# Save output frames as PNGs for visual before/after comparison: +python examples/benchmark/ab_bench.py --save-goldens --n-golden-frames 5 + +READING RESULTS +--------------- +Each run writes: + /__.json + +The JSON contains: + - "run": metadata (sha, config_hash, config_path, timestamp, iterations, warmup) + - "frame_ms": per-frame CUDA timings {p50, p95, p99, mean, min, max} + - "fps": fps stats derived from frame timings + - "regions": per-region profiler stats (only present when GPU_PROFILER=1) + +COMPARISON +---------- +Diff two runs: + python -c " + import json, sys + a, b = [json.load(open(p)) for p in sys.argv[1:3]] + for k in ['p50','p95','p99']: + va, vb = a['frame_ms'][k], b['frame_ms'][k] + print(f'frame {k}: {va:.2f} -> {vb:.2f} ms ({vb-va:+.2f} ms)') + " results/sha1_hash1_*.json results/sha2_hash2_*.json + +GPU_PROFILER env vars +--------------------- + GPU_PROFILER=1 - enable region timing + GPU_PROFILER_NVTX=0 - disable NVTX (required when CUDA graphs active) + GPU_PROFILER_EVENTS=1 - (default) CUDA-event timing +""" + +from __future__ import annotations + +import hashlib +import json +import subprocess +import sys +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + + +if TYPE_CHECKING: + import PIL.Image # noqa: F401 — used in type annotations only + +import numpy as np +import torch +from tqdm import tqdm + + +# ── repo root on sys.path so streamdiffusion is importable without install ───── +_REPO_ROOT = Path(__file__).resolve().parents[2] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from streamdiffusion.tools.gpu_profiler import configure as _prof_configure # noqa: E402 +from streamdiffusion.tools.gpu_profiler import profiler # noqa: E402 + + +# ────────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────────── + + +def _git_sha() -> str: + """Return the current HEAD short SHA, or 'unknown' when git is unavailable.""" + try: + result = subprocess.run( + ["git", "rev-parse", "--short", "HEAD"], + capture_output=True, + text=True, + cwd=str(_REPO_ROOT), + ) + return result.stdout.strip() if result.returncode == 0 else "unknown" + except FileNotFoundError: + return "unknown" + + +def _config_hash(config: Dict[str, Any]) -> str: + """Return an 8-char SHA-1 of the sorted JSON-serialised config.""" + blob = json.dumps(config, sort_keys=True, default=str).encode() + return hashlib.sha1(blob).hexdigest()[:8] + + +def _make_synthetic_image(width: int, height: int) -> PIL.Image.Image: + """Create a solid grey PIL image so the benchmark runs without a real photo.""" + import PIL.Image + import PIL.ImageDraw + + img = PIL.Image.new("RGB", (width, height), color=(128, 128, 128)) + draw = PIL.ImageDraw.Draw(img) + # Add a simple pattern so it's not a zero tensor (avoids edge-case norms). + for x in range(0, width, 64): + draw.line([(x, 0), (x, height)], fill=(160, 160, 160), width=1) + for y in range(0, height, 64): + draw.line([(0, y), (width, y)], fill=(160, 160, 160), width=1) + return img + + +def _load_or_synth_image(path: Optional[str], width: int, height: int) -> PIL.Image.Image: + if path: + import PIL.Image + + return PIL.Image.open(path).convert("RGB").resize((width, height)) + return _make_synthetic_image(width, height) + + +def _percentile_stats(samples: List[float]) -> Dict[str, float]: + arr = np.array(samples) + return { + "mean": float(np.mean(arr)), + "p50": float(np.percentile(arr, 50)), + "p95": float(np.percentile(arr, 95)), + "p99": float(np.percentile(arr, 99)), + "min": float(np.min(arr)), + "max": float(np.max(arr)), + } + + +def _fps_stats(frame_ms_stats: Dict[str, float]) -> Dict[str, float]: + """Convert ms-per-frame stats to FPS stats (note: p50 ms → median FPS, etc.).""" + return {k: round(1000.0 / v, 2) if v > 0 else 0.0 for k, v in frame_ms_stats.items()} + + +# ────────────────────────────────────────────────────────────────────────────── +# Core benchmark loop +# ────────────────────────────────────────────────────────────────────────────── + + +def _to_pil(frame: Any) -> Optional[PIL.Image.Image]: + """Best-effort conversion of a pipeline output to PIL Image for golden saving.""" + import PIL.Image + + if isinstance(frame, PIL.Image.Image): + return frame + if isinstance(frame, np.ndarray): + arr = frame + if arr.dtype != np.uint8: + arr = (arr * 255).clip(0, 255).astype(np.uint8) + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4): + arr = arr.transpose(1, 2, 0) # CHW → HWC + return PIL.Image.fromarray(arr.squeeze()) + if hasattr(frame, "cpu"): # torch.Tensor + arr = frame.cpu().float().numpy() + if arr.max() <= 1.0: + arr = (arr * 255).clip(0, 255).astype(np.uint8) + else: + arr = arr.clip(0, 255).astype(np.uint8) + if arr.ndim == 3 and arr.shape[0] in (1, 3, 4): + arr = arr.transpose(1, 2, 0) + return PIL.Image.fromarray(arr.squeeze()) + return None + + +def _run_loop( + stream: Any, + image_tensor: Any, + iterations: int, + warmup: int, + n_capture: int = 0, +) -> Tuple[List[float], List[Any]]: + """Warmup then time `iterations` frames. + + Returns + ------- + frame_times : list of per-frame ms values (length == iterations) + captured : first ``n_capture`` raw pipeline outputs (empty when n_capture=0) + """ + + # ── warmup (no timing) ───────────────────────────────────────────────── + print(f"[ab_bench] Warming up ({warmup} frames)…") + for _ in range(warmup): + stream(image=image_tensor) + + torch.cuda.synchronize() + + # ── timed loop ───────────────────────────────────────────────────────── + print(f"[ab_bench] Timing {iterations} frames…") + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + frame_times: List[float] = [] + captured: List[Any] = [] + + for _ in tqdm(range(iterations)): + start_evt.record() + output = stream(image=image_tensor) + end_evt.record() + torch.cuda.synchronize() + frame_times.append(start_evt.elapsed_time(end_evt)) + if n_capture > 0 and len(captured) < n_capture: + captured.append(output) + + return frame_times, captured + + +# ────────────────────────────────────────────────────────────────────────────── +# Public entry point +# ────────────────────────────────────────────────────────────────────────────── + + +def run( + config: Optional[str] = None, + iterations: int = 200, + warmup: int = 10, + image: Optional[str] = None, + style_image: Optional[str] = None, + output_dir: str = "examples/benchmark/results", + # ── bare-pipeline defaults (ignored when --config is provided) ───────── + model_id: str = "KBlueLeaf/kohaku-v2.1", + width: int = 512, + height: int = 512, + prompt: str = "1girl with brown dog hair, thick glasses, smiling", + negative_prompt: str = "bad image, bad quality", + acceleration: str = "tensorrt", + # ── behaviour flags ──────────────────────────────────────────────────── + gpu_profiler: bool = False, + nvtx: bool = False, + # ── golden capture ───────────────────────────────────────────────────── + save_goldens: bool = False, + n_golden_frames: int = 5, +) -> None: + """ + A/B benchmark harness for StreamDiffusion performance improvements. + + Parameters + ---------- + config : str, optional + Path to a StreamDiffusion YAML/JSON config file. When provided the + full config (including ControlNet, IPAdapter, ESRGAN) drives the run. + CLI flags below are ignored except --iterations/--warmup/--image/ + --style-image/--output-dir. + iterations : int + Number of timed frames (after warmup). Default 200. + warmup : int + Warmup frames (not timed). Overrides config.warmup when --config given. + Default 10. + image : str, optional + Path to input image. Synthetic grey image used when absent. + style_image : str, optional + Path to style/reference image for IPAdapter. Synthetic image used when + absent and an IPAdapter is active. + output_dir : str + Directory for JSON result files. Created if absent. + model_id : str + HuggingFace model ID for the bare-pipeline (no --config) path. + width / height : int + Image resolution for the bare-pipeline path. + prompt / negative_prompt : str + Prompts for the bare-pipeline path. + acceleration : str + Acceleration backend for the bare-pipeline path ("tensorrt", "xformers", "none"). + gpu_profiler : bool + Activate region-level profiling (equivalent to GPU_PROFILER=1). + NVTX is off by default; set --nvtx to enable (breaks CUDA graphs). + nvtx : bool + Enable NVTX markers (only useful for Nsight Systems; disable with CUDA graphs). + save_goldens : bool + Capture the first ``n_golden_frames`` output frames and save them as + PNG files alongside the JSON result. Useful for visual before/after + comparison after pipeline changes (e.g. the antialias resize fix). + Files are named ``__golden_NN.png`` in ``output_dir``. + n_golden_frames : int + Number of output frames to capture when ``--save-goldens`` is set. + Default 5. + """ + # ── activate profiler (env var takes precedence, flag is an alternative) ─ + _prof_configure(enabled=gpu_profiler, nvtx=nvtx) # reads GPU_PROFILER env internally + # After configure, profiler is either active or a null-op depending on env/flag. + + # ── build the stream ─────────────────────────────────────────────────── + bench_config: Dict[str, Any] = {} + + if config is not None: + # Config-file path (CN / IPA / ESRGAN configs live here) + from streamdiffusion.config import create_wrapper_from_config, load_config + + bench_config = load_config(config) + # CLI overrides for iteration control + if warmup != 10: + bench_config["warmup"] = warmup + print(f"[ab_bench] Building stream from config: {config}") + stream = create_wrapper_from_config(bench_config) + _width = bench_config.get("width", 512) + _height = bench_config.get("height", 512) + _has_ipa = bool(bench_config.get("ipadapters") or bench_config.get("use_ipadapter")) + else: + # Bare-pipeline inline defaults (no config file) + from streamdiffusion import StreamDiffusionWrapper + + bench_config = { + "model_id": model_id, + "width": width, + "height": height, + "acceleration": acceleration, + "warmup": warmup, + "prompt": prompt, + "negative_prompt": negative_prompt, + } + print(f"[ab_bench] Building bare-pipeline stream (model={model_id})") + stream = StreamDiffusionWrapper( + model_id_or_path=model_id, + t_index_list=[32, 45], + mode="img2img", + frame_buffer_size=1, + width=width, + height=height, + warmup=warmup, + acceleration=acceleration, + use_tiny_vae=True, + enable_similar_image_filter=False, + use_denoising_batch=True, + cfg_type="self", + seed=2, + ) + stream.prepare( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=50, + guidance_scale=1.2, + delta=0.5, + ) + _width, _height = width, height + _has_ipa = False + + # ── prepare input image ──────────────────────────────────────────────── + pil_image = _load_or_synth_image(image, _width, _height) + image_tensor = stream.preprocess_image(pil_image) + print(f"[ab_bench] Input: {'file ' + image if image else 'synthetic'} ({_width}×{_height})") + + # ── inject style image for IPAdapter if needed ───────────────────────── + if _has_ipa: + pil_style = _load_or_synth_image(style_image, _width, _height) + print(f"[ab_bench] Style: {'file ' + style_image if style_image else 'synthetic'}") + stream.update_style_image(pil_style) + + # ── run the loop ─────────────────────────────────────────────────────── + n_capture = n_golden_frames if save_goldens else 0 + frame_times, captured_frames = _run_loop(stream, image_tensor, iterations, warmup, n_capture) + + # ── flush profiler and collect region stats ──────────────────────────── + profiler.flush() + profiler.report() + + # ── compute stats ───────────────────────────────────────────────────── + frame_stats = _percentile_stats(frame_times) + fps_stats = _fps_stats(frame_stats) + + sha = _git_sha() + cfg_hash = _config_hash(bench_config) + timestamp = time.strftime("%Y%m%d_%H%M%S") + result_fname = f"{sha}_{cfg_hash}_{timestamp}.json" + + # ── collect region stats if profiler is active ───────────────────────── + region_stats: List[Dict] = [] + try: + inner = object.__getattribute__(profiler, "_inner") + if hasattr(inner, "_regions"): + region_stats = [s.to_dict() for s in inner._regions.values()] + except (AttributeError, TypeError): + pass + + result: Dict[str, Any] = { + "run": { + "git_sha": sha, + "config_hash": cfg_hash, + "config_path": config, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + "iterations": iterations, + "warmup": warmup, + "width": _width, + "height": _height, + }, + "frame_ms": {k: round(v, 3) for k, v in frame_stats.items()}, + "fps": fps_stats, + "regions": region_stats, + } + + # ── write JSON ───────────────────────────────────────────────────────── + out_path = Path(output_dir) + out_path.mkdir(parents=True, exist_ok=True) + result_path = out_path / result_fname + + # Also export profiler's own JSON if active + if region_stats: + profiler.export_stats(str(out_path / f"{sha}_{cfg_hash}_{timestamp}_regions.json")) + + with open(result_path, "w") as fh: + json.dump(result, fh, indent=2) + + # ── save goldens ─────────────────────────────────────────────────────── + if save_goldens and captured_frames: + saved = 0 + for i, frame in enumerate(captured_frames): + pil = _to_pil(frame) + if pil is None: + print(f"[ab_bench] --save-goldens: cannot serialise frame {i} (type {type(frame).__name__}), skipping") + continue + golden_path = out_path / f"{sha}_{cfg_hash}_golden_{i:02d}.png" + pil.save(str(golden_path)) + saved += 1 + print(f"[ab_bench] {saved}/{len(captured_frames)} goldens saved to {out_path}/") + + # ── human-readable summary ───────────────────────────────────────────── + print() + print("=" * 60) + print(f"[ab_bench] Results ({iterations} frames, {warmup} warmup)") + print(f" SHA: {sha} cfg: {cfg_hash}") + print(f" frame p50: {frame_stats['p50']:.2f} ms ({fps_stats['p50']:.1f} FPS)") + print(f" frame p95: {frame_stats['p95']:.2f} ms ({fps_stats['p95']:.1f} FPS)") + print(f" frame p99: {frame_stats['p99']:.2f} ms ({fps_stats['p99']:.1f} FPS)") + print(f" frame mean: {frame_stats['mean']:.2f} ms ({fps_stats['mean']:.1f} FPS)") + if region_stats: + top = sorted(region_stats, key=lambda r: r.get("total_ms", 0), reverse=True)[:8] + print() + print(f" {'Region':<30} {'p50':>8} {'p95':>8} {'count':>6}") + print(" " + "-" * 58) + for r in top: + print(f" {r['name']:<30} {r['p50_ms']:>7.2f}ms {r['p95_ms']:>7.2f}ms {r['count']:>6}") + print() + print(f" JSON -> {result_path}") + print("=" * 60) + + +if __name__ == "__main__": + import fire + + fire.Fire(run) diff --git a/scripts/profiling/profile_nsys.py b/scripts/profiling/profile_nsys.py index c73107a09..c2e8f9036 100644 --- a/scripts/profiling/profile_nsys.py +++ b/scripts/profiling/profile_nsys.py @@ -76,6 +76,24 @@ action="store_true", help="Print subprocess commands without executing (td_main target only)", ) +parser.add_argument( + "--cn-scale", + type=float, + default=0.0, + metavar="SCALE", + help="[benchmark] ControlNet conditioning scale (0 = disabled, default). " + "When > 0, activates the first registered ControlNet at this scale using a dummy " + "gray control image. Lets you measure CN per-frame cost alongside the UNet baseline.", +) +parser.add_argument( + "--cn-cache-interval", + type=int, + default=1, + metavar="N", + help="[benchmark] ControlNet residual cache interval (default 1 = disabled). " + "N>1: CN forward runs once every N frames; residuals reused between. " + "Requires --cn-scale > 0.", +) parser.add_argument( "--config", default="", @@ -256,6 +274,33 @@ dummy_img = PIL.Image.new("RGB", (_WIDTH, _HEIGHT), (128, 128, 128)) +# ── ControlNet activation (--cn-scale > 0) ──────────────────────────────────── +if args.cn_scale > 0.0: + try: + cn_mod = getattr(stream.stream, "_controlnet_module", None) + if cn_mod is None: + raise RuntimeError("No _controlnet_module found — ensure config has a ControlNet") + # Set scale + cn_mod.update_controlnet_scale(0, args.cn_scale) + # update_control_image_efficient bails if _preprocessing_orchestrator is None (offline mode). + # Bypass it: directly inject a dummy control tensor ([1,3,H,W] fp16 on GPU) so the hook's + # 'img is not None' gate passes. prepare_frame_tensors will expand it to the right batch. + with cn_mod._collections_lock: + if len(cn_mod.controlnet_images) > 0: + dummy_cn = torch.ones(1, 3, _HEIGHT, _WIDTH, dtype=torch.float16, device="cuda") * 0.5 + cn_mod.controlnet_images[0] = dummy_cn + cn_mod._prepared_tensors = [] + cn_mod._images_version += 1 + else: + raise RuntimeError("ControlNet registered but controlnet_images list is empty") + print(f"[profile] ControlNet[0] enabled: scale={args.cn_scale}, image=dummy gray tensor {_WIDTH}x{_HEIGHT}") + if args.cn_cache_interval > 1: + cn_mod.set_cn_cache_interval(args.cn_cache_interval) + print(f"[profile] ControlNet residual cache: interval={args.cn_cache_interval} (CN forward every {args.cn_cache_interval} frames)") + except Exception as _cn_err: + print(f"[profile] WARNING: Could not activate ControlNet — {_cn_err}") + print(" Make sure the config includes a ControlNet and its engine is built.") + # ── Preprocess once ──────────────────────────────────────────────────────────── image_tensor = stream.preprocess_image(dummy_img) diff --git a/src/streamdiffusion/acceleration/sfast/__init__.py b/src/streamdiffusion/acceleration/sfast/__init__.py index 962a7f5aa..a804ddecf 100644 --- a/src/streamdiffusion/acceleration/sfast/__init__.py +++ b/src/streamdiffusion/acceleration/sfast/__init__.py @@ -24,8 +24,13 @@ def accelerate_with_stable_fast( config.enable_triton = True except ImportError: print("Triton not installed, skip") - # CUDA Graph is suggested for small batch sizes and small resolutions to reduce CPU overhead. - config.enable_cuda_graph = True + # CUDA Graph reduces CPU overhead for small batches/resolutions. + # Disable when the UNet is a TRT engine (which has its own CUDA-graph regime) + # to avoid double-capture overhead and potential replay conflicts. + # TRT engines expose `dump_profile`; standard nn.Module does not. + _unet = getattr(stream.pipe, "unet", None) + _trt_active = _unet is not None and hasattr(_unet, "dump_profile") + config.enable_cuda_graph = not _trt_active stream.pipe = compile(stream.pipe, config) stream.unet = stream.pipe.unet stream.vae = stream.pipe.vae diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index dd3a97e0b..52a3109a6 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -11,6 +11,7 @@ from .models.models import BaseModel from .utilities import ( + BUILD_TRT_LOGGER, build_engine, export_onnx, optimize_onnx, @@ -288,7 +289,7 @@ def _quant_fn(): try: import tensorrt as trt - _rt = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + _rt = trt.Runtime(BUILD_TRT_LOGGER) with open(engine_path, "rb") as _f: _eng = _rt.deserialize_cuda_engine(_f.read()) _insp = _eng.create_engine_inspector() diff --git a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py index ca01cc5c4..3abc3f004 100644 --- a/src/streamdiffusion/acceleration/tensorrt/engine_manager.py +++ b/src/streamdiffusion/acceleration/tensorrt/engine_manager.py @@ -114,6 +114,7 @@ def get_engine_path( fp8: bool = False, resolution: Optional[tuple] = None, builder_optimization_level: Optional[int] = None, + build_static_batch: Optional[bool] = None, ) -> Path: """ Generate engine path using wrapper.py's current logic. @@ -163,6 +164,14 @@ def get_engine_path( prefix += "--controlnet" if fp8: prefix += "--fp8v3" + # Encode the actual batch-profile policy so that a static-batch engine + # and a dynamic-batch engine never share the same directory. + # The capacity range (min_batch / max_batch above) is the same for both, + # so without this suffix a stale dynamic engine is silently reused after + # the static-batch switch — and TRT emits "l2tc doesn't take effect" + # because the loaded engine has a symbolic batch dim. + if build_static_batch is not None: + prefix += f"--sbatch{int(build_static_batch)}" prefix += optlvl_suffix diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index b6b110946..f770416bd 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -51,7 +51,76 @@ TRT_LOGGER = trt.Logger(trt.Logger.ERROR) -from ...model_detection import detect_model + +class _BuildLogFilter(trt.ILogger): + """Forwards TRT build messages to polygraphy's logger, dropping known-benign messages: + + - Myelin tactic-skip spam (TRT 10.x catches this exception, skips the tactic, and + still builds a correct engine). Counted so builds emit a one-line summary instead + of ~140 identical [E] lines per VAE engine. + - Logger-mismatch notice ("logger passed into createInferBuilder differs") — a + singleton bookkeeping warning with no effect on engine correctness. Counted so + users see a single summary line instead of repeated [W] noise. + """ + + # Myelin tactic-skip: ALL tokens must appear in the same message. + _BENIGN = ("setupProxyGraph", "g.nodes.size() == 0") + # Logger-mismatch singleton warning: any of these tokens suffices. + _BENIGN_WARN = ("logger passed into createInferBuilder differs",) + + def __init__(self, inner): + trt.ILogger.__init__(self) + self._inner = inner + self.suppressed = 0 + self.suppressed_warn = 0 + + def log(self, severity, msg): + if all(s in msg for s in self._BENIGN): + self.suppressed += 1 + return + if any(s in msg for s in self._BENIGN_WARN): + self.suppressed_warn += 1 + return + self._inner.log(severity, msg) + + +# Single shared instance. TensorRT registers ONE logger globally (first +# builder/runtime/refitter wins); reusing one instance for every trt.Builder, +# trt.Runtime, and trt.Refitter we create avoids the "logger differs from one +# already registered" warning while still filtering the benign myelin spam. +BUILD_TRT_LOGGER = _BuildLogFilter(TRT_LOGGER) + +_BUILD_LOGGER_REGISTERED = False + + +def _ensure_build_logger_registered() -> None: + """Force BUILD_TRT_LOGGER to win the global TRT logger registration race. + + TRT registers exactly ONE ILogger globally (first trt.Builder / trt.Runtime / + trt.Refitter wins via ``nvinfer1::getLogger()``). Calling this once — before any + polygraphy ``engine_from_bytes()`` or any other ``trt.Builder()`` — guarantees that + subsequent "logger differs" warnings (from loads that use TRT_LOGGER, or from + standalone compile tools with a fresh logger) route through BUILD_TRT_LOGGER and are + silently suppressed by its ``_BENIGN_WARN`` filter. + + Idempotent: the throwaway builder is created at most once per process. + """ + global _BUILD_LOGGER_REGISTERED + if _BUILD_LOGGER_REGISTERED: + return + _BUILD_LOGGER_REGISTERED = True + try: + trt.Builder(BUILD_TRT_LOGGER) # registers BUILD_TRT_LOGGER as the global TRT logger + except Exception: + pass # no CUDA device or TRT init failure — skip; filter still active for any msgs received + + +# Register on import so the first polygraphy engine_from_bytes() (which uses TRT_LOGGER) +# cannot claim the global slot before BUILD_TRT_LOGGER. +_ensure_build_logger_registered() + + +from ...model_detection import detect_model # noqa: E402 # --------------------------------------------------------------------------- @@ -534,7 +603,7 @@ def map_name(name): # Construct refit dictionary refit_dict = {} - refitter = trt.Refitter(self.engine, TRT_LOGGER) + refitter = trt.Refitter(self.engine, BUILD_TRT_LOGGER) all_weights = refitter.get_all() for layer_name, role in zip(all_weights[0], all_weights[1]): # for speciailized roles, use a unique name in the map: @@ -624,7 +693,9 @@ def build( # set_preview_feature, or SPARSE_WEIGHTS. We use the raw API (same as # the FP8 path) so all parameters are available for both precision paths. - build_logger = trt.Logger(trt.Logger.WARNING) + build_logger = BUILD_TRT_LOGGER + suppressed_before = build_logger.suppressed + suppressed_warn_before = build_logger.suppressed_warn builder = trt.Builder(build_logger) network_flags = 0 @@ -689,6 +760,18 @@ def build( serialized = builder.build_serialized_network(network, config) if serialized is None: raise RuntimeError(f"TRT FP16 engine build failed for {onnx_path}. Check TRT logs above for details.") + suppressed = build_logger.suppressed - suppressed_before + if suppressed: + logger.info( + f"[TRT Build] Suppressed {suppressed} benign myelin tactic-skip " + f"messages (TRT Error Code 9 / setupProxyGraph) — engine built normally." + ) + suppressed_warn = build_logger.suppressed_warn - suppressed_warn_before + if suppressed_warn: + logger.info( + f"[TRT Build] Suppressed {suppressed_warn} benign logger-mismatch " + f"notice(s) (createInferBuilder singleton warning) — no impact on engine." + ) with open(self.engine_path, "wb") as f: f.write(serialized) @@ -732,8 +815,9 @@ def _build_fp8( gpu_profile: Hardware-aware build parameters from detect_gpu_profile(). dynamic_shapes: Whether the engine uses dynamic input shapes. """ - build_logger = trt.Logger(trt.Logger.WARNING) - + build_logger = BUILD_TRT_LOGGER + suppressed_before = build_logger.suppressed + suppressed_warn_before = build_logger.suppressed_warn builder = trt.Builder(build_logger) # STRONGLY_TYPED: required for FP8. Tells TRT to use the data-type annotations @@ -804,6 +888,18 @@ def _build_fp8( serialized = builder.build_serialized_network(network, config) if serialized is None: raise RuntimeError(f"TRT FP8 engine build failed for {onnx_path}. Check TRT logs above for details.") + suppressed = build_logger.suppressed - suppressed_before + if suppressed: + logger.info( + f"[TRT Build] Suppressed {suppressed} benign myelin tactic-skip " + f"messages (TRT Error Code 9 / setupProxyGraph) — engine built normally." + ) + suppressed_warn = build_logger.suppressed_warn - suppressed_warn_before + if suppressed_warn: + logger.info( + f"[TRT Build] Suppressed {suppressed_warn} benign logger-mismatch " + f"notice(s) (createInferBuilder singleton warning) — no impact on engine." + ) with open(self.engine_path, "wb") as f: f.write(serialized) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 001fda3db..de699faba 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -1,3 +1,4 @@ +import logging import os import sys import yaml @@ -5,6 +6,8 @@ from typing import Dict, List, Optional, Union, Any, Tuple from pathlib import Path +logger = logging.getLogger(__name__) + def load_config(config_path: Union[str, Path]) -> Dict[str, Any]: """Load StreamDiffusion configuration from YAML or JSON file""" config_path = Path(config_path) @@ -147,12 +150,19 @@ def _extract_wrapper_params(config: Dict[str, Any]) -> Dict[str, Any]: else: param_map['use_ipadapter'] = config.get('use_ipadapter', False) param_map['ipadapter_config'] = config.get('ipadapter_config') - + param_map['use_cached_attn'] = config.get('use_cached_attn', False) - + param_map['cache_maxframes'] = config.get('cache_maxframes', 1) param_map['cache_interval'] = config.get('cache_interval', 1) - + # cn_cache_interval: ControlNet residual reuse interval. + # 1 (default) = disabled, run CN every frame. + # N > 1 = run CN once every N frames; reuse residuals between (control latency = N-1 frames). + param_map['cn_cache_interval'] = config.get('cn_cache_interval', 1) + # max_cache_maxframes: allocation cap for the KVO/FI cache ring buffers (VRAM). + # cache_maxframes is the live logical write window; this is the hard upper bound. + param_map['max_cache_maxframes'] = config.get('max_cache_maxframes', 4) + # Pipeline hook configurations (Phase 4: Configuration Integration) hook_configs = _prepare_pipeline_hook_configs(config) param_map.update(hook_configs) @@ -206,6 +216,22 @@ def _prepare_controlnet_configs(config: Dict[str, Any]) -> List[Dict[str, Any]]: 'control_guidance_start': cn_config.get('control_guidance_start', 0.0), 'control_guidance_end': cn_config.get('control_guidance_end', 1.0), } + + # --- Profile knob injection --- + # Thread the active UI build profile into preprocessor_params so self-building + # TRT preprocessors (HED, Scribble, NormalBae) apply the same + # builder_optimization_level as the main UNet/VAE build. FP8 is flagged so + # the preprocessor can log a one-time info message and fall back to FP16. + # Per-CN overrides in the YAML take precedence over the top-level value. + pp = dict(controlnet_config["preprocessor_params"] or {}) + global_opt_level = config.get("builder_optimization_level") + if global_opt_level is not None and "builder_optimization_level" not in pp: + pp["builder_optimization_level"] = global_opt_level + if config.get("fp8", False) and "build_fp8" not in pp: + pp["build_fp8"] = True + if pp: + controlnet_config["preprocessor_params"] = pp + controlnet_configs.append(controlnet_config) return controlnet_configs diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py index 6c8655b13..d463f27fb 100644 --- a/src/streamdiffusion/modules/controlnet_module.py +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -13,6 +13,7 @@ PreprocessingOrchestrator, ) from streamdiffusion.preprocessing.orchestrator_user import OrchestratorUser +from streamdiffusion.tools.gpu_profiler import profiler @dataclass @@ -69,6 +70,14 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) -> # Cache engine type detection to avoid repeated hasattr calls self._engine_type_cache: Dict[str, bool] = {} + # Residual caching: skip the CN engine forward for intermediate frames and reuse the last result. + # interval=1 disables caching (run every frame). interval=2 halves CN cost; control latency = 1 frame. + self._cn_cache_interval: int = 1 + self._cn_frame_counter: int = 0 + self._cn_cached_residuals: Optional[UnetKwargsDelta] = None + self._cn_cache_images_version: int = -1 + self._cn_cache_scale_hash: Optional[tuple] = None + # ---------- Public API (used by wrapper in a later step) ---------- def install(self, stream) -> None: self._stream = stream @@ -91,6 +100,11 @@ def install(self, stream) -> None: self._is_sdxl = None self._sdxl_conditioning_valid = False self._engine_type_cache.clear() + # Reset residual cache on re-install so stale tensors are never reused. + self._cn_frame_counter = 0 + self._cn_cached_residuals = None + self._cn_cache_images_version = -1 + self._cn_cache_scale_hash = None def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None: model = self._load_pytorch_controlnet_model(cfg.model_id, cfg.conditioning_channels) @@ -210,6 +224,22 @@ def update_controlnet_scale(self, index: int, scale: float) -> None: if 0 <= index < len(self.controlnet_scales): self.controlnet_scales[index] = float(scale) + def set_cn_cache_interval(self, n: int) -> None: + """Set the residual cache interval. + + interval=1 (default): disabled — CN runs every frame. + interval=N: CN runs once, residuals reused for the next N-1 frames, + then re-run, repeating. Control latency = N-1 frames. N=2 halves cost. + + Changing the interval resets the frame counter and drops any cached delta. + """ + n = max(1, int(n)) + with self._collections_lock: + if n != self._cn_cache_interval: + self._cn_cache_interval = n + self._cn_frame_counter = 0 + self._cn_cached_residuals = None + def update_controlnet_enabled(self, index: int, enabled: bool) -> None: with self._collections_lock: if 0 <= index < len(self.enabled_list): @@ -406,9 +436,25 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: if enabled: active_data.append((cn, img, scale, i)) + # Snapshot invalidation keys for residual cache (captured under lock for consistency). + curr_images_version = self._images_version + scale_hash = tuple(self.controlnet_scales) + if not active_data: return UnetKwargsDelta() + # Residual cache hit: reuse the last forward result when control + # inputs are unchanged and this is an intermediate frame. + if ( + self._cn_cache_interval > 1 + and self._cn_cached_residuals is not None + and self._cn_cache_images_version == curr_images_version + and self._cn_cache_scale_hash == scale_hash + and self._cn_frame_counter % self._cn_cache_interval != 0 + ): + self._cn_frame_counter += 1 + return self._cn_cached_residuals + # Cache TRT engines lookup to avoid rebuilding every frame if not self._engines_cache_valid: self._engines_by_id.clear() @@ -452,82 +498,93 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: prepared_images = self._prepared_tensors for cn, img, scale, idx_i in active_data: - # Swap to TRT engine if available for this model_id (use cached lookup) - model_id = getattr(cn, 'model_id', None) - if model_id and model_id in self._engines_by_id: - cn = self._engines_by_id[model_id] - - # Use pre-prepared tensor - current_img = prepared_images[idx_i] if idx_i < len(prepared_images) else img - if current_img is None: - continue - - # Check if this is TensorRT engine (use cached result to avoid repeated hasattr calls) - cache_key = id(cn) # Use object id as unique identifier - if cache_key in self._engine_type_cache: - is_trt_engine = self._engine_type_cache[cache_key] - else: - is_trt_engine = hasattr(cn, 'engine') and hasattr(cn, 'stream') - self._engine_type_cache[cache_key] = is_trt_engine - - # Get optimized SDXL conditioning (uses caching to avoid repeated tensor operations) - added_cond_kwargs = self._get_cached_sdxl_conditioning(ctx) - - try: - if is_trt_engine: - # TensorRT engine path - if added_cond_kwargs: - down_samples, mid_sample = cn( - sample=x_t, - timestep=t_list, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=current_img, - conditioning_scale=float(scale), - **added_cond_kwargs - ) - else: - down_samples, mid_sample = cn( - sample=x_t, - timestep=t_list, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=current_img, - conditioning_scale=float(scale) - ) + with profiler.region("cn.prep"): + # Swap to TRT engine if available for this model_id (use cached lookup) + model_id = getattr(cn, "model_id", None) + if model_id and model_id in self._engines_by_id: + cn = self._engines_by_id[model_id] + + # Use pre-prepared tensor + current_img = prepared_images[idx_i] if idx_i < len(prepared_images) else img + if current_img is None: + continue + + # Check if this is TensorRT engine (use cached result to avoid repeated hasattr calls) + cache_key = id(cn) # Use object id as unique identifier + if cache_key in self._engine_type_cache: + is_trt_engine = self._engine_type_cache[cache_key] else: - # PyTorch ControlNet path - if added_cond_kwargs: - down_samples, mid_sample = cn( - sample=x_t, - timestep=t_list, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=current_img, - conditioning_scale=float(scale), - return_dict=False, - added_cond_kwargs=added_cond_kwargs - ) + is_trt_engine = hasattr(cn, "engine") and hasattr(cn, "stream") + self._engine_type_cache[cache_key] = is_trt_engine + + # Get optimized SDXL conditioning (uses caching to avoid repeated tensor operations) + added_cond_kwargs = self._get_cached_sdxl_conditioning(ctx) + + with profiler.region("cn.forward"): + try: + if is_trt_engine: + # TensorRT engine path + if added_cond_kwargs: + down_samples, mid_sample = cn( + sample=x_t, + timestep=t_list, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=current_img, + conditioning_scale=float(scale), + **added_cond_kwargs, + ) + else: + down_samples, mid_sample = cn( + sample=x_t, + timestep=t_list, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=current_img, + conditioning_scale=float(scale), + ) else: - down_samples, mid_sample = cn( - sample=x_t, - timestep=t_list, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=current_img, - conditioning_scale=float(scale), - return_dict=False + # PyTorch ControlNet path + if added_cond_kwargs: + down_samples, mid_sample = cn( + sample=x_t, + timestep=t_list, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=current_img, + conditioning_scale=float(scale), + return_dict=False, + added_cond_kwargs=added_cond_kwargs, + ) + else: + down_samples, mid_sample = cn( + sample=x_t, + timestep=t_list, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=current_img, + conditioning_scale=float(scale), + return_dict=False, + ) + except Exception as e: + import traceback + + __import__("logging").getLogger(__name__).error( + "ControlNetModule: controlnet forward failed: %s", e + ) + try: + __import__("logging").getLogger(__name__).error( + "ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s", + ( + tuple(encoder_hidden_states.shape) + if isinstance(encoder_hidden_states, torch.Tensor) + else None + ), + (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None), + scale, + self._is_sdxl, + is_trt_engine, ) - except Exception as e: - import traceback - __import__('logging').getLogger(__name__).error("ControlNetModule: controlnet forward failed: %s", e) - try: - __import__('logging').getLogger(__name__).error("ControlNetModule: call_summary: cond_shape=%s, img_shape=%s, scale=%s, is_sdxl=%s, is_trt=%s", - (tuple(encoder_hidden_states.shape) if isinstance(encoder_hidden_states, torch.Tensor) else None), - (tuple(current_img.shape) if isinstance(current_img, torch.Tensor) else None), - scale, - self._is_sdxl, - is_trt_engine) - except Exception: - pass - __import__('logging').getLogger(__name__).error(traceback.format_exc()) - continue + except Exception: + pass + __import__("logging").getLogger(__name__).error(traceback.format_exc()) + continue down_samples_list.append(down_samples) mid_samples_list.append(mid_sample) @@ -535,23 +592,30 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta: return UnetKwargsDelta() if len(down_samples_list) == 1: - return UnetKwargsDelta( + _result = UnetKwargsDelta( down_block_additional_residuals=down_samples_list[0], mid_block_additional_residual=mid_samples_list[0], ) + else: + # Merge multiple ControlNet residuals + merged_down = down_samples_list[0] + merged_mid = mid_samples_list[0] + for ds, ms in zip(down_samples_list[1:], mid_samples_list[1:]): + for j in range(len(merged_down)): + merged_down[j] = merged_down[j] + ds[j] + merged_mid = merged_mid + ms + _result = UnetKwargsDelta( + down_block_additional_residuals=merged_down, + mid_block_additional_residual=merged_mid, + ) - # Merge multiple ControlNet residuals - merged_down = down_samples_list[0] - merged_mid = mid_samples_list[0] - for ds, ms in zip(down_samples_list[1:], mid_samples_list[1:]): - for j in range(len(merged_down)): - merged_down[j] = merged_down[j] + ds[j] - merged_mid = merged_mid + ms - - return UnetKwargsDelta( - down_block_additional_residuals=merged_down, - mid_block_additional_residual=merged_mid, - ) + # Residual cache write: store result for reuse on upcoming intermediate frames. + if self._cn_cache_interval > 1: + self._cn_cached_residuals = _result + self._cn_cache_images_version = curr_images_version + self._cn_cache_scale_hash = scale_hash + self._cn_frame_counter += 1 + return _result return _unet_hook diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 55922e695..b78d376d6 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -503,6 +503,13 @@ def prepare( self.stock_noise = torch.zeros_like(self.init_noise) + # Ping-pong buffers for stock_noise rotation. + # Replaces the per-frame torch.cat((init_noise[0:1], stock_noise[:-1])) + # which allocates a new tensor every frame. Two preallocated buffers + # are alternated so src and dst never alias. + self._stock_noise_bufs = [self.stock_noise.clone(), torch.empty_like(self.stock_noise)] + self._stock_noise_pong = 0 # next write-target index (0 or 1) + # Handle scheduler-specific scaling calculations c_skip_list = [] c_out_list = [] @@ -553,6 +560,21 @@ def prepare( self.c_skip = self.c_skip.to(self.device) self.c_out = self.c_out.to(self.device) + # Precompute per-step expanded timestep tensors for the TCD / non-batched sequential loop. + # Avoids per-step t.view(1).repeat(frame_bff_size) tensor allocations inside predict_x0_batch. + # Only valid when sub_timesteps_tensor is a 1-D sequence (not the collapsed-scalar LCM path). + _use_seq_loop = not (self.use_denoising_batch and isinstance(self.scheduler, LCMScheduler)) + if _use_seq_loop and self.sub_timesteps_tensor.dim() >= 1: + self._sub_timesteps_expanded = ( + self.sub_timesteps_tensor + .view(-1) + .unsqueeze(1) + .expand(-1, self.frame_bff_size) + .contiguous() + ) # shape [loop_steps, frame_bff_size] + else: + self._sub_timesteps_expanded = None + # Pre-compute shifted alpha/beta/init_noise (eliminates 5 mallocs + 8 kernel launches per frame) if self.use_denoising_batch and (self.cfg_type == "self" or self.cfg_type == "initialize"): self._alpha_next = torch.cat( @@ -946,7 +968,14 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: self._combined_latent_buf[: self.frame_bff_size].copy_(x_t_latent) self._combined_latent_buf[self.frame_bff_size :].copy_(prev_latent_batch) x_t_latent = self._combined_latent_buf - self.stock_noise = torch.cat((self.init_noise[0:1], self.stock_noise[:-1]), dim=0) + # Ping-pong rotation: eliminates torch.cat malloc every frame. + # _stock_noise_bufs[pong] is always a different tensor from self.stock_noise, + # so src and dst never alias. + _sn_dst = self._stock_noise_bufs[self._stock_noise_pong] + _sn_dst[0].copy_(self.init_noise[0]) + _sn_dst[1:].copy_(self.stock_noise[:-1]) + self.stock_noise = _sn_dst + self._stock_noise_pong = 1 - self._stock_noise_pong with profiler.region("unet_step"): x_0_pred_batch, model_pred = self.unet_step(x_t_latent, t_list) @@ -966,20 +995,20 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: # Standard scheduler loop for TCD and non-batched LCM sample = x_t_latent for idx, timestep in enumerate(self.sub_timesteps_tensor): - # Ensure timestep tensor on device with correct dtype - if not isinstance(timestep, torch.Tensor): - t = torch.tensor(timestep, device=self.device, dtype=torch.long) + # Resolve scalar timestep tensor on device. + # Avoid per-step t.view(1).repeat() allocations by using the + # precomputed _sub_timesteps_expanded table from prepare(). + t = timestep if isinstance(timestep, torch.Tensor) else torch.tensor( + timestep, device=self.device, dtype=torch.long + ) + if self._sub_timesteps_expanded is not None: + t_expanded = self._sub_timesteps_expanded[idx] # [frame_bff_size], no alloc else: - t = timestep.to(self.device) + t_expanded = t.view(1).repeat(self.frame_bff_size) # fallback # For TCD, use the same UNet calling logic as LCM to ensure ControlNet hooks are processed if isinstance(self.scheduler, TCDScheduler): # Use unet_step to process ControlNet hooks and get proper noise prediction - t_expanded = t.view( - 1, - ).repeat( - self.frame_bff_size, - ) with profiler.region("unet_step"): x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) @@ -991,13 +1020,8 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: ) else: # Original LCM logic for non-batched mode - t = t.view( - 1, - ).repeat( - self.frame_bff_size, - ) with profiler.region("unet_step"): - x_0_pred, model_pred = self.unet_step(sample, t, idx) + x_0_pred, model_pred = self.unet_step(sample, t_expanded, idx) if idx < len(self.sub_timesteps_tensor) - 1: if self.do_add_noise: if self._noise_buf is None: diff --git a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py index cf05fbfaa..5e7b18cf9 100644 --- a/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py +++ b/src/streamdiffusion/preprocessing/preprocessing_orchestrator.py @@ -455,10 +455,9 @@ def _process_single_ipadapter(self, } return None - - except Exception as e: - import traceback - traceback.print_exc() + + except Exception: + logger.exception("PreprocessingOrchestrator: IPAdapter preprocessor %d failed", index) return None #Helper methods @@ -788,14 +787,31 @@ def _process_single_preprocessor_group(self, processed_image = self.prepare_control_image( control_variants['tensor'], preprocessor, stream_width, stream_height ) - return { - 'prep_key': prep_key, - 'indices': indices, - 'processed_image': processed_image - } - except Exception: - pass # Fall through to PIL processing - + return {"prep_key": prep_key, "indices": indices, "processed_image": processed_image} + except Exception as tensor_exc: + # Lazy-init dedup set to suppress per-frame log spam. + if not hasattr(self, "_tensor_fail_warned"): + self._tensor_fail_warned = set() + if prep_key not in self._tensor_fail_warned: + logger.warning( + "PreprocessingOrchestrator: tensor path failed for '%s': %s", + prep_key, + tensor_exc, + ) + logger.debug( + "PreprocessingOrchestrator: tensor path traceback for '%s'", + prep_key, + exc_info=True, + ) + self._tensor_fail_warned.add(prep_key) + + # GPU-native preprocessors (e.g. self-building TRT) route the PIL + # path through the same _process_tensor_core — the fallback would + # fail identically. Skip it and surface the error immediately. + if getattr(preprocessor, "gpu_native", False): + return None + + # PIL processing fallback if control_variants['image'] is not None: processed_image = self.prepare_control_image( @@ -810,6 +826,6 @@ def _process_single_preprocessor_group(self, return None except Exception as e: - logger.error(f"PreprocessingOrchestrator: Preprocessor {prep_key} failed: {e}") + logger.exception("PreprocessingOrchestrator: Preprocessor '%s' failed: %s", prep_key, e) return None diff --git a/src/streamdiffusion/preprocessing/processors/__init__.py b/src/streamdiffusion/preprocessing/processors/__init__.py index 5674ab4af..adf53e158 100644 --- a/src/streamdiffusion/preprocessing/processors/__init__.py +++ b/src/streamdiffusion/preprocessing/processors/__init__.py @@ -13,6 +13,7 @@ from .faceid_embedding import FaceIDEmbeddingPreprocessor from .feedback import FeedbackPreprocessor from .latent_feedback import LatentFeedbackPreprocessor +from .scribble import ScribblePreprocessor from .sharpen import SharpenPreprocessor from .upscale import UpscalePreprocessor from .blur import BlurPreprocessor @@ -33,6 +34,30 @@ YoloNasPoseTensorrtPreprocessor = None POSE_TENSORRT_AVAILABLE = False +try: + from .hed_tensorrt import HEDTensorrtPreprocessor + + HED_TENSORRT_AVAILABLE = True +except ImportError: + HEDTensorrtPreprocessor = None + HED_TENSORRT_AVAILABLE = False + +try: + from .scribble_tensorrt import ScribbleTensorrtPreprocessor + + SCRIBBLE_TENSORRT_AVAILABLE = True +except ImportError: + ScribbleTensorrtPreprocessor = None + SCRIBBLE_TENSORRT_AVAILABLE = False + +try: + from .normal_bae_tensorrt import NormalBaeTensorrtPreprocessor + + NORMAL_BAE_TENSORRT_AVAILABLE = True +except ImportError: + NormalBaeTensorrtPreprocessor = None + NORMAL_BAE_TENSORRT_AVAILABLE = False + try: from .temporal_net_tensorrt import TemporalNetTensorRTPreprocessor TEMPORAL_NET_TENSORRT_AVAILABLE = True @@ -65,6 +90,7 @@ "external": ExternalPreprocessor, "soft_edge": SoftEdgePreprocessor, "hed": HEDPreprocessor, + "scribble": ScribblePreprocessor, "feedback": FeedbackPreprocessor, "latent_feedback": LatentFeedbackPreprocessor, "sharpen": SharpenPreprocessor, @@ -90,6 +116,16 @@ if MEDIAPIPE_SEGMENTATION_AVAILABLE: _preprocessor_registry["mediapipe_segmentation"] = MediaPipeSegmentationPreprocessor +# Add GPU-native TRT ControlNet preprocessors (HED, Scribble, NormalBae) +if HED_TENSORRT_AVAILABLE: + _preprocessor_registry["hed_tensorrt"] = HEDTensorrtPreprocessor + +if SCRIBBLE_TENSORRT_AVAILABLE: + _preprocessor_registry["scribble_tensorrt"] = ScribbleTensorrtPreprocessor + +if NORMAL_BAE_TENSORRT_AVAILABLE: + _preprocessor_registry["normal_bae_tensorrt"] = NormalBaeTensorrtPreprocessor + def get_preprocessor_class(name: str) -> type: """ @@ -168,6 +204,7 @@ def list_preprocessors(): "ExternalPreprocessor", "SoftEdgePreprocessor", "HEDPreprocessor", + "ScribblePreprocessor", "IPAdapterEmbeddingPreprocessor", "FaceIDEmbeddingPreprocessor", "FeedbackPreprocessor", @@ -193,6 +230,15 @@ def list_preprocessors(): if MEDIAPIPE_SEGMENTATION_AVAILABLE: __all__.append("MediaPipeSegmentationPreprocessor") +if HED_TENSORRT_AVAILABLE: + __all__.append("HEDTensorrtPreprocessor") + +if SCRIBBLE_TENSORRT_AVAILABLE: + __all__.append("ScribbleTensorrtPreprocessor") + +if NORMAL_BAE_TENSORRT_AVAILABLE: + __all__.append("NormalBaeTensorrtPreprocessor") + # region Custom Processor Discovery import logging diff --git a/src/streamdiffusion/preprocessing/processors/base.py b/src/streamdiffusion/preprocessing/processors/base.py index 218a459f5..71fdcf9f1 100644 --- a/src/streamdiffusion/preprocessing/processors/base.py +++ b/src/streamdiffusion/preprocessing/processors/base.py @@ -1,17 +1,28 @@ +import logging from abc import ABC, abstractmethod -from typing import Union, Dict, Any, Tuple, Optional +from typing import Any, Dict, Optional, Set, Tuple, Union import torch import torch.nn.functional as F import numpy as np from PIL import Image +from streamdiffusion.tools.gpu_profiler import profiler + + +_pil_fallback_warned: Set[str] = set() # per-class warning dedup +_base_logger = logging.getLogger(__name__) + class BasePreprocessor(ABC): """ Base class for ControlNet preprocessors with template method pattern """ - - + + # Set to True on subclasses whose _process_tensor_core path is genuinely GPU-native + # (i.e. does NOT call tensor_to_pil / pil_to_tensor or any CPU op). + # Used by the residency guard test and the one-time PIL-fallback warning below. + gpu_native: bool = False + def __init__(self, normalization_context: str = 'controlnet', **kwargs): """ Initialize the preprocessor @@ -53,7 +64,8 @@ def process(self, image: Union[Image.Image, np.ndarray, torch.Tensor]) -> Image. Template method - handles all common operations """ image = self.validate_input(image) - processed = self._process_core(image) + with profiler.region("proc.core"): + processed = self._process_core(image) return self._ensure_target_size(processed) def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor: @@ -61,7 +73,8 @@ def process_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor: Template method for GPU tensor processing """ tensor = self.validate_tensor_input(image_tensor) - processed = self._process_tensor_core(tensor) + with profiler.region("proc.core"): + processed = self._process_tensor_core(tensor) return self._ensure_target_size_tensor(processed) @abstractmethod @@ -73,12 +86,29 @@ def _process_core(self, image: Image.Image) -> Image.Image: def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """ - Optional GPU processing (fallback to PIL if not overridden) + Optional GPU processing (fallback to PIL if not overridden). + + D8 residency guard: emits a one-time per-class warning when this + fallback fires so that silent CPU round-trips surface immediately. + Subclasses that are genuinely GPU-native must override this method + AND set `gpu_native = True` on the class. """ - pil_image = self.tensor_to_pil(tensor) + cls_name = type(self).__name__ + if cls_name not in _pil_fallback_warned: + _pil_fallback_warned.add(cls_name) + _base_logger.warning( + f"[GPU-residency] {cls_name}._process_tensor_core is using the PIL " + "fallback (tensor → CPU → PIL → _process_core → tensor). " + "Set gpu_native=True and override _process_tensor_core to eliminate " + "this CPU round-trip. (This warning fires once per class.)" + ) + with profiler.region("proc.tensor_to_pil"): + pil_image = self.tensor_to_pil(tensor) processed_pil = self._process_core(pil_image) - return self.pil_to_tensor(processed_pil) - + with profiler.region("proc.pil_to_tensor"): + return self.pil_to_tensor(processed_pil) + + def _ensure_target_size(self, image: Image.Image) -> Image.Image: """ Centralized PIL resize logic @@ -99,7 +129,9 @@ def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor: if current_size != target_size: if tensor.dim() == 3: tensor = tensor.unsqueeze(0) - tensor = F.interpolate(tensor, size=target_size, mode='bilinear', align_corners=False) + # antialias=True applies a Gaussian pre-filter before downscaling, reducing + # aliasing artifacts in ControlNet conditioning maps (no-op when upscaling). + tensor = F.interpolate(tensor, size=target_size, mode="bilinear", align_corners=False, antialias=True) if tensor.shape[0] == 1: tensor = tensor.squeeze(0) return tensor @@ -125,13 +157,17 @@ def validate_tensor_input(self, image_tensor: torch.Tensor) -> torch.Tensor: if image_tensor.dim() == 3 and image_tensor.shape[0] not in [1, 3]: # Likely HWC format, convert to CHW image_tensor = image_tensor.permute(2, 0, 1) - + + # Normalize to [0,1] range only if tensor is uint8 [0,255]. + # Check dtype BEFORE the .to() cast so we don't force a D2H sync via + # .max() > 1.0 (scalar reduction onto CPU every frame). + # [-1,1] and [0,1] float tensors from the pipeline are left unchanged. + _was_uint8 = image_tensor.dtype == torch.uint8 + # Ensure correct device and dtype image_tensor = image_tensor.to(device=self.device, dtype=self.dtype) - - # Normalize to [0,1] range only if tensor is in [0,255] uint8 range - # Preserves [-1,1] and [0,1] ranges (max <= 1.0) - if image_tensor.max() > 1.0: + + if _was_uint8: image_tensor = image_tensor / 255.0 return image_tensor diff --git a/src/streamdiffusion/preprocessing/processors/canny.py b/src/streamdiffusion/preprocessing/processors/canny.py index 7c25e9ab4..a20e92630 100644 --- a/src/streamdiffusion/preprocessing/processors/canny.py +++ b/src/streamdiffusion/preprocessing/processors/canny.py @@ -12,7 +12,10 @@ class CannyPreprocessor(BasePreprocessor): Detects edges in the input image using the Canny edge detection algorithm. """ - + + gpu_native = True # _process_tensor_core uses conv2d — no CPU/PIL round-trip + + @classmethod def get_preprocessor_metadata(cls): return { diff --git a/src/streamdiffusion/preprocessing/processors/depth.py b/src/streamdiffusion/preprocessing/processors/depth.py index fbf57dc83..c940d7189 100644 --- a/src/streamdiffusion/preprocessing/processors/depth.py +++ b/src/streamdiffusion/preprocessing/processors/depth.py @@ -1,8 +1,7 @@ import numpy as np from PIL import Image -import torch -from typing import Union, Optional -from .base import BasePreprocessor + +from .base import BasePreprocessor, _base_logger, _pil_fallback_warned try: import torch @@ -99,9 +98,21 @@ def _process_core(self, image: Image.Image) -> Image.Image: def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ - Process tensor directly on GPU for depth estimation + Process tensor directly on GPU for depth estimation. + + Note: HF depth pipeline requires PIL input, so this method round-trips through + PIL (tensor → CPU → PIL → depth model → GPU tensor) every frame. + For fully GPU-resident depth estimation use DepthAnythingTensorrtPreprocessor. """ - detect_resolution = self.params.get('detect_resolution', 512) + cls_name = type(self).__name__ + if cls_name not in _pil_fallback_warned: + _pil_fallback_warned.add(cls_name) + _base_logger.warning( + f"[GPU-residency] {cls_name}._process_tensor_core round-trips through PIL " + "every frame (HF depth pipeline requires PIL input). For full GPU residency " + "use the TensorRT variant (depth_tensorrt). (Fires once per class.)" + ) + detect_resolution = self.params.get("detect_resolution", 512) current_size = image_tensor.shape[-2:] if current_size != (detect_resolution, detect_resolution): diff --git a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py index 993ee242d..0949e3d18 100644 --- a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py @@ -8,95 +8,11 @@ from PIL import Image from typing import Union, Optional from .base import BasePreprocessor - -try: - import tensorrt as trt - from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict - TENSORRT_AVAILABLE = True -except ImportError: - TENSORRT_AVAILABLE = False - - -# Map of numpy dtype -> torch dtype -numpy_to_torch_dtype_dict = { - np.uint8: torch.uint8, - np.int8: torch.int8, - np.int16: torch.int16, - np.int32: torch.int32, - np.int64: torch.int64, - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, - np.complex64: torch.complex64, - np.complex128: torch.complex128, -} -if np.version.full_version >= "1.24.0": - numpy_to_torch_dtype_dict[np.bool_] = torch.bool -else: - numpy_to_torch_dtype_dict[np.bool] = torch.bool - - -class TensorRTEngine: - """Simplified TensorRT engine wrapper for depth estimation inference (optimized)""" - - def __init__(self, engine_path): - self.engine_path = engine_path - self.engine = None - self.context = None - self.tensors = OrderedDict() - self._cuda_stream = None # Cache CUDA stream - - def load(self): - """Load TensorRT engine from file""" - print(f"Loading TensorRT engine: {self.engine_path}") - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self): - """Create execution context""" - self.context = self.engine.create_execution_context() - # Cache CUDA stream for reuse - self._cuda_stream = torch.cuda.current_stream().cuda_stream - - def allocate_buffers(self, device="cuda"): - """Allocate input/output buffers""" - for idx in range(self.engine.num_io_tensors): - name = self.engine.get_tensor_name(idx) - shape = self.context.get_tensor_shape(name) - dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: - self.context.set_input_shape(name, shape) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=device) - self.tensors[name] = tensor - - def infer(self, feed_dict, stream=None): - """Run inference with optional stream parameter""" - # Use cached stream if none provided - if stream is None: - stream = self._cuda_stream - - # Copy input data to tensors - for name, buf in feed_dict.items(): - self.tensors[name].copy_(buf) - - # Set tensor addresses - for name, tensor in self.tensors.items(): - self.context.set_tensor_address(name, tensor.data_ptr()) - - # Execute inference - success = self.context.execute_async_v3(stream) - if not success: - raise ValueError("ERROR: TensorRT inference failed.") - - return self.tensors +from .trt_base import TENSORRT_AVAILABLE, TensorRTEngine # shared engine wrapper class DepthAnythingTensorrtPreprocessor(BasePreprocessor): + gpu_native = True # _process_tensor_core runs full pipeline on GPU — no PIL round-trip """ Depth Anything TensorRT preprocessor for ControlNet diff --git a/src/streamdiffusion/preprocessing/processors/feedback.py b/src/streamdiffusion/preprocessing/processors/feedback.py index 72a37a7f0..01461b703 100644 --- a/src/streamdiffusion/preprocessing/processors/feedback.py +++ b/src/streamdiffusion/preprocessing/processors/feedback.py @@ -22,7 +22,10 @@ class FeedbackPreprocessor(PipelineAwareProcessor): The preprocessor accesses the pipeline's prev_image_result to get the previous output. For the first frame (when no previous output exists), it falls back to the input image. """ - + + gpu_native = True # _process_tensor_core blends tensors on GPU — no CPU/PIL round-trip + + @classmethod def get_preprocessor_metadata(cls): return { diff --git a/src/streamdiffusion/preprocessing/processors/hed.py b/src/streamdiffusion/preprocessing/processors/hed.py index 78c770878..c82adf648 100644 --- a/src/streamdiffusion/preprocessing/processors/hed.py +++ b/src/streamdiffusion/preprocessing/processors/hed.py @@ -1,8 +1,8 @@ import torch import numpy as np from PIL import Image -from typing import Union, Optional -from .base import BasePreprocessor + +from .base import BasePreprocessor, _base_logger, _pil_fallback_warned try: from controlnet_aux import HEDdetector @@ -90,11 +90,20 @@ def _process_core(self, image: Image.Image) -> Image.Image: def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ - GPU-optimized HED processing using tensors - - Note: controlnet_aux doesn't support direct tensor input, so we convert to PIL and back. - This is still reasonably fast due to optimized conversions in the base class. + HED processing via tensor I/O. + + Note: controlnet_aux HEDdetector requires PIL input, so this method round-trips + through PIL (tensor → CPU → PIL → HEDdetector → GPU tensor) every frame. + For fully GPU-resident edge detection use the TensorRT variant (hed_tensorrt). """ + cls_name = type(self).__name__ + if cls_name not in _pil_fallback_warned: + _pil_fallback_warned.add(cls_name) + _base_logger.warning( + f"[GPU-residency] {cls_name}._process_tensor_core round-trips through PIL " + "every frame (controlnet_aux requires PIL input). For full GPU residency " + "use the TensorRT variant (hed_tensorrt). (Fires once per class.)" + ) # Convert tensor to PIL, process, then back to tensor pil_image = self.tensor_to_pil(image_tensor) processed_pil = self._process_core(pil_image) diff --git a/src/streamdiffusion/preprocessing/processors/hed_tensorrt.py b/src/streamdiffusion/preprocessing/processors/hed_tensorrt.py new file mode 100644 index 000000000..783921ea2 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/hed_tensorrt.py @@ -0,0 +1,161 @@ +""" +HED TensorRT preprocessor — GPU-native edge detection via TRT engine. + +The HED network (ControlNetHED_Apache2) is wrapped in HEDExportWrapper so +that ONNX export sees a single output tensor (full-resolution edge map) rather +than the native 5-output multi-scale tuple. The wrapper input/output contract: + + input : float32 (B, 3, H, W) in [0, 1] ← same as validate_tensor_input output + output : float32 (B, 1, H, W) in [0, 1] ← sigmoid edge map at full resolution +""" + +import logging +from pathlib import Path + +import torch + +from .trt_base import SelfBuildingTRTPreprocessor, _first_output + + +logger = logging.getLogger(__name__) + +try: + from controlnet_aux import HEDdetector + + CONTROLNET_AUX_AVAILABLE = True +except ImportError: + CONTROLNET_AUX_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# ONNX export wrapper — returns only the full-resolution output +# --------------------------------------------------------------------------- + + +class HEDExportWrapper(torch.nn.Module): + """ + Thin wrapper around ControlNetHED_Apache2 for ONNX export. + + The native forward returns a 5-element tuple of tensors at decreasing + resolutions. ONNX export requires a single output of consistent shape. + This wrapper returns only output[0] (the full-resolution sigmoid map). + """ + + def __init__(self, netNetwork: torch.nn.Module): + super().__init__() + self.netNetwork = netNetwork + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, 3, H, W) in [0, 1] + outputs = self.netNetwork(x) + # outputs is a list/tuple of 5 tensors at (H, H/2, H/4, H/8, H/16); + # take the first = full-resolution edge map shape (B, 1, H, W) + return outputs[0] if isinstance(outputs, (list, tuple)) else outputs + + +# --------------------------------------------------------------------------- +# Preprocessor +# --------------------------------------------------------------------------- + + +class HEDTensorrtPreprocessor(SelfBuildingTRTPreprocessor): + """ + HED edge detection via a self-built TensorRT engine. + + GPU-native: no CPU / PIL round-trip on the tensor path. + The engine is built on first use and cached in engines/preprocessors/hed.engine + (or the path supplied via preprocessor_params.engine_path in the YAML config). + """ + + engine_filename = "hed.engine" + onnx_filename = "hed.onnx" + default_detect_resolution = 512 + + @classmethod + def get_preprocessor_metadata(cls): + return { + "display_name": "HED Edge Detection (TensorRT)", + "description": ( + "GPU-native HED (Holistically-Nested Edge Detection) via TensorRT. " + "Self-builds its engine from the controlnet_aux model on first run. " + "No CPU/PIL round-trips — satisfies the GPU-residency constraint." + ), + "parameters": {}, + "use_cases": [ + "HED ControlNet conditioning", + "Structured edge maps (real-time)", + ], + } + + def __init__(self, **kwargs): + if not CONTROLNET_AUX_AVAILABLE: + raise ImportError( + "controlnet_aux is required for HEDTensorrtPreprocessor. Install with: pip install controlnet_aux" + ) + super().__init__(**kwargs) + + # ------------------------------------------------------------------ + # ONNX export + # ------------------------------------------------------------------ + + def _export_onnx(self, onnx_path: Path) -> None: + """Load HEDdetector, wrap it, and export to ONNX.""" + logger.info("HEDTensorrtPreprocessor: loading HEDdetector for ONNX export…") + detector = HEDdetector.from_pretrained("lllyasviel/Annotators") + + if not hasattr(detector, "netNetwork"): + raise RuntimeError( + "HEDTensorrtPreprocessor: HEDdetector has no 'netNetwork' attribute. " + "controlnet_aux version may be incompatible." + ) + + wrapper = HEDExportWrapper(detector.netNetwork).to(self.device).eval() + res = self.default_detect_resolution + dummy = torch.zeros(1, 3, res, res, device=self.device) + + with torch.no_grad(): + torch.onnx.export( + wrapper, + dummy, + str(onnx_path), + opset_version=17, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch", 2: "height", 3: "width"}, + "output": {0: "batch", 2: "height", 3: "width"}, + }, + ) + + logger.info(f"HEDTensorrtPreprocessor: ONNX exported → {onnx_path}") + # Free GPU memory used by the export model + del wrapper, detector + torch.cuda.empty_cache() + + # ------------------------------------------------------------------ + # Post-process TRT output → CHW GPU tensor + # ------------------------------------------------------------------ + + def _postprocess(self, engine_outputs: dict) -> torch.Tensor: + """ + Convert TRT output to a 3-channel [0, 1] edge map GPU tensor (CHW). + + Input : engine_outputs["output"] shape (B, 1, H, W) or (B, H, W) + Output : (3, H, W) in [0, 1] + """ + out = _first_output(engine_outputs).float() + + # Collapse batch + channel dims if present + if out.dim() == 4: + out = out.squeeze(1) # (B, H, W) — B should be 1 + if out.dim() == 3: + out = out.squeeze(0) # (H, W) + + # Normalize to [0, 1] (edge map may already be in this range) + v_min, v_max = out.min(), out.max() + if v_max > v_min: + out = (out - v_min) / (v_max - v_min) + out = out.clamp(0.0, 1.0) + + # Expand to 3-channel RGB → (3, H, W) + return out.unsqueeze(0).repeat(3, 1, 1) diff --git a/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py b/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py index 8b7d28a08..896d11e5e 100644 --- a/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py +++ b/src/streamdiffusion/preprocessing/processors/ipadapter_embedding.py @@ -1,7 +1,9 @@ -from typing import Union, Tuple, Optional, Any +from typing import Any, Optional, Tuple, Union + import torch from PIL import Image from .base import BasePreprocessor +from streamdiffusion.tools.gpu_profiler import profiler class IPAdapterEmbeddingPreprocessor(BasePreprocessor): @@ -28,34 +30,73 @@ def __init__(self, ipadapter: Any, **kwargs): # Create dedicated CUDA stream for IPAdapter processing to avoid TensorRT conflicts self._ipadapter_stream = torch.cuda.Stream() if torch.cuda.is_available() else None - + # CUDA event for GPU-side stream sync — CPU thread NOT blocked (vs .synchronize()). + # Lazily allocated on first _process_core call so the constructor stays CUDA-context free. + self._completion_event: Optional[torch.cuda.Event] = None + + # Per-preprocessor embedding cache: avoids CLIP re-encode when the style image is + # unchanged across consecutive frames (the common streaming scenario). + # Keyed by tensor.data_ptr() — stable while the storage is live, changes on realloc. + self._last_input_ptr: int = -1 + self._cached_embeds: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + def _process_core(self, image: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]: """Returns (positive_embeds, negative_embeds) instead of processed image""" if self._ipadapter_stream is not None: + # Lazy-init the completion event (avoids CUDA context init in constructor). + if self._completion_event is None: + self._completion_event = torch.cuda.Event() + # Use dedicated stream to avoid TensorRT stream capture conflicts with torch.cuda.stream(self._ipadapter_stream): - image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image]) - - # Wait for stream completion and move tensors to default stream - self._ipadapter_stream.synchronize() - - # Ensure tensors are accessible from default stream - if hasattr(image_embeds, 'record_stream'): + with profiler.region("ipa.clip_encode"): + image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image]) + # Record the event on the IPA stream immediately after encode. + # The default stream will GPU-wait on this event; the CPU thread is NOT blocked. + with profiler.region("ipa.sync"): + self._completion_event.record() + + # GPU-side dependency: the default stream defers until the IPA stream event fires. + # Replaces the blocking _ipadapter_stream.synchronize() — CPU thread continues now. + torch.cuda.current_stream().wait_event(self._completion_event) + + # Mark tensors as owned by the default stream (cross-stream memory safety) + if hasattr(image_embeds, "record_stream"): image_embeds.record_stream(torch.cuda.current_stream()) if hasattr(negative_embeds, 'record_stream'): negative_embeds.record_stream(torch.cuda.current_stream()) else: # Fallback for non-CUDA environments - image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image]) - + with profiler.region("ipa.clip_encode"): + image_embeds, negative_embeds = self.ipadapter.get_image_embeds(images=[image]) + + return image_embeds, negative_embeds def _process_tensor_core(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """GPU-optimized path for tensor inputs""" - # Convert tensor to PIL for IPAdapter processing + """GPU-optimized path for tensor inputs. + + Checks the per-preprocessor embedding cache before running CLIP encode. + CLIP requires PIL input, so the GPU→CPU tensor_to_pil step is unavoidable; + caching avoids it on frames where the style image is unchanged. + """ + # Cache check: skip CLIP re-encode if the tensor's storage pointer is unchanged. + # data_ptr() is stable as long as the tensor storage is not reallocated, which + # is the common case in streaming (same style-image tensor reused across frames). + current_ptr = tensor.data_ptr() if tensor.is_cuda else id(tensor) + if self._cached_embeds is not None and current_ptr == self._last_input_ptr: + return self._cached_embeds + pil_image = self.tensor_to_pil(tensor) - return self._process_core(pil_image) - + result = self._process_core(pil_image) + + # Update cache for next frame + self._last_input_ptr = current_ptr + self._cached_embeds = result + return result + + def process(self, image: Union[Image.Image, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: """Override base process to return embeddings tuple instead of PIL Image""" if isinstance(image, torch.Tensor): diff --git a/src/streamdiffusion/preprocessing/processors/normal_bae_tensorrt.py b/src/streamdiffusion/preprocessing/processors/normal_bae_tensorrt.py new file mode 100644 index 000000000..7558da481 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/normal_bae_tensorrt.py @@ -0,0 +1,309 @@ +""" +NormalBae TensorRT preprocessor — GPU-native surface normal estimation. + +Implementation strategy +----------------------- +The NormalBaeDetector from controlnet_aux uses an NNET architecture with a +complex multi-scale decoder output (nested lists of tensors at different scales), +which complicates ONNX export. MCP verification confirmed zero prior usage +of this model in the repo, so its ONNX-exportability was unverified. + +Probe at module import time: + + PRIMARY: self-building TRT engine (ONNX export succeeds at class load time). + NormalBaeExportWrapper encapsulates: + self.norm(x) → self.model(normed) → extract high-res 3ch normal + so the engine takes plain [0,1] RGB and returns a [0,1] 3ch normal map. + + FALLBACK: if ONNX export fails (or TRT is unavailable), the class falls back + to running the torch model directly under no_grad — the same pattern + as SoftEdgePreprocessor, which is MCP-confirmed GPU-native. This + still satisfies the GPU-residency constraint (no CPU/PIL round-trips). + +In either case `gpu_native = True` is set, and the dangling 'normal_bae' +registry reference in get_preprocessor_for_controlnet is resolved. +""" + +import logging +from pathlib import Path +from typing import Optional + +import torch +from PIL import Image + +from .base import BasePreprocessor +from .trt_base import TENSORRT_AVAILABLE, SelfBuildingTRTPreprocessor, _first_output + + +logger = logging.getLogger(__name__) + +try: + from controlnet_aux import NormalBaeDetector + + CONTROLNET_AUX_AVAILABLE = True +except ImportError: + CONTROLNET_AUX_AVAILABLE = False + + +# --------------------------------------------------------------------------- +# Probe whether ONNX export works for this version of controlnet_aux +# --------------------------------------------------------------------------- + +_TRT_STRATEGY_AVAILABLE: Optional[bool] = None # None = not yet probed + + +def _probe_normal_bae_onnx_export(device: str = "cuda") -> bool: + """ + Try a lightweight ONNX export of NormalBaeExportWrapper. + Returns True if it succeeds, False otherwise. + Cached after first call. + """ + global _TRT_STRATEGY_AVAILABLE + if _TRT_STRATEGY_AVAILABLE is not None: + return _TRT_STRATEGY_AVAILABLE + + if not CONTROLNET_AUX_AVAILABLE or not TENSORRT_AVAILABLE: + _TRT_STRATEGY_AVAILABLE = False + return False + + import os + import tempfile + + try: + det = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") + wrapper = NormalBaeExportWrapper(det.model, det.norm).to(device).eval() + dummy = torch.zeros(1, 3, 64, 64, device=device) # small for probe + tmp = tempfile.mktemp(suffix=".onnx") + with torch.no_grad(): + torch.onnx.export( + wrapper, + dummy, + tmp, + opset_version=17, + input_names=["input"], + output_names=["output"], + ) + _TRT_STRATEGY_AVAILABLE = os.path.exists(tmp) and os.path.getsize(tmp) > 0 + if os.path.exists(tmp): + os.unlink(tmp) + del wrapper, det + torch.cuda.empty_cache() + except Exception as exc: + logger.warning( + f"NormalBaeTensorrtPreprocessor: ONNX probe failed ({exc}); will use torch-direct GPU fallback instead." + ) + _TRT_STRATEGY_AVAILABLE = False + + return _TRT_STRATEGY_AVAILABLE + + +# --------------------------------------------------------------------------- +# ONNX export wrapper +# --------------------------------------------------------------------------- + + +class NormalBaeExportWrapper(torch.nn.Module): + """ + Wraps NormalBaeDetector internals for single-pass ONNX export. + + Replicates the core of NormalBaeDetector.__call__: + normed = self.norm(x) + out = self.model(normed) # NNET + normal = out[0][-1][:, :3] # highest-res decoder output, 3 channels + return ((normal + 1) * 0.5).clamp(0, 1) # [-1,1] → [0,1] + + Input : (B, 3, H, W) float32 [0, 1] + Output : (B, 3, H, W) float32 [0, 1] + """ + + def __init__(self, nnet_model: torch.nn.Module, norm_transform: torch.nn.Module): + super().__init__() + self.nnet_model = nnet_model + self.norm = norm_transform + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + out = self.nnet_model(x) + # out[0] is a list of 4 tensors at scales [64², 128², 256², 512²] + # out[0][-1] is the highest-resolution output (B, 4, H, W) + normal = out[0][-1][:, :3] # (B, 3, H, W) + return ((normal + 1.0) * 0.5).clamp(0.0, 1.0) + + +# --------------------------------------------------------------------------- +# Torch-direct GPU fallback (used when TRT strategy is unavailable) +# --------------------------------------------------------------------------- + + +class _NormalBaeTorchGPU(BasePreprocessor): + """ + GPU-direct NormalBae using the torch model under no_grad. + Mirrors the SoftEdgePreprocessor pattern (MCP-confirmed GPU-native). + No CPU / PIL round-trips. + """ + + gpu_native = True + _detector_cache: dict = {} + + @classmethod + def get_preprocessor_metadata(cls): + return { + "display_name": "Normal Map Estimation (torch GPU)", + "description": ( + "GPU-native surface normal estimation using NormalBaeDetector " + "run directly under torch.no_grad. No TRT engine required." + ), + "parameters": {}, + "use_cases": ["Normal ControlNet conditioning"], + } + + def __init__(self, **kwargs): + if not CONTROLNET_AUX_AVAILABLE: + raise ImportError( + "controlnet_aux is required for normal map preprocessing. Install with: pip install controlnet_aux" + ) + super().__init__(**kwargs) + self._detector = None + self._load_model() + + def _load_model(self): + cache_key = f"normal_bae_{self.device}" + if cache_key in self._detector_cache: + self._detector = self._detector_cache[cache_key] + return + logger.info("NormalBae (torch-GPU): loading NormalBaeDetector…") + det = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") + det.model.to(self.device).eval() + det.norm.to(self.device) + self._detector = det + self._detector_cache[cache_key] = det + + def _process_core(self, image: Image.Image) -> Image.Image: + tensor = self.pil_to_tensor(image) + result = self._process_tensor_core(tensor) + return self.tensor_to_pil(result) + + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: + if not hasattr(self, "_detector") or self._detector is None: + raise RuntimeError( + f"{self.__class__.__name__}._process_tensor_core: model not initialized — " + "_load_model() was never called. This is a bug; please report it." + ) + with torch.no_grad(): + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) + image_tensor = image_tensor.to(device=self.device, dtype=torch.float32) + + # Apply NormalBae normalization + normed = self._detector.norm(image_tensor) + out = self._detector.model(normed) + + # Extract highest-res 3-channel output + normal = out[0][-1][:, :3] # (B, 3, H, W) + normal = ((normal + 1.0) * 0.5).clamp(0.0, 1.0) + + return normal.squeeze(0) # (3, H, W) + + +# --------------------------------------------------------------------------- +# Public class — chooses strategy at construction time +# --------------------------------------------------------------------------- + + +class NormalBaeTensorrtPreprocessor(SelfBuildingTRTPreprocessor): + """ + Normal map estimation — GPU-native via TRT engine (primary) or torch-direct (fallback). + + The class name retains the '_tensorrt' suffix so the existing engine-path + wiring in StreamDiffusionExt (the "tensorrt" in name gate, line 3572) works + correctly. When TRT is unavailable or ONNX export fails, construction + transparently returns a _NormalBaeTorchGPU instance instead. + """ + + engine_filename = "normal_bae.engine" + onnx_filename = "normal_bae.onnx" + default_detect_resolution = 512 + + @classmethod + def get_preprocessor_metadata(cls): + return { + "display_name": "Normal Map Estimation (TensorRT)", + "description": ( + "GPU-native surface normal estimation. Self-builds a TRT engine " + "from the controlnet_aux NormalBaeDetector model on first run. " + "Falls back to torch-direct GPU mode if TRT export is unavailable." + ), + "parameters": {}, + "use_cases": ["Normal ControlNet conditioning"], + } + + def __new__(cls, **kwargs): + """ + If TRT strategy is available return a SelfBuildingTRTPreprocessor subclass; + otherwise return the torch-direct GPU fallback transparently. + """ + if not CONTROLNET_AUX_AVAILABLE: + raise ImportError( + "controlnet_aux is required for NormalBaeTensorrtPreprocessor. " + "Install with: pip install controlnet_aux" + ) + + device = kwargs.get("device", "cuda" if torch.cuda.is_available() else "cpu") + use_trt = TENSORRT_AVAILABLE and _probe_normal_bae_onnx_export(device) + + if not use_trt: + # Return a fully-constructed fallback. + # + # Using object.__new__(_NormalBaeTorchGPU) here would cause CPython's + # type.__call__ to skip __init__ entirely, because _NormalBaeTorchGPU is + # NOT a subclass of NormalBaeTensorrtPreprocessor. The resulting object + # would have no self._detector, self.params, or self.device and raise + # AttributeError on the first frame. Calling the class directly runs + # _NormalBaeTorchGPU.__init__ correctly (Finding A fix). + return _NormalBaeTorchGPU(**kwargs) + + obj = object.__new__(cls) + return obj + + def __init__(self, **kwargs): + # __new__ now returns a fully-constructed _NormalBaeTorchGPU when TRT is + # unavailable, so CPython never calls this __init__ for the fallback path. + # The guard that was here ("if type(self) is _NormalBaeTorchGPU: return") + # was dead code and has been removed. + super().__init__(**kwargs) + + def _export_onnx(self, onnx_path: Path) -> None: + logger.info("NormalBaeTensorrtPreprocessor: loading NormalBaeDetector for ONNX export…") + det = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") + wrapper = NormalBaeExportWrapper(det.model, det.norm).to(self.device).eval() + + res = self.default_detect_resolution + dummy = torch.zeros(1, 3, res, res, device=self.device) + + with torch.no_grad(): + torch.onnx.export( + wrapper, + dummy, + str(onnx_path), + opset_version=17, + input_names=["input"], + output_names=["output"], + dynamic_axes={ + "input": {0: "batch", 2: "height", 3: "width"}, + "output": {0: "batch", 2: "height", 3: "width"}, + }, + ) + + logger.info(f"NormalBaeTensorrtPreprocessor: ONNX exported → {onnx_path}") + del wrapper, det + torch.cuda.empty_cache() + + def _postprocess(self, engine_outputs: dict) -> torch.Tensor: + """ + Convert TRT output (B, 3, H, W) [0,1] to CHW GPU tensor. + The export wrapper already applies the [0,1] normalisation. + """ + out = _first_output(engine_outputs).float() + if out.dim() == 4: + out = out.squeeze(0) # (3, H, W) + return out.clamp(0.0, 1.0) diff --git a/src/streamdiffusion/preprocessing/processors/passthrough.py b/src/streamdiffusion/preprocessing/processors/passthrough.py index e4d1125fe..c93011d0e 100644 --- a/src/streamdiffusion/preprocessing/processors/passthrough.py +++ b/src/streamdiffusion/preprocessing/processors/passthrough.py @@ -15,7 +15,10 @@ class PassthroughPreprocessor(BasePreprocessor): - Reference ControlNet - Custom ControlNets that don't need preprocessing """ - + + gpu_native = True # _process_tensor_core is a no-op identity — no CPU/PIL round-trip + + @classmethod def get_preprocessor_metadata(cls): return { diff --git a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py index 7662c37cc..072b5cb94 100644 --- a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py @@ -8,91 +8,7 @@ from PIL import Image from typing import Union, Optional, List, Tuple from .base import BasePreprocessor - -try: - import tensorrt as trt - from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict - TENSORRT_AVAILABLE = True -except ImportError: - TENSORRT_AVAILABLE = False - - -# Map of numpy dtype -> torch dtype -numpy_to_torch_dtype_dict = { - np.uint8: torch.uint8, - np.int8: torch.int8, - np.int16: torch.int16, - np.int32: torch.int32, - np.int64: torch.int64, - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, - np.complex64: torch.complex64, - np.complex128: torch.complex128, -} -if np.version.full_version >= "1.24.0": - numpy_to_torch_dtype_dict[np.bool_] = torch.bool -else: - numpy_to_torch_dtype_dict[np.bool] = torch.bool - - -class TensorRTEngine: - """Simplified TensorRT engine wrapper for pose estimation inference (optimized)""" - - def __init__(self, engine_path): - self.engine_path = engine_path - self.engine = None - self.context = None - self.tensors = OrderedDict() - self._cuda_stream = None # Cache CUDA stream - - def load(self): - """Load TensorRT engine from file""" - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self): - """Create execution context""" - self.context = self.engine.create_execution_context() - # Cache CUDA stream for reuse - self._cuda_stream = torch.cuda.current_stream().cuda_stream - - def allocate_buffers(self, device="cuda"): - """Allocate input/output buffers""" - for idx in range(self.engine.num_io_tensors): - name = self.engine.get_tensor_name(idx) - shape = self.context.get_tensor_shape(name) - dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: - self.context.set_input_shape(name, shape) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=device) - self.tensors[name] = tensor - - def infer(self, feed_dict, stream=None): - """Run inference with optional stream parameter""" - # Use cached stream if none provided - if stream is None: - stream = self._cuda_stream - - # Copy input data to tensors - for name, buf in feed_dict.items(): - self.tensors[name].copy_(buf) - - # Set tensor addresses - for name, tensor in self.tensors.items(): - self.context.set_tensor_address(name, tensor.data_ptr()) - - # Execute inference - success = self.context.execute_async_v3(stream) - if not success: - raise ValueError("TensorRT inference failed.") - - return self.tensors +from .trt_base import TENSORRT_AVAILABLE, TensorRTEngine # shared engine wrapper class PoseVisualization: @@ -207,6 +123,9 @@ def show_predictions_from_batch_format(predictions): class YoloNasPoseTensorrtPreprocessor(BasePreprocessor): + # TRT inference stays on GPU; keypoint-to-image rasterization has a tiny CPU hop + # (~17 sparse keypoints → cv2 draw → re-upload). Accepted by design (D5). + gpu_native = True """ YoloNas Pose TensorRT preprocessor for ControlNet diff --git a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py index 18adfbb26..182e923f8 100644 --- a/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py +++ b/src/streamdiffusion/preprocessing/processors/realesrgan_trt.py @@ -13,6 +13,7 @@ from collections import OrderedDict from .base import BasePreprocessor +from streamdiffusion.tools.gpu_profiler import profiler # Try to import spandrel for model loading try: @@ -102,24 +103,28 @@ def infer(self, feed_dict, stream=None): # Use provided stream or current stream context if stream is None: stream = torch.cuda.current_stream().cuda_stream - - # Copy input data to tensors + + # Copy input data to tensors (safe outside lock — per-engine buffers, single caller path) for name, buf in feed_dict.items(): self.tensors[name].copy_(buf) - # Set tensor addresses - for name, tensor in self.tensors.items(): - addr = tensor.data_ptr() - self.context.set_tensor_address(name, addr) - + # set_tensor_address + execute_async inside the lock. + # TRT execution contexts are not thread-safe; holding the lock only during enqueue + # (not during the GPU execution itself) keeps the critical section minimal. + # torch.cuda.synchronize() is removed: GPU stream ordering serialises downstream + # PyTorch ops (.clamp / .clone) that the caller enqueues on the same stream after + # this method returns — no explicit CPU-side wait needed. with self._inference_lock: - success = self.context.execute_async_v3(stream) - + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + with profiler.region("esrgan.infer"): + success = self.context.execute_async_v3(stream) + if not success: raise RuntimeError("RealESRGANEngine: TensorRT inference failed") - - torch.cuda.synchronize() - + + return self.tensors logger = logging.getLogger(__name__) @@ -167,14 +172,12 @@ def __init__(self, enable_tensorrt: bool = True, force_rebuild: bool = False, ** # Model state self.pytorch_model = None self._engine = None # Lazy loading like depth processor - + self._model_ready = False # Guards one-time lazy load + # Thread safety for engine initialization import threading self._engine_lock = threading.Lock() - - # Initialize - self._ensure_model_ready() - + @property def engine(self): """Lazy loading of the TensorRT engine""" @@ -213,7 +216,17 @@ def _download_file(self, url: str, save_path: Path): for data in response.iter_content(chunk_size=1024): size = file.write(data) progress_bar.update(size) - + + def _ensure_loaded_once(self): + """Idempotent, thread-safe lazy loader — called at the top of every process path.""" + if self._model_ready: + return + with self._engine_lock: + if self._model_ready: # double-checked locking + return + self._ensure_model_ready() + self._model_ready = True + def _ensure_model_ready(self): """Ensure PyTorch model is downloaded and loaded""" # Download model if needed @@ -399,6 +412,7 @@ def _process_with_pytorch(self, tensor: torch.Tensor) -> torch.Tensor: def _process_core(self, image: Image.Image) -> Image.Image: """Core processing using PIL Image""" + self._ensure_loaded_once() # Convert to tensor for processing tensor = self.pil_to_tensor(image) if tensor.dim() == 3: @@ -441,6 +455,7 @@ def _ensure_target_size_tensor(self, tensor: torch.Tensor) -> torch.Tensor: def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: """Core tensor processing""" + self._ensure_loaded_once() if tensor.dim() == 3: tensor = tensor.unsqueeze(0) squeeze_output = True diff --git a/src/streamdiffusion/preprocessing/processors/scribble.py b/src/streamdiffusion/preprocessing/processors/scribble.py new file mode 100644 index 000000000..ae73f8711 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/scribble.py @@ -0,0 +1,50 @@ +from PIL import Image + +from .hed import HEDPreprocessor + + +class ScribblePreprocessor(HEDPreprocessor): + """ + Scribble preprocessor for ControlNet conditioning + + Produces sketch-like scribble edge maps using the HED model with scribble mode enabled. + Reuses the HED model cache so no extra model download is needed when HED is already loaded. + Compatible with xinsir/controlnet-scribble-sdxl-1.0 and similar scribble ControlNets. + """ + + @classmethod + def get_preprocessor_metadata(cls): + return { + "display_name": "Scribble (HED)", + "description": "Produces scribble-style edge maps using HED in scribble mode. Compatible with scribble ControlNets.", + "parameters": { + "safe": { + "type": "bool", + "default": True, + "description": "Whether to use safe mode for edge detection", + } + }, + "use_cases": ["Scribble ControlNet conditioning", "Sketch-style edge maps"], + } + + def _process_core(self, image: Image.Image) -> Image.Image: + """Apply HED in scribble mode to produce sketch-like edge maps""" + target_width, target_height = self.get_target_dimensions() + + result = self.model(image, output_type="pil", scribble=True) + + if not isinstance(result, Image.Image): + import numpy as np + + if isinstance(result, np.ndarray): + result = Image.fromarray(result) + else: + raise ValueError(f"ScribblePreprocessor: unexpected result type: {type(result)}") + + if result.size != (target_width, target_height): + result = result.resize((target_width, target_height), Image.LANCZOS) + + return result + + # _process_tensor_core is inherited from HEDPreprocessor (PIL round-trip via tensor_to_pil / + # _process_core / pil_to_tensor) — same GPU class as openpose/lineart/hed. Acceptable for v1. diff --git a/src/streamdiffusion/preprocessing/processors/scribble_tensorrt.py b/src/streamdiffusion/preprocessing/processors/scribble_tensorrt.py new file mode 100644 index 000000000..af3389915 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/scribble_tensorrt.py @@ -0,0 +1,131 @@ +""" +Scribble TensorRT preprocessor — GPU-native scribble edge maps via TRT. + +Reuses the HED TRT engine (no second build needed). Overrides _postprocess +to apply a GPU-native NMS + binarization that replicates the scribble=True +post-processing from controlnet_aux HEDdetector.__call__: + + 1. Gaussian-blur the sigmoid edge map (smooth noise) + 2. Directional NMS — keep only local maxima (thin lines) + 3. Threshold at 0.5 → binary edge mask +""" + +import logging + +import torch +import torch.nn.functional as F + +from .hed_tensorrt import HEDTensorrtPreprocessor +from .trt_base import _first_output + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# GPU scribble NMS helper +# --------------------------------------------------------------------------- + + +def _scribble_nms_gpu(edge_map: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Approximate GPU version of controlnet_aux nms() used by scribble=True mode. + + Performs: + 1. Light Gaussian blur (3×3 average pool — avoids kornia dependency) + 2. 4-directional local-max suppression (keep only ridge pixels) + 3. Threshold at `threshold` + + Args: + edge_map: (H, W) float32 tensor in [0, 1] on GPU + threshold: binarization threshold (default 0.5) + + Returns: + (H, W) float32 binary tensor on GPU + """ + x = edge_map.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) + + # Step 1: smooth + x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + + # Step 2: directional NMS — keep pixel if it is the local max along + # each of the 4 scanning directions (horizontal, vertical, two diagonals). + # We approximate with isotropic max-pool (good enough for thin-line extraction). + x_max = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + is_max = (x == x_max).float() + thinned = x * is_max + + # Step 3: binarize + binary = (thinned.squeeze() > threshold).float() + + return binary + + +# --------------------------------------------------------------------------- +# Preprocessor +# --------------------------------------------------------------------------- + + +class ScribbleTensorrtPreprocessor(HEDTensorrtPreprocessor): + """ + Scribble edge maps via TRT — reuses the HED engine, overrides postprocess. + + The 'scribble' mode in controlnet_aux HEDdetector runs the same HED + network but adds an NMS + binarization step. Here we replicate that + step with GPU tensor operations, so the full pipeline stays on CUDA. + + No second engine build is needed: engine_filename points at hed.engine. + """ + + # Deliberately points at the HED engine — no separate build + engine_filename = "hed.engine" + onnx_filename = "hed.onnx" # kept consistent; export is never re-run if engine exists + default_detect_resolution = 512 + + @classmethod + def get_preprocessor_metadata(cls): + return { + "display_name": "Scribble Edge Detection (TensorRT)", + "description": ( + "GPU-native scribble-style edge maps. Uses the HED TRT engine with " + "GPU NMS + binarization post-processing (no CPU round-trips). " + "Compatible with scribble ControlNets." + ), + "parameters": { + "scribble_threshold": { + "type": "float", + "default": 0.5, + "description": "Binarization threshold for scribble edges (0–1)", + }, + }, + "use_cases": [ + "Scribble ControlNet conditioning", + "Sketch-style edge maps (real-time)", + ], + } + + def _postprocess(self, engine_outputs: dict) -> torch.Tensor: + """ + Apply scribble NMS + threshold to the HED output, return 3-channel CHW. + + Input : engine_outputs["output"] shape (B, 1, H, W) or (B, H, W) + Output : (3, H, W) in {0.0, 1.0} (binary scribble map) + """ + out = _first_output(engine_outputs).float() + + if out.dim() == 4: + out = out.squeeze(1) + if out.dim() == 3: + out = out.squeeze(0) # (H, W) + + # Normalize to [0, 1] before NMS + v_min, v_max = out.min(), out.max() + if v_max > v_min: + out = (out - v_min) / (v_max - v_min) + out = out.clamp(0.0, 1.0) + + threshold = float(self.params.get("scribble_threshold", 0.5)) + scribble = _scribble_nms_gpu(out, threshold=threshold) # (H, W) + + # Expand to 3-channel RGB + return scribble.unsqueeze(0).repeat(3, 1, 1) # (3, H, W) diff --git a/src/streamdiffusion/preprocessing/processors/soft_edge.py b/src/streamdiffusion/preprocessing/processors/soft_edge.py index 67537982b..63db56adb 100644 --- a/src/streamdiffusion/preprocessing/processors/soft_edge.py +++ b/src/streamdiffusion/preprocessing/processors/soft_edge.py @@ -158,7 +158,8 @@ class SoftEdgePreprocessor(BasePreprocessor): Uses multi-scale Sobel operations for extremely fast soft edge detection that mimics HED output quality at 50x+ the speed. """ - + + gpu_native = True # _process_tensor_core uses torch ops under no_grad — no PIL round-trip _model_cache = {} @classmethod diff --git a/src/streamdiffusion/preprocessing/processors/standard_lineart.py b/src/streamdiffusion/preprocessing/processors/standard_lineart.py index bc732ea04..ad2bdff1f 100644 --- a/src/streamdiffusion/preprocessing/processors/standard_lineart.py +++ b/src/streamdiffusion/preprocessing/processors/standard_lineart.py @@ -16,7 +16,10 @@ class StandardLineartPreprocessor(BasePreprocessor): Uses Gaussian blur and intensity calculations to detect lines without requiring pre-trained models. GPU-accelerated with PyTorch for optimal real-time performance. """ - + + gpu_native = True # _process_tensor_core uses torch ops — no CPU/PIL round-trip + + @classmethod def get_preprocessor_metadata(cls): return { @@ -176,50 +179,84 @@ def remove_pad(x): return x[:H_target, :W_target, ...] return img_padded, remove_pad - - def _process_core(self, image: Image.Image) -> Image.Image: + + def _compute_lineart_hwc(self, input_image: torch.Tensor) -> torch.Tensor: """ - Apply standard line art detection to the input image + Core line art computation on an HWC float tensor in [0, 255] on self.device. + + Args: + input_image: HWC float32 tensor in [0, 255] already on device, already padded + to the detect_resolution (with remove_pad closure returned separately). + + Returns: + HWC float32 tensor in [0, 255] (3-channel) on the same device. """ - start_time = time.time() - - if isinstance(image, Image.Image): - input_image_cpu = np.array(image, dtype=np.uint8) - else: - input_image_cpu = image.astype(np.uint8) - - input_image = torch.from_numpy(input_image_cpu).float().to(self.device) - - detect_resolution = self.params.get('detect_resolution', 512) - gaussian_sigma = self.params.get('gaussian_sigma', 6.0) - intensity_threshold = self.params.get('intensity_threshold', 8) - - input_image, remove_pad = self._resize_image_with_pad_torch(input_image, detect_resolution) - - x = input_image - - g = self._gaussian_blur_torch(x, gaussian_sigma) - - intensity = torch.min(g - x, dim=2)[0] + gaussian_sigma = self.params.get("gaussian_sigma", 6.0) + intensity_threshold = self.params.get("intensity_threshold", 8) + + g = self._gaussian_blur_torch(input_image, gaussian_sigma) + + intensity = torch.min(g - input_image, dim=2)[0] intensity = torch.clamp(intensity, 0, 255) threshold_mask = intensity > intensity_threshold - if torch.any(threshold_mask): - median_val = torch.median(intensity[threshold_mask]) - normalization_factor = max(16, float(median_val)) - else: - normalization_factor = 16 - + # Sync-free: nanmedian over thresholded pixels equals median(intensity[threshold_mask]). + # All-False mask → every element is nan → nan_to_num floors to 16. + # normalization_factor stays as a 0-dim CUDA tensor — no .item() / host sync. + masked = torch.where(threshold_mask, intensity, torch.full_like(intensity, float("nan"))) + median_val = torch.nanmedian(masked) + normalization_factor = torch.clamp_min(torch.nan_to_num(median_val, nan=16.0), 16.0) + + intensity = intensity / normalization_factor intensity = intensity * 127 detected_map = torch.clamp(intensity, 0, 255).byte() detected_map = detected_map.unsqueeze(-1) detected_map = self._ensure_hwc3_torch(detected_map.float()) - + return detected_map + + def _process_core(self, image: Image.Image) -> Image.Image: + """ + Apply standard line art detection to the input image (PIL I/O path). + """ + time.time() + + if isinstance(image, Image.Image): + input_image_cpu = np.array(image, dtype=np.uint8) + else: + input_image_cpu = image.astype(np.uint8) + + input_image = torch.from_numpy(input_image_cpu).float().to(self.device) + + detect_resolution = self.params.get("detect_resolution", 512) + input_image, remove_pad = self._resize_image_with_pad_torch(input_image, detect_resolution) + + detected_map = self._compute_lineart_hwc(input_image) detected_map = remove_pad(detected_map) detected_map_cpu = detected_map.byte().cpu().numpy() - lineart_image = Image.fromarray(detected_map_cpu) - - return lineart_image \ No newline at end of file + return Image.fromarray(detected_map_cpu) + + def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: + """ + GPU-native line art detection — no PIL round-trip. + + Receives a CHW float32 tensor in [0, 1] on device (guaranteed by validate_tensor_input). + Returns a CHW float32 tensor in [0, 1] on the same device. + """ + detect_resolution = self.params.get("detect_resolution", 512) + + # CHW [0,1] → HWC [0,255] + hwc = tensor.permute(1, 2, 0) * 255.0 + + # Ensure on the right device + if hwc.device != torch.device(self.device): + hwc = hwc.to(self.device) + + hwc, remove_pad = self._resize_image_with_pad_torch(hwc, detect_resolution) + detected_map = self._compute_lineart_hwc(hwc) + detected_map = remove_pad(detected_map) + + # HWC [0,255] → CHW [0,1] + return detected_map.permute(2, 0, 1) / 255.0 diff --git a/src/streamdiffusion/preprocessing/processors/trt_base.py b/src/streamdiffusion/preprocessing/processors/trt_base.py new file mode 100644 index 000000000..56b8852a6 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/trt_base.py @@ -0,0 +1,602 @@ +""" +Shared TensorRT infrastructure for ControlNet preprocessors. + +Provides: + TensorRTEngine — low-level TRT engine wrapper (load/activate/infer). + Extracted from the verbatim copies in depth_tensorrt.py + and pose_tensorrt.py; those files now import from here. + + SelfBuildingTRTPreprocessor — base class for preprocessors that self-build their TRT + engine from a torch model at first use. Subclasses only + need to implement two hooks: + _export_onnx(onnx_path) — model-specific ONNX export + _postprocess(engine_outputs) -> T — GPU-only output shaping + plus three class attributes: + engine_filename, onnx_filename, default_detect_resolution +""" + +import logging +import threading +from abc import abstractmethod +from collections import OrderedDict +from pathlib import Path +from typing import Dict, Optional + +import torch +import torch.nn.functional as F +from PIL import Image + +from .base import BasePreprocessor + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional TRT / polygraphy imports +# --------------------------------------------------------------------------- +try: + import numpy as np + import tensorrt as trt + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import engine_from_bytes + + numpy_to_torch_dtype_dict: Dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, + } + if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool + else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool # type: ignore[attr-defined] + + TENSORRT_AVAILABLE = True +except ImportError: + TENSORRT_AVAILABLE = False + numpy_to_torch_dtype_dict = {} + + +# --------------------------------------------------------------------------- +# Shared TensorRT engine wrapper +# --------------------------------------------------------------------------- + + +class TensorRTEngine: + """ + Thin wrapper around a TensorRT ICudaEngine + IExecutionContext. + + Identical to the copies in depth_tensorrt.py and pose_tensorrt.py; + those modules import this class instead of redefining it. + """ + + # Max number of distinct input-shape configurations whose GPU buffers are kept alive. + # Covers typical resolution-switching scenarios (e.g., 256/512/768/1024). + _BUF_CACHE_MAXSIZE: int = 4 + + def __init__(self, engine_path: str): + self.engine_path = engine_path + self.engine = None + self.context = None + self.tensors = OrderedDict() + self._cuda_stream: Optional[int] = None # raw CUDA stream handle (int) for TRT + self._dedicated_stream: Optional[torch.cuda.Stream] = None # backing non-default stream + self._pre_exec_event: Optional[torch.cuda.Event] = None # current→dedicated barrier + self._post_exec_event: Optional[torch.cuda.Event] = None # dedicated→current barrier + # LRU cache: shape-signature → {name: tensor}. + # Avoids repeated GPU malloc/free when a small set of input shapes alternates. + self._buf_cache: OrderedDict = OrderedDict() + + def load(self): + logger.info(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + # Create a dedicated non-default CUDA stream for this engine so that + # execute_async_v3 / enqueueV3 does NOT run on stream 0x0 (the default/null + # stream). Using the default stream forces TensorRT to insert an implicit + # cudaStreamSynchronize on every enqueue call (TRT warning: + # "Using default stream in enqueueV3() may lead to performance issues"). + # Cross-stream ordering with the surrounding PyTorch context is maintained + # via the CUDA events created below; see infer() for the sync protocol. + self._dedicated_stream = torch.cuda.Stream() + self._cuda_stream = self._dedicated_stream.cuda_stream # raw int handle for TRT + self._pre_exec_event = torch.cuda.Event() # current stream → dedicated stream barrier + self._post_exec_event = torch.cuda.Event() # dedicated stream → current stream barrier + + def allocate_buffers(self, device: str = "cuda", input_shape: tuple = None): + """ + Allocate GPU buffers for all engine I/O tensors. + + For dynamic-shape engines the caller must pass ``input_shape`` (concrete + NCHW tuple) so input dims are resolved before output shapes are queried. + Without it, ``get_tensor_shape`` returns -1 for dynamic dims and the + subsequent ``torch.empty`` call fails or allocates with a stale shape. + + Args: + device: CUDA device string (default ``"cuda"``) + input_shape: Concrete ``(N, C, H, W)`` shape for the engine's INPUT tensor. + Required when the engine was built with a dynamic-shape + optimization profile. + """ + # Pass 1: set all INPUT shapes so TRT can resolve downstream output shapes. + for idx in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(idx) + if self.engine.get_tensor_mode(name) != trt.TensorIOMode.INPUT: + continue + if input_shape is not None: + self.context.set_input_shape(name, input_shape) + else: + static_shape = tuple(self.context.get_tensor_shape(name)) + if any(d < 0 for d in static_shape): + raise RuntimeError( + f"TensorRTEngine.allocate_buffers: tensor '{name}' has dynamic " + f"shape {static_shape} but no input_shape was provided. " + "Pass input_shape=(N, C, H, W) when using a dynamic engine." + ) + self.context.set_input_shape(name, static_shape) + + # Pass 2: allocate buffers for ALL tensors (output shapes resolved by TRT now). + for idx in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(idx) + dtype = trt.nptype(self.engine.get_tensor_dtype(name)) + is_input = self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT + + if is_input and input_shape is not None: + shape = input_shape + else: + shape = tuple(self.context.get_tensor_shape(name)) + if any(d < 0 for d in shape): + raise RuntimeError( + f"TensorRTEngine.allocate_buffers: tensor '{name}' still has " + f"unresolved dynamic dims {shape} after setting input shapes. " + "Provide input_shape to resolve all dimensions." + ) + + # Allocate directly on the target device — avoids CPU alloc + H2D copy + # that `torch.empty(...).to(device=device)` would incur. + tensor = torch.empty( + shape, + dtype=numpy_to_torch_dtype_dict[dtype], + device=device, + ) + self.tensors[name] = tensor + + def infer(self, feed_dict: dict, stream=None) -> OrderedDict: + if stream is None: + stream = self._cuda_stream + + # --- Per-request shape reconciliation with LRU buffer cache --- + # Fast path: no shape change (the common streaming case — zero overhead). + # Slow path: shape changed → consult LRU before allocating new GPU memory. + new_input_shapes = {name: tuple(buf.shape) for name, buf in feed_dict.items() if name in self.tensors} + shapes_match = all(new_input_shapes[n] == tuple(self.tensors[n].shape) for n in new_input_shapes) + + if not shapes_match: + # Hashable signature for this input shape configuration. + shape_sig = tuple(sorted(new_input_shapes.items())) + + if shape_sig in self._buf_cache: + # LRU hit: reuse pre-allocated GPU buffers, no malloc needed. + self._buf_cache.move_to_end(shape_sig) # promote to MRU + cached = self._buf_cache[shape_sig] + for name in list(self.tensors.keys()): + if name in cached: + self.tensors[name] = cached[name] + # Re-apply input shapes to TRT context (context state is NOT cached). + for name, shape in new_input_shapes.items(): + self.context.set_input_shape(name, shape) + else: + # LRU miss: reallocate changed inputs, re-derive output shapes. + for name, fed_shape in new_input_shapes.items(): + if fed_shape != tuple(self.tensors[name].shape): + self.context.set_input_shape(name, fed_shape) + self.tensors[name] = torch.empty( + fed_shape, + dtype=self.tensors[name].dtype, + device=self.tensors[name].device, + ) + + # Re-query and reallocate output buffers with TRT-resolved shapes. + for out_idx in range(self.engine.num_io_tensors): + out_name = self.engine.get_tensor_name(out_idx) + if self.engine.get_tensor_mode(out_name) == trt.TensorIOMode.OUTPUT: + new_out_shape = tuple(self.context.get_tensor_shape(out_name)) + if new_out_shape != tuple(self.tensors[out_name].shape): + self.tensors[out_name] = torch.empty( + new_out_shape, + dtype=self.tensors[out_name].dtype, + device=self.tensors[out_name].device, + ) + + # Store the new buffer set in the LRU cache. + self._buf_cache[shape_sig] = OrderedDict(self.tensors) + if len(self._buf_cache) > self._BUF_CACHE_MAXSIZE: + self._buf_cache.popitem(last=False) # evict LRU (oldest) entry + + # --- Copy inputs with dtype validation --- + for name, buf in feed_dict.items(): + if self.tensors[name].dtype != buf.dtype: + raise ValueError( + f"TensorRTEngine.infer: dtype mismatch for tensor '{name}': " + f"engine expects {self.tensors[name].dtype}, got {buf.dtype}. " + f"(engine: {self.engine_path})" + ) + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + # --- Cross-stream synchronization --- + # The input copy_() calls above ran on the CURRENT (default) stream. + # execute_async_v3 runs on _dedicated_stream. Record a barrier event on + # the current stream so the dedicated stream cannot start reading inputs + # until the copies have landed. (If no dedicated stream was created yet — + # e.g. engine not activated — fall back to the supplied stream directly.) + if self._pre_exec_event is not None and self._dedicated_stream is not None: + self._pre_exec_event.record() # on current stream + self._dedicated_stream.wait_event(self._pre_exec_event) + exec_stream = self._cuda_stream + else: + exec_stream = stream + + success = self.context.execute_async_v3(exec_stream) + if not success: + raise ValueError("TensorRTEngine: inference failed.") + + # Output tensors were written by execute on _dedicated_stream. + # _postprocess (the next call after infer()) runs on the current stream and + # reads those tensors. Make the current stream GPU-wait for execute + # completion, then record_stream so PyTorch's caching allocator knows the + # buffers are live on the current stream (prevents premature reuse). + if self._post_exec_event is not None and self._dedicated_stream is not None: + self._post_exec_event.record(self._dedicated_stream) + torch.cuda.current_stream().wait_event(self._post_exec_event) + for tensor in self.tensors.values(): + tensor.record_stream(torch.cuda.current_stream()) + + return self.tensors + + +# --------------------------------------------------------------------------- +# Output-key helper — guards against TRT renaming the output tensor +# --------------------------------------------------------------------------- + + +def _first_output(engine_outputs: dict) -> torch.Tensor: + """ + Return the ``'output'`` tensor from TRT engine outputs, or the first + non-``'input'`` key if ``'output'`` is absent. + + TRT may rename output tensors depending on the ONNX model and opset. + Using this helper instead of a hard-coded ``engine_outputs["output"]`` + guards against a bare ``KeyError`` when the tensor name doesn't match. + + Args: + engine_outputs: Dict returned by :meth:`TensorRTEngine.infer`. + + Returns: + The output tensor. + + Raises: + KeyError: if no output tensor is found (e.g. all keys are inputs). + """ + if "output" in engine_outputs: + return engine_outputs["output"] + candidates = [v for k, v in engine_outputs.items() if not k.startswith("input")] + if candidates: + return candidates[0] + raise KeyError(f"TRT engine returned no recognizable output tensor. Available keys: {list(engine_outputs.keys())}") + + +# --------------------------------------------------------------------------- +# Self-building TRT preprocessor base +# --------------------------------------------------------------------------- + + +class SelfBuildingTRTPreprocessor(BasePreprocessor): + """ + Base class for TRT preprocessors that build their own engine on first use. + + Subclass interface + ------------------ + Class attributes (override in subclass): + engine_filename : str = "engine.engine" + onnx_filename : str = "engine.onnx" + default_detect_resolution : int = 512 + + Abstract methods (implement in subclass): + _export_onnx(onnx_path: Path) -> None + Export the underlying torch model to ONNX at onnx_path. + + _postprocess(engine_outputs: dict) -> torch.Tensor + Convert raw TRT output tensors to a CHW GPU tensor in [0, 1]. + + Engine-path precedence + ---------------------- + 1. params["engine_path"] — TD always supplies this via StreamDiffusionExt config-gen + 2. /engines/preprocessors/ — offline fallback + + Build-registry hook + ------------------- + td_manager._ensure_preprocessor_engines calls: + cls.build_engine_for_path(engine_path, device) + which instantiates the preprocessor and runs _ensure_engine(). + """ + + gpu_native = True + # One-time FP8→FP16 fallback log: keyed by class name so each subclass logs once. + _fp8_warned_classes: set = set() + + # Subclasses set these: + engine_filename: str = "engine.engine" + onnx_filename: str = "engine.onnx" + default_detect_resolution: int = 512 + + def __init__(self, **kwargs): + if not TENSORRT_AVAILABLE: + raise ImportError( + "TensorRT and polygraphy are required for TRT preprocessors. " + "Install with: pip install tensorrt polygraphy" + ) + super().__init__(**kwargs) + self._engine: Optional[TensorRTEngine] = None + self._engine_lock = threading.Lock() + + # ------------------------------------------------------------------ + # PIL fallback path — goes through tensor for GPU residency + # ------------------------------------------------------------------ + + def _process_core(self, image: Image.Image) -> Image.Image: + tensor = self.pil_to_tensor(image) + result = self._process_tensor_core(tensor) + return self.tensor_to_pil(result) + + # ------------------------------------------------------------------ + # Engine path resolution + # ------------------------------------------------------------------ + + def _get_engine_path(self) -> Path: + from_params = self.params.get("engine_path") + if from_params: + return Path(from_params) + # Default fallback: /engines/preprocessors/ + repo_root = Path(__file__).resolve().parent.parent.parent.parent.parent + return repo_root / "engines" / "preprocessors" / self.engine_filename + + def _get_onnx_path(self, engine_path: Path) -> Path: + return engine_path.parent / self.onnx_filename + + # ------------------------------------------------------------------ + # Engine lifecycle + # ------------------------------------------------------------------ + + @property + def engine(self) -> TensorRTEngine: + """Lazy-load the TRT engine (double-checked locking).""" + if self._engine is None: + with self._engine_lock: + if self._engine is None: + cls_name = self.__class__.__name__ + engine_path = self._get_engine_path() + try: + self._ensure_engine() + except Exception as exc: + raise RuntimeError(f"{cls_name}: engine build/export failed for {engine_path}: {exc}") from exc + if not engine_path.exists(): + raise FileNotFoundError(f"{cls_name}: engine not found after build: {engine_path}") + try: + trt_engine = TensorRTEngine(str(engine_path)) + trt_engine.load() + trt_engine.activate() + trt_engine.allocate_buffers( + device=self.device, + input_shape=( + 1, + 3, + self.default_detect_resolution, + self.default_detect_resolution, + ), + ) + self._engine = trt_engine + except Exception as exc: + raise RuntimeError( + f"{cls_name}: engine load/activate/allocate failed for {engine_path}: {exc}" + ) from exc + return self._engine + + def _ensure_engine(self) -> None: + """Build the TRT engine from scratch if it doesn't exist yet.""" + engine_path = self._get_engine_path() + if engine_path.exists(): + return + + engine_path.parent.mkdir(parents=True, exist_ok=True) + onnx_path = self._get_onnx_path(engine_path) + + try: + logger.info(f"{self.__class__.__name__}: exporting ONNX → {onnx_path}") + self._export_onnx(onnx_path) + logger.info(f"{self.__class__.__name__}: building TRT engine → {engine_path}") + self._build_tensorrt_engine(onnx_path, engine_path) + logger.info(f"{self.__class__.__name__}: engine built ({engine_path.stat().st_size / 1024 / 1024:.1f} MB)") + finally: + # Always clean up the ONNX intermediary + if onnx_path.exists(): + onnx_path.unlink() + + def _build_tensorrt_engine(self, onnx_path: Path, engine_path: Path) -> None: + """Build TRT engine from ONNX using trt.Builder with FP16 + dynamic shapes. + + FP16 is always used; FP8 builds produce a one-time info log and fall back to FP16 + (no calibration infrastructure for preprocessor engines). The active UI profile's + ``builder_optimization_level`` is applied via the shared GPU-profile helper so + build quality matches the main UNet/VAE build for the selected profile. + """ + if not onnx_path.exists(): + raise FileNotFoundError(f"ONNX model not found: {onnx_path}") + + builder = trt.Builder(trt.Logger(trt.Logger.WARNING)) + network = builder.create_network() + parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) + + with open(onnx_path, "rb") as f: + if not parser.parse(f.read()): + errors = [str(parser.get_error(i)) for i in range(parser.num_errors)] + raise RuntimeError(f"{self.__class__.__name__}: ONNX parse failed: {errors}") + + config = builder.create_builder_config() + config.set_flag(trt.BuilderFlag.FP16) + + # FP8 guard: one-time log, always build FP16 (no Q/DQ calibration infra for preprocessors). + if self.params.get("build_fp8", False): + cls_name = self.__class__.__name__ + if cls_name not in SelfBuildingTRTPreprocessor._fp8_warned_classes: + logger.info( + "%s: FP8 Q/DQ is not applied to preprocessor engines " + "(no calibration infrastructure for tiny conv detectors). " + "Building FP16 instead.", + cls_name, + ) + SelfBuildingTRTPreprocessor._fp8_warned_classes.add(cls_name) + + # Apply builder_optimization_level via the shared GPU-profile helper. + # This honours the active UI profile (Flexible/Quality/Performance/Fast Build) + # at build time. The preprocessor engine is always dynamic + FP16 regardless. + opt_level = self.params.get("builder_optimization_level") + try: + from streamdiffusion.acceleration.tensorrt.utilities import ( + _apply_gpu_profile_to_config, + detect_gpu_profile, + ) + + gpu_profile = detect_gpu_profile() + # dynamic_shapes=True: tiling / l2tc helpers suppressed automatically + _apply_gpu_profile_to_config(config, gpu_profile, dynamic_shapes=True) + # Per-UI-profile override takes precedence over hardware-detected default + if opt_level is not None: + config.builder_optimization_level = int(opt_level) + logger.info( + "%s: builder_optimization_level set to %d (from UI profile)", + self.__class__.__name__, + int(opt_level), + ) + except Exception as exc: + logger.debug( + "%s: GPU profile helper not available (%s); using TRT defaults.", + self.__class__.__name__, + exc, + ) + if opt_level is not None: + try: + config.builder_optimization_level = int(opt_level) + except AttributeError: + logger.debug( + "%s: config.builder_optimization_level not supported by this TRT version.", + self.__class__.__name__, + ) + + profile = builder.create_optimization_profile() + res = self.default_detect_resolution + profile.set_shape( + "input", + (1, 3, res // 2, res // 2), # min + (1, 3, res, res), # opt + (1, 3, res * 2, res * 2), # max + ) + config.add_optimization_profile(profile) + + serialized = builder.build_serialized_network(network, config) + if serialized is None: + raise RuntimeError(f"{self.__class__.__name__}: TRT engine build returned None") + + with open(engine_path, "wb") as f: + f.write(serialized) + + # ------------------------------------------------------------------ + # Subclass hooks (must override) + # ------------------------------------------------------------------ + + @abstractmethod + def _export_onnx(self, onnx_path: Path) -> None: + """Export the underlying torch model to ONNX at onnx_path.""" + + @abstractmethod + def _postprocess(self, engine_outputs: dict) -> torch.Tensor: + """Convert raw TRT outputs to a CHW GPU tensor in [0, 1].""" + + # ------------------------------------------------------------------ + # Core tensor processing + # ------------------------------------------------------------------ + + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: + """ + Resize → TRT infer → postprocess. All on GPU, no PIL round-trip. + """ + if image_tensor.dim() == 3: + image_tensor = image_tensor.unsqueeze(0) + if not image_tensor.is_cuda: + image_tensor = image_tensor.to(self.device) + + detect_resolution = self.params.get("detect_resolution", self.default_detect_resolution) + image_resized = F.interpolate( + image_tensor.float(), + size=(detect_resolution, detect_resolution), + mode="bilinear", + align_corners=False, + ) + + # Match the dtype the engine expects on its input tensor + engine_input = self.engine.tensors.get("input") + if engine_input is not None and image_resized.dtype != engine_input.dtype: + image_resized = image_resized.to(dtype=engine_input.dtype) + + # Execute on the engine's dedicated non-default CUDA stream. + # Passing no stream lets infer() use self.engine._cuda_stream (the dedicated + # stream handle). Cross-stream sync (copy_ → execute → _postprocess) is + # handled inside TensorRTEngine.infer() via CUDA events. + outputs = self.engine.infer({"input": image_resized}) + result = self._postprocess(outputs) + + # Ensure result is CHW (strip batch dim if present) + if result.dim() == 4: + result = result.squeeze(0) + return result + + # ------------------------------------------------------------------ + # Class-level build hook called by td_manager._ensure_preprocessor_engines + # ------------------------------------------------------------------ + + @classmethod + def build_engine_for_path(cls, engine_path: str, device: str = "cuda") -> bool: + """ + Build (export + compile) the TRT engine and write it to engine_path. + + Called by td_manager._ensure_preprocessor_engines for preprocessors + that use the 'self_build' strategy in the build_registry. + + Returns True on success, False on failure. + """ + try: + instance = cls(engine_path=engine_path, device=device) + instance._ensure_engine() + return Path(engine_path).exists() + except Exception as exc: + logger.exception( + "%s.build_engine_for_path failed for %s: %s", + cls.__name__, + engine_path, + exc, + ) + return False + + def __del__(self): + if hasattr(self, "_engine") and self._engine is not None: + del self._engine diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 8f81b6f89..501e25472 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -256,6 +256,7 @@ def update_stream_params( latent_postprocessing_config: Optional[List[Dict[str, Any]]] = None, cache_maxframes: Optional[int] = None, cache_interval: Optional[int] = None, + cn_cache_interval: Optional[int] = None, ) -> None: """Update streaming parameters efficiently in a single call.""" @@ -393,6 +394,13 @@ def update_stream_params( else: logger.info(f"update_stream_params: Cache maxframes set to {cache_maxframes}") + # ControlNet residual cache interval — delegate to CN module if present. + if cn_cache_interval is not None: + cn_mod = self._get_controlnet_pipeline() + if cn_mod is not None: + cn_mod.set_cn_cache_interval(int(cn_cache_interval)) + logger.info(f"update_stream_params: cn_cache_interval -> {int(cn_cache_interval)}") + @torch.inference_mode() def update_prompt_weights( self, diff --git a/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py b/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py index 355ad6fc1..ac0ed7465 100644 --- a/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py +++ b/src/streamdiffusion/tools/compile_depth_anything_tensorrt.py @@ -20,10 +20,15 @@ logger = logging.getLogger(__name__) try: - import tensorrt as trt + import tensorrt as trt # noqa: E402 + + from streamdiffusion.acceleration.tensorrt.utilities import BUILD_TRT_LOGGER + + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False + BUILD_TRT_LOGGER = None logger.warning("TensorRT not available. Please install it first.") try: @@ -120,10 +125,9 @@ def build_tensorrt_engine( logger.info(f"Building TensorRT engine: {engine_path}") try: - trt_logger = trt.Logger(trt.Logger.INFO) - builder = trt.Builder(trt_logger) + builder = trt.Builder(BUILD_TRT_LOGGER) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) - parser = trt.OnnxParser(network, trt_logger) + parser = trt.OnnxParser(network, BUILD_TRT_LOGGER) # Parse ONNX with open(onnx_path, 'rb') as f: diff --git a/src/streamdiffusion/tools/compile_raft_tensorrt.py b/src/streamdiffusion/tools/compile_raft_tensorrt.py index 8734987ef..6c6650786 100644 --- a/src/streamdiffusion/tools/compile_raft_tensorrt.py +++ b/src/streamdiffusion/tools/compile_raft_tensorrt.py @@ -9,9 +9,14 @@ try: import tensorrt as trt + + from streamdiffusion.acceleration.tensorrt.utilities import BUILD_TRT_LOGGER + + TENSORRT_AVAILABLE = True except ImportError: TENSORRT_AVAILABLE = False + BUILD_TRT_LOGGER = None logger.error("TensorRT not available. Please install it first.") try: @@ -143,10 +148,11 @@ def build_tensorrt_engine( logger.info("This may take several minutes...") try: - builder = trt.Builder(trt.Logger(trt.Logger.INFO)) + builder = trt.Builder(BUILD_TRT_LOGGER) network = builder.create_network() # EXPLICIT_BATCH deprecated/ignored in TRT 10.x - parser = trt.OnnxParser(network, trt.Logger(trt.Logger.WARNING)) - + parser = trt.OnnxParser(network, BUILD_TRT_LOGGER) + + logger.info("Parsing ONNX model...") with open(onnx_path, 'rb') as model: if not parser.parse(model.read()): diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 36abe02b0..45489cb5b 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -127,6 +127,7 @@ def __init__( cache_interval: int = 1, min_cache_maxframes: int = 1, max_cache_maxframes: int = 4, + cn_cache_interval: int = 1, fp8: bool = False, static_shapes: bool = False, fp8_allow_fp16_fallback: bool = False, @@ -360,6 +361,7 @@ def __init__( cache_interval=cache_interval, min_cache_maxframes=min_cache_maxframes, max_cache_maxframes=max_cache_maxframes, + cn_cache_interval=cn_cache_interval, fp8=fp8, ) @@ -623,6 +625,8 @@ def update_stream_params( safety_checker_threshold: Optional[float] = None, cache_maxframes: Optional[int] = None, cache_interval: Optional[int] = None, + # ControlNet residual cache interval (1=off, N>1=reuse residuals for N-1 frames) + cn_cache_interval: Optional[int] = None, ) -> None: """ Update streaming parameters efficiently in a single call. @@ -696,6 +700,7 @@ def update_stream_params( latent_postprocessing_config=latent_postprocessing_config, cache_maxframes=cache_maxframes, cache_interval=cache_interval, + cn_cache_interval=cn_cache_interval, ) finally: if needs_encoding: @@ -1087,6 +1092,7 @@ def _load_model( cache_interval: int = 1, min_cache_maxframes: int = 1, max_cache_maxframes: int = 4, + cn_cache_interval: int = 1, fp8: bool = False, ) -> StreamDiffusion: """ @@ -1567,6 +1573,9 @@ def _load_model( fp8=fp8, resolution=(self.height, self.width), builder_optimization_level=self.builder_optimization_level, + # Must match the hardcoded build_static_batch value below so the cache + # key reflects the actual TRT profile policy (static vs dynamic batch). + build_static_batch=True, ) vae_encoder_path = engine_manager.get_engine_path( EngineType.VAE_ENCODER, @@ -1902,8 +1911,13 @@ def _load_model( vae_dtype = stream.vae.dtype try: + # Note: the UNet always builds with build_static_batch=True / + # build_dynamic_shape=False regardless of self.static_shapes. + # static_shapes only controls the VAE enc/dec build flags. logger.warning( - f"[TRT] UNet engine: fp8={fp8}, static_shapes={self.static_shapes}, engine_path={unet_path}" + f"[TRT] UNet engine: fp8={fp8}, " + f"build_static_batch=True, build_dynamic_shape=False, " + f"engine_path={unet_path}" ) _unet_build_opts = { "opt_image_height": self.height, @@ -2134,6 +2148,9 @@ def _load_model( cn_module.add_controlnet(cn_cfg, control_image=cfg.get("control_image")) # Expose for later updates if needed by caller code stream._controlnet_module = cn_module + # Apply startup cache interval from config (1 = disabled, no-op). + if cn_cache_interval > 1: + cn_module.set_cn_cache_interval(cn_cache_interval) if acceleration == "tensorrt": try: diff --git a/tests/manual/smoke_self_build_preprocessors.py b/tests/manual/smoke_self_build_preprocessors.py new file mode 100644 index 000000000..96d6f49c6 --- /dev/null +++ b/tests/manual/smoke_self_build_preprocessors.py @@ -0,0 +1,222 @@ +""" +Manual GPU smoke test for the self-building TRT preprocessors (HED / Scribble / NormalBae). + +Purpose: + Exercise the paths that depth_tensorrt (static, constant-res) did NOT cover: + - SelfBuildingTRTPreprocessor._build_tensorrt_engine (FP8→FP16 fallback log, + opt-level threading, dynamic TRT profile) + - TensorRTEngine.allocate_buffers resolving -1 dims from input_shape + - TensorRTEngine.infer dynamic-shape RECONCILE (384 after 512) — the alignment + guarantee that commit 850e8eb added and depth never exercised + - TensorRTEngine._first_output postprocess path + - NormalBaeTensorrtPreprocessor.__new__ fallback (Finding A) on real GPU + +Prerequisites: + - Run from the repo root inside the project venv: + python tests/manual/smoke_self_build_preprocessors.py + - TensorRT must be installed and a CUDA GPU must be present. + - controlnet_aux must be installed (for NormalBae). + +Committed as a manual (run-only) GPU smoke test; output engines are written to a tempdir and discarded. +""" + +import logging +import sys +import tempfile +from pathlib import Path + +import torch + + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s — %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger("smoke_self_build") + +PASS = "[PASS]" +FAIL = "[FAIL]" +SKIP = "[SKIP]" + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _assert(condition: bool, msg: str) -> None: + if not condition: + raise AssertionError(msg) + + +# --------------------------------------------------------------------------- +# Part A — HED and Scribble self-build + dual-resolution reconcile + dtype guard +# --------------------------------------------------------------------------- + + +def _smoke_self_build(name: str, cls, tmpdir: Path, *, build_fp8: bool) -> None: + """Run the self-build + dual-resolution + dtype-guard assertions for one preprocessor.""" + engine_path = str(tmpdir / f"{name}.engine") + logger.info(f"\n{'=' * 60}") + logger.info(f"Smoke: {name} fp8={build_fp8} engine={engine_path}") + logger.info(f"{'=' * 60}") + + # --- 1. Instantiate (does NOT build yet — lazy) --- + proc = cls( + engine_path=engine_path, + build_fp8=build_fp8, + builder_optimization_level=4, + device="cuda", + ) + + # --- 2. Access .engine → triggers _ensure_engine → _build_tensorrt_engine --- + logger.info(f"[{name}] Triggering engine build via .engine access…") + eng = proc.engine + _assert(eng is not None, f"{name}: .engine returned None after build") + _assert(Path(engine_path).exists(), f"{name}: engine file not written to disk") + logger.info(f"{PASS} [{name}] engine built and saved: {engine_path}") + + # The FP8→FP16 one-time log is emitted by _build_tensorrt_engine internally when + # build_fp8=True but the GPU/TRT combination does not support STRONGLY_TYPED at the + # requested opt level. We can't assert it here without capturing logs, but it was + # previously validated by trt_base unit tests and is only meaningful for finding regressions + # during live builds. + + # --- 3. process_tensor at 512 (dynamic allocate_buffers -1 resolution path) --- + t512 = torch.rand(1, 3, 512, 512, device="cuda", dtype=torch.float16) + out512 = proc.process_tensor(t512) + _assert( + tuple(out512.shape) == (3, 512, 512), + f"{name}: 512 output shape {tuple(out512.shape)} != (3,512,512)", + ) + logger.info(f"{PASS} [{name}] process_tensor 512→{tuple(out512.shape)}") + + # --- 4. process_tensor at 384 (dynamic reconcile — the key unexercised branch) --- + t384 = torch.rand(1, 3, 384, 384, device="cuda", dtype=torch.float16) + out384 = proc.process_tensor(t384) + _assert( + tuple(out384.shape) == (3, 384, 384), + f"{name}: 384 output shape {tuple(out384.shape)} != (3,384,384) — dynamic reconcile broken", + ) + logger.info(f"{PASS} [{name}] process_tensor 384→{tuple(out384.shape)} (dynamic reconcile OK)") + + # --- 5. dtype mismatch guard: float32 into a float16 engine must raise ValueError --- + t_f32 = torch.rand(1, 3, 512, 512, device="cuda", dtype=torch.float32) + try: + proc.process_tensor(t_f32) + raise AssertionError(f"{name}: float32 input did NOT raise ValueError — dtype guard missing") + except ValueError as exc: + _assert("dtype mismatch" in str(exc), f"{name}: ValueError does not say 'dtype mismatch': {exc}") + logger.info(f"{PASS} [{name}] float32 input raised ValueError('dtype mismatch') as expected") + + +def run_hed_scribble(tmpdir: Path) -> None: + """Run HED and Scribble smoke tests under the Performance knobs (build_fp8=True).""" + try: + from streamdiffusion.preprocessing.processors.hed_tensorrt import HEDTensorrtPreprocessor + from streamdiffusion.preprocessing.processors.scribble_tensorrt import ScribbleTensorrtPreprocessor + except ImportError as e: + logger.warning(f"{SKIP} Could not import HED/Scribble preprocessors: {e}") + return + + for name, cls in [("hed_tensorrt", HEDTensorrtPreprocessor), ("scribble_tensorrt", ScribbleTensorrtPreprocessor)]: + try: + _smoke_self_build(name, cls, tmpdir, build_fp8=True) + except AssertionError as e: + logger.error(f"{FAIL} [{name}] {e}") + raise + + +# --------------------------------------------------------------------------- +# Part A — NormalBae __new__ fallback (Finding A) on real GPU +# --------------------------------------------------------------------------- + + +def run_normalbae_fallback(tmpdir: Path) -> None: + """Verify the NormalBae fallback path initializes correctly on a real GPU.""" + try: + import streamdiffusion.preprocessing.processors.normal_bae_tensorrt as nmod + from streamdiffusion.preprocessing.processors.normal_bae_tensorrt import NormalBaeTensorrtPreprocessor + except ImportError as e: + logger.warning(f"{SKIP} Could not import NormalBae preprocessor: {e}") + return + + try: + from controlnet_aux import NormalBaeDetector # noqa: F401 — presence check only + except ImportError: + logger.warning(f"{SKIP} controlnet_aux not installed — NormalBae fallback test skipped") + return + + logger.info(f"\n{'=' * 60}") + logger.info("Smoke: NormalBae fallback (__new__ Finding A) on real GPU") + logger.info(f"{'=' * 60}") + + # Reset the probe cache so patching takes effect. + nmod._TRT_STRATEGY_AVAILABLE = None + + # Force the fallback path: pretend ONNX export probing returned False. + from unittest.mock import patch + + with patch.object(nmod, "_probe_normal_bae_onnx_export", return_value=False): + with patch.object(nmod, "TENSORRT_AVAILABLE", False): + obj = NormalBaeTensorrtPreprocessor(device="cuda", detect_resolution=512) + + # Verify the object is usable (Finding A: before the fix, it had no _detector). + _assert(hasattr(obj, "_detector"), "fallback object missing '_detector' attribute") + _assert(obj._detector is not None, "fallback object's '_detector' is None") + _assert(hasattr(obj, "device"), "fallback object missing 'device' attribute") + logger.info(f"{PASS} [normal_bae_tensorrt] fallback __new__ produces fully-initialized object") + + # Run one frame (GPU) through the fallback detector. + t = torch.rand(1, 3, 512, 512, device="cuda") + try: + out = obj.process_tensor(t) + _assert(out is not None, "fallback process_tensor returned None") + logger.info( + f"{PASS} [normal_bae_tensorrt] fallback process_tensor ran without AttributeError: shape={tuple(out.shape)}" + ) + except AttributeError as exc: + raise AssertionError(f"fallback AttributeError still present — Finding A not fixed: {exc}") from exc + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + if not torch.cuda.is_available(): + print(f"{FAIL} No CUDA GPU available — cannot run GPU smoke tests") + sys.exit(1) + + logger.info(f"GPU: {torch.cuda.get_device_name(0)}") + logger.info("TensorRT: ", end="") + try: + import tensorrt as trt # noqa: F401 + + logger.info(f"{trt.__version__}") + except ImportError: + print(f"{FAIL} TensorRT not installed") + sys.exit(1) + + with tempfile.TemporaryDirectory(prefix="smoke_self_build_") as tmpdir: + tmp = Path(tmpdir) + logger.info(f"Temp engine dir: {tmp}") + + # Part A — HED + Scribble + run_hed_scribble(tmp) + + # Part A — NormalBae fallback + run_normalbae_fallback(tmp) + + logger.info("\n" + "=" * 60) + logger.info("All smoke assertions passed.") + logger.info("=" * 60) + + +if __name__ == "__main__": + # Add repo src to path when run directly. + repo_root = Path(__file__).resolve().parents[2] + sys.path.insert(0, str(repo_root / "src")) + main() diff --git a/tests/unit/test_cn_preprocessor_residency.py b/tests/unit/test_cn_preprocessor_residency.py new file mode 100644 index 000000000..fe142e97f --- /dev/null +++ b/tests/unit/test_cn_preprocessor_residency.py @@ -0,0 +1,152 @@ +""" +GPU-Residency guard for ControlNet-coupled preprocessors. + +Verification test for plan cozy-booping-wilkinson.md, Verification step 1: +"new test imports the resolver, iterates every CN-coupled type, resolves the +preprocessor name, instantiates via get_preprocessor, and asserts gpu_native +is True. Fails today for hed/scribble/normal; passes after the port." + +Run with: pytest tests/unit/test_cn_preprocessor_residency.py -v +""" + +import pytest + + +# --------------------------------------------------------------------------- +# CN-coupled type → expected preprocessor mappings +# (matches the plan table and the updated CN_MODEL_REGISTRY 'preprocessor' fields) +# --------------------------------------------------------------------------- + +CN_COUPLED_PREPROCESSORS = [ + # (cn_type_label, preprocessor_name, expect_gpu_native) + ("canny", "canny", True), + ("soft_edge", "soft_edge", True), + ("lineart", "standard_lineart", True), + ("tile", "feedback", True), + ("color", "passthrough", True), + ("depth", "depth_tensorrt", True), + ("openpose", "pose_tensorrt", True), + ("hed", "hed_tensorrt", True), + ("scribble", "scribble_tensorrt", True), + ("normalbae", "normal_bae_tensorrt", True), +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _check_gpu_native(preprocessor_name: str) -> bool: + """ + Return True if the named preprocessor class has gpu_native = True. + Returns False if the class is not registered or has gpu_native = False. + """ + from streamdiffusion.preprocessing.processors import get_preprocessor_class + + try: + cls = get_preprocessor_class(preprocessor_name) + return getattr(cls, "gpu_native", False) + except (ValueError, Exception): + return False + + +def _is_registered(preprocessor_name: str) -> bool: + from streamdiffusion.preprocessing.processors import list_preprocessors + + return preprocessor_name in list_preprocessors() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("cn_type,preprocessor_name,should_be_gpu_native", CN_COUPLED_PREPROCESSORS) +def test_cn_preprocessor_is_registered(cn_type, preprocessor_name, should_be_gpu_native): + """Every CN-coupled preprocessor must be registered (no dangling names).""" + assert _is_registered(preprocessor_name), ( + f"Preprocessor '{preprocessor_name}' for CN type '{cn_type}' is NOT registered. " + "This would crash get_preprocessor_class with 'Unknown preprocessor'." + ) + + +@pytest.mark.parametrize("cn_type,preprocessor_name,should_be_gpu_native", CN_COUPLED_PREPROCESSORS) +def test_cn_preprocessor_gpu_native_flag(cn_type, preprocessor_name, should_be_gpu_native): + """Every CN-coupled preprocessor class must declare gpu_native = True.""" + if not _is_registered(preprocessor_name): + pytest.skip(f"'{preprocessor_name}' not registered (see test above)") + + actual = _check_gpu_native(preprocessor_name) + assert actual == should_be_gpu_native, ( + f"Preprocessor '{preprocessor_name}' (for CN type '{cn_type}') " + f"has gpu_native={actual!r}, expected {should_be_gpu_native!r}.\n" + "If gpu_native=True, the class runs _process_tensor_core on GPU with no PIL round-trip.\n" + "If you see False, either:\n" + " (a) the class still uses the base-class PIL fallback — override _process_tensor_core\n" + " (b) you forgot to set gpu_native=True on the class" + ) + + +class TestAutopreprocessResolver: + """Tests for D9 — registry-driven + heuristic + passthrough fallback.""" + + def test_registry_lookup_sd15_canny(self): + """Exact registry match returns the 'preprocessor' field directly.""" + # Inline import mirrors model_utils__td.py location in the Scripts directory. + import importlib.util + import os + + scripts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "Scripts") + spec = importlib.util.spec_from_file_location( + "model_utils_td", + os.path.join(scripts_dir, "StreamDiffusionTD__Text__model_utils__td.py"), + ) + if spec is None or spec.loader is None: + pytest.skip("model_utils__td.py not found — run from repo root") + + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # SD1.5 canny → depth_tensorrt is NOT the fallback; canny is correct + result = mod.get_preprocessor_for_controlnet("lllyasviel/control_v11p_sd15_canny", "Local") + assert result == "canny", f"Expected 'canny', got '{result}'" + + def test_registry_lookup_sd15_normalbae(self): + import importlib.util + import os + + scripts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "Scripts") + spec = importlib.util.spec_from_file_location( + "model_utils_td", + os.path.join(scripts_dir, "StreamDiffusionTD__Text__model_utils__td.py"), + ) + if spec is None or spec.loader is None: + pytest.skip("model_utils__td.py not found") + + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + result = mod.get_preprocessor_for_controlnet("lllyasviel/control_v11p_sd15_normalbae", "Local") + assert result == "normal_bae_tensorrt", ( + f"Expected 'normal_bae_tensorrt', got '{result}' — dangling 'normal_bae' reference not fixed" + ) + + def test_fallback_to_passthrough_for_unknown_model(self): + import importlib.util + import os + + scripts_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", "Scripts") + spec = importlib.util.spec_from_file_location( + "model_utils_td", + os.path.join(scripts_dir, "StreamDiffusionTD__Text__model_utils__td.py"), + ) + if spec is None or spec.loader is None: + pytest.skip("model_utils__td.py not found") + + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + # Completely unknown ID — should fall through to 'passthrough', not crash + result = mod.get_preprocessor_for_controlnet("some-custom/totally-unknown-controlnet-v1", "Local") + assert result == "passthrough", f"Expected 'passthrough' fallback for unknown model, got '{result}'" diff --git a/tests/unit/test_normal_bae_fallback.py b/tests/unit/test_normal_bae_fallback.py new file mode 100644 index 000000000..3478315d2 --- /dev/null +++ b/tests/unit/test_normal_bae_fallback.py @@ -0,0 +1,164 @@ +""" +Regression tests for Finding A: NormalBae torch-direct fallback initialization. + +Before the fix, NormalBaeTensorrtPreprocessor.__new__ returned +``object.__new__(_NormalBaeTorchGPU)``. Because _NormalBaeTorchGPU is not a +subclass of NormalBaeTensorrtPreprocessor, CPython's type.__call__ skipped +__init__ entirely. The resulting object had no _detector, params, or device — +AttributeError on every frame. + +After the fix, __new__ calls ``_NormalBaeTorchGPU(**kwargs)`` directly so +__init__ runs correctly. + +Requires: controlnet_aux installed (skipped otherwise). +CPU-only: patches NormalBaeDetector.from_pretrained to avoid model downloads. +Run with: pytest tests/unit/test_normal_bae_fallback.py -v +""" + +import unittest +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Skip guard — controlnet_aux must be importable +# --------------------------------------------------------------------------- + +try: + from controlnet_aux import NormalBaeDetector as _NDA # noqa: F401 — import probe only + + _CONTROLNET_AUX_OK = True +except ImportError: + _CONTROLNET_AUX_OK = False + +pytestmark = pytest.mark.skipif( + not _CONTROLNET_AUX_OK, + reason="controlnet_aux not installed — skipping NormalBae fallback tests", +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_stub_detector(): + """Return a NormalBaeDetector stub that survives _load_model().""" + stub = MagicMock() + stub.model = MagicMock() + stub.model.to.return_value = stub.model + stub.model.eval.return_value = stub.model + stub.norm = MagicMock() + stub.norm.to.return_value = stub.norm + return stub + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestNormalBaeFallbackFullyInitialized(unittest.TestCase): + """Verify the fallback path produces a correctly-initialized _NormalBaeTorchGPU.""" + + def setUp(self): + # Reset the module-level probe cache before each test. + import streamdiffusion.preprocessing.processors.normal_bae_tensorrt as mod + + mod._TRT_STRATEGY_AVAILABLE = None + + def test_fallback_has_device_params_and_detector(self): + """ + Finding A regression: fallback object must have device, params, _detector. + Before the fix this test raised AttributeError on the first frame. + """ + import streamdiffusion.preprocessing.processors.normal_bae_tensorrt as mod + from streamdiffusion.preprocessing.processors.normal_bae_tensorrt import ( + NormalBaeTensorrtPreprocessor, + _NormalBaeTorchGPU, + ) + + stub_det = _make_stub_detector() + + with ( + patch.object(mod, "_probe_normal_bae_onnx_export", return_value=False), + patch.object(mod, "TENSORRT_AVAILABLE", False), + patch.object(mod, "NormalBaeDetector") as MockNDA, + ): + MockNDA.from_pretrained.return_value = stub_det + obj = NormalBaeTensorrtPreprocessor(device="cpu") + + self.assertIsInstance( + obj, + _NormalBaeTorchGPU, + "fallback must be a _NormalBaeTorchGPU instance", + ) + self.assertTrue(hasattr(obj, "device"), "fallback must have 'device' attribute") + self.assertTrue(hasattr(obj, "params"), "fallback must have 'params' attribute") + self.assertTrue(hasattr(obj, "_detector"), "fallback must have '_detector' attribute") + self.assertIsNotNone(obj._detector, "_detector must not be None after __init__") + + def test_fallback_device_is_passed_through(self): + """Constructor kwargs (device, detect_resolution) must flow to the fallback object.""" + import streamdiffusion.preprocessing.processors.normal_bae_tensorrt as mod + from streamdiffusion.preprocessing.processors.normal_bae_tensorrt import ( + NormalBaeTensorrtPreprocessor, + ) + + stub_det = _make_stub_detector() + + with ( + patch.object(mod, "_probe_normal_bae_onnx_export", return_value=False), + patch.object(mod, "TENSORRT_AVAILABLE", False), + patch.object(mod, "NormalBaeDetector") as MockNDA, + ): + MockNDA.from_pretrained.return_value = stub_det + obj = NormalBaeTensorrtPreprocessor(device="cpu", detect_resolution=384) + + # params is populated by BasePreprocessor.__init__ from **kwargs + self.assertEqual( + obj.params.get("detect_resolution"), + 384, + "detect_resolution kwarg must appear in fallback's params", + ) + + +class TestNormalBaeUninitializedGuard(unittest.TestCase): + """Verify that the defensive guard in _process_tensor_core raises clearly.""" + + def setUp(self): + import streamdiffusion.preprocessing.processors.normal_bae_tensorrt as mod + + mod._TRT_STRATEGY_AVAILABLE = None + + def test_none_detector_raises_runtime_error_not_attribute_error(self): + """ + Simulating the pre-fix uninitialized state: _detector=None must raise + RuntimeError with a clear message instead of a bare AttributeError. + """ + import torch + + import streamdiffusion.preprocessing.processors.normal_bae_tensorrt as mod + from streamdiffusion.preprocessing.processors.normal_bae_tensorrt import ( + _NormalBaeTorchGPU, + ) + + stub_det = _make_stub_detector() + + with patch.object(mod, "NormalBaeDetector") as MockNDA: + MockNDA.from_pretrained.return_value = stub_det + obj = _NormalBaeTorchGPU(device="cpu") + + # Explicitly unset _detector to replicate pre-fix broken state + obj._detector = None + + with self.assertRaises(RuntimeError) as ctx: + obj._process_tensor_core(torch.zeros(3, 64, 64)) + + err = str(ctx.exception) + self.assertIn("_load_model", err, "error message should mention _load_model") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/test_trt_engine_guards.py b/tests/unit/test_trt_engine_guards.py new file mode 100644 index 000000000..435ce5f2a --- /dev/null +++ b/tests/unit/test_trt_engine_guards.py @@ -0,0 +1,240 @@ +""" +Tests for TensorRTEngine dynamic-shape and dtype guards (Findings C, D). + +These tests use MagicMock for all TRT internals and verify only the pure-Python +guard logic added to allocate_buffers() and infer(). + +The entire module is skipped when TensorRT is not installed. + +Run with: pytest tests/unit/test_trt_engine_guards.py -v +""" + +from collections import OrderedDict +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Import guard — skip all tests if TRT is unavailable +# --------------------------------------------------------------------------- + +try: + from streamdiffusion.preprocessing.processors.trt_base import ( + TENSORRT_AVAILABLE, + TensorRTEngine, + ) + + if TENSORRT_AVAILABLE: + import tensorrt as trt +except ImportError: + TENSORRT_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not TENSORRT_AVAILABLE, + reason="TensorRT not installed — skipping engine guard tests", +) + + +# --------------------------------------------------------------------------- +# Fake-engine factory +# --------------------------------------------------------------------------- + + +def _make_fake_engine( + input_shape=(1, 3, 512, 512), + output_shape=(1, 1, 512, 512), + dtype=None, +): + """ + Build a TensorRTEngine with fully-mocked TRT internals. + + Pre-populates tensors as if ``allocate_buffers`` already ran at + ``input_shape`` / ``output_shape`` with ``dtype`` (default float16). + """ + import torch + + dtype = dtype or torch.float16 + + eng = TensorRTEngine.__new__(TensorRTEngine) + eng.engine_path = "/fake/test.engine" + eng._cuda_stream = None + + eng.engine = MagicMock() + eng.context = MagicMock() + eng.context.execute_async_v3.return_value = True + + names = ["input", "output"] + modes = [trt.TensorIOMode.INPUT, trt.TensorIOMode.OUTPUT] + current_shapes = { + "input": input_shape, + "output": output_shape, + } + + eng.engine.num_io_tensors = 2 + eng.engine.get_tensor_name.side_effect = lambda idx: names[idx] + eng.engine.get_tensor_mode.side_effect = lambda n: modes[names.index(n)] + eng.context.get_tensor_shape.side_effect = lambda n: current_shapes[n] + + eng.tensors = OrderedDict( + input=torch.zeros(*input_shape, dtype=dtype), + output=torch.zeros(*output_shape, dtype=dtype), + ) + + return eng, current_shapes + + +# --------------------------------------------------------------------------- +# infer() dtype guard (Finding D) +# --------------------------------------------------------------------------- + + +class TestInferDtypeGuard: + def test_dtype_mismatch_raises_valueerror_with_context(self): + """float32 input into a float16 engine must raise ValueError naming the tensor.""" + import torch + + eng, _ = _make_fake_engine() + feed = {"input": torch.zeros(1, 3, 512, 512, dtype=torch.float32)} + + with pytest.raises(ValueError, match="dtype mismatch") as exc_info: + eng.infer(feed) + + msg = str(exc_info.value) + assert "input" in msg, "error should name the mismatched tensor" + assert "engine_path" in msg or "engine:" in msg.lower() or "/fake/" in msg, ( + "error should include engine path for diagnosability" + ) + + def test_correct_dtype_and_shape_succeeds(self): + """Matching dtype + shape: infer should succeed and return output.""" + import torch + + eng, _ = _make_fake_engine() + feed = {"input": torch.zeros(1, 3, 512, 512, dtype=torch.float16)} + result = eng.infer(feed) + assert "output" in result + + def test_dtype_mismatch_float16_input_into_float32_engine(self): + """float16 input into a float32 engine must also raise ValueError.""" + import torch + + eng, _ = _make_fake_engine(dtype=torch.float32) + feed = {"input": torch.zeros(1, 3, 512, 512, dtype=torch.float16)} + + with pytest.raises(ValueError, match="dtype mismatch"): + eng.infer(feed) + + +# --------------------------------------------------------------------------- +# infer() shape reconciliation — alignment guarantee (Finding C) +# --------------------------------------------------------------------------- + + +class TestInferShapeReconciliation: + def test_shape_change_reallocates_output_to_new_resolution(self): + """ + Feeding a 384×384 input after allocating at 512×512 must reallocate + the output buffer to match the new resolution (the alignment guarantee). + """ + import torch + + eng, current_shapes = _make_fake_engine( + input_shape=(1, 3, 512, 512), + output_shape=(1, 1, 512, 512), + ) + + new_input_shape = (1, 3, 384, 384) + new_output_shape = (1, 1, 384, 384) + + # After set_input_shape(384), context.get_tensor_shape should return new sizes + def dynamic_shape(name): + if name == "input": + return new_input_shape + return new_output_shape + + eng.context.get_tensor_shape.side_effect = dynamic_shape + + feed = {"input": torch.zeros(*new_input_shape, dtype=torch.float16)} + result = eng.infer(feed) + + assert tuple(result["output"].shape) == new_output_shape, ( + f"output shape {tuple(result['output'].shape)} != expected {new_output_shape} " + "— dynamic-shape reconciliation is broken" + ) + + def test_same_shape_does_not_trigger_realloc(self): + """Feeding the same shape as allocated must not call set_input_shape.""" + import torch + + eng, _ = _make_fake_engine() + feed = {"input": torch.zeros(1, 3, 512, 512, dtype=torch.float16)} + eng.infer(feed) + + eng.context.set_input_shape.assert_not_called() + + +# --------------------------------------------------------------------------- +# allocate_buffers() dynamic-shape guard (Finding C) +# --------------------------------------------------------------------------- + + +class TestAllocateBuffersGuard: + def _make_engine_skeleton(self): + """Bare TensorRTEngine without pre-allocated tensors.""" + eng = TensorRTEngine.__new__(TensorRTEngine) + eng.engine_path = "/fake/dynamic.engine" + eng.tensors = OrderedDict() + eng.engine = MagicMock() + eng.context = MagicMock() + return eng + + def test_dynamic_dims_without_input_shape_raises_runtime_error(self): + """ + allocate_buffers on a dynamic engine without input_shape must raise + RuntimeError that names the problematic tensor and says 'dynamic'. + """ + eng = self._make_engine_skeleton() + + eng.engine.num_io_tensors = 1 + eng.engine.get_tensor_name.side_effect = lambda idx: "input" + eng.engine.get_tensor_mode.side_effect = lambda n: trt.TensorIOMode.INPUT + # Simulate dynamic engine: shape has -1 dims + eng.context.get_tensor_shape.return_value = (-1, 3, -1, -1) + + with pytest.raises(RuntimeError, match="dynamic") as exc_info: + eng.allocate_buffers() + + msg = str(exc_info.value) + assert "input" in msg, "error should name the problematic tensor" + + def test_static_dims_without_input_shape_succeeds(self): + """ + A fully-static engine (no -1 dims) must work without input_shape. + Verifies we didn't break the existing static-engine path. + """ + import numpy as np + + eng = self._make_engine_skeleton() + static_shape = (1, 3, 512, 512) + + eng.engine.num_io_tensors = 1 + eng.engine.get_tensor_name.side_effect = lambda idx: "input" + eng.engine.get_tensor_mode.side_effect = lambda n: trt.TensorIOMode.INPUT + eng.engine.get_tensor_dtype.side_effect = lambda n: trt.DataType.FLOAT + + # Return static (concrete) shape + eng.context.get_tensor_shape.return_value = static_shape + + # Patch trt.nptype to return a known numpy dtype + import streamdiffusion.preprocessing.processors.trt_base as trt_base_mod + + original_nptype = trt_base_mod.trt.nptype + try: + trt_base_mod.trt.nptype = lambda _: np.float32 + eng.allocate_buffers() + finally: + trt_base_mod.trt.nptype = original_nptype + + assert "input" in eng.tensors + assert tuple(eng.tensors["input"].shape) == static_shape