Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions configs/td_config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use_denoising_batch: true
use_tiny_vae: true
acceleration: "tensorrt"
cfg_type: "self"
do_add_noise: false
do_add_noise: true # matches reference fork default; adds re-noising term in inter-step latent buffer rebuild
warmup: 10
use_safety_checker: false
skip_diffusion: false
Expand Down Expand Up @@ -61,14 +61,16 @@ scheduler: "lcm"
sampler: "normal"

# StreamV2V Cached Attention (Cattenable enables, Cattmaxframes/Cattinterval tune)
# cache_interval=4: hold a stable 4-frame anchor (matches reference fork default, interval=1 just caches previous frame)
use_cached_attn: true
cache_maxframes: 2
cache_interval: 1
cache_interval: 4

# Image filtering (similar frame skip)
enable_similar_image_filter: false
similar_image_filter_threshold: 0.99
similar_image_filter_max_skip_frame: 1
# enable_similar_image_filter=true: hold previous output on near-duplicate frames → less flicker (matches reference fork)
enable_similar_image_filter: true
similar_image_filter_threshold: 0.9996
similar_image_filter_max_skip_frame: 10

# HuggingFace cache directory (for model downloads); leave empty to use default
hf_cache: ""
Expand Down
8 changes: 4 additions & 4 deletions src/streamdiffusion/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,10 @@ def _validate_config(config: Dict[str, Any]) -> None:

if not isinstance(weight, (int, float)) or weight < 0:
raise ValueError(f"_validate_config: Prompt weight {i} must be a non-negative number")
interpolation_method = blend_config.get('interpolation_method', 'slerp')
if interpolation_method not in ['linear', 'slerp']:
raise ValueError("_validate_config: interpolation_method must be 'linear' or 'slerp'")

interpolation_method = blend_config.get("interpolation_method", "slerp")
if interpolation_method not in ["linear", "slerp", "cosine_weighted"]:
raise ValueError("_validate_config: interpolation_method must be 'linear', 'slerp', or 'cosine_weighted'")

# Validate seed blending configuration if present
if 'seed_blending' in config:
Expand Down
86 changes: 84 additions & 2 deletions src/streamdiffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,10 @@ def prepare(
generator=generator,
).to(device=self.device, dtype=self.dtype)

self.stock_noise = torch.zeros_like(self.init_noise)
# Clone init_noise rather than zeros so stock_noise starts coherent when CFG is active.
# Initialising with zeros forces a cold-restart warm-up of the RCFG residual evolution;
# the reference fork (pipeline_td.py:1103-1105) uses clone() to avoid that.
self.stock_noise = self.init_noise.clone()

# Handle scheduler-specific scaling calculations
c_skip_list = []
Expand Down Expand Up @@ -593,6 +596,59 @@ def prepare(
# Seed _unet_kwargs with the constant key so per-frame code only updates values
self._unet_kwargs = {"return_dict": False}

def _refresh_derived_tensors(self) -> None:
"""Re-create tensors derived from batch_size/scheduler state but NOT rebuilt by
StreamParameterUpdater._update_timestep_calculations.

Called by the error-fallback handler in __call__ after
_param_updater._update_timestep_calculations() has already refreshed
sub_timesteps_tensor, c_skip/c_out, alpha/beta_prod_t_sqrt.
"""
# init_noise + stock_noise — re-sampled so the fallback frame is coherent
self.init_noise = torch.randn(
(self.batch_size, 4, self.latent_height, self.latent_width),
generator=self.generator,
).to(device=self.device, dtype=self.dtype)
self.stock_noise = self.init_noise.clone()

# Pre-computed shifted tensors (depend on alpha/beta already refreshed above)
if self.use_denoising_batch and (self.cfg_type == "self" or self.cfg_type == "initialize"):
self._alpha_next = torch.cat(
[self.alpha_prod_t_sqrt[1:], torch.ones_like(self.alpha_prod_t_sqrt[0:1])], dim=0
)
self._beta_next = torch.cat(
[self.beta_prod_t_sqrt[1:], torch.ones_like(self.beta_prod_t_sqrt[0:1])], dim=0
)
self._init_noise_rotated = torch.cat([self.init_noise[1:], self.init_noise[0:1]], dim=0)
else:
self._alpha_next = None
self._beta_next = None
self._init_noise_rotated = None

# Pre-allocated latent scratch buffers (batch-size-dependent)
if self.denoising_steps_num > 1:
self._combined_latent_buf = torch.empty(
(self.batch_size, 4, self.latent_height, self.latent_width),
dtype=self.dtype,
device=self.device,
)
else:
self._combined_latent_buf = None

if self.guidance_scale > 1.0 and (self.cfg_type == "initialize" or self.cfg_type == "full"):
cfg_batch = (1 + self.batch_size) if self.cfg_type == "initialize" else (2 * self.batch_size)
self._cfg_latent_buf = torch.empty(
(cfg_batch, 4, self.latent_height, self.latent_width),
dtype=self.dtype,
device=self.device,
)
self._cfg_t_buf = torch.empty(
cfg_batch, dtype=self.sub_timesteps_tensor.dtype, device=self.device
)
else:
self._cfg_latent_buf = None
self._cfg_t_buf = None

def _get_scheduler_scalings(self, timestep):
"""Get LCM/TCD-specific scaling factors for boundary conditions."""
if isinstance(self.scheduler, LCMScheduler):
Expand Down Expand Up @@ -1059,7 +1115,33 @@ def __call__(self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None) -
device=self.device, dtype=self.dtype
)

x_0_pred_out = self.predict_x0_batch(x_t_latent)
try:
x_0_pred_out = self.predict_x0_batch(x_t_latent)
except RuntimeError as _e:
_msg = str(_e).lower()
if "out of memory" in _msg:
logger.error("StreamDiffusion.__call__: OOM — clearing cache and re-raising: %s", _e)
torch.cuda.empty_cache()
raise
if "expanded size" in _msg and "must match" in _msg:
logger.error(
"StreamDiffusion.__call__: tensor size mismatch — attempting scheduler rebuild: %s", _e
)
try:
self._param_updater._update_timestep_calculations()
self._refresh_derived_tensors()
except Exception as _fix_err:
logger.error("StreamDiffusion.__call__: rebuild failed: %s", _fix_err)
torch.cuda.empty_cache()
logger.warning("StreamDiffusion.__call__: returning safe fallback frame")
return self.decode_image(
torch.randn(
(1, 4, self.latent_height, self.latent_width),
device=self.device,
dtype=self.dtype,
)
)
raise

# LATENT POSTPROCESSING HOOKS: After diffusion, before VAE decoding
x_0_pred_out = self._apply_latent_postprocessing_hooks(x_0_pred_out)
Expand Down
Loading