diff --git a/demo/realtime-img2img/controlnet_registry.yaml b/demo/realtime-img2img/controlnet_registry.yaml index 7da63eee..9d0b69d5 100644 --- a/demo/realtime-img2img/controlnet_registry.yaml +++ b/demo/realtime-img2img/controlnet_registry.yaml @@ -119,5 +119,120 @@ available_controlnets: default_preprocessor: "feedback" default_scale: 0.6 description: "Uses image feedback for enhanced details (SDXL)" + preprocessor_params: + image_resolution: 512 + + - id: "depth_xinsir_sdxl" + name: "Depth Detection (xinsir)" + model_id: "xinsir/controlnet-depth-sdxl-1.0" + default_preprocessor: "depth_tensorrt" + default_scale: 0.8 + description: "Estimates depth information from images — xinsir SDXL variant" + preprocessor_params: + detect_resolution: 518 + image_resolution: 512 + + - id: "scribble_sdxl" + name: "Scribble" + model_id: "xinsir/controlnet-scribble-sdxl-1.0" + default_preprocessor: "scribble_tensorrt" + default_scale: 0.8 + description: "Produces sketch-like scribble edge conditioning (SDXL)" + preprocessor_params: + image_resolution: 512 + + sd21: + - id: "canny_sd21" + name: "Canny Edge Detection" + model_id: "thibaud/controlnet-sd21-canny-diffusers" + default_preprocessor: "canny" + default_scale: 0.8 + description: "Detects edges and outlines in images (SD2.1)" + preprocessor_params: + low_threshold: 100 + high_threshold: 200 + + - id: "depth_sd21" + name: "Depth Estimation" + model_id: "thibaud/controlnet-sd21-depth-diffusers" + default_preprocessor: "depth_tensorrt" + default_scale: 0.8 + description: "Estimates depth from images (SD2.1)" + preprocessor_params: + detect_resolution: 518 + image_resolution: 512 + + - id: "openpose_sd21" + name: "OpenPose" + model_id: "thibaud/controlnet-sd21-openpose-diffusers" + default_preprocessor: "pose_tensorrt" + default_scale: 0.8 + description: "Detects human body pose (SD2.1)" + preprocessor_params: + detect_resolution: 640 + image_resolution: 512 + + - id: "scribble_sd21" + name: "Scribble" + model_id: "thibaud/controlnet-sd21-scribble-diffusers" + default_preprocessor: "scribble_tensorrt" + default_scale: 0.8 + description: "Generates from rough sketches (SD2.1)" + preprocessor_params: + image_resolution: 512 + + - id: "hed_sd21" + name: "HED Soft Edge" + model_id: "thibaud/controlnet-sd21-hed-diffusers" + default_preprocessor: "hed_tensorrt" + default_scale: 0.8 + description: "Soft edge / HED boundary detection (SD2.1)" + preprocessor_params: + image_resolution: 512 + + - id: "normalbae_sd21" + name: "Normal Map (BAE)" + model_id: "thibaud/controlnet-sd21-normalbae-diffusers" + default_preprocessor: "normal_bae_tensorrt" + default_scale: 0.8 + description: "Surface normal estimation (SD2.1)" + preprocessor_params: + image_resolution: 512 + + - id: "lineart_sd21" + name: "Lineart" + model_id: "thibaud/controlnet-sd21-lineart-diffusers" + default_preprocessor: "standard_lineart" + default_scale: 0.8 + description: "Line-art extraction (SD2.1)" + preprocessor_params: + gaussian_sigma: 6.0 + intensity_threshold: 8 + + - id: "zoedepth_sd21" + name: "ZoeDepth" + model_id: "thibaud/controlnet-sd21-zoedepth-diffusers" + default_preprocessor: "depth_tensorrt" + default_scale: 0.8 + description: "Metric depth estimation (SD2.1)" + preprocessor_params: + detect_resolution: 518 + image_resolution: 512 + + - id: "color_sd21" + name: "Color" + model_id: "thibaud/controlnet-sd21-color-diffusers" + default_preprocessor: "passthrough" + default_scale: 0.8 + description: "Color/palette conditioning (SD2.1)" + preprocessor_params: + image_resolution: 512 + + - id: "ade20k_sd21" + name: "Segmentation (ADE20K)" + model_id: "thibaud/controlnet-sd21-ade20k-diffusers" + default_preprocessor: "passthrough" + default_scale: 0.8 + description: "Semantic segmentation conditioning (SD2.1)" preprocessor_params: image_resolution: 512 \ No newline at end of file diff --git a/demo/realtime-img2img/routes/controlnet.py b/demo/realtime-img2img/routes/controlnet.py index e25cbcbf..6698d9d4 100644 --- a/demo/realtime-img2img/routes/controlnet.py +++ b/demo/realtime-img2img/routes/controlnet.py @@ -239,16 +239,23 @@ async def get_available_controlnets_endpoint(app_instance=Depends(get_app_instan model_type = "sd15" # Default fallback # Try to determine model type from pipeline config or uploaded config - if app_instance.pipeline and hasattr(app_instance.pipeline, 'config') and app_instance.pipeline.config: - model_id = app_instance.pipeline.config.get('model_id', '') - if 'sdxl' in model_id.lower() or 'xl' in model_id.lower(): + if app_instance.pipeline and hasattr(app_instance.pipeline, "config") and app_instance.pipeline.config: + model_id = app_instance.pipeline.config.get("model_id", "") + ml = model_id.lower() + if "sdxl" in ml or "xl" in ml: model_type = "sdxl" + elif "sd-turbo" in ml or "sd21" in ml or "sd2.1" in ml or "2-1" in ml or "stable-diffusion-2" in ml: + model_type = "sd21" elif app_instance.app_state.uploaded_config: # If no pipeline yet, try to get model type from uploaded config - model_id = app_instance.app_state.uploaded_config.get('model_id_or_path', '') - if 'sdxl' in model_id.lower() or 'xl' in model_id.lower(): + model_id = app_instance.app_state.uploaded_config.get("model_id_or_path", "") + ml = model_id.lower() + if "sdxl" in ml or "xl" in ml: model_type = "sdxl" - + elif "sd-turbo" in ml or "sd21" in ml or "sd2.1" in ml or "2-1" in ml or "stable-diffusion-2" in ml: + model_type = "sd21" + + # Handle case where available_controlnets dependency returns None if available_controlnets is None: logging.warning("get_available_controlnets: available_controlnets dependency returned None") diff --git a/src/streamdiffusion/preprocessing/processors/canny.py b/src/streamdiffusion/preprocessing/processors/canny.py index 7c25e9ab..6058ef63 100644 --- a/src/streamdiffusion/preprocessing/processors/canny.py +++ b/src/streamdiffusion/preprocessing/processors/canny.py @@ -4,6 +4,7 @@ import torch from typing import Union from .base import BasePreprocessor +from .category_params import EDGE_SMOOTHNESS_PARAM, apply_edge_smoothness #TODO provide gpu native edge detection class CannyPreprocessor(BasePreprocessor): @@ -29,8 +30,9 @@ def get_preprocessor_metadata(cls): "type": "int", "default": 200, "range": [1, 255], - "description": "Upper threshold for edge detection. Higher values are more selective." - } + "description": "Upper threshold for edge detection. Higher values are more selective.", + }, + **EDGE_SMOOTHNESS_PARAM, }, "use_cases": ["Line art", "Architecture", "Technical drawings", "Clean edge detection"] } @@ -60,10 +62,21 @@ def _process_core(self, image: Image.Image) -> Image.Image: gray = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY) else: gray = image_np - - low_threshold = self.params.get('low_threshold', 100) - high_threshold = self.params.get('high_threshold', 200) - + + # Optional smoothness pre-blur (category-standard edge param). + # Applied before cv2.Canny so that coarser smoothing suppresses high-frequency + # texture, yielding sparser / softer edges without changing threshold semantics. + smoothness = float(self.params.get("smoothness", 0.0)) + if smoothness > 0.0: + sigma = smoothness * 2.0 + radius = max(1, int(sigma * 3.0 + 0.5)) + k_size = 2 * radius + 1 + gray = cv2.GaussianBlur(gray, (k_size, k_size), sigma) + + low_threshold = self.params.get("low_threshold", 100) + high_threshold = self.params.get("high_threshold", 200) + + edges = cv2.Canny(gray, low_threshold, high_threshold) edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) @@ -77,18 +90,26 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: gray_tensor = 0.299 * image_tensor[0] + 0.587 * image_tensor[1] + 0.114 * image_tensor[2] else: gray_tensor = image_tensor[0] if image_tensor.shape[0] == 1 else image_tensor - + gray_cpu = gray_tensor.cpu() gray_np = (gray_cpu * 255).clamp(0, 255).to(torch.uint8).numpy() - - low_threshold = self.params.get('low_threshold', 100) - high_threshold = self.params.get('high_threshold', 200) - + + # Optional smoothness pre-blur (category-standard edge param). + smoothness = float(self.params.get("smoothness", 0.0)) + if smoothness > 0.0: + sigma = smoothness * 2.0 + radius = max(1, int(sigma * 3.0 + 0.5)) + k_size = 2 * radius + 1 + gray_np = cv2.GaussianBlur(gray_np, (k_size, k_size), sigma) + + low_threshold = self.params.get("low_threshold", 100) + high_threshold = self.params.get("high_threshold", 200) + edges = cv2.Canny(gray_np, low_threshold, high_threshold) - + edges_tensor = torch.from_numpy(edges).float() / 255.0 edges_tensor = edges_tensor.to(device=self.device, dtype=self.dtype) - + edges_rgb = edges_tensor.unsqueeze(0).repeat(3, 1, 1) - - return edges_rgb \ No newline at end of file + + return edges_rgb diff --git a/src/streamdiffusion/preprocessing/processors/category_params.py b/src/streamdiffusion/preprocessing/processors/category_params.py new file mode 100644 index 00000000..49e4c587 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/category_params.py @@ -0,0 +1,245 @@ +""" +Category-level parameter contracts for ControlNet preprocessors. + +Canonical metadata fragments (merged into get_preprocessor_metadata) and GPU / NumPy +helpers that implement each category's standard post-processing step. Preprocessors +import what they need — no inheritance change required. + +Category contracts (keyed to the production xinsir SDXL CN set): + EDGE_SMOOTHNESS_PARAM edge-based (canny, scribble_tensorrt, …) + DEPTH_GRADE_PARAMS depth-based (depth_tensorrt, …) + POSE_DRAW_PARAMS bodypose (pose_tensorrt, …) + SEGMENTATION_PARAMS segmentation (future; mediapipe_segmentation already matches) + +GPU helpers: apply_edge_smoothness, apply_depth_grade +NumPy helper: apply_depth_grade_numpy +""" + +from __future__ import annotations + +import math + +import numpy as np +import torch +import torch.nn.functional as F + + +# --------------------------------------------------------------------------- +# Canonical metadata fragments +# --------------------------------------------------------------------------- + +EDGE_SMOOTHNESS_PARAM: dict = { + "smoothness": { + "type": "float", + "default": 0.0, + "range": [0.0, 1.0], + "description": ( + "Optional pre-blur applied before edge extraction. " + "0 = off (sharpest edges); 1 = heaviest smoothing (σ≈2, ~13×13 kernel)." + ), + }, +} + +DEPTH_GRADE_PARAMS: dict = { + "gamma": { + "type": "float", + "default": 1.0, + "range": [0.1, 3.0], + "description": ( + "Gamma applied to the [0,1] depth map after auto-normalization. " + ">1 compresses the far field (deepens contrast near the camera); " + "<1 lifts it (stretches shadow detail)." + ), + }, + "black_level": { + "type": "float", + "default": 0.0, + "range": [0.0, 1.0], + "description": ("Normalization floor — depth values at or below this level map to 0 (far field)."), + }, + "white_level": { + "type": "float", + "default": 1.0, + "range": [0.0, 1.0], + "description": ("Normalization ceiling — depth values at or above this level map to 1 (near field)."), + }, + "invert": { + "type": "bool", + "default": False, + "description": "Swap near/far (1 − depth) before grading.", + }, +} + +POSE_DRAW_PARAMS: dict = { + "keypoint_threshold": { + "type": "float", + "default": 0.5, + "range": [0.0, 1.0], + "description": "Confidence cutoff for drawing skeleton joints and keypoints.", + }, + "joint_thickness": { + "type": "int", + "default": 10, + "range": [1, 30], + "description": "Thickness of skeleton limb lines (pixels).", + }, + "keypoint_radius": { + "type": "int", + "default": 10, + "range": [1, 30], + "description": "Radius of keypoint dots (pixels).", + }, +} + +# Passthrough — zero parameters, intentionally. +# The input image is forwarded unchanged to the ControlNet; no pre-processing is applied. +# Use this when the source is already a conditioning map (depth pass, scribble, skeleton, …). +# Listed here so the empty contract is explicit rather than accidentally omitted. +PASSTHROUGH_PARAMS: dict = {} + +# Segmentation — not in the production set today; defined here so future seg CNs align. +# mediapipe_segmentation already implements this exact set of parameters. +SEGMENTATION_PARAMS: dict = { + "threshold": { + "type": "float", + "default": 0.5, + "range": [0.0, 1.0], + "description": "Mask binarization threshold.", + }, + "blur_radius": { + "type": "int", + "default": 0, + "range": [0, 20], + "description": "Edge blur radius on the segmentation mask (pixels).", + }, + "invert_mask": { + "type": "bool", + "default": False, + "description": "Invert foreground/background.", + }, +} + + +# --------------------------------------------------------------------------- +# GPU helpers (operate on torch.Tensor, no CPU round-trip) +# --------------------------------------------------------------------------- + + +def apply_edge_smoothness(t: torch.Tensor, strength: float) -> torch.Tensor: + """Apply an adaptive separable Gaussian pre-blur to a grayscale / edge-map tensor. + + Designed to be inserted *before* the native Gaussian/Sobel block so that increasing + `strength` progressively suppresses high-frequency texture, yielding sparser / softer + edge maps. strength=0 is a fast no-op (early return, no allocation). + + Args: + t: Input tensor. Accepts (H, W), (C, H, W), or (1, C, H, W). + strength: Blur intensity in [0, 1]. Maps to σ ∈ [0, 2] (3σ gives the kernel radius). + At strength=1: σ=2, radius=6, k_size=13. + + Returns: + Blurred tensor with the same shape and dtype as *t*. + """ + if strength <= 0.0: + return t + + orig_shape = t.shape + orig_dtype = t.dtype + + # Promote to float32 (1, C, H, W) for conv2d + x = t.float() + if x.dim() == 2: # (H, W) + x = x.unsqueeze(0).unsqueeze(0) + elif x.dim() == 3: # (C, H, W) + x = x.unsqueeze(0) + # else already (1, C, H, W) — or (B, C, H, W); we treat as single image + + sigma = float(strength) * 2.0 + radius = max(1, int(math.ceil(3.0 * sigma))) + k_size = 2 * radius + 1 + + coords = torch.arange(k_size, dtype=torch.float32, device=t.device) - radius + kernel_1d = torch.exp(-(coords**2) / (2.0 * sigma**2)) + kernel_1d = kernel_1d / kernel_1d.sum() + + c = x.shape[1] + # Separable 1-D horizontal / vertical Gaussian convolutions + k_h = kernel_1d.view(1, 1, k_size, 1).expand(c, 1, k_size, 1).contiguous() + k_w = kernel_1d.view(1, 1, 1, k_size).expand(c, 1, 1, k_size).contiguous() + + x = F.conv2d(x, k_h, padding=(radius, 0), groups=c) + x = F.conv2d(x, k_w, padding=(0, radius), groups=c) + + # Restore original shape + if len(orig_shape) == 2: + x = x.squeeze(0).squeeze(0) + elif len(orig_shape) == 3: + x = x.squeeze(0) + + return x.to(dtype=orig_dtype) + + +def apply_depth_grade( + depth: torch.Tensor, + gamma: float = 1.0, + black_level: float = 0.0, + white_level: float = 1.0, + invert: bool = False, +) -> torch.Tensor: + """Apply normalization + gamma grade to a depth map in [0, 1]. + + Operation order: + 1. Optional invert (1 − d) — swap near/far. + 2. Level remap: (d − black_level) / (white_level − black_level), clamped to [0, 1]. + 3. Gamma: d^gamma (1.0 = identity). + + Args: + depth: Depth tensor in [0, 1], any shape. + gamma: Gamma exponent (1.0 = identity). + black_level: New zero point — depth values at/below this map to 0. + white_level: New full-scale point — depth values at/above this map to 1. + invert: Swap near/far before grading. + + Returns: + Graded depth tensor with the same shape and dtype as *depth*. + """ + orig_dtype = depth.dtype + d = depth.float() + + if invert: + d = 1.0 - d + + span = max(float(white_level) - float(black_level), 1e-6) + d = ((d - float(black_level)) / span).clamp(0.0, 1.0) + + if abs(float(gamma) - 1.0) > 1e-6: + d = d.pow(float(gamma)) + + return d.clamp(0.0, 1.0).to(dtype=orig_dtype) + + +# --------------------------------------------------------------------------- +# NumPy helper (CPU / _process_core paths) +# --------------------------------------------------------------------------- + + +def apply_depth_grade_numpy( + depth: np.ndarray, + gamma: float = 1.0, + black_level: float = 0.0, + white_level: float = 1.0, + invert: bool = False, +) -> np.ndarray: + """NumPy equivalent of apply_depth_grade. Expects *depth* in [0, 1] float.""" + d = depth.astype(np.float32) + + if invert: + d = 1.0 - d + + span = max(float(white_level) - float(black_level), 1e-6) + d = ((d - float(black_level)) / span).clip(0.0, 1.0) + + if abs(float(gamma) - 1.0) > 1e-6: + d = np.power(d, float(gamma)) + + return d.clip(0.0, 1.0) diff --git a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py index 993ee242..f3a3f659 100644 --- a/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/depth_tensorrt.py @@ -8,98 +8,14 @@ from PIL import Image from typing import Union, Optional from .base import BasePreprocessor - -try: - import tensorrt as trt - from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict - TENSORRT_AVAILABLE = True -except ImportError: - TENSORRT_AVAILABLE = False - - -# Map of numpy dtype -> torch dtype -numpy_to_torch_dtype_dict = { - np.uint8: torch.uint8, - np.int8: torch.int8, - np.int16: torch.int16, - np.int32: torch.int32, - np.int64: torch.int64, - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, - np.complex64: torch.complex64, - np.complex128: torch.complex128, -} -if np.version.full_version >= "1.24.0": - numpy_to_torch_dtype_dict[np.bool_] = torch.bool -else: - numpy_to_torch_dtype_dict[np.bool] = torch.bool - - -class TensorRTEngine: - """Simplified TensorRT engine wrapper for depth estimation inference (optimized)""" - - def __init__(self, engine_path): - self.engine_path = engine_path - self.engine = None - self.context = None - self.tensors = OrderedDict() - self._cuda_stream = None # Cache CUDA stream - - def load(self): - """Load TensorRT engine from file""" - print(f"Loading TensorRT engine: {self.engine_path}") - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self): - """Create execution context""" - self.context = self.engine.create_execution_context() - # Cache CUDA stream for reuse - self._cuda_stream = torch.cuda.current_stream().cuda_stream - - def allocate_buffers(self, device="cuda"): - """Allocate input/output buffers""" - for idx in range(self.engine.num_io_tensors): - name = self.engine.get_tensor_name(idx) - shape = self.context.get_tensor_shape(name) - dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: - self.context.set_input_shape(name, shape) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=device) - self.tensors[name] = tensor - - def infer(self, feed_dict, stream=None): - """Run inference with optional stream parameter""" - # Use cached stream if none provided - if stream is None: - stream = self._cuda_stream - - # Copy input data to tensors - for name, buf in feed_dict.items(): - self.tensors[name].copy_(buf) - - # Set tensor addresses - for name, tensor in self.tensors.items(): - self.context.set_tensor_address(name, tensor.data_ptr()) - - # Execute inference - success = self.context.execute_async_v3(stream) - if not success: - raise ValueError("ERROR: TensorRT inference failed.") - - return self.tensors +from .category_params import DEPTH_GRADE_PARAMS, apply_depth_grade, apply_depth_grade_numpy +from .trt_base import TENSORRT_AVAILABLE, TensorRTEngine # shared engine wrapper class DepthAnythingTensorrtPreprocessor(BasePreprocessor): """ Depth Anything TensorRT preprocessor for ControlNet - + Uses TensorRT-optimized Depth Anything model for fast depth estimation. """ @classmethod @@ -108,18 +24,18 @@ def get_preprocessor_metadata(cls): "display_name": "Depth Estimation (TensorRT)", "description": "Fast TensorRT-optimized depth estimation using Depth Anything model. Significantly faster than standard depth estimation.", "parameters": { - + **DEPTH_GRADE_PARAMS, }, - "use_cases": ["High-performance depth estimation", "Real-time applications", "3D-aware generation"] + "use_cases": ["High-performance depth estimation", "Real-time applications", "3D-aware generation"], } - def __init__(self, + def __init__(self, engine_path: str = None, detect_resolution: int = 518, image_resolution: int = 512, **kwargs): """ Initialize TensorRT depth preprocessor - + Args: engine_path: Path to TensorRT engine file detect_resolution: Resolution for depth detection (should match engine input) @@ -131,16 +47,16 @@ def __init__(self, "TensorRT and polygraphy libraries are required for TensorRT depth preprocessing. " "Install them with: pip install tensorrt polygraphy" ) - + super().__init__( engine_path=engine_path, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs ) - + self._engine = None - + @property def engine(self): """Lazy loading of the TensorRT engine""" @@ -151,54 +67,64 @@ def engine(self): "engine_path is required for TensorRT depth preprocessing. " "Please provide it in the preprocessor_params config." ) - + if not os.path.exists(engine_path): raise FileNotFoundError(f"TensorRT engine not found: {engine_path}") - + print(f"Loading TensorRT depth estimation engine: {engine_path}") - + self._engine = TensorRTEngine(engine_path) self._engine.load() self._engine.activate() self._engine.allocate_buffers() - + return self._engine - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply TensorRT depth estimation to the input image """ detect_resolution = self.params.get('detect_resolution', 518) - + image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) - + image_resized = F.interpolate( - image_tensor, - size=(detect_resolution, detect_resolution), - mode='bilinear', + image_tensor, + size=(detect_resolution, detect_resolution), + mode='bilinear', align_corners=False ) - + if torch.cuda.is_available(): image_resized = image_resized.cuda() - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized}, cuda_stream) depth = result['output'] - + depth = np.reshape(depth.cpu().numpy(), (detect_resolution, detect_resolution)) - depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 - depth = depth.astype(np.uint8) - + + # Auto-normalize TRT output to [0, 1], then apply category-standard depth grade. + depth_min, depth_max = float(depth.min()), float(depth.max()) + depth_norm = (depth - depth_min) / max(depth_max - depth_min, 1e-6) + depth_norm = apply_depth_grade_numpy( + depth_norm, + gamma=float(self.params.get("gamma", 1.0)), + black_level=float(self.params.get("black_level", 0.0)), + white_level=float(self.params.get("white_level", 1.0)), + invert=bool(self.params.get("invert", False)), + ) + depth = (depth_norm * 255.0).astype(np.uint8) + original_size = image.size depth = cv2.resize(depth, original_size) - + depth_rgb = cv2.cvtColor(depth, cv2.COLOR_GRAY2RGB) result = Image.fromarray(depth_rgb) - + return result - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU to avoid CPU transfers @@ -207,20 +133,29 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: image_tensor = image_tensor.unsqueeze(0) if not image_tensor.is_cuda: image_tensor = image_tensor.cuda() - + detect_resolution = self.params.get('detect_resolution', 518) - + image_resized = torch.nn.functional.interpolate( - image_tensor, size=(detect_resolution, detect_resolution), + image_tensor, size=(detect_resolution, detect_resolution), mode='bilinear', align_corners=False ) - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized}, cuda_stream) depth_tensor = result['output'] - + depth_tensor = depth_tensor.squeeze() if depth_tensor.dim() > 2 else depth_tensor + + # Auto-normalize TRT output to [0, 1], then apply category-standard depth grade. depth_min, depth_max = depth_tensor.min(), depth_tensor.max() - depth_normalized = (depth_tensor - depth_min) / (depth_max - depth_min) - - return depth_normalized.repeat(3, 1, 1).unsqueeze(0) \ No newline at end of file + depth_normalized = (depth_tensor - depth_min) / (depth_max - depth_min + 1e-6) + depth_normalized = apply_depth_grade( + depth_normalized, + gamma=float(self.params.get("gamma", 1.0)), + black_level=float(self.params.get("black_level", 0.0)), + white_level=float(self.params.get("white_level", 1.0)), + invert=bool(self.params.get("invert", False)), + ) + + return depth_normalized.repeat(3, 1, 1).unsqueeze(0) diff --git a/src/streamdiffusion/preprocessing/processors/passthrough.py b/src/streamdiffusion/preprocessing/processors/passthrough.py index e4d1125f..f4cc45cf 100644 --- a/src/streamdiffusion/preprocessing/processors/passthrough.py +++ b/src/streamdiffusion/preprocessing/processors/passthrough.py @@ -20,11 +20,18 @@ class PassthroughPreprocessor(BasePreprocessor): def get_preprocessor_metadata(cls): return { "display_name": "Passthrough", - "description": "Passes the input image through with minimal processing. Used for tile ControlNet or when you want to use the input image directly.", - "parameters": { - - }, - "use_cases": ["Tile ControlNet", "Image-to-image with structure preservation", "Upscaling with control"] + "description": ( + "Sends the input image directly to the ControlNet with no preprocessing. " + "Use when the input is already a pre-rendered conditioning map — e.g. a " + "depth pass, hand-drawn scribble, or OpenPose skeleton rendered externally." + ), + "parameters": {}, + "use_cases": [ + "Pre-rendered depth / normal maps", + "Hand-drawn scribble inputs", + "Externally generated pose skeletons", + "Image-to-image with structure preservation", + ], } def __init__(self, diff --git a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py index 7662c37c..7c99d8f5 100644 --- a/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/pose_tensorrt.py @@ -8,137 +8,76 @@ from PIL import Image from typing import Union, Optional, List, Tuple from .base import BasePreprocessor - -try: - import tensorrt as trt - from polygraphy.backend.common import bytes_from_path - from polygraphy.backend.trt import engine_from_bytes - from collections import OrderedDict - TENSORRT_AVAILABLE = True -except ImportError: - TENSORRT_AVAILABLE = False - - -# Map of numpy dtype -> torch dtype -numpy_to_torch_dtype_dict = { - np.uint8: torch.uint8, - np.int8: torch.int8, - np.int16: torch.int16, - np.int32: torch.int32, - np.int64: torch.int64, - np.float16: torch.float16, - np.float32: torch.float32, - np.float64: torch.float64, - np.complex64: torch.complex64, - np.complex128: torch.complex128, -} -if np.version.full_version >= "1.24.0": - numpy_to_torch_dtype_dict[np.bool_] = torch.bool -else: - numpy_to_torch_dtype_dict[np.bool] = torch.bool - - -class TensorRTEngine: - """Simplified TensorRT engine wrapper for pose estimation inference (optimized)""" - - def __init__(self, engine_path): - self.engine_path = engine_path - self.engine = None - self.context = None - self.tensors = OrderedDict() - self._cuda_stream = None # Cache CUDA stream - - def load(self): - """Load TensorRT engine from file""" - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - - def activate(self): - """Create execution context""" - self.context = self.engine.create_execution_context() - # Cache CUDA stream for reuse - self._cuda_stream = torch.cuda.current_stream().cuda_stream - - def allocate_buffers(self, device="cuda"): - """Allocate input/output buffers""" - for idx in range(self.engine.num_io_tensors): - name = self.engine.get_tensor_name(idx) - shape = self.context.get_tensor_shape(name) - dtype = trt.nptype(self.engine.get_tensor_dtype(name)) - - if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: - self.context.set_input_shape(name, shape) - - tensor = torch.empty( - tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] - ).to(device=device) - self.tensors[name] = tensor - - def infer(self, feed_dict, stream=None): - """Run inference with optional stream parameter""" - # Use cached stream if none provided - if stream is None: - stream = self._cuda_stream - - # Copy input data to tensors - for name, buf in feed_dict.items(): - self.tensors[name].copy_(buf) - - # Set tensor addresses - for name, tensor in self.tensors.items(): - self.context.set_tensor_address(name, tensor.data_ptr()) - - # Execute inference - success = self.context.execute_async_v3(stream) - if not success: - raise ValueError("TensorRT inference failed.") - - return self.tensors +from .category_params import POSE_DRAW_PARAMS +from .trt_base import TENSORRT_AVAILABLE, TensorRTEngine # shared engine wrapper class PoseVisualization: """Pose drawing utilities ported from ComfyUI YoloNasPose node""" - + @staticmethod - def draw_skeleton(image, keypoints, edge_links, edge_colors, joint_thickness=10, keypoint_radius=10): + def draw_skeleton( + image, + keypoints, + edge_links, + edge_colors, + joint_thickness=10, + keypoint_radius=10, + keypoint_threshold=0.5, + ): """Draw pose skeleton on image""" overlay = image.copy() - + # Draw edges/links between keypoints for (kp1, kp2), color in zip(edge_links, edge_colors): if kp1 < len(keypoints) and kp2 < len(keypoints): # Check if both keypoints are valid (confidence > threshold) if len(keypoints[kp1]) >= 3 and len(keypoints[kp2]) >= 3: conf1, conf2 = keypoints[kp1][2], keypoints[kp2][2] - if conf1 > 0.5 and conf2 > 0.5: + if conf1 > keypoint_threshold and conf2 > keypoint_threshold: p1 = (int(keypoints[kp1][0]), int(keypoints[kp1][1])) p2 = (int(keypoints[kp2][0]), int(keypoints[kp2][1])) cv2.line(overlay, p1, p2, color=color, thickness=joint_thickness, lineType=cv2.LINE_AA) - + # Draw keypoints for keypoint in keypoints: - if len(keypoint) >= 3 and keypoint[2] > 0.5: # confidence threshold + if len(keypoint) >= 3 and keypoint[2] > keypoint_threshold: x, y = int(keypoint[0]), int(keypoint[1]) cv2.circle(overlay, (x, y), keypoint_radius, (0, 255, 0), -1, cv2.LINE_AA) - + return cv2.addWeighted(overlay, 0.75, image, 0.25, 0) @staticmethod - def draw_poses(image, poses, edge_links, edge_colors, joint_thickness=10, keypoint_radius=10): + def draw_poses( + image, + poses, + edge_links, + edge_colors, + joint_thickness=10, + keypoint_radius=10, + keypoint_threshold=0.5, + ): """Draw multiple poses on image""" result = image.copy() - + for pose in poses: result = PoseVisualization.draw_skeleton( - result, pose, edge_links, edge_colors, joint_thickness, keypoint_radius + result, + pose, + edge_links, + edge_colors, + joint_thickness, + keypoint_radius, + keypoint_threshold, ) - + return result def iterate_over_batch_predictions(predictions, batch_size): """Process batch predictions from TensorRT output""" num_detections, batch_boxes, batch_scores, batch_joints = predictions - + for image_index in range(batch_size): num_detection_in_image = int(num_detections[image_index, 0]) @@ -156,29 +95,41 @@ def iterate_over_batch_predictions(predictions, batch_size): yield image_index, pred_boxes, pred_scores, pred_joints # precompute edge links define skeleton connections (COCO format) -edge_links = [[0, 17], [13, 15], [14, 16], [12, 14], [12, 17], [5, 6], - [11, 13], [7, 9], [5, 7], [17, 11], [6, 8], [8, 10], +edge_links = [[0, 17], [13, 15], [14, 16], [12, 14], [12, 17], [5, 6], + [11, 13], [7, 9], [5, 7], [17, 11], [6, 8], [8, 10], [1, 3], [0, 1], [0, 2], [2, 4]] edge_colors = [ - [255, 0, 0], [255, 85, 0], [170, 255, 0], [85, 255, 0], [85, 255, 0], - [85, 0, 255], [255, 170, 0], [0, 177, 58], [0, 179, 119], [179, 179, 0], + [255, 0, 0], [255, 85, 0], [170, 255, 0], [85, 255, 0], [85, 255, 0], + [85, 0, 255], [255, 170, 0], [0, 177, 58], [0, 179, 119], [179, 179, 0], [0, 119, 179], [0, 179, 179], [119, 0, 179], [179, 0, 179], [178, 0, 118], [178, 0, 118] ] -def show_predictions_from_batch_format(predictions): - """Convert predictions to pose visualization format""" + + +def show_predictions_from_batch_format( + predictions, + keypoint_threshold: float = 0.5, + joint_thickness: int = 10, + keypoint_radius: int = 10, +): + """Convert predictions to pose visualization format. + + Args: + predictions: Raw TRT engine output list (num_dets, boxes, scores, joints). + keypoint_threshold: Confidence cutoff for drawing joints (category-standard param). + joint_thickness: Skeleton limb line thickness in pixels. + keypoint_radius: Keypoint dot radius in pixels. + """ try: image_index, pred_boxes, pred_scores, pred_joints = next( iter(iterate_over_batch_predictions(predictions, 1))) except Exception as e: raise RuntimeError(f"show_predictions_from_batch_format: Error in iterate_over_batch_predictions: {e}") - - # Handle case where no poses are detected if pred_joints.shape[0] == 0: return np.zeros((640, 640, 3)) - + # Add middle joint between shoulders (keypoints 5 and 6) try: # Calculate middle joints for all poses at once @@ -187,49 +138,57 @@ def show_predictions_from_batch_format(predictions): new_pred_joints = np.concatenate([pred_joints, middle_joints[:, np.newaxis]], axis=1) except Exception as e: raise RuntimeError(f"show_predictions_from_batch_format: Error processing poses: {e}") - + # Create black background for pose visualization black_image = np.zeros((640, 640, 3)) - + try: image = PoseVisualization.draw_poses( - image=black_image, - poses=new_pred_joints, - edge_links=edge_links, - edge_colors=edge_colors, - joint_thickness=10, - keypoint_radius=10 + image=black_image, + poses=new_pred_joints, + edge_links=edge_links, + edge_colors=edge_colors, + joint_thickness=joint_thickness, + keypoint_radius=keypoint_radius, + keypoint_threshold=keypoint_threshold, ) except Exception as e: raise RuntimeError(f"show_predictions_from_batch_format: Error in pose drawing: {e}") - + return image class YoloNasPoseTensorrtPreprocessor(BasePreprocessor): """ YoloNas Pose TensorRT preprocessor for ControlNet - + Uses TensorRT-optimized YoloNas Pose model for fast pose estimation. """ - + @classmethod def get_preprocessor_metadata(cls): return { "display_name": "Pose Detection (TensorRT)", "description": "Fast TensorRT-optimized pose detection using YOLO-NAS Pose model. Detects human pose keypoints with high performance.", - "parameters": {}, - "use_cases": ["Human pose control", "Character animation", "Pose-guided generation", "Real-time pose detection"] + "parameters": { + **POSE_DRAW_PARAMS, + }, + "use_cases": [ + "Human pose control", + "Character animation", + "Pose-guided generation", + "Real-time pose detection", + ], } - - def __init__(self, + + def __init__(self, engine_path: str = None, detect_resolution: int = 640, image_resolution: int = 512, **kwargs): """ Initialize TensorRT pose preprocessor - + Args: engine_path: Path to TensorRT engine file detect_resolution: Resolution for pose detection (should match engine input) @@ -241,18 +200,18 @@ def __init__(self, "TensorRT and polygraphy libraries are required for TensorRT pose preprocessing. " "Install them with: pip install tensorrt polygraphy" ) - + super().__init__( engine_path=engine_path, detect_resolution=detect_resolution, image_resolution=image_resolution, **kwargs ) - + self._engine = None self._device = "cuda" if torch.cuda.is_available() else "cpu" self._is_cuda_available = torch.cuda.is_available() - + @property def engine(self): """Lazy loading of the TensorRT engine""" @@ -263,56 +222,65 @@ def engine(self): "engine_path is required for TensorRT pose preprocessing. " "Please provide it in the preprocessor_params config." ) - + if not os.path.exists(engine_path): raise FileNotFoundError(f"TensorRT engine not found: {engine_path}") - + self._engine = TensorRTEngine(engine_path) self._engine.load() self._engine.activate() self._engine.allocate_buffers() - + return self._engine - + def _process_core(self, image: Image.Image) -> Image.Image: """ Apply TensorRT pose estimation to the input image """ detect_resolution = self.params.get('detect_resolution', 640) - + image_tensor = torch.from_numpy(np.array(image)).float() / 255.0 image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) - + image_resized = F.interpolate( - image_tensor, - size=(detect_resolution, detect_resolution), - mode='bilinear', + image_tensor, + size=(detect_resolution, detect_resolution), + mode='bilinear', align_corners=False ) - + image_resized_uint8 = (image_resized * 255.0).type(torch.uint8) - + if self._is_cuda_available: image_resized_uint8 = image_resized_uint8.cuda() - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized_uint8}, cuda_stream) - - predictions = [result[key].cpu().numpy() for key in result.keys() if key != 'input'] - + + predictions = [result[key].cpu().numpy() for key in result.keys() if key != "input"] + + keypoint_threshold = float(self.params.get("keypoint_threshold", 0.5)) + joint_thickness = int(self.params.get("joint_thickness", 10)) + keypoint_radius = int(self.params.get("keypoint_radius", 10)) + try: - pose_image = show_predictions_from_batch_format(predictions) + pose_image = show_predictions_from_batch_format( + predictions, + keypoint_threshold=keypoint_threshold, + joint_thickness=joint_thickness, + keypoint_radius=keypoint_radius, + ) except Exception: # Fallback to black image on error pose_image = np.zeros((detect_resolution, detect_resolution, 3)) - + pose_image = pose_image.clip(0, 255).astype(np.uint8) pose_image = cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB) - + result = Image.fromarray(pose_image) - + return result - + def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: """ Process tensor directly on GPU to avoid CPU transfers @@ -321,31 +289,40 @@ def _process_tensor_core(self, image_tensor: torch.Tensor) -> torch.Tensor: image_tensor = image_tensor.unsqueeze(0) if not image_tensor.is_cuda: image_tensor = image_tensor.cuda() - + detect_resolution = self.params.get('detect_resolution', 640) - + image_resized = torch.nn.functional.interpolate( - image_tensor, size=(detect_resolution, detect_resolution), + image_tensor, size=(detect_resolution, detect_resolution), mode='bilinear', align_corners=False ) - + image_resized_uint8 = (image_resized * 255.0).type(torch.uint8) - + cuda_stream = torch.cuda.current_stream().cuda_stream result = self.engine.infer({"input": image_resized_uint8}, cuda_stream) - - predictions = [result[key].cpu().numpy() for key in result.keys() if key != 'input'] - + + predictions = [result[key].cpu().numpy() for key in result.keys() if key != "input"] + + keypoint_threshold = float(self.params.get("keypoint_threshold", 0.5)) + joint_thickness = int(self.params.get("joint_thickness", 10)) + keypoint_radius = int(self.params.get("keypoint_radius", 10)) + try: - pose_image = show_predictions_from_batch_format(predictions) + pose_image = show_predictions_from_batch_format( + predictions, + keypoint_threshold=keypoint_threshold, + joint_thickness=joint_thickness, + keypoint_radius=keypoint_radius, + ) pose_image = pose_image.clip(0, 255).astype(np.uint8) pose_image = cv2.cvtColor(pose_image, cv2.COLOR_BGR2RGB) - + pose_tensor = torch.from_numpy(pose_image).float() / 255.0 pose_tensor = pose_tensor.permute(2, 0, 1).unsqueeze(0).cuda() - + except Exception: # Fallback to black tensor on error pose_tensor = torch.zeros(1, 3, detect_resolution, detect_resolution).cuda() - - return pose_tensor \ No newline at end of file + + return pose_tensor diff --git a/src/streamdiffusion/preprocessing/processors/scribble_tensorrt.py b/src/streamdiffusion/preprocessing/processors/scribble_tensorrt.py new file mode 100644 index 00000000..1a4774a1 --- /dev/null +++ b/src/streamdiffusion/preprocessing/processors/scribble_tensorrt.py @@ -0,0 +1,144 @@ +""" +Scribble TensorRT preprocessor — GPU-native scribble edge maps via TRT. + +Reuses the HED TRT engine (no second build needed). Overrides _postprocess +to apply a GPU-native NMS + binarization that replicates the scribble=True +post-processing from controlnet_aux HEDdetector.__call__: + + 1. Gaussian-blur the sigmoid edge map (smooth noise) + 2. Directional NMS — keep only local maxima (thin lines) + 3. Threshold at 0.5 → binary edge mask +""" + +import logging + +import torch +import torch.nn.functional as F + +from .category_params import EDGE_SMOOTHNESS_PARAM, apply_edge_smoothness +from .hed_tensorrt import HEDTensorrtPreprocessor +from .trt_base import _first_output + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# GPU scribble NMS helper +# --------------------------------------------------------------------------- + + +def _scribble_nms_gpu(edge_map: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + """ + Approximate GPU version of controlnet_aux nms() used by scribble=True mode. + + Performs: + 1. Light Gaussian blur (3×3 average pool — avoids kornia dependency) + 2. 4-directional local-max suppression (keep only ridge pixels) + 3. Threshold at `threshold` + + Args: + edge_map: (H, W) float32 tensor in [0, 1] on GPU + threshold: binarization threshold (default 0.5) + + Returns: + (H, W) float32 binary tensor on GPU + """ + x = edge_map.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) + + # Step 1: smooth + x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) + + # Step 2: directional NMS — keep pixel if it is the local max along + # each of the 4 scanning directions (horizontal, vertical, two diagonals). + # We approximate with isotropic max-pool (good enough for thin-line extraction). + x_max = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) + is_max = (x == x_max).float() + thinned = x * is_max + + # Step 3: binarize + binary = (thinned.squeeze() > threshold).float() + + return binary + + +# --------------------------------------------------------------------------- +# Preprocessor +# --------------------------------------------------------------------------- + + +class ScribbleTensorrtPreprocessor(HEDTensorrtPreprocessor): + """ + Scribble edge maps via TRT — reuses the HED engine, overrides postprocess. + + The 'scribble' mode in controlnet_aux HEDdetector runs the same HED + network but adds an NMS + binarization step. Here we replicate that + step with GPU tensor operations, so the full pipeline stays on CUDA. + + No second engine build is needed: engine_filename points at hed.engine. + """ + + # Deliberately points at the HED engine — no separate build + engine_filename = "hed.engine" + onnx_filename = "hed.onnx" # kept consistent; export is never re-run if engine exists + default_detect_resolution = 512 + + @classmethod + def get_preprocessor_metadata(cls): + return { + "display_name": "Scribble Edge Detection (TensorRT)", + "description": ( + "GPU-native scribble-style edge maps. Uses the HED TRT engine with " + "GPU NMS + binarization post-processing (no CPU round-trips). " + "Compatible with scribble ControlNets." + ), + "parameters": { + "scribble_threshold": { + "type": "float", + "default": 0.01, # was 0.5 — post-NMS ridge values live near zero (~0.005–0.05) + "range": [0.0, 0.05], # was [0.0, 1.0] — spreads useful control across full travel + "description": ( + "Binarization threshold for scribble edge NMS. Operates on the post-NMS ridge map " + "whose values are small (~0.005–0.05); lower keeps more edges." + ), + }, + **EDGE_SMOOTHNESS_PARAM, + }, + "use_cases": [ + "Scribble ControlNet conditioning", + "Sketch-style edge maps (real-time)", + ], + } + + def _postprocess(self, engine_outputs: dict) -> torch.Tensor: + """ + Apply scribble NMS + threshold to the HED output, return 3-channel CHW. + + Input : engine_outputs["output"] shape (B, 1, H, W) or (B, H, W) + Output : (3, H, W) in {0.0, 1.0} (binary scribble map) + """ + out = _first_output(engine_outputs).float() + + if out.dim() == 4: + out = out.squeeze(1) + if out.dim() == 3: + out = out.squeeze(0) # (H, W) + + # Normalize to [0, 1] before NMS + v_min, v_max = out.min(), out.max() + if v_max > v_min: + out = (out - v_min) / (v_max - v_min) + out = out.clamp(0.0, 1.0) + + # Optional smoothness pre-blur (category-standard edge param) applied before + # NMS so that increasing smoothness suppresses fine texture while preserving + # the structural ridges that NMS retains. + smoothness = float(self.params.get("smoothness", 0.0)) + if smoothness > 0.0: + out = apply_edge_smoothness(out, smoothness) # (H, W) in, (H, W) out + + threshold = float(self.params.get("scribble_threshold", 0.5)) + scribble = _scribble_nms_gpu(out, threshold=threshold) # (H, W) + + # Expand to 3-channel RGB + return scribble.unsqueeze(0).repeat(3, 1, 1) # (3, H, W)