From c7f0e1f93b8741f15bf0059c5d45cc1067be9b7c Mon Sep 17 00:00:00 2001 From: Zeyu Zhou Date: Mon, 29 Jun 2026 09:58:18 -0700 Subject: [PATCH] fix: allow router replay trace fallback composition (#2963) Signed-off-by: Zeyu Zhou Signed-off-by: NeMo Bot --- nemo_rl/models/megatron/router_replay.py | 12 ++++--- nemo_rl/utils/r3_trace.py | 5 ++- .../models/megatron/test_router_replay.py | 33 +++++++++++++++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/megatron/router_replay.py b/nemo_rl/models/megatron/router_replay.py index a262410f3e..d81763b33d 100644 --- a/nemo_rl/models/megatron/router_replay.py +++ b/nemo_rl/models/megatron/router_replay.py @@ -17,6 +17,7 @@ import inspect import os from collections.abc import Iterable +from functools import wraps from typing import Any, Optional import torch @@ -258,8 +259,7 @@ def _install_missing_route_fallback_patch() -> None: return original_get_replay_topk = RouterReplay.get_replay_topk - expected_params = [ - "self", + expected_non_receiver_params = [ "scores", "topk", "num_groups", @@ -267,15 +267,19 @@ def _install_missing_route_fallback_patch() -> None: "default_compute_topk", ] actual_params = list(inspect.signature(original_get_replay_topk).parameters) - if actual_params != expected_params: + # Wrapper receiver names are arbitrary; guard only Megatron's callable API. + actual_non_receiver_params = actual_params[1:] + if actual_non_receiver_params != expected_non_receiver_params: raise RuntimeError( "Unsupported Megatron RouterReplay.get_replay_topk signature for " "NeMo RL missing-route fallback patch: " - f"expected={expected_params}, actual={actual_params}. " + f"expected_non_receiver_params={expected_non_receiver_params}, " + f"actual={actual_params}. " "Update nemo_rl.models.megatron.router_replay before enabling " "policy.router_replay.enabled." ) + @wraps(original_get_replay_topk) def wrapped_get_replay_topk( replay_instance: Any, scores: torch.Tensor, diff --git a/nemo_rl/utils/r3_trace.py b/nemo_rl/utils/r3_trace.py index 1c254a887d..d0a6a675be 100644 --- a/nemo_rl/utils/r3_trace.py +++ b/nemo_rl/utils/r3_trace.py @@ -24,6 +24,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from contextlib import contextmanager, nullcontext +from functools import wraps from pathlib import Path from typing import Any, Optional @@ -298,8 +299,10 @@ def _verify_router_replay_forward_context() -> Iterator[None]: with _patch_lock: if _router_replay_patch_depth == 0: - _original_get_replay_topk = RouterReplay.get_replay_topk + original_get_replay_topk = RouterReplay.get_replay_topk + _original_get_replay_topk = original_get_replay_topk + @wraps(original_get_replay_topk) def wrapped_get_replay_topk( replay_instance: Any, scores: Any, diff --git a/tests/unit/models/megatron/test_router_replay.py b/tests/unit/models/megatron/test_router_replay.py index 35eeca07c8..4ae0040857 100644 --- a/tests/unit/models/megatron/test_router_replay.py +++ b/tests/unit/models/megatron/test_router_replay.py @@ -602,6 +602,39 @@ def test_r3_trace_forward_verifier_records_actual_replayed_topk(tmp_path, monkey RouterReplay.clear_global_router_replay_instances() +@pytest.mark.mcore +def test_missing_route_fallback_patch_is_idempotent_inside_r3_trace_verifier( + tmp_path, monkeypatch +): + from megatron.core.transformer.moe.router_replay import RouterReplay + + from nemo_rl.models.megatron import router_replay + from nemo_rl.utils.r3_trace import r3_trace_stage + + monkeypatch.setenv("NRL_R3_TRACE", "1") + monkeypatch.setenv("NRL_R3_TRACE_VERIFY_FORWARD", "1") + monkeypatch.setenv("NRL_R3_TRACE_DIR", str(tmp_path)) + RouterReplay.clear_global_router_replay_instances() + + try: + router_replay._install_missing_route_fallback_patch() + assert getattr( + RouterReplay.get_replay_topk, + router_replay._MISSING_ROUTE_FALLBACK_PATCH_ATTR, + False, + ) + + with r3_trace_stage("unit-forward"): + assert getattr( + RouterReplay.get_replay_topk, + router_replay._MISSING_ROUTE_FALLBACK_PATCH_ATTR, + False, + ) + router_replay._install_missing_route_fallback_patch() + finally: + RouterReplay.clear_global_router_replay_instances() + + @pytest.mark.mcore @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_mcore_moe_replay_backward_recompute_matches_parameter_grads(tmp_path):