Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def _spinup_nemo_gym(base_urls, model_name):
base_urls=base_urls,
invalid_tool_call_patterns=invalid_tool_call_patterns,
thinking_tags=thinking_tags,
require_routed_experts=router_replay_enabled(policy_config),
initial_global_config_dict=nemo_gym_dict,
)
nemo_gym_opts = {}
Expand Down
73 changes: 56 additions & 17 deletions nemo_rl/environments/nemo_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class NemoGymConfig(TypedDict):
thinking_tags: NotRequired[
List[str] | None
] # Thinking tags to check for malformed usage
require_routed_experts: NotRequired[
bool
] # Require Gym output items to carry R3 routed_experts


def _detect_invalid_tool_call_and_malformed_thinking(
Expand Down Expand Up @@ -343,15 +346,48 @@ def _postprocess_nemo_gym_to_nemo_rl_result(
prompt_token_ids = output_item_dict.pop("prompt_token_ids")
generation_token_ids = output_item_dict.pop("generation_token_ids")
generation_log_probs = output_item_dict.pop("generation_log_probs")
routed_experts_raw = output_item_dict.pop("routed_experts", None)
new_prompt_token_ids = prompt_token_ids[len(seen_token_ids) :]

nemo_rl_message_log.append(
{
"role": "user",
"content": "",
"token_ids": torch.tensor(new_prompt_token_ids),
}
)
routed_experts = None
if routed_experts_raw is not None:
routed_experts = torch.as_tensor(routed_experts_raw, dtype=torch.int32)
if routed_experts.dim() != 3:
raise ValueError(
"NeMo Gym returned routed_experts with invalid shape. "
"Expected [tokens, num_moe_layers, topk], got "
f"{tuple(routed_experts.shape)}."
)
expected_tokens = len(prompt_token_ids) + len(generation_token_ids)
if routed_experts.shape[0] < expected_tokens:
raise ValueError(
"NeMo Gym returned too few routed_experts rows for a "
"trainable output item: "
f"routes={routed_experts.shape[0]}, expected_at_least="
f"{expected_tokens}."
)
elif self.cfg.get("require_routed_experts", False):
raise ValueError(
"policy.router_replay.enabled=true requires NeMo Gym output "
"items to include routed_experts, but the field was missing. "
"Make sure the Gym repo includes routed_experts propagation "
"and the NeMo-RL vLLM OpenAI-compatible server is configured "
"with enable_return_routed_experts."
)

prompt_start = len(seen_token_ids)
prompt_end = len(prompt_token_ids)
generation_start = prompt_end
generation_end = prompt_end + len(generation_token_ids)

user_message = {
"role": "user",
"content": "",
"token_ids": torch.tensor(new_prompt_token_ids),
}
if routed_experts is not None:
user_message["routed_experts"] = routed_experts[prompt_start:prompt_end]
nemo_rl_message_log.append(user_message)
# Valid tool calls go through the structured API (tool_calls field) and get
# executed by NeMo-Gym. If tool call patterns appear in the text content instead,
# the call was invalid and never executed — flag it so training can penalize it.
Expand All @@ -365,16 +401,19 @@ def _postprocess_nemo_gym_to_nemo_rl_result(
)
)

nemo_rl_message_log.append(
{
"role": "assistant",
"content": "",
"token_ids": torch.tensor(generation_token_ids),
"generation_logprobs": torch.tensor(generation_log_probs),
"is_invalid_tool_call": is_invalid_tool_call,
"has_malformed_thinking": has_malformed_thinking,
}
)
assistant_message = {
"role": "assistant",
"content": "",
"token_ids": torch.tensor(generation_token_ids),
"generation_logprobs": torch.tensor(generation_log_probs),
"is_invalid_tool_call": is_invalid_tool_call,
"has_malformed_thinking": has_malformed_thinking,
}
if routed_experts is not None:
assistant_message["routed_experts"] = routed_experts[
generation_start:generation_end
]
nemo_rl_message_log.append(assistant_message)

seen_token_ids.extend(new_prompt_token_ids)
seen_token_ids.extend(generation_token_ids)
Expand Down
89 changes: 89 additions & 0 deletions nemo_rl/models/generation/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,95 @@ def pad_and_align_routed_expert_indices(
return (full, stats) if return_stats else full


def attach_routed_experts_to_chat_response_choices(
response: Any,
final_request_output: Any,
*,
device: torch.device,
logger: Any = None,
) -> Any:
"""Attach aligned routed experts to OpenAI chat response choices."""
outputs_by_index = {
output.index: output for output in getattr(final_request_output, "outputs", [])
}
prompt_token_count = len(
getattr(final_request_output, "prompt_token_ids", []) or []
)

choices = list(getattr(response, "choices", []))
attached_choice_indices = set()
for choice in choices:
generation_details = outputs_by_index.get(choice.index)
if generation_details is None:
continue
attached_choice_indices.add(choice.index)

generation_token_count = len(getattr(generation_details, "token_ids", []) or [])
routed_result = pad_and_align_routed_expert_indices(
final_request_output,
generation_details,
valid_length=prompt_token_count + generation_token_count,
padded_length=prompt_token_count + generation_token_count,
device=device,
require_complete_routed_experts=True,
return_stats=True,
)
if not isinstance(routed_result, tuple):
raise RuntimeError(
"Expected routed_experts alignment to return stats for the "
"OpenAI-compatible chat endpoint."
)
routed_experts, r3_stats = routed_result
if routed_experts is None:
raise RuntimeError(
"vLLM was asked to return routed experts for the "
"OpenAI-compatible chat endpoint but the generation "
Comment thread
zyzhou5 marked this conversation as resolved.
"output did not include routed_experts."
)
if r3_stats["missing_routes"] > 0 and logger is not None:
logger.warning(
"R3 router replay fallback: vLLM returned incomplete "
"routed_experts for chat choice_idx=%d, "
"missing_token_routes=%d, actual_routes=%d, "
"expected_routes=%d. Megatron will use its own router "
"for those missing token routes.",
choice.index,
r3_stats["missing_routes"],
r3_stats["actual_routes"],
r3_stats["expected_routes"],
)
choice.message.routed_experts = routed_experts.to(dtype=torch.int32).tolist()

if len(attached_choice_indices) != len(choices):
missing_choice_indices = sorted(
choice.index
for choice in choices
if choice.index not in attached_choice_indices
)
raise RuntimeError(
"vLLM was asked to return routed experts for the "
"OpenAI-compatible chat endpoint but response choices could not be "
"matched to generation outputs: "
f"missing_choice_indices={missing_choice_indices}."
)

return response


def model_dump_chat_response_with_routed_experts(response: Any) -> dict[str, Any]:
"""Dump a vLLM OpenAI chat response while preserving dynamic R3 fields."""
response_dict = response.model_dump()
for choice, choice_dict in zip(
getattr(response, "choices", []), response_dict.get("choices", [])
):
routed_experts = getattr(
getattr(choice, "message", None), "routed_experts", None
)
if routed_experts is not None:
choice_dict.setdefault("message", {})["routed_experts"] = routed_experts
return response_dict


def aggregate_spec_decode_counters(
worker_metrics: list[dict[str, float | list[float]]],
) -> dict[str | tuple[str, int], float]:
Expand Down
60 changes: 53 additions & 7 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
verify_right_padding,
)
from nemo_rl.models.generation.vllm.utils import (
attach_routed_experts_to_chat_response_choices,
format_prompt_for_vllm_generation,
model_dump_chat_response_with_routed_experts,
pad_and_align_routed_expert_indices,
)
from nemo_rl.models.generation.vllm.vllm_worker import BaseVllmGenerationWorker
Expand Down Expand Up @@ -216,6 +218,14 @@ def __init__(
self.llm = None
self.vllm_device_ids = None

def _return_routed_experts_enabled(self) -> bool:
engine_args = getattr(self, "llm_async_engine_args", None)
if bool(getattr(engine_args, "enable_return_routed_experts", False)):
return True
return bool(
self.cfg.get("vllm_kwargs", {}).get("enable_return_routed_experts", False)
)

def _reserve_port(self) -> None:
"""Bind and listen on a TCP socket to reserve a free port from the OS.

Expand Down Expand Up @@ -671,7 +681,45 @@ class NeMoRLChatCompletionRequest(
# vLLM 0.20 routes both /v1/chat/completions and /tokenize through
# OpenAIServingRender.preprocess_chat, so the prefix-token override
# belongs on the render subclass.
class NeMoRLOpenAIServingChat(OpenAIServingChat):
worker_self = self

class NeMoRLOpenAIServingChatMixin:
async def chat_completion_full_generator(
Comment thread
zyzhou5 marked this conversation as resolved.
self,
request,
result_generator,
*args,
**kwargs,
):
final_res = None

async def capture_result_generator():
nonlocal final_res
async for res in result_generator:
final_res = res
yield res

response = await super().chat_completion_full_generator(
request,
capture_result_generator(),
*args,
**kwargs,
)
if (
not worker_self._return_routed_experts_enabled()
or not isinstance(response, ChatCompletionResponse)
or final_res is None
):
return response

return attach_routed_experts_to_chat_response_choices(
response,
final_res,
device=torch.device("cpu"),
logger=LOGGER,
)

class NeMoRLOpenAIServingChat(NeMoRLOpenAIServingChatMixin, OpenAIServingChat):
pass

class NeMoRLOpenAIServingRender(NeMoRLOpenAIServingMixin, OpenAIServingRender):
Expand Down Expand Up @@ -753,7 +801,9 @@ async def create_chat_completion(
)

elif isinstance(generator, ChatCompletionResponse):
return JSONResponse(content=generator.model_dump())
return JSONResponse(
content=model_dump_chat_response_with_routed_experts(generator)
)

return StreamingResponse(content=generator, media_type="text/event-stream")

Expand Down Expand Up @@ -1071,11 +1121,7 @@ async def process_single_sample(sample_idx):
generation_details = final_request_output.outputs[0]
generated_token_ids = list(generation_details.token_ids)
num_generated_tokens = len(generated_token_ids)
return_routed_experts = bool(
self.cfg.get("vllm_kwargs", {}).get(
"enable_return_routed_experts", False
)
)
return_routed_experts = self._return_routed_experts_enabled()

original_input_ids_single_row = input_ids_batch[sample_idx]
final_output_tensor_len = current_input_actual_length + num_generated_tokens
Expand Down
90 changes: 90 additions & 0 deletions tests/unit/environments/test_nemo_gym_router_replay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from nemo_rl.environments.nemo_gym import NemoGym


class _Tokenizer:
def batch_decode(self, batch):
return [" ".join(map(str, token_ids)) for token_ids in batch]


def _routes(num_tokens: int) -> list[list[list[int]]]:
return [[[token_idx, token_idx + 100]] for token_idx in range(num_tokens)]


def test_nemo_gym_postprocess_slices_routed_experts():
nemo_gym_result = {
"response": {
"output": [
{
"prompt_token_ids": [1, 2],
"generation_token_ids": [3],
"generation_log_probs": [-0.1],
"routed_experts": _routes(3),
},
{
"prompt_token_ids": [1, 2, 3, 4, 5],
"generation_token_ids": [6, 7],
"generation_log_probs": [-0.2, -0.3],
"routed_experts": _routes(7),
},
]
},
"responses_create_params": {"input": []},
}

class _MockSelf:
cfg = {"require_routed_experts": True}

result = (
NemoGym.__ray_metadata__.modified_class._postprocess_nemo_gym_to_nemo_rl_result(
_MockSelf(), nemo_gym_result, _Tokenizer()
)
)

message_log = result["message_log"]
assert message_log[0]["token_ids"].tolist() == [1, 2]
assert message_log[0]["routed_experts"].tolist() == _routes(2)
assert message_log[1]["token_ids"].tolist() == [3]
assert message_log[1]["routed_experts"].tolist() == _routes(3)[2:3]
assert message_log[2]["token_ids"].tolist() == [4, 5]
assert message_log[2]["routed_experts"].tolist() == _routes(7)[3:5]
assert message_log[3]["token_ids"].tolist() == [6, 7]
assert message_log[3]["routed_experts"].tolist() == _routes(7)[5:7]


def test_nemo_gym_postprocess_requires_routed_experts_when_configured():
nemo_gym_result = {
"response": {
"output": [
{
"prompt_token_ids": [1, 2],
"generation_token_ids": [3],
"generation_log_probs": [-0.1],
},
]
},
"responses_create_params": {"input": []},
}

class _MockSelf:
cfg = {"require_routed_experts": True}

with pytest.raises(ValueError, match="requires NeMo Gym output items"):
NemoGym.__ray_metadata__.modified_class._postprocess_nemo_gym_to_nemo_rl_result(
_MockSelf(), nemo_gym_result, _Tokenizer()
)
Loading
Loading