fix: prevent autocast crash with fp8 weights in WAN model addcmul ops (CORE-47)#13343
fix: prevent autocast crash with fp8 weights in WAN model addcmul ops (CORE-47)#13343
Conversation
📝 WalkthroughWalkthroughA module-level helper 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
comfy/ldm/wan/model.py (1)
177-183: Usetorch.is_autocast_enabled(x.device.type)for device-specific autocast detection.The wrapper correctly addresses the PyTorch autocast prioritization issue with FP8 ScalarTypes. Using
enabled=Falseis the right approach here (vs. forcing a specific dtype which would fail on CPU per existing codebase patterns).The current code uses
torch.is_autocast_enabled()without arguments, which is the deprecated form in PyTorch 2.8. For proper device-specific handling, usetorch.is_autocast_enabled(x.device.type)instead—this works for all device types (CUDA, CPU, MPS, etc.) and aligns with the modern PyTorch API, though FP8 is primarily a CUDA concern.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@comfy/ldm/wan/model.py` around lines 177 - 183, The _addcmul wrapper uses the deprecated torch.is_autocast_enabled() without device info; change it to call torch.is_autocast_enabled(x.device.type) so autocast is checked per-device (CUDA/CPU/MPS) before entering the with torch.autocast(..., enabled=False) block in function _addcmul, keeping the existing behavior of disabling autocast when enabled for that specific device.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@comfy/ldm/wan/model.py`:
- Around line 177-183: The _addcmul wrapper uses the deprecated
torch.is_autocast_enabled() without device info; change it to call
torch.is_autocast_enabled(x.device.type) so autocast is checked per-device
(CUDA/CPU/MPS) before entering the with torch.autocast(..., enabled=False) block
in function _addcmul, keeping the existing behavior of disabling autocast when
enabled for that specific device.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 0849a7d6-25a0-42f5-950b-156b1dff19b9
📒 Files selected for processing (1)
comfy/ldm/wan/model.py
Updated AnalysisTesting confirmed the error also occurs with The |
338706c to
7d6fda7
Compare
Wraps torch.addcmul calls in WAN attention blocks with autocast-disabled context to prevent 'Unexpected floating ScalarType in at::autocast::prioritize' RuntimeError. This occurs when upstream nodes (e.g. SAM3) leave CUDA autocast enabled - PyTorch 2.8's autocast promote dispatch for addcmul hits an unhandled dtype in the prioritize function. Uses torch.is_autocast_enabled(device_type) (non-deprecated API) and only applies the workaround when autocast is actually active (zero overhead otherwise).
7d6fda7 to
ec913cc
Compare
Summary
RuntimeError: Unexpected floating ScalarType in at::autocast::prioritizewhen using WAN 2.1 models withfp8_e4m3fn_fastweight dtypetorch.addcmulcalls in WAN attention blocks (WanAttentionBlock,Head,WanAttentionBlockAudio) with an autocast-disabled context when autocast is activeprioritizefunction handleskFloat,kDouble, andlower_precision_fp(BF16/FP16) but crashes on FP8 ScalarTypes sinceis_autocast_eligible()returns true for FP8 CUDA tensorsRoot Cause
When
fp8_e4m3fn_fastweights are used:self.modulation(nn.Parameter) is initialized in FP8x.dtypeat runtime viacast_to()torch.addcmul's autocast dispatch path can still encounter the FP8 type inat::autocast::prioritize, which has no promotion rule for itrepeat_ecode path is takenApproach
Added
_addcmul()wrapper that checkstorch.is_autocast_enabled()and disables autocast for the operation when active. This matches the existing pattern incomfy/ldm/modules/sub_quadratic_attention.pyand other modules. When autocast is not enabled, there is zero overhead (no context manager).Test plan
fp8_e4m3fn_fastweights usinguni_pcsampler — should no longer crashfp16weights — should still work identically