diff --git a/docs/guides/quantization-aware-rl.md b/docs/guides/quantization-aware-rl.md index 319d283488..9ccd06148c 100644 --- a/docs/guides/quantization-aware-rl.md +++ b/docs/guides/quantization-aware-rl.md @@ -1,6 +1,6 @@ # Quantization-Aware RL (QARL) -Quantization-Aware RL (QARL) integrates [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/Model-Optimizer) into the NeMo RL training loop, enabling quantization-aware training and generation for both GRPO and on-policy distillation workflows. QARL automatically quantizes a standard model at initialization, maintains quantizer state (amax values) throughout training, and transfers quantized state to vLLM during weight refit. By default, vLLM generation uses fake-quantized modules. For NVFP4 W4A16 rollout experiments, NeMo RL can instead stream packed real-quant ModelOpt NVFP4 weights into vLLM. +Quantization-Aware RL (QARL) integrates [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/Model-Optimizer) into the NeMo RL training loop, enabling quantization-aware training and generation for both GRPO and on-policy distillation workflows. QARL automatically quantizes a standard model at initialization, maintains quantizer state (amax values) throughout training, and transfers quantized state to vLLM during weight refit. By default, vLLM generation uses fake-quantized modules. For NVFP4 W4A16 rollout experiments, NeMo RL can instead stream packed real-quant ModelOpt NVFP4 weights and scales into vLLM. ## Overview @@ -9,7 +9,7 @@ In a standard NeMo RL loop, model weights are trained in full precision and refi There are two vLLM rollout modes: - **Fake-quant rollout**: vLLM receives folded full-precision weights and runs fake-quantized layers. This is the default when `policy.generation.quant_cfg` is set. -- **Real-quant rollout**: vLLM is initialized with ModelOpt NVFP4 kernels and receives packed NVFP4 weights plus scale tensors during every refit. Enable this with `policy.generation.real_quant: true`. +- **Real-quant rollout**: vLLM is initialized with ModelOpt NVFP4 kernels and receives packed NVFP4 weights plus scale tensors during every refit. Enable this with `policy.generation.real_quant: true`. W4A16 real-quant rollout supports dense ModelOpt NVFP4 layers and fused MoE weights exported through Megatron-Bridge. See [Verified Configurations](#verified-configurations) for the workflow + recipe combinations that have been empirically validated, and [Supported Quantization Formats](#supported-quantization-formats) for the full set of available formats. W4A4 (`NVFP4_DEFAULT_CFG`) converges for on-policy distillation but has been observed to have convergence issues on GRPO; W4A16 (NVFP4 weights, native-dtype activations) works for GRPO. @@ -26,8 +26,9 @@ The following workflow + quantization recipe combinations have been validated en | QA-Distillation | W4A4 | `examples/modelopt/quant_configs/nano3_nvfp4_default.yaml` | ✅ Converges | `examples/modelopt/qa_distillation_nano3_megatron.yaml` | | QA-GRPO | W4A16 | `NVFP4_MLP_WEIGHT_ONLY_CFG` | ✅ Smoke tested on MoE | `examples/modelopt/qa_grpo_qwen3_30ba3b_megatron.yaml` | | QA-GRPO real quantization rollout | W4A16 | `examples/modelopt/quant_configs/nvfp4_a16_mlp_only.yaml` with `policy.generation.real_quant: true` | ✅ Converges | `examples/configs/recipes/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a16.yaml` | +| QA-GRPO real quantization rollout | W4A16 | `examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml` with `policy.generation.real_quant: true` and `policy.generation.real_quant_ignore: NANO3_NVFP4_IGNORE` | ✅ Converges tested on hybrid MoE/Mamba | `examples/configs/recipes/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.yaml` | -The `nvfp4_a16.yaml` custom YAML enables NVFP4 e2m1 weight quantization (with dynamic e4m3 micro-block scales) and leaves activations unquantized; weights are still exercised through both Megatron training and vLLM generation. The `nvfp4_a16_mlp_only.yaml` recipe restricts W4A16 to MLP weights for real-quant rollout. The `nvfp4_w4a8_fp8.yaml` recipe uses the same NVFP4 weight format and enables FP8 e4m3 input activation fake quantization. +The `nvfp4_a16.yaml` custom YAML enables NVFP4 e2m1 weight quantization (with dynamic e4m3 micro-block scales) and leaves activations unquantized; weights are still exercised through both Megatron training and vLLM generation. The `nvfp4_a16_mlp_only.yaml` recipe restricts W4A16 to MLP weights for real-quant rollout. The Nano3 `nano3_nvfp4_weightonly.yaml` recipe applies the same W4A16 weight-only format to the supported MLP/MoE weights while keeping Nano3-sensitive Mamba, attention, gate/router, shared-expert, norm, and selected layer paths in BF16 through the `NANO3_NVFP4_IGNORE` profile. The `nvfp4_w4a8_fp8.yaml` and `nano3_nvfp4_w4a8_fp8.yaml` recipes use the same NVFP4 weight format and enable FP8 e4m3 input activation fake quantization. ## ModelOpt Layer Spec Toggle @@ -92,7 +93,7 @@ sbatch \ Real-quant rollout is intended for checking the deployment-style vLLM path during RL, not only the fake-quant training path. With `policy.generation.real_quant: true`, the Megatron policy worker exports ModelOpt QAT weights as packed NVFP4 tensors during refit, and the vLLM worker loads them into ModelOpt NVFP4 layers. This exercises vLLM's real FP4 kernel path during rollout while the policy training worker remains a QAT model. -This path is validated for W4A16. +This path is validated for W4A16. Dense models can use the default real-quant ignore profile. Hybrid MoE/Mamba models such as Nano3 should use a model-specific ignore profile so unsupported or numerically sensitive paths stay in BF16. ### Minimal Configuration @@ -108,12 +109,34 @@ policy: real_quant: true ``` +For Nano3 W4A16 real-quant rollout, use the Nano3 weight-only recipe and the named ignore profile: + +```yaml +policy: + quant_cfg: examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml + + generation: + backend: vllm + quant_cfg: examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml + real_quant: true + real_quant_ignore: NANO3_NVFP4_IGNORE + vllm_cfg: + gpu_memory_utilization: 0.35 + enable_prefix_caching: false +``` + The ready-to-run 1-node DAPO smoke recipe is: ```text examples/configs/recipes/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a16.yaml ``` +The ready-to-run Nano3 4-node x 4-GPU smoke recipe is: + +```text +examples/configs/recipes/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.yaml +``` + Use the matching BF16 recipe as the baseline: ```text @@ -130,6 +153,14 @@ uv run --extra mcore --extra modelopt --extra vllm \ --config examples/configs/recipes/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a16.yaml ``` +For Nano3: + +```bash +uv run --extra mcore --extra modelopt --extra vllm \ + examples/run_grpo.py \ + --config examples/configs/recipes/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.yaml +``` + For a BF16 comparison run: ```bash @@ -164,7 +195,8 @@ For long runs on queues with short wall times, enable periodic checkpointing and A healthy W4A16 real-rollout run should include these lines or equivalent vLLM logs: ```text -quantization=modelopt_fp4 +quantization=modelopt +Detected ModelOpt NVFP4 checkpoint Using NvFp4LinearBackend.MARLIN for NVFP4 GEMM MegatronQuantPolicyWorker[rank=0]: Packed ... groups of tensors ``` @@ -183,10 +215,11 @@ For an initial sanity check, compare the first `Generation KL Error` with the BF | Symptom | Likely Cause | Action | |---|---|---| -| vLLM does not log `quantization=modelopt_fp4` | `policy.generation.real_quant` is not set or generation is not using vLLM | Check the YAML under `policy.generation` | +| vLLM does not log `quantization=modelopt` | `policy.generation.real_quant` is not set or generation is not using vLLM | Check the YAML under `policy.generation` | | `Using rollout logprobs` appears | The run is bypassing policy/reference logprob computation | Do not use rollout logprobs for real-quant validation | | First-step W4A16 `Generation KL Error` is much higher than BF16 | Stale converted Megatron checkpoint or refit/export mismatch | Clear checkpoints and rerun; confirm packed tensors are streamed | | `negative scales` warning appears | Invalid or stale NVFP4 scale tensors reached vLLM | Clear checkpoints and verify `nvfp4_a16_mlp_only.yaml` is used for both policy and generation | +| Nano3 first-step KL is high while dense W4A16 is healthy | Nano3-sensitive paths were quantized or the vLLM ignore set does not match the policy recipe | Use `nano3_nvfp4_weightonly.yaml` for policy and generation, and set `policy.generation.real_quant_ignore: NANO3_NVFP4_IGNORE` | | CUDA invalid argument during refit or generation | vLLM consumed malformed packed tensors or stale IPC state | Restart from a fresh job and inspect the first real-quant refit logs | ## Fake-Quant NVFP4 Rollout (W4A8) @@ -258,7 +291,7 @@ Generation-specific parameters are added under `policy.generation`: |---|---| | `quant_cfg` | Quantization config used by the vLLM generation worker. For QARL, this should normally match `policy.quant_cfg`. | | `real_quant` | When `true`, vLLM uses ModelOpt NVFP4 real kernels and receives packed quantized weights during refit. When unset or `false`, vLLM uses fake-quantized generation. | -| `real_quant_ignore` | Optional list of vLLM parameter name patterns that should stay in native dtype during real-quant rollout. If omitted, NeMo RL uses the default ModelOpt NVFP4 ignore set for sensitive layers such as attention and output heads. | +| `real_quant_ignore` | Optional list of vLLM parameter name patterns, or a named profile, that should stay in native dtype during real-quant rollout. If omitted, NeMo RL uses the default ModelOpt NVFP4 ignore set for sensitive layers such as attention and output heads. Use `NANO3_NVFP4_IGNORE` for Nano3 hybrid MoE/Mamba W4A16 real-quant rollout. | ## Megatron Checkpoint Directory @@ -311,7 +344,7 @@ uv run --extra mcore --extra modelopt \ - **Generation**: Currently only vLLM is supported for generation. - **DTensor backend**: Quantization support for the DTensor policy worker is not yet implemented. -- **Real-quant rollout**: W4A16 real rollout is supported for dense vLLM ModelOpt NVFP4 layers. +- **Real-quant rollout**: W4A16 real rollout is supported for dense and fused-MoE vLLM ModelOpt NVFP4 layers. Hybrid MoE/Mamba recipes should keep unsupported or sensitive non-MLP paths in BF16 via `real_quant_ignore`. - **W4A8 rollout**: W4A8 is supported through fake-quant rollout. - **Input quantization**: Only per-tensor input (activation) quantization is supported. - **Model support**: Dense Transformer, MoE (Mixture of Experts), and hybrid MoE/Mamba models are supported on the Megatron policy + vLLM generation path when Megatron-Bridge and ModelOpt support the model architecture and quantization recipe. MoE/Mamba support is currently covered by smoke-tested example configs rather than broad convergence guarantees. diff --git a/examples/configs/recipes/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.yaml b/examples/configs/recipes/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.yaml new file mode 100644 index 0000000000..5a64cc22c1 --- /dev/null +++ b/examples/configs/recipes/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.yaml @@ -0,0 +1,41 @@ +defaults: ../../../../examples/modelopt/qa_grpo_nano3_megatron.yaml +grpo: + max_num_steps: 1 + val_period: 0 +checkpointing: + checkpoint_dir: results/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real +policy: + model_name: /lustre/fs1/portfolios/coreai/projects/coreai_dlalgo_ci/artifacts/model/nvidia_nvidia-nemotron-3-nano-30b-a3b-bf16/hf/hf-cbd3fa9_orig + tokenizer: + name: ${policy.model_name} + quant_calib_size: 16 + quant_sequence_length: 1024 + generation: + real_quant: true + real_quant_ignore: NANO3_NVFP4_IGNORE + vllm_cfg: + gpu_memory_utilization: 0.35 + enable_prefix_caching: false + vllm_kwargs: + tokenizer: ${policy.tokenizer.name} +data: + max_input_seq_length: 1024 + train: + dataset_name: DAPOMath17K + default: + prompt_file: null +env: + dapo: + num_workers: 2 + math: + num_workers: 2 + math_verify_impl: dapo_math_verify +logger: + log_dir: logs/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real + wandb: + name: grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real + tensorboard: + log_dir: tb_logs-grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real +cluster: + gpus_per_node: 4 + num_nodes: 4 diff --git a/examples/modelopt/quant_configs/nano3_nvfp4_default.yaml b/examples/modelopt/quant_configs/nano3_nvfp4_default.yaml index 7c0ab7c844..fd08281b27 100644 --- a/examples/modelopt/quant_configs/nano3_nvfp4_default.yaml +++ b/examples/modelopt/quant_configs/nano3_nvfp4_default.yaml @@ -80,6 +80,8 @@ quantize: enable: false - quantizer_name: '*router*' enable: false + - quantizer_name: '*.gate.*' + enable: false - quantizer_name: '*mlp.gate.*' enable: false - quantizer_name: '*mlp.shared_expert_gate.*' @@ -88,13 +90,29 @@ quantize: enable: false - quantizer_name: '*mixer.conv1d*' enable: false + - quantizer_name: '*.mixer.in_proj.*' + enable: false + - quantizer_name: '*.mixer.out_proj.*' + enable: false + - quantizer_name: '*.shared_expert.*' + enable: false + - quantizer_name: '*.shared_experts.*' + enable: false + - quantizer_name: '*.norm.*' + enable: false - quantizer_name: '*output_layer*' enable: false - quantizer_name: 'output.*' enable: false # Nano3-specific BF16 layers: attention projections and explicit # pre-attention layers 4/11/18/25/32/41. - - quantizer_name: '*.[q|k|v|o]_proj.*' + - quantizer_name: '*.q_proj.*' + enable: false + - quantizer_name: '*.k_proj.*' + enable: false + - quantizer_name: '*.v_proj.*' + enable: false + - quantizer_name: '*.o_proj.*' enable: false - quantizer_name: '*.qkv_proj.*' enable: false diff --git a/examples/modelopt/quant_configs/nano3_nvfp4_w4a8_fp8.yaml b/examples/modelopt/quant_configs/nano3_nvfp4_w4a8_fp8.yaml new file mode 100644 index 0000000000..63d44447ec --- /dev/null +++ b/examples/modelopt/quant_configs/nano3_nvfp4_w4a8_fp8.yaml @@ -0,0 +1,98 @@ +# Nano3 (NemotronH) NVFP4 W4A8 custom quantization recipe. +# +# Defines W4A8: NVFP4 (e2m1, dynamic e4m3 micro-block scales) on weights +# and FP8 e4m3 on layer-input activations. +# +# This mirrors the Nano3 exclusions used by the W4A16/W4A4 recipes so Mamba +# projections, attention projections, shared experts, gates/routers, norms, and +# explicit BF16 layers stay out of the fake-quant path. +metadata: + recipe_type: ptq + description: Nano3 NVFP4 weights with FP8 input activations for W4A8 QAT. +quantize: + algorithm: max + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: '*.gate.*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*.mixer.in_proj.*' + enable: false + - quantizer_name: '*.mixer.out_proj.*' + enable: false + - quantizer_name: '*.shared_expert.*' + enable: false + - quantizer_name: '*.shared_experts.*' + enable: false + - quantizer_name: '*.norm.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: 'output.*' + enable: false + - quantizer_name: '*.q_proj.*' + enable: false + - quantizer_name: '*.k_proj.*' + enable: false + - quantizer_name: '*.v_proj.*' + enable: false + - quantizer_name: '*.o_proj.*' + enable: false + - quantizer_name: '*.qkv_proj.*' + enable: false + - quantizer_name: '*.linear_proj.*' + enable: false + - quantizer_name: '*.linear_qkv.*' + enable: false + - quantizer_name: '*.layers.4.*' + enable: false + - quantizer_name: '*.layers.11.*' + enable: false + - quantizer_name: '*.layers.18.*' + enable: false + - quantizer_name: '*.layers.25.*' + enable: false + - quantizer_name: '*.layers.32.*' + enable: false + - quantizer_name: '*.layers.41.*' + enable: false diff --git a/examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml b/examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml index 0b56ce1ead..0c046565ce 100644 --- a/examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml +++ b/examples/modelopt/quant_configs/nano3_nvfp4_weightonly.yaml @@ -71,6 +71,8 @@ quantize: enable: false - quantizer_name: '*router*' enable: false + - quantizer_name: '*.gate.*' + enable: false - quantizer_name: '*mlp.gate.*' enable: false - quantizer_name: '*mlp.shared_expert_gate.*' @@ -79,13 +81,29 @@ quantize: enable: false - quantizer_name: '*mixer.conv1d*' enable: false + - quantizer_name: '*.mixer.in_proj.*' + enable: false + - quantizer_name: '*.mixer.out_proj.*' + enable: false + - quantizer_name: '*.shared_expert.*' + enable: false + - quantizer_name: '*.shared_experts.*' + enable: false + - quantizer_name: '*.norm.*' + enable: false - quantizer_name: '*output_layer*' enable: false - quantizer_name: 'output.*' enable: false # Nano3-specific BF16 layers: attention projections and explicit # pre-attention layers 4/11/18/25/32/41. - - quantizer_name: '*.[q|k|v|o]_proj.*' + - quantizer_name: '*.q_proj.*' + enable: false + - quantizer_name: '*.k_proj.*' + enable: false + - quantizer_name: '*.v_proj.*' + enable: false + - quantizer_name: '*.o_proj.*' enable: false - quantizer_name: '*.qkv_proj.*' enable: false diff --git a/nemo_rl/modelopt/models/generation/vllm_modelopt_patch.py b/nemo_rl/modelopt/models/generation/vllm_modelopt_patch.py index 95e0e28aad..d71aca26fd 100644 --- a/nemo_rl/modelopt/models/generation/vllm_modelopt_patch.py +++ b/nemo_rl/modelopt/models/generation/vllm_modelopt_patch.py @@ -13,29 +13,67 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""vLLM ModelOpt NVFP4 patches for dense rollout weight reloads.""" +"""vLLM ModelOpt NVFP4 patches for rollout weight reloads.""" + +from typing import Any import torch from torch.nn import Parameter _DENSE_HF_PARAMS = ("weight", "weight_scale", "weight_scale_2") +_MOE_HF_PARAMS = ( + "w13_weight", + "w13_weight_scale", + "w13_weight_scale_2", + "w13_input_scale", + "w2_weight", + "w2_weight_scale", + "w2_weight_scale_2", + "w2_input_scale", +) +_MOE_INPUT_SCALE_PARAMS = ("w13_input_scale", "w2_input_scale") +_MOE_MARLIN_TENSOR_PARAMS = _MOE_HF_PARAMS[:3] + _MOE_HF_PARAMS[4:7] _MODELOPT_W4A16_QUANT_MODES = frozenset({"w4a16_nvfp4", "nvfp4_w4a16"}) _MODELOPT_W4A16_ATTR = "_nrl_weight_only_w4a16" +_MODELOPT_W4A16_MOE_MARLIN_TILE_N = 64 +_MODELOPT_PARAM_META_ATTR = "_nrl_modelopt_param_meta" +_MODELOPT_WEIGHT_LOADERS_ATTR = "_nrl_modelopt_weight_loaders" +_MODELOPT_RELOAD_PARAM_ATTRS = ("quant_method",) _ORIGINAL_NVFP4_CONFIG_FROM_CONFIG_ATTR = "_nrl_original_from_config" _ORIGINAL_LINEAR_APPLY_ATTR = "_nrl_original_apply" +_ORIGINAL_FUSED_MOE_INIT_ATTR = "_nrl_original_init" +_ORIGINAL_FUSED_MOE_ROUNDUP_SIZES_ATTR = "_nrl_original_maybe_roundup_sizes" +_ORIGINAL_FUSED_MOE_PROCESS_WEIGHTS_ATTR = "_nrl_original_process_weights_after_loading" +_ORIGINAL_KV_CACHE_PROCESS_WEIGHTS_ATTR = ( + "_nrl_original_kv_cache_process_weights_after_loading" +) +_MODELOPT_PROCESS_WEIGHTS_CALL_COUNT_ATTR = "_nrl_process_weights_call_count" +_MODELOPT_PROCESSED_TENSOR_REFS_ATTR = "_nrl_modelopt_processed_tensor_refs" def _unwrap_vllm_model(model: torch.nn.Module) -> torch.nn.Module: return model.model if hasattr(model, "model") else model -def _canonicalize_nvfp4_weight_scale(layer: torch.nn.Module) -> None: - weight_scale = layer.weight_scale - scale = weight_scale.data.to(torch.float32).abs().to(weight_scale.dtype) - weight_scale.data.copy_(scale) +def _canonicalize_nvfp4_scale_parameter( + layer: torch.nn.Module, + param_name: str, +) -> None: + param = getattr(layer, param_name) + scale = param.data.to(torch.float32).abs().to(param.dtype) + param.data.copy_(scale) + + +def _canonicalize_dense_nvfp4_weight_scale(layer: torch.nn.Module) -> None: + _canonicalize_nvfp4_scale_parameter(layer, "weight_scale") -def _requests_w4a16_modelopt_config(config: dict) -> bool: +def _canonicalize_moe_nvfp4_weight_scales(layer: torch.nn.Module) -> None: + _canonicalize_nvfp4_scale_parameter(layer, "w13_weight_scale") + _canonicalize_nvfp4_scale_parameter(layer, "w2_weight_scale") + + +def _requests_w4a16_modelopt_config(config: dict[str, Any]) -> bool: quant_mode = config.get("quant_mode") if ( isinstance(quant_mode, str) @@ -49,11 +87,40 @@ def _requests_w4a16_modelopt_config(config: dict) -> bool: return isinstance(nested, dict) and _requests_w4a16_modelopt_config(nested) -def _is_w4a16_modelopt_quant_config(quant_config) -> bool: +def _is_w4a16_modelopt_quant_config(quant_config: object) -> bool: return bool(getattr(quant_config, _MODELOPT_W4A16_ATTR, False)) -def _modelopt_nvfp4_config_from_config(cls, *args, **kwargs): +def _ceil_div(value: int, divisor: int) -> int: + return (value + divisor - 1) // divisor + + +def _is_w4a16_marlin_moe_quant_method(quant_method: object) -> bool: + backend = getattr(quant_method, "nvfp4_backend", None) + return ( + quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoE" + and _is_w4a16_modelopt_quant_config(getattr(quant_method, "quant_config", None)) + and getattr(backend, "value", backend) == "MARLIN" + ) + + +def _stash_original(cls: type, attr: str, value: object) -> None: + if not hasattr(cls, attr): + setattr(cls, attr, value) + + +def _require_fp4_marlin_supported() -> None: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + is_fp4_marlin_supported, + ) + + if not is_fp4_marlin_supported(): + raise RuntimeError( + "ModelOpt NVFP4 W4A16 rollout requires vLLM FP4 Marlin support." + ) + + +def _modelopt_nvfp4_config_from_config(cls, *args: Any, **kwargs: Any) -> object: original_from_config = getattr(cls, _ORIGINAL_NVFP4_CONFIG_FROM_CONFIG_ATTR) quant_config = original_from_config(*args, **kwargs) @@ -66,7 +133,10 @@ def _modelopt_nvfp4_config_from_config(cls, *args, **kwargs): return quant_config -def _convert_nvfp4_linear_kernel_format(quant_method, layer: torch.nn.Module) -> None: +def _convert_nvfp4_linear_kernel_format( + quant_method: object, + layer: torch.nn.Module, +) -> None: kernel = getattr(quant_method, "kernel", None) if kernel is not None: kernel.process_weights_after_loading(layer) @@ -79,23 +149,21 @@ def _convert_nvfp4_linear_kernel_format(quant_method, layer: torch.nn.Module) -> convert_to_nvfp4_linear_kernel_format(quant_method.backend, layer) -def _convert_w4a16_linear_kernel_format(layer: torch.nn.Module) -> None: - from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( - prepare_fp4_layer_for_marlin, - ) - - prepare_fp4_layer_for_marlin(layer) - - -def _capture_modelopt_dense_param_reload_meta(layer: torch.nn.Module) -> None: - if not hasattr(layer, "_nrl_modelopt_param_meta"): - layer._nrl_modelopt_param_meta = {} - layer._nrl_modelopt_weight_loaders = {} - elif not hasattr(layer, "_nrl_modelopt_weight_loaders"): - layer._nrl_modelopt_weight_loaders = {} - - for param_name in _DENSE_HF_PARAMS: - if param_name in layer._nrl_modelopt_param_meta: +def _capture_modelopt_param_reload_meta( + layer: torch.nn.Module, + param_names: tuple[str, ...], +) -> None: + param_meta = getattr(layer, _MODELOPT_PARAM_META_ATTR, None) + if param_meta is None: + param_meta = {} + setattr(layer, _MODELOPT_PARAM_META_ATTR, param_meta) + weight_loaders = getattr(layer, _MODELOPT_WEIGHT_LOADERS_ATTR, None) + if weight_loaders is None: + weight_loaders = {} + setattr(layer, _MODELOPT_WEIGHT_LOADERS_ATTR, weight_loaders) + + for param_name in param_names: + if param_name in param_meta: continue param = getattr(layer, param_name) meta = { @@ -108,14 +176,50 @@ def _capture_modelopt_dense_param_reload_meta(layer: torch.nn.Module) -> None: meta["input_dim"] = param._input_dim if hasattr(param, "_output_dim"): meta["output_dim"] = param._output_dim - layer._nrl_modelopt_param_meta[param_name] = meta + attrs = { + attr: getattr(param, attr) + for attr in _MODELOPT_RELOAD_PARAM_ATTRS + if hasattr(param, attr) + } + if attrs: + meta["attrs"] = attrs + param_meta[param_name] = meta if hasattr(param, "weight_loader"): - layer._nrl_modelopt_weight_loaders[param_name] = param.weight_loader + weight_loaders[param_name] = param.weight_loader + + +def _is_first_modelopt_process_weights_call(layer: torch.nn.Module) -> bool: + count = getattr(layer, _MODELOPT_PROCESS_WEIGHTS_CALL_COUNT_ATTR, 0) + setattr(layer, _MODELOPT_PROCESS_WEIGHTS_CALL_COUNT_ATTR, count + 1) + return count == 0 + + +def _set_or_update_processed_tensor_ref( + layer: torch.nn.Module, + param_name: str, + data: torch.Tensor, + is_first_call: bool, +) -> None: + refs = getattr(layer, _MODELOPT_PROCESSED_TENSOR_REFS_ATTR, None) + if refs is None: + refs = {} + setattr(layer, _MODELOPT_PROCESSED_TENSOR_REFS_ATTR, refs) + + ref = refs.get(param_name) + if is_first_call or ref is None: + setattr(layer, param_name, Parameter(data, requires_grad=False)) + refs[param_name] = getattr(layer, param_name).data + return + + ref.copy_(data) + setattr(layer, param_name, Parameter(ref, requires_grad=False)) def _modelopt_dense_process_w4a16_weights(self, layer: torch.nn.Module) -> None: """Convert dense ModelOpt NVFP4 W4A16 weights for Marlin weight-only GEMM.""" - _capture_modelopt_dense_param_reload_meta(layer) + _require_fp4_marlin_supported() + _capture_modelopt_param_reload_meta(layer, _DENSE_HF_PARAMS) + is_first_call = _is_first_modelopt_process_weights_call(layer) weight_global_scale = layer.weight_scale_2.max().to(torch.float32) layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) @@ -130,8 +234,19 @@ def _modelopt_dense_process_w4a16_weights(self, layer: torch.nn.Module) -> None: if hasattr(layer, attr): delattr(layer, attr) - _canonicalize_nvfp4_weight_scale(layer) - _convert_w4a16_linear_kernel_format(layer) + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_fp4_layer_for_marlin, + ) + + _canonicalize_dense_nvfp4_weight_scale(layer) + prepare_fp4_layer_for_marlin(layer) + for param_name in ("weight", "weight_scale", "weight_global_scale"): + _set_or_update_processed_tensor_ref( + layer, + param_name, + getattr(layer, param_name).data, + is_first_call=is_first_call, + ) def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: @@ -140,7 +255,7 @@ def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: _modelopt_dense_process_w4a16_weights(self, layer) return - _capture_modelopt_dense_param_reload_meta(layer) + _capture_modelopt_param_reload_meta(layer, _DENSE_HF_PARAMS) input_global_scale = torch.ones( (), @@ -162,7 +277,6 @@ def _modelopt_dense_process_weights(self, layer: torch.nn.Module) -> None: requires_grad=False, ) - _canonicalize_nvfp4_weight_scale(layer) _convert_nvfp4_linear_kernel_format(self, layer) @@ -195,17 +309,309 @@ def _modelopt_dense_apply( return self.kernel.apply_weights(layer=layer, x=x, bias=bias) -def prepare_modelopt_for_weight_reload(model, device=None) -> None: - """Prepare a dense ModelOpt-vLLM model for one weight reload cycle.""" +def _modelopt_moe_init(self, quant_config: object, moe_config: object) -> None: + if not _is_w4a16_modelopt_quant_config(quant_config): + original_init = getattr(type(self), _ORIGINAL_FUSED_MOE_INIT_ATTR) + original_init(self, quant_config, moe_config) + return + + try: + from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( + FusedMoEMethodBase, + ) + except ImportError: + from vllm.model_executor.layers.fused_moe.layer import FusedMoEMethodBase + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import NvFp4MoeBackend + + _require_fp4_marlin_supported() + FusedMoEMethodBase.__init__(self, moe_config) + self.__dict__.update( + quant_config=quant_config, + allow_flashinfer=False, + cutlass_nvfp4_supported=False, + flashinfer_moe_backend=None, + use_marlin=True, + backend="marlin", + nvfp4_backend=NvFp4MoeBackend.MARLIN, + experts_cls=MarlinExperts, + use_global_sf=False, + ) + + +def _modelopt_moe_maybe_roundup_sizes( + self, + hidden_size: int, + intermediate_size_per_partition: int, + act_dtype: torch.dtype, + moe_parallel_config: object, +) -> tuple[int, int]: + """Round W4A16 MoE intermediate shards to Marlin's FP4 tile size.""" + original_roundup = getattr(type(self), _ORIGINAL_FUSED_MOE_ROUNDUP_SIZES_ATTR) + hidden_size, intermediate_size_per_partition = original_roundup( + self, + hidden_size, + intermediate_size_per_partition, + act_dtype, + moe_parallel_config, + ) + + if _is_w4a16_marlin_moe_quant_method(self): + tile = _MODELOPT_W4A16_MOE_MARLIN_TILE_N + intermediate_size_per_partition = ( + _ceil_div( + intermediate_size_per_partition, + tile, + ) + * tile + ) + + return hidden_size, intermediate_size_per_partition + + +def _zero_modelopt_moe_padding(layer: torch.nn.Module) -> None: + moe_config = getattr(layer, "moe_config", None) + padded_size = getattr(moe_config, "intermediate_size_per_partition", None) + unpadded_size = getattr( + moe_config, + "intermediate_size_per_partition_unpadded", + padded_size, + ) + if ( + not isinstance(padded_size, int) + or not isinstance(unpadded_size, int) + or padded_size <= unpadded_size + ): + return + + tp_rank = getattr(layer, "tp_rank", None) + tp_size = getattr(layer, "tp_size", None) + if isinstance(tp_rank, int) and isinstance(tp_size, int) and tp_size > 0: + full_unpadded_size = unpadded_size * tp_size + valid_size = max( + 0, + min(padded_size, full_unpadded_size - padded_size * tp_rank), + ) + else: + valid_size = unpadded_size + if valid_size >= padded_size: + return + + quant_config = getattr(layer, "quant_config", None) + group_size = getattr(quant_config, "group_size", None) + if not isinstance(group_size, int) or group_size <= 0: + raise RuntimeError( + "Missing or invalid ModelOpt NVFP4 group_size for padded MoE reload" + ) + with torch.no_grad(): + for param_name in ("w13_weight", "w13_weight_scale"): + tensor = getattr(layer, param_name).data + if tensor.ndim >= 2: + for shard_start in range(0, tensor.shape[1], padded_size): + start = shard_start + valid_size + end = min(shard_start + padded_size, tensor.shape[1]) + if start < end: + tensor.narrow(1, start, end - start).zero_() + for tensor, start, end in ( + ( + layer.w2_weight.data, + _ceil_div(valid_size, 2), + _ceil_div(padded_size, 2), + ), + ( + layer.w2_weight_scale.data, + _ceil_div(valid_size, group_size), + _ceil_div(padded_size, group_size), + ), + ): + if start < end and tensor.ndim > 2 and start < tensor.shape[2]: + tensor.narrow(2, start, min(end, tensor.shape[2]) - start).zero_() + + +def _modelopt_moe_process_w4a16_marlin_weights( + self, + layer: torch.nn.Module, + is_first_call: bool, +) -> None: + from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + convert_to_nvfp4_moe_kernel_format, + make_nvfp4_moe_kernel, + ) + + w13_weight_scale_2 = layer.w13_weight_scale_2.data + if w13_weight_scale_2.dim() == 2: + w13_weight_scale_2 = w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) + + _canonicalize_moe_nvfp4_weight_scales(layer) + + ( + w13, + w13_scale, + w13_scale_2, + _a13_scale, + w2, + w2_scale, + w2_scale_2, + _a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight.data, + w13_scale=layer.w13_weight_scale.data, + w13_scale_2=layer.w13_weight_scale_2.data, + a13_scale=None, + w2=layer.w2_weight.data, + w2_scale=layer.w2_weight_scale.data, + w2_scale_2=layer.w2_weight_scale_2.data, + a2_scale=None, + is_act_and_mul=self.moe.is_act_and_mul, + ) + + for param_name, data in zip( + _MOE_MARLIN_TENSOR_PARAMS, + (w13, w13_scale, w13_scale_2, w2, w2_scale, w2_scale_2), + strict=True, + ): + _set_or_update_processed_tensor_ref( + layer, + param_name, + data, + is_first_call=is_first_call, + ) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + moe_quant_config = self.get_fused_moe_quant_config(layer) + if moe_quant_config is None: + raise RuntimeError("ModelOpt NVFP4 MoE quant config is missing") + self.moe_quant_config = moe_quant_config + if self.experts_cls is None: + raise RuntimeError("ModelOpt NVFP4 MoE experts class is missing") + routing_tables = ( + layer._maybe_init_expert_routing_tables() + if hasattr(layer, "_maybe_init_expert_routing_tables") + else None + ) + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=getattr(layer, "shared_experts", None), + routing_tables=routing_tables, + ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) + + +def _modelopt_kv_cache_process_weights(self, layer: torch.nn.Module) -> None: + """Update KV-cache quantization scales without deleting reload parameters.""" + from vllm.platforms import current_platform + + def copy_scalar(dst: torch.Tensor, value: float | torch.Tensor) -> None: + if isinstance(value, torch.Tensor): + dst.copy_(value.to(device=dst.device, dtype=dst.dtype)) + else: + dst.fill_(value) + + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + k_scale = 1.0 + v_scale = 1.0 + else: + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor for fp8 KV cache") + + if layer.q_scale < 0.0: + copy_scalar(layer._q_scale, k_scale) + layer._q_scale_float = k_scale + + copy_scalar(layer._k_scale, k_scale) + copy_scalar(layer._v_scale, v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + + if layer.q_scale > 0.0: + q_scale = layer.q_scale + if current_platform.is_fp8_fnuz(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale + if current_platform.is_fp8_fnuz(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + def is_singleton_float(value: float | torch.Tensor) -> bool: + return isinstance(value, float) or ( + isinstance(value, torch.Tensor) + and value.numel() == 1 + and value.is_floating_point() + ) + + if not is_singleton_float(q_scale) or not is_singleton_float(prob_scale): + raise ValueError("Only support per-tensor scaling factor for fp8 Q/prob") + + copy_scalar(layer._q_scale, q_scale) + layer._q_scale_float = ( + q_scale.item() if isinstance(q_scale, torch.Tensor) else q_scale + ) + copy_scalar(layer._prob_scale, prob_scale) + + +def _modelopt_moe_process_weights(self, layer: torch.nn.Module) -> None: + """Convert MoE ModelOpt NVFP4 weights after initial load or refit.""" + if not _is_w4a16_marlin_moe_quant_method(self): + original_process_weights = getattr( + type(self), + _ORIGINAL_FUSED_MOE_PROCESS_WEIGHTS_ATTR, + ) + original_process_weights(self, layer) + return + + _capture_modelopt_param_reload_meta(layer, _MOE_HF_PARAMS) + is_first_call = _is_first_modelopt_process_weights_call(layer) + _zero_modelopt_moe_padding(layer) + _modelopt_moe_process_w4a16_marlin_weights( + self, + layer, + is_first_call=is_first_call, + ) + + +def prepare_modelopt_for_weight_reload( + model: torch.nn.Module, + device: torch.device | str | None = None, +) -> None: + """Prepare a ModelOpt-vLLM model for one weight reload cycle.""" inner_model = _unwrap_vllm_model(model) for module in inner_model.modules(): - layer_meta = getattr(module, "_nrl_modelopt_param_meta", None) + layer_meta = getattr(module, _MODELOPT_PARAM_META_ATTR, None) if layer_meta is None: continue + weight_loaders = getattr(module, _MODELOPT_WEIGHT_LOADERS_ATTR, {}) for param_name, meta in layer_meta.items(): param = getattr(module, param_name, None) - weight_loader = module._nrl_modelopt_weight_loaders.get(param_name) + weight_loader = weight_loaders.get(param_name) param_class = meta["param_class"] + default_to_one = param_name in _MOE_INPUT_SCALE_PARAMS if ( param is None or tuple(param.shape) != tuple(meta["shape"]) @@ -223,6 +629,8 @@ def prepare_modelopt_for_weight_reload(model, device=None) -> None: dtype=meta["dtype"], device=device or meta["device"], ) + if default_to_one: + data.fill_(1.0) if param_class is not Parameter and weight_loader is not None: kwargs = {"data": data, "weight_loader": weight_loader} if "input_dim" in meta: @@ -234,16 +642,33 @@ def prepare_modelopt_for_weight_reload(model, device=None) -> None: replacement = Parameter(data, requires_grad=False) if weight_loader is not None: replacement.weight_loader = weight_loader + for attr, value in meta.get("attrs", {}).items(): + setattr(replacement, attr, value) setattr(module, param_name, replacement) + elif default_to_one: + # W4A16 exports do not stream MoE activation scales. Reset to + # the neutral scale before loading; W4A8 streams overwrite it. + param.data.fill_(1.0) -def modelopt_process_weights_after_loading(model) -> None: - """Run vLLM ModelOpt post-load processing for dense quantized layers.""" +def modelopt_process_weights_after_loading(model: torch.nn.Module) -> None: + """Run vLLM ModelOpt post-load processing for quantized layers.""" actual_model = _unwrap_vllm_model(model) for module in actual_model.modules(): + scheme = getattr(module, "scheme", None) + if ( + scheme is not None + and getattr(type(scheme), "process_weights_after_loading", None) + is _modelopt_kv_cache_process_weights + ): + scheme.process_weights_after_loading(module) + quant_method = getattr(module, "quant_method", None) - if quant_method.__class__.__name__ == "ModelOptNvFp4LinearMethod": + if quant_method.__class__.__name__ in ( + "ModelOptNvFp4LinearMethod", + "ModelOptNvFp4FusedMoE", + ): quant_method.process_weights_after_loading(module) @@ -251,7 +676,7 @@ def modelopt_process_weights_after_loading(model) -> None: def apply_modelopt_nvfp4_patches() -> None: - """Patch vLLM's dense ModelOpt NVFP4 method for rollout refits.""" + """Patch vLLM's ModelOpt NVFP4 methods for rollout refits.""" global _patched if _patched: @@ -262,23 +687,63 @@ def apply_modelopt_nvfp4_patches() -> None: ModelOptNvFp4LinearMethod, ) - if not hasattr(ModelOptNvFp4Config, _ORIGINAL_NVFP4_CONFIG_FROM_CONFIG_ATTR): - setattr( - ModelOptNvFp4Config, - _ORIGINAL_NVFP4_CONFIG_FROM_CONFIG_ATTR, - ModelOptNvFp4Config._from_config, + try: + from vllm.model_executor.layers.quantization.kv_cache import ( + BaseKVCacheMethod, + ) + except ImportError: + BaseKVCacheMethod = None + try: + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4FusedMoE, ) + except ImportError: + ModelOptNvFp4FusedMoE = None + + _stash_original( + ModelOptNvFp4Config, + _ORIGINAL_NVFP4_CONFIG_FROM_CONFIG_ATTR, + ModelOptNvFp4Config._from_config, + ) ModelOptNvFp4Config._from_config = classmethod(_modelopt_nvfp4_config_from_config) - if not hasattr(ModelOptNvFp4LinearMethod, _ORIGINAL_LINEAR_APPLY_ATTR): - setattr( - ModelOptNvFp4LinearMethod, - _ORIGINAL_LINEAR_APPLY_ATTR, - ModelOptNvFp4LinearMethod.apply, - ) + _stash_original( + ModelOptNvFp4LinearMethod, + _ORIGINAL_LINEAR_APPLY_ATTR, + ModelOptNvFp4LinearMethod.apply, + ) ModelOptNvFp4LinearMethod.process_weights_after_loading = ( _modelopt_dense_process_weights ) ModelOptNvFp4LinearMethod.apply = _modelopt_dense_apply + if ModelOptNvFp4FusedMoE is not None: + moe_patches = { + "__init__": (_ORIGINAL_FUSED_MOE_INIT_ATTR, _modelopt_moe_init), + "process_weights_after_loading": ( + _ORIGINAL_FUSED_MOE_PROCESS_WEIGHTS_ATTR, + _modelopt_moe_process_weights, + ), + } + if hasattr(ModelOptNvFp4FusedMoE, "maybe_roundup_sizes"): + moe_patches["maybe_roundup_sizes"] = ( + _ORIGINAL_FUSED_MOE_ROUNDUP_SIZES_ATTR, + _modelopt_moe_maybe_roundup_sizes, + ) + for method_name, (original_attr, replacement) in moe_patches.items(): + _stash_original( + ModelOptNvFp4FusedMoE, + original_attr, + getattr(ModelOptNvFp4FusedMoE, method_name), + ) + setattr(ModelOptNvFp4FusedMoE, method_name, replacement) + if BaseKVCacheMethod is not None: + _stash_original( + BaseKVCacheMethod, + _ORIGINAL_KV_CACHE_PROCESS_WEIGHTS_ATTR, + BaseKVCacheMethod.process_weights_after_loading, + ) + BaseKVCacheMethod.process_weights_after_loading = ( + _modelopt_kv_cache_process_weights + ) _patched = True diff --git a/nemo_rl/modelopt/models/generation/vllm_quant_backend.py b/nemo_rl/modelopt/models/generation/vllm_quant_backend.py index 7298f56094..888ae65a0a 100644 --- a/nemo_rl/modelopt/models/generation/vllm_quant_backend.py +++ b/nemo_rl/modelopt/models/generation/vllm_quant_backend.py @@ -22,7 +22,6 @@ from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer from nemo_rl.modelopt.utils import ( - iter_quant_ignore_name_candidates, matches_quant_ignore_pattern, ) from nemo_rl.models.generation.vllm.vllm_backend import VllmInternalWorkerExtension @@ -88,10 +87,6 @@ def _load_weights(self, weights): self.model_runner.vllm_config.model_config.hf_config.quantization_config ) ignore_patterns = quant_config.get("ignore", []) or [] - # Built lazily on first use: only the rare ignored, floating-point - # weights (typically just lm_head) need a parameter lookup, so most - # refit chunks skip the full named_parameters() scan entirely. - params = None filtered = [] for name, weight in weights: suffix = name.rsplit(".", 1)[-1] @@ -99,23 +94,6 @@ def _load_weights(self, weights): if ignored and suffix in {"weight_scale", "weight_scale_2"}: continue - if ignored and suffix == "weight" and weight.is_floating_point(): - if params is None: - params = dict(self.model_runner.model.named_parameters()) - copied = False - for candidate in iter_quant_ignore_name_candidates(name): - param = params.get(candidate) - if param is not None and tuple(param.shape) == tuple( - weight.shape - ): - param.data.copy_( - weight.to(device=param.device, dtype=param.dtype) - ) - copied = True - break - if copied: - continue - filtered.append((name, weight)) weights = filtered if not weights: diff --git a/nemo_rl/modelopt/models/generation/vllm_quant_patch.py b/nemo_rl/modelopt/models/generation/vllm_quant_patch.py index de58172722..2919096ae1 100644 --- a/nemo_rl/modelopt/models/generation/vllm_quant_patch.py +++ b/nemo_rl/modelopt/models/generation/vllm_quant_patch.py @@ -47,9 +47,9 @@ def _tolerate_dummy_weight_nan_amax(): Scoping this monkey-patch to the prolog (instead of editing `MaxCalibrator.collect` in modelopt) keeps modelopt's source pristine - and limits the workaround to the single dummy-weight code path that - needs it. Genuine numerical NaN at runtime — when the calibrator is - no longer active — would still be caught by the production callsite. + and limits the behavior to the single dummy-weight code path that needs + it. Genuine numerical NaN at runtime — when the calibrator is no longer + active — would still be caught by the production callsite. Nonfinite dummy activations are sanitized before calibration reduce. The patch is active only inside the dummy-weight prolog, before runtime diff --git a/nemo_rl/modelopt/models/generation/vllm_quant_worker.py b/nemo_rl/modelopt/models/generation/vllm_quant_worker.py index d7c81d3a34..70cd551055 100644 --- a/nemo_rl/modelopt/models/generation/vllm_quant_worker.py +++ b/nemo_rl/modelopt/models/generation/vllm_quant_worker.py @@ -30,9 +30,17 @@ _EXTRA_ENV_VARS = ( "VLLM_QUANT_CFG", "VLLM_MODELOPT_REAL_QUANT", + "PYTHONPATH", ) +def _quant_cfg_for_worker_env(quant_cfg: str) -> str: + expanded = os.path.expanduser(quant_cfg) + if os.path.isfile(expanded): + return os.path.abspath(expanded) + return quant_cfg + + def _configure_quant_engine_kwargs( cfg: VllmConfig, llm_kwargs: dict[str, Any], @@ -48,6 +56,7 @@ def _configure_quant_engine_kwargs( from nemo_rl.modelopt.utils import build_vllm_modelopt_nvfp4_config apply_modelopt_nvfp4_patches() + os.environ.pop("VLLM_QUANT_CFG", None) os.environ["VLLM_MODELOPT_REAL_QUANT"] = "1" hf_overrides = llm_kwargs.setdefault("hf_overrides", {}) @@ -61,8 +70,10 @@ def _configure_quant_engine_kwargs( ) # Expert fakequant needs a decomposed MoE path; explicit user config still wins. llm_kwargs.setdefault("moe_backend", "triton") + os.environ.pop("VLLM_MODELOPT_REAL_QUANT", None) + os.environ.pop("VLLM_QUANT_CFG", None) if cfg["quant_cfg"]: - os.environ["VLLM_QUANT_CFG"] = cfg["quant_cfg"] + os.environ["VLLM_QUANT_CFG"] = _quant_cfg_for_worker_env(cfg["quant_cfg"]) @ray.remote( diff --git a/nemo_rl/modelopt/models/policy/workers/megatron_quant_policy_worker.py b/nemo_rl/modelopt/models/policy/workers/megatron_quant_policy_worker.py index 395dc10d18..2e58f0c66c 100644 --- a/nemo_rl/modelopt/models/policy/workers/megatron_quant_policy_worker.py +++ b/nemo_rl/modelopt/models/policy/workers/megatron_quant_policy_worker.py @@ -13,22 +13,23 @@ # limitations under the License. +import hashlib +import json import os +from collections.abc import Generator, Mapping from contextlib import contextmanager -from typing import Generator +from pathlib import Path -import modelopt.torch.quantization as mtq import ray import torch +from modelopt.torch.quantization.nn.modules.quant_module import QuantModule + +import nemo_rl.models.policy.workers.megatron_policy_worker as megatron_policy_worker from megatron.bridge.training.post_training.checkpointing import ( has_modelopt_state, load_modelopt_state, ) from megatron.core.utils import unwrap_model -from modelopt.torch.quantization.nn.modules.quant_module import QuantModule -from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer - -import nemo_rl.models.policy.workers.megatron_policy_worker as megatron_policy_worker from nemo_rl.modelopt.models.policy.workers.utils import ( get_quantization_layer_spec, get_quantization_mamba_stack_spec, @@ -42,27 +43,36 @@ ) -@contextmanager -def _w4a16_modelopt_exporter(): - """Temporarily adapt Bridge's NVFP4 exporter for W4A16 rollout metadata. - - Get this removed when Bridge fixed the logic here. - """ - from megatron.bridge.models.conversion import modelopt_utils - from modelopt.torch.export.quant_utils import QUANTIZATION_W4A16_NVFP4 +def _quant_checkpoint_cache_suffix(config: Mapping[str, object]) -> str: + """Build a short suffix for HF->Megatron checkpoints with ModelOpt state.""" + keys = ( + "quant_cfg", + "quant_calib_data", + "quant_calib_size", + "quant_batch_size", + "quant_sequence_length", + "disable_modelopt_layer_spec", + ) + payload = {key: config.get(key) for key in keys} + quant_cfg = payload["quant_cfg"] + path = Path(quant_cfg).expanduser() if isinstance(quant_cfg, str) else None + if path is not None and path.is_file(): + payload["quant_cfg"] = { + "path": path.resolve().as_posix(), + "sha256": hashlib.sha256(path.read_bytes()).hexdigest(), + } + digest = hashlib.sha256( + json.dumps(payload, sort_keys=True, separators=(",", ":")).encode() + ).hexdigest()[:12] + return f"_modelopt_{digest}" - original_get_exporter = modelopt_utils.get_modelopt_quant_exporter - def _get_modelopt_quant_exporter(quant_mode: str): - if quant_mode.lower() == "w4a16_nvfp4": - return QUANTIZATION_W4A16_NVFP4, modelopt_utils.quantize_nvfp4_weight - return original_get_exporter(quant_mode) +def _tensor_quantizer_cls() -> type: + # Keep ModelOpt imports lazy so non-quant policy workers do not load + # ModelOpt's Megatron plugin path during module import. + from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer - modelopt_utils.get_modelopt_quant_exporter = _get_modelopt_quant_exporter - try: - yield - finally: - modelopt_utils.get_modelopt_quant_exporter = original_get_exporter + return TensorQuantizer @ray.remote( @@ -106,9 +116,6 @@ def __init__(self, config, *args, **kwargs): self.reference_state_dict[name] = item.detach().to( device="cpu", non_blocking=True, copy=True ) - if self.rank == 0: - print(f"Quantized model: {self.model}") - mtq.print_quant_summary(self.model) def _quantize(self, model): """Quantize the model if the model is not quantized yet.""" @@ -131,8 +138,8 @@ def _patch_validate_model_paths(self): """Patch validate_model_paths to handle quantized checkpoint paths. In cases like distillation where the teacher model is the same as the student model, - we need to save an extra quantized checkpoint. This patch checks for modelopt state - and redirects to a _quantized suffix path. It also handles pre-quantized model symlinks. + we need to save an extra quantized checkpoint. This patch routes auto-converted HF + checkpoints to a ModelOpt-specific cache path. It also handles pre-quantized model symlinks. """ if getattr(megatron_policy_worker.validate_model_paths, "_is_patched", False): return @@ -143,11 +150,14 @@ def _validate_model_paths(config): original_validate_model_paths(config) ) + if config.get("pretrained_checkpoint") is not None: + return hf_model_name, pretrained_path, pt_checkpoint_exists + + pretrained_path += _quant_checkpoint_cache_suffix(config) iter0_path = os.path.join(pretrained_path, "iter_0000000") - if pt_checkpoint_exists and not has_modelopt_state(iter0_path): - pretrained_path += "_quantized" - iter0_path = os.path.join(pretrained_path, "iter_0000000") - pt_checkpoint_exists = os.path.exists(iter0_path) + pt_checkpoint_exists = os.path.exists(iter0_path) and has_modelopt_state( + iter0_path + ) pre_quantized_model_path = os.environ.get( "NRL_PRE_QUANTIZED_MEGATRON_MODEL_PATH" @@ -224,6 +234,7 @@ def hide_tensor_quantizers(self): yield return + TensorQuantizer = _tensor_quantizer_cls() inner_module = self.model.module original_named_modules = inner_module.named_modules @@ -251,6 +262,7 @@ def disable_forward_pre_hook(self, param_sync=True): @contextmanager def disable_quantization(self): """Context manager that temporarily disables quantization.""" + TensorQuantizer = _tensor_quantizer_cls() quantizers = [] try: for _, module in self.model.named_modules(): @@ -306,6 +318,7 @@ def without_model_config(self): triggers spurious validation/serialization errors. We strip it for the duration of the call and restore it on exit. """ + TensorQuantizer = _tensor_quantizer_cls() configs = {} try: for name, module in self.model.named_modules(): @@ -323,6 +336,7 @@ def get_quantizer_stats(self) -> dict: Useful for verifying that calibration ran and amax values are valid. """ + TensorQuantizer = _tensor_quantizer_cls() total = 0 enabled = 0 with_amax = 0 @@ -360,31 +374,31 @@ def save_checkpoint(self, *args, **kwargs): return super().save_checkpoint(*args, **kwargs) def _use_real_quant_refit(self) -> bool: - generation_cfg = self.cfg.get("generation") or {} + generation_cfg = self.cfg["generation"] return ( - generation_cfg.get("backend") == "vllm" + generation_cfg["backend"] == "vllm" and generation_cfg.get("quant_cfg") is not None and bool(generation_cfg.get("real_quant")) ) - def _iter_real_quant_refit_params(self, kv_scales=None): + def _iter_real_quant_refit_params( + self, + kv_scales: dict[str, float] | None = None, + ) -> Generator[tuple[str, torch.Tensor], None, None]: """Export packed NVFP4 weights and scales for real-quant vLLM rollout.""" - from nemo_rl.modelopt.utils import DEFAULT_NVFP4_IGNORE - - generation_cfg = self.cfg.get("generation") or {} - vllm_cfg = generation_cfg.get("vllm_cfg", {}) - ignore = generation_cfg.get("real_quant_ignore") - if ignore is None: - ignore = DEFAULT_NVFP4_IGNORE - with _w4a16_modelopt_exporter(): - yield from self.megatron_bridge.export_hf_weights_modelopt( - [self.model], - quant_mode="w4a16_nvfp4", - cpu=True, - show_progress=False, - conversion_tasks=self.refit_conversion_tasks, - ignore_patterns=ignore, - ) + from nemo_rl.modelopt.utils import resolve_real_quant_ignore + + generation_cfg = self.cfg["generation"] + vllm_cfg = generation_cfg["vllm_cfg"] + ignore = resolve_real_quant_ignore(generation_cfg.get("real_quant_ignore")) + yield from self.megatron_bridge.export_hf_weights_modelopt( + [self.model], + quant_mode="w4a16_nvfp4", + cpu=True, + show_progress=False, + conversion_tasks=self.refit_conversion_tasks, + ignore_patterns=ignore, + ) if self.draft_model is not None: from nemo_rl.models.megatron.draft import export_eagle_weights_to_hf @@ -392,7 +406,7 @@ def _iter_real_quant_refit_params(self, kv_scales=None): for name, tensor in export_eagle_weights_to_hf(self.draft_model): yield f"draft.{name}", tensor - if not vllm_cfg.get("kv_cache_dtype", "").startswith("fp8"): + if not vllm_cfg["kv_cache_dtype"].startswith("fp8"): return from nemo_rl.models.generation.vllm.quantization.fp8_train_utils import ( @@ -417,28 +431,23 @@ def _iter_real_quant_refit_params(self, kv_scales=None): ) @staticmethod - def _find_weight_quantizer(module, param_weight): - """Find the enabled weight quantizer that corresponds to ``param_weight``. - - Uses ModelOpt's ``QuantModule.iter_weights_for_calibration`` to discover - ``(weight, weight_quantizer)`` pairs, then matches by identity. - This handles standard ``weight`` / ``weight_quantizer`` as well as - custom names like ``gate_up_proj`` / ``gate_up_proj_weight_quantizer``. - - Returns the matching ``TensorQuantizer`` or ``None``. - """ + def _find_weight_quantizer( + module: object, + param_weight: object, + ) -> object | None: + """Find the enabled ModelOpt weight quantizer for ``param_weight``.""" if module is None or param_weight is None: return None - if not isinstance(module, QuantModule): - return None - for weight, wq in module.iter_weights_for_calibration(): - if ( - param_weight is weight - and isinstance(wq, TensorQuantizer) - and wq.is_enabled - ): - return wq - return None + + from megatron.bridge.models.conversion.modelopt_utils import ( + find_modelopt_weight_quantizer_and_module, + ) + + weight_quantizer, _ = find_modelopt_weight_quantizer_and_module( + module, + param_weight, + ) + return weight_quantizer @staticmethod def _iter_hf_input_amax_names(mapping): @@ -448,6 +457,7 @@ def _iter_hf_input_amax_names(mapping): @staticmethod def _get_enabled_input_amax(task): + TensorQuantizer = _tensor_quantizer_cls() input_quantizer = getattr(task.megatron_module, "input_quantizer", None) if not isinstance(input_quantizer, TensorQuantizer): return None @@ -514,7 +524,7 @@ def param_weight(self): return None try: return self._wq(w.float()).to(w.dtype) - except Exception as e: + except RuntimeError as e: raise RuntimeError( f"Failed to apply weight quantizer for param " f"'{self._task.param_name}': {e}" diff --git a/nemo_rl/modelopt/models/policy/workers/utils.py b/nemo_rl/modelopt/models/policy/workers/utils.py index d8f9065a85..0a9c2a2fcb 100644 --- a/nemo_rl/modelopt/models/policy/workers/utils.py +++ b/nemo_rl/modelopt/models/policy/workers/utils.py @@ -17,23 +17,16 @@ from functools import partial from pathlib import Path -import modelopt.torch.quantization as mtq import torch import torch.nn as nn +from torch.utils.data import DataLoader, Dataset + from megatron.bridge.models.gpt_provider import transformer_engine_layer_spec from megatron.bridge.models.mamba.mamba_provider import ( modelopt_mamba_stack_spec, transformer_engine_mamba_stack_spec, ) from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec -from modelopt.torch.quantization.config import need_calibration -from modelopt.torch.utils.dataset_utils import ( - create_forward_loop, - get_dataset_dataloader, -) -from modelopt.torch.utils.plugins import megatron_prefill -from torch.utils.data import DataLoader, Dataset - from nemo_rl.algorithms.utils import get_tokenizer as _base_get_tokenizer from nemo_rl.modelopt.utils import resolve_quant_cfg @@ -86,9 +79,13 @@ def get_forward_loop_func( ): """Gets the forward loop function for the model.""" if not is_megatron: + from modelopt.torch.utils.dataset_utils import create_forward_loop + return create_forward_loop(dataloader=calib_dataloader) def _forward_loop(model): + from modelopt.torch.utils.plugins import megatron_prefill + for batch in calib_dataloader: megatron_prefill(model, batch["input_ids"], skip_return_logits=True) @@ -117,6 +114,11 @@ def quantize_model( auto_quantize_bits: The effective bits constraint for auto_quantize. data: the name of the calibration dataset. """ + # Keep ModelOpt imports local because this module is imported by workers + # that do not always run quantization. + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.config import need_calibration + mtq_cfg = resolve_quant_cfg(quant_cfg) use_calibration = need_calibration(mtq_cfg) if not use_calibration: @@ -145,6 +147,8 @@ def quantize_model( batch_size=1, ) else: + from modelopt.torch.utils.dataset_utils import get_dataset_dataloader + calib_dataloader = get_dataset_dataloader( dataset_name=data, tokenizer=tokenizer, diff --git a/nemo_rl/modelopt/utils.py b/nemo_rl/modelopt/utils.py index c6771e888b..d044aa1a64 100644 --- a/nemo_rl/modelopt/utils.py +++ b/nemo_rl/modelopt/utils.py @@ -24,7 +24,6 @@ ".weight_scale", ".weight_scale_2", ) - # Layers kept in native dtype by the real-quant vLLM rollout. Shared between the # vLLM quantization_config and the Megatron export-side ignore patterns. DEFAULT_NVFP4_IGNORE = [ @@ -37,6 +36,39 @@ "*self_attn*", ] +# Nano3 keeps non-MLP/MoE GEMM paths and several sensitive layers in BF16. +NANO3_NVFP4_IGNORE = [ + *DEFAULT_NVFP4_IGNORE, + "*proj_out.*", + "*.gate.*", + "*mlp.shared_expert_gate.*", + "*linear_attn.conv1d*", + "*mixer.conv1d*", + "*.mixer.in_proj*", + "*.mixer.out_proj*", + "*.shared_expert.*", + "*.shared_experts.*", + "*.norm.*", + "*.q_proj*", + "*.k_proj*", + "*.v_proj*", + "*.o_proj*", + "*.qkv_proj*", + "*.linear_proj*", + "*.linear_qkv*", + "*.layers.4.*", + "*.layers.11.*", + "*.layers.18.*", + "*.layers.25.*", + "*.layers.32.*", + "*.layers.41.*", +] + +REAL_QUANT_IGNORE_ALIASES = { + "DEFAULT_NVFP4_IGNORE": DEFAULT_NVFP4_IGNORE, + "NANO3_NVFP4_IGNORE": NANO3_NVFP4_IGNORE, +} + def _iter_quant_ignore_suffix_variants(name: str) -> Iterator[str]: """Yield ``name`` and, if it ends in a known quant suffix, the stripped form.""" @@ -69,9 +101,25 @@ def matches_quant_ignore_pattern(name: str, patterns: list[str]) -> bool: ) +def resolve_real_quant_ignore(ignore: list[str] | str | None = None) -> list[str]: + """Resolve a real-quant ignore profile or explicit pattern list.""" + if ignore is None: + return list(DEFAULT_NVFP4_IGNORE) + if isinstance(ignore, str): + alias = REAL_QUANT_IGNORE_ALIASES.get(ignore) + if alias is None: + known = ", ".join(sorted(REAL_QUANT_IGNORE_ALIASES)) + raise ValueError( + f"Unknown real_quant_ignore profile '{ignore}'. " + f"Expected one of: {known}, or provide a list of patterns." + ) + return list(alias) + return list(ignore) + + def build_vllm_modelopt_nvfp4_config( *, - ignore: list[str] | None = None, + ignore: list[str] | str | None = None, ) -> dict[str, Any]: """Build the HuggingFace quantization_config consumed by vLLM ModelOpt NVFP4. @@ -93,7 +141,7 @@ def build_vllm_modelopt_nvfp4_config( "targets": ["Linear"], } }, - "ignore": ignore if ignore is not None else list(DEFAULT_NVFP4_IGNORE), + "ignore": resolve_real_quant_ignore(ignore), "quant_algo": "NVFP4", "quant_mode": "w4a16_nvfp4", "weight_only": True, diff --git a/nemo_rl/models/generation/vllm/config.py b/nemo_rl/models/generation/vllm/config.py index c3548460d6..28d97ae9e4 100644 --- a/nemo_rl/models/generation/vllm/config.py +++ b/nemo_rl/models/generation/vllm/config.py @@ -60,4 +60,4 @@ class VllmConfig(GenerationConfig): # NVFP4 kernels and stream packed quantized weights instead of fake-quant # modules. This is intended for ModelOpt NVFP4 rollout experiments. real_quant: NotRequired[bool] - real_quant_ignore: NotRequired[list[str]] + real_quant_ignore: NotRequired[list[str] | str] diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 0eec29f04a..b7111efce7 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -22,6 +22,8 @@ import ray import torch +from transformers import PreTrainedTokenizerBase + from megatron.bridge.training.checkpointing import ( maybe_finalize_async_save, save_checkpoint, @@ -40,8 +42,6 @@ from megatron.core.optimizer import ChainedOptimizer from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.utils import get_model_config -from transformers import PreTrainedTokenizerBase - from nemo_rl.algorithms.logits_sampling_utils import TrainingSamplingParams from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.data_plane.worker_mixin import TQWorkerMixin @@ -1366,9 +1366,6 @@ def broadcast_weights_for_collective( post_iter_func=lambda x: x[1], ) - def _use_real_quant_refit(self) -> bool: - return False - def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) self.model.eval() diff --git a/tests/functional/modelopt_quant_rollout.sh b/tests/functional/modelopt_quant_rollout.sh index 857103ad65..0471b634fd 100644 --- a/tests/functional/modelopt_quant_rollout.sh +++ b/tests/functional/modelopt_quant_rollout.sh @@ -60,7 +60,7 @@ run_quant_rollout_case() { cd "$PROJECT_ROOT" NRL_MEGATRON_CHECKPOINT_DIR="$megatron_cache_dir" \ - uv run --extra modelopt --group test \ + uv run --no-sync --extra modelopt --group test \ coverage run -a --data-file="$PROJECT_ROOT/tests/.coverage" --source="$PROJECT_ROOT/nemo_rl" \ "$PROJECT_ROOT/examples/run_grpo.py" \ --config "$PROJECT_ROOT/examples/modelopt/qa_grpo_math_megatron.yaml" \ @@ -96,9 +96,9 @@ run_quant_rollout_case() { "$@" \ 2>&1 | tee "$run_log" - uv run --extra modelopt --group test tests/json_dump_tb_logs.py "$log_dir" --output_path "$metrics_json" + uv run --no-sync --extra modelopt --group test tests/json_dump_tb_logs.py "$log_dir" --output_path "$metrics_json" - uv run --extra modelopt --group test tests/check_metrics.py "$metrics_json" \ + uv run --no-sync --extra modelopt --group test tests/check_metrics.py "$metrics_json" \ "data[\"train/gen_kl_error\"][\"1\"] < $gen_kl_error_step1_max" \ "max(data[\"train/token_mult_prob_error\"]) < $token_mult_prob_error_max" @@ -109,7 +109,6 @@ run_quant_rollout_case() { assert_not_grep "VLLM_QUANT_CFG" "$run_log" else assert_grep "FakeQuantWorker" "$run_log" - assert_grep "VLLM_QUANT_CFG" "$run_log" assert_not_grep "Detected ModelOpt NVFP4 checkpoint" "$run_log" fi } diff --git a/tests/test_suites/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.sh b/tests/test_suites/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.sh new file mode 100755 index 0000000000..2c2d6561cd --- /dev/null +++ b/tests/test_suites/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# One-step Nano3 W4A16 real-quant rollout check. This validates that Megatron +# exports ModelOpt NVFP4 packed tensors, vLLM loads them through the real +# ModelOpt kernel path, and generation/policy logprobs stay aligned. +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source "$SCRIPT_DIR/common.env" + +# ===== BEGIN CONFIG ===== +NUM_NODES=4 +GPUS_PER_NODE=4 +STEPS_PER_RUN=1 +MAX_STEPS=1 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=180 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +cd "$PROJECT_ROOT" +export PYTHONPATH="$PROJECT_ROOT/3rdparty/Megatron-Bridge-workspace/Megatron-Bridge/src:${PYTHONPATH:-}" + +uv run --no-sync examples/run_grpo.py \ + --config "$CONFIG_PATH" \ + grpo.max_num_steps=$MAX_STEPS \ + cluster.num_nodes=$NUM_NODES \ + cluster.gpus_per_node=$GPUS_PER_NODE \ + logger.log_dir="$LOG_DIR" \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name="$EXP_NAME" \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir="$CKPT_DIR" \ + "$@" \ + 2>&1 | tee "$RUN_LOG" + +uv run --no-sync tests/json_dump_tb_logs.py "$LOG_DIR" --output_path "$JSON_METRICS" + +grep -q "VllmQuantInternalWorkerExtension" "$RUN_LOG" +grep -q "Detected ModelOpt NVFP4 checkpoint" "$RUN_LOG" +grep -q "quantization=modelopt" "$RUN_LOG" +! grep -q "FakeQuantWorker" "$RUN_LOG" +! grep -q "VLLM_QUANT_CFG" "$RUN_LOG" + +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' "$JSON_METRICS") -ge $MAX_STEPS ]]; then + uv run --no-sync tests/check_metrics.py "$JSON_METRICS" \ + 'data["train/gen_kl_error"]["1"] < 0.003' \ + 'max(data["train/token_mult_prob_error"]) < 1.05' \ + 'data["train/loss"]["1"] > 0.0' \ + 'data["train/num_valid_samples"]["1"] > 0' +fi diff --git a/tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a8-fake.sh b/tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a8-fake.sh index ad48db1209..917a0a5fd7 100755 --- a/tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a8-fake.sh +++ b/tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a8-fake.sh @@ -67,7 +67,6 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma 'max(data["train/token_mult_prob_error"]) < 1.06' grep -q "FakeQuantWorker" "$RUN_LOG" - grep -q "VLLM_QUANT_CFG" "$RUN_LOG" ! grep -q "Detected ModelOpt NVFP4 checkpoint" "$RUN_LOG" # Clean up checkpoint directory after successful run to save space. diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 2c2fa417dd..ab1c1d6b09 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -86,6 +86,7 @@ tests/test_suites/llm/grpo-qwen3-8b-base-1n8g-fp8-kvcache-megatron.sh tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron.sh tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a16.sh tests/test_suites/llm/grpo-qwen2.5-0.5b-dapo-1n8g-megatron-qa-nvfp4-w4a8-fake.sh +tests/test_suites/llm/grpo-nanov3-30ba3b-4n4g-megatron-qa-nvfp4-w4a16-real.sh # Non-colocated tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh diff --git a/tests/unit/models/generation/test_vllm_modelopt_real_quant_config.py b/tests/unit/models/generation/test_vllm_modelopt_real_quant_config.py index 207d397755..f9e2006ef9 100644 --- a/tests/unit/models/generation/test_vllm_modelopt_real_quant_config.py +++ b/tests/unit/models/generation/test_vllm_modelopt_real_quant_config.py @@ -21,10 +21,11 @@ import torch from nemo_rl.modelopt.models.generation.vllm_modelopt_patch import ( - _canonicalize_nvfp4_weight_scale, _convert_nvfp4_linear_kernel_format, _modelopt_dense_apply, _modelopt_dense_process_weights, + _modelopt_kv_cache_process_weights, + _zero_modelopt_moe_padding, apply_modelopt_nvfp4_patches, modelopt_process_weights_after_loading, prepare_modelopt_for_weight_reload, @@ -34,6 +35,7 @@ iter_quant_ignore_name_candidates, matches_quant_ignore_pattern, resolve_quant_cfg, + resolve_real_quant_ignore, ) @@ -94,6 +96,33 @@ class FakeTensorQuantizer(torch.nn.Module): ].tensor_quantizer = tensor_quantizer_module +def _make_real_quant_extension(backend, model, ignore): + extension = object.__new__(backend.VllmQuantInternalWorkerExtension) + extension.model_runner = types.SimpleNamespace( + model=model, + vllm_config=types.SimpleNamespace( + model_config=types.SimpleNamespace( + hf_config=types.SimpleNamespace(quantization_config={"ignore": ignore}) + ) + ), + ) + return extension + + +def _patch_real_quant_load(monkeypatch, backend, forwarded=None): + monkeypatch.setattr( + backend.VllmQuantInternalWorkerExtension, + "_is_real_quant_model", + lambda self: True, + ) + if forwarded is not None: + monkeypatch.setattr( + backend.VllmInternalWorkerExtension, + "_load_weights", + lambda self, weights: forwarded.extend(weights) or "loaded", + ) + + def test_w4a16_real_quant_config_keeps_weight_only_default(): cfg = build_vllm_modelopt_nvfp4_config() @@ -127,6 +156,19 @@ def test_real_quant_config_allows_explicit_ignore_override(): assert cfg["ignore"] == ["lm_head"] +def test_real_quant_config_resolves_named_ignore_profiles(): + cfg = build_vllm_modelopt_nvfp4_config(ignore="NANO3_NVFP4_IGNORE") + + assert "*.mixer.in_proj*" in cfg["ignore"] + assert matches_quant_ignore_pattern( + "model.layers.0.mixer.in_proj.weight", + cfg["ignore"], + ) + + with pytest.raises(ValueError, match="Unknown real_quant_ignore profile"): + resolve_real_quant_ignore("missing-profile") + + def test_default_ignore_patterns_match_expected_layers(): ignore_patterns = build_vllm_modelopt_nvfp4_config()["ignore"] @@ -381,7 +423,7 @@ def test_vllm_modelopt_backend_applies_real_quant_patch_on_import(monkeypatch): assert calls == ["patched"] -def test_real_quant_load_weights_copies_ignored_float_weights(monkeypatch): +def test_real_quant_load_weights_forwards_ignored_float_weights(monkeypatch): backend = _import_vllm_quant_backend(monkeypatch) class TinyModel(torch.nn.Module): @@ -391,34 +433,9 @@ def __init__(self): self.keep = torch.nn.Linear(2, 2, bias=False) model = TinyModel() - extension = object.__new__(backend.VllmQuantInternalWorkerExtension) - extension.model_runner = types.SimpleNamespace( - model=model, - vllm_config=types.SimpleNamespace( - model_config=types.SimpleNamespace( - hf_config=types.SimpleNamespace( - quantization_config={"ignore": ["lm_head"]} - ) - ) - ), - ) - forwarded = [] - - def fake_base_load_weights(self, weights): - forwarded.extend(weights) - return "loaded" - - monkeypatch.setattr( - backend.VllmQuantInternalWorkerExtension, - "_is_real_quant_model", - lambda self: True, - ) - monkeypatch.setattr( - backend.VllmInternalWorkerExtension, - "_load_weights", - fake_base_load_weights, - ) + extension = _make_real_quant_extension(backend, model, ["lm_head"]) + _patch_real_quant_load(monkeypatch, backend, forwarded) ignored_weight = torch.full_like(model.lm_head.weight, 7.0) kept_weight = torch.full_like(model.keep.weight, 3.0) @@ -434,76 +451,37 @@ def fake_base_load_weights(self, weights): == "loaded" ) - torch.testing.assert_close(model.lm_head.weight, ignored_weight) - assert [name for name, _ in forwarded] == ["keep.weight"] - torch.testing.assert_close(forwarded[0][1], kept_weight) + assert [name for name, _ in forwarded] == ["lm_head.weight", "keep.weight"] + torch.testing.assert_close(forwarded[0][1], ignored_weight) + torch.testing.assert_close(forwarded[1][1], kept_weight) -def test_real_quant_load_weights_returns_when_only_ignored_weights(monkeypatch): +def test_real_quant_load_weights_returns_when_only_ignored_scales(monkeypatch): backend = _import_vllm_quant_backend(monkeypatch) model = torch.nn.Module() model.lm_head = torch.nn.Linear(2, 2, bias=False) - extension = object.__new__(backend.VllmQuantInternalWorkerExtension) - extension.model_runner = types.SimpleNamespace( - model=model, - vllm_config=types.SimpleNamespace( - model_config=types.SimpleNamespace( - hf_config=types.SimpleNamespace( - quantization_config={"ignore": ["lm_head"]} - ) - ) - ), - ) - monkeypatch.setattr( - backend.VllmQuantInternalWorkerExtension, - "_is_real_quant_model", - lambda self: True, - ) + extension = _make_real_quant_extension(backend, model, ["lm_head"]) + _patch_real_quant_load(monkeypatch, backend) assert ( extension._load_weights( [ - ("lm_head.weight", torch.full_like(model.lm_head.weight, 1.5)), ("lm_head.weight_scale", torch.ones(1)), ("lm_head.weight_scale_2", torch.ones(1)), ] ) is None ) - torch.testing.assert_close( - model.lm_head.weight, - torch.full_like(model.lm_head.weight, 1.5), - ) -def test_real_quant_load_weights_forwards_ignored_shape_mismatch(monkeypatch): +def test_real_quant_load_weights_forwards_ignored_weights_to_vllm_loader(monkeypatch): backend = _import_vllm_quant_backend(monkeypatch) model = torch.nn.Module() model.lm_head = torch.nn.Linear(2, 2, bias=False) - extension = object.__new__(backend.VllmQuantInternalWorkerExtension) - extension.model_runner = types.SimpleNamespace( - model=model, - vllm_config=types.SimpleNamespace( - model_config=types.SimpleNamespace( - hf_config=types.SimpleNamespace( - quantization_config={"ignore": ["lm_head"]} - ) - ) - ), - ) forwarded = [] - - monkeypatch.setattr( - backend.VllmQuantInternalWorkerExtension, - "_is_real_quant_model", - lambda self: True, - ) - monkeypatch.setattr( - backend.VllmInternalWorkerExtension, - "_load_weights", - lambda self, weights: forwarded.extend(weights) or "loaded", - ) + extension = _make_real_quant_extension(backend, model, ["lm_head"]) + _patch_real_quant_load(monkeypatch, backend, forwarded) mismatched = torch.ones(1, dtype=model.lm_head.weight.dtype) @@ -935,21 +913,6 @@ def test_resolve_quant_cfg_rejects_recipe_without_quant_cfg(monkeypatch): resolve_quant_cfg("missing-quant-cfg") -def test_vllm_reload_canonicalizes_nvfp4_scales_before_kernel_conversion(): - layer = torch.nn.Module() - layer.weight_scale = torch.nn.Parameter( - torch.tensor([[1.0, -2.0], [-0.5, 4.0]]), - requires_grad=False, - ) - - _canonicalize_nvfp4_weight_scale(layer) - - torch.testing.assert_close( - layer.weight_scale, - torch.tensor([[1.0, 2.0], [0.5, 4.0]]), - ) - - def test_prepare_modelopt_for_weight_reload_restores_deleted_dense_params(): layer = torch.nn.Module() layer.weight = torch.nn.Parameter(torch.ones(2, 2), requires_grad=False) @@ -1047,20 +1010,60 @@ def fake_loader(param, loaded_weight): torch.testing.assert_close(layer.weight, torch.ones(2, 2)) -def test_modelopt_process_weights_after_loading_runs_dense_quant_method(): +def test_modelopt_process_weights_after_loading_runs_modelopt_quant_methods(): calls = [] class ModelOptNvFp4LinearMethod: def process_weights_after_loading(self, layer): calls.append(layer) + class ModelOptNvFp4FusedMoE: + def process_weights_after_loading(self, layer): + calls.append(layer) + + model = torch.nn.Module() + model.dense_layer = torch.nn.Module() + model.dense_layer.quant_method = ModelOptNvFp4LinearMethod() + model.moe_layer = torch.nn.Module() + model.moe_layer.quant_method = ModelOptNvFp4FusedMoE() + + modelopt_process_weights_after_loading(model) + + assert calls == [model.dense_layer, model.moe_layer] + + +def test_modelopt_process_weights_after_loading_runs_patched_kv_scheme(monkeypatch): + platforms = types.ModuleType("vllm.platforms") + platforms.current_platform = types.SimpleNamespace(is_fp8_fnuz=lambda: False) + monkeypatch.setitem(sys.modules, "vllm.platforms", platforms) + + class BaseKVCacheMethod: + process_weights_after_loading = _modelopt_kv_cache_process_weights + model = torch.nn.Module() - model.layer = torch.nn.Module() - model.layer.quant_method = ModelOptNvFp4LinearMethod() + model.kv_layer = torch.nn.Module() + model.kv_layer.scheme = BaseKVCacheMethod() + model.kv_layer.kv_cache_dtype = "fp8" + model.kv_layer.calculate_kv_scales = False + model.kv_layer.k_scale = torch.tensor(2.0) + model.kv_layer.v_scale = torch.tensor(3.0) + model.kv_layer.q_scale = torch.tensor(4.0) + model.kv_layer.prob_scale = torch.tensor(5.0) + model.kv_layer._k_scale = torch.zeros(()) + model.kv_layer._v_scale = torch.zeros(()) + model.kv_layer._q_scale = torch.zeros(()) + model.kv_layer._prob_scale = torch.zeros(()) modelopt_process_weights_after_loading(model) - assert calls == [model.layer] + torch.testing.assert_close(model.kv_layer._k_scale, torch.tensor(2.0)) + torch.testing.assert_close(model.kv_layer._v_scale, torch.tensor(3.0)) + torch.testing.assert_close(model.kv_layer._q_scale, torch.tensor(4.0)) + torch.testing.assert_close(model.kv_layer._prob_scale, torch.tensor(5.0)) + assert model.kv_layer._k_scale_float == 2.0 + assert model.kv_layer._v_scale_float == 3.0 + assert model.kv_layer._q_scale_float == 4.0 + assert model.kv_layer.calculate_kv_scales is False def test_apply_modelopt_nvfp4_patches_updates_vllm_method(monkeypatch): @@ -1160,7 +1163,7 @@ def test_modelopt_dense_process_uses_vllm_kernel_api(): layer.weight._output_dim = 0 layer.weight.weight_loader = lambda param, loaded_weight: None layer.weight_scale = torch.nn.Parameter( - torch.tensor([[1.0, -2.0], [-0.5, 4.0]]), + torch.tensor([[1.0, 2.0], [0.5, 4.0]]), requires_grad=False, ) layer.weight_scale_2 = torch.nn.Parameter(torch.tensor([2.0]), requires_grad=False) @@ -1219,6 +1222,7 @@ def fake_apply(**kwargs): calls.append(("apply", kwargs)) return "out" + marlin_utils.is_fp4_marlin_supported = lambda: True marlin_utils.prepare_fp4_layer_for_marlin = fake_prepare marlin_utils.apply_fp4_marlin_linear = fake_apply monkeypatch.setitem( @@ -1230,7 +1234,7 @@ def fake_apply(**kwargs): layer = torch.nn.Module() layer.weight = torch.nn.Parameter(torch.ones(2, 1), requires_grad=False) layer.weight_scale = torch.nn.Parameter( - torch.tensor([[1.0, -2.0], [-0.5, 4.0]]), + torch.tensor([[1.0, 2.0], [0.5, 4.0]]), requires_grad=False, ) layer.weight_scale_2 = torch.nn.Parameter(torch.tensor([2.0]), requires_grad=False) @@ -1256,14 +1260,52 @@ def fake_apply(**kwargs): assert result == "out" assert calls[0] == ("prepare", layer) assert calls[1][0] == "apply" + assert calls[1][1]["weight_global_scale"] is layer.weight_global_scale assert not hasattr(layer, "input_scale") assert not hasattr(layer, "input_global_scale") assert not hasattr(layer, "alpha") assert not hasattr(layer, "input_global_scale_inv") assert not hasattr(layer, "weight_scale_2") torch.testing.assert_close(layer.weight_global_scale, torch.tensor(2.0)) + assert set(layer._nrl_modelopt_processed_tensor_refs) == { + "weight", + "weight_scale", + "weight_global_scale", + } + + +def test_zero_modelopt_moe_padding_uses_tp_rank_valid_size(): + def make_layer(tp_rank: int): + layer = torch.nn.Module() + layer.moe_config = types.SimpleNamespace( + intermediate_size_per_partition=512, + intermediate_size_per_partition_unpadded=464, + ) + layer.quant_config = types.SimpleNamespace(group_size=16) + layer.tp_rank = tp_rank + layer.tp_size = 4 + layer.w13_weight = torch.nn.Parameter(torch.ones(1, 512, 1)) + layer.w13_weight_scale = torch.nn.Parameter(torch.ones(1, 512, 1)) + layer.w2_weight = torch.nn.Parameter(torch.ones(1, 1, 256)) + layer.w2_weight_scale = torch.nn.Parameter(torch.ones(1, 1, 32)) + return layer + + rank0 = make_layer(tp_rank=0) + _zero_modelopt_moe_padding(rank0) + torch.testing.assert_close(rank0.w13_weight, torch.ones_like(rank0.w13_weight)) + torch.testing.assert_close(rank0.w2_weight, torch.ones_like(rank0.w2_weight)) torch.testing.assert_close( - layer.weight_scale, - torch.tensor([[1.0, 2.0], [0.5, 4.0]]), - ) - assert calls[1][1]["weight_global_scale"] is layer.weight_global_scale + rank0.w2_weight_scale, + torch.ones_like(rank0.w2_weight_scale), + ) + + rank3 = make_layer(tp_rank=3) + _zero_modelopt_moe_padding(rank3) + assert torch.all(rank3.w13_weight[:, :320] == 1) + assert torch.all(rank3.w13_weight[:, 320:] == 0) + assert torch.all(rank3.w13_weight_scale[:, :320] == 1) + assert torch.all(rank3.w13_weight_scale[:, 320:] == 0) + assert torch.all(rank3.w2_weight[:, :, :160] == 1) + assert torch.all(rank3.w2_weight[:, :, 160:] == 0) + assert torch.all(rank3.w2_weight_scale[:, :, :20] == 1) + assert torch.all(rank3.w2_weight_scale[:, :, 20:] == 0) diff --git a/tests/unit/models/policy/test_megatron_quant_worker.py b/tests/unit/models/policy/test_megatron_quant_worker.py index 07a8d0b9ff..a769c5860e 100644 --- a/tests/unit/models/policy/test_megatron_quant_worker.py +++ b/tests/unit/models/policy/test_megatron_quant_worker.py @@ -90,7 +90,7 @@ def _make_real_quant_worker(): "quant_cfg": "examples/modelopt/quant_configs/nvfp4_a16_mlp_only.yaml", "real_quant": True, "real_quant_ignore": ["lm_head"], - "vllm_cfg": {}, + "vllm_cfg": {"kv_cache_dtype": "auto"}, } } worker.model = object() @@ -126,7 +126,6 @@ def test_modelopt_layer_spec_config_selects_layer_specs(): from megatron.core.post_training.modelopt.gpt.model_specs import ( get_gpt_modelopt_spec, ) - from nemo_rl.modelopt.models.policy.workers.utils import ( get_quantization_layer_spec, get_quantization_mamba_stack_spec, @@ -226,8 +225,8 @@ def __init__(self): monkeypatch.setattr( megatron_quant_policy_worker, - "TensorQuantizer", - FakeTensorQuantizer, + "_tensor_quantizer_cls", + lambda: FakeTensorQuantizer, ) monkeypatch.setattr( MegatronPolicyWorkerImpl, diff --git a/tests/unit/models/policy/test_modelopt_worker_utils.py b/tests/unit/models/policy/test_modelopt_worker_utils.py index 0a40199cca..efa57f3c2c 100644 --- a/tests/unit/models/policy/test_modelopt_worker_utils.py +++ b/tests/unit/models/policy/test_modelopt_worker_utils.py @@ -109,6 +109,22 @@ def __len__(self): ) +def _modelopt_quant_module(): + return importlib.import_module("modelopt.torch.quantization") + + +def _modelopt_config_module(): + return importlib.import_module("modelopt.torch.quantization.config") + + +def _modelopt_dataset_utils_module(): + return importlib.import_module("modelopt.torch.utils.dataset_utils") + + +def _modelopt_plugins_module(): + return importlib.import_module("modelopt.torch.utils.plugins") + + def test_get_tokenizer_applies_modelopt_calibration_defaults(monkeypatch): tokenizer = types.SimpleNamespace(padding_side="right", model_max_length=0) monkeypatch.setattr(worker_utils, "_base_get_tokenizer", lambda cfg: tokenizer) @@ -127,7 +143,7 @@ def test_megatron_forward_loop_prefills_batch_input_ids(monkeypatch): batch_size=1, ) monkeypatch.setattr( - worker_utils, + _modelopt_plugins_module(), "megatron_prefill", lambda model, input_ids, skip_return_logits: seen.append( (model, input_ids.clone(), skip_return_logits) @@ -151,16 +167,22 @@ def test_quantize_model_skips_forward_loop_for_weight_only_config(monkeypatch): "resolve_quant_cfg", lambda quant_cfg: {"quant_cfg": [{"name": quant_cfg}]}, ) - monkeypatch.setattr(worker_utils, "need_calibration", lambda cfg: False) monkeypatch.setattr( - worker_utils.mtq, + _modelopt_config_module(), "need_calibration", lambda cfg: False + ) + monkeypatch.setattr( + _modelopt_quant_module(), "quantize", lambda model_arg, cfg, forward_loop: calls.append( (model_arg, cfg, forward_loop) ) or model_arg, ) - monkeypatch.setattr(worker_utils.mtq, "print_quant_summary", lambda model: None) + monkeypatch.setattr( + _modelopt_quant_module(), + "print_quant_summary", + lambda model: None, + ) worker_utils.quantize_model( model, @@ -177,7 +199,7 @@ def test_quantize_model_skips_forward_loop_for_weight_only_config(monkeypatch): def test_quantize_model_requires_calibration_data(monkeypatch): model = torch.nn.Linear(1, 1) monkeypatch.setattr(worker_utils, "resolve_quant_cfg", lambda quant_cfg: {}) - monkeypatch.setattr(worker_utils, "need_calibration", lambda cfg: True) + monkeypatch.setattr(_modelopt_config_module(), "need_calibration", lambda cfg: True) with pytest.raises(ValueError, match="policy.quant_calib_data"): worker_utils.quantize_model( @@ -194,7 +216,7 @@ def test_quantize_model_uses_random_calibration_loop(monkeypatch): calls = [] monkeypatch.setattr(worker_utils, "resolve_quant_cfg", lambda quant_cfg: {}) - monkeypatch.setattr(worker_utils, "need_calibration", lambda cfg: True) + monkeypatch.setattr(_modelopt_config_module(), "need_calibration", lambda cfg: True) monkeypatch.setattr( worker_utils, "get_forward_loop_func", @@ -205,11 +227,15 @@ def test_quantize_model_uses_random_calibration_loop(monkeypatch): ), ) monkeypatch.setattr( - worker_utils.mtq, + _modelopt_quant_module(), "quantize", lambda model_arg, cfg, forward_loop: calls.append(forward_loop) or model_arg, ) - monkeypatch.setattr(worker_utils.mtq, "print_quant_summary", lambda model: None) + monkeypatch.setattr( + _modelopt_quant_module(), + "print_quant_summary", + lambda model: None, + ) worker_utils.quantize_model( model, @@ -232,9 +258,9 @@ def test_quantize_model_uses_named_calibration_dataset(monkeypatch): calls = [] monkeypatch.setattr(worker_utils, "resolve_quant_cfg", lambda quant_cfg: {}) - monkeypatch.setattr(worker_utils, "need_calibration", lambda cfg: True) + monkeypatch.setattr(_modelopt_config_module(), "need_calibration", lambda cfg: True) monkeypatch.setattr( - worker_utils, + _modelopt_dataset_utils_module(), "get_dataset_dataloader", lambda **kwargs: calls.append(("dataset", kwargs)) or dataloader, ) @@ -248,14 +274,18 @@ def test_quantize_model_uses_named_calibration_dataset(monkeypatch): ), ) monkeypatch.setattr( - worker_utils.mtq, + _modelopt_quant_module(), "quantize", lambda model_arg, cfg, forward_loop: calls.append( ("quantize", model_arg, cfg, forward_loop) ) or model_arg, ) - monkeypatch.setattr(worker_utils.mtq, "print_quant_summary", lambda model: None) + monkeypatch.setattr( + _modelopt_quant_module(), + "print_quant_summary", + lambda model: None, + ) worker_utils.quantize_model( model,