In SkyRL v0.2.0's FSDP backend, there's a bit of an asymmetry:
This leads to a case where, by default:
- Policy and Critic: mixed precision (weights fp32, forward pass
autocasts to bf16)
- Ref: pure bf16 (weights bf16, forward pass
autocasts to bf16)
The asymmetry produces no observable failure on small models (e.g. ≤14B) and short sequences (e.g. ≤16k tokens), but on large dense models with long sequences, the ref's pure-bf16 attention overflows; one bad key/value position then poisons every later position in the sequence (each later position attends back to the bad one).
The cause:
- Ref's weights in bf16 vs policy weights in fp32, across many (e.g. Qwen3-32B has 64) layers, a rounding error compounds
- Requires
use_kl_loss=true and kl_loss_coef>0, so the reference model is wired into the policy loss
- Eventually an attention dot product saturates to +/- inf
- Turns
log_probs_base into NaN
- Contaminates the final loss to be NaN
Workarounds
Match the hardcoded bf16=false for policy and critic.
- Configure
trainer.bf16=false on the launch CLI
- Patch
HFModelWrapper.__init__ globally to force bf16=False
In SkyRL v0.2.0's FSDP backend, there's a bit of an asymmetry:
HFModelWrapper.__init__andTrainerConfig: defaultbf16toTrueinit_model: hardcodesbf16=Falseinit_model: hardcodesbf16=Falseinit_model: passesself.cfg.bf16This leads to a case where, by default:
autocasts to bf16)autocasts to bf16)The asymmetry produces no observable failure on small models (e.g. ≤14B) and short sequences (e.g. ≤16k tokens), but on large dense models with long sequences, the ref's pure-bf16 attention overflows; one bad key/value position then poisons every later position in the sequence (each later position attends back to the bad one).
The cause:
use_kl_loss=trueandkl_loss_coef>0, so the reference model is wired into the policy losslog_probs_baseinto NaNWorkarounds
Match the hardcoded
bf16=falsefor policy and critic.trainer.bf16=falseon the launch CLIHFModelWrapper.__init__globally to forcebf16=False