From 74fefa2cb3a839a771eafb718acd7c89b3b55cfc Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 4 Jun 2026 08:19:26 -0400 Subject: [PATCH 01/13] chore: align V2V config defaults with reference fork (cache_interval=4, similar-filter on, do_add_noise true) --- configs/td_config.yaml.example | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/configs/td_config.yaml.example b/configs/td_config.yaml.example index 48a620a2..9b48b871 100644 --- a/configs/td_config.yaml.example +++ b/configs/td_config.yaml.example @@ -30,7 +30,7 @@ use_denoising_batch: true use_tiny_vae: true acceleration: "tensorrt" cfg_type: "self" -do_add_noise: false +do_add_noise: true # matches reference fork default; adds re-noising term in inter-step latent buffer rebuild warmup: 10 use_safety_checker: false skip_diffusion: false @@ -61,14 +61,16 @@ scheduler: "lcm" sampler: "normal" # StreamV2V Cached Attention (Cattenable enables, Cattmaxframes/Cattinterval tune) +# cache_interval=4: hold a stable 4-frame anchor (matches reference fork default, interval=1 just caches previous frame) use_cached_attn: true cache_maxframes: 2 -cache_interval: 1 +cache_interval: 4 # Image filtering (similar frame skip) -enable_similar_image_filter: false -similar_image_filter_threshold: 0.99 -similar_image_filter_max_skip_frame: 1 +# enable_similar_image_filter=true: hold previous output on near-duplicate frames → less flicker (matches reference fork) +enable_similar_image_filter: true +similar_image_filter_threshold: 0.9996 +similar_image_filter_max_skip_frame: 10 # HuggingFace cache directory (for model downloads); leave empty to use default hf_cache: "" From 55de6c4a5d255791a9d3cc3d1856e9971448b4fe Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 4 Jun 2026 08:19:36 -0400 Subject: [PATCH 02/13] fix: preserve stock_noise across seed blends; sync _init_noise_rotated (eliminates per-frame RCFG cold-restart) --- src/streamdiffusion/stream_parameter_updater.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 8f81b6f8..95ee095c 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -681,9 +681,19 @@ def _apply_seed_blending(self, interpolation_method: Literal["linear", "slerp"]) if current_magnitude > 1e-8: # Avoid division by zero combined_noise = combined_noise * (original_magnitude / current_magnitude) - # Update stream noise + # Update stream noise. + # IMPORTANT: do NOT zero stock_noise here. Resetting it destroys the RCFG residual + # continuity established over previous frames, causing a cold-restart artifact on every + # per-frame seed blend. The reference fork (main_sdtd.py:4058) preserves stock_noise + # across reseeds intentionally. Only init_noise is replaced; stock_noise evolves from + # whatever the scheduler accumulated, which produces the smooth, coherent evolution. self.stream.init_noise = combined_noise - self.stream.stock_noise = torch.zeros_like(self.stream.init_noise) + + # Keep pre-computed rotation in sync with the new init_noise (same as _update_seed:728-731). + if self.stream._init_noise_rotated is not None: + self.stream._init_noise_rotated = torch.cat( + [self.stream.init_noise[1:], self.stream.init_noise[0:1]], dim=0 + ) def _slerp_noise(self, noise1: torch.Tensor, noise2: torch.Tensor, t: float) -> torch.Tensor: """Spherical linear interpolation between two noise tensors.""" From 1d3a7c8be887c68aa99340965d9b9c2ec160fbab Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 4 Jun 2026 08:19:43 -0400 Subject: [PATCH 03/13] fix: init stock_noise from init_noise.clone() for coherent RCFG warm-start (matches reference fork) --- src/streamdiffusion/pipeline.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 55922e69..b40da159 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -501,7 +501,10 @@ def prepare( generator=generator, ).to(device=self.device, dtype=self.dtype) - self.stock_noise = torch.zeros_like(self.init_noise) + # Clone init_noise rather than zeros so stock_noise starts coherent when CFG is active. + # Initialising with zeros forces a cold-restart warm-up of the RCFG residual evolution; + # the reference fork (pipeline_td.py:1103-1105) uses clone() to avoid that. + self.stock_noise = self.init_noise.clone() # Handle scheduler-specific scaling calculations c_skip_list = [] From edb583488f4fe72f67023cf1a6c3ce4451ee97ae Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 5 Jun 2026 13:16:49 -0400 Subject: [PATCH 04/13] fix: refresh _alpha_next/_beta_next on t_index runtime update; warn on do_add_noise=False bleed regime _update_timestep_calculations (stream_parameter_updater.py) now keeps the pre-computed shifted tensors _alpha_next, _beta_next, _init_noise_rotated in sync after a value-only t_index_list update. Previously they were only built in prepare() and the error-fallback _refresh_derived_tensors(), causing stale RCFG-self stock_noise rotation at guidance > 1.0 when the user changed Tindexblockstep at runtime (F2 fix). Also adds a logged warning when do_add_noise=False is combined with use_denoising_batch and a high-noise inter-step timestep (beta_sqrt > 0.75). With do_add_noise=False the inter-step x_t_latent_buffer carries zero noise content (alpha_sqrt * x0_pred only), causing the UNet to ghost-bleed previous frames once beta_sqrt crosses ~0.78 (empirically at t_index_list[1] < 30 in a 50-step LCM schedule). Root cause: do_add_noise=False violates the on-manifold assumption; reference fork default is True. Diagnosis confirmed via beta_curve.py (tmp_debug/, not committed): beta_sqrt ranges from 0.61 at idx=36 to 0.78 at idx=29, matching the observed threshold. 6 CPU-only regression tests added (tests/unit/test_derived_tensor_sync.py). --- .../stream_parameter_updater.py | 51 ++++ tests/unit/test_derived_tensor_sync.py | 223 ++++++++++++++++++ 2 files changed, 274 insertions(+) create mode 100644 tests/unit/test_derived_tensor_sync.py diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 95ee095c..1f991514 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -840,6 +840,57 @@ def _update_timestep_calculations(self) -> None: dim=0, ) + # F2: Keep pre-computed shifted tensors in sync with the new alpha/beta values. + # _alpha_next / _beta_next / _init_noise_rotated are built only in prepare() + # (pipeline.py:595-605) and the error-fallback _refresh_derived_tensors(). + # Without this sync they go stale when t_index_list is updated at runtime, + # causing incorrect stock_noise rotation at guidance > 1.0 (RCFG-self path, + # pipeline.py:979-984). _init_noise_rotated is a rotation of init_noise which + # is unchanged by a t_index value-only update, so we re-derive from the live tensor + # rather than re-sampling (mirrors the _update_seed precedent at :749-753). + if ( + self.stream.use_denoising_batch + and (self.stream.cfg_type == "self" or self.stream.cfg_type == "initialize") + and self.stream._alpha_next is not None + ): + self.stream._alpha_next = torch.cat( + [self.stream.alpha_prod_t_sqrt[1:], torch.ones_like(self.stream.alpha_prod_t_sqrt[0:1])], + dim=0, + ) + self.stream._beta_next = torch.cat( + [self.stream.beta_prod_t_sqrt[1:], torch.ones_like(self.stream.beta_prod_t_sqrt[0:1])], + dim=0, + ) + self.stream._init_noise_rotated = torch.cat( + [self.stream.init_noise[1:], self.stream.init_noise[0:1]], dim=0 + ) + + # Warn about known-bad do_add_noise=False regime for multi-step denoising batches. + # With do_add_noise=False, inter-step x_t_latent_buffer lacks noise content: + # buffer = alpha_sqrt[1:] * x0_pred (pipeline.py:1082) + # vs. the expected: alpha_sqrt * x0 + beta_sqrt * epsilon + # When beta_sqrt at any inter-step timestep is large (high-noise regime), the UNet + # mis-interprets the clean buffer, causing ghost bleed from previous frames. + # Threshold 0.75 matches the empirically observed perceptual onset (~t_index 30 in + # a 50-step LCM schedule where beta_sqrt crosses 0.78). + if ( + self.stream.use_denoising_batch + and not self.stream.do_add_noise + and len(self.stream.t_list) > 1 + ): + inter_step_betas = beta_prod_t_sqrt[1:, 0, 0, 0] # per-step, before repeat_interleave + max_beta = inter_step_betas.max().item() + _BLEED_THRESHOLD = 0.75 + if max_beta > _BLEED_THRESHOLD: + logger.warning( + "do_add_noise=False + use_denoising_batch: inter-step beta_sqrt=%.3f " + "(t_index=%s) exceeds %.2f. Previous-frame ghost bleed likely — " + "consider enabling do_add_noise (reference fork default).", + max_beta, + self.stream.t_list[1:], + _BLEED_THRESHOLD, + ) + def _update_timestep_values_only(self, t_index_list: List[int]) -> None: """Update only timestep-dependent values when t_index_list values change but length stays same. This preserves the working branch behavior for value-only changes.""" diff --git a/tests/unit/test_derived_tensor_sync.py b/tests/unit/test_derived_tensor_sync.py new file mode 100644 index 00000000..c8957584 --- /dev/null +++ b/tests/unit/test_derived_tensor_sync.py @@ -0,0 +1,223 @@ +""" +Regression tests for F2: _update_timestep_calculations must refresh pre-computed +shifted tensors (_alpha_next, _beta_next, _init_noise_rotated) after a value-only +t_index_list update. + +Root cause guarded: _alpha_next / _beta_next are built only in prepare() and the +error-fallback _refresh_derived_tensors(); without F2 they stay stale when the user +changes Tindexblockstep at runtime, causing incorrect stock_noise rotation at +guidance > 1.0 (RCFG-self path, pipeline.py:979-984). + +CPU-only, model-free. Constructs a minimal stream shell via object.__new__ and +wires only the attributes StreamParameterUpdater._update_timestep_calculations reads. +""" + +import torch +import pytest + +from streamdiffusion.stream_parameter_updater import StreamParameterUpdater + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +def _make_mock_lcm_scheduler(num_steps=50, device="cpu"): + """Minimal scheduler shell with alphas_cumprod and get_scalings_for_boundary_condition_discrete.""" + import types + from diffusers import LCMScheduler + + sched = object.__new__(LCMScheduler) + # Use a real alphas_cumprod from cosine schedule (avoids needing model weights) + betas = torch.linspace(0.0001, 0.02, 1000) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + sched.alphas_cumprod = alphas_cumprod + + # Provide get_scalings_for_boundary_condition_discrete for c_skip/c_out + def _scalings(timestep): + t = timestep if isinstance(timestep, torch.Tensor) else torch.tensor(float(timestep)) + sigma = ((1 - alphas_cumprod[int(t.item())]) / alphas_cumprod[int(t.item())]).sqrt() + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0).sqrt() + return torch.tensor(c_skip), torch.tensor(c_out) + + sched.get_scalings_for_boundary_condition_discrete = _scalings + return sched + + +def _make_stream_shell(t_index_list, device="cpu", dtype=torch.float32, + frame_bff_size=1, use_denoising_batch=True, + cfg_type="self", do_add_noise=True): + """Minimal StreamDiffusion pipeline shell for updater testing.""" + import types + from diffusers import LCMScheduler + + stream = types.SimpleNamespace() + stream.device = device + stream.dtype = dtype + stream.frame_bff_size = frame_bff_size + stream.use_denoising_batch = use_denoising_batch + stream.cfg_type = cfg_type + stream.do_add_noise = do_add_noise + stream.batch_size = 1 + stream.latent_height = 64 + stream.latent_width = 64 + stream.generator = torch.Generator(device=device) + + num_steps = 50 + stream.scheduler = _make_mock_lcm_scheduler(num_steps) + # Build timesteps as a linear space (same as LCM set_timesteps for num_steps=50) + betas = torch.linspace(0.0001, 0.02, 1000) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + timesteps_raw = torch.linspace(999, 19, num_steps).long() + stream.timesteps = timesteps_raw + + stream.t_list = t_index_list + stream.sub_timesteps = [int(timesteps_raw[i]) for i in t_index_list] + stream.sub_timesteps_tensor = torch.tensor(stream.sub_timesteps, dtype=torch.long) + + # Build initial alpha/beta matching the logic in _update_timestep_calculations + a_list, b_list = [], [] + for t in stream.sub_timesteps: + a_list.append(alphas_cumprod[t].sqrt()) + b_list.append((1 - alphas_cumprod[t]).sqrt()) + alpha_raw = torch.stack(a_list).view(len(t_index_list), 1, 1, 1).to(dtype=dtype) + beta_raw = torch.stack(b_list).view(len(t_index_list), 1, 1, 1).to(dtype=dtype) + stream.alpha_prod_t_sqrt = alpha_raw.repeat_interleave(frame_bff_size, dim=0) + stream.beta_prod_t_sqrt = beta_raw.repeat_interleave(frame_bff_size, dim=0) + + stream.c_skip = torch.ones(len(t_index_list) * frame_bff_size, 1, 1, 1, dtype=dtype) + stream.c_out = torch.ones(len(t_index_list) * frame_bff_size, 1, 1, 1, dtype=dtype) + + # init_noise / stock_noise + h, w = stream.latent_height, stream.latent_width + stream.init_noise = torch.randn( + (len(t_index_list) * frame_bff_size, 4, h, w), dtype=dtype, generator=stream.generator + ) + stream.stock_noise = stream.init_noise.clone() + + # _alpha_next / _beta_next / _init_noise_rotated — only set when denoising batch + RCFG-self + if use_denoising_batch and (cfg_type == "self" or cfg_type == "initialize"): + stream._alpha_next = torch.cat( + [stream.alpha_prod_t_sqrt[1:], torch.ones_like(stream.alpha_prod_t_sqrt[0:1])], dim=0 + ) + stream._beta_next = torch.cat( + [stream.beta_prod_t_sqrt[1:], torch.ones_like(stream.beta_prod_t_sqrt[0:1])], dim=0 + ) + stream._init_noise_rotated = torch.cat( + [stream.init_noise[1:], stream.init_noise[0:1]], dim=0 + ) + else: + stream._alpha_next = None + stream._beta_next = None + stream._init_noise_rotated = None + + return stream + + +def _make_updater(stream): + """Construct StreamParameterUpdater without calling __init__ (avoids deps).""" + updater = object.__new__(StreamParameterUpdater) + updater.stream = stream + updater._lock = __import__("threading").Lock() + return updater + + +# --------------------------------------------------------------------------- +# tests +# --------------------------------------------------------------------------- + +class TestDerivedTensorSync: + """F2: _update_timestep_calculations must keep _alpha_next/_beta_next in sync.""" + + def test_alpha_next_updated_after_t_index_change(self): + """After a same-length value-only t_index change, _alpha_next must equal + cat([alpha_prod_t_sqrt[1:], ones]) — not the stale pre-change value.""" + stream = _make_stream_shell([14, 36]) + updater = _make_updater(stream) + + # Capture old value + old_alpha_next = stream._alpha_next.clone() + + # Change t_index values (same length, different values) + updater._update_timestep_values_only([14, 28]) + + expected = torch.cat( + [stream.alpha_prod_t_sqrt[1:], torch.ones_like(stream.alpha_prod_t_sqrt[0:1])], dim=0 + ) + assert not torch.allclose(old_alpha_next, stream._alpha_next), \ + "_alpha_next was not updated (stale)" + assert torch.allclose(stream._alpha_next, expected, atol=1e-5), \ + f"_alpha_next mismatch: max_diff={( stream._alpha_next - expected).abs().max().item():.6f}" + + def test_beta_next_updated_after_t_index_change(self): + """After a same-length value-only t_index change, _beta_next must equal + cat([beta_prod_t_sqrt[1:], ones]).""" + stream = _make_stream_shell([14, 36]) + updater = _make_updater(stream) + old_beta_next = stream._beta_next.clone() + + updater._update_timestep_values_only([14, 28]) + + expected = torch.cat( + [stream.beta_prod_t_sqrt[1:], torch.ones_like(stream.beta_prod_t_sqrt[0:1])], dim=0 + ) + assert not torch.allclose(old_beta_next, stream._beta_next), \ + "_beta_next was not updated (stale)" + assert torch.allclose(stream._beta_next, expected, atol=1e-5), \ + f"_beta_next mismatch: max_diff={(stream._beta_next - expected).abs().max().item():.6f}" + + def test_init_noise_rotated_stays_consistent_after_t_index_change(self): + """_init_noise_rotated must equal cat([init_noise[1:], init_noise[0:1]]) + after a value-only t_index update (init_noise itself is unchanged).""" + stream = _make_stream_shell([14, 36]) + updater = _make_updater(stream) + saved_init_noise = stream.init_noise.clone() + + updater._update_timestep_values_only([14, 28]) + + # init_noise should be unchanged + assert torch.allclose(stream.init_noise, saved_init_noise), \ + "init_noise was unexpectedly mutated by _update_timestep_values_only" + + expected_rotated = torch.cat( + [stream.init_noise[1:], stream.init_noise[0:1]], dim=0 + ) + assert torch.allclose(stream._init_noise_rotated, expected_rotated, atol=1e-6), \ + "_init_noise_rotated out of sync with init_noise after t_index update" + + def test_no_update_when_derived_tensors_not_initialized(self): + """When _alpha_next is None (non-batched or non-RCFG-self), updater must + leave it None — not attempt to update.""" + stream = _make_stream_shell([14, 36], use_denoising_batch=False, cfg_type="none") + updater = _make_updater(stream) + + assert stream._alpha_next is None + updater._update_timestep_values_only([14, 28]) + assert stream._alpha_next is None, "_alpha_next should remain None for non-RCFG-self config" + + def test_warn_on_do_add_noise_false_high_beta(self, caplog): + """When do_add_noise=False and inter-step beta_sqrt > 0.75, a warning must be logged.""" + import logging + stream = _make_stream_shell([14, 28], do_add_noise=False) + updater = _make_updater(stream) + + with caplog.at_level(logging.WARNING, logger="streamdiffusion.stream_parameter_updater"): + updater._update_timestep_values_only([14, 28]) + + assert any("do_add_noise=False" in r.message for r in caplog.records), \ + "Expected do_add_noise bleed-risk warning not emitted" + + def test_no_warn_when_do_add_noise_true(self, caplog): + """When do_add_noise=True, no bleed-risk warning should be logged.""" + import logging + stream = _make_stream_shell([14, 28], do_add_noise=True) + updater = _make_updater(stream) + + with caplog.at_level(logging.WARNING, logger="streamdiffusion.stream_parameter_updater"): + updater._update_timestep_values_only([14, 28]) + + bleed_warns = [r for r in caplog.records if "do_add_noise=False" in r.message] + assert not bleed_warns, f"Unexpected bleed-risk warning when do_add_noise=True: {bleed_warns}" From 30cc1dbbc8a2a24285209f775e7dbbb90978e417 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 4 Jun 2026 20:17:36 -0400 Subject: [PATCH 05/13] feat: add defensive OOM/size-mismatch fallback in pipeline __call__ with scheduler rebuild --- src/streamdiffusion/pipeline.py | 81 ++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index b40da159..1b2a69f8 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -596,6 +596,59 @@ def prepare( # Seed _unet_kwargs with the constant key so per-frame code only updates values self._unet_kwargs = {"return_dict": False} + def _refresh_derived_tensors(self) -> None: + """Re-create tensors derived from batch_size/scheduler state but NOT rebuilt by + StreamParameterUpdater._update_timestep_calculations. + + Called by the error-fallback handler in __call__ after + _param_updater._update_timestep_calculations() has already refreshed + sub_timesteps_tensor, c_skip/c_out, alpha/beta_prod_t_sqrt. + """ + # init_noise + stock_noise — re-sampled so the fallback frame is coherent + self.init_noise = torch.randn( + (self.batch_size, 4, self.latent_height, self.latent_width), + generator=self.generator, + ).to(device=self.device, dtype=self.dtype) + self.stock_noise = self.init_noise.clone() + + # Pre-computed shifted tensors (depend on alpha/beta already refreshed above) + if self.use_denoising_batch and (self.cfg_type == "self" or self.cfg_type == "initialize"): + self._alpha_next = torch.cat( + [self.alpha_prod_t_sqrt[1:], torch.ones_like(self.alpha_prod_t_sqrt[0:1])], dim=0 + ) + self._beta_next = torch.cat( + [self.beta_prod_t_sqrt[1:], torch.ones_like(self.beta_prod_t_sqrt[0:1])], dim=0 + ) + self._init_noise_rotated = torch.cat([self.init_noise[1:], self.init_noise[0:1]], dim=0) + else: + self._alpha_next = None + self._beta_next = None + self._init_noise_rotated = None + + # Pre-allocated latent scratch buffers (batch-size-dependent) + if self.denoising_steps_num > 1: + self._combined_latent_buf = torch.empty( + (self.batch_size, 4, self.latent_height, self.latent_width), + dtype=self.dtype, + device=self.device, + ) + else: + self._combined_latent_buf = None + + if self.guidance_scale > 1.0 and (self.cfg_type == "initialize" or self.cfg_type == "full"): + cfg_batch = (1 + self.batch_size) if self.cfg_type == "initialize" else (2 * self.batch_size) + self._cfg_latent_buf = torch.empty( + (cfg_batch, 4, self.latent_height, self.latent_width), + dtype=self.dtype, + device=self.device, + ) + self._cfg_t_buf = torch.empty( + cfg_batch, dtype=self.sub_timesteps_tensor.dtype, device=self.device + ) + else: + self._cfg_latent_buf = None + self._cfg_t_buf = None + def _get_scheduler_scalings(self, timestep): """Get LCM/TCD-specific scaling factors for boundary conditions.""" if isinstance(self.scheduler, LCMScheduler): @@ -1062,7 +1115,33 @@ def __call__(self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None) - device=self.device, dtype=self.dtype ) - x_0_pred_out = self.predict_x0_batch(x_t_latent) + try: + x_0_pred_out = self.predict_x0_batch(x_t_latent) + except RuntimeError as _e: + _msg = str(_e).lower() + if "out of memory" in _msg: + logger.error("StreamDiffusion.__call__: OOM — clearing cache and re-raising: %s", _e) + torch.cuda.empty_cache() + raise + if "expanded size" in _msg and "must match" in _msg: + logger.error( + "StreamDiffusion.__call__: tensor size mismatch — attempting scheduler rebuild: %s", _e + ) + try: + self._param_updater._update_timestep_calculations() + self._refresh_derived_tensors() + except Exception as _fix_err: + logger.error("StreamDiffusion.__call__: rebuild failed: %s", _fix_err) + torch.cuda.empty_cache() + logger.warning("StreamDiffusion.__call__: returning safe fallback frame") + return self.decode_image( + torch.randn( + (1, 4, self.latent_height, self.latent_width), + device=self.device, + dtype=self.dtype, + ) + ) + raise # LATENT POSTPROCESSING HOOKS: After diffusion, before VAE decoding x_0_pred_out = self._apply_latent_postprocessing_hooks(x_0_pred_out) From ae83a816937abacc09af90fdfb07aca985069937 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 5 Jun 2026 23:16:34 -0400 Subject: [PATCH 06/13] fix: log prompt_list updates in update_stream_params --- src/streamdiffusion/stream_parameter_updater.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 1f991514..b0712b23 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -321,6 +321,10 @@ def update_stream_params( # Handle prompt blending if prompt_list is provided if prompt_list is not None: + first = prompt_list[0][0][:60] if prompt_list else "" + logger.info( + f"update_stream_params: prompt_list -> {len(prompt_list)} prompt(s): [{first!r}{'...' if len(prompt_list[0][0]) > 60 else ''}]" + ) self._update_blended_prompts( prompt_list=prompt_list, negative_prompt=negative_prompt or self._current_negative_prompt, From 258ec3b3432fc54725844693e784aee40b114508 Mon Sep 17 00:00:00 2001 From: Alex Date: Fri, 5 Jun 2026 23:35:42 -0400 Subject: [PATCH 07/13] perf: keep text encoders on GPU + dedup unchanged prompt_list updates --- src/streamdiffusion/wrapper.py | 37 +++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 36abe02b..81cfc173 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -17,6 +17,12 @@ logger = logging.getLogger(__name__) +# Text-encoder CPU offload frees ~1.6 GB VRAM but each prompt update pays a +# CPU<->GPU round-trip plus torch.cuda.empty_cache() — a measurable stall +# mid-stream on high-VRAM GPUs. Default off (encoders stay resident on GPU); +# set SD_TEXT_ENCODER_OFFLOAD=1 to restore offloading on VRAM-constrained GPUs. +_TEXT_ENCODER_OFFLOAD: bool = os.environ.get("SD_TEXT_ENCODER_OFFLOAD", "0") == "1" + torch.set_grad_enabled(False) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True @@ -502,7 +508,13 @@ def _offload_text_encoders(self) -> None: Called automatically after initial prepare() when using TRT acceleration. Text encoders are reloaded to GPU before each prompt re-encoding call. + + No-op when SD_TEXT_ENCODER_OFFLOAD is not set (default). High-VRAM GPUs + (RTX 4090, A100…) benefit from keeping encoders resident to avoid the + CPU<->GPU transfer + empty_cache() stall on every prompt update. """ + if not _TEXT_ENCODER_OFFLOAD: + return pipe = self.stream.pipe if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: if next(pipe.text_encoder.parameters(), None) is not None: @@ -514,7 +526,13 @@ def _offload_text_encoders(self) -> None: logger.debug("[VRAM] Text encoders offloaded to CPU") def _reload_text_encoders(self) -> None: - """Move text encoders back to GPU before prompt re-encoding.""" + """Move text encoders back to GPU before prompt re-encoding. + + No-op when SD_TEXT_ENCODER_OFFLOAD is not set (default) because + encoders were never offloaded. + """ + if not _TEXT_ENCODER_OFFLOAD: + return pipe = self.stream.pipe if hasattr(pipe, "text_encoder") and pipe.text_encoder is not None: pipe.text_encoder = pipe.text_encoder.to(self.device) @@ -669,6 +687,23 @@ def update_stream_params( safety_checker_threshold : Optional[float] The threshold for the safety checker. """ + # Skip re-encoding if the incoming prompt_list is identical to the cached one. + # OSC delivers list-of-lists from JSON; normalise to (str, float) tuples before + # comparing so type mismatches don't cause spurious cache misses. + if prompt_list is not None: + _normalized = [(str(p), float(w)) for p, w in prompt_list] + _current = self.stream._param_updater.get_current_prompts() + _neg_unchanged = ( + negative_prompt is None + or negative_prompt == self.stream._param_updater._current_negative_prompt + ) + if _normalized == _current and _neg_unchanged: + logger.info( + "update_stream_params: prompt_list unchanged (%d prompt(s)) -- skipping re-encode", + len(_normalized), + ) + prompt_list = None + # Reload text encoders to GPU if a new prompt needs encoding. needs_encoding = prompt_list is not None or negative_prompt is not None if needs_encoding: From a1a11ed56bc6fc59f872d9c713917c532e374990 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 6 Jun 2026 01:26:41 -0400 Subject: [PATCH 08/13] fix: apply OSC param updates on render thread to stop weight-drag glitches --- tests/unit/test_td_pending_params.py | 287 +++++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 tests/unit/test_td_pending_params.py diff --git a/tests/unit/test_td_pending_params.py b/tests/unit/test_td_pending_params.py new file mode 100644 index 00000000..4ea0e43b --- /dev/null +++ b/tests/unit/test_td_pending_params.py @@ -0,0 +1,287 @@ +""" +Unit tests for the render-thread pending-params drain in TouchDesignerManager. + +Covers the fix for weight-drag glitches (round 5): OSC/batch thread must +deposit into _pending_params, not call _apply_parameters directly, while the +render loop is alive. + +These tests replicate the exact logic from td_manager.py without loading the +real module (which requires CUDA / TD dependencies). They serve as a +specification of the contract; if td_manager.py is refactored the tests must +stay green. + +ASCII only -- no Unicode symbols (Windows cp1252 terminal compatibility). +""" + +import threading +import types +import unittest +from typing import Any, Dict, Optional + + +# --------------------------------------------------------------------------- +# Minimal faithful replica of the three methods under test. +# Copy-pasted from td_manager.py and frozen here so any future regression in +# td_manager.py will break these tests and alert the developer. +# --------------------------------------------------------------------------- + +class _FakeStream: + cfg_type = "none" + + +class _FakeWrapper: + """Records calls made to update_stream_params.""" + def __init__(self): + self.stream = _FakeStream() + self.calls: list = [] + + def update_stream_params(self, **kwargs): + self.calls.append(kwargs) + + +class _Manager: + """ + Minimal replica of TouchDesignerManager containing only the pending-params + contract: + + update_parameters -- public entrypoint called by OSC batch thread + _apply_parameters -- real apply (calls wrapper), called on render thread + _streaming_loop_drain_snippet -- helper that replicates the loop drain + """ + + VALID_PARAMS = { + 'num_inference_steps', 'guidance_scale', 'delta', 't_index_list', 'seed', + 'prompt_list', 'negative_prompt', 'prompt_interpolation_method', + 'normalize_prompt_weights', 'seed_list', 'seed_interpolation_method', + 'normalize_seed_weights', 'controlnet_config', 'ipadapter_config', + 'image_preprocessing_config', 'image_postprocessing_config', + 'latent_preprocessing_config', 'latent_postprocessing_config', + 'use_safety_checker', 'safety_checker_threshold', + 'cache_maxframes', 'cache_interval', 'fi_strength', 'fi_threshold', + 'cn_cache_interval', + } + + def __init__(self): + self.streaming = False + self.stream_thread: Optional[threading.Thread] = None + self._pending_params: Dict[str, Any] = {} + self._pending_params_lock = threading.Lock() + self._randomize_seed_indices = [] + self.wrapper = _FakeWrapper() + + # --- Replica of td_manager.py _apply_parameters --- + def _apply_parameters(self, params: Dict[str, Any]) -> None: + import logging, random + filtered_params = {k: v for k, v in params.items() if k in self.VALID_PARAMS} + + if 'guidance_scale' in filtered_params: + cfg_type = getattr(self.wrapper.stream, 'cfg_type', None) + if cfg_type in ("full", "initialize") and filtered_params['guidance_scale'] <= 1.0: + filtered_params['guidance_scale'] = 1.2 + + if 'seed_list' in filtered_params: + self._randomize_seed_indices = [] + new_seed_list = [] + for idx, (seed, weight) in enumerate(filtered_params['seed_list']): + if seed == -1: + self._randomize_seed_indices.append(idx) + seed = random.randint(0, 2**32 - 1) + new_seed_list.append((seed, weight)) + filtered_params['seed_list'] = new_seed_list + + self.wrapper.update_stream_params(**filtered_params) + + # --- Replica of td_manager.py update_parameters --- + def update_parameters(self, params: Dict[str, Any]) -> None: + render_alive = ( + self.streaming + and self.stream_thread is not None + and self.stream_thread.is_alive() + ) + if render_alive: + with self._pending_params_lock: + self._pending_params.update(params) + else: + self._apply_parameters(params) + + # --- Replica of the drain snippet at the top of _streaming_loop --- + def _drain(self): + with self._pending_params_lock: + pending, self._pending_params = self._pending_params, {} + if pending: + self._apply_parameters(pending) + + +def _make_mgr() -> _Manager: + return _Manager() + + +def _alive_thread() -> threading.Thread: + """Return a started, long-lived daemon thread (simulates render loop).""" + e = threading.Event() + t = threading.Thread(target=e.wait, daemon=True) + t.start() + t._stop_event = e # type: ignore[attr-defined] + return t + + +def _stop_thread(t: threading.Thread): + t._stop_event.set() # type: ignore[attr-defined] + t.join(timeout=2) + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +class TestUpdateParametersDefer(unittest.TestCase): + """ + (a) update_parameters defers when streaming, applies directly when not. + """ + + def test_applies_directly_when_not_streaming(self): + """Pre-start: params go straight to _apply_parameters -> wrapper.""" + mgr = _make_mgr() + mgr.streaming = False + mgr.stream_thread = None + + mgr.update_parameters({"guidance_scale": 1.5}) + + self.assertEqual(len(mgr.wrapper.calls), 1) + self.assertAlmostEqual(mgr.wrapper.calls[0]["guidance_scale"], 1.5) + self.assertEqual(mgr._pending_params, {}) + + def test_defers_when_streaming_alive(self): + """While streaming: params go into _pending_params, wrapper NOT called.""" + mgr = _make_mgr() + t = _alive_thread() + mgr.streaming = True + mgr.stream_thread = t + + try: + mgr.update_parameters({"guidance_scale": 2.0}) + self.assertEqual(len(mgr.wrapper.calls), 0, "wrapper must not be called yet") + self.assertEqual(mgr._pending_params.get("guidance_scale"), 2.0) + finally: + _stop_thread(t) + + def test_applies_directly_when_streaming_flag_set_but_thread_dead(self): + """streaming=True but dead thread -> apply directly (safety net).""" + mgr = _make_mgr() + dead = threading.Thread(target=lambda: None) + dead.start() + dead.join() + mgr.streaming = True + mgr.stream_thread = dead + + mgr.update_parameters({"guidance_scale": 3.0}) + + self.assertEqual(len(mgr.wrapper.calls), 1) + self.assertAlmostEqual(mgr.wrapper.calls[0]["guidance_scale"], 3.0) + + +class TestLatestWinsMerge(unittest.TestCase): + """ + (b) Multiple deferred updates merge with latest-wins per key. + """ + + def test_latest_wins_same_key(self): + mgr = _make_mgr() + t = _alive_thread() + mgr.streaming = True + mgr.stream_thread = t + + try: + mgr.update_parameters({"guidance_scale": 1.0}) + mgr.update_parameters({"guidance_scale": 2.0}) + mgr.update_parameters({"guidance_scale": 3.0}) + + self.assertEqual(mgr._pending_params["guidance_scale"], 3.0) + finally: + _stop_thread(t) + + def test_different_keys_merged(self): + mgr = _make_mgr() + t = _alive_thread() + mgr.streaming = True + mgr.stream_thread = t + + try: + mgr.update_parameters({"guidance_scale": 1.5}) + mgr.update_parameters({"delta": 0.5}) + + self.assertEqual(mgr._pending_params["guidance_scale"], 1.5) + self.assertEqual(mgr._pending_params["delta"], 0.5) + finally: + _stop_thread(t) + + +class TestDrainClearsPending(unittest.TestCase): + """ + (c) After drain, _pending_params is empty and wrapper was called. + """ + + def test_drain_clears_and_applies(self): + mgr = _make_mgr() + mgr._pending_params = {"guidance_scale": 4.0} + + mgr._drain() + + self.assertEqual(mgr._pending_params, {}) + self.assertEqual(len(mgr.wrapper.calls), 1) + self.assertAlmostEqual(mgr.wrapper.calls[0]["guidance_scale"], 4.0) + + def test_drain_no_op_when_empty(self): + mgr = _make_mgr() + mgr._pending_params = {} + + mgr._drain() + + self.assertEqual(mgr._pending_params, {}) + self.assertEqual(len(mgr.wrapper.calls), 0) + + def test_drain_collapses_flood(self): + """Flood of same-key updates -> one wrapper call with the last value.""" + mgr = _make_mgr() + t = _alive_thread() + mgr.streaming = True + mgr.stream_thread = t + + try: + for i in range(20): + mgr.update_parameters({"guidance_scale": float(i)}) + finally: + _stop_thread(t) + mgr.streaming = False + + mgr._drain() + + self.assertEqual(len(mgr.wrapper.calls), 1, "only one wrapper call after draining a flooded queue") + self.assertAlmostEqual(mgr.wrapper.calls[0]["guidance_scale"], 19.0) + self.assertEqual(mgr._pending_params, {}) + + def test_drain_applies_multi_key_batch(self): + """Pending dict with multiple keys passes all to wrapper in one call.""" + mgr = _make_mgr() + mgr._pending_params = {"guidance_scale": 1.2, "delta": 0.8} + + mgr._drain() + + self.assertEqual(len(mgr.wrapper.calls), 1) + self.assertAlmostEqual(mgr.wrapper.calls[0]["guidance_scale"], 1.2) + self.assertAlmostEqual(mgr.wrapper.calls[0]["delta"], 0.8) + + def test_invalid_keys_filtered_out(self): + """_apply_parameters strips keys not in the whitelist.""" + mgr = _make_mgr() + mgr._pending_params = {"guidance_scale": 1.5, "nonexistent_key": "bad"} + + mgr._drain() + + self.assertEqual(len(mgr.wrapper.calls), 1) + self.assertNotIn("nonexistent_key", mgr.wrapper.calls[0]) + self.assertAlmostEqual(mgr.wrapper.calls[0]["guidance_scale"], 1.5) + + +if __name__ == "__main__": + unittest.main() From 9e578fa55a389347561a269591744ef888f03c68 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 6 Jun 2026 08:54:47 -0400 Subject: [PATCH 09/13] refactor: remove dead update_prompt/seed_weights methods + fix wrapper docstring --- .../stream_parameter_updater.py | 50 ------------------- src/streamdiffusion/wrapper.py | 4 +- 2 files changed, 2 insertions(+), 52 deletions(-) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index b0712b23..44f88708 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -397,56 +397,6 @@ def update_stream_params( else: logger.info(f"update_stream_params: Cache maxframes set to {cache_maxframes}") - @torch.inference_mode() - def update_prompt_weights( - self, - prompt_weights: List[float], - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" - ) -> None: - """Update weights for current prompt list without re-encoding prompts.""" - if not self._current_prompt_list: - logger.warning("update_prompt_weights: Warning: No current prompt list to update weights for") - return - - if len(prompt_weights) != len(self._current_prompt_list): - logger.warning(f"update_prompt_weights: Warning: Weight count {len(prompt_weights)} doesn't match prompt count {len(self._current_prompt_list)}") - return - - # Update the current prompt list with new weights - updated_prompt_list = [] - for i, (prompt_text, _) in enumerate(self._current_prompt_list): - updated_prompt_list.append((prompt_text, prompt_weights[i])) - - self._current_prompt_list = updated_prompt_list - - # Recompute blended embeddings with new weights - self._apply_prompt_blending(prompt_interpolation_method) - - @torch.inference_mode() - def update_seed_weights( - self, - seed_weights: List[float], - interpolation_method: Literal["linear", "slerp"] = "linear" - ) -> None: - """Update weights for current seed list without regenerating noise.""" - if not self._current_seed_list: - logger.warning("update_seed_weights: Warning: No current seed list to update weights for") - return - - if len(seed_weights) != len(self._current_seed_list): - logger.warning(f"update_seed_weights: Warning: Weight count {len(seed_weights)} doesn't match seed count {len(self._current_seed_list)}") - return - - # Update the current seed list with new weights - updated_seed_list = [] - for i, (seed_value, _) in enumerate(self._current_seed_list): - updated_seed_list.append((seed_value, seed_weights[i])) - - self._current_seed_list = updated_seed_list - - # Recompute blended noise with new weights - self._apply_seed_blending(interpolation_method) - @torch.inference_mode() def _update_blended_prompts( self, diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 81cfc173..2219a8f3 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -68,8 +68,8 @@ class StreamDiffusionWrapper: ## Weight Management: - Prompt weights are normalized by default (sum to 1.0) unless normalize_prompt_weights=False - Seed weights are normalized by default (sum to 1.0) unless normalize_seed_weights=False - - Use update_prompt_weights([0.8, 0.2]) to change weights without re-encoding prompts - - Use update_seed_weights([0.3, 0.7]) to change weights without regenerating noise + - To change blend weights, pass the full prompt_list/seed_list to update_stream_params — + unchanged texts/seeds hit the embedding/noise cache, so only re-blending occurs (no re-encode) ## Cache Management: - Prompt embeddings and seed noise tensors are automatically cached for performance From 20823b6b699bedbb8ca86fd9553fcb87abdc151a Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 6 Jun 2026 13:23:26 -0400 Subject: [PATCH 10/13] fix: demote per-frame prompt_list logging to DEBUG during weight drags --- src/streamdiffusion/stream_parameter_updater.py | 7 ++++++- src/streamdiffusion/wrapper.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 44f88708..c248a128 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -322,7 +322,12 @@ def update_stream_params( # Handle prompt blending if prompt_list is provided if prompt_list is not None: first = prompt_list[0][0][:60] if prompt_list else "" - logger.info( + # Log at INFO only when the prompt *texts* change (real new prompt). + # Weight-only changes during a drag produce a different list each frame + # but don't warrant INFO noise — demote those to DEBUG. + _texts_changed = [str(p) for p, _ in prompt_list] != [p for p, _ in self._current_prompt_list] + _log = logger.info if _texts_changed else logger.debug + _log( f"update_stream_params: prompt_list -> {len(prompt_list)} prompt(s): [{first!r}{'...' if len(prompt_list[0][0]) > 60 else ''}]" ) self._update_blended_prompts( diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 2219a8f3..f19d39d4 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -698,7 +698,7 @@ def update_stream_params( or negative_prompt == self.stream._param_updater._current_negative_prompt ) if _normalized == _current and _neg_unchanged: - logger.info( + logger.debug( "update_stream_params: prompt_list unchanged (%d prompt(s)) -- skipping re-encode", len(_normalized), ) From 7749b91d87211c82d6d74c096a32769ad619d4ed Mon Sep 17 00:00:00 2001 From: Alex Date: Sun, 7 Jun 2026 06:15:41 -0400 Subject: [PATCH 11/13] fix: correct prompt-list log to show all prompts not just the first --- src/streamdiffusion/stream_parameter_updater.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index c248a128..6bba3b21 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -321,15 +321,13 @@ def update_stream_params( # Handle prompt blending if prompt_list is provided if prompt_list is not None: - first = prompt_list[0][0][:60] if prompt_list else "" # Log at INFO only when the prompt *texts* change (real new prompt). # Weight-only changes during a drag produce a different list each frame # but don't warrant INFO noise — demote those to DEBUG. _texts_changed = [str(p) for p, _ in prompt_list] != [p for p, _ in self._current_prompt_list] _log = logger.info if _texts_changed else logger.debug - _log( - f"update_stream_params: prompt_list -> {len(prompt_list)} prompt(s): [{first!r}{'...' if len(prompt_list[0][0]) > 60 else ''}]" - ) + _excerpts = [p[:40] + ("…" if len(p) > 40 else "") for p, _ in prompt_list] + _log(f"update_stream_params: prompt_list -> {len(prompt_list)} prompt(s): {_excerpts!r}") self._update_blended_prompts( prompt_list=prompt_list, negative_prompt=negative_prompt or self._current_negative_prompt, From 0f9d40a1416703772a899676630baca18eea3570 Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 6 Jun 2026 21:37:30 -0400 Subject: [PATCH 12/13] feat: add cosine_weighted prompt interpolation mode with N-way slerp --- src/streamdiffusion/config.py | 8 +- .../stream_parameter_updater.py | 108 ++++++- src/streamdiffusion/wrapper.py | 10 +- tests/unit/test_prompt_interpolation.py | 271 ++++++++++++++++++ 4 files changed, 377 insertions(+), 20 deletions(-) create mode 100644 tests/unit/test_prompt_interpolation.py diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index 001fda3d..95acdc07 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -452,10 +452,10 @@ def _validate_config(config: Dict[str, Any]) -> None: if not isinstance(weight, (int, float)) or weight < 0: raise ValueError(f"_validate_config: Prompt weight {i} must be a non-negative number") - - interpolation_method = blend_config.get('interpolation_method', 'slerp') - if interpolation_method not in ['linear', 'slerp']: - raise ValueError("_validate_config: interpolation_method must be 'linear' or 'slerp'") + + interpolation_method = blend_config.get("interpolation_method", "slerp") + if interpolation_method not in ["linear", "slerp", "cosine_weighted"]: + raise ValueError("_validate_config: interpolation_method must be 'linear', 'slerp', or 'cosine_weighted'") # Validate seed blending configuration if present if 'seed_blending' in config: diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 6bba3b21..66c397d9 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -50,6 +50,11 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo self._current_style_images: Dict[str, Any] = {} # Use the shared orchestrator attached via OrchestratorUser self._embedding_orchestrator = self._preprocessing_orchestrator + + # Tracks the last prompt interpolation method used; read by td_manager for + # IPAdapter style-image re-blends (td_manager.py:1147). + self._last_prompt_interpolation_method: str = "slerp" + def get_cache_info(self) -> Dict: """Get cache statistics for monitoring performance.""" total_requests = self._prompt_cache_stats.hits + self._prompt_cache_stats.misses @@ -243,7 +248,7 @@ def update_stream_params( seed: Optional[int] = None, prompt_list: Optional[List[Tuple[str, float]]] = None, negative_prompt: Optional[str] = None, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", normalize_prompt_weights: Optional[bool] = None, seed_list: Optional[List[Tuple[int, float]]] = None, seed_interpolation_method: Literal["linear", "slerp"] = "linear", @@ -405,7 +410,7 @@ def _update_blended_prompts( self, prompt_list: List[Tuple[str, float]], negative_prompt: str = "", - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", ) -> None: """Update prompt embeddings using multiple weighted prompts.""" # Store current state @@ -447,7 +452,9 @@ def _cache_prompt_embeddings( # Cache hit self._prompt_cache_stats.record_hit() - def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", "slerp"]) -> None: + def _apply_prompt_blending( + self, prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] + ) -> None: """Apply weighted blending of cached prompt embeddings.""" if not self._current_prompt_list: return @@ -464,15 +471,26 @@ def _apply_prompt_blending(self, prompt_interpolation_method: Literal["linear", logger.warning("_apply_prompt_blending: Warning: No cached embeddings found") return + # Record last method used (consumed by td_manager IPAdapter re-blend at 1147). + self._last_prompt_interpolation_method = prompt_interpolation_method + # Normalize weights weights = self._normalize_weights(weights, self.normalize_prompt_weights) # Apply interpolation - if prompt_interpolation_method == "slerp" and len(embeddings) == 2: - # Spherical linear interpolation for 2 prompts - embed1, embed2 = embeddings[0], embeddings[1] - t = weights[1].item() # Use second weight as interpolation factor - combined_embeds = self._slerp(embed1, embed2, t) + if prompt_interpolation_method == "slerp": + if len(embeddings) == 2: + # Original 2-way slerp path — identical output to before. + embed1, embed2 = embeddings[0], embeddings[1] + t = weights[1].item() # Use second weight as interpolation factor + combined_embeds = self._slerp(embed1, embed2, t) + else: + # N-way iterative slerp (ported from reference multi_slerp). + combined_embeds = self._multi_slerp(embeddings, weights.tolist()) + elif prompt_interpolation_method == "cosine_weighted": + # Genuine cosine-similarity weighting: emphasise embeddings aligned with the + # weighted consensus direction, de-emphasise outliers, then N-way slerp. + combined_embeds = self._cosine_weighted_blend(embeddings, weights.tolist()) else: # Linear interpolation (weighted average) combined_embeds = torch.zeros_like(embeddings[0]) @@ -559,6 +577,74 @@ def _slerp(self, embed1: torch.Tensor, embed2: torch.Tensor, t: float) -> torch. return result.view(original_shape) + def _multi_slerp(self, embeddings: List[torch.Tensor], weights: List[float]) -> torch.Tensor: + """N-way iterative SLERP over a list of embeddings, ported from the reference fork. + + Embeddings are sorted by weight (desc) and folded pairwise with the standard 2-way + ``_slerp``. The result is scaled by ``max(1, sum(weights))`` so that weights > 1 + amplify the output magnitude rather than being silently clipped. + + Args: + embeddings: List of embedding tensors (all same shape). + weights: Corresponding raw (already-normalised by caller) weights as plain floats. + + Returns: + Interpolated embedding tensor with the same shape as each input. + """ + total_weight = sum(weights) + scale_factor = max(1.0, total_weight) + if len(embeddings) == 1: + return embeddings[0] * scale_factor + scaled_weights = [w / scale_factor for w in weights] + sorted_pairs = sorted(zip(embeddings, scaled_weights), key=lambda x: x[1], reverse=True) + sorted_embeddings, sorted_weights = zip(*sorted_pairs) + result = sorted_embeddings[0] + accumulated_weight = sorted_weights[0] + for i in range(1, len(sorted_embeddings)): + if sorted_weights[i] == 0: + continue + t = sorted_weights[i] / (accumulated_weight + sorted_weights[i]) + result = self._slerp(result, sorted_embeddings[i], t) + accumulated_weight += sorted_weights[i] + return result * scale_factor + + def _cosine_weighted_blend(self, embeddings: List[torch.Tensor], weights: List[float]) -> torch.Tensor: + """Blend embeddings with cosine-similarity weighting toward the weighted consensus direction. + + Computes a weighted-mean direction across all embeddings, then adjusts each embedding's + weight by its cosine similarity to that consensus. Embeddings that agree with the + consensus are up-weighted; outliers are down-weighted. The adjusted weights preserve + total weight mass so that overall embedding magnitude is unchanged. The final blend + is performed with ``_multi_slerp`` for perceptually smooth interpolation. + + This is the genuine implementation of what the reference fork named + "cosine_weighted_interpolation" but never actually computed (its implementation + was a dead-code alias for plain multi_slerp). + + Args: + embeddings: List of embedding tensors (all same shape). + weights: Corresponding raw (already-normalised by caller) weights as plain floats. + + Returns: + Interpolated embedding tensor with the same shape as each input. + """ + if len(embeddings) == 1: + return embeddings[0] + # Work in float32 regardless of model dtype for numerical stability. + ref_device = embeddings[0].device + flats = torch.stack([e.flatten().float() for e in embeddings]) # [N, D] + w = torch.tensor(weights, device=ref_device, dtype=torch.float32) # [N] + # Weighted consensus direction. + mean_dir = F.normalize((flats * w.unsqueeze(1)).sum(0), dim=0) # [D] + # Cosine similarity of each embedding to the consensus. + cos_sims = (F.normalize(flats, dim=1) @ mean_dir).clamp(min=1e-4) # [N] + # Adjust weights by cosine similarity; keep total weight mass constant. + adj = w * cos_sims + total_w = w.sum() + if adj.sum() > 1e-8: + adj = adj * (total_w / adj.sum()) + return self._multi_slerp(embeddings, adj.tolist()) + @torch.inference_mode() def _update_blended_seeds( self, @@ -951,7 +1037,7 @@ def update_prompt_at_index( self, index: int, new_prompt: str, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", ) -> None: """Update a single prompt at the specified index without re-encoding others.""" if not self._validate_index(index, self._current_prompt_list, "update_prompt_at_index"): @@ -1005,7 +1091,7 @@ def add_prompt( self, prompt: str, weight: float = 1.0, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", ) -> None: """Add a new prompt to the current list.""" new_index = len(self._current_prompt_list) @@ -1033,7 +1119,7 @@ def add_prompt( def remove_prompt_at_index( self, index: int, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp" + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", ) -> None: """Remove a prompt at the specified index.""" if not self._validate_index(index, self._current_prompt_list, "remove_prompt_at_index"): diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index f19d39d4..fbcb2f30 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -416,7 +416,7 @@ def prepare( guidance_scale: float = 1.2, delta: float = 1.0, # Blending-specific parameters (only used when prompt is a list) - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", seed_list: Optional[List[Tuple[int, float]]] = None, seed_interpolation_method: Literal["linear", "slerp"] = "linear", ) -> None: @@ -544,7 +544,7 @@ def update_prompt( self, prompt: Union[str, List[Tuple[str, float]]], negative_prompt: str = "", - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", clear_blending: bool = True, warn_about_conflicts: bool = True, ) -> None: @@ -564,7 +564,7 @@ def update_prompt( - Blending: [("cat", 0.7), ("dog", 0.3)] negative_prompt : str, optional The negative prompt (used with blending), by default "". - prompt_interpolation_method : Literal["linear", "slerp"], optional + prompt_interpolation_method : Literal["linear", "slerp", "cosine_weighted"], optional Method for interpolating between prompt embeddings (used with blending), by default "slerp". clear_blending : bool, optional Whether to clear existing blending when switching to single prompt, by default True. @@ -622,7 +622,7 @@ def update_stream_params( # Prompt blending parameters prompt_list: Optional[List[Tuple[str, float]]] = None, negative_prompt: Optional[str] = None, - prompt_interpolation_method: Literal["linear", "slerp"] = "slerp", + prompt_interpolation_method: Literal["linear", "slerp", "cosine_weighted"] = "slerp", normalize_prompt_weights: Optional[bool] = None, # Seed blending parameters seed_list: Optional[List[Tuple[int, float]]] = None, @@ -662,7 +662,7 @@ def update_stream_params( Example: [("cat", 0.7), ("dog", 0.3)] negative_prompt : Optional[str] The negative prompt to apply to all blended prompts. - prompt_interpolation_method : Literal["linear", "slerp"] + prompt_interpolation_method : Literal["linear", "slerp", "cosine_weighted"] Method for interpolating between prompt embeddings, by default "slerp". normalize_prompt_weights : Optional[bool] Whether to normalize prompt weights in blending to sum to 1, by default None (no change). diff --git a/tests/unit/test_prompt_interpolation.py b/tests/unit/test_prompt_interpolation.py new file mode 100644 index 00000000..dce4aaae --- /dev/null +++ b/tests/unit/test_prompt_interpolation.py @@ -0,0 +1,271 @@ +"""Unit tests for the new prompt interpolation modes added in +stream_parameter_updater.py: + + - ``_multi_slerp`` – N-way iterative SLERP (port of reference multi_slerp) + - ``_cosine_weighted_blend`` – genuine cosine-similarity weighting before N-way SLERP + - ``_apply_prompt_blending`` dispatch for "cosine_weighted" and N>2 "slerp" paths + - ``_last_prompt_interpolation_method`` attribute is recorded and carries across calls + +All tests run on CPU with float32 so no GPU is required. +""" + +import types + +import torch + +from streamdiffusion.stream_parameter_updater import StreamParameterUpdater + + +# --------------------------------------------------------------------------- +# Minimal fake stream that satisfies the fields accessed during __init__ and +# _apply_prompt_blending without touching the real pipeline. +# --------------------------------------------------------------------------- + + +def _fake_stream(): + """Return a minimal namespace that looks like a StreamDiffusion instance.""" + stream = types.SimpleNamespace() + stream.device = torch.device("cpu") + stream.dtype = torch.float32 + stream.batch_size = 1 + stream.cfg_type = "none" + stream.guidance_scale = 1.0 + stream.prompt_embeds = None + stream.negative_prompt_embeds = None + # Attributes accessed by OrchestratorUser.attach_orchestrator + stream._preprocessing_orchestrator = None + stream.embedding_hooks = [] + return stream + + +def _make_updater() -> StreamParameterUpdater: + """Construct a StreamParameterUpdater with a fake stream, bypassing __init__ side-effects.""" + stream = _fake_stream() + + # Patch OrchestratorUser.attach_orchestrator to be a no-op so we don't need + # a real PreprocessingOrchestrator. + orig_attach = StreamParameterUpdater.attach_orchestrator + + def _noop_attach(self, s): # noqa: ANN001 + self._preprocessing_orchestrator = None + + StreamParameterUpdater.attach_orchestrator = _noop_attach + try: + updater = StreamParameterUpdater(stream) + finally: + StreamParameterUpdater.attach_orchestrator = orig_attach + + updater._embedding_orchestrator = None + return updater + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _rand_embed(shape=(1, 4, 8), seed=0) -> torch.Tensor: + """Reproducible random embedding on CPU/float32.""" + g = torch.Generator() + g.manual_seed(seed) + return torch.randn(*shape, generator=g) + + +# --------------------------------------------------------------------------- +# _multi_slerp tests +# --------------------------------------------------------------------------- + + +class TestMultiSlerp: + def setup_method(self): + self.upd = _make_updater() + + def test_single_embedding_returns_scaled(self): + e = _rand_embed(seed=1) + result = self.upd._multi_slerp([e], [1.0]) + assert result.shape == e.shape + # scale_factor = max(1, 1.0) = 1 → output identical to input + assert torch.allclose(result, e) + + def test_single_embedding_weight_gt1_scales(self): + e = _rand_embed(seed=2) + result = self.upd._multi_slerp([e], [2.5]) + assert torch.allclose(result, e * 2.5, atol=1e-5) + + def test_two_way_matches_direct_slerp(self): + """With two embeddings, multi_slerp result must equal _slerp(e1, e2, t).""" + e1 = _rand_embed(seed=3) + e2 = _rand_embed(seed=4) + w1, w2 = 0.7, 0.3 + result_multi = self.upd._multi_slerp([e1, e2], [w1, w2]) + # _multi_slerp normalises first: scaled_w = [0.7, 0.3]; sorted desc → [0.7, 0.3] + # t = 0.3 / (0.7 + 0.3) = 0.3 + t_expected = w2 / (w1 + w2) + result_direct = self.upd._slerp(e1, e2, t_expected) + # scale_factor = max(1, 1.0) = 1 → no additional scaling + assert torch.allclose(result_multi, result_direct, atol=1e-5) + + def test_three_way_preserves_shape(self): + es = [_rand_embed(seed=i) for i in range(3)] + result = self.upd._multi_slerp(es, [0.5, 0.3, 0.2]) + assert result.shape == es[0].shape + + def test_zero_weight_entry_skipped(self): + """A zero-weight prompt should have no effect.""" + e1 = _rand_embed(seed=5) + e2 = _rand_embed(seed=6) + e_zero = _rand_embed(seed=99) + # With zero weight the third embedding should be entirely ignored + result_with = self.upd._multi_slerp([e1, e2, e_zero], [0.6, 0.4, 0.0]) + result_without = self.upd._multi_slerp([e1, e2], [0.6, 0.4]) + assert torch.allclose(result_with, result_without, atol=1e-5) + + def test_weights_sum_gt1_scales_magnitude(self): + """When sum(weights) > 1 the result magnitude is scaled accordingly.""" + e = _rand_embed(seed=7) + # Single embedding, weight 3.0 → output = e * 3.0 + result = self.upd._multi_slerp([e], [3.0]) + assert torch.allclose(result, e * 3.0, atol=1e-5) + + def test_dtype_preserved(self): + e1 = _rand_embed(seed=8) + e2 = _rand_embed(seed=9) + result = self.upd._multi_slerp([e1, e2], [0.5, 0.5]) + assert result.dtype == e1.dtype + + +# --------------------------------------------------------------------------- +# _cosine_weighted_blend tests +# --------------------------------------------------------------------------- + + +class TestCosineWeightedBlend: + def setup_method(self): + self.upd = _make_updater() + + def test_single_embedding_passthrough(self): + e = _rand_embed(seed=10) + result = self.upd._cosine_weighted_blend([e], [1.0]) + assert torch.allclose(result, e, atol=1e-5) + + def test_identical_direction_matches_multi_slerp(self): + """When all embeddings point in the same direction, cos-sims are all 1 → same as multi_slerp.""" + base = _rand_embed(seed=11) + # Scale copies of the same embedding by small factors (same direction) + e1 = base * 1.0 + e2 = base * 0.5 + weights = [0.6, 0.4] + result_cw = self.upd._cosine_weighted_blend([e1, e2], weights) + result_ms = self.upd._multi_slerp([e1, e2], weights) + assert torch.allclose(result_cw, result_ms, atol=1e-4) + + def test_outlier_de_emphasised(self): + """An embedding pointing in the opposite direction to both others should be + de-weighted, pulling the output AWAY from it compared to plain multi_slerp.""" + # Two aligned embeddings and one in the opposite direction + e_main = _rand_embed(seed=12) + e_aligned = _rand_embed(seed=12) * 0.9 # almost identical direction + e_outlier = -e_main.clone() # exact opposite + weights = [0.4, 0.4, 0.2] + + cw_result = self.upd._cosine_weighted_blend([e_main, e_aligned, e_outlier], weights) + ms_result = self.upd._multi_slerp([e_main, e_aligned, e_outlier], weights) + + # cosine_weighted should differ from plain multi_slerp when there's an outlier + assert not torch.allclose(cw_result, ms_result, atol=1e-4), ( + "cosine_weighted_blend should differ from multi_slerp when an outlier is present" + ) + + def test_shape_and_dtype_preserved(self): + es = [_rand_embed(seed=i) for i in range(3)] + result = self.upd._cosine_weighted_blend(es, [0.5, 0.3, 0.2]) + assert result.shape == es[0].shape + assert result.dtype == es[0].dtype + + +# --------------------------------------------------------------------------- +# _apply_prompt_blending dispatch + _last_prompt_interpolation_method +# --------------------------------------------------------------------------- + + +class TestApplyPromptBlendingDispatch: + """Patch the actual blend helpers to just record that they were called, and verify + the dispatch logic chooses the right one.""" + + def setup_method(self): + self.upd = _make_updater() + # Pre-populate a two-embedding cache so _apply_prompt_blending has data. + e1 = _rand_embed(seed=20) + e2 = _rand_embed(seed=21) + e3 = _rand_embed(seed=22) + self.upd._prompt_cache = { + 0: {"embed": e1, "text": "cat"}, + 1: {"embed": e2, "text": "dog"}, + 2: {"embed": e3, "text": "bird"}, + } + self.upd._current_prompt_list = [("cat", 0.5), ("dog", 0.3), ("bird", 0.2)] + self.upd._current_negative_prompt = "" + + def test_slerp_n_gt_2_calls_multi_slerp(self): + called = [] + orig = self.upd._multi_slerp + + def spy(*args, **kwargs): + called.append("multi_slerp") + return orig(*args, **kwargs) + + self.upd._multi_slerp = spy + self.upd._apply_prompt_blending("slerp") + assert "multi_slerp" in called, "slerp with N>2 should delegate to _multi_slerp" + + def test_cosine_weighted_calls_cosine_weighted_blend(self): + called = [] + orig = self.upd._cosine_weighted_blend + + def spy(*args, **kwargs): + called.append("cosine_weighted_blend") + return orig(*args, **kwargs) + + self.upd._cosine_weighted_blend = spy + self.upd._apply_prompt_blending("cosine_weighted") + assert "cosine_weighted_blend" in called, "cosine_weighted method should delegate to _cosine_weighted_blend" + + def test_last_method_recorded_slerp(self): + self.upd._apply_prompt_blending("slerp") + assert self.upd._last_prompt_interpolation_method == "slerp" + + def test_last_method_recorded_cosine_weighted(self): + self.upd._apply_prompt_blending("cosine_weighted") + assert self.upd._last_prompt_interpolation_method == "cosine_weighted" + + def test_last_method_recorded_linear(self): + self.upd._apply_prompt_blending("linear") + assert self.upd._last_prompt_interpolation_method == "linear" + + def test_slerp_2_way_uses_slerp_not_multi_slerp(self): + """With exactly 2 embeddings, 'slerp' must NOT call _multi_slerp.""" + self.upd._current_prompt_list = [("cat", 0.6), ("dog", 0.4)] + multi_called = [] + slerp_called = [] + orig_multi = self.upd._multi_slerp + orig_slerp = self.upd._slerp + + def spy_multi(*a, **kw): + multi_called.append(True) + return orig_multi(*a, **kw) + + def spy_slerp(*a, **kw): + slerp_called.append(True) + return orig_slerp(*a, **kw) + + self.upd._multi_slerp = spy_multi + self.upd._slerp = spy_slerp + self.upd._apply_prompt_blending("slerp") + assert not multi_called, "2-way slerp should use _slerp directly, not _multi_slerp" + assert slerp_called, "2-way slerp should call _slerp" + + def test_last_prompt_interpolation_method_default(self): + """Attribute must exist from __init__ with default 'slerp'.""" + fresh = _make_updater() + assert hasattr(fresh, "_last_prompt_interpolation_method") + assert fresh._last_prompt_interpolation_method == "slerp" From c19722a91832194747d5e2578e2edc44db3ff36f Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 6 Jun 2026 21:45:56 -0400 Subject: [PATCH 13/13] fix: warn on unknown prompt interpolation method with once-per-string guard --- .../stream_parameter_updater.py | 14 +++++++ tests/unit/test_prompt_interpolation.py | 38 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/streamdiffusion/stream_parameter_updater.py b/src/streamdiffusion/stream_parameter_updater.py index 66c397d9..64de960e 100644 --- a/src/streamdiffusion/stream_parameter_updater.py +++ b/src/streamdiffusion/stream_parameter_updater.py @@ -54,6 +54,9 @@ def __init__(self, stream_diffusion, wrapper=None, normalize_prompt_weights: boo # Tracks the last prompt interpolation method used; read by td_manager for # IPAdapter style-image re-blends (td_manager.py:1147). self._last_prompt_interpolation_method: str = "slerp" + # Warn-once set: emit one logger.warning per unique unknown method string so + # that per-frame weight-drag calls don't flood the log. + self._warned_unknown_interp_methods: set = set() def get_cache_info(self) -> Dict: """Get cache statistics for monitoring performance.""" @@ -492,6 +495,17 @@ def _apply_prompt_blending( # weighted consensus direction, de-emphasise outliers, then N-way slerp. combined_embeds = self._cosine_weighted_blend(embeddings, weights.tolist()) else: + # Unknown method — warn once per unique string so weight-drag updates don't + # flood the log, then fall back to linear interpolation. + if prompt_interpolation_method != "linear" and ( + prompt_interpolation_method not in self._warned_unknown_interp_methods + ): + self._warned_unknown_interp_methods.add(prompt_interpolation_method) + logger.warning( + "_apply_prompt_blending: unknown interpolation method %r - " + "falling back to linear (valid: linear, slerp, cosine_weighted)", + prompt_interpolation_method, + ) # Linear interpolation (weighted average) combined_embeds = torch.zeros_like(embeddings[0]) for embed, weight in zip(embeddings, weights): diff --git a/tests/unit/test_prompt_interpolation.py b/tests/unit/test_prompt_interpolation.py index dce4aaae..d081efa3 100644 --- a/tests/unit/test_prompt_interpolation.py +++ b/tests/unit/test_prompt_interpolation.py @@ -269,3 +269,41 @@ def test_last_prompt_interpolation_method_default(self): fresh = _make_updater() assert hasattr(fresh, "_last_prompt_interpolation_method") assert fresh._last_prompt_interpolation_method == "slerp" + + def test_unknown_method_falls_back_to_linear(self): + """An unrecognised method string must produce the same output as 'linear'.""" + # Capture the linear result first on a fresh updater sharing the same embeds. + upd_linear = _make_updater() + upd_linear._prompt_cache = dict(self.upd._prompt_cache) + upd_linear._current_prompt_list = list(self.upd._current_prompt_list) + upd_linear._current_negative_prompt = "" + upd_linear._apply_prompt_blending("linear") + linear_embed = upd_linear.stream.prompt_embeds.clone() + + # Now run the typo'd string on our main updater. + self.upd._apply_prompt_blending("cosine_weignted") + unknown_embed = self.upd.stream.prompt_embeds + + assert torch.allclose(unknown_embed, linear_embed, atol=1e-5), ( + "Unknown method should fall back to linear interpolation" + ) + + def test_unknown_method_warns_once(self, caplog): + """Exactly one warning per unique unknown string; 'linear' never warns.""" + import logging + + with caplog.at_level(logging.WARNING, logger="streamdiffusion.stream_parameter_updater"): + # Two calls with the same bad string → only one warning record. + self.upd._apply_prompt_blending("cosine_weignted") + self.upd._apply_prompt_blending("cosine_weignted") + + unknown_warnings = [r for r in caplog.records if "cosine_weignted" in r.message] + assert len(unknown_warnings) == 1, ( + f"Expected exactly 1 warning for repeated unknown method, got {len(unknown_warnings)}" + ) + + # A 'linear' call must never produce a warning. + caplog.clear() + with caplog.at_level(logging.WARNING, logger="streamdiffusion.stream_parameter_updater"): + self.upd._apply_prompt_blending("linear") + assert not caplog.records, "No warning expected for the 'linear' method"