diff --git a/examples/train/mini_swe_agent/README.md b/examples/train/mini_swe_agent/README.md index b130e635e0..fc1d52fbc4 100644 --- a/examples/train/mini_swe_agent/README.md +++ b/examples/train/mini_swe_agent/README.md @@ -58,7 +58,7 @@ For issues with SkyRL or the Mini-SWE-Agent integration, please [open an Issue]( ### Common Issues -- **Context length errors**: If you see `ValueError: The decoder prompt (length xxxx) is longer than the maximum model length`, increase `max_input_length` and `max_generate_length` or reduce steps in `swebench.yaml`. +- **Context length errors**: If you see `ValueError: The decoder prompt (length xxxx) is longer than the maximum model length`, increase the vLLM `engine_init_kwargs.max_model_len`, reduce `max_input_length`, or reduce steps in `swebench.yaml`. `max_generate_length` is the assistant-token budget for a trajectory and does not increase the model context window. - **All zero rewards**: If rewards are consistently zero, the task may be too difficult. Consider: - Filtering data for a better mix of easy/hard samples diff --git a/examples/train/mini_swe_agent/mini_swe_generator.py b/examples/train/mini_swe_agent/mini_swe_generator.py index 923c2345d8..d2892f7000 100644 --- a/examples/train/mini_swe_agent/mini_swe_generator.py +++ b/examples/train/mini_swe_agent/mini_swe_generator.py @@ -1,28 +1,37 @@ import asyncio -from dataclasses import dataclass -from typing import Dict, List, Optional, Any, Tuple -import yaml import traceback -import ray +from dataclasses import dataclass from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple -from minisweagent.models import get_model +import ray +import yaml from minisweagent.agents.default import DefaultAgent -from minisweagent.run.utils.save import save_traj from minisweagent.config import get_config_path -from .mini_swe_utils import evaluate_trajectory, get_sb_environment +from minisweagent.models import get_model +from minisweagent.run.utils.save import save_traj -from skyrl.train.config import GeneratorConfig, SkyRLGymConfig -from skyrl.train.generators.skyrl_gym_generator import SkyRLGymGenerator, GeneratorOutput, GeneratorInput -from skyrl.train.generators.base import TrajectoryID, TrainingPhase, BatchMetadata from skyrl.backends.skyrl_train.inference_engines.base import ConversationType -from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient -from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend +from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import ( + InferenceEngineClient, +) +from skyrl.backends.skyrl_train.inference_engines.utils import ( + get_sampling_params_for_backend, +) +from skyrl.train.config import GeneratorConfig, SkyRLGymConfig +from skyrl.train.generators.base import BatchMetadata, TrainingPhase, TrajectoryID +from skyrl.train.generators.skyrl_gym_generator import ( + GeneratorInput, + GeneratorOutput, + SkyRLGymGenerator, +) from skyrl.train.generators.utils import ( - get_rollout_metrics, get_response_ids_and_loss_mask_from_messages, + get_rollout_metrics, ) +from .mini_swe_utils import evaluate_trajectory, get_sb_environment + @dataclass class MiniSWEGeneratorConfig(GeneratorConfig): @@ -199,15 +208,27 @@ async def minisweagent_agent_loop( # Extract prompt ids prompt_ids = initial_input_ids - # Calculate maximum response tokens allowed - max_response_tokens = max_tokens + max_input_length - initial_prompt_length + # Truncate by assistant-token budget first. Environment/user observations are kept only + # insofar as they fit the secondary packed-sequence guard below; they do not consume + # max_generate_length because their loss mask is 0. + assistant_tokens = 0 + assistant_budget_response_tokens = len(response_ids) + assistant_budget_exceeded = False + for idx, mask in enumerate(loss_mask): + assistant_tokens += int(bool(mask)) + if assistant_tokens > max_tokens: + assistant_budget_response_tokens = idx + assistant_budget_exceeded = True + break + + # Keep the packed prompt+response sequence bounded for training tensor sizes. + packed_response_tokens = max(0, max_tokens + max_input_length - initial_prompt_length) + max_response_tokens = min(assistant_budget_response_tokens, packed_response_tokens) - # Determine stop reason stop_reason = "complete" # Default for trial completion - if len(response_ids) > max_response_tokens: + if assistant_budget_exceeded or len(response_ids) > packed_response_tokens: stop_reason = "length" - # Truncate to maximum allowed length response_ids = response_ids[:max_response_tokens] loss_mask = loss_mask[:max_response_tokens] diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index f28d997ba4..15f8c22bb2 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -424,6 +424,9 @@ class FullyAsyncConfig(BaseConfig): @dataclass class SamplingParams(BaseConfig): max_generate_length: int = 1024 + """Trajectory-level assistant/generated-token budget. In multi-turn generators, + environment observation tokens are loss-masked and do not count against this budget. + The vLLM request field is ``max_tokens`` and may be reduced per turn to fit context.""" repetition_penalty: float = 1.0 temperature: float = 1.0 top_p: float = 1.0 @@ -496,7 +499,9 @@ class InferenceEngineConfig(BaseConfig): """When True, pass ``language_model_only=True`` to the vLLM engine so that multimodal models (e.g. Qwen3.5) skip vision encoder initialization.""" engine_init_kwargs: Dict[str, Any] = field(default_factory=dict) - """Pass-through kwargs for the vLLM engine. Names must match the engine's args.""" + """Pass-through kwargs for the vLLM engine. Names must match the engine's args. If + ``max_model_len`` is set, rollout requests are capped so input tokens plus per-request + generated tokens fit within that window.""" override_existing_update_group: str = "auto" """``"auto"``, ``"enable"``, or ``"disable"``.""" external_proxy_url: Optional[str] = None @@ -528,7 +533,8 @@ class GeneratorConfig(BaseConfig): batched: bool = False max_turns: int = 1 max_input_length: Optional[int] = None - """Max generator input length for multi-turn conversations. For single-turn, set equal to ``max_prompt_length``.""" + """Max input/context length allowed before each generation turn. For single-turn, set + equal to ``max_prompt_length``. Distinct from ``sampling_params.max_generate_length``.""" chat_template: ChatTemplateConfig = field(default_factory=ChatTemplateConfig) chat_template_kwargs: Dict[str, Any] = field(default_factory=dict) """Kwargs passed to ``tokenizer.apply_chat_template``.""" diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index b10695c5b3..c3257ad2c5 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -295,7 +295,9 @@ generator: n_samples_per_prompt: 5 async_engine: true batched: false - max_input_length: ${trainer.max_prompt_length} # max generator input length used for multi-turn conversations - for single turn set equal to max_prompt_length + # Max input/context length checked before each generation turn. For single-turn, set equal to max_prompt_length. + # This is distinct from sampling_params.max_generate_length, which budgets assistant-generated tokens. + max_input_length: ${trainer.max_prompt_length} # VLLM_ENABLE_V1_MULTIPROCESSING=0 for reproducibility vllm_v1_disable_multiproc: true enable_prefix_caching: true @@ -334,11 +336,14 @@ generator: # Inference engine arguments. Arguments are passed directly to the vLLM engine, so names must match # the engine's args. To specify an engine arg in the CLI override, use the format: +generator.engine_init_kwargs.arg_name=value + # If max_model_len is set, each rollout request's max_tokens is capped so prompt+completion fits this window. engine_init_kwargs: {} override_existing_update_group: "auto" # "auto", "enable", "disable" # sampling params for generation phase sampling_params: + # Trajectory-level assistant/generated-token budget. Multi-turn environment observations are loss-masked + # and do not count against this value. max_generate_length: 1024 repetition_penalty: 1.0 temperature: 1.0 @@ -395,4 +400,4 @@ generator: environment: env_class: "gsm8k" # NOTE: environment specific defaults for environment.skyrl_gym are set at the following path: - # skyrl_gym: config/skyrl_gym_config/default.yaml \ No newline at end of file + # skyrl_gym: config/skyrl_gym_config/default.yaml diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 68b635e8c9..068131422a 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -8,7 +8,7 @@ import asyncio import copy from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union from uuid import uuid4 @@ -33,9 +33,13 @@ ) from skyrl.train.generators.utils import ( apply_overlong_filtering, + compute_request_max_tokens, get_custom_chat_template, get_generation_prompt_ids, + get_max_model_len, get_rollout_metrics, + normalize_sampling_params, + sampling_params_with_max_tokens, ) from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput @@ -330,11 +334,12 @@ async def agent_loop( loss_mask = [] # this excludes the prompt rollout_logprobs = None - # `sampling_params` if provided is a dict in the format expected by the inference engine backend - # we cast default config to a dict for consistency - current_sampling_params: dict = ( - sampling_params if sampling_params is not None else asdict(self.generator_cfg.sampling_params) - ) + # `sampling_params` if provided is a dict in the format expected by the inference engine backend. + # When absent, normalize the config dataclass into the backend shape here so agent_loop() direct + # callers receive the same behavior as the main generator path. + base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params) + max_model_len = get_max_model_len(self.generator_cfg) + generated_tokens_used = 0 # Accumulate per-step rewards. Format: (reward, response_end_token_idx) per_step_rewards: List[Tuple[float, Optional[int]]] = [] @@ -343,7 +348,7 @@ async def agent_loop( agent_loop_output = StepWiseOutput(step_outputs=[]) if is_step_wise else None - get_logprobs = self.generator_cfg.sampling_params.logprobs is not None + get_logprobs = base_sampling_params.get("logprobs", None) is not None agent_loop_state = AgentLoopState( chat_history=chat_history, input_ids=initial_input_ids, @@ -352,6 +357,7 @@ async def agent_loop( response_end_idx=None, done=False, ) + new_obs: ConversationType = [] while not agent_loop_state.done: @@ -373,8 +379,20 @@ async def agent_loop( agent_loop_state.loss_mask = [] agent_loop_state.rollout_logprobs = None + request_max_tokens = compute_request_max_tokens( + max_tokens - generated_tokens_used, + len(agent_loop_state.input_ids), + max_model_len, + ) + if request_max_tokens <= 0: + stop_reason = "length" + break + + current_sampling_params = sampling_params_with_max_tokens(base_sampling_params, request_max_tokens) engine_input = InferenceEngineInput( - prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params + prompt_token_ids=[agent_loop_state.input_ids], + session_ids=[session_id], + sampling_params=current_sampling_params, ) engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) output = engine_output["responses"][0] @@ -440,6 +458,22 @@ async def agent_loop( if turn_output.rollout_expert_indices is not None and agent_loop_state.rollout_expert_indices is None: agent_loop_state.rollout_expert_indices = [] + turn_generated_tokens = sum(turn_output.get_turn_loss_mask()) + if not self.use_conversation_multi_turn and not retokenize_chat_history: + new_resp_tokens = turn_output.output_ids.copy() + if new_resp_tokens and new_resp_tokens[-1] == self.tokenizer.eos_token_id: + new_resp_tokens = new_resp_tokens[:-1] + turn_generated_tokens = len(new_resp_tokens) + generated_tokens_used += turn_generated_tokens + if generated_tokens_used >= max_tokens and not agent_loop_state.done: + # The trajectory-level assistant budget is exhausted. Do not append the + # next observation, since there will be no following generation request. + stop_reason = "length" + agent_loop_state.done = True + new_obs = [] + turn_output.new_obs = [] + turn_output.obs_ids = [] + if is_step_wise: # current response + observation ids turn_response_ids = turn_output.output_ids + turn_output.obs_ids @@ -489,6 +523,23 @@ async def agent_loop( rollout_expert_indices_out = None response_ids = None + if not is_step_wise and not retokenize_chat_history and agent_loop_state.response_end_idx is None: + agent_loop_output = TrajectoryOutput( + response_ids=[], + reward=[], + stop_reason=stop_reason, + loss_mask=[], + prompt_ids=prompt_ids, + rollout_logprobs=[] if get_logprobs else None, + env_metrics=env_metrics, + rollout_expert_indices=None, + ) + return self._post_process_agent_loop_output( + agent_loop_output, + env_extras, + trajectory_id, + ) + # Prepare the final loss_mask, response_ids and rollout_logprobs . # We remove the final observation messages /token IDs here # Note that during the agent loop, we still add the final observation messages/ tokens because we terminate the agent loop if the input length @@ -531,17 +582,21 @@ async def agent_loop( if not self.use_conversation_multi_turn: assert response_ids is not None and loss_mask is not None if stop_reason != "length" and response_ids and response_ids[-1] != self.tokenizer.eos_token_id: - response_ids.append(self.tokenizer.eos_token_id) - # TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss - # masked with 0, why bother adding it? - loss_mask.append(1) - if rollout_logprobs is not None: - rollout_logprobs.append(0.0) - if rollout_expert_indices_out is not None and rollout_expert_indices_out: - layer_num = len(rollout_expert_indices_out[0]) - topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0 - rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)]) - appended_eos_token = True + if generated_tokens_used < max_tokens: + response_ids.append(self.tokenizer.eos_token_id) + # TODO(Charlie): this should be 0? Otherwise logprobs will be extremely off. But if it is loss + # masked with 0, why bother adding it? + loss_mask.append(1) + if rollout_logprobs is not None: + rollout_logprobs.append(0.0) + if rollout_expert_indices_out is not None and rollout_expert_indices_out: + layer_num = len(rollout_expert_indices_out[0]) + topk = len(rollout_expert_indices_out[0][0]) if layer_num > 0 else 0 + rollout_expert_indices_out.append([[0] * topk for _ in range(layer_num)]) + generated_tokens_used += 1 + appended_eos_token = True + else: + stop_reason = "length" if self.generator_cfg.step_wise_trajectories: for per_step_output, (reward, resp_end_idx) in zip(agent_loop_output.step_outputs, per_step_rewards): @@ -713,7 +768,10 @@ async def generate_batched( tokenize=True, return_dict=False, ) - engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params) + request_sampling_params = sampling_params_with_max_tokens( + normalize_sampling_params(self.generator_cfg, sampling_params), max_tokens + ) + engine_input = InferenceEngineInput(prompt_token_ids=prompt_token_ids, sampling_params=request_sampling_params) engine_output = await self.inference_engine_client.generate(engine_input, model=self.policy_model_name) outputs = engine_output["responses"] responses = engine_output["response_ids"] @@ -788,7 +846,8 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False if self.generator_cfg.step_wise_trajectories: assert trajectory_ids is not None, "`trajectory_ids` is a required field for step wise training" sampling_params: Optional[dict] = input_batch.get("sampling_params", None) - max_tokens = self.generator_cfg.sampling_params.max_generate_length + base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params) + max_tokens = int(base_sampling_params.get("max_tokens", self.generator_cfg.sampling_params.max_generate_length)) max_input_length = self.generator_cfg.max_input_length if self.batched: diff --git a/skyrl/train/generators/skyrl_vlm_generator.py b/skyrl/train/generators/skyrl_vlm_generator.py index 76779aac6e..1e5472d88d 100644 --- a/skyrl/train/generators/skyrl_vlm_generator.py +++ b/skyrl/train/generators/skyrl_vlm_generator.py @@ -3,7 +3,6 @@ """ import copy -from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple, TypedDict from uuid import uuid4 @@ -25,6 +24,12 @@ SkyRLGymGenerator, TrajectoryOutput, ) +from skyrl.train.generators.utils import ( + compute_request_max_tokens, + get_max_model_len, + normalize_sampling_params, + sampling_params_with_max_tokens, +) class RenderedConversation(TypedDict): @@ -97,11 +102,11 @@ async def agent_loop( prompt_ids = initial_render["prompt_ids"] latest_features = initial_render["features"] - current_sampling_params: dict = ( - sampling_params if sampling_params is not None else asdict(self.generator_cfg.sampling_params) - ) - get_logprobs = self.generator_cfg.sampling_params.logprobs is not None - stop_strs = current_sampling_params.get("stop", None) + base_sampling_params = normalize_sampling_params(self.generator_cfg, sampling_params) + max_model_len = get_max_model_len(self.generator_cfg) + generated_tokens_used = 0 + get_logprobs = base_sampling_params.get("logprobs", None) is not None + stop_strs = base_sampling_params.get("stop", None) # ── Accumulators ─────────────────────────────────────────────── response_ids: List[int] = [] @@ -142,6 +147,17 @@ async def agent_loop( stop_reason = "length" break + request_max_tokens = compute_request_max_tokens( + max_tokens - generated_tokens_used, + len(input_ids), + max_model_len, + ) + if request_max_tokens <= 0: + stop_reason = "length" + break + + current_sampling_params = sampling_params_with_max_tokens(base_sampling_params, request_max_tokens) + # 2. Generate engine_input = InferenceEngineInput( prompt_token_ids=[input_ids], @@ -171,6 +187,12 @@ async def agent_loop( step_reward: float = env_step_output["reward"] done = env_step_output["done"] + generated_tokens_used += len(gen_ids) - int(added_eos) + if generated_tokens_used >= max_tokens and not done: + stop_reason = "length" + done = True + new_obs = [] + # 4. Append assistant message to conversation conversation.append({"role": "assistant", "content": gen_text}) diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index ec56abbea9..e6ff6caba5 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -9,6 +9,7 @@ from skyrl.backends.skyrl_train.inference_engines.base import ConversationType from skyrl.train.config import ChatTemplateConfig +from skyrl.train.config.config import GeneratorConfig from skyrl.train.generators.base import ( BatchMetadata, GeneratorInput, @@ -102,6 +103,54 @@ def _validate_template_file_path(file_path: str) -> str: } +def get_max_model_len(generator_cfg: GeneratorConfig) -> Optional[int]: + """Return the configured inference context window, if one is explicitly set.""" + max_model_len = generator_cfg.inference_engine.engine_init_kwargs.get("max_model_len") + if max_model_len is None: + return None + try: + return int(max_model_len) + except (TypeError, ValueError) as exc: + raise ValueError( + "generator.inference_engine.engine_init_kwargs.max_model_len must be an integer when set, " + f"got {max_model_len!r}" + ) from exc + + +def normalize_sampling_params( + generator_cfg: GeneratorConfig, sampling_params: Optional[Dict[str, Any]] = None +) -> Dict[str, Any]: + """Return backend-shaped sampling params without mutating the caller's dict.""" + if sampling_params is not None: + request_sampling_params = copy.deepcopy(sampling_params) + max_generate_length = request_sampling_params.pop("max_generate_length", None) + if "max_tokens" not in request_sampling_params and max_generate_length is not None: + request_sampling_params["max_tokens"] = max_generate_length + return request_sampling_params + from skyrl.backends.skyrl_train.inference_engines.utils import get_sampling_params_for_backend # noqa: I001 + + return get_sampling_params_for_backend(generator_cfg.inference_engine.backend, generator_cfg.sampling_params) + + +def sampling_params_with_max_tokens(sampling_params: Dict[str, Any], max_tokens: int) -> Dict[str, Any]: + """Return a copy of backend sampling params with the per-request decode cap set.""" + request_sampling_params = copy.deepcopy(sampling_params) + request_sampling_params["max_tokens"] = max_tokens + return request_sampling_params + + +def compute_request_max_tokens( + remaining_generate_tokens: int, + input_length: int, + max_model_len: Optional[int], +) -> int: + """Compute the safe max_tokens for the next inference request.""" + request_max_tokens = remaining_generate_tokens + if max_model_len is not None: + request_max_tokens = min(request_max_tokens, max_model_len - input_length) + return max(0, request_max_tokens) + + def get_custom_chat_template(chat_template_config: Optional[Union[dict, ChatTemplateConfig]] = None) -> Optional[str]: """ Get custom chat template based on the new config structure. diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index b59fda9766..c48eec4eb8 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -428,6 +428,45 @@ def validate_generator_cfg(cfg: SkyRLTrainConfig): "for multi-turn generation" ) + max_model_len = ie_cfg.engine_init_kwargs.get("max_model_len") + if max_model_len is not None: + try: + max_model_len = int(max_model_len) + except (TypeError, ValueError) as exc: + raise ValueError( + "`generator.inference_engine.engine_init_kwargs.max_model_len` must be an integer " + f"when set, got {max_model_len!r}" + ) from exc + if max_model_len <= 0: + raise ValueError( + "`generator.inference_engine.engine_init_kwargs.max_model_len` must be positive " + f"when set, got {max_model_len}" + ) + if cfg.trainer.max_prompt_length > max_model_len: + raise ValueError( + "`trainer.max_prompt_length` cannot exceed " + "`generator.inference_engine.engine_init_kwargs.max_model_len`: " + f"{cfg.trainer.max_prompt_length} > {max_model_len}" + ) + if cfg.generator.max_input_length > max_model_len: + raise ValueError( + "`generator.max_input_length` cannot exceed " + "`generator.inference_engine.engine_init_kwargs.max_model_len`: " + f"{cfg.generator.max_input_length} > {max_model_len}" + ) + if cfg.trainer.max_prompt_length + cfg.generator.sampling_params.max_generate_length > max_model_len: + logger.warning( + "`trainer.max_prompt_length + generator.sampling_params.max_generate_length` exceeds " + "`generator.inference_engine.engine_init_kwargs.max_model_len`. Multi-turn generation will " + "cap each request's `max_tokens` to fit the remaining model context window." + ) + if cfg.trainer.max_prompt_length + cfg.generator.eval_sampling_params.max_generate_length > max_model_len: + logger.warning( + "`trainer.max_prompt_length + generator.eval_sampling_params.max_generate_length` exceeds " + "`generator.inference_engine.engine_init_kwargs.max_model_len`. Evaluation generation will " + "cap each request's `max_tokens` to fit the remaining model context window." + ) + if ie_cfg.enable_pd: assert ie_cfg.num_prefill > 0, "num_prefill must be > 0 when enable_pd=True" assert ( diff --git a/tests/train/generators/test_skyrl_gym_generator.py b/tests/train/generators/test_skyrl_gym_generator.py index d8bee68a1a..57c97b16a0 100644 --- a/tests/train/generators/test_skyrl_gym_generator.py +++ b/tests/train/generators/test_skyrl_gym_generator.py @@ -610,6 +610,212 @@ def mock_encode(text, **kwargs): assert output.stop_reason == "length", f"Expected stop_reason='length', got '{output.stop_reason}'" +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_agent_loop_caps_each_request_to_remaining_generate_budget( + mock_make, mock_tokenizer, mock_llm, mock_env, generator_cfg, mock_env_cfg +): + """Observation tokens should not consume the trajectory-level assistant generation budget.""" + mock_make.return_value = mock_env + generator_cfg.batched = False + generator_cfg.max_turns = 3 + generator_cfg.sampling_params.max_generate_length = 5 + generator_cfg.use_conversation_multi_turn = True + mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) + + step_count = 0 + + def mock_step_multi_turn(_output): + nonlocal step_count + step_count += 1 + return BaseTextEnvStepOutput( + observations=[{"role": "user", "content": "next"}], + reward=1.0, + done=step_count >= 2, + metadata={}, + ) + + mock_env.step.side_effect = mock_step_multi_turn + requested_max_tokens = [] + + async def llm_generate_side_effect(input_batch, model=None): + request_max_tokens = input_batch["sampling_params"]["max_tokens"] + requested_max_tokens.append(request_max_tokens) + response = [10, 11, 12] if len(requested_max_tokens) == 1 else [20, 21] + response = response[:request_max_tokens] + return { + "responses": ["step"] * len(input_batch["prompt_token_ids"]), + "stop_reasons": ["length" if len(response) == request_max_tokens else "stop"], + "response_logprobs": None, + "response_ids": [response], + } + + mock_llm.generate = AsyncMock(side_effect=llm_generate_side_effect) + + generator = SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + generator.base_conversation_token_ids = [] + + output = await generator.agent_loop( + [{"role": "user", "content": "Initial prompt"}], + "test_env", + {}, + max_tokens=5, + max_input_length=512, + ) + + assert requested_max_tokens == [5, 2] + assert sum(output.loss_mask) == 5 + assert output.stop_reason == "length" + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_agent_loop_caps_request_to_max_model_len( + mock_make, mock_tokenizer, mock_llm, mock_env, generator_cfg, mock_env_cfg +): + """Per-request max_tokens should also fit inside vLLM's prompt+completion window.""" + mock_make.return_value = mock_env + generator_cfg.batched = False + generator_cfg.max_turns = 1 + generator_cfg.sampling_params.max_generate_length = 5 + generator_cfg.inference_engine.engine_init_kwargs["max_model_len"] = 6 + generator_cfg.use_conversation_multi_turn = True + mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) + mock_env.step.side_effect = lambda _output: BaseTextEnvStepOutput( + observations=[], + reward=1.0, + done=True, + metadata={}, + ) + requested_max_tokens = [] + + async def llm_generate_side_effect(input_batch, model=None): + requested_max_tokens.append(input_batch["sampling_params"]["max_tokens"]) + return { + "responses": ["done"], + "stop_reasons": ["stop"], + "response_logprobs": None, + "response_ids": [[10, 11]], + } + + mock_llm.generate = AsyncMock(side_effect=llm_generate_side_effect) + + generator = SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + + output = await generator.agent_loop( + [{"role": "user", "content": "Initial prompt"}], + "test_env", + {}, + max_tokens=5, + max_input_length=512, + ) + + assert requested_max_tokens == [2] + assert output.response_ids == [10, 11] + assert sum(output.loss_mask) == 2 + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_agent_loop_stops_when_max_model_len_has_no_decode_room( + mock_make, mock_tokenizer, mock_llm, mock_env, generator_cfg, mock_env_cfg +): + """A full context window should stop cleanly without issuing an invalid vLLM request.""" + mock_make.return_value = mock_env + generator_cfg.batched = False + generator_cfg.max_turns = 1 + generator_cfg.sampling_params.max_generate_length = 5 + generator_cfg.inference_engine.engine_init_kwargs["max_model_len"] = len(MOCK_TOKENIZER_ENCODED_IDS) + generator_cfg.use_conversation_multi_turn = True + mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) + + generator = SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + + output = await generator.agent_loop( + [{"role": "user", "content": "Initial prompt"}], + "test_env", + {}, + max_tokens=5, + max_input_length=512, + ) + + mock_llm.generate.assert_not_called() + mock_env.step.assert_not_called() + assert output.response_ids == [] + assert output.loss_mask == [] + assert output.reward == [] + assert output.stop_reason == "length" + + +@pytest.mark.asyncio +@patch("skyrl_gym.make") +async def test_generate_uses_input_sampling_params_max_tokens_budget( + mock_make, mock_tokenizer, mock_llm, mock_env, generator_cfg, mock_env_cfg +): + """Eval/custom sampling params should control the trajectory budget when provided.""" + mock_make.return_value = mock_env + generator_cfg.batched = False + generator_cfg.max_turns = 3 + generator_cfg.sampling_params.max_generate_length = 50 + generator_cfg.use_conversation_multi_turn = True + mock_env.init.return_value = ([{"role": "user", "content": "Initial input"}], {}) + mock_env.get_metrics.return_value = {} + mock_env.step.side_effect = lambda _output: BaseTextEnvStepOutput( + observations=[{"role": "user", "content": "next"}], + reward=1.0, + done=False, + metadata={}, + ) + requested_max_tokens = [] + + async def llm_generate_side_effect(input_batch, model=None): + requested_max_tokens.append(input_batch["sampling_params"]["max_tokens"]) + return { + "responses": ["step"], + "stop_reasons": ["length"], + "response_logprobs": None, + "response_ids": [[10, 11]], + } + + mock_llm.generate = AsyncMock(side_effect=llm_generate_side_effect) + + generator = SkyRLGymGenerator( + generator_cfg=generator_cfg, + skyrl_gym_cfg=mock_env_cfg, + inference_engine_client=mock_llm, + tokenizer=mock_tokenizer, + ) + + output = await generator.generate( + { + "prompts": [[{"role": "user", "content": "Initial prompt"}]], + "env_classes": [mock_env_cfg.env_class], + "env_extras": [{}], + "sampling_params": {"max_tokens": 2, "logprobs": None}, + }, + disable_tqdm=True, + ) + + assert requested_max_tokens == [2] + assert sum(output["loss_masks"][0]) == 2 + assert output["stop_reasons"] == ["length"] + + @pytest.mark.asyncio @patch("skyrl_gym.make") async def test_postprocessed_action_used(mock_make, mock_tokenizer, mock_llm, mock_env, mock_env_cfg, generator_cfg): @@ -738,7 +944,9 @@ def mock_apply_chat_template(messages, **kwargs): async def llm_generate_side_effect(input_batch, model=None): if input_batch.get("sampling_params") is not None: - max_len = input_batch["sampling_params"]["max_generate_length"] + max_len = input_batch["sampling_params"].get( + "max_tokens", input_batch["sampling_params"].get("max_generate_length") + ) else: max_len = generator_cfg.sampling_params.max_generate_length @@ -1164,7 +1372,9 @@ async def llm_generate_side_effect(input_batch, model=None): num = len(input_batch["prompt_token_ids"]) if "prompt_token_ids" in input_batch else len(input_batch["prompts"]) if input_batch.get("sampling_params") is not None: - max_len = input_batch["sampling_params"]["max_generate_length"] + max_len = input_batch["sampling_params"].get( + "max_tokens", input_batch["sampling_params"].get("max_generate_length") + ) else: max_len = cfg.sampling_params.max_generate_length diff --git a/tests/train/generators/test_skyrl_vlm_generator.py b/tests/train/generators/test_skyrl_vlm_generator.py index 675c212a99..d43054e9fb 100644 --- a/tests/train/generators/test_skyrl_vlm_generator.py +++ b/tests/train/generators/test_skyrl_vlm_generator.py @@ -127,6 +127,24 @@ async def mock_generate(input_batch, model=None): return AsyncMock(side_effect=mock_generate) +def _make_budget_capturing_mock_llm(tokenizer, response_text: str, requested_max_tokens: list[int]): + """Create an AsyncMock that captures per-turn max_tokens for VLM generation.""" + + async def mock_generate(input_batch, model=None): + requested_max_tokens.append(input_batch["sampling_params"]["max_tokens"]) + num_prompts = len(input_batch["prompt_token_ids"]) + text_with_eos = response_text + tokenizer.eos_token + ids = tokenizer.encode(text_with_eos, add_special_tokens=False) + return { + "responses": [response_text] * num_prompts, + "stop_reasons": ["stop"] * num_prompts, + "response_logprobs": None, + "response_ids": [ids] * num_prompts, + } + + return AsyncMock(side_effect=mock_generate) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -247,3 +265,36 @@ async def test_vlm_obs_tokens_match_expected(mock_decode): assert ( obs_text in decoded ), f"Obs segment {seg_idx}: expected '{obs_text}' in decoded obs tokens, got '{decoded}'" + + +@pytest.mark.asyncio +@patch("skyrl.train.generators.skyrl_vlm_generator.decode_mm_kwargs") +async def test_vlm_caps_each_request_to_remaining_generate_budget(mock_decode): + """VLM observation tokens should not consume the assistant generation budget.""" + mock_decode.return_value = {"pixel_values": None, "image_grid_thw": None} + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + response_text = "b" + response_ids = tokenizer.encode(response_text + tokenizer.eos_token, add_special_tokens=False) + max_generate_length = len(response_ids) * 2 + requested_max_tokens = [] + + generator = _build_vlm_generator(tokenizer) + generator.generator_cfg.sampling_params.max_generate_length = max_generate_length + generator.inference_engine_client.render_chat_completion = _make_mock_renderer(tokenizer) + generator.inference_engine_client.generate = _make_budget_capturing_mock_llm( + tokenizer, response_text, requested_max_tokens + ) + + prompt = [[{"role": "user", "content": "a"}]] + input_batch: GeneratorInput = { + "prompts": prompt, + "env_extras": [{"answer": "4"}], + "env_classes": ["cpu_vlm_test_env"], + } + + output: GeneratorOutput = await generator.generate(input_batch, disable_tqdm=True) + + assert requested_max_tokens == [max_generate_length, max_generate_length - len(response_ids)] + assert sum(output["loss_masks"][0]) == max_generate_length + assert output["stop_reasons"][0] == "length" diff --git a/tests/train/test_config.py b/tests/train/test_config.py index 7a8cd86401..9c25e0c33c 100644 --- a/tests/train/test_config.py +++ b/tests/train/test_config.py @@ -17,7 +17,7 @@ build_nested_dataclass, ) from skyrl.train.config.utils import get_legacy_config -from skyrl.train.utils.utils import validate_cfg +from skyrl.train.utils.utils import validate_cfg, validate_generator_cfg from tests.train.util import example_dummy_config @@ -217,6 +217,32 @@ def test_cross_field_defaults(): assert cfg.generator.rope_theta == cfg.trainer.rope_theta +class TestGeneratorMaxModelLenValidation: + """Tests for generator/vLLM context window validation.""" + + def test_validate_generator_cfg_rejects_prompt_over_max_model_len(self): + cfg = example_dummy_config() + cfg.generator.inference_engine.engine_init_kwargs["max_model_len"] = cfg.trainer.max_prompt_length - 1 + + with pytest.raises(ValueError, match="trainer.max_prompt_length"): + validate_generator_cfg(cfg) + + def test_validate_generator_cfg_rejects_input_over_max_model_len(self): + cfg = example_dummy_config() + cfg.generator.max_turns = 2 + cfg.generator.max_input_length = cfg.trainer.max_prompt_length + 1 + cfg.generator.inference_engine.engine_init_kwargs["max_model_len"] = cfg.trainer.max_prompt_length + + with pytest.raises(ValueError, match="generator.max_input_length"): + validate_generator_cfg(cfg) + + def test_validate_generator_cfg_allows_prompt_plus_generation_over_max_model_len(self): + cfg = example_dummy_config() + cfg.generator.inference_engine.engine_init_kwargs["max_model_len"] = cfg.trainer.max_prompt_length + + validate_generator_cfg(cfg) + + class TestMaxSeqLenValidation: """Tests for max_seq_len defaults and validation behavior."""