Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions nemo_rl/models/megatron/router_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import os
from collections.abc import Iterable
from functools import wraps
from typing import Any, Optional

import torch
Expand Down Expand Up @@ -258,24 +259,27 @@ def _install_missing_route_fallback_patch() -> None:
return

original_get_replay_topk = RouterReplay.get_replay_topk
expected_params = [
"self",
expected_non_receiver_params = [
"scores",
"topk",
"num_groups",
"group_topk",
"default_compute_topk",
]
actual_params = list(inspect.signature(original_get_replay_topk).parameters)
if actual_params != expected_params:
# Wrapper receiver names are arbitrary; guard only Megatron's callable API.
actual_non_receiver_params = actual_params[1:]
if actual_non_receiver_params != expected_non_receiver_params:
raise RuntimeError(
"Unsupported Megatron RouterReplay.get_replay_topk signature for "
"NeMo RL missing-route fallback patch: "
f"expected={expected_params}, actual={actual_params}. "
f"expected_non_receiver_params={expected_non_receiver_params}, "
f"actual={actual_params}. "
"Update nemo_rl.models.megatron.router_replay before enabling "
"policy.router_replay.enabled."
)

@wraps(original_get_replay_topk)
def wrapped_get_replay_topk(
replay_instance: Any,
scores: torch.Tensor,
Expand Down
5 changes: 4 additions & 1 deletion nemo_rl/utils/r3_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from collections import defaultdict
from collections.abc import Iterator, Sequence
from contextlib import contextmanager, nullcontext
from functools import wraps
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -298,8 +299,10 @@ def _verify_router_replay_forward_context() -> Iterator[None]:

with _patch_lock:
if _router_replay_patch_depth == 0:
_original_get_replay_topk = RouterReplay.get_replay_topk
original_get_replay_topk = RouterReplay.get_replay_topk
_original_get_replay_topk = original_get_replay_topk

@wraps(original_get_replay_topk)
def wrapped_get_replay_topk(
replay_instance: Any,
scores: Any,
Expand Down
33 changes: 33 additions & 0 deletions tests/unit/models/megatron/test_router_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,39 @@ def test_r3_trace_forward_verifier_records_actual_replayed_topk(tmp_path, monkey
RouterReplay.clear_global_router_replay_instances()


@pytest.mark.mcore
def test_missing_route_fallback_patch_is_idempotent_inside_r3_trace_verifier(
tmp_path, monkeypatch
):
from megatron.core.transformer.moe.router_replay import RouterReplay

from nemo_rl.models.megatron import router_replay
from nemo_rl.utils.r3_trace import r3_trace_stage

monkeypatch.setenv("NRL_R3_TRACE", "1")
monkeypatch.setenv("NRL_R3_TRACE_VERIFY_FORWARD", "1")
monkeypatch.setenv("NRL_R3_TRACE_DIR", str(tmp_path))
RouterReplay.clear_global_router_replay_instances()

try:
router_replay._install_missing_route_fallback_patch()
assert getattr(
RouterReplay.get_replay_topk,
router_replay._MISSING_ROUTE_FALLBACK_PATCH_ATTR,
False,
)

with r3_trace_stage("unit-forward"):
assert getattr(
RouterReplay.get_replay_topk,
router_replay._MISSING_ROUTE_FALLBACK_PATCH_ATTR,
False,
)
router_replay._install_missing_route_fallback_patch()
finally:
RouterReplay.clear_global_router_replay_instances()


@pytest.mark.mcore
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_mcore_moe_replay_backward_recompute_matches_parameter_grads(tmp_path):
Expand Down
Loading