Skip to content

fix: prevent autocast crash with fp8 weights in WAN model addcmul ops (CORE-47)#13343

Closed
deepme987 wants to merge 2 commits intomasterfrom
deepme987/fix-fp8-addcmul-autocast
Closed

fix: prevent autocast crash with fp8 weights in WAN model addcmul ops (CORE-47)#13343
deepme987 wants to merge 2 commits intomasterfrom
deepme987/fix-fp8-addcmul-autocast

Conversation

@deepme987
Copy link
Copy Markdown
Contributor

@deepme987 deepme987 commented Apr 10, 2026

Summary

  • Fixes sporadic RuntimeError: Unexpected floating ScalarType in at::autocast::prioritize when using WAN 2.1 models with fp8_e4m3fn_fast weight dtype
  • Wraps all torch.addcmul calls in WAN attention blocks (WanAttentionBlock, Head, WanAttentionBlockAudio) with an autocast-disabled context when autocast is active
  • PyTorch 2.8's autocast prioritize function handles kFloat, kDouble, and lower_precision_fp (BF16/FP16) but crashes on FP8 ScalarTypes since is_autocast_eligible() returns true for FP8 CUDA tensors

Root Cause

When fp8_e4m3fn_fast weights are used:

  1. self.modulation (nn.Parameter) is initialized in FP8
  2. It gets cast to x.dtype at runtime via cast_to()
  3. But torch.addcmul's autocast dispatch path can still encounter the FP8 type in at::autocast::prioritize, which has no promotion rule for it
  4. The error is sporadic because it depends on tensor memory layout, tiling configuration (ComfyUI-TiledDiffusion), and which repeat_e code path is taken

Approach

Added _addcmul() wrapper that checks torch.is_autocast_enabled() and disables autocast for the operation when active. This matches the existing pattern in comfy/ldm/modules/sub_quadratic_attention.py and other modules. When autocast is not enabled, there is zero overhead (no context manager).

Test plan

  • Run WAN 2.1 VACE FusionX workflow with fp8_e4m3fn_fast weights using uni_pc sampler — should no longer crash
  • Run same workflow with fp16 weights — should still work identically
  • Run standard WAN 2.1 t2v workflow — no regression
  • Verify via ephemeral cloud deployment

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Apr 10, 2026

📝 Walkthrough

Walkthrough

A module-level helper _addcmul(x, y, z) was added to comfy/ldm/wan/model.py. It calls torch.addcmul but disables autocast for x.device.type during the operation when autocast is enabled. All direct torch.addcmul calls used for modulation arithmetic were replaced with _addcmul in WanAttentionBlock.forward, WanAttentionBlockAudio.forward, and Head.forward. No control flow, tensor shapes, surrounding expressions, or public interfaces were changed.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description check ✅ Passed The pull request description clearly explains the problem (RuntimeError with fp8 weights), the root cause (PyTorch autocast limitations with FP8), the solution (autocast-disabled wrapper), and includes a test plan.
Title check ✅ Passed The title accurately describes the main change: introducing an autocast-disabling wrapper for addcmul operations in WAN models to prevent crashes with fp8 weights.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
comfy/ldm/wan/model.py (1)

177-183: Use torch.is_autocast_enabled(x.device.type) for device-specific autocast detection.

The wrapper correctly addresses the PyTorch autocast prioritization issue with FP8 ScalarTypes. Using enabled=False is the right approach here (vs. forcing a specific dtype which would fail on CPU per existing codebase patterns).

The current code uses torch.is_autocast_enabled() without arguments, which is the deprecated form in PyTorch 2.8. For proper device-specific handling, use torch.is_autocast_enabled(x.device.type) instead—this works for all device types (CUDA, CPU, MPS, etc.) and aligns with the modern PyTorch API, though FP8 is primarily a CUDA concern.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@comfy/ldm/wan/model.py` around lines 177 - 183, The _addcmul wrapper uses the
deprecated torch.is_autocast_enabled() without device info; change it to call
torch.is_autocast_enabled(x.device.type) so autocast is checked per-device
(CUDA/CPU/MPS) before entering the with torch.autocast(..., enabled=False) block
in function _addcmul, keeping the existing behavior of disabling autocast when
enabled for that specific device.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@comfy/ldm/wan/model.py`:
- Around line 177-183: The _addcmul wrapper uses the deprecated
torch.is_autocast_enabled() without device info; change it to call
torch.is_autocast_enabled(x.device.type) so autocast is checked per-device
(CUDA/CPU/MPS) before entering the with torch.autocast(..., enabled=False) block
in function _addcmul, keeping the existing behavior of disabling autocast when
enabled for that specific device.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0849a7d6-25a0-42f5-950b-156b1dff19b9

📥 Commits

Reviewing files that changed from the base of the PR and between 2d861fb and 8e4bc0e.

📒 Files selected for processing (1)
  • comfy/ldm/wan/model.py

@deepme987
Copy link
Copy Markdown
Contributor Author

Updated Analysis

Testing confirmed the error also occurs with default (non-FP8) weight dtype, which means this isn't solely an FP8 issue. The root cause is that CUDA autocast gets enabled upstream (likely by the SAM2 library used in SAM3 nodes — SAM2's predictor.py wraps operations in torch.autocast("cuda", dtype=torch.bfloat16)), and when autocast is active, torch.addcmul goes through PyTorch's autocast promote dispatch (registered via AT_FORALL_PROMOTE in autocast_mode.cpp), which calls prioritize(). The prioritize() function only handles kFloat, kDouble, and lower_precision_fp — if any tensor has an unexpected type (or current/next don't match the expected patterns), it crashes.

The _addcmul wrapper that disables autocast is the correct fix regardless of what enables autocast upstream, since addcmul in the WAN model operates on tensors that are already in the correct dtype and don't need autocast promotion.

@deepme987 deepme987 force-pushed the deepme987/fix-fp8-addcmul-autocast branch from 338706c to 7d6fda7 Compare April 10, 2026 03:59
Wraps torch.addcmul calls in WAN attention blocks with autocast-disabled
context to prevent 'Unexpected floating ScalarType in at::autocast::prioritize'
RuntimeError. This occurs when upstream nodes (e.g. SAM3) leave CUDA autocast
enabled - PyTorch 2.8's autocast promote dispatch for addcmul hits an unhandled
dtype in the prioritize function.

Uses torch.is_autocast_enabled(device_type) (non-deprecated API) and only
applies the workaround when autocast is actually active (zero overhead otherwise).
@deepme987 deepme987 force-pushed the deepme987/fix-fp8-addcmul-autocast branch from 7d6fda7 to ec913cc Compare April 10, 2026 04:09
@alexisrolland alexisrolland changed the title fix: prevent autocast crash with fp8 weights in WAN model addcmul ops fix: prevent autocast crash with fp8 weights in WAN model addcmul ops (CORE-47) Apr 10, 2026
@deepme987 deepme987 closed this Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants