diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5435df757a..77a6d55eff 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -496,6 +496,7 @@ def _spinup_nemo_gym(base_urls, model_name): base_urls=base_urls, invalid_tool_call_patterns=invalid_tool_call_patterns, thinking_tags=thinking_tags, + require_routed_experts=router_replay_enabled(policy_config), initial_global_config_dict=nemo_gym_dict, ) nemo_gym_opts = {} diff --git a/nemo_rl/environments/nemo_gym.py b/nemo_rl/environments/nemo_gym.py index 611751af36..b3f8dcbfbd 100644 --- a/nemo_rl/environments/nemo_gym.py +++ b/nemo_rl/environments/nemo_gym.py @@ -77,6 +77,9 @@ class NemoGymConfig(TypedDict): thinking_tags: NotRequired[ List[str] | None ] # Thinking tags to check for malformed usage + require_routed_experts: NotRequired[ + bool + ] # Require Gym output items to carry R3 routed_experts def _detect_invalid_tool_call_and_malformed_thinking( @@ -343,15 +346,48 @@ def _postprocess_nemo_gym_to_nemo_rl_result( prompt_token_ids = output_item_dict.pop("prompt_token_ids") generation_token_ids = output_item_dict.pop("generation_token_ids") generation_log_probs = output_item_dict.pop("generation_log_probs") + routed_experts_raw = output_item_dict.pop("routed_experts", None) new_prompt_token_ids = prompt_token_ids[len(seen_token_ids) :] - nemo_rl_message_log.append( - { - "role": "user", - "content": "", - "token_ids": torch.tensor(new_prompt_token_ids), - } - ) + routed_experts = None + if routed_experts_raw is not None: + routed_experts = torch.as_tensor(routed_experts_raw, dtype=torch.int32) + if routed_experts.dim() != 3: + raise ValueError( + "NeMo Gym returned routed_experts with invalid shape. " + "Expected [tokens, num_moe_layers, topk], got " + f"{tuple(routed_experts.shape)}." + ) + expected_tokens = len(prompt_token_ids) + len(generation_token_ids) + if routed_experts.shape[0] < expected_tokens: + raise ValueError( + "NeMo Gym returned too few routed_experts rows for a " + "trainable output item: " + f"routes={routed_experts.shape[0]}, expected_at_least=" + f"{expected_tokens}." + ) + elif self.cfg.get("require_routed_experts", False): + raise ValueError( + "policy.router_replay.enabled=true requires NeMo Gym output " + "items to include routed_experts, but the field was missing. " + "Make sure the Gym repo includes routed_experts propagation " + "and the NeMo-RL vLLM OpenAI-compatible server is configured " + "with enable_return_routed_experts." + ) + + prompt_start = len(seen_token_ids) + prompt_end = len(prompt_token_ids) + generation_start = prompt_end + generation_end = prompt_end + len(generation_token_ids) + + user_message = { + "role": "user", + "content": "", + "token_ids": torch.tensor(new_prompt_token_ids), + } + if routed_experts is not None: + user_message["routed_experts"] = routed_experts[prompt_start:prompt_end] + nemo_rl_message_log.append(user_message) # Valid tool calls go through the structured API (tool_calls field) and get # executed by NeMo-Gym. If tool call patterns appear in the text content instead, # the call was invalid and never executed — flag it so training can penalize it. @@ -365,16 +401,19 @@ def _postprocess_nemo_gym_to_nemo_rl_result( ) ) - nemo_rl_message_log.append( - { - "role": "assistant", - "content": "", - "token_ids": torch.tensor(generation_token_ids), - "generation_logprobs": torch.tensor(generation_log_probs), - "is_invalid_tool_call": is_invalid_tool_call, - "has_malformed_thinking": has_malformed_thinking, - } - ) + assistant_message = { + "role": "assistant", + "content": "", + "token_ids": torch.tensor(generation_token_ids), + "generation_logprobs": torch.tensor(generation_log_probs), + "is_invalid_tool_call": is_invalid_tool_call, + "has_malformed_thinking": has_malformed_thinking, + } + if routed_experts is not None: + assistant_message["routed_experts"] = routed_experts[ + generation_start:generation_end + ] + nemo_rl_message_log.append(assistant_message) seen_token_ids.extend(new_prompt_token_ids) seen_token_ids.extend(generation_token_ids) diff --git a/nemo_rl/models/generation/vllm/utils.py b/nemo_rl/models/generation/vllm/utils.py index 349d36fabf..9635fa389d 100644 --- a/nemo_rl/models/generation/vllm/utils.py +++ b/nemo_rl/models/generation/vllm/utils.py @@ -192,6 +192,95 @@ def pad_and_align_routed_expert_indices( return (full, stats) if return_stats else full +def attach_routed_experts_to_chat_response_choices( + response: Any, + final_request_output: Any, + *, + device: torch.device, + logger: Any = None, +) -> Any: + """Attach aligned routed experts to OpenAI chat response choices.""" + outputs_by_index = { + output.index: output for output in getattr(final_request_output, "outputs", []) + } + prompt_token_count = len( + getattr(final_request_output, "prompt_token_ids", []) or [] + ) + + choices = list(getattr(response, "choices", [])) + attached_choice_indices = set() + for choice in choices: + generation_details = outputs_by_index.get(choice.index) + if generation_details is None: + continue + attached_choice_indices.add(choice.index) + + generation_token_count = len(getattr(generation_details, "token_ids", []) or []) + routed_result = pad_and_align_routed_expert_indices( + final_request_output, + generation_details, + valid_length=prompt_token_count + generation_token_count, + padded_length=prompt_token_count + generation_token_count, + device=device, + require_complete_routed_experts=True, + return_stats=True, + ) + if not isinstance(routed_result, tuple): + raise RuntimeError( + "Expected routed_experts alignment to return stats for the " + "OpenAI-compatible chat endpoint." + ) + routed_experts, r3_stats = routed_result + if routed_experts is None: + raise RuntimeError( + "vLLM was asked to return routed experts for the " + "OpenAI-compatible chat endpoint but the generation " + "output did not include routed_experts." + ) + if r3_stats["missing_routes"] > 0 and logger is not None: + logger.warning( + "R3 router replay fallback: vLLM returned incomplete " + "routed_experts for chat choice_idx=%d, " + "missing_token_routes=%d, actual_routes=%d, " + "expected_routes=%d. Megatron will use its own router " + "for those missing token routes.", + choice.index, + r3_stats["missing_routes"], + r3_stats["actual_routes"], + r3_stats["expected_routes"], + ) + choice.message.routed_experts = routed_experts.to(dtype=torch.int32).tolist() + + if len(attached_choice_indices) != len(choices): + missing_choice_indices = sorted( + choice.index + for choice in choices + if choice.index not in attached_choice_indices + ) + raise RuntimeError( + "vLLM was asked to return routed experts for the " + "OpenAI-compatible chat endpoint but response choices could not be " + "matched to generation outputs: " + f"missing_choice_indices={missing_choice_indices}." + ) + + return response + + +def model_dump_chat_response_with_routed_experts(response: Any) -> dict[str, Any]: + """Dump a vLLM OpenAI chat response while preserving dynamic R3 fields.""" + response_dict = response.model_dump() + for choice, choice_dict in zip( + getattr(response, "choices", []), response_dict.get("choices", []) + ): + routed_experts = getattr( + getattr(choice, "message", None), "routed_experts", None + ) + if routed_experts is not None: + choice_dict.setdefault("message", {})["routed_experts"] = routed_experts + return response_dict + + def aggregate_spec_decode_counters( worker_metrics: list[dict[str, float | list[float]]], ) -> dict[str | tuple[str, int], float]: diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 591b929fbf..fab2e1330e 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -41,7 +41,9 @@ verify_right_padding, ) from nemo_rl.models.generation.vllm.utils import ( + attach_routed_experts_to_chat_response_choices, format_prompt_for_vllm_generation, + model_dump_chat_response_with_routed_experts, pad_and_align_routed_expert_indices, ) from nemo_rl.models.generation.vllm.vllm_worker import BaseVllmGenerationWorker @@ -216,6 +218,14 @@ def __init__( self.llm = None self.vllm_device_ids = None + def _return_routed_experts_enabled(self) -> bool: + engine_args = getattr(self, "llm_async_engine_args", None) + if bool(getattr(engine_args, "enable_return_routed_experts", False)): + return True + return bool( + self.cfg.get("vllm_kwargs", {}).get("enable_return_routed_experts", False) + ) + def _reserve_port(self) -> None: """Bind and listen on a TCP socket to reserve a free port from the OS. @@ -671,7 +681,45 @@ class NeMoRLChatCompletionRequest( # vLLM 0.20 routes both /v1/chat/completions and /tokenize through # OpenAIServingRender.preprocess_chat, so the prefix-token override # belongs on the render subclass. - class NeMoRLOpenAIServingChat(OpenAIServingChat): + worker_self = self + + class NeMoRLOpenAIServingChatMixin: + async def chat_completion_full_generator( + self, + request, + result_generator, + *args, + **kwargs, + ): + final_res = None + + async def capture_result_generator(): + nonlocal final_res + async for res in result_generator: + final_res = res + yield res + + response = await super().chat_completion_full_generator( + request, + capture_result_generator(), + *args, + **kwargs, + ) + if ( + not worker_self._return_routed_experts_enabled() + or not isinstance(response, ChatCompletionResponse) + or final_res is None + ): + return response + + return attach_routed_experts_to_chat_response_choices( + response, + final_res, + device=torch.device("cpu"), + logger=LOGGER, + ) + + class NeMoRLOpenAIServingChat(NeMoRLOpenAIServingChatMixin, OpenAIServingChat): pass class NeMoRLOpenAIServingRender(NeMoRLOpenAIServingMixin, OpenAIServingRender): @@ -753,7 +801,9 @@ async def create_chat_completion( ) elif isinstance(generator, ChatCompletionResponse): - return JSONResponse(content=generator.model_dump()) + return JSONResponse( + content=model_dump_chat_response_with_routed_experts(generator) + ) return StreamingResponse(content=generator, media_type="text/event-stream") @@ -1071,11 +1121,7 @@ async def process_single_sample(sample_idx): generation_details = final_request_output.outputs[0] generated_token_ids = list(generation_details.token_ids) num_generated_tokens = len(generated_token_ids) - return_routed_experts = bool( - self.cfg.get("vllm_kwargs", {}).get( - "enable_return_routed_experts", False - ) - ) + return_routed_experts = self._return_routed_experts_enabled() original_input_ids_single_row = input_ids_batch[sample_idx] final_output_tensor_len = current_input_actual_length + num_generated_tokens diff --git a/tests/unit/environments/test_nemo_gym_router_replay.py b/tests/unit/environments/test_nemo_gym_router_replay.py new file mode 100644 index 0000000000..fdc7a021f7 --- /dev/null +++ b/tests/unit/environments/test_nemo_gym_router_replay.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nemo_rl.environments.nemo_gym import NemoGym + + +class _Tokenizer: + def batch_decode(self, batch): + return [" ".join(map(str, token_ids)) for token_ids in batch] + + +def _routes(num_tokens: int) -> list[list[list[int]]]: + return [[[token_idx, token_idx + 100]] for token_idx in range(num_tokens)] + + +def test_nemo_gym_postprocess_slices_routed_experts(): + nemo_gym_result = { + "response": { + "output": [ + { + "prompt_token_ids": [1, 2], + "generation_token_ids": [3], + "generation_log_probs": [-0.1], + "routed_experts": _routes(3), + }, + { + "prompt_token_ids": [1, 2, 3, 4, 5], + "generation_token_ids": [6, 7], + "generation_log_probs": [-0.2, -0.3], + "routed_experts": _routes(7), + }, + ] + }, + "responses_create_params": {"input": []}, + } + + class _MockSelf: + cfg = {"require_routed_experts": True} + + result = ( + NemoGym.__ray_metadata__.modified_class._postprocess_nemo_gym_to_nemo_rl_result( + _MockSelf(), nemo_gym_result, _Tokenizer() + ) + ) + + message_log = result["message_log"] + assert message_log[0]["token_ids"].tolist() == [1, 2] + assert message_log[0]["routed_experts"].tolist() == _routes(2) + assert message_log[1]["token_ids"].tolist() == [3] + assert message_log[1]["routed_experts"].tolist() == _routes(3)[2:3] + assert message_log[2]["token_ids"].tolist() == [4, 5] + assert message_log[2]["routed_experts"].tolist() == _routes(7)[3:5] + assert message_log[3]["token_ids"].tolist() == [6, 7] + assert message_log[3]["routed_experts"].tolist() == _routes(7)[5:7] + + +def test_nemo_gym_postprocess_requires_routed_experts_when_configured(): + nemo_gym_result = { + "response": { + "output": [ + { + "prompt_token_ids": [1, 2], + "generation_token_ids": [3], + "generation_log_probs": [-0.1], + }, + ] + }, + "responses_create_params": {"input": []}, + } + + class _MockSelf: + cfg = {"require_routed_experts": True} + + with pytest.raises(ValueError, match="requires NeMo Gym output items"): + NemoGym.__ray_metadata__.modified_class._postprocess_nemo_gym_to_nemo_rl_result( + _MockSelf(), nemo_gym_result, _Tokenizer() + ) diff --git a/tests/unit/models/generation/test_vllm_utils.py b/tests/unit/models/generation/test_vllm_utils.py index 8e21abbf46..674a23b5fb 100644 --- a/tests/unit/models/generation/test_vllm_utils.py +++ b/tests/unit/models/generation/test_vllm_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from types import SimpleNamespace import pytest import torch @@ -21,8 +22,10 @@ from nemo_rl.models.generation.vllm.utils import ( R3_MISSING_ROUTE_SENTINEL, aggregate_spec_decode_counters, + attach_routed_experts_to_chat_response_choices, compute_spec_decode_metrics, format_prompt_for_vllm_generation, + model_dump_chat_response_with_routed_experts, pad_and_align_routed_expert_indices, ) @@ -298,6 +301,95 @@ class Output: ) +def test_attach_routed_experts_to_chat_response_choices_reassociates_by_choice_index(): + final_res = SimpleNamespace( + prompt_token_ids=[101, 102, 103], + prompt_routed_experts=torch.tensor([[[10]], [[11]]], dtype=torch.int32), + outputs=[ + SimpleNamespace( + index=1, + token_ids=[201, 202], + routed_experts=torch.tensor([[[31]], [[32]]], dtype=torch.int32), + ), + SimpleNamespace( + index=0, + token_ids=[200], + routed_experts=torch.tensor([[[30]]], dtype=torch.int32), + ), + ], + ) + response = SimpleNamespace( + choices=[ + SimpleNamespace(index=0, message=SimpleNamespace()), + SimpleNamespace(index=1, message=SimpleNamespace()), + ] + ) + + attach_routed_experts_to_chat_response_choices( + response, + final_res, + device=torch.device("cpu"), + ) + + assert response.choices[0].message.routed_experts == [ + [[10]], + [[11]], + [[30]], + [[0]], + ] + assert response.choices[1].message.routed_experts == [ + [[10]], + [[11]], + [[31]], + [[32]], + [[0]], + ] + + +def test_attach_routed_experts_to_chat_response_choices_requires_routed_experts(): + final_res = SimpleNamespace( + prompt_token_ids=[101, 102], + outputs=[SimpleNamespace(index=0, token_ids=[200])], + ) + response = SimpleNamespace( + choices=[SimpleNamespace(index=0, message=SimpleNamespace())] + ) + + with pytest.raises(RuntimeError, match="did not include routed_experts"): + attach_routed_experts_to_chat_response_choices( + response, + final_res, + device=torch.device("cpu"), + ) + + +def test_model_dump_chat_response_with_routed_experts_preserves_dynamic_field(): + routed_experts = [[[1]], [[2]]] + + class Response: + choices = [ + SimpleNamespace( + message=SimpleNamespace(routed_experts=routed_experts), + ) + ] + + def model_dump(self): + return { + "choices": [ + { + "message": { + "role": "assistant", + "content": "hello", + } + } + ] + } + + response_dict = model_dump_chat_response_with_routed_experts(Response()) + + assert response_dict["choices"][0]["message"]["routed_experts"] == routed_experts + + @pytest.mark.vllm def test_vllm_speculative_decoding_patch_removed(): # The speculative decoding patch was fixed upstream in vLLM >= 0.14.0: