From 671e50dfea7491aeb31d1a3135e9f3532bdb9694 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 14:38:59 -0700 Subject: [PATCH 001/121] feat: add Fleet Task environment for skyrl-gym Add fleet_task environment that integrates Fleet-hosted tasks with SkyRL via OpenEnv's FleetTaskEnv abstraction layer. Supports multi-turn tool-use and computer-use (multimodal) modalities. - FleetTaskEnv(BaseTextEnv): provisions Fleet env, multi-turn episodes, reward via verifier, partial reward support, hint augmentation - Tool call parser: handles / tag formats with JSON repair for missing closing braces - Multimodal observations: returns image_url content blocks for CUA, compatible with upstream's extract_images_from_conversation() - Per-env metrics aggregation with environment breakdown - Context management integration for long trajectories - Trace upload support for eval telemetry Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/pyproject.toml | 3 + skyrl-gym/skyrl_gym/envs/__init__.py | 5 + .../skyrl_gym/envs/fleet_task/__init__.py | 10 + skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 874 ++++++++++++++++++ .../envs/fleet_task/tool_call_parser.py | 68 ++ 5 files changed, 960 insertions(+) create mode 100644 skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py create mode 100644 skyrl-gym/skyrl_gym/envs/fleet_task/env.py create mode 100644 skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py diff --git a/skyrl-gym/pyproject.toml b/skyrl-gym/pyproject.toml index b7beb59d9f..82c4078e51 100644 --- a/skyrl-gym/pyproject.toml +++ b/skyrl-gym/pyproject.toml @@ -31,6 +31,9 @@ include = ["skyrl_gym*"] dev = [ "pytest" ] +fleet = [ + "openenv[fleet]", +] [tool.black] line-length = 120 diff --git a/skyrl-gym/skyrl_gym/envs/__init__.py b/skyrl-gym/skyrl_gym/envs/__init__.py index 770b65e1e8..000c9cf23e 100644 --- a/skyrl-gym/skyrl_gym/envs/__init__.py +++ b/skyrl-gym/skyrl_gym/envs/__init__.py @@ -36,3 +36,8 @@ id="searchcode", entry_point="skyrl_gym.envs.searchcode.env:SearchCodeEnv", ) + +register( + id="fleet_task", + entry_point="skyrl_gym.envs.fleet_task.env:FleetTaskEnv", +) diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py b/skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py new file mode 100644 index 0000000000..922066c478 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/__init__.py @@ -0,0 +1,10 @@ +"""Fleet Task Environment for SkyRL-Gym. + +Provides a multi-turn tool-use environment backed by Fleet-hosted environments, +using OpenEnv's FleetTaskEnv as the abstraction layer. +""" + +from skyrl_gym.envs.fleet_task.env import FleetTaskEnv +from skyrl_gym.envs.fleet_task.tool_call_parser import parse_tool_call + +__all__ = ["FleetTaskEnv", "parse_tool_call"] diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py new file mode 100644 index 0000000000..7dfc2f07ff --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -0,0 +1,874 @@ +"""Fleet Task Environment for SkyRL-Gym. + +Provides a SkyRL-compatible environment wrapper for Fleet-hosted tasks. +Uses OpenEnv's FleetTaskEnv as the abstraction layer for Fleet environments, +keeping a clean separation between SkyRL's training interface and Fleet's +environment management. + +Multi-modal support: When the task modality is "computer_use", step() returns +multimodal observations in OpenAI format (image_url content blocks). Upstream +SkyRL's generator already handles these via extract_images_from_conversation() +and passes them as multi_modal_data to vLLM — no upstream changes needed. +""" + +import ast +import asyncio +import json +import logging +import os +import re +import time +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from skyrl_gym.envs.base_text_env import ( + BaseTextEnv, + BaseTextEnvStepOutput, + ConversationType, +) +from skyrl_gym.envs.fleet_task.tool_call_parser import parse_tool_call + +# Reduce MCP client log noise +try: + from loguru import logger as loguru_logger + + loguru_logger.disable("mcp") +except ImportError: + pass +logging.getLogger("mcp").setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + +# Global task cache to avoid reloading JSON for each env instance +_TASK_CACHE: Dict[str, Dict[str, Any]] = {} + + +def load_tasks_from_json(tasks_file: str) -> Dict[str, Any]: + """Load tasks from JSON file with caching. + + Returns a dict mapping task_key -> task_config dict. + """ + if tasks_file not in _TASK_CACHE: + expanded_path = os.path.expanduser(tasks_file) + if not os.path.exists(expanded_path): + raise FileNotFoundError(f"Tasks file not found: {expanded_path}") + + with open(expanded_path, "r") as f: + data = json.load(f) + + # Handle both formats: array or {"tasks": [...]} + if isinstance(data, list): + tasks = data + elif isinstance(data, dict) and "tasks" in data: + tasks = data["tasks"] + else: + raise ValueError( + f"Invalid JSON format in {tasks_file}: expected array or object with 'tasks' key" + ) + + if not tasks: + raise ValueError(f"No tasks found in {tasks_file}") + + # Index by task_key (support both 'key' and 'task_key' fields) + _TASK_CACHE[tasks_file] = { + t.get("key") or t.get("task_key"): t for t in tasks + } + + return _TASK_CACHE[tasks_file] + + +def clear_caches(): + """Clear global caches. Useful for testing.""" + global _TASK_CACHE + _TASK_CACHE = {} + + +class FleetTaskEnv(BaseTextEnv): + """SkyRL environment for Fleet-hosted tasks. + + Uses OpenEnv's FleetTaskEnv as the abstraction layer for Fleet environments. + This provides a clean separation between SkyRL's training interface and + Fleet's environment management. + + Constructor signature follows upstream convention: + __init__(self, env_config=None, extras={}) + + Where: + env_config: Dict or DictConfig from skyrl_gym_config YAML + extras: Per-sample data from the training dataset (task_key, max_turns, etc.) + """ + + _trace_config: Optional[Dict[str, str]] = None + + @classmethod + def set_trace_config(cls, job_id: str, model: str): + """Set trace config for uploading eval traces to Fleet.""" + cls._trace_config = {"job_id": job_id, "model": model} + + @classmethod + def clear_trace_config(cls): + """Clear trace config after eval is done.""" + cls._trace_config = None + + def __init__( + self, + env_config=None, + extras: Dict[str, Any] = {}, + ): + super().__init__() + + if env_config is None: + env_config = {} + + self.extras = extras + self.max_turns = extras.get("max_turns", 50) + + # Task configuration from extras (set by dataset) + self.task_key = extras.get("task_key") + self.tasks_file = ( + env_config.get("tasks_file") if hasattr(env_config, "get") else None + ) or extras.get("tasks_file") + + if not self.task_key: + raise ValueError("task_key must be provided in extras (from dataset)") + if not self.tasks_file: + raise ValueError( + "tasks_file must be provided in env_config or extras" + ) + + # Expand path + self.tasks_file = os.path.expanduser(self.tasks_file) + + # Load task config from JSON + tasks = load_tasks_from_json(self.tasks_file) + self.task_config = tasks.get(self.task_key) + if not self.task_config: + available_keys = list(tasks.keys())[:5] + raise ValueError( + f"Task '{self.task_key}' not found in {self.tasks_file}. " + f"Available keys (first 5): {available_keys}" + ) + + # API key + self.api_key = ( + env_config.get("api_key") if hasattr(env_config, "get") else None + ) or os.environ.get("FLEET_API_KEY") + if not self.api_key: + raise ValueError( + "FLEET_API_KEY must be set in env_config or environment" + ) + + # Logfire telemetry (no-op if LOGFIRE_TOKEN is not set) + logfire_token = os.environ.get("LOGFIRE_TOKEN") + if logfire_token: + try: + from envs.fleet_env import configure_fleet_telemetry + + configure_fleet_telemetry(token=logfire_token) + except ImportError: + pass + + # TTL for Fleet environment instances + self.ttl_seconds = ( + env_config.get("ttl_seconds") if hasattr(env_config, "get") else None + ) + + # Partial reward: use verifier accumulator counts instead of binary 0/1 + self.partial_reward = ( + env_config.get("partial_reward", False) + if hasattr(env_config, "get") + else False + ) + + # Hint config + self.enable_hints = ( + env_config.get("enable_hints", False) + if hasattr(env_config, "get") + else False + ) + + # Environment state (initialized on init()) + self.openenv_task_env = None + self.chat_history: ConversationType = [] + self.turns = 0 + self.tool_calls = 0 + self.tool_errors = 0 + self.last_reward: Optional[float] = None + self.tools: List[Dict[str, Any]] = [] + + # Verifier feedback (captured at close time for hint generation) + self._verifier_stdout: Optional[str] = None + self._verifier_error: Optional[str] = None + self._tool_error_messages: List[str] = [] + + # Context management (uses OpenEnv's ContextManager) + self.enable_context_tools = ( + env_config.get("enable_context_tools", False) + if hasattr(env_config, "get") + else False + ) + self.context_manager = None + if self.enable_context_tools: + try: + from envs.fleet_env import ContextManager + + logger.info( + "Enabling context management tools with " + f"max_output_chars={extras.get('max_output_chars', 10000)}" + ) + self.context_manager = ContextManager( + max_output_chars=extras.get("max_output_chars", 10000) + ) + except ImportError: + logger.warning( + "ContextManager not available, disabling context tools" + ) + + def _normalize_task_config(self) -> Dict[str, Any]: + """Normalize task config to OpenEnv's expected format.""" + config = self.task_config.copy() + + # Map field names if needed + if "key" in config and "task_key" not in config: + config["task_key"] = config["key"] + if "env_id" in config and "env_key" not in config: + config["env_key"] = config["env_id"] + if "version" in config and "env_version" not in config: + config["env_version"] = config["version"] + + return config + + async def init_async( + self, prompt: ConversationType + ) -> Tuple[ConversationType, Dict[str, Any]]: + """Initialize the Fleet environment and return initial observation. + + Creates Fleet environment via OpenEnv's FleetTaskEnv and returns + the task prompt with tool definitions. + """ + from envs.fleet_env import FleetTaskEnv as OpenEnvFleetTaskEnv + + # Close any existing environment + self.close() + + # Create OpenEnv's FleetTaskEnv with normalized config + task_config = self._normalize_task_config() + + try: + self.openenv_task_env = OpenEnvFleetTaskEnv( + task_config=task_config, + api_key=self.api_key, + ttl_seconds=self.ttl_seconds, + max_steps=self.max_turns, + partial_reward=self.partial_reward, + ) + except Exception as e: + raise RuntimeError( + f"Failed to create OpenEnv FleetTaskEnv: {e}" + ) from e + + # Reset episode state (tools are already cached from __init__) + obs = await self.openenv_task_env.reset_async() + + # Reset state + self.turns = 0 + self.tool_calls = 0 + self.tool_errors = 0 + self.last_reward = None + + # Reset context manager if enabled + if self.context_manager: + self.context_manager.reset() + + # Get tools from observation + self.tools = obs.get("tools", []) + + # Add context management tools if enabled + if self.context_manager: + self.tools = self.tools + self.context_manager.get_tools() + if not self.tools: + raise RuntimeError( + f"Task {self.task_key}: no tools found. Fleet env requires tools." + ) + + # Build initial prompt with task instruction + task_prompt = self.task_config.get("prompt", "") + + # Inject hint from previous failed attempt if provided + hint = self.extras.get("hint") + if hint: + task_prompt = ( + f"{task_prompt}\n\nHere is feedback from a previous attempt " + f"to help you:\n{hint}" + ) + + # Build system prompt with tool definitions + tools_json = json.dumps(self.tools, indent=2) + current_date = datetime.now().strftime("%Y-%m-%d") + + # Build environment context section from env_variables + env_context = "" + env_vars = self.task_config.get("env_variables", {}) + if env_vars: + env_lines = [] + if "LOGGED_IN_USER" in env_vars: + env_lines.append( + f"- Logged in user ID: {env_vars['LOGGED_IN_USER']}" + ) + if "LOGGED_IN_NAME" in env_vars: + env_lines.append( + f"- Logged in as: {env_vars['LOGGED_IN_NAME']}" + ) + for key, value in env_vars.items(): + if key not in ( + "LOGGED_IN_USER", + "LOGGED_IN_NAME", + "CURRENT_DATE", + ): + env_lines.append(f"- {key}: {value}") + if env_lines: + env_context = ( + "\n## Environment Context\n" + + "\n".join(env_lines) + + "\n" + ) + + # Add environment-specific hints + env_key = self.task_config.get("env_key") or self.task_config.get( + "env_id" + ) + env_hints = "" + if env_key == "fostgres": + env_hints = ( + "\n## Database Exploration\n" + "Before writing SQL queries, first explore the database schema:\n" + "- List tables: SELECT table_name FROM information_schema.tables " + "WHERE table_schema = 'public'\n" + "- List columns: SELECT column_name, data_type FROM " + "information_schema.columns WHERE table_name = 'your_table'\n" + ) + + system_content = ( + f"You are a helpful agent. Complete the task by calling tools.\n\n" + f"## Current Date\n" + f"Today's date is {current_date}. When dates are mentioned without " + f"a year, assume the current year ({datetime.now().year}) or a " + f"future date.\n" + f"{env_context}{env_hints}\n" + f"## Available Tools\n{tools_json}\n\n" + f"## Tool Call Format\n" + f'{{"name": "tool_name", "arguments": ' + f'{{"param": "value"}}}}\n\n' + f"## Error Handling\n" + f"If a tool call returns an error:\n" + f"- Read the error message carefully\n" + f"- Do NOT repeat the same call with identical arguments\n" + f"- Change your approach: use different parameters, try a different " + f"tool, or break the task into smaller steps\n\n" + f"## Response Format\n" + f"EVERY response MUST end with exactly ONE of:\n" + f"1. A tool call: ... - to perform an action\n" + f"2. Done signal: - ONLY when the task is fully complete\n\n" + f"IMPORTANT: When the task is complete, first output your final " + f"answer with the requested information, THEN say . Do not " + f"just say without providing the answer.\n\n" + f"NEVER respond with just a message. NEVER say \"feel free to ask\" " + f"or offer further help.\n" + f"If the task is complete, provide your answer then say . " + f"Otherwise, make a tool call." + ) + + system_message = {"role": "system", "content": system_content} + user_message = {"role": "user", "content": task_prompt} + self.chat_history = [system_message, user_message] + + metadata = { + "task_key": self.task_key, + "env_key": env_key, + "tools": self.tools, + "modality": self.task_config.get("task_modality", "tool_use"), + } + + return self.chat_history.copy(), metadata + + def init( + self, prompt: ConversationType + ) -> Tuple[ConversationType, Dict[str, Any]]: + """Initialize the Fleet environment (sync wrapper). + + Uses asyncio.run() for sync contexts. For async contexts, the upstream + generator's _run_in_executor_if_available will call this in a thread pool, + where asyncio.run() is safe. + """ + return asyncio.run(self.init_async(prompt)) + + async def step_async(self, action: str) -> BaseTextEnvStepOutput: + """Execute one step in the Fleet environment. + + Parses the action for tool calls, executes via OpenEnv's FleetTaskEnv, + and returns observation. Reward is computed by the verifier on completion. + + For computer_use modality, observations may include multimodal content + (image_url blocks with base64 screenshots). Upstream SkyRL's generator + handles these via extract_images_from_conversation(). + """ + step_start = time.time() + self.turns += 1 + assistant_msg = {"role": "assistant", "content": action} + self.chat_history.append(assistant_msg) + if self.context_manager: + self.context_manager.track_message(assistant_msg) + + max_turns_reached = self.turns >= self.max_turns + + # Check if agent signals completion + agent_done = "" in action.lower() or "[done]" in action.lower() + + # Parse tool call from LLM response + tool_call = parse_tool_call(action) + + tool_result = None + error = None + reward = 0.0 + mcp_time = 0.0 + + # Handle context management tools locally (no MCP call) + if ( + tool_call + and self.context_manager + and self.context_manager.is_context_tool(tool_call["name"]) + ): + tool_result, self.chat_history = self.context_manager.execute_tool( + tool_call["name"], + tool_call.get("arguments", {}), + self.chat_history, + ) + # Execute tool call if present via OpenEnv + elif tool_call and self.openenv_task_env: + self.tool_calls += 1 + openenv_action = { + "tool": tool_call["name"], + "params": tool_call.get("arguments", {}), + "done": agent_done, + } + + try: + mcp_start = time.time() + obs, reward, done, info = ( + await self.openenv_task_env.step_async(openenv_action) + ) + mcp_time = time.time() - mcp_start + tool_result = obs.get("observation") + if "tool_error" in info: + error = info["tool_error"] + + # Truncate long outputs if context management is enabled + if ( + tool_result + and isinstance(tool_result, str) + and self.context_manager + ): + tool_result = self.context_manager.truncate_output( + tool_result + ) + except Exception as e: + mcp_time = time.time() - mcp_start + error = str(e) + elif agent_done and self.openenv_task_env: + # Agent signaled done without tool call + openenv_action = {"done": True} + try: + mcp_start = time.time() + obs, reward, done, info = ( + await self.openenv_task_env.step_async(openenv_action) + ) + mcp_time = time.time() - mcp_start + except Exception as e: + mcp_time = time.time() - mcp_start + error = str(e) + + # Detect error patterns in tool_result + if not error and tool_result: + result_str = ( + str(tool_result) + if not isinstance(tool_result, str) + else tool_result + ) + if result_str.strip().startswith( + "Error:" + ) or result_str.strip().startswith("error:"): + error = result_str + tool_result = None + elif isinstance(tool_result, dict) and tool_result.get("error"): + error = tool_result["error"] + tool_result = None + + episode_done = agent_done or max_turns_reached + + # Upload trace at episode end if trace config is set + if episode_done and FleetTaskEnv._trace_config: + try: + from envs.fleet_env.trace import upload_trace + + inst_id = None + orch = getattr(self.openenv_task_env, "_orch", None) + if orch: + fleet_env = getattr(orch, "_fleet_env", None) + if fleet_env: + inst_id = getattr(fleet_env, "instance_id", None) + await upload_trace( + api_key=self.api_key, + job_id=FleetTaskEnv._trace_config["job_id"], + task_key=self.task_key, + model=FleetTaskEnv._trace_config["model"], + chat_history=self.chat_history, + reward=reward, + instance_id=inst_id, + metadata={ + "env_key": self.task_config.get("env_key"), + "turns": self.turns, + }, + ) + except Exception as e: + logger.warning( + f"Failed to upload trace for {self.task_key}: {e}" + ) + + # Build observation message + if max_turns_reached: + return BaseTextEnvStepOutput( + observations=[], + reward=reward, + done=True, + metadata={ + "done_reason": "max_turns", + "task_key": self.task_key, + }, + ) + + # Build response observation + if error: + self.tool_errors += 1 + self._tool_error_messages.append(str(error)[:500]) + obs_content = f"Error: {error}" + elif tool_result: + # Handle multimodal results (list with image_url blocks) + if isinstance(tool_result, list): + # Multimodal: return as structured content for VL models + new_obs = {"role": "user", "content": tool_result} + self.chat_history.append(new_obs) + if self.context_manager: + self.context_manager.track_message(new_obs) + + step_time = time.time() - step_start + metadata = { + "task_key": self.task_key, + "turn": self.turns, + "tool_call": tool_call, + "error": None, + "done_reason": "agent_done" if agent_done else None, + "step_time": step_time, + "mcp_time": mcp_time, + } + return BaseTextEnvStepOutput( + observations=[new_obs], + reward=reward, + done=episode_done, + metadata=metadata, + ) + elif isinstance(tool_result, dict): + obs_content = ( + f"Tool result:\n{json.dumps(tool_result, indent=2)}" + ) + else: + obs_content = f"Tool result:\n{tool_result}" + elif agent_done: + obs_content = "Task marked as complete." + elif not tool_call: + obs_content = ( + "No tool call found. Use " + '{"name": "...", "arguments": {...}} ' + "format." + ) + else: + obs_content = "Action executed." + + new_obs = {"role": "user", "content": obs_content} + self.chat_history.append(new_obs) + if self.context_manager: + self.context_manager.track_message(new_obs) + + step_time = time.time() - step_start + metadata = { + "task_key": self.task_key, + "turn": self.turns, + "tool_call": tool_call, + "tool_result": ( + tool_result[:200] + if isinstance(tool_result, str) and len(tool_result) > 200 + else tool_result + ), + "error": error, + "done_reason": "agent_done" if agent_done else None, + "step_time": step_time, + "mcp_time": mcp_time, + } + + # If context was modified, return full chat_history so the generator + # can replace its copy (required for stepwise training). + if ( + tool_call + and self.context_manager + and self.context_manager.is_context_tool(tool_call["name"]) + ): + if tool_call["name"] == "manage_context": + metadata["modified_chat_history"] = self.chat_history.copy() + + return BaseTextEnvStepOutput( + observations=[new_obs], + reward=reward, + done=episode_done, + metadata=metadata, + ) + + def step(self, action: str) -> BaseTextEnvStepOutput: + """Execute one step in the Fleet environment (sync wrapper).""" + return asyncio.run(self.step_async(action)) + + def _capture_verifier_feedback(self): + """Capture verifier feedback from OpenEnv before nulling the env.""" + if self.openenv_task_env: + self._verifier_stdout = getattr( + self.openenv_task_env, "verifier_stdout", None + ) + self._verifier_error = getattr( + self.openenv_task_env, "verifier_error", None + ) + self._tool_error_messages = getattr( + self.openenv_task_env, "tool_errors_list", [] + ) + + def close(self): + """Close the Fleet environment and cleanup resources.""" + if self.openenv_task_env: + try: + self.openenv_task_env.close() + if self.openenv_task_env.final_reward is not None: + self.last_reward = self.openenv_task_env.final_reward + self._capture_verifier_feedback() + except Exception as e: + logger.warning(f"Failed to close Fleet environment: {e}") + self.openenv_task_env = None + + async def close_async(self): + """Close the Fleet environment (async version). + + Runs verifier via OpenEnv's close_async() to get actual reward for + orphaned rollouts (context overflow, early termination by SkyRL). + """ + if self.openenv_task_env: + try: + await self.openenv_task_env.close_async() + if self.openenv_task_env.final_reward is not None: + self.last_reward = self.openenv_task_env.final_reward + self._capture_verifier_feedback() + except Exception as e: + logger.warning(f"Failed to close Fleet environment: {e}") + self.openenv_task_env = None + + def get_metrics(self) -> Dict[str, Any]: + """Return environment metrics for this episode.""" + metrics = { + "task_key": self.task_key, + "env_key": self.task_config.get("env_key") + or self.task_config.get("env_id"), + "turns": self.turns, + "tool_calls": self.tool_calls, + "tool_errors": self.tool_errors, + "is_hinted": bool(self.extras.get("hint")), + } + if self.last_reward is not None: + metrics["final_reward"] = self.last_reward + # Include verifier feedback for hint generation + if self._verifier_stdout is not None: + metrics["verifier_stdout"] = self._verifier_stdout + if self._verifier_error is not None: + metrics["verifier_error"] = self._verifier_error + if self._tool_error_messages: + metrics["tool_error_messages"] = self._tool_error_messages + return metrics + + @staticmethod + def build_hint_text( + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], + ) -> str: + """Build hint text from verifier feedback. No LLM call. + + Parses ERROR_ACCUMULATOR / SUCCESS_ACCUMULATOR from verifier stdout + and formats tool errors into structured feedback for the next attempt. + """ + parts = [] + + if verifier_stdout: + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n" + r"<<< SUCCESS_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + if err_match or suc_match: + try: + errors = ( + ast.literal_eval(err_match.group(1)) + if err_match + else [] + ) + successes = ( + ast.literal_eval(suc_match.group(1)) + if suc_match + else [] + ) + except Exception: + errors, successes = [], [] + if successes: + parts.append( + f"Checks passed ({len(successes)}): " + + ", ".join( + str(s)[:100] for s in successes[:5] + ) + ) + if errors: + parts.append( + f"Checks failed ({len(errors)}): " + + ", ".join(str(e)[:100] for e in errors[:5]) + ) + + if verifier_error: + parts.append(f"Verifier: {verifier_error}") + + if tool_error_messages: + unique = list(dict.fromkeys(tool_error_messages))[:5] + parts.append( + "Tool errors: " + "; ".join(e[:200] for e in unique) + ) + + return ( + "\n".join(parts) + if parts + else "The previous attempt failed. Try a different approach." + ) + + @staticmethod + def aggregate_metrics( + metrics: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """Aggregate metrics across episodes with per-env breakdown.""" + if not metrics: + return {} + + env_init_failures: Dict[str, int] = {} + total_init_failures = 0 + + env_data: Dict[str, Dict[str, List[int]]] = {} + for m in metrics: + # Check for init failure metrics + for key, value in m.items(): + if key.startswith("env_init_failed/"): + env_key = key.split("/", 1)[1] + env_init_failures[env_key] = ( + env_init_failures.get(env_key, 0) + int(value) + ) + total_init_failures += int(value) + + env_key = m.get("env_key") + if env_key: + if env_key not in env_data: + env_data[env_key] = { + "turns": [], + "tool_calls": [], + "tool_errors": [], + } + env_data[env_key]["turns"].append(m.get("turns", 0)) + env_data[env_key]["tool_calls"].append( + m.get("tool_calls", 0) + ) + env_data[env_key]["tool_errors"].append( + m.get("tool_errors", 0) + ) + + result: Dict[str, Any] = {} + total_turns = 0 + total_tool_calls = 0 + total_tool_errors = 0 + total_episodes = 0 + + for env_key, data in env_data.items(): + turns_list = data["turns"] + tool_calls_list = data["tool_calls"] + tool_errors_list = data["tool_errors"] + + avg_turns = sum(turns_list) / len(turns_list) + avg_tool_calls = sum(tool_calls_list) / len(tool_calls_list) + avg_tool_errors = sum(tool_errors_list) / len(tool_errors_list) + total_env_turns = sum(turns_list) + total_env_tool_calls = sum(tool_calls_list) + total_env_tool_errors = sum(tool_errors_list) + tool_calls_per_turn = ( + total_env_tool_calls / total_env_turns + if total_env_turns > 0 + else 0 + ) + tool_error_rate = ( + total_env_tool_errors / total_env_tool_calls + if total_env_tool_calls > 0 + else 0 + ) + + result[f"{env_key}/avg_turns"] = avg_turns + result[f"{env_key}/min_turns"] = min(turns_list) + result[f"{env_key}/max_turns"] = max(turns_list) + result[f"{env_key}/avg_tool_calls"] = avg_tool_calls + result[f"{env_key}/tool_calls_per_turn"] = tool_calls_per_turn + result[f"{env_key}/avg_tool_errors"] = avg_tool_errors + result[f"{env_key}/total_tool_errors"] = total_env_tool_errors + result[f"{env_key}/tool_error_rate"] = tool_error_rate + result[f"{env_key}/num_episodes"] = len(turns_list) + + total_turns += total_env_turns + total_tool_calls += total_env_tool_calls + total_tool_errors += total_env_tool_errors + total_episodes += len(turns_list) + + result["avg_turns"] = ( + total_turns / total_episodes if total_episodes > 0 else 0 + ) + result["avg_tool_calls"] = ( + total_tool_calls / total_episodes if total_episodes > 0 else 0 + ) + result["tool_calls_per_turn"] = ( + total_tool_calls / total_turns if total_turns > 0 else 0 + ) + result["avg_tool_errors"] = ( + total_tool_errors / total_episodes if total_episodes > 0 else 0 + ) + result["total_tool_errors"] = total_tool_errors + result["tool_error_rate"] = ( + total_tool_errors / total_tool_calls + if total_tool_calls > 0 + else 0 + ) + result["total_episodes"] = total_episodes + + for env_key, failures in env_init_failures.items(): + result[f"{env_key}/env_init_failed"] = failures + if total_init_failures > 0: + result["total_env_init_failed"] = total_init_failures + + return result diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py b/skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py new file mode 100644 index 0000000000..bec243a9e4 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/tool_call_parser.py @@ -0,0 +1,68 @@ +"""Tool call parser for LLM-generated tool calls. + +Parses tool calls from various tag-based formats commonly produced by LLMs: +- {"name": "...", "arguments": {...}} +- {"name": "...", "arguments": {...}} + +Handles missing closing tags (e.g., when is the stop string) +and repairs common JSON issues like missing trailing braces. +""" + +import json +import re +from typing import Any, Dict, Optional + + +def _try_parse_json(raw: str) -> Optional[Dict[str, Any]]: + """Try to parse JSON, repairing missing trailing braces if needed.""" + raw = raw.strip() + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + pass + + # Repair: models often drop trailing closing braces on nested JSON. + # Try appending up to 3 closing braces. + for extra in range(1, 4): + try: + parsed = json.loads(raw + "}" * extra) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + continue + + return None + + +def parse_tool_call(action: str) -> Optional[Dict[str, Any]]: + """Parse tool call from LLM response. + + Supports tag-based formats: + - {"name": "...", "arguments": {...}} + - {"name": "...", "arguments": {...}} + + Also handles cases where the closing tag is missing (e.g., when + is used as the stop string and not included in the output). + + Returns: + Dict with "name" and "arguments" keys, or None if no tool call found. + """ + for tag in ["tool_call", "function_call"]: + # First try with closing tag + match = re.search(rf"<{tag}>(.*?)", action, re.DOTALL) + if not match: + # Try without closing tag (for when is the stop string) + match = re.search(rf"<{tag}>(.*?)(?:<\||\Z)", action, re.DOTALL) + if match: + parsed = _try_parse_json(match.group(1)) + if parsed is None: + continue + # Normalize keys + name = parsed.get("name") or parsed.get("tool") + args = parsed.get("arguments") or parsed.get("params", {}) + if name: + return {"name": name, "arguments": args} + + return None From 35d9513680f753db17c737af8e4eb3609957a37b Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 14:47:46 -0700 Subject: [PATCH 002/121] feat: add Fleet training integration with entrypoints, scripts, and configs Port Fleet-specific training infrastructure from fork to fresh SkyRL-v2: Entrypoints: - main_fleet.py: GRPO training on Fleet-hosted envs with S3 checkpoints - main_task_gen.py: Task generation training entrypoint - main_fleet_tinker.py: Tinker-based training with Fleet envs (LoRA, async) Dataset & Checkpoints: - prepare_dataset.py: Convert Fleet task JSON to SkyRL parquet format (stratified split, dedup, env capping, difficulty filtering) - s3_checkpoints.py: Async S3 upload, cross-VM resume, local cleanup - export_tasks.py: CLI to export tasks from Fleet API Training Scripts: - fleet-common-setup.sh: Shared setup (deps, OpenEnv, dataset download) - fleet-common-run.sh: Multi-node Ray cluster + training launch - fleet-35b-run.sh: Qwen3.5-35B config (TP=2, multi-node) - fleet-qwen35-extra-setup.sh: Qwen3.5 deps (transformers 5.3, flash-attn) - fleet-task-gen-run.sh: Task generation config SkyPilot YAML Configs: - openenv-fleet-grpo-qwen3_5-35b.yaml: 2-node H200 training - task-gen-grpo-qwen3_5-9b.yaml: Single-node task gen Also adds fleet_task and task_gen config to skyrl_gym_config/default.yaml. Co-Authored-By: Claude Opus 4.6 --- integrations/__init__.py | 1 + integrations/fleet/__init__.py | 15 + integrations/fleet/entrypoints/__init__.py | 1 + integrations/fleet/entrypoints/main_fleet.py | 94 ++ .../fleet/entrypoints/main_fleet_tinker.py | 939 ++++++++++++++++++ .../fleet/entrypoints/main_task_gen.py | 80 ++ integrations/fleet/export_tasks.py | 71 ++ integrations/fleet/prepare_dataset.py | 564 +++++++++++ integrations/fleet/reward_metrics.py | 253 +++++ integrations/fleet/s3_checkpoints.py | 439 ++++++++ integrations/fleet/task_gen_reward.py | 89 ++ integrations/fleet/tests/__init__.py | 0 integrations/fleet/utils.py | 118 +++ scripts/fleet-35b-run.sh | 90 ++ scripts/fleet-common-run.sh | 313 ++++++ scripts/fleet-common-setup.sh | 130 +++ scripts/fleet-qwen35-extra-setup.sh | 71 ++ scripts/fleet-task-gen-run.sh | 81 ++ .../config/skyrl_gym_config/default.yaml | 17 + tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 69 ++ tasks/task-gen-grpo-qwen3_5-9b.yaml | 59 ++ 21 files changed, 3494 insertions(+) create mode 100644 integrations/__init__.py create mode 100644 integrations/fleet/__init__.py create mode 100644 integrations/fleet/entrypoints/__init__.py create mode 100644 integrations/fleet/entrypoints/main_fleet.py create mode 100644 integrations/fleet/entrypoints/main_fleet_tinker.py create mode 100644 integrations/fleet/entrypoints/main_task_gen.py create mode 100644 integrations/fleet/export_tasks.py create mode 100644 integrations/fleet/prepare_dataset.py create mode 100644 integrations/fleet/reward_metrics.py create mode 100644 integrations/fleet/s3_checkpoints.py create mode 100644 integrations/fleet/task_gen_reward.py create mode 100644 integrations/fleet/tests/__init__.py create mode 100644 integrations/fleet/utils.py create mode 100755 scripts/fleet-35b-run.sh create mode 100755 scripts/fleet-common-run.sh create mode 100755 scripts/fleet-common-setup.sh create mode 100755 scripts/fleet-qwen35-extra-setup.sh create mode 100755 scripts/fleet-task-gen-run.sh create mode 100644 tasks/openenv-fleet-grpo-qwen3_5-35b.yaml create mode 100644 tasks/task-gen-grpo-qwen3_5-9b.yaml diff --git a/integrations/__init__.py b/integrations/__init__.py new file mode 100644 index 0000000000..2c8aa36f02 --- /dev/null +++ b/integrations/__init__.py @@ -0,0 +1 @@ +# Fleet integrations for SkyRL diff --git a/integrations/fleet/__init__.py b/integrations/fleet/__init__.py new file mode 100644 index 0000000000..88f3247b76 --- /dev/null +++ b/integrations/fleet/__init__.py @@ -0,0 +1,15 @@ +# Fleet Task Environment Integration for SkyRL +# +# This module provides a SkyRL-compatible environment wrapper for Fleet-hosted tasks. +# It uses OpenEnv's FleetTaskEnv as the abstraction layer. + +__all__ = ["FleetTaskEnv"] + + +def __getattr__(name: str): + """Lazy import to avoid import errors when dependencies are not installed.""" + if name == "FleetTaskEnv": + from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + + return FleetTaskEnv + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/integrations/fleet/entrypoints/__init__.py b/integrations/fleet/entrypoints/__init__.py new file mode 100644 index 0000000000..3ba7648c80 --- /dev/null +++ b/integrations/fleet/entrypoints/__init__.py @@ -0,0 +1 @@ +# Fleet entrypoints diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py new file mode 100644 index 0000000000..5f8d9c82f3 --- /dev/null +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -0,0 +1,94 @@ +""" +Fleet Task Training Entrypoint for SkyRL. + +Registers the FleetTaskEnv and runs GRPO training on Fleet-hosted environments +with S3 checkpoint management. + +Usage: + python -m integrations.fleet.entrypoints.main_fleet \ + environment.env_class=fleet_task \ + environment.skyrl_gym.fleet_task.tasks_file=/path/to/tasks.json \ + data.train_data=./data/fleet/train.parquet \ + data.val_data=./data/fleet/validation.parquet + +Environment Variables for S3 Checkpoint Management: + AWS_ACCESS_KEY_ID: AWS access key + AWS_SECRET_ACCESS_KEY: AWS secret key + AWS_REGION: AWS region (default: us-east-1) + S3_CHECKPOINT_BUCKET: S3 bucket name (default: skyrl-checkpoints) + RESUME_RUN_NAME: Run name to resume from (downloads checkpoint from S3) +""" + +import asyncio +import logging +import os +from pathlib import Path + +import hydra +import ray +from skyrl_gym.envs import register +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +logger = logging.getLogger(__name__) + + +class FleetPPOExp(BasePPOExp): + """Fleet-specific PPO experiment with S3 checkpoint management.""" + + def run(self): + trainer = self._setup_trainer() + + # Download checkpoint from S3 if RESUME_RUN_NAME is set (cross-VM resume) + resume_run_name = os.environ.get("RESUME_RUN_NAME", "") + if resume_run_name: + try: + from integrations.fleet.s3_checkpoints import download_checkpoint_from_s3 + + ckpt_path = trainer.cfg.trainer.ckpt_path + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + download_checkpoint_from_s3( + ckpt_path=ckpt_path, + run_name=resume_run_name, + project_name=project_name, + model_name=model_name, + ) + except Exception as e: + logger.warning(f"Failed to download checkpoint from S3: {e}") + + # Wrap trainer for checkpoint management (cleanup + S3 upload) + try: + from integrations.fleet.s3_checkpoints import wrap_trainer_with_s3_upload + + trainer = wrap_trainer_with_s3_upload(trainer) + except Exception as e: + logger.warning(f"Failed to setup checkpoint management: {e}") + + asyncio.run(trainer.train()) + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: SkyRLTrainConfig): + """Ray remote function that registers Fleet environment and runs training.""" + register( + id="fleet_task", + entry_point="skyrl_gym.envs.fleet_task.env:FleetTaskEnv", + ) + exp = FleetPPOExp(cfg) + exp.run() + + +@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) +def main(cfg: SkyRLTrainConfig) -> None: + """Main entry point for Fleet task training.""" + validate_cfg(cfg) + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/entrypoints/main_fleet_tinker.py b/integrations/fleet/entrypoints/main_fleet_tinker.py new file mode 100644 index 0000000000..a442f42420 --- /dev/null +++ b/integrations/fleet/entrypoints/main_fleet_tinker.py @@ -0,0 +1,939 @@ +""" +Fleet Task Training with Tinker Backend. + +This entrypoint uses Tinker (hosted) for training and inference, +combined with Fleet environments via OpenEnv for rollout collection. + +Usage: + python -m integrations.fleet.entrypoints.main_fleet_tinker \ + --model-name Qwen/Qwen3-VL-30B-A3B-Instruct \ + --tasks-file /path/to/tasks.json \ + --dataset-file /path/to/train.parquet \ + --eval-dataset-file /path/to/validation.parquet + +Environment Variables: + TINKER_API_KEY: Tinker API key for authentication (required) + TINKER_API_URL: Tinker service URL (optional, SDK uses default if not set) + FLEET_API_KEY: Fleet API key for environment access + WANDB_API_KEY: Weights & Biases API key for logging + +Architecture: + 1. Load tasks from JSON file (same format as SkyRL Fleet integration) + 2. For each training step: + a. Save current model weights for sampling + b. Create SamplingClient from Tinker + c. Collect rollouts using FleetTaskEnv (OpenEnv) + Tinker inference + d. Compute GRPO advantages + e. Train using Tinker's forward_backward + optim_step + 3. Checkpoints saved via Tinker API + +Metrics (matching SkyRL): + - reward/avg_pass_at_{n}: Pass@k across all prompts + - reward/variance_per_prompt: Mean within-prompt reward variance (GRPO learning signal) + - reward/{env_key}/pass_at_{n}: Per-environment pass@k + - reward/{env_key}/variance_per_prompt: Per-environment variance (learning signal) + - eval/all/pass_at_{n}: Evaluation pass@k + - eval/{env_key}/pass_at_{n}: Per-environment eval pass@k +""" + +import asyncio +import logging +import os +import random +import time +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np +from pydantic import BaseModel +from tqdm import tqdm +import tinker +import torch +import wandb +from tinker import types +from tinker.types.tensor_data import TensorData +from transformers import AutoTokenizer +from datasets import load_dataset +from torch.utils.data import DataLoader + +# Use SkyRL's FleetTaskEnv wrapper (now supports async via init_async/step_async) +from omegaconf import OmegaConf +from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + +# Import SkyRL's overlong filtering for parity +from skyrl.train.generators.utils import apply_overlong_filtering + +# Import shared metrics module for consistent metric calculation with SkyRL trainer +from integrations.fleet.reward_metrics import ( + compute_pass_at_n as _compute_pass_at_n, + compute_reward_metrics, + compute_per_group_metrics, + sanitize_metric_key, +) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger(__name__) +logging.getLogger("httpx").setLevel(logging.WARNING) +logging.getLogger("mcp").setLevel(logging.WARNING) + +# Thread pool for env operations - isolates MCP connections per thread (like SkyRL) +_env_executor: ThreadPoolExecutor = None + + +def _get_env_executor(max_workers: int = 16) -> ThreadPoolExecutor: + global _env_executor + if _env_executor is None: + _env_executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="fleet-env-") + return _env_executor + + +async def _run_in_executor(func, *args): + """Run sync function in thread pool - each thread gets isolated event loop/connections.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(_get_env_executor(), func, *args) + + +class RolloutOutput(BaseModel): + """Output from a single rollout collection.""" + + prompt_ids: List[int] + response_ids: List[int] + logprobs: List[float] + loss_mask: List[int] + reward: float + task_key: str + env_key: str + turns: int + tool_calls: int + tool_errors: int = 0 # Count of tool call errors in this rollout + stop_reason: str + duration: float + # Timing breakdown for WandB + total_gen_time: float = 0.0 # Total Tinker generation time + total_step_time: float = 0.0 # Total MCP/Fleet step time + total_tokens: int = 0 # Total tokens generated + error: Optional[str] = None + + class Config: + frozen = True + + +def set_seed(seed: int): + """Set random seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def normalize_advantages(advantages: List[float]) -> List[float]: + """Normalize advantages to have mean 0 and std 1.""" + if not advantages or len(advantages) == 1: + return advantages + mean = np.mean(advantages) + std = np.std(advantages) + if std < 1e-8: + return [0.0] * len(advantages) + return [(a - mean) / (std + 1e-8) for a in advantages] + + +def compute_advantages_grpo( + rewards: List[float], + group_size: int = None, + normalize: bool = True, +) -> List[float]: + """ + GRPO (Group Relative Policy Optimization) advantage estimation. + + For each group of trajectories from the same prompt, compute advantages + as deviations from the group mean. + """ + rewards = np.array(rewards) + + if group_size is None: + group_size = len(rewards) + + n_groups = len(rewards) // group_size + advantages = [] + + for i in range(n_groups): + start_idx = i * group_size + end_idx = start_idx + group_size + group_rewards = rewards[start_idx:end_idx] + group_mean = group_rewards.mean() + group_advantages = group_rewards - group_mean + advantages.extend(group_advantages.tolist()) + + remaining = len(rewards) % group_size + if remaining > 0: + remaining_rewards = rewards[-remaining:] + remaining_mean = remaining_rewards.mean() + advantages.extend((remaining_rewards - remaining_mean).tolist()) + + if normalize: + advantages = normalize_advantages(advantages) + + return advantages + + +def compute_pass_at_n(rollouts: List[Dict[str, Any]], n_samples_per_prompt: int) -> float: + """ + Compute pass@n metric using the shared metrics module. + + For each unique prompt (task_key), if ANY of the n trajectories has reward > 0, + that counts as a "pass". + + This function is a thin wrapper around the shared compute_pass_at_n for backward + compatibility with the rollout dict format. + """ + rewards = [r.get("reward", 0.0) for r in rollouts] + uids = [r.get("task_key", "unknown") for r in rollouts] + return _compute_pass_at_n(rewards, uids) + + +def compute_per_env_metrics(rollouts: List[Dict[str, Any]], n_samples_per_prompt: int) -> Dict[str, float]: + """ + Compute per-environment metrics using the shared metrics module. + + This function is a thin wrapper around the shared compute_per_group_metrics for + backward compatibility with the rollout dict format. + """ + rewards = [r.get("reward", 0.0) for r in rollouts] + uids = [r.get("task_key", "unknown") for r in rollouts] + env_keys = [r.get("env_key", "unknown") for r in rollouts] + + return compute_per_group_metrics( + rewards=rewards, + uids=uids, + groups=env_keys, + n_samples_per_prompt=n_samples_per_prompt, + prefix="reward", + ) + + +def compute_rollout_metrics( + rollouts: List[Dict[str, Any]], + valid_rollouts: List[Dict[str, Any]], + rewards: List[float], + advantages: List[float], + n_samples_per_prompt: int, +) -> Dict[str, Any]: + """ + Compute all rollout metrics using the shared metrics module. + + Args: + rollouts: All rollouts (including failed ones) + valid_rollouts: Only valid rollouts + rewards: Rewards for valid rollouts + advantages: GRPO advantages for valid rollouts + n_samples_per_prompt: Number of samples per prompt + + Returns: + Dict of metrics for wandb logging + """ + metrics = {} + + # Extract data for shared module + uids = [r.get("task_key", "unknown") for r in valid_rollouts] + env_keys = [r.get("env_key", "unknown") for r in valid_rollouts] + + # Core reward metrics using shared module + core_metrics = compute_reward_metrics(rewards, uids, n_samples_per_prompt) + metrics[f"reward/avg_pass_at_{n_samples_per_prompt}"] = core_metrics[f"pass_at_{n_samples_per_prompt}"] + metrics["reward/variance_per_prompt"] = core_metrics["variance_per_prompt"] + metrics["reward/mean_positive_reward"] = core_metrics["mean_positive_reward"] + + # Advantage metrics (Tinker-specific) + metrics["advantage/mean"] = np.mean(advantages) + metrics["advantage/std"] = np.std(advantages) + metrics["rollouts/valid"] = len(valid_rollouts) + metrics["rollouts/total"] = len(rollouts) + + # Per-environment reward metrics using shared module + per_env_metrics = compute_per_group_metrics( + rewards=rewards, + uids=uids, + groups=env_keys, + n_samples_per_prompt=n_samples_per_prompt, + prefix="reward", + ) + metrics.update(per_env_metrics) + + # Per-environment rollout stats (turns, tool_calls, tool_errors, duration) - Tinker-specific + rollout_stats = defaultdict(list) + for r in valid_rollouts: + env_key = sanitize_metric_key(r.get("env_key", "unknown")) + rollout_stats[f"rollout/{env_key}/turns"].append(r.get("turns", 0)) + rollout_stats[f"rollout/{env_key}/tool_calls"].append(r.get("tool_calls", 0)) + rollout_stats[f"rollout/{env_key}/tool_errors"].append(r.get("tool_errors", 0)) + rollout_stats[f"rollout/{env_key}/duration"].append(r.get("duration", 0.0)) + + for key, values in rollout_stats.items(): + metrics[key] = np.mean(values) + + # Compute tool error rate per environment + env_keys_seen = set() + for r in valid_rollouts: + env_keys_seen.add(sanitize_metric_key(r.get("env_key", "unknown"))) + for env_key in env_keys_seen: + total_calls = sum(rollout_stats[f"rollout/{env_key}/tool_calls"]) + total_errors = sum(rollout_stats[f"rollout/{env_key}/tool_errors"]) + if total_calls > 0: + metrics[f"rollout/{env_key}/tool_error_rate"] = total_errors / total_calls + else: + metrics[f"rollout/{env_key}/tool_error_rate"] = 0.0 + + # Overall rollout duration stats + durations = [r.get("duration", 0.0) for r in valid_rollouts] + metrics["rollout/avg_duration"] = np.mean(durations) + metrics["rollout/max_duration"] = np.max(durations) + metrics["rollout/min_duration"] = np.min(durations) + + return metrics + + +def prepare_training_data( + rollouts: List[Dict[str, Any]], + advantages: List[float], + tokenizer: AutoTokenizer, + max_sequence_length: int, +) -> tuple: + """ + Prepare training data from rollouts (matching SkyRL's generate_batched pattern). + + Applies: + 1. DAPO overlong filtering (zero loss mask if response doesn't end with EOS) + 2. Sequence truncation for max_sequence_length + 3. Builds Tinker Datum objects for training + + Args: + rollouts: List of rollout dicts with prompt_ids, response_ids, logprobs, loss_mask + advantages: GRPO advantages for each rollout + tokenizer: Tokenizer for EOS token ID + max_sequence_length: Maximum sequence length for training + + Returns: + Tuple of (training_datums, truncated_count) + """ + # Apply DAPO overlong filtering (zero out loss mask if response doesn't end with EOS) + all_response_ids = [r.response_ids for r in rollouts] + all_loss_masks = [r.loss_mask for r in rollouts] + filtered_loss_masks = apply_overlong_filtering(all_loss_masks, all_response_ids, tokenizer.eos_token_id) + + training_datums = [] + truncated_count = 0 + + for idx, rollout in enumerate(rollouts): + prompt_ids = rollout.prompt_ids + response_ids = rollout.response_ids + logprobs = rollout.logprobs + loss_mask_data = filtered_loss_masks[idx] + + full_sequence = prompt_ids + response_ids + prompt_len = len(prompt_ids) + + # Truncate sequences exceeding model's max length for Tinker API + if len(full_sequence) > max_sequence_length: + truncated_count += 1 + full_sequence = full_sequence[:max_sequence_length] + response_len = len(full_sequence) - prompt_len + response_ids = response_ids[:response_len] + logprobs = logprobs[:response_len] if logprobs else [] + loss_mask_data = loss_mask_data[:response_len] + + # Target tokens (shifted by 1) + target_tokens = full_sequence[1:] + + # Logprobs (0 for prompt, actual for response) + full_logprobs = [0.0] * prompt_len + logprobs + full_logprobs = full_logprobs[1:] + + # Loss mask (0 for prompt, actual for response) + full_mask = [0] * prompt_len + loss_mask_data + full_mask = full_mask[1:] + + # Advantages (apply only where loss mask is 1) + advantage_value = advantages[idx] + full_advantages = torch.zeros(len(full_sequence)) + for i in range(prompt_len, len(full_sequence)): + if i - 1 < len(full_mask) and full_mask[i - 1] > 0: + full_advantages[i] = advantage_value + full_advantages = full_advantages[1:] + + datum = types.Datum( + model_input=types.ModelInput.from_ints(tokens=full_sequence[:-1]), + loss_fn_inputs={ + "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)), + "logprobs": TensorData.from_torch(torch.tensor(full_logprobs)), + "advantages": TensorData.from_torch(full_advantages), + }, + ) + training_datums.append(datum) + + return training_datums, truncated_count + + +def tokenize_chat(tokenizer: AutoTokenizer, chat_history: List[Dict], add_generation_prompt: bool = True) -> List[int]: + """ + Tokenize chat history and ensure we get a plain list of token IDs. + + apply_chat_template can return different types depending on the tokenizer: + - List[int] for some tokenizers + - BatchEncoding dict with 'input_ids' key for others + + Tinker's ModelInput.from_ints() requires a plain list of integers. + """ + result = tokenizer.apply_chat_template(chat_history, add_generation_prompt=add_generation_prompt, tokenize=True) + # Handle BatchEncoding (dict-like) vs plain list + if hasattr(result, "input_ids"): + return list(result.input_ids) + elif isinstance(result, dict) and "input_ids" in result: + return list(result["input_ids"]) + else: + return list(result) + + +async def collect_fleet_rollout( + task_config: Dict[str, Any], + tasks_file: str, + sampling_client: tinker.SamplingClient, + tokenizer: AutoTokenizer, + max_turns: int = 50, + max_generate_length: int = 2048, + max_input_length: int = 30720, + temperature: float = 1.0, +) -> Dict[str, Any]: + """ + Collect a single trajectory using Fleet environment and Tinker inference. + + Uses SkyRL's FleetTaskEnv wrapper with async methods for environment interaction. + + Args: + max_generate_length: Max tokens per generation step. + max_input_length: Max context length before ending rollout (matching SkyRL). + """ + rollout_start = time.time() + + task_key = task_config.get("task_key") or task_config.get("key") + + # Create SkyRL FleetTaskEnv wrapper + # TTL of 2 hours - some rollouts with many turns can take 30+ minutes + env_config = OmegaConf.create({"tasks_file": tasks_file, "ttl_seconds": 7200}) + extras = {"task_key": task_key, "max_turns": max_turns} + + env = FleetTaskEnv(env_config=env_config, extras=extras) + + try: + # Initialize environment in thread pool - isolates MCP connections + chat_history, metadata = await _run_in_executor(env.init, []) + env_key = metadata.get("env_key", "unknown") + + # Tokenize initial prompt + prompt_ids = tokenize_chat(tokenizer, chat_history, add_generation_prompt=True) + + all_response_ids = [] + all_logprobs = [] + loss_mask = [] + done = False + total_reward = 0.0 + stop_reason = "stop" + # Timing accumulators for WandB + total_gen_time = 0.0 + total_step_time = 0.0 + total_tokens = 0 + + while not done and env.turns < max_turns: + turn_num = env.turns + 1 # 1-indexed for logging + + # Prepare input for Tinker (use env's chat_history) + input_ids = tokenize_chat(tokenizer, env.chat_history, add_generation_prompt=True) + + # Check context length limit (matching SkyRL's skyrl_gym_generator.py:274) + if len(input_ids) > max_input_length: + logger.info( + f"[{task_key}] Turn {turn_num}: context length ({len(input_ids)}) exceeds max ({max_input_length}), ending" + ) + stop_reason = "length" + break + + # Generate with Tinker + gen_start = time.time() + sampling_params = types.SamplingParams( + max_tokens=max_generate_length, + temperature=temperature, + top_p=1.0, + ) + + # Use async sampling to avoid blocking the event loop + result = await sampling_client.sample_async( + prompt=types.ModelInput.from_ints(tokens=input_ids), + num_samples=1, + sampling_params=sampling_params, + ) + gen_time = time.time() - gen_start + total_gen_time += gen_time + + if not result.sequences or len(result.sequences) == 0: + logger.warning(f"[{task_key}] Turn {turn_num}: no sequences returned from Tinker") + break + + sequence = result.sequences[0] + output_ids = sequence.tokens + output_logprobs = sequence.logprobs if sequence.logprobs else [] + + # Decode output + output_text = tokenizer.decode(output_ids, skip_special_tokens=True) + + # Collect trajectory data (assistant response tokens - trainable) + all_response_ids.extend(output_ids) + if output_logprobs: + all_logprobs.extend(output_logprobs) + else: + all_logprobs.extend([0.0] * len(output_ids)) + loss_mask.extend([1] * len(output_ids)) + + # Step environment in thread pool - isolates MCP connections + step_start = time.time() + step_output = await _run_in_executor(env.step, output_text) + step_time = time.time() - step_start + total_step_time += step_time + total_tokens += len(output_ids) + + # Get observation content for tokenization (masked out for loss) + # Note: BaseTextEnvStepOutput is a TypedDict, use dict access + if step_output["observations"]: + obs_content = step_output["observations"][0].get("content", "") + obs_ids = tokenizer.encode(obs_content, add_special_tokens=False) + all_response_ids.extend(obs_ids) + all_logprobs.extend([0.0] * len(obs_ids)) + loss_mask.extend([0] * len(obs_ids)) + + total_reward = step_output["reward"] + done = step_output["done"] + + return RolloutOutput( + prompt_ids=prompt_ids, + response_ids=all_response_ids, + logprobs=all_logprobs, + loss_mask=loss_mask, + reward=total_reward, + task_key=task_key, + env_key=env_key, + turns=env.turns, + tool_calls=env.tool_calls, + tool_errors=env.tool_errors, + stop_reason=stop_reason, + duration=time.time() - rollout_start, + total_gen_time=total_gen_time, + total_step_time=total_step_time, + total_tokens=total_tokens, + ) + + finally: + env.close() + + +async def collect_batch_rollouts( + batch: List[Dict[str, Any]], + tasks_file: str, + sampling_client: tinker.SamplingClient, + tokenizer: AutoTokenizer, + max_turns: int = 50, + max_generate_length: int = 2048, + max_input_length: int = 30720, + n_samples_per_prompt: int = 1, + max_concurrent: int = 8, +) -> List[Dict[str, Any]]: + """Collect rollouts for a batch of tasks with limited concurrency. + + Args: + max_concurrent: Maximum number of concurrent Fleet environment connections. + Now safe to increase since ThreadPoolExecutor isolates connections. + """ + # Semaphore to limit concurrent Fleet environment connections + semaphore = asyncio.Semaphore(max_concurrent) + + async def collect_single_rollout(task_config: Dict[str, Any], index: int) -> tuple: + """Wrapper to collect a single rollout with error handling and concurrency limit.""" + async with semaphore: + rollout_start = time.time() + try: + rollout = await collect_fleet_rollout( + task_config=task_config, + tasks_file=tasks_file, + sampling_client=sampling_client, + tokenizer=tokenizer, + max_turns=max_turns, + max_generate_length=max_generate_length, + max_input_length=max_input_length, + ) + return index, rollout + except Exception as e: + logger.error(f"Failed to collect rollout for {task_config.get('task_key')}: {e}") + return index, RolloutOutput( + prompt_ids=[], + response_ids=[], + logprobs=[], + loss_mask=[], + reward=0.0, + task_key=task_config.get("task_key", "unknown"), + env_key=task_config.get("env_key", "unknown"), + turns=0, + tool_calls=0, + tool_errors=0, + stop_reason="error", + error=str(e), + duration=time.time() - rollout_start, + ) + + # Create all rollout tasks (batch_size * n_samples_per_prompt) + tasks = [] + index = 0 + for task_config in batch: + for _ in range(n_samples_per_prompt): + tasks.append(collect_single_rollout(task_config, index)) + index += 1 + + total = len(tasks) + logger.info(f" Collecting {total} rollouts (max {max_concurrent} concurrent)...") + rollouts = [None] * total + completed = 0 + last_logged = 0 + log_interval = max(1, total // 4) # Log at ~25%, 50%, 75%, 100% + + # Run rollouts with limited concurrency via semaphore + for coro in asyncio.as_completed(tasks): + idx, rollout = await coro + rollouts[idx] = rollout + completed += 1 + + # Log progress at intervals + if completed - last_logged >= log_interval or completed == total: + logger.info(f" Progress: {completed}/{total} rollouts completed") + last_logged = completed + + return rollouts + + +def collate_fn(batch): + """Return batch as-is without tensor collation.""" + return batch + + +async def main( + model_name: str = "Qwen/Qwen3-VL-30B-A3B-Instruct", + tasks_file: str = None, + dataset_file: str = None, + eval_dataset_file: str = None, + batch_size: int = 8, + eval_batch_size: int = 32, + learning_rate: float = 4e-5, + lora_rank: int = 16, + max_steps: int = 200, + max_turns: int = 50, + max_generate_length: int = 2048, + max_input_length: int = 30720, + max_sequence_length: int = 32768, + n_samples_per_prompt: int = 4, + eval_every: int = 20, + seed: int = 42, + wandb_project: str = "fleet-tinker-grpo", + wandb_name: str = None, +): + """ + Main training loop using Tinker for training/inference and Fleet for environments. + """ + set_seed(seed) + + # Setup WandB run name + if wandb_name is None: + wandb_name = f"{model_name.split('/')[-1]}_{datetime.now().strftime('%m%d_%H%M')}" + + # Initialize WandB + wandb.init( + project=wandb_project, + name=wandb_name, + config={ + "model_name": model_name, + "batch_size": batch_size, + "learning_rate": learning_rate, + "lora_rank": lora_rank, + "max_turns": max_turns, + "max_generate_length": max_generate_length, + "max_input_length": max_input_length, + "max_sequence_length": max_sequence_length, + "n_samples_per_prompt": n_samples_per_prompt, + }, + ) + + # Load datasets + train_dataset = load_dataset("parquet", data_files=dataset_file)["train"] + eval_dataset = load_dataset("parquet", data_files=eval_dataset_file)["train"] if eval_dataset_file else None + + logger.info(f"Loaded {len(train_dataset)} training samples") + if eval_dataset: + logger.info(f"Loaded {len(eval_dataset)} eval samples") + + # Setup Tinker + tinker_url = os.environ.get("TINKER_API_URL") + tinker_api_key = os.environ.get("TINKER_API_KEY") + + service_client_kwargs = {} + if tinker_url: + service_client_kwargs["base_url"] = tinker_url + if tinker_api_key: + service_client_kwargs["api_key"] = tinker_api_key + + service_client = tinker.ServiceClient(**service_client_kwargs) + training_client = await service_client.create_lora_training_client_async(base_model=model_name, rank=lora_rank) + + adam_params = types.AdamParams(learning_rate=learning_rate, beta1=0.9, beta2=0.95, eps=1e-8) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + # Create dataloader + def create_dataloader(epoch: int): + return DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn, + generator=torch.Generator().manual_seed(seed + epoch), + ) + + steps_per_epoch = (len(train_dataset) + batch_size - 1) // batch_size + current_epoch = 0 + train_dataloader = create_dataloader(current_epoch) + train_iterator = iter(train_dataloader) + + # Training loop + pbar = tqdm(range(max_steps), desc="Training", unit="step") + for step in pbar: + step_start = time.time() + metrics = {"step": step, "epoch": step // steps_per_epoch} + + # Get sampler weights for rollout inference + sampling_path = training_client.save_weights_for_sampler(name=f"step_{step:06d}").result().path + sampling_client = service_client.create_sampling_client(model_path=sampling_path) + + # Get batch + try: + batch = next(train_iterator) + except StopIteration: + current_epoch += 1 + train_dataloader = create_dataloader(current_epoch) + train_iterator = iter(train_dataloader) + batch = next(train_iterator) + + # Collect rollouts + logger.info(f"Step {step}: Collecting rollouts for {len(batch)} tasks...") + rollout_start = time.time() + + rollouts = await collect_batch_rollouts( + batch=batch, + tasks_file=tasks_file, + sampling_client=sampling_client, + tokenizer=tokenizer, + max_turns=max_turns, + max_generate_length=max_generate_length, + max_input_length=max_input_length, + n_samples_per_prompt=n_samples_per_prompt, + ) + + metrics["time/rollout"] = time.time() - rollout_start + + # Filter valid rollouts and log invalid ones + # Note: rollouts are RolloutOutput Pydantic objects - use attribute access + valid_rollouts = [] + invalid_rollouts = [] + for r in rollouts: + if r.response_ids and not r.error: + valid_rollouts.append(r) + else: + invalid_rollouts.append(r) + + if invalid_rollouts: + for r in invalid_rollouts: + task_key = r.task_key + error = r.error or "no response_ids" + stop_reason = r.stop_reason + logger.warning(f"Step {step}: Invalid rollout for {task_key}: {error} (stop_reason={stop_reason})") + metrics["rollouts/invalid"] = len(invalid_rollouts) + + if not valid_rollouts: + logger.warning(f"Step {step}: No valid rollouts, skipping") + continue + + # Compute GRPO advantages + rewards = [r.reward for r in valid_rollouts] + advantages = compute_advantages_grpo(rewards, group_size=n_samples_per_prompt, normalize=True) + + # Compute all rollout metrics (convert to dicts for metrics functions) + rollout_metrics = compute_rollout_metrics( + rollouts=[r.model_dump() for r in rollouts], + valid_rollouts=[r.model_dump() for r in valid_rollouts], + rewards=rewards, + advantages=advantages, + n_samples_per_prompt=n_samples_per_prompt, + ) + metrics.update(rollout_metrics) + + # Compute timing metrics from valid rollouts + gen_times = [r.total_gen_time for r in valid_rollouts] + step_times = [r.total_step_time for r in valid_rollouts] + tokens = [r.total_tokens for r in valid_rollouts] + durations = [r.duration for r in valid_rollouts] + + metrics["time/gen_total"] = sum(gen_times) + metrics["time/gen_mean"] = np.mean(gen_times) + metrics["time/step_total"] = sum(step_times) + metrics["time/step_mean"] = np.mean(step_times) + metrics["time/gen_pct"] = 100 * sum(gen_times) / sum(durations) if sum(durations) > 0 else 0 + metrics["time/step_pct"] = 100 * sum(step_times) / sum(durations) if sum(durations) > 0 else 0 + metrics["throughput/tokens_total"] = sum(tokens) + metrics["throughput/tokens_per_sec_gen"] = sum(tokens) / sum(gen_times) if sum(gen_times) > 0 else 0 + metrics["throughput/tokens_per_sec_effective"] = sum(tokens) / sum(durations) if sum(durations) > 0 else 0 + + # Prepare training data (DAPO filtering + truncation + datum creation) + training_datums, truncated_count = prepare_training_data( + rollouts=valid_rollouts, + advantages=advantages, + tokenizer=tokenizer, + max_sequence_length=max_sequence_length, + ) + + metrics["rollouts/truncated_overlong"] = truncated_count + if truncated_count > 0: + logger.info(f"Step {step}: Truncated {truncated_count} overlong sequences") + + if not training_datums: + logger.warning(f"Step {step}: No valid training sequences after filtering, skipping") + continue + + # Training step + logger.info(f"Step {step}: Training on {len(training_datums)} sequences...") + train_start = time.time() + + fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn="ppo") + optim_step_future = training_client.optim_step(adam_params) + + fwd_bwd_future.result() + optim_step_future.result() + + metrics["time/train"] = time.time() - train_start + metrics["time/total"] = time.time() - step_start + + # Log metrics (commit=True forces immediate sync) + wandb.log(metrics, step=step, commit=True) + pbar.set_postfix( + { + f"pass@{n_samples_per_prompt}": f"{metrics[f'reward/avg_pass_at_{n_samples_per_prompt}']:.3f}", + "reward": f"{metrics['reward/avg_raw_reward']:.3f}", + "time": f"{metrics['time/total']:.1f}s", + } + ) + + # Evaluation + if eval_every > 0 and eval_dataset and step % eval_every == 0: + logger.info(f"Step {step}: Running evaluation...") + eval_dataloader = DataLoader(eval_dataset, batch_size=eval_batch_size, shuffle=False, collate_fn=collate_fn) + + all_eval_rollouts = [] + for eval_batch in eval_dataloader: + eval_rollouts = await collect_batch_rollouts( + batch=eval_batch, + tasks_file=tasks_file, + sampling_client=sampling_client, + tokenizer=tokenizer, + max_turns=max_turns, + max_generate_length=max_generate_length, + max_input_length=max_input_length, + n_samples_per_prompt=1, + ) + all_eval_rollouts.extend([r for r in eval_rollouts if not r.error]) + + if all_eval_rollouts: + eval_rewards = [r.reward for r in all_eval_rollouts] + # Convert to dicts for metrics functions + eval_rollouts_dicts = [r.model_dump() for r in all_eval_rollouts] + eval_pass_at_1 = compute_pass_at_n(eval_rollouts_dicts, 1) + eval_per_env = compute_per_env_metrics(eval_rollouts_dicts, 1) + + eval_metrics = { + "eval/all/pass_at_1": eval_pass_at_1, + "eval/all/mean_positive_reward": ( + np.mean([r for r in eval_rewards if r > 0]) if any(r > 0 for r in eval_rewards) else 0.0 + ), + "eval/num_samples": len(all_eval_rollouts), + } + # Add per-env eval metrics (rename from reward/ to eval/) + for key, value in eval_per_env.items(): + eval_key = key.replace("reward/", "eval/") + eval_metrics[eval_key] = value + + wandb.log(eval_metrics, step=step, commit=True) + logger.info(f"Step {step}: eval pass@1={eval_pass_at_1:.3f}, num_samples={len(all_eval_rollouts)}") + + wandb.finish() + logger.info("Training completed!") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Fleet Task Training with Tinker") + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-VL-30B-A3B-Instruct") + parser.add_argument("--tasks-file", type=str, required=True, help="Path to tasks JSON file") + parser.add_argument("--dataset-file", type=str, required=True, help="Path to training parquet") + parser.add_argument("--eval-dataset-file", type=str, default=None, help="Path to eval parquet") + parser.add_argument("--batch-size", type=int, default=8) + parser.add_argument("--eval-batch-size", type=int, default=32) + parser.add_argument("--learning-rate", type=float, default=4e-5) + parser.add_argument("--lora-rank", type=int, default=16) + parser.add_argument("--max-steps", type=int, default=200) + parser.add_argument("--max-turns", type=int, default=50) + parser.add_argument("--max-generate-length", type=int, default=2048, help="Max tokens per generation") + parser.add_argument("--max-input-length", type=int, default=30720, help="Max context length before ending rollout") + parser.add_argument("--max-sequence-length", type=int, default=32768, help="Max sequence length for training") + parser.add_argument("--n-samples-per-prompt", type=int, default=4) + parser.add_argument("--eval-every", type=int, default=20) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--wandb-project", type=str, default="fleet-tinker-grpo") + parser.add_argument("--wandb-name", type=str, default=None) + parser.add_argument( + "--track-extra-gradient-metrics", + type=bool, + default=False, + help="Track additional gradient metrics (for parity with SkyRL config)", + ) + + args = parser.parse_args() + + asyncio.run( + main( + model_name=args.model_name, + tasks_file=args.tasks_file, + dataset_file=args.dataset_file, + eval_dataset_file=args.eval_dataset_file, + batch_size=args.batch_size, + eval_batch_size=args.eval_batch_size, + learning_rate=args.learning_rate, + lora_rank=args.lora_rank, + max_steps=args.max_steps, + max_turns=args.max_turns, + max_generate_length=args.max_generate_length, + max_input_length=args.max_input_length, + max_sequence_length=args.max_sequence_length, + n_samples_per_prompt=args.n_samples_per_prompt, + eval_every=args.eval_every, + seed=args.seed, + wandb_project=args.wandb_project, + wandb_name=args.wandb_name, + ) + ) diff --git a/integrations/fleet/entrypoints/main_task_gen.py b/integrations/fleet/entrypoints/main_task_gen.py new file mode 100644 index 0000000000..3731bc2327 --- /dev/null +++ b/integrations/fleet/entrypoints/main_task_gen.py @@ -0,0 +1,80 @@ +""" +Task Generation Training Entrypoint for SkyRL. + +Registers the TaskGenEnv and runs GRPO training for task generation +with S3 checkpoint management. + +Usage: + python -m integrations.fleet.entrypoints.main_task_gen \ + environment.env_class=task_gen \ + data.train_data=./data/task_gen/train.parquet \ + data.val_data=./data/task_gen/validation.parquet +""" + +import asyncio +import logging +import os +from pathlib import Path + +import hydra +import ray +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +logger = logging.getLogger(__name__) + + +class FleetPPOExp(BasePPOExp): + """Fleet-specific PPO experiment with S3 checkpoint management.""" + + def run(self): + trainer = self._setup_trainer() + + resume_run_name = os.environ.get("RESUME_RUN_NAME", "") + if resume_run_name: + try: + from integrations.fleet.s3_checkpoints import download_checkpoint_from_s3 + + ckpt_path = trainer.cfg.trainer.ckpt_path + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + download_checkpoint_from_s3( + ckpt_path=ckpt_path, + run_name=resume_run_name, + project_name=project_name, + model_name=model_name, + ) + except Exception as e: + logger.warning(f"Failed to download checkpoint from S3: {e}") + + try: + from integrations.fleet.s3_checkpoints import wrap_trainer_with_s3_upload + + trainer = wrap_trainer_with_s3_upload(trainer) + except Exception as e: + logger.warning(f"Failed to setup checkpoint management: {e}") + + asyncio.run(trainer.train()) + + +@ray.remote(num_cpus=1) +def skyrl_entrypoint(cfg: SkyRLTrainConfig): + """Ray remote function that registers TaskGenEnv and runs training.""" + # task_gen env is registered in skyrl_gym.envs.__init__ (after PR 3) + exp = FleetPPOExp(cfg) + exp.run() + + +@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) +def main(cfg: SkyRLTrainConfig) -> None: + """Main entry point for task generation training.""" + validate_cfg(cfg) + initialize_ray(cfg) + ray.get(skyrl_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/export_tasks.py b/integrations/fleet/export_tasks.py new file mode 100644 index 0000000000..7b1ed956fc --- /dev/null +++ b/integrations/fleet/export_tasks.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Export tasks from Fleet API to JSON file. + +Usage: + python -m integrations.fleet.export_tasks --output ~/data/fleet/tasks.json + python -m integrations.fleet.export_tasks --output ~/data/fleet/tasks.json --env-key github +""" + +import argparse +import json +import os +import sys + + +def export_tasks(output_file: str, env_key: str | None = None, modality: str = "tool_use"): + """Export tasks from Fleet API to JSON file.""" + try: + from fleet import Fleet + except ImportError: + print("Fleet SDK not available. Install with: pip install fleet-python") + sys.exit(1) + + api_key = os.environ.get("FLEET_API_KEY") + if not api_key: + print("ERROR: FLEET_API_KEY environment variable not set") + sys.exit(1) + + fleet = Fleet(api_key=api_key) + + print(f"Loading tasks from Fleet API (env_key={env_key})...") + tasks = fleet.load_tasks(env_key=env_key) + print(f"Loaded {len(tasks)} tasks") + + # Convert to JSON format + task_dicts = [] + for task in tasks: + task_dicts.append( + { + "key": task.key, + "prompt": task.prompt, + "env_id": task.env_id, + "version": task.version, + "data_id": task.data_id, + "data_version": task.data_version, + "verifier_func": task.verifier_func, + "task_modality": modality, + } + ) + + # Ensure output directory exists + os.makedirs(os.path.dirname(os.path.expanduser(output_file)), exist_ok=True) + + output_path = os.path.expanduser(output_file) + with open(output_path, "w") as f: + json.dump(task_dicts, f, indent=2) + + print(f"Exported {len(task_dicts)} tasks to {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Export Fleet tasks to JSON") + parser.add_argument("--output", "-o", required=True, help="Output JSON file path") + parser.add_argument("--env-key", default=None, help="Filter by environment key") + parser.add_argument("--modality", default="tool_use", help="Task modality") + args = parser.parse_args() + + export_tasks(args.output, args.env_key, args.modality) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/prepare_dataset.py b/integrations/fleet/prepare_dataset.py new file mode 100644 index 0000000000..cdc465b8d7 --- /dev/null +++ b/integrations/fleet/prepare_dataset.py @@ -0,0 +1,564 @@ +""" +Prepare Fleet tasks for SkyRL training. + +Converts Fleet task JSON files to SkyRL parquet dataset format. + +Usage: + python -m integrations.fleet.prepare_dataset \ + --tasks-json /path/to/all_tool_use.json \ + --output-dir ./data/fleet \ + --modality tool_use + +Split Strategy: + - Stratified by environment (each env maintains train/eval ratio) + - Hash-based deterministic assignment (same task always goes to same split) + - 20% eval ratio, capped at 20 samples per env (MAX_EVAL_SAMPLES) + - Minimum 5 eval samples per env (otherwise all go to train) + - Held-out eval envs: instacart (computer_use only) + +v0.3.2 Changes: + - Increased eval_ratio from 10% to 20% to include carlisle/outlook in eval + - Result: 11 envs in eval (was 9), ~183 eval samples (was ~146) + +v0.3.1 Changes: + - Added MAX_ENV_TRAIN_RATIO=0.20 to prevent any single env from dominating + - Hash-based deterministic sampling for reproducibility + +v0.3.0 Changes: + - Increased eval_ratio from 2% to 10% + - Added MAX_EVAL_SAMPLES=30 cap per environment + - MIN_EVAL_SAMPLES stays at 5 + - Result: ticketmaster now gets ~22 eval samples for trace analysis +""" + +import argparse +import hashlib +import json +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional + +from datasets import Dataset + +# Held-out environments for eval only (not used in train) +HELD_OUT_ENVS = { + "tool_use": [], # v0.3: all envs split normally (outlook now included in train) + "computer_use": [], +} + +# Excluded environments (removed from both train and eval) +# v0.3.6: google-maps excluded due to broken MCP server (502 errors, "database is locked") +# v0.4.0: dropbox excluded due to broken env (instance creation timeouts) +EXCLUDED_ENVS = { + "tool_use": ["dropbox"], + "computer_use": ["dropbox"], +} + +# Tasks excluded due to missing CURRENT_DATE in env_variables (v0.4.0) +# These tasks have partial dates (e.g., "January 30th" without year) but their +# tool calls require mm/dd/yy format. Without CURRENT_DATE, the model cannot +# compute the correct year, causing date validation failures. +# See: https://github.com/fleet-ai/SkyRL/pull/246 +TASKS_MISSING_CURRENT_DATE = { + "task_a44hx6crecg4_1769052238469_i7dxxtjvq", # zillow - February 1st + "task_a7rlslof7gdy_1768337837679_8be6pguu3", # zillow - March 11th + "task_axtmgwocana_1768544478249_k2ozcylyf", # zillow - January 21st + "task_b1fxgn0k3yms_1768542773490_ddbhj5bai", # zillow - January 30th + "task_b4v77hb3owof_1768546181946_efsedxv9g", # zillow - February 14th + "task_b5zt6ipf0nbl_1768346335430_i23gknp4t", # zillow - January 15th + "task_bafrpi5qgyzh_1768546181946_2cebmq91r", # zillow - February 14th + "task_bdmnfipwxlqv_1769052238469_4nglwjqfm", # zillow - February 1st + "task_bxqzfjc2dbte_1768337837679_2qvnm9rq7", # zillow - March 11th + "task_c3jwlxmfvbop_1768544478249_efo6hxylr", # zillow - January 21st + "task_c7o0c7ehhv9t_1768542773490_2t9w2l1z5", # zillow - January 30th + "task_ceqj4h9t0ygi_1768346335430_8j1w8w5xp", # zillow - January 15th + "task_cgpxfxp78bvp_1768346335430_6v4n8wlt8", # zillow - January 15th + "task_cgsz56tqjlv6_1768346335430_hqgsjy4wt", # zillow - January 15th + "task_dpv4bpdpz6db_1768542773490_f3g6w8e8g", # zillow - January 30th + "task_f7lgb6fxfwln_1768337837679_d1dxk6ahv", # zillow - March 11th + "task_fl1rq3d2wbj9_1768337837679_d2x4k8p93", # zillow - March 11th + "task_fn1k5mvjx6r1_1768544478249_1nfmnp6r2", # zillow - January 21st + "task_fnh5f0x7hv6w_1768544478249_8wptm6zqp", # zillow - January 21st + "task_g2dwb1rfx69c_1769052238469_bc1y9h9d7", # zillow - February 1st + "task_g3wpj1mcl0lf_1768546181946_59vtqn9fw", # zillow - February 14th +} + +# Minimum number of samples required to create an eval split for an env +MIN_EVAL_SAMPLES = 5 + +# Maximum number of eval samples per environment (v0.3.1: reduced from 30 to 20) +# Ensures small envs get eval traces without blowing up eval set size +MAX_EVAL_SAMPLES = 20 + +# Maximum fraction of training data any single environment can have (v0.3.1) +# Prevents dominant environments from skewing training +MAX_ENV_TRAIN_RATIO = 0.20 + +# Maximum total eval prompts across all environments (v0.3.2) +# With eval_n_samples_per_prompt=3 and 30s per trajectory: +# 96 prompts × 3 samples = 288 trajectories (~8 tasks/env × 12 envs) +MAX_EVAL_PROMPTS = 96 + + +def load_tasks_from_json(json_path: str) -> List[Dict[str, Any]]: + """Load tasks from JSON file (Fleet export format).""" + with open(json_path, "r") as f: + data = json.load(f) + + # Handle both formats: array or {"tasks": [...]} + if isinstance(data, list): + return data + elif isinstance(data, dict) and "tasks" in data: + return data["tasks"] + else: + raise ValueError("Invalid JSON format: expected array or object with 'tasks' key") + + +def hash_to_split(task_key: str, eval_ratio: float = 0.10) -> str: + """Deterministically assign task to train or eval based on hash. + + Uses MD5 hash of task_key to get a deterministic float in [0, 1). + This ensures the same task always goes to the same split. + """ + hash_bytes = hashlib.md5(task_key.encode()).digest() + hash_int = int.from_bytes(hash_bytes[:8], byteorder="big") + hash_float = hash_int / (2**64) + return "eval" if hash_float < eval_ratio else "train" + + +def hash_to_float(task_key: str) -> float: + """Convert task_key to deterministic float in [0, 1) for sampling.""" + hash_bytes = hashlib.md5(task_key.encode()).digest() + hash_int = int.from_bytes(hash_bytes[:8], byteorder="big") + return hash_int / (2**64) + + +def cap_training_distribution( + train_records: List[Dict[str, Any]], + max_env_ratio: float, +) -> tuple[List[Dict[str, Any]], Dict[str, Dict[str, int]]]: + """Cap each environment's contribution to training data. + + Uses hash-based deterministic sampling so the same tasks are always selected. + + Args: + train_records: List of training records with 'data_source' (env_key) and 'task_key' + max_env_ratio: Maximum fraction any single env can contribute (e.g., 0.20 = 20%) + + Returns: + Tuple of (capped_records, cap_stats) where cap_stats shows per-env before/after counts + """ + if max_env_ratio >= 1.0: + return train_records, {} + + total_train = len(train_records) + max_per_env = int(total_train * max_env_ratio) + + # Group by environment + records_by_env: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for record in train_records: + env_key = record.get("data_source", "unknown") + records_by_env[env_key].append(record) + + # Cap each environment + capped_records = [] + cap_stats: Dict[str, Dict[str, int]] = {} + + for env_key, records in records_by_env.items(): + before_count = len(records) + + if before_count <= max_per_env: + # No capping needed + capped_records.extend(records) + cap_stats[env_key] = {"before": before_count, "after": before_count, "capped": False} + else: + # Sort by hash for deterministic selection + records_sorted = sorted(records, key=lambda r: hash_to_float(r.get("task_key", ""))) + selected = records_sorted[:max_per_env] + capped_records.extend(selected) + cap_stats[env_key] = {"before": before_count, "after": max_per_env, "capped": True} + + return capped_records, cap_stats + + +def prepare_fleet_dataset( + tasks_json: str, + output_dir: str, + modality: Optional[str] = "tool_use", + eval_ratio: float = 0.20, # v0.3.2: increased to 20% to include carlisle/outlook in eval + env_filter: Optional[str] = None, + difficulty_filter: Optional[str] = None, # v0.4.0: filter by difficulty (1=easy, 2=medium, 3=hard) + max_tasks: Optional[int] = None, + max_env_ratio: float = MAX_ENV_TRAIN_RATIO, # v0.3.1: cap dominant environments + max_eval_prompts: Optional[int] = MAX_EVAL_PROMPTS, # v0.3.2: cap total eval prompts +): + """ + Convert Fleet tasks JSON to SkyRL parquet dataset. + + Args: + tasks_json: Path to Fleet tasks JSON file + output_dir: Output directory for parquet files + modality: Task modality filter ("tool_use" or "computer_use"), None for all + eval_ratio: Fraction of data for evaluation (default: 0.02) + env_filter: Optional env_key filter (e.g., "github", "booking") + max_tasks: Optional maximum number of tasks to include + max_env_ratio: Maximum fraction any single env can contribute to training (default: 0.20) + """ + # Log applied filters at the start + print("\n=== Dataset Filters ===") + print(f" Source: {tasks_json}") + print(f" Modality: {modality or 'all'}") + print(f" Env filter: {env_filter or 'none'}") + print(f" Difficulty filter: {difficulty_filter or 'all (1,2,3)'}") + print(f" Max tasks: {max_tasks or 'unlimited'}") + print(f" Max env ratio: {max_env_ratio:.0%}") + print(f" Max eval prompts: {max_eval_prompts or 'unlimited'}") + print() + + print(f"Loading tasks from {tasks_json}...") + tasks = load_tasks_from_json(tasks_json) + print(f"Loaded {len(tasks)} tasks") + + # Filter by modality if specified + if modality: + tasks = [t for t in tasks if t.get("task_modality") == modality] + print(f"After modality filter ({modality}): {len(tasks)} tasks") + + # Filter by env_key(s) if specified - supports comma-separated list + if env_filter: + env_list = [e.strip() for e in env_filter.split(",") if e.strip()] + tasks = [t for t in tasks if t.get("env_key") in env_list or t.get("env_id") in env_list] + print(f"After env filter ({env_list}): {len(tasks)} tasks") + + # Filter by difficulty if specified - supports comma-separated list (e.g., "1,2" for easy+medium) + if difficulty_filter: + diff_list = [int(d.strip()) for d in difficulty_filter.split(",") if d.strip()] + tasks = [t for t in tasks if t.get("difficulty") in diff_list] + print(f"After difficulty filter ({diff_list}): {len(tasks)} tasks") + + # Limit tasks if specified + if max_tasks and len(tasks) > max_tasks: + tasks = tasks[:max_tasks] + print(f"Limited to {max_tasks} tasks") + + if not tasks: + print("No tasks remaining after filtering. Exiting.") + return + + # Deduplicate by task_key (keep first occurrence) + seen_task_keys: set = set() + unique_tasks = [] + duplicate_count = 0 + env_duplicate_counts: Dict[str, int] = defaultdict(int) + + for task in tasks: + task_key = task.get("key") or task.get("task_key") + if not task_key: + continue + if task_key in seen_task_keys: + duplicate_count += 1 + env_key = task.get("env_key") or task.get("env_id") or "unknown" + env_duplicate_counts[env_key] += 1 + else: + seen_task_keys.add(task_key) + unique_tasks.append(task) + + if duplicate_count > 0: + print(f"\n⚠️ WARNING: Removed {duplicate_count} duplicate task_keys") + print(" By environment:") + for env, count in sorted(env_duplicate_counts.items(), key=lambda x: -x[1]): + print(f" {env}: {count} duplicates removed") + print() + + tasks = unique_tasks + print(f"After deduplication: {len(tasks)} unique tasks") + + # Get excluded envs for this modality (removed entirely) + excluded_envs = set(EXCLUDED_ENVS.get(modality, [])) + if excluded_envs: + before_count = len(tasks) + tasks = [t for t in tasks if t.get("env_key") not in excluded_envs] + print(f"Excluded environments: {excluded_envs}") + print(f"After excluding envs: {len(tasks)} tasks (removed {before_count - len(tasks)})") + + # Exclude specific tasks missing CURRENT_DATE + if TASKS_MISSING_CURRENT_DATE: + before_count = len(tasks) + tasks = [t for t in tasks if (t.get("key") or t.get("task_key")) not in TASKS_MISSING_CURRENT_DATE] + removed = before_count - len(tasks) + if removed > 0: + print(f"Excluded tasks missing CURRENT_DATE: {removed} tasks") + + # Get held-out envs for this modality + held_out_envs = set(HELD_OUT_ENVS.get(modality, [])) + if held_out_envs: + print(f"Held-out test environments: {held_out_envs}") + + # Group tasks by environment + tasks_by_env: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for task in tasks: + env_key = task.get("env_key") or task.get("env_id") or "unknown" + tasks_by_env[env_key].append(task) + + # Prepare records with stratified split + train_records = [] + eval_records = [] + + # Track per-env counts for summary table + env_split_counts: Dict[str, Dict[str, int]] = {} + + print("\n=== Per-Environment Split ===") + for env_key in sorted(tasks_by_env.keys()): + env_tasks = tasks_by_env[env_key] + + # Check if this env is held out for eval only + if env_key in held_out_envs: + env_eval_count = 0 + for task in env_tasks: + record = _task_to_record(task, env_key) + if record: + eval_records.append(record) + env_eval_count += 1 + env_split_counts[env_key] = {"train": 0, "eval": env_eval_count} + print(f" {env_key}: {len(env_tasks)} -> EVAL only (held-out)") + continue + + # Calculate target eval size: use ratio but cap at MAX_EVAL_SAMPLES + target_eval_size = min(int(len(env_tasks) * eval_ratio), MAX_EVAL_SAMPLES) + + # If not enough samples for eval, put all in train + if target_eval_size < MIN_EVAL_SAMPLES: + env_train_count = 0 + for task in env_tasks: + record = _task_to_record(task, env_key) + if record: + train_records.append(record) + env_train_count += 1 + env_split_counts[env_key] = {"train": env_train_count, "eval": 0} + print(f" {env_key}: {len(env_tasks)} -> all TRAIN (< {MIN_EVAL_SAMPLES} eval samples)") + continue + + # Compute effective eval ratio to achieve target_eval_size (capped at MAX_EVAL_SAMPLES) + effective_eval_ratio = target_eval_size / len(env_tasks) + + # Stratified split using hash with effective ratio + env_train = 0 + env_eval = 0 + for task in env_tasks: + task_key = task.get("key") or task.get("task_key") + record = _task_to_record(task, env_key) + if not record: + continue + + split = hash_to_split(task_key, effective_eval_ratio) + if split == "eval": + eval_records.append(record) + env_eval += 1 + else: + train_records.append(record) + env_train += 1 + + env_split_counts[env_key] = {"train": env_train, "eval": env_eval} + print(f" {env_key}: {len(env_tasks)} -> {env_train} train, {env_eval} eval") + + print(f"\nTotal: {len(train_records)} train, {len(eval_records)} eval") + + # Apply total eval cap (v0.3.2) - stratified sampling across environments + if max_eval_prompts and len(eval_records) > max_eval_prompts: + print(f"\n=== Capping Eval Prompts ({max_eval_prompts} max total) ===") + + # Group by environment + eval_by_env: Dict[str, List[Dict[str, Any]]] = defaultdict(list) + for record in eval_records: + eval_by_env[record.get("data_source", "unknown")].append(record) + + # Take min(8, available) from each env, then distribute remaining quota proportionally + min_per_env = 8 + capped_eval_records = [] + + for env_key, records in eval_by_env.items(): + # Sort by hash for deterministic selection + records.sort(key=lambda r: hash_to_float(r.get("task_key", ""))) + # Take at least min_per_env (or all if fewer available) + take = min(min_per_env, len(records)) + capped_eval_records.extend(records[:take]) + + # If we have budget remaining, distribute round-robin across envs + remaining_budget = max_eval_prompts - len(capped_eval_records) + if remaining_budget > 0: + # Records not yet selected (sorted by hash for determinism) + remaining_by_env = { + env: records[min_per_env:] for env, records in eval_by_env.items() if len(records) > min_per_env + } + + # Round-robin until budget exhausted + env_keys = sorted(remaining_by_env.keys()) + idx = 0 + while remaining_budget > 0 and any(remaining_by_env.values()): + env = env_keys[idx % len(env_keys)] + if remaining_by_env[env]: + capped_eval_records.append(remaining_by_env[env].pop(0)) + remaining_budget -= 1 + idx += 1 + + # Update env_split_counts + for env_key in eval_by_env: + count = sum(1 for r in capped_eval_records if r.get("data_source") == env_key) + if env_key in env_split_counts: + env_split_counts[env_key]["eval"] = count + print(f" {env_key}: {len(eval_by_env[env_key])} -> {count}") + + eval_records = capped_eval_records + print(f"\nAfter capping: {len(eval_records)} eval prompts") + + # Apply per-environment cap to training data (v0.3.1) + if max_env_ratio < 1.0 and train_records: + train_records, cap_stats = cap_training_distribution(train_records, max_env_ratio) + + # Print capping summary + capped_envs = [env for env, stats in cap_stats.items() if stats["capped"]] + if capped_envs: + print(f"\n=== Training Distribution Cap ({max_env_ratio:.0%} max per env) ===") + for env in sorted(capped_envs): + stats = cap_stats[env] + print(f" {env}: {stats['before']} -> {stats['after']} ({stats['before'] - stats['after']} removed)") + print(f"\nAfter capping: {len(train_records)} train") + + # Update env_split_counts with capped values + for env, stats in cap_stats.items(): + if env in env_split_counts: + env_split_counts[env]["train"] = stats["after"] + + # Create datasets + train_dataset = Dataset.from_list(train_records) if train_records else None + eval_dataset = Dataset.from_list(eval_records) if eval_records else None + + # Save to parquet + os.makedirs(output_dir, exist_ok=True) + + if train_dataset: + train_path = os.path.join(output_dir, "train.parquet") + train_dataset.to_parquet(train_path) + print(f"Saved train dataset to {train_path}") + + if eval_dataset: + eval_path = os.path.join(output_dir, "validation.parquet") + eval_dataset.to_parquet(eval_path) + print(f"Saved validation dataset to {eval_path}") + + # Print summary statistics + print("\n=== Dataset Summary ===") + print(f"Train: {len(train_records)}") + print(f"Eval: {len(eval_records)} (includes held-out: {held_out_envs or 'none'})") + + # Print per-environment breakdown table + print("\n=== Per-Environment Breakdown ===") + print(f"{'Environment':<20} {'Train':>8} {'Eval':>8} {'Total':>8}") + print("-" * 48) + for env_key in sorted(env_split_counts.keys()): + counts = env_split_counts[env_key] + total = counts["train"] + counts["eval"] + print(f"{env_key:<20} {counts['train']:>8} {counts['eval']:>8} {total:>8}") + print("-" * 48) + print( + f"{'TOTAL':<20} {len(train_records):>8} {len(eval_records):>8} " f"{len(train_records) + len(eval_records):>8}" + ) + + +def _task_to_record(task: Dict[str, Any], env_key: str) -> Optional[Dict[str, Any]]: + """Convert a task dict to a dataset record.""" + task_key = task.get("key") or task.get("task_key") + prompt = task.get("prompt", "") + + if not task_key or not prompt: + return None + + return { + # Required fields for SkyRL + "prompt": [{"role": "user", "content": prompt}], + "env_class": "fleet_task", # This tells SkyRL to use FleetTaskEnv + # Task identification (passed as env_extras) + "task_key": task_key, + # Data source for per-environment metrics in WandB + "data_source": env_key, + } + + +def main(): + parser = argparse.ArgumentParser(description="Prepare Fleet tasks for SkyRL training") + parser.add_argument( + "--tasks-json", + type=str, + required=True, + help="Path to Fleet tasks JSON file", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./data/fleet", + help="Output directory for parquet files", + ) + parser.add_argument( + "--modality", + type=str, + default="tool_use", + choices=["tool_use", "computer_use", "all"], + help="Task modality filter ('all' for no filter)", + ) + parser.add_argument( + "--eval-ratio", + type=float, + default=0.20, + help="Fraction of data for evaluation (default: 0.20)", + ) + parser.add_argument( + "--env-filter", + type=str, + default=None, + help="Optional env_key filter (e.g., 'github', 'booking')", + ) + parser.add_argument( + "--difficulty-filter", + type=str, + default=None, + help="Optional difficulty filter: 1=easy, 2=medium, 3=hard (e.g., '1,2' for easy+medium)", + ) + parser.add_argument( + "--max-tasks", + type=int, + default=None, + help="Maximum number of tasks to include", + ) + parser.add_argument( + "--max-env-ratio", + type=float, + default=MAX_ENV_TRAIN_RATIO, + help=f"Maximum fraction of training data per environment (default: {MAX_ENV_TRAIN_RATIO})", + ) + parser.add_argument( + "--max-eval-prompts", + type=int, + default=MAX_EVAL_PROMPTS, + help=f"Maximum total eval prompts across all environments (default: {MAX_EVAL_PROMPTS})", + ) + + args = parser.parse_args() + + # Handle 'all' modality + modality = None if args.modality == "all" else args.modality + + prepare_fleet_dataset( + tasks_json=args.tasks_json, + output_dir=args.output_dir, + modality=modality, + eval_ratio=args.eval_ratio, + env_filter=args.env_filter, + difficulty_filter=args.difficulty_filter, + max_tasks=args.max_tasks, + max_env_ratio=args.max_env_ratio, + max_eval_prompts=args.max_eval_prompts, + ) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/reward_metrics.py b/integrations/fleet/reward_metrics.py new file mode 100644 index 0000000000..3978ee1acd --- /dev/null +++ b/integrations/fleet/reward_metrics.py @@ -0,0 +1,253 @@ +"""Unified reward metrics for SkyRL and Tinker. + +This module provides shared metric calculation functions used by both: +- SkyRL trainer (skyrl_train/trainer.py, skyrl_train/utils/trainer_utils.py) +- Tinker integration (integrations/fleet/entrypoints/main_fleet_tinker.py) + +All metrics follow the same naming convention for WandB logging: +- reward/{group}/pass_at_{n} - Pass@n metric for group +- reward/{group}/variance_per_prompt - Mean within-prompt reward variance (GRPO learning signal) +- reward/{group}/signal_ratio - Fraction of prompts with non-zero variance (% with signal) +- reward/{group}/mean_positive_reward - Mean of positive rewards for group + +Rewards can be in two formats: +- Scalar rewards: List[float] - one reward per trajectory +- Token-level rewards: List[List[float]] - per-token rewards per trajectory (summed to scalar) +""" + +from collections import defaultdict +from typing import Any, Dict, List, Union + +import numpy as np + + +def flatten_rewards(rewards: Union[List[float], List[List[float]]]) -> List[float]: + """Flatten rewards to scalar format. + + Handles both scalar rewards (List[float]) and token-level rewards (List[List[float]]). + For token-level rewards, sums each trajectory's rewards into a single scalar. + + Args: + rewards: Either List[float] (scalar per trajectory) or + List[List[float]] (token-level per trajectory) + + Returns: + List[float]: Flattened scalar rewards, one per trajectory + """ + if not rewards: + return [] + + flat_rewards: List[float] = [] + for r in rewards: + if isinstance(r, list): + # Token-level rewards: sum to get trajectory reward + flat_rewards.append(float(sum(r))) + else: + flat_rewards.append(float(r)) + return flat_rewards + + +def sanitize_metric_key(key: str) -> str: + """Sanitize metric key for wandb (replace / with _). + + Args: + key: Raw metric key that may contain slashes + + Returns: + Sanitized key with slashes replaced by underscores + """ + return key.replace("/", "_") + + +def compute_pass_at_n( + rewards: Union[List[float], List[List[float]]], + uids: List[str], +) -> float: + """Compute pass@n: fraction of unique prompts with at least one fully successful rollout. + + For each unique prompt (identified by uid), if ANY of its rollouts achieves a + perfect reward (>= 1.0), that prompt counts as a "pass". This metric measures + how often the model can fully solve a task when given multiple attempts. + Partial rewards (e.g. 0.3 from partial_reward mode) do not count as a pass. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs (one per rollout, same uid = same prompt) + + Returns: + Float between 0.0 and 1.0 representing the fraction of prompts that passed + """ + flat_rewards = flatten_rewards(rewards) + uid_to_rewards: Dict[str, List[float]] = defaultdict(list) + for uid, reward in zip(uids, flat_rewards): + uid_to_rewards[uid].append(reward) + + if not uid_to_rewards: + return 0.0 + + passed = sum(1 for r_list in uid_to_rewards.values() if any(r >= 1.0 for r in r_list)) + return passed / len(uid_to_rewards) + + +def compute_variance_per_prompt( + rewards: Union[List[float], List[List[float]]], + uids: List[str], +) -> float: + """Compute mean within-prompt reward variance (GRPO learning signal). + + For GRPO to learn, there must be variance in rewards within each prompt's rollouts. + If all rollouts for a prompt get the same reward, there's no learning signal. + + This metric computes the variance of rewards for each prompt, then returns the + mean variance across all prompts. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs (one per rollout, same uid = same prompt) + + Returns: + Mean variance across prompts. Higher = more learning signal. + Returns 0.0 if no prompts or all prompts have single rollouts. + """ + flat_rewards = flatten_rewards(rewards) + uid_to_rewards: Dict[str, List[float]] = defaultdict(list) + for uid, reward in zip(uids, flat_rewards): + uid_to_rewards[uid].append(reward) + + if not uid_to_rewards: + return 0.0 + + # Compute variance for each prompt (need at least 2 samples for variance) + variances = [] + for r_list in uid_to_rewards.values(): + if len(r_list) >= 2: + variances.append(float(np.var(r_list))) + + return float(np.mean(variances)) if variances else 0.0 + + +def compute_signal_ratio( + rewards: Union[List[float], List[List[float]]], + uids: List[str], +) -> float: + """Compute fraction of prompts with non-zero variance (GRPO signal ratio). + + This metric shows what percentage of prompts have any learning signal at all. + A prompt has signal if at least one rollout differs from others (variance > 0). + + Unlike variance_per_prompt (which averages variance magnitudes), this metric + counts how many prompts contribute ANY signal, making it easier to interpret: + - 100% = every prompt has at least one differing rollout + - 0% = all prompts have identical rewards across rollouts (no learning possible) + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs (one per rollout, same uid = same prompt) + + Returns: + Float between 0.0 and 1.0 representing fraction of prompts with signal. + Returns 0.0 if no prompts or all prompts have single rollouts. + """ + flat_rewards = flatten_rewards(rewards) + uid_to_rewards: Dict[str, List[float]] = defaultdict(list) + for uid, reward in zip(uids, flat_rewards): + uid_to_rewards[uid].append(reward) + + if not uid_to_rewards: + return 0.0 + + # Count prompts with variance > 0 (need at least 2 samples) + prompts_with_signal = 0 + prompts_total = 0 + for r_list in uid_to_rewards.values(): + if len(r_list) >= 2: + prompts_total += 1 + if np.var(r_list) > 0: + prompts_with_signal += 1 + + return prompts_with_signal / prompts_total if prompts_total > 0 else 0.0 + + +def compute_reward_metrics( + rewards: Union[List[float], List[List[float]]], + uids: List[str], + n_samples_per_prompt: int, +) -> Dict[str, float]: + """Compute core reward metrics. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs for pass@n grouping + n_samples_per_prompt: Number of samples per prompt (used in metric key name) + + Returns: + Dictionary with keys: + - "pass_at_{n}": Pass@n metric + - "variance_per_prompt": Mean within-prompt reward variance (GRPO learning signal) + - "signal_ratio": Fraction of prompts with non-zero variance (% with signal) + - "mean_positive_reward": Mean of positive rewards only + """ + # Flatten rewards once for efficiency (each sub-function would otherwise flatten again) + flat_rewards = flatten_rewards(rewards) + pass_at_n = compute_pass_at_n(flat_rewards, uids) + variance = compute_variance_per_prompt(flat_rewards, uids) + signal_ratio = compute_signal_ratio(flat_rewards, uids) + positive_rewards = [r for r in flat_rewards if r > 0] + mean_positive = float(np.mean(positive_rewards)) if positive_rewards else 0.0 + + return { + f"pass_at_{n_samples_per_prompt}": pass_at_n, + "variance_per_prompt": variance, + "signal_ratio": signal_ratio, + "mean_positive_reward": mean_positive, + } + + +def compute_per_group_metrics( + rewards: Union[List[float], List[List[float]]], + uids: List[str], + groups: List[str], + n_samples_per_prompt: int, + prefix: str = "reward", +) -> Dict[str, float]: + """Compute metrics grouped by a key (env_key, data_source, etc). + + This function computes reward metrics for each group separately, enabling + per-environment analysis in training and evaluation. + + Args: + rewards: List of rewards (one per rollout). Can be scalar (List[float]) + or token-level (List[List[float]]) - will be flattened. + uids: List of unique IDs for pass@n grouping within each group + groups: List of group keys (e.g., env_key or data_source per rollout) + n_samples_per_prompt: Number of samples per prompt (used in metric key name) + prefix: Metric prefix ("reward" for training, "eval" for evaluation) + + Returns: + Dictionary with keys like: + - "{prefix}/{group}/avg_score" + - "{prefix}/{group}/pass_at_{n}" + - "{prefix}/{group}/mean_positive_reward" + """ + # Flatten rewards once before grouping + flat_rewards = flatten_rewards(rewards) + + # Group data by group key + group_data: Dict[str, Dict[str, List[Any]]] = defaultdict(lambda: {"rewards": [], "uids": []}) + for reward, uid, group in zip(flat_rewards, uids, groups): + group_key = group if group is not None else "unknown" + group_data[group_key]["rewards"].append(reward) + group_data[group_key]["uids"].append(uid) + + metrics: Dict[str, float] = {} + for group_key, data in group_data.items(): + sanitized = sanitize_metric_key(group_key) + group_metrics = compute_reward_metrics(data["rewards"], data["uids"], n_samples_per_prompt) + for metric_name, value in group_metrics.items(): + metrics[f"{prefix}/{sanitized}/{metric_name}"] = value + + return metrics diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py new file mode 100644 index 0000000000..5db5aebb40 --- /dev/null +++ b/integrations/fleet/s3_checkpoints.py @@ -0,0 +1,439 @@ +""" +S3 Checkpoint Management for SkyRL Training. + +Provides checkpoint upload to S3, download from S3 for resume, and local cleanup. + +Key behavior: +- Cleans up old local checkpoints BEFORE saving new one (prevents disk full) +- Uploads to S3 asynchronously (non-blocking, training continues) +- Downloads checkpoint from S3 before training for cross-VM resume +- Uploads eval results to S3 for persistence + +Usage: + from integrations.fleet.s3_checkpoints import ( + wrap_trainer_with_s3_upload, + download_checkpoint_from_s3, + upload_eval_results_to_s3, + ) + + # Download checkpoint before training (for resume on new VM) + download_checkpoint_from_s3(ckpt_path, run_name) + + trainer = wrap_trainer_with_s3_upload(trainer, bucket="skyrl-checkpoints") + upload_eval_results_to_s3(local_dir, run_name, global_step) + +Environment Variables: + AWS_ACCESS_KEY_ID: AWS access key + AWS_SECRET_ACCESS_KEY: AWS secret key + AWS_REGION: AWS region (default: us-east-1) + S3_CHECKPOINT_BUCKET: S3 bucket for checkpoints (default: skyrl-checkpoints) + S3_TRAJECTORY_BUCKET: S3 bucket for eval trajectories (default: skyrl-trajectories) +""" + +import os +import shutil +import threading +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Optional +import logging + +logger = logging.getLogger(__name__) + + +class S3CheckpointUploader: + """ + Uploads checkpoint directories to S3 asynchronously. + + Uses a background thread pool to avoid blocking training. + Deletes local checkpoints after successful upload. + """ + + def __init__( + self, + bucket: str, + prefix: str, + region: str = "us-east-1", + max_workers: int = 2, + ): + self.bucket = bucket + self.prefix = prefix + self.region = region + self._executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix="s3-upload") + self._pending: set = set() + self._lock = threading.Lock() + + def _upload_sync(self, local_dir: str) -> bool: + """Synchronous upload that runs in thread pool.""" + try: + import boto3 + from botocore.config import Config + from boto3.s3.transfer import TransferConfig + + config = Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=120, + ) + + s3 = boto3.client("s3", region_name=self.region, config=config) + + local_path = Path(local_dir) + if not local_path.exists(): + logger.warning(f"Checkpoint directory does not exist: {local_dir}") + return False + + checkpoint_name = local_path.name + s3_prefix = f"{self.prefix}/{checkpoint_name}" + + transfer_config = TransferConfig( + multipart_threshold=64 * 1024 * 1024, + multipart_chunksize=64 * 1024 * 1024, + max_concurrency=4, + use_threads=True, + ) + + uploaded_files = 0 + total_size = 0 + + for file_path in local_path.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(local_path) + s3_key = f"{s3_prefix}/{relative_path}" + file_size = file_path.stat().st_size + total_size += file_size + + logger.info(f"Uploading {file_path.name} ({file_size / 1e6:.1f} MB)") + + s3.upload_file(str(file_path), self.bucket, s3_key, Config=transfer_config) + uploaded_files += 1 + + logger.info( + f"Uploaded {checkpoint_name}: {uploaded_files} files, {total_size / 1e9:.2f} GB to s3://{self.bucket}/{s3_prefix}/" + ) + + # Delete local after successful upload to free disk space + logger.info(f"Deleting local checkpoint after S3 upload: {local_dir}") + shutil.rmtree(local_dir) + + return True + + except Exception as e: + logger.error(f"S3 upload failed for {local_dir}: {e}") + return False + finally: + with self._lock: + self._pending.discard(local_dir) + + def upload_async(self, local_dir: str) -> None: + """Queue checkpoint for async upload. Non-blocking.""" + with self._lock: + if local_dir in self._pending: + return + self._pending.add(local_dir) + + logger.info(f"Queuing checkpoint for S3 upload: {local_dir}") + self._executor.submit(self._upload_sync, local_dir) + + def wait_for_uploads(self, timeout: Optional[float] = None) -> None: + """Wait for all pending uploads to complete.""" + self._executor.shutdown(wait=True) + self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="s3-upload") + + +def cleanup_old_local_checkpoints(ckpt_path: str, keep_n: int = 2) -> None: + """ + Delete old local checkpoints, keeping only the most recent N. + + Args: + ckpt_path: Base checkpoint directory + keep_n: Number of recent checkpoints to keep (default: 2 for safety) + """ + ckpt_dir = Path(ckpt_path) + if not ckpt_dir.exists(): + return + + checkpoint_dirs = sorted( + [d for d in ckpt_dir.iterdir() if d.is_dir() and d.name.startswith("global_step_")], + key=lambda x: int(x.name.split("_")[-1]), + reverse=True, + ) + + for old_dir in checkpoint_dirs[keep_n:]: + logger.info(f"Cleaning up old local checkpoint: {old_dir}") + try: + shutil.rmtree(old_dir) + except Exception as e: + logger.warning(f"Failed to delete {old_dir}: {e}") + + +def wrap_trainer_with_s3_upload( + trainer, + bucket: Optional[str] = None, + prefix: Optional[str] = None, + region: Optional[str] = None, +): + """ + Wrap a SkyRL trainer to: + 1. Clean up old checkpoints BEFORE saving (prevents disk full) + 2. Upload to S3 asynchronously AFTER saving (if credentials set) + 3. Delete local checkpoint after successful S3 upload (frees disk) + + Args: + trainer: SkyRL trainer instance + bucket: S3 bucket (default: from S3_CHECKPOINT_BUCKET env var) + prefix: S3 prefix (default: from trainer config) + region: AWS region (default: from AWS_REGION env var) + + Returns: + The trainer (modified in place) + """ + bucket = bucket or os.environ.get("S3_CHECKPOINT_BUCKET", "skyrl-checkpoints") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + # Build prefix from trainer config + if prefix is None: + run_name = getattr(trainer.cfg.trainer, "run_name", None) + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + prefix = f"{project_name}/{model_name}/{run_name}" if run_name else f"{project_name}/{model_name}" + + # Check AWS credentials + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + s3_enabled = bool(aws_key and aws_secret) + + if s3_enabled: + logger.info(f"S3 checkpoint upload ENABLED: s3://{bucket}/{prefix}/") + uploader = S3CheckpointUploader(bucket=bucket, prefix=prefix, region=region) + else: + logger.warning( + "AWS credentials not found. S3 upload DISABLED. " + "Using aggressive local cleanup (keeping only 2 checkpoints). " + "Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to enable S3." + ) + uploader = None + + original_save_checkpoints = trainer.save_checkpoints + ckpt_path = trainer.cfg.trainer.ckpt_path + + def save_checkpoints_with_cleanup(): + """Wrapped save_checkpoints with pre-save cleanup and async S3 upload.""" + # CRITICAL: Clean up old checkpoints BEFORE saving to free disk space + # With S3: keep only 1 (we have S3 backup), allows room for new checkpoint + # Without S3: keep 2 for safety + keep_n = 1 if s3_enabled else 2 + cleanup_old_local_checkpoints(ckpt_path, keep_n=keep_n) + + # Now save the new checkpoint (disk has space) + original_save_checkpoints() + + # Queue async S3 upload (non-blocking) + if s3_enabled and uploader: + global_step = trainer.global_step + checkpoint_dir = os.path.join(ckpt_path, f"global_step_{global_step}") + if os.path.exists(checkpoint_dir): + uploader.upload_async(checkpoint_dir) + + trainer.save_checkpoints = save_checkpoints_with_cleanup + trainer._s3_uploader = uploader + + return trainer + + +def download_checkpoint_from_s3( + ckpt_path: str, + run_name: str, + bucket: Optional[str] = None, + region: Optional[str] = None, + project_name: str = "fleet-task-grpo", + model_name: str = "Qwen3-32B", +) -> bool: + """ + Download the latest checkpoint from S3 for resume on a fresh VM. + + Looks for checkpoint directories under the S3 prefix matching the run_name, + downloads the latest one, and writes latest_ckpt_global_step.txt. + + Args: + ckpt_path: Local checkpoint directory (e.g., ~/ckpts/fleet_tool_use_32b) + run_name: W&B run name used as S3 prefix (e.g., fleet_tool_use_32b_d7167c1c) + bucket: S3 bucket (default: from S3_CHECKPOINT_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + project_name: Project name used in S3 prefix + model_name: Model name used in S3 prefix + + Returns: + True if checkpoint was downloaded, False otherwise + """ + bucket = bucket or os.environ.get("S3_CHECKPOINT_BUCKET", "skyrl-checkpoints") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.info("No AWS credentials, skipping S3 checkpoint download") + return False + + # Check if local checkpoint already exists + latest_file = os.path.join(ckpt_path, "latest_ckpt_global_step.txt") + if os.path.exists(latest_file): + with open(latest_file, "r") as f: + step = f.read().strip() + local_ckpt = os.path.join(ckpt_path, f"global_step_{step}") + if os.path.exists(local_ckpt): + logger.info(f"Local checkpoint already exists at step {step}, skipping S3 download") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=120, + ) + s3 = boto3.client("s3", region_name=region, config=config) + + # S3 prefix matches what wrap_trainer_with_s3_upload builds + s3_prefix = f"{project_name}/{model_name}/{run_name}/" + + # List all checkpoint directories in S3 + paginator = s3.get_paginator("list_objects_v2") + checkpoint_steps = set() + for page in paginator.paginate(Bucket=bucket, Prefix=s3_prefix, Delimiter="/"): + for prefix_obj in page.get("CommonPrefixes", []): + dir_name = prefix_obj["Prefix"].rstrip("/").split("/")[-1] + if dir_name.startswith("global_step_"): + try: + step = int(dir_name.split("_")[-1]) + checkpoint_steps.add(step) + except ValueError: + pass + + if not checkpoint_steps: + logger.info(f"No checkpoints found in s3://{bucket}/{s3_prefix}") + return False + + latest_step = max(checkpoint_steps) + s3_ckpt_prefix = f"{s3_prefix}global_step_{latest_step}/" + local_ckpt_dir = os.path.join(ckpt_path, f"global_step_{latest_step}") + + logger.info(f"Downloading checkpoint step {latest_step} from s3://{bucket}/{s3_ckpt_prefix}") + + os.makedirs(local_ckpt_dir, exist_ok=True) + + downloaded_files = 0 + total_size = 0 + for page in paginator.paginate(Bucket=bucket, Prefix=s3_ckpt_prefix): + for obj in page.get("Contents", []): + s3_key = obj["Key"] + relative_path = s3_key[len(s3_ckpt_prefix) :] + if not relative_path: + continue + local_file = os.path.join(local_ckpt_dir, relative_path) + os.makedirs(os.path.dirname(local_file), exist_ok=True) + file_size = obj["Size"] + total_size += file_size + logger.info(f"Downloading {relative_path} ({file_size / 1e6:.1f} MB)") + s3.download_file(bucket, s3_key, local_file) + downloaded_files += 1 + + # Write latest_ckpt_global_step.txt so SkyRL's resume_mode=latest can find it + os.makedirs(ckpt_path, exist_ok=True) + with open(latest_file, "w") as f: + f.write(str(latest_step)) + + logger.info( + f"Downloaded checkpoint: {downloaded_files} files, {total_size / 1e9:.2f} GB " + f"from s3://{bucket}/{s3_ckpt_prefix} to {local_ckpt_dir}" + ) + return True + + except Exception as e: + logger.error(f"Failed to download checkpoint from S3: {e}") + return False + + +def upload_eval_results_to_s3( + local_dir: str, + run_name: str, + global_step: Optional[int] = None, + bucket: Optional[str] = None, + region: Optional[str] = None, + delete_local: bool = False, +) -> bool: + """ + Upload eval results directory to S3. + + Args: + local_dir: Local directory containing eval JSONL files + run_name: Run name for S3 prefix (e.g., "fleet_tool_use_abc123") + global_step: Global step number (for organizing in S3) + bucket: S3 bucket (default: from S3_TRAJECTORY_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + delete_local: If True, delete local files after upload + + Returns: + True if upload succeeded, False otherwise + """ + bucket = bucket or os.environ.get("S3_TRAJECTORY_BUCKET", "skyrl-trajectories") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + # Check AWS credentials + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.warning("AWS credentials not found. Skipping S3 upload for eval results.") + return False + + local_path = Path(local_dir) + if not local_path.exists(): + logger.warning(f"Eval directory does not exist: {local_dir}") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config( + retries={"max_attempts": 3, "mode": "adaptive"}, + connect_timeout=30, + read_timeout=60, + ) + + s3 = boto3.client("s3", region_name=region, config=config) + + # Build S3 prefix: evals/{run_name}/global_step_{N}/ + step_suffix = f"global_step_{global_step}" if global_step is not None else "eval_only" + s3_prefix = f"evals/{run_name}/{step_suffix}" + + uploaded_files = 0 + total_size = 0 + + for file_path in local_path.rglob("*"): + if file_path.is_file(): + relative_path = file_path.relative_to(local_path) + s3_key = f"{s3_prefix}/{relative_path}" + file_size = file_path.stat().st_size + total_size += file_size + + s3.upload_file(str(file_path), bucket, s3_key) + uploaded_files += 1 + + logger.info( + f"Uploaded eval results: {uploaded_files} files, {total_size / 1e6:.2f} MB " + f"to s3://{bucket}/{s3_prefix}/" + ) + + if delete_local: + shutil.rmtree(local_dir) + logger.info(f"Deleted local eval directory: {local_dir}") + + return True + + except Exception as e: + logger.error(f"S3 upload failed for eval results {local_dir}: {e}") + return False diff --git a/integrations/fleet/task_gen_reward.py b/integrations/fleet/task_gen_reward.py new file mode 100644 index 0000000000..80bdf5a886 --- /dev/null +++ b/integrations/fleet/task_gen_reward.py @@ -0,0 +1,89 @@ +""" +Reward functions for task generation RL. + +Computes: + R(task) = llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) + +Components: + - var(raw_scores): Variance of k raw (no-hint) evaluator rollouts. + Measures difficulty calibration — maximized at p_raw ≈ 0.5 + (Bernoulli variance = 0.25). Tasks at the evaluator's frontier. + - p_hint - p_raw: Hint gap — mean(hinted) minus mean(raw). + Positive when hints help, meaning the task is hard but solvable. + Captures learnability beyond current capability. + - llm_validity: LLM-as-a-judge gate (0/1). Kills reward for broken tasks. + - alpha: Weight balancing variance (frontier difficulty) vs hint gap (learnability). Default 1.0 (equal weight). +""" + +from typing import Dict, List + + +def compute_variance(scores: List[float]) -> float: + """Compute variance of binary rollout scores. + + Args: + scores: List of binary (0/1) rollout outcomes. + + Returns: + Variance in [0, 0.25]. Zero when all same, max at p=0.5. + """ + if len(scores) < 2: + return 0.0 + mean = sum(scores) / len(scores) + return sum((s - mean) ** 2 for s in scores) / len(scores) + + +def compute_hint_gap(raw_scores: List[float], hinted_scores: List[float]) -> float: + """Compute hint gap: mean(hinted) - mean(raw). + + Positive when hints help the evaluator solve the task. + Zero or negative when hints don't help (task too easy or too hard). + + Args: + raw_scores: Scores from evaluator rollouts without hints. + hinted_scores: Scores from evaluator rollouts with hints. + + Returns: + Hint gap in [-1, 1]. + """ + if not raw_scores or not hinted_scores: + return 0.0 + p_raw = sum(raw_scores) / len(raw_scores) + p_hint = sum(hinted_scores) / len(hinted_scores) + return p_hint - p_raw + + +def compute_task_reward( + raw_scores: List[float], + hinted_scores: List[float], + validity: float = 1.0, + alpha: float = 1.0, +) -> Dict[str, float]: + """Compute the full task generation reward. + + R = validity * (alpha * var(raw) + (p_hint - p_raw)) + + Args: + raw_scores: Scores from k evaluator rollouts without hints. + hinted_scores: Scores from k evaluator rollouts with hints. + validity: LLM-as-a-judge gate (0.0 or 1.0). + alpha: Weight for variance term. + + Returns: + Dict with all reward components and total. + """ + p_raw = sum(raw_scores) / len(raw_scores) if raw_scores else 0.0 + p_hint = sum(hinted_scores) / len(hinted_scores) if hinted_scores else 0.0 + var_raw = compute_variance(raw_scores) + hint_gap = p_hint - p_raw + total = validity * (alpha * var_raw + hint_gap) + + return { + "validity": validity, + "p_raw": p_raw, + "p_hint": p_hint, + "var_raw": var_raw, + "hint_gap": hint_gap, + "alpha": alpha, + "total": total, + } diff --git a/integrations/fleet/tests/__init__.py b/integrations/fleet/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/fleet/utils.py b/integrations/fleet/utils.py new file mode 100644 index 0000000000..caa64eaab0 --- /dev/null +++ b/integrations/fleet/utils.py @@ -0,0 +1,118 @@ +""" +Utility functions for Fleet task training with Tinker. + +These functions handle sequence truncation and loss mask filtering, +matching SkyRL's skyrl_gym_generator patterns. +""" + +from typing import List, Tuple + + +def truncate_sequence( + prompt_ids: List[int], + response_ids: List[int], + max_sequence_length: int, +) -> Tuple[List[int], List[int], int]: + """ + Truncate a sequence to fit within max_sequence_length. + + The prompt is preserved fully; only the response is truncated. + + Args: + prompt_ids: Token IDs for the prompt. + response_ids: Token IDs for the response. + max_sequence_length: Maximum total sequence length. + + Returns: + Tuple of (full_sequence, truncated_response_ids, response_len). + """ + full_sequence = prompt_ids + response_ids + prompt_len = len(prompt_ids) + + if len(full_sequence) > max_sequence_length: + full_sequence = full_sequence[:max_sequence_length] + response_len = len(full_sequence) - prompt_len + truncated_response_ids = response_ids[:response_len] + else: + response_len = len(response_ids) + truncated_response_ids = response_ids + + return full_sequence, truncated_response_ids, response_len + + +def truncate_auxiliary_data( + data: List, + response_len: int, +) -> List: + """ + Truncate auxiliary data (logprobs, loss_mask) to match truncated response length. + + Args: + data: List of values corresponding to response tokens. + response_len: Target length after truncation. + + Returns: + Truncated list. + """ + if len(data) > response_len: + return data[:response_len] + return data + + +def apply_overlong_filtering_simple( + loss_masks: List[List[int]], + response_ids: List[List[int]], + eos_token_id: int, +) -> List[List[int]]: + """ + Apply DAPO overlong filtering: zero out loss mask for responses not ending with EOS. + + This is a simplified version for testing - the actual SkyRL function is in + skyrl_train.generators.utils.apply_overlong_filtering. + + Args: + loss_masks: List of loss masks for each response. + response_ids: List of response token IDs for each response. + eos_token_id: The EOS token ID. + + Returns: + Filtered loss masks (zeroed if response doesn't end with EOS). + """ + filtered = [] + for mask, response in zip(loss_masks, response_ids): + # Empty response or doesn't end with EOS -> zero out mask + if not response or response[-1] != eos_token_id: + filtered.append([0] * len(mask)) + else: + filtered.append(list(mask)) + return filtered + + +def prepare_training_sequence( + prompt_ids: List[int], + response_ids: List[int], + logprobs: List[float], + loss_mask: List[int], + max_sequence_length: int, +) -> Tuple[List[int], List[float], List[int], bool]: + """ + Prepare a training sequence with truncation if needed. + + Args: + prompt_ids: Token IDs for the prompt. + response_ids: Token IDs for the response. + logprobs: Log probabilities for response tokens. + loss_mask: Loss mask for response tokens. + max_sequence_length: Maximum total sequence length. + + Returns: + Tuple of (full_sequence, truncated_logprobs, truncated_loss_mask, was_truncated). + """ + full_sequence, truncated_response, response_len = truncate_sequence(prompt_ids, response_ids, max_sequence_length) + + was_truncated = len(prompt_ids) + len(response_ids) > max_sequence_length + + truncated_logprobs = truncate_auxiliary_data(logprobs, response_len) + truncated_loss_mask = truncate_auxiliary_data(loss_mask, response_len) + + return full_sequence, truncated_logprobs, truncated_loss_mask, was_truncated diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh new file mode 100755 index 0000000000..b183e317e0 --- /dev/null +++ b/scripts/fleet-35b-run.sh @@ -0,0 +1,90 @@ +#!/usr/bin/env bash +# Single source of truth for Qwen3.5-35B-A3B GRPO training config. +# Called by the SkyPilot YAML and by fleet-research run.sh. +# +# Required env vars: FLEET_API_KEY, WANDB_API_KEY +# Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) +set -euo pipefail +cd "$(dirname "$0")/../.." # cd to SkyRL root + +# Defaults for vars normally set by SkyPilot YAML envs block +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export DATA_VERSION="${DATA_VERSION:-v55}" +export MODALITY="${MODALITY:-tool_use}" +export NUM_EPOCHS="${NUM_EPOCHS:-20}" +export MAX_TURNS="${MAX_TURNS:-50}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" +export ENV_KEYS="${ENV_KEYS:-}" +export DIFFICULTY="${DIFFICULTY:-}" +export RUN_ID="${RUN_ID:-}" +export MAX_TASKS="${MAX_TASKS:-}" +export RESUME_RUN_NAME="${RESUME_RUN_NAME:-}" +export AWS_REGION="${AWS_REGION:-us-east-1}" +export S3_DATASET_BUCKET="${S3_DATASET_BUCKET:-fleet-internal-datasets}" +export S3_CHECKPOINT_BUCKET="${S3_CHECKPOINT_BUCKET:-skyrl-checkpoints}" +export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" + +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit \ + --nccl-heartbeat 1800 -- \ + environment.skyrl_gym.fleet_task.ttl_seconds=900 \ + environment.skyrl_gym.fleet_task.partial_reward=true \ + environment.skyrl_gym.fleet_task.enable_hints=true \ + environment.skyrl_gym.fleet_task.n_hint_samples=2 \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ + trainer.flash_attn=true \ + trainer.loss_chunk_size=4096 \ + trainer.use_sample_packing=false \ + +generator.chat_template_kwargs='{enable_thinking:true}' \ + generator.inference_engine_tensor_parallel_size=2 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=8 \ + trainer.eval_before_train=false \ + trainer.eval_interval=20 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=16 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=16 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_ckpts_to_keep=1 \ + trainer.max_prompt_length=2048 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.9 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=[""]' \ + 'generator.eval_sampling_params.stop=[""]' \ + trainer.policy.optimizer_config.lr=5.0e-7 \ + trainer.algorithm.use_kl_loss=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=8 \ + generator.eval_n_samples_per_prompt=3 \ + generator.enforce_eager=false \ + generator.gpu_memory_utilization=0.65 \ + generator.inject_context_status=true \ + generator.context_warning_threshold=0.90 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-task-grpo" \ + trainer.run_name="fleet_qwen35_35b_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/fleet_qwen35_35b_${MODALITY}" \ + trainer.export_path="$HOME/exports" \ + trainer.dump_data_batch=true \ + "$@" diff --git a/scripts/fleet-common-run.sh b/scripts/fleet-common-run.sh new file mode 100755 index 0000000000..2a4571c563 --- /dev/null +++ b/scripts/fleet-common-run.sh @@ -0,0 +1,313 @@ +#!/usr/bin/env bash +# Fleet shared run: Ray cluster setup (multi-node aware) + training launch +# +# Usage (from SkyPilot YAML run block): +# bash skyrl-train/scripts/fleet-common-run.sh \ +# --use-python-direct --cuda-env "$HOME/.cuda_env" \ +# --set-ulimit --no-pytorch-alloc-conf -- \ +# trainer.policy.model.path="Qwen/Qwen3.5-9B" \ +# trainer.epochs=20 ... +# +# Multi-node: +# Rank 0 (head): starts Ray head, launches training +# Rank >0 (workers): joins Ray cluster, sleeps +# +# Required env vars: WANDB_API_KEY, MODALITY, INFERENCE_BACKEND, +# SKYPILOT_NUM_GPUS_PER_NODE, SKYPILOT_NODE_IPS +# Optional env vars: SKYPILOT_NUM_NODES, SKYPILOT_NODE_RANK +set -euo pipefail + +# Defaults +DATA_ROOT="" +CKPT_ROOT="" +USE_PYTHON_DIRECT=false +CUDA_ENV="" +SET_ULIMIT=false +NO_PYTORCH_ALLOC_CONF=false +NCCL_HEARTBEAT="" +ENTRYPOINT="integrations.fleet.entrypoints.main_fleet" +ENV_CLASS="fleet_task" +DATA_DIR_NAME="" +HYDRA_OVERRIDES=() + +# Parse args +while [[ $# -gt 0 ]]; do + case "$1" in + --data-root) DATA_ROOT="$2"; shift 2 ;; + --ckpt-root) CKPT_ROOT="$2"; shift 2 ;; + --use-python-direct) USE_PYTHON_DIRECT=true; shift ;; + --cuda-env) CUDA_ENV="$2"; shift 2 ;; + --set-ulimit) SET_ULIMIT=true; shift ;; + --no-pytorch-alloc-conf) NO_PYTORCH_ALLOC_CONF=true; shift ;; + --nccl-heartbeat) NCCL_HEARTBEAT="$2"; shift 2 ;; + --entrypoint) ENTRYPOINT="$2"; shift 2 ;; + --env-class) ENV_CLASS="$2"; shift 2 ;; + --data-dir-name) DATA_DIR_NAME="$2"; shift 2 ;; + --) shift; HYDRA_OVERRIDES=("$@"); break ;; + *) echo "ERROR: Unknown arg: $1"; exit 1 ;; + esac +done + +# Auto-detect data/ckpt root: /workspace if writable (RunPod), else $HOME (GCP, Lambda, etc.) +if [ -z "$DATA_ROOT" ]; then + if [ -d "/workspace" ] && [ -w "/workspace" ]; then + DATA_ROOT="/workspace" + else + DATA_ROOT="$HOME" + fi +fi +if [ -z "$CKPT_ROOT" ]; then + CKPT_ROOT="$DATA_ROOT" +fi +DATA_DIR_NAME="${DATA_DIR_NAME:-$MODALITY}" + +echo "=== Fleet Common Run ===" +echo "Entrypoint: $ENTRYPOINT" +echo "Env class: $ENV_CLASS" +echo "Data root: $DATA_ROOT" +echo "Data dir name: $DATA_DIR_NAME" +echo "Ckpt root: $CKPT_ROOT" + +# Activate venv from repo root (upstream SkyRL layout) +source .venv/bin/activate + +# --- Optional settings --- +if [ "$SET_ULIMIT" = true ]; then + # Set open files limit. Try 1M first, fall back to hard limit if too high. + ulimit -n 1048576 2>/dev/null || ulimit -n "$(ulimit -Hn)" 2>/dev/null || true +fi + +# vLLM TP>1 uses pidfd_getfd for CUDA IPC weight sync between Ray workers. +# This requires ptrace access, which is blocked by default (ptrace_scope=1). +sudo sysctl -w kernel.yama.ptrace_scope=0 2>/dev/null || true + +if [ -n "$CUDA_ENV" ]; then + source "$CUDA_ENV" 2>/dev/null || true +fi + +if [ "$NO_PYTORCH_ALLOC_CONF" = false ]; then + export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True +fi + +if [ -n "$NCCL_HEARTBEAT" ]; then + export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC="$NCCL_HEARTBEAT" +fi + +TMP_DIR="${CKPT_ROOT}/skyrl-tmp" +mkdir -p "$TMP_DIR" +export TMPDIR="$TMP_DIR" + +TASKS_FILE="${DATA_ROOT}/data/fleet/tasks_${MODALITY}.json" +DATA_DIR="${DATA_ROOT}/data/fleet/${DATA_DIR_NAME}" + +# --- System diagnostics --- +echo "=== System Diagnostics ===" +free -h +nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv 2>/dev/null || true +echo "--- /dev/shm ---" +df -h /dev/shm 2>/dev/null || echo "/dev/shm not mounted" +ls -la /dev/shm/ 2>/dev/null | head -5 || true +echo "--- GPU Topology ---" +nvidia-smi topo -m 2>/dev/null || true +echo "--- cgroup memory limits ---" +cat /sys/fs/cgroup/memory.max 2>/dev/null || cat /sys/fs/cgroup/memory/memory.limit_in_bytes 2>/dev/null || echo "No cgroup memory limit found" +echo "--- ulimits ---" +ulimit -a 2>/dev/null || true +echo "--- NCCL env vars ---" +env | grep -i NCCL || echo "No NCCL env vars set" +echo "--- kernel overcommit ---" +cat /proc/sys/vm/overcommit_memory 2>/dev/null || true +echo "=== End Diagnostics ===" + +# --- wandb login --- +python3 -c "import wandb; wandb.login(relogin=True, key='$WANDB_API_KEY')" + +# --- Fabric Manager check (NVSwitch GPUs: B200, H200 SXM) --- +# On non-GCP clouds (RunPod, Lambda, etc.), Fabric Manager is required for NVLink +# P2P on NVSwitch systems. Without it, dist.broadcast() in FSDP causes SIGKILL. +# +# On GCP, NVSwitch is managed at the HOST level — the guest VM does not have +# NVSwitch devices, so FM reports "NV_WARN_NOTHING_TO_DO" and cannot start. +# This is EXPECTED. NVLink P2P works through GCP's host-managed fabric without FM. +# GCP also provides a custom NCCL shim (gIB) that manages all NCCL configuration. +# Do NOT set NCCL_P2P_DISABLE or NCCL_NVLS_ENABLE on GCP with RDMA — +# the shim's "Guest Config Checker" expects these to be unset. +# NCCL_CUMEM_ENABLE=0 is set below for GCP WITHOUT RDMA to disable multicast. +ON_GCP=false +if [ -d "/usr/local/gib" ]; then + ON_GCP=true +elif [ -f "/sys/class/dmi/id/product_name" ] && grep -qi "google" /sys/class/dmi/id/product_name 2>/dev/null; then + ON_GCP=true +fi + +FM_STATUS=$(systemctl is-active nvidia-fabricmanager 2>/dev/null || echo "unknown") +echo "Fabric Manager status: $FM_STATUS" +echo "On GCP: $ON_GCP" + +if [ "$ON_GCP" = true ]; then + echo "GCP detected — skipping Fabric Manager restart (host manages NVSwitch)" + + # GCP's deep learning images install /etc/profile.d/nccl_env.sh which auto-sources + # /usr/local/gib/scripts/set_nccl_env.sh and adds /usr/local/gib/lib64 to LD_LIBRARY_PATH. + # This sets NCCL_NET=gIB, forcing the gIB network plugin for RDMA/InfiniBand. + # + # Problem: gIB requires RDMA hardware (ConnectX NICs + multiple GPUDirect VPC networks). + # SkyPilot provisions VMs with a single management NIC — no RDMA networking. + # When NCCL_NET=gIB is forced but gIB can't init, NCCL fails with + # "Failed to initialize any NET plugin" → SIGKILL during dist.broadcast(). + # + # Fix: check for RDMA devices. If absent, strip gIB so NCCL falls back to + # NVLink P2P for intra-node communication. Multi-node uses GKE with RDMA. + if [ -d "/sys/class/infiniband" ] && [ "$(ls /sys/class/infiniband/ 2>/dev/null)" ]; then + echo "RDMA devices found — keeping gIB for GPUDirect RDMA" + else + echo "No RDMA devices — disabling gIB" + # Remove gIB from LD_LIBRARY_PATH (set by /etc/profile.d/nccl_env.sh) + export LD_LIBRARY_PATH=$(echo "${LD_LIBRARY_PATH:-}" | sed 's|/usr/local/gib/lib64:||g; s|:/usr/local/gib/lib64||g; s|/usr/local/gib/lib64||g') + # Unset NCCL_NET=gIB so NCCL can fall back to NVLink P2P + unset NCCL_NET + # Clear gIB-specific vars set by set_nccl_env.sh + unset NCCL_CROSS_NIC NCCL_NET_GDR_LEVEL NCCL_P2P_NET_CHUNKSIZE NCCL_NVLS_CHUNKSIZE + unset NCCL_IB_ADAPTIVE_ROUTING NCCL_IB_QPS_PER_CONNECTION NCCL_IB_TC NCCL_IB_FIFO_TC + unset NCCL_TUNER_CONFIG_PATH + # Disable CUDA multicast (requires NVSwitch fabric manager for GPU multicast + # team setup). Without this, vLLM TP>1 hangs on CUDASymmetricMemory init. + export NCCL_CUMEM_ENABLE=0 + echo "Cleared gIB NCCL env vars. Using NVLink P2P (intra-node)." + fi + echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" + echo "NCCL vars:" + env | grep -i NCCL || echo " (none)" + + # Ensure /dev/shm is large enough for NCCL IPC (some GCP images have small default) + SHM_SIZE=$(df --output=size /dev/shm 2>/dev/null | tail -1 | tr -d ' ') + echo "Current /dev/shm size: ${SHM_SIZE}K" + if [ -n "$SHM_SIZE" ] && [ "$SHM_SIZE" -lt 16777216 ]; then + echo "WARNING: /dev/shm is only ${SHM_SIZE}K — remounting to 16G for NCCL" + sudo mount -o remount,size=16G /dev/shm 2>&1 || echo "Failed to remount /dev/shm" + df -h /dev/shm + fi +elif [ "$FM_STATUS" != "active" ]; then + echo "WARNING: Fabric Manager not active. Attempting restart..." + sudo nvidia-smi -pm 1 2>&1 || true + sudo systemctl stop nvidia-fabricmanager 2>&1 || true + sleep 1 + sudo systemctl start nvidia-fabricmanager 2>&1 || true + sleep 5 + FM_STATUS=$(systemctl is-active nvidia-fabricmanager 2>/dev/null || echo "unknown") + echo "Fabric Manager status after restart: $FM_STATUS" + if [ "$FM_STATUS" != "active" ]; then + echo "=== WARNING: Fabric Manager failed to start ===" + echo "Training may fail if this system has NVSwitch GPUs." + sudo journalctl -u nvidia-fabricmanager --no-pager -n 10 2>&1 || true + fi +fi + +# --- Ray cluster setup (multi-node aware) --- +export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook +export RAY_object_store_memory=10000000000 +# Disable Ray's memory monitor to prevent spurious worker kills +export RAY_DISABLE_MEMORY_MONITOR=1 +# NOTE: On GCP VMs without RDMA, gIB NCCL vars are stripped above. +# On GKE with RDMA, gIB is preserved for inter-node GPUDirect. + +read -r head_ip _ <<< "$SKYPILOT_NODE_IPS" + +wait_for_ray() { + local address=$1 + for _ in $(seq 1 24); do + if ray status --address "$address" >/dev/null 2>&1; then + return 0 + fi + sleep 5 + done + echo "ERROR: Ray cluster at $address failed to become ready" >&2 + return 1 +} + +if [ "${SKYPILOT_NODE_RANK:-0}" = "0" ]; then + # === Head node: start Ray head + launch training === + if ! ray status --address 127.0.0.1:6479 >/dev/null 2>&1; then + ray start --head --disable-usage-stats --port 6479 --object-store-memory=10000000000 + fi + wait_for_ray 127.0.0.1:6479 + + TOTAL_GPUS=$((SKYPILOT_NUM_GPUS_PER_NODE * ${SKYPILOT_NUM_NODES:-1})) + export TOTAL_GPUS + # NUM_INFERENCE_ENGINES can be overridden via env var for TP>1 (engines = GPUs / TP) + NUM_INFERENCE_ENGINES=${NUM_INFERENCE_ENGINES:-$TOTAL_GPUS} + echo "=== Head node: $TOTAL_GPUS GPUs across ${SKYPILOT_NUM_NODES:-1} node(s), $NUM_INFERENCE_ENGINES inference engines ===" + + # Build training command + CMD_ARGS=() + if [ "$USE_PYTHON_DIRECT" = true ]; then + CMD_ARGS=(python -m "$ENTRYPOINT") + else + CMD_ARGS=(uv run --isolated --extra "$INFERENCE_BACKEND" -m "$ENTRYPOINT") + fi + + # Common hydra overrides (data paths, placement, strategy, checkpoints) + CMD_ARGS+=( + "data.train_data=['${DATA_DIR}/train.parquet']" + "data.val_data=['${DATA_DIR}/validation.parquet']" + "environment.env_class=$ENV_CLASS" + ) + + # fleet_task-specific: pass tasks_file path + if [ "$ENV_CLASS" = "fleet_task" ]; then + CMD_ARGS+=("environment.skyrl_gym.fleet_task.tasks_file=$TASKS_FILE") + fi + + CMD_ARGS+=( + trainer.placement.colocate_all=true + trainer.strategy=fsdp2 + "trainer.placement.policy_num_gpus_per_node=$SKYPILOT_NUM_GPUS_PER_NODE" + "trainer.placement.ref_num_gpus_per_node=$SKYPILOT_NUM_GPUS_PER_NODE" + "trainer.placement.policy_num_nodes=${SKYPILOT_NUM_NODES:-1}" + "trainer.placement.ref_num_nodes=${SKYPILOT_NUM_NODES:-1}" + "generator.num_inference_engines=$NUM_INFERENCE_ENGINES" + "trainer.ckpt_path=${CKPT_ROOT}/ckpts" + "trainer.export_path=${CKPT_ROOT}/exports" + ) + + # Append model-specific hydra overrides (passed after --) + if [ ${#HYDRA_OVERRIDES[@]} -gt 0 ]; then + CMD_ARGS+=("${HYDRA_OVERRIDES[@]}") + fi + + export HYDRA_FULL_ERROR=1 + echo "=== Launching Training ===" + set +e + "${CMD_ARGS[@]}" + EXIT_CODE=$? + set -e + + if [ $EXIT_CODE -ne 0 ]; then + echo "=== Training failed (exit code $EXIT_CODE) ===" + echo "--- dmesg (last 50 lines, unfiltered) ---" + sudo dmesg -T 2>/dev/null | tail -50 || true + echo "--- dmesg (OOM/kill/segfault) ---" + sudo dmesg -T 2>/dev/null | grep -iE "oom|kill|out of memory|segfault|sigsegv|general protection|cgroup" | tail -20 || true + echo "--- memory ---" + free -h + echo "--- GPU memory ---" + nvidia-smi --query-gpu=memory.used,memory.free --format=csv 2>/dev/null || true + echo "--- /dev/shm after crash ---" + df -h /dev/shm 2>/dev/null || true + echo "--- cgroup memory events ---" + cat /sys/fs/cgroup/memory.events 2>/dev/null || cat /sys/fs/cgroup/memory/memory.oom_control 2>/dev/null || true + echo "--- Ray worker logs (last errors) ---" + grep -r "SIGKILL\|SIGABRT\|SIGSEGV\|SYSTEM_ERROR\|RuntimeError\|NCCL" /tmp/ray/session_latest/logs/ 2>/dev/null | tail -30 || true + exit $EXIT_CODE + fi + +else + # === Worker node: join Ray cluster and wait === + echo "=== Worker node (rank ${SKYPILOT_NODE_RANK}), joining Ray cluster at $head_ip:6479 ===" + if ! ray status --address "$head_ip:6479" >/dev/null 2>&1; then + ray start --address "$head_ip:6479" --disable-usage-stats + fi + wait_for_ray "$head_ip:6479" + echo "Worker node joined. Sleeping..." + sleep infinity +fi diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh new file mode 100755 index 0000000000..aaeea4b908 --- /dev/null +++ b/scripts/fleet-common-setup.sh @@ -0,0 +1,130 @@ +#!/usr/bin/env bash +# Fleet shared setup: env validation, venv, dependencies, OpenEnv, dataset download +# +# Usage (from SkyPilot YAML setup block): +# bash skyrl-train/scripts/fleet-common-setup.sh \ +# --openenv-branch deniz/fleet_client \ +# --extra-setup skyrl-train/scripts/fleet-qwen35-extra-setup.sh +# +# Required env vars: FLEET_API_KEY, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, +# MODALITY, DATA_VERSION, S3_DATASET_BUCKET +# Optional env vars: ENV_KEYS, DIFFICULTY +set -euo pipefail + +# Defaults +OPENENV_BRANCH="deniz/fleet_client" +EXTRA_SETUP="" +DATA_ROOT="" +SKIP_UV_ISOLATED=false +EXTRA_PIP="" +SKIP_PREPARE=false + +# Parse args +while [[ $# -gt 0 ]]; do + case "$1" in + --openenv-branch) OPENENV_BRANCH="$2"; shift 2 ;; + --extra-setup) EXTRA_SETUP="$2"; shift 2 ;; + --data-root) DATA_ROOT="$2"; shift 2 ;; + --skip-uv-isolated) SKIP_UV_ISOLATED=true; shift ;; + --extra-pip) EXTRA_PIP="$2"; shift 2 ;; + --skip-prepare) SKIP_PREPARE=true; shift ;; + *) echo "ERROR: Unknown arg: $1"; exit 1 ;; + esac +done + +# Auto-detect data root: /workspace if writable (RunPod), else $HOME (GCP, Lambda, etc.) +if [ -z "$DATA_ROOT" ]; then + if [ -d "/workspace" ] && [ -w "/workspace" ]; then + DATA_ROOT="/workspace" + else + DATA_ROOT="$HOME" + fi +fi + +# Resolve extra-setup path to absolute before cd (it's relative to repo root) +if [ -n "$EXTRA_SETUP" ]; then + EXTRA_SETUP="$(cd "$(dirname "$EXTRA_SETUP")" && pwd)/$(basename "$EXTRA_SETUP")" +fi + +# In upstream SkyRL, training packages live at repo root (skyrl/, skyrl-gym/, integrations/) +# No need to cd into skyrl-train/ — the venv and dependencies are at root level + +echo "=== Fleet Common Setup ===" +echo "OpenEnv branch: $OPENENV_BRANCH" +echo "Data root: $DATA_ROOT" +echo "Extra setup: ${EXTRA_SETUP:-none}" + +# --- Environment validation --- +echo "Validating environment variables..." +if [ -z "${FLEET_API_KEY:-}" ]; then + echo "ERROR: FLEET_API_KEY is required"; exit 1 +fi +if [ -z "${AWS_ACCESS_KEY_ID:-}" ] || [ -z "${AWS_SECRET_ACCESS_KEY:-}" ]; then + echo "ERROR: AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are required for S3 dataset download"; exit 1 +fi +if [ "${MODALITY:-}" != "tool_use" ] && [ "${MODALITY:-}" != "computer_use" ]; then + echo "ERROR: MODALITY must be 'tool_use' or 'computer_use', got: ${MODALITY:-unset}"; exit 1 +fi +echo "Environment validation passed" + +# --- Fix Ray binary permissions (some cloud images strip +x) --- +for f in .venv/bin/ray .venv/lib/python*/site-packages/ray/core/src/ray/raylet/raylet; do + [ -f "$f" ] && chmod +x "$f" 2>/dev/null || true +done + +# --- System dependencies (GCP images may lack build tools) --- +if ! command -v c++ &>/dev/null; then + echo "Installing build-essential (c++ compiler required for causal-conv1d)..." + sudo apt-get update -qq && sudo apt-get install -y --no-install-recommends build-essential +fi + +# --- Python environment --- +if [ -d ".venv" ]; then + echo "Virtual environment already exists, reusing" +else + uv venv --python 3.12 --seed +fi +source .venv/bin/activate +# vLLM 0.17.0 has native Qwen3.5 support (GDN via torch.ops.vllm.gdn_attention_core), +# FlashAttention 4, and PyTorch 2.10.0 +uv sync --extra vllm +uv pip install wandb boto3 awscli +# Pin fleet-python<=0.2.119: 0.2.120+ has async BaseWrapper bug (missing jwt/team_id params) +uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" + +# --- Extra pip packages (installed before extra-setup to avoid dependency downgrades) --- +if [ -n "$EXTRA_PIP" ]; then + echo "Installing extra pip packages: $EXTRA_PIP" + uv pip install $EXTRA_PIP +fi + +# --- Extra setup hook (model-specific dependencies) --- +if [ -n "$EXTRA_SETUP" ]; then + echo "Running extra setup: $EXTRA_SETUP" + source "$EXTRA_SETUP" +fi + +# --- OpenEnv (force reinstall for latest changes) --- +uv pip install --force-reinstall --no-cache-dir --no-deps "git+https://github.com/fleet-ai/OpenEnv.git@${OPENENV_BRANCH}" + +# --- Dataset download --- +mkdir -p "${DATA_ROOT}/data/fleet" +TASKS_FILE="${DATA_ROOT}/data/fleet/tasks_${MODALITY}.json" +S3_PATH="s3://${S3_DATASET_BUCKET}/${DATA_VERSION}/openenv/all_${MODALITY}.json" +echo "Downloading dataset from $S3_PATH..." +aws s3 cp "$S3_PATH" "$TASKS_FILE" +TASK_COUNT=$(python3 -c "import json; print(len(json.load(open('$TASKS_FILE'))['tasks']))") +echo "Downloaded $TASK_COUNT tasks for modality: $MODALITY" + +# --- Prepare dataset (parquet files) --- +if [ "$SKIP_PREPARE" = true ]; then + echo "Skipping prepare_dataset (--skip-prepare). Caller handles preparation." +else + DATA_DIR="${DATA_ROOT}/data/fleet/${MODALITY}" + PREPARE_CMD="python -m integrations.fleet.prepare_dataset --tasks-json $TASKS_FILE --output-dir $DATA_DIR --modality $MODALITY" + [ -n "${ENV_KEYS:-}" ] && PREPARE_CMD="$PREPARE_CMD --env-filter $ENV_KEYS" + [ -n "${DIFFICULTY:-}" ] && PREPARE_CMD="$PREPARE_CMD --difficulty-filter $DIFFICULTY" + eval "$PREPARE_CMD" +fi + +echo "=== Fleet Common Setup Complete ===" diff --git a/scripts/fleet-qwen35-extra-setup.sh b/scripts/fleet-qwen35-extra-setup.sh new file mode 100755 index 0000000000..03367b8f64 --- /dev/null +++ b/scripts/fleet-qwen35-extra-setup.sh @@ -0,0 +1,71 @@ +#!/usr/bin/env bash +# Qwen3.5-specific dependencies (sourced by fleet-common-setup.sh via --extra-setup) +# +# Installs: transformers 5.3.0, flash-attn 2.8.3 wheel, CUDA toolkit (nvcc), causal-conv1d +# Writes: $HOME/.cuda_env (sourced at run time for FlashInfer JIT) + +# Upgrade transformers to 5.3.0 for Qwen3.5-MoE (model_type=qwen3_5_moe). +# - Qwen3.5 launched Feb 2026; all 4.x releases predate it. +# - 5.1.0 doesn't register qwen3_5_moe in AUTO_CONFIG_MAPPING. +# - 5.3.0 is the first stable release with full qwen3_5_moe support. +# - Do NOT install from git main (renamed layer_type_validation, breaks vLLM 0.17). +uv pip install -U "transformers==5.3.0" + +# flash-attn 2.8.3 prebuilt wheel for torch 2.10 + CUDA 12 (training forward/backward) +uv pip install "https://github.com/lesj0610/flash-attention/releases/download/v2.8.3-cu12-torch2.10-cp312/flash_attn-2.8.3%2Bcu12torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl" + +python -c "import torch; import torchvision; print(f'torch={torch.__version__}, torchvision={torchvision.__version__}')" + +# --- CUDA toolkit for FlashInfer JIT (GatedDeltaNet kernels) --- +# pip CUDA packages are incomplete (missing nv/target headers); use NVIDIA apt repo instead +CUDA_HOME="" +for d in /usr/local/cuda /usr/local/cuda-12.8 /usr/local/cuda-12.6 /usr/local/cuda-12.4; do + if [ -x "$d/bin/nvcc" ]; then + CUDA_HOME="$d" + break + fi +done +if [ -z "$CUDA_HOME" ] && command -v nvcc &>/dev/null; then + NVCC_PATH=$(command -v nvcc) + CUDA_HOME=$(dirname "$(dirname "$NVCC_PATH")") +fi +if [ -z "$CUDA_HOME" ]; then + echo "nvcc not found on system. Installing CUDA toolkit from NVIDIA apt repo..." + sudo apt-get update -qq + UBUNTU_VER=$(lsb_release -rs 2>/dev/null | tr -d '.' || echo "2204") + KEYRING_URL="https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${UBUNTU_VER}/x86_64/cuda-keyring_1.1-1_all.deb" + echo "Installing CUDA keyring from $KEYRING_URL" + wget -qO /tmp/cuda-keyring.deb "$KEYRING_URL" 2>&1 || curl -sLo /tmp/cuda-keyring.deb "$KEYRING_URL" + file /tmp/cuda-keyring.deb + sudo dpkg -i /tmp/cuda-keyring.deb + sudo apt-get update -qq + sudo apt-get install -y --no-install-recommends cuda-nvcc-12-8 libcublas-dev-12-8 cuda-nvrtc-dev-12-8 + CUDA_HOME="/usr/local/cuda-12.8" +fi +export CUDA_HOME +export PATH="$CUDA_HOME/bin:$PATH" +echo "CUDA_HOME=$CUDA_HOME" +"$CUDA_HOME/bin/nvcc" --version + +# Write cuda_env for run phase (fleet-common-run.sh sources this via --cuda-env) +echo "export CUDA_HOME=$CUDA_HOME" > "$HOME/.cuda_env" +echo "export PATH=$CUDA_HOME/bin:\$PATH" >> "$HOME/.cuda_env" + +# causal-conv1d: required for GatedDeltaNet fast CUDA kernels in Qwen3.5-MoE. +# Without it, fla-core falls back to a naive PyTorch implementation that crashes +# with cudaErrorIllegalAddress on multi-node FSDP2 (Xid 31 MMU fault). +# Must be built from source (needs nvcc + g++) — install AFTER CUDA toolkit setup. +uv pip install "causal-conv1d>=1.6.0" +python -c "import causal_conv1d; print(f'causal-conv1d OK: {causal_conv1d.__version__}')" + +# Verify pinned packages survived dependency resolution +python -c "import transformers; assert transformers.__version__ == '5.3.0', f'Expected 5.3.0 got {transformers.__version__}'" +# Ensure torch 2.10.0 — uv pip install can downgrade it during transitive resolution +TORCH_VER=$(python -c "import torch; print(torch.__version__)") +echo "torch version after setup: $TORCH_VER" +if [[ "$TORCH_VER" != 2.10.0* ]]; then + echo "WARNING: torch was downgraded to $TORCH_VER, reinstalling 2.10.0+cu128" + pip install --force-reinstall --no-deps torch==2.10.0 --index-url https://download.pytorch.org/whl/cu128 +fi +python -c "import torch; assert torch.__version__.startswith('2.10.0'), f'Expected 2.10.0 got {torch.__version__}'" +python -c "import torch; import flash_attn_2_cuda; print('flash_attn CUDA extension OK')" diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh new file mode 100755 index 0000000000..e066ef2d66 --- /dev/null +++ b/scripts/fleet-task-gen-run.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +# Task-gen specific run: calls common run with task-gen entrypoint and hydra overrides +# +# Usage (from SkyPilot YAML run block): +# bash skyrl-train/scripts/fleet-task-gen-run.sh +# +# Required env vars: WANDB_API_KEY, MODALITY, INFERENCE_BACKEND, LOGGER, +# MAX_TURNS, MAX_INPUT_LENGTH, MAX_GENERATE_LENGTH, NUM_EPOCHS, +# JUDGE_MODEL, K_ROLLOUTS, ALPHA, MAX_EVAL_STEPS +# SkyPilot env vars: SKYPILOT_NUM_GPUS_PER_NODE, SKYPILOT_NODE_IPS +set -euo pipefail + +# Export RUN_NAME so task_gen_env can tag rollout dumps +# Always use random hex suffix for unique run names +export RUN_NAME="task_gen_$(python3 -c 'import os; print(os.urandom(4).hex())')" + +# Task-gen GRPO training via shared run script +# --entrypoint: task-gen entrypoint (not main_fleet) +# --env-class: task_gen environment (not fleet_task) +# --data-dir-name: parquet files are in data/fleet/task_gen/ (not data/fleet/tool_use/) +# TP=1: N engines × 1 GPU each (Qwen3.5-9B fits in single H200) +# num_inference_engines auto-detected from SKYPILOT_NUM_GPUS_PER_NODE by fleet-common-run.sh +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --entrypoint integrations.fleet.entrypoints.main_task_gen \ + --env-class task_gen \ + --data-dir-name task_gen -- \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path="Qwen/Qwen3.5-9B" \ + trainer.flash_attn=true \ + trainer.use_sample_packing=false \ + generator.inference_engine_tensor_parallel_size=1 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=12 \ + trainer.eval_before_train=false \ + trainer.eval_interval=20 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=12 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=12 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.loss_chunk_size=4096 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=4096 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.95 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=["", ""]' \ + generator.eval_sampling_params.temperature=0.95 \ + generator.eval_sampling_params.top_p=0.95 \ + 'generator.eval_sampling_params.stop=["", ""]' \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.trajectory_timeout_seconds=1800 \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=8 \ + generator.eval_n_samples_per_prompt=3 \ + generator.gpu_memory_utilization=0.75 \ + trainer.logger="$LOGGER" \ + trainer.project_name="task-gen-grpo" \ + trainer.run_name="$RUN_NAME" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/task_gen" \ + trainer.dump_data_batch=true \ + ++environment.skyrl_gym.task_gen.max_turns=$MAX_TURNS \ + ++environment.skyrl_gym.task_gen.judge_model="$JUDGE_MODEL" \ + ++environment.skyrl_gym.task_gen.k_rollouts=$K_ROLLOUTS \ + ++environment.skyrl_gym.task_gen.alpha=$ALPHA \ + ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ + ++environment.skyrl_gym.task_gen.evaluator_model="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" \ + ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 diff --git a/skyrl/train/config/skyrl_gym_config/default.yaml b/skyrl/train/config/skyrl_gym_config/default.yaml index a94985f1e5..87541fd25c 100644 --- a/skyrl/train/config/skyrl_gym_config/default.yaml +++ b/skyrl/train/config/skyrl_gym_config/default.yaml @@ -14,3 +14,20 @@ search: search_url: "http://127.0.0.1:8000/retrieve" topk: 3 timeout: 30 + +fleet_task: + tasks_file: null + api_key: null + ttl_seconds: null + partial_reward: false + enable_hints: false + enable_context_tools: false + +task_gen: + max_turns: 10 + judge_model: "anthropic/claude-sonnet-4.5" + evaluator_model: "anthropic/claude-sonnet-4.5" + k_rollouts: 4 + eval_k_rollouts: 8 + alpha: 1.0 + max_eval_steps: 20 diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml new file mode 100644 index 0000000000..7acbbb06fa --- /dev/null +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -0,0 +1,69 @@ +# Fleet Task GRPO Training via SkyPilot - Qwen3.5-35B-A3B (MoE, Multi-Node) +# Usage: sky launch tasks/openenv-fleet-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# MoE: 35B total, 3B active (256 experts, 9 active/token). GatedDeltaNet architecture. +# 262K native context. All 35B params in memory (~70GB fp16), optimizer ~140GB, gradients ~70GB. +# +# Multi-node (2-node default, 16 GPUs total): 8x H200 per node +# NOTE: Requires vLLM >= 0.17.0 for native Qwen3.5/GDN support + FlashAttention 4 + +name: fleet-task-grpo-qwen3-5-35b + +resources: + disk_size: 750 + memory: 1500+ + ports: 6479 + any_of: + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + - accelerators: H200:8 + cloud: kubernetes + network_tier: best + use_spot: true + - accelerators: H200:8 + cloud: runpod + - accelerators: H200:8 + cloud: lambda + - accelerators: H200:8 + cloud: nebius + +num_nodes: 2 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v55" + ENV_KEYS: "" + DIFFICULTY: "" + MODALITY: "tool_use" + MAX_TURNS: 50 + MAX_INPUT_LENGTH: 96000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 20 + RUN_ID: "" + MAX_TASKS: "" + RESUME_RUN_NAME: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + # TP=2 -> 8 engines (each uses 2 GPUs) to match 16 policy GPUs with colocate_all + NUM_INFERENCE_ENGINES: 8 + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh + +run: | + bash scripts/fleet-35b-run.sh diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml new file mode 100644 index 0000000000..bc1abdd2c5 --- /dev/null +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -0,0 +1,59 @@ +# Task Generation GRPO Training via SkyPilot - Qwen3.5-9B +# Usage: sky launch tasks/task-gen-grpo-qwen3_5-9b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# Qwen3.5-9B: MoE with ~1B active params. Fits on single H200 GPU (TP=1). +# 8 inference engines on 8x H200 node. + +name: task-gen-grpo-qwen3-5-9b + +resources: + disk_size: 500 + memory: 800+ + ports: 6479 + any_of: + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + - accelerators: H200:8 + cloud: runpod + - accelerators: H200:8 + cloud: lambda + +num_nodes: 1 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v55" + MODALITY: "tool_use" + MAX_TURNS: 10 + MAX_INPUT_LENGTH: 30720 + MAX_GENERATE_LENGTH: 2048 + NUM_EPOCHS: 20 + JUDGE_MODEL: "anthropic/claude-sonnet-4.5" + EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" + K_ROLLOUTS: 4 + ALPHA: "1.0" + MAX_EVAL_STEPS: 20 + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh \ + --skip-prepare + +run: | + bash scripts/fleet-task-gen-run.sh From ae7934e9ad47dc61b191f013762ae3221a2e0542 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 14:50:59 -0700 Subject: [PATCH 003/121] Add task generation environment for skyrl-gym Port the task generation environment from fleet-ai/SkyRL that enables RL-based training of task-generating models. The environment supports multi-turn task generation where the model generates (prompt, verifier) pairs that are evaluated via Fleet harness rollouts. Key components: - TaskGenEnv(BaseTextEnv): Multi-turn env with tool-based DB exploration, task generation, and reward computation via variance + hint gap - VerifierSandbox: AST-based static analysis for generated verifier code safety (blocked imports/builtins, complexity bounds, signature checks) - Tool call parser: Handles / tag formats Reward formula: R = gate * (base_quality + alpha * var(raw_scores) + hint_gap) Depends on PR #2 (fleet/training) for integrations.fleet.task_gen_reward. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/__init__.py | 5 + skyrl-gym/skyrl_gym/envs/task_gen/__init__.py | 5 + .../skyrl_gym/envs/task_gen/task_gen_env.py | 1344 +++++++++++++++++ .../envs/task_gen/tool_call_parser.py | 67 + .../envs/task_gen/verifier_sandbox.py | 300 ++++ 5 files changed, 1721 insertions(+) create mode 100644 skyrl-gym/skyrl_gym/envs/task_gen/__init__.py create mode 100644 skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py create mode 100644 skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py create mode 100644 skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py diff --git a/skyrl-gym/skyrl_gym/envs/__init__.py b/skyrl-gym/skyrl_gym/envs/__init__.py index 770b65e1e8..5258537907 100644 --- a/skyrl-gym/skyrl_gym/envs/__init__.py +++ b/skyrl-gym/skyrl_gym/envs/__init__.py @@ -36,3 +36,8 @@ id="searchcode", entry_point="skyrl_gym.envs.searchcode.env:SearchCodeEnv", ) + +register( + id="task_gen", + entry_point="skyrl_gym.envs.task_gen.task_gen_env:TaskGenEnv", +) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/__init__.py b/skyrl-gym/skyrl_gym/envs/task_gen/__init__.py new file mode 100644 index 0000000000..b5c5a7e88c --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/__init__.py @@ -0,0 +1,5 @@ +from skyrl_gym.envs.task_gen.task_gen_env import TaskGenEnv +from skyrl_gym.envs.task_gen.tool_call_parser import parse_tool_call +from skyrl_gym.envs.task_gen.verifier_sandbox import VerifierSandbox + +__all__ = ["TaskGenEnv", "VerifierSandbox", "parse_tool_call"] diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py new file mode 100644 index 0000000000..5e3dea83c7 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -0,0 +1,1344 @@ +""" +Task Generation Environment for SkyRL. + +Multi-turn BaseTextEnv where the LLM can explore the seed database via +``describe_db`` / ``query_db`` meta-tools before generating a task. + +When ``max_turns > 1`` (the default), the model explores the DB first +and then produces a ```` block. When ``max_turns == 1`` it +behaves identically to the original single-turn variant. + +Reward: + + R(task) = base_quality + llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) + + base_quality: Small reward for passing sandbox+judge (default 0.1) + llm_validity: Binary 0/1 from LLM-as-a-judge (is the task well-formed?) + var(raw_scores): Variance of k raw evaluator rollouts (difficulty calibration) + p_hint - p_raw: Hint gap — solvable with hints but not without (learnability) + alpha: Weight balancing variance vs hint gap (default 0.5) +""" + +import ast +import asyncio +import json +import logging +import os +import re +import time +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from omegaconf import DictConfig + +from skyrl_gym.envs.base_text_env import ( + BaseTextEnv, + BaseTextEnvStepOutput, + ConversationType, +) +from skyrl_gym.envs.task_gen.tool_call_parser import parse_tool_call +from skyrl_gym.envs.task_gen.verifier_sandbox import ( + VerifierSandbox, + parse_task_output, +) + +logger = logging.getLogger(__name__) + +# Meta-tools the model can call to explore the seed database. +_META_TOOLS = {"describe_db", "query_db"} + +# All callable tools = meta-tools + any MCP env tools discovered at init time. +# Populated per-instance in init_async(). + + +class TaskGenEnv(BaseTextEnv): + """Environment for RL-based task generation. + + The LLM generates (prompt, verifier) pairs for Fleet environments. + Supports multi-turn: the model can explore the seed DB via ``describe_db`` + and ``query_db`` meta-tools before outputting a ```` block. + + Reward = llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) + + Evaluation uses Fleet harness jobs (POST /v1/jobs) to run an LLM agent + against the generated task, rather than a stub evaluator. + + Constructor args (via extras, from dataset): + env_key, env_version, data_key, data_version + env_tools, env_tools_schema, env_variable_keys + + Constructor args (via env_config, from Hydra): + max_turns: Max turns before forced termination (default 10) + judge_model: Model ID for LLM-as-a-judge gate + k_rollouts: Number of rollouts per condition (raw/hinted, default 4) + max_eval_steps: Max agent steps per evaluator rollout (default 30) + evaluator_model: Fleet harness model for task evaluation (default anthropic/claude-sonnet-4.5) + base_quality_reward: Small reward for passing sandbox+judge (default 0.1). + Prevents GRPO zero-signal deadlock when all harness evals fail. + """ + + def __init__( + self, + env_config: DictConfig, + extras: Dict[str, Any] = {}, + ): + super().__init__() + + # Configurable multi-turn (default 10; set to 1 for single-turn) + self.max_turns = int(env_config.get("max_turns", 10)) if env_config else 10 + + # Fleet orchestrator for DB exploration (set in init_async) + self.orch = None + # MCP tools client for calling env tools (set in init_async) + self.mcp_tools = None + # Set of all callable tool names (meta-tools + MCP tools) + self.callable_tools = set(_META_TOOLS) + # Exploration sequence tracking (reset in init_async) + self.called_describe_db = False + self.called_query_db = False + + # Environment context from dataset (extras) + self.env_key = extras.get("env_key", "unknown") + self.env_version = extras.get("env_version", "") + self.data_key = extras.get("data_key", "") + self.data_version = extras.get("data_version", "") + + # Parse env_tools_schema (full tool schemas for prompt building) + env_tools_schema_raw = extras.get("env_tools_schema", "[]") + if isinstance(env_tools_schema_raw, str): + try: + self.env_tools_schema: List[Dict[str, Any]] = json.loads(env_tools_schema_raw) + except json.JSONDecodeError: + self.env_tools_schema: List[Dict[str, Any]] = [] + else: + self.env_tools_schema: List[Dict[str, Any]] = env_tools_schema_raw or [] + + # Parse env_tools (tool name list for sandbox validation) + env_tools_raw = extras.get("env_tools", []) + if isinstance(env_tools_raw, str): + try: + self.env_tools: List[str] = json.loads(env_tools_raw) + except json.JSONDecodeError: + self.env_tools: List[str] = [] + else: + self.env_tools: List[str] = env_tools_raw or [] + + # If env_tools is empty but we have schemas, extract names from schemas + if not self.env_tools and self.env_tools_schema: + self.env_tools = [ + t["function"]["name"] for t in self.env_tools_schema if "function" in t and "name" in t["function"] + ] + + # Parse env_variable_keys (available context variables for this env) + env_var_keys_raw = extras.get("env_variable_keys", "[]") + if isinstance(env_var_keys_raw, str): + try: + self.env_variable_keys: List[str] = json.loads(env_var_keys_raw) + except json.JSONDecodeError: + self.env_variable_keys: List[str] = [] + else: + self.env_variable_keys: List[str] = env_var_keys_raw or [] + + # Parse env_variables (actual values for harness evaluation) + env_vars_raw = extras.get("env_variables", "{}") + if isinstance(env_vars_raw, str): + try: + self.env_variables: Dict[str, Any] = json.loads(env_vars_raw) + except json.JSONDecodeError: + self.env_variables: Dict[str, Any] = {} + else: + self.env_variables: Dict[str, Any] = env_vars_raw or {} + + # Parse env_schema (compact DB schema: table→columns) + self.env_schema: str = extras.get("env_schema", "") or "" + + # Verifier sandbox — filters out CUA-only tool "computer" from available tools + api_tools = set(self.env_tools) - {"computer"} if self.env_tools else None + self.sandbox = VerifierSandbox(available_tools=api_tools if api_tools else None) + + # Judge config (from Hydra env_config) + self.judge_model = str(env_config.get("judge_model", "")) if env_config else "" + + # Evaluator config (from Hydra env_config) + self.k_rollouts = int(env_config.get("k_rollouts", 4)) if env_config else 4 + self.max_eval_steps = int(env_config.get("max_eval_steps", 30)) if env_config else 30 + self.evaluator_model = ( + str(env_config.get("evaluator_model", "anthropic/claude-sonnet-4.5")) + if env_config + else "anthropic/claude-sonnet-4.5" + ) + + # API keys from environment variables (set by SkyPilot YAML) + self.openrouter_api_key = os.environ.get("OPENROUTER_API_KEY", "") + self.fleet_api_key = os.environ.get("FLEET_API_KEY", "") + + # Eval mode: k=8 raw only (no hints); Train mode: k with hints + self.is_eval = extras.get("training_phase") == "eval" + self.eval_k_rollouts = int(env_config.get("eval_k_rollouts", 8)) if env_config else 8 + + # Lazy-init Fleet SDK client for harness evaluation + self._fleet_client = None + + # Rollout dump directory (full prompt/verifier/scores per eval) + self._rollout_dir = os.environ.get("ROLLOUT_DIR", "/workspace/rollouts") + os.makedirs(self._rollout_dir, exist_ok=True) + + # Base quality reward for tasks passing sandbox + judge gate. + # Provides GRPO gradient signal even when all harness evals return 0. + self.base_quality_reward = float(env_config.get("base_quality_reward", 0.1)) if env_config else 0.1 + + logger.info( + f"TaskGenEnv: env={self.env_key}, max_turns={self.max_turns}, " + f"judge={self.judge_model or 'none'}, " + f"tools={len(self.env_tools)}, k={self.k_rollouts}, eval_k={self.eval_k_rollouts}, " + f"evaluator={self.evaluator_model}, is_eval={self.is_eval}, " + f"base_quality={self.base_quality_reward}" + ) + + def _format_tool_schema(self, tool: Dict[str, Any]) -> str: + """Format a single tool schema for the system prompt.""" + func = tool.get("function", {}) + name = func.get("name", "unknown") + desc = func.get("description", "") + params = func.get("parameters", {}) + properties = params.get("properties", {}) + required = set(params.get("required", [])) + + lines = [f"**{name}**: {desc}"] + if properties: + lines.append(" Parameters:") + for pname, pschema in properties.items(): + ptype = pschema.get("type", "any") + pdesc = pschema.get("description", "") + req_marker = " (required)" if pname in required else "" + lines.append(f" - {pname} ({ptype}{req_marker}): {pdesc}") + + return "\n".join(lines) + + def _build_system_prompt(self) -> str: + """Build the system prompt with environment context and priors.""" + parts = [] + + parts.append(f'You are a task designer for the "{self.env_key}" environment.') + + # --- Date context (critical for date-sensitive environments) --- + current_date = self.env_variables.get("CURRENT_DATE", "") + if current_date: + parts.append( + f"\n**IMPORTANT — Current Date: {current_date}**\n" + f"The environment's current date is {current_date}. " + "All dates in generated tasks MUST be on or after this date. " + "Do NOT use past dates — the environment will reject them " + "(e.g., check-in dates, event dates, appointment dates must be in the future)." + ) + + # --- A. Environment context (from tool discovery) --- + parts.append(f"\n## Environment: {self.env_key}") + parts.append("\n### Available Tools") + + # Filter out CUA-only "computer" tool — task-gen is for tool-use APIs + api_schemas = [t for t in self.env_tools_schema if t.get("function", {}).get("name") != "computer"] + api_tool_names = [t for t in self.env_tools if t != "computer"] + + if api_schemas: + # Compact format: name + description only (no parameter schemas) + # Full schemas make the prompt too long for envs with many tools + for tool in api_schemas: + func = tool.get("function", {}) + name = func.get("name", "unknown") + desc = func.get("description", "") + parts.append(f"- **{name}**: {desc}") + elif api_tool_names: + parts.append("\n".join(f"- {t}" for t in api_tool_names)) + else: + parts.append("No tools discovered for this environment.") + + # Environment variables (user context available at task runtime) + if self.env_variables: + parts.append("\n### Environment Variables (embed as constants)") + parts.append( + "These variables describe the user/session context. " + "**Embed them directly as string constants** in your verifier code. " + "Do NOT use `env.env_variables` — it is not available at verifier runtime." + ) + for var_key, var_val in self.env_variables.items(): + parts.append(f'- `{var_key}` = `"{var_val}"`') + parts.append( + "\nExample usage in verifier:\n" + "```python\n" + f'LOGGED_IN_USER = "{self.env_variables.get("LOGGED_IN_USER", "user@example.com")}"\n' + f'# Use as: rows = current.table("users").eq("email", LOGGED_IN_USER).all()\n' + "```" + ) + elif self.env_variable_keys: + parts.append("\n### Environment Variables") + parts.append( + "These variables parameterize each environment instance. " + "Look up values from the database instead of using env.env_variables." + ) + for var_key in self.env_variable_keys: + parts.append(f"- `{var_key}`") + + # Database schema (table names and columns) + if self.env_schema: + parts.append("\n### Database Schema") + parts.append( + "Use these exact table and column names in verifiers " + '(e.g., `current.table("bookings").eq("guest_email", val).all()`):' + ) + parts.append(f"```\n{self.env_schema}\n```") + + # --- B. Priors (concise, static, same for all envs) --- + # Date awareness guidance (prevents past-date failures in booking/ticketmaster) + if current_date: + date_guidance = ( + f"### Date Awareness\n" + f"The environment's current date is **{current_date}**. " + f"ALL dates in your task MUST be on or after {current_date}. " + "Tasks with past dates will always fail because the environment " + "rejects them (e.g., 'checkIn date cannot be in the past'). " + "Use `query_db` to check what date ranges exist in the data, " + "and always generate future dates." + ) + else: + date_guidance = ( + "### Date Awareness\n" + "If the environment works with dates, verify what date ranges " + "are valid before generating tasks. Use `query_db` to check." + ) + + # NOTE: env.env_variables is NOT available at verifier runtime (Fleet harness bug). + # Model is instructed to embed env var values as constants instead. + + parts.append( + f""" +## Verifier Guidelines + +The verifier checks whether the agent completed the task by inspecting database state changes. + +Signature: `def validate_task(env: Environment, final_answer: str | None = None) -> int` + +**IMPORTANT**: The function MUST be named `validate_task` and return `TASK_FAILED_SCORE` (0) or `TASK_SUCCESSFUL_SCORE` (1). + +### Verifier API +```python +env.instance.load() # Load current state (call first) +seed = env.db("seed") # Original DB before agent acted +current = env.db("current") # Current DB after agent acted + +# Query tables — ALL results are Python dicts, use row["column"] NOT row.column: +rows = current.table("table_name").eq("column", value).all() # -> List[dict] +row = current.table("table_name").eq("column", value).first() # -> dict or None +rows = current.table("table_name").neq("column", value).all() # -> List[dict] +count = current.table("table_name").eq("column", value).count() # -> int +rows = current.table("table_name").select("col1", "col2").all() # -> List[dict] +# Access fields: row["id"], row["name"], row["email"] — NEVER row.id or row.name +# Only methods: .table(), .eq(), .neq(), .select(), .all(), .first(), .count() +# NO .like(), .gt(), .lt(), .contains(), .in_() — use Python filtering instead + +# Compare seed vs current to detect NEW entries: +def find_new_entries(seed, current, table_name, id_field="id", filter_conditions=None): + before_query = seed.table(table_name) + after_query = current.table(table_name) + if filter_conditions: + for key, value in filter_conditions.items(): + before_query = before_query.eq(key, value) + after_query = after_query.eq(key, value) + before_ids = {{entry[id_field] for entry in before_query.select(id_field).all()}} + return [e for e in after_query.all() if e[id_field] not in before_ids] +``` + +### Error Tracking (REQUIRED) +Every verifier MUST track errors and successes using accumulator lists, and print them +before returning. This enables automated feedback for hint-based evaluation. + +```python +error_accumulator = [] +success_accumulator = [] + +# ... check conditions ... +if condition_met: + success_accumulator.append("[C] Booking was created") +else: + error_accumulator.append("[X] Expected booking not found") + +# ALWAYS print accumulators before returning: +if error_accumulator: + print(">>> ERROR_ACCUMULATOR >>>") + print(error_accumulator) + print("<<< ERROR_ACCUMULATOR <<<") +if success_accumulator: + print(">>> SUCCESS_ACCUMULATOR >>>") + print(success_accumulator) + print("<<< SUCCESS_ACCUMULATOR <<<") +``` + +### Verifier Template (follow this structure) +```python +def validate_task(env: Environment, final_answer: str | None = None) -> int: + error_accumulator = [] + success_accumulator = [] + env.instance.load() + seed = env.db("seed") + current = env.db("current") + + def find_new_entries(table_name, id_field="id", filter_conditions=None): + before_query = seed.table(table_name) + after_query = current.table(table_name) + if filter_conditions: + for key, value in filter_conditions.items(): + before_query = before_query.eq(key, value) + after_query = after_query.eq(key, value) + before_ids = set(entry[id_field] for entry in before_query.select(id_field).all()) + return [e for e in after_query.all() if e[id_field] not in before_ids] + + # Check conditions... + # On early failure: + if critical_failure: + error_accumulator.append("[X] Critical check failed") + print(">>> ERROR_ACCUMULATOR >>>") + print(error_accumulator) + print("<<< ERROR_ACCUMULATOR <<<") + return TASK_FAILED_SCORE + + # Final result: + if error_accumulator: + print(">>> ERROR_ACCUMULATOR >>>") + print(error_accumulator) + print("<<< ERROR_ACCUMULATOR <<<") + return TASK_FAILED_SCORE + print(">>> SUCCESS_ACCUMULATOR >>>") + print(success_accumulator) + print("<<< SUCCESS_ACCUMULATOR <<<") + return TASK_SUCCESSFUL_SCORE +``` + +### Rules +- **NEVER hardcode database IDs** (user_id, hotel_id, etc.) — always query the DB to find them +- **NEVER use `env.env_variables`** — it is not available at runtime. Embed env var values as string constants at the top of your verifier (e.g., `LOGGED_IN_USER = "riley3318"`) +- **DB rows are dicts** — use `row["id"]`, `row["name"]`, NOT `row.id`, `row.name`. Using dot notation will crash with `AttributeError: 'dict' object has no attribute 'id'` +- **Only use supported query methods**: `.eq()`, `.neq()`, `.select()`, `.all()`, `.first()`, `.count()`. NO `.like()`, `.gt()`, `.lt()`, `.order()`, `.limit()`, `.contains()`, `.in_()` — filter and sort in Python instead (e.g., `sorted([r for r in rows if r["score"] > 8.0], key=lambda r: r["score"], reverse=True)[:5]`) +- **`.eq()` takes exactly 2 args**: `.eq(column, value)`. NO operator arg like `.eq("rating", ">", 8)` — use Python: `[r for r in rows if r["rating"] > 8]` +- **Use timezone-tolerant comparisons** for datetimes — the DB may store `"2025-08-08T14:00:00Z"` while you expect `"2025-08-08T14:00:00"`. Use `.startswith()` or strip the trailing `"Z"` before comparing +- **If you use `.select()`, only access the selected columns** — accessing other columns raises `KeyError`. Prefer `.all()` without `.select()` unless you specifically need to limit columns +- **Define `find_new_entries` inside your verifier function** — it is NOT a built-in. Copy it from the template above into your `validate_task()` function body. Do NOT call `find_new_entries()` without defining it first +- **List comprehensions produce tuples if you use tuple syntax** — `[(a, b) for ...]` creates tuples, not dicts. If you need dict-like access later, keep the original dicts: `[row for row in rows if condition]` +- **NEVER hardcode expected values the agent must create** — e.g., don't check for a specific phone number or email the agent would need to invent. Instead, check that the field was changed from its original value: `current_val != seed_val` +- Look up the logged-in user by name/email from the users table, don't assume an ID +- Compare `seed` (before) vs `current` (after) to detect what the agent did +- Must return `TASK_FAILED_SCORE` on a fresh environment (before agent acts) +- Use `final_answer` for tasks that require the agent to report a value +- Reference actual tool names from this environment + +## Task Design Guidelines + +Design tasks that maximize learnability: an ideal task is one that a capable agent can solve with effort, but not trivially. Tasks that are too easy (always solved) or too hard (never solved) produce no learning signal. + +{date_guidance} + +### Realism +Write prompts as a real user would — natural language, concrete parameters, plausible intent. The task should sound like something a person would actually ask, not a test case. + +BAD: "Call get_user with id=5, then call update_user to set email to test@example.com" +GOOD: "Update the email address for Jamie Chen to jamie.chen@newdomain.com" + +### Avoiding Underspecification +A prompt is underspecified when multiple valid solutions exist but the verifier only accepts one. This creates false negatives — the agent solves the task correctly but gets reward 0. + +BAD prompt: "Find a designer in Mexico" (3 designers exist, verifier checks for one specific one) +FIX option 1: Make the prompt specific: "Find the designer in Mexico City who joined after 2023" +FIX option 2: Make the verifier accept all valid answers: check that ANY designer in Mexico is returned + +Use `describe_db`/`query_db` to check the actual data before writing the prompt. If a query returns multiple rows, either narrow the prompt or widen the verifier. Always verify your assumptions by querying — don't guess. You MUST call all three of `describe_db`, `query_db`, and at least one environment API tool before writing the task — your task will be rejected otherwise. + +### Avoiding Overspecification +A prompt is overspecified when it dictates HOW to accomplish the task rather than WHAT outcome is needed. This makes the task trivially easy (no learning signal) and doesn't test real problem-solving. + +BAD: "First call list_tables, then call get_bookings with check_in_date='2024-03-15', then count the results and call submit_answer with the count" +GOOD: "How many bookings have a check-in date of March 15, 2024?" + +The prompt should specify the desired outcome. The agent should figure out which tools to use and in what order. + +### Complexity +Aim for tasks solvable in 2-8 tool calls. Tasks requiring 1 tool call are too easy (no signal). Tasks requiring 15+ calls are too hard (agent gives up). The sweet spot is 3-6 calls with some reasoning required. + +### Diversity +Vary tasks across multiple dimensions: +- Operations: reads (lookup, search, aggregate) AND writes (create, update, delete) +- Complexity: simple (2-3 tool calls) through moderate (4-8 tool calls with dependencies) +- Reasoning: some tasks need multi-step logic (find X, use X to look up Y, modify Y based on Z) +- Data entities: use different tables, columns, and relationships in the schema + +### Verifier-Prompt Consistency +The verifier must check exactly what the prompt asks — no more, no less. Before writing, verify: +1. Is there exactly one correct outcome for this prompt? (If not, widen the verifier or narrow the prompt) +2. Does the verifier return 0.0 on a fresh environment? (It must — the agent hasn't acted yet) +3. Does the verifier avoid hardcoded values? (Query the DB instead) +4. Could a different valid approach fool the verifier? (If so, fix the verifier to accept it)""" + ) + + # --- C. Exploration tools (multi-turn only) --- + if self.max_turns > 1: + parts.append( + """ +## Exploration Tools + +Before generating a task, explore the environment to understand the actual data and API behavior. + +### Database Tools +{"name": "describe_db", "arguments": {}} +Returns the full schema: table names, columns, types. + +{"name": "query_db", "arguments": {"sql": "SELECT * FROM table_name LIMIT 5"}} +Runs a read-only SQL query against the seed database. + +### Environment Tools +You MUST call at least one of the environment's API tools listed above to understand their input/output formats. + +**REQUIRED before generating a task:** You must call ALL THREE of: (1) `describe_db`, (2) `query_db`, and (3) at least one environment API tool. Your task will be rejected if any are missing. + +{"name": "tool_name", "arguments": {"param": "value"}} +Calls the tool and returns its result. Use this to understand input/output formats. + +### Workflow +1. **Explore**: Call `describe_db` to see all tables and columns. +2. **Inspect data**: Call `query_db` with SELECT queries to inspect real data (values, ranges, row counts, patterns). +3. **Try tools**: Call at least one environment API tool to understand its behavior, input/output formats, and edge cases. +4. **Draft a task idea**: Think about what prompt + verifier you could write based on the data you've seen. +5. **Validate your draft**: Before outputting the task, run `query_db` to verify your assumptions: + - Does the data your prompt references actually exist? (e.g., "Update Jamie's email" — is there a Jamie?) + - Will the verifier return 0.0 on a fresh DB? (Check seed state) + - Are there edge cases? (e.g., multiple matches, null values, empty tables) +6. **Iterate**: If your queries reveal problems (wrong assumptions, ambiguous data, too many/few matches), revise your task idea and verify again. Do NOT output the task until you've confirmed the data supports it. +7. **Output**: Only when confident, output the final task in the format below.""" + ) + + # --- D. Few-shot examples removed --- + # Few-shot examples were removed because they anchored the model to + # generate near-copies of the examples (especially booking/wishlist tasks), + # causing mode collapse and zero reward signal. The verifier template + + # guidelines above provide enough structure for the model to generate + # diverse tasks from the actual DB schema and tools. + + # --- E. Output format --- + parts.append( + """ +## Output Format + +Generate exactly ONE task. Output it in this format: + + + +[Natural language task instruction for the agent. Be specific about what needs to be done.] + + +[Python function: def validate_task(env, final_answer=None) -> int] + +""" + ) + + return "\n".join(parts) + + def _judge_task(self, prompt: str, verifier: str) -> float: + """LLM-as-a-judge gate: returns 0.0 (invalid) or 1.0 (valid). + + Uses a model to check if the generated (prompt, verifier) pair + is valid and coherent. This is the binary gate in the reward formula. + """ + if not self.judge_model or not self.openrouter_api_key: + return 1.0 # No judge configured, pass through + + # Build concise tool list for context + tool_names = [t for t in self.env_tools if t != "computer"] + tools_str = ", ".join(tool_names[:20]) if tool_names else "none discovered" + + judge_prompt = ( + f'Evaluate this task for the "{self.env_key}" environment.\n\n' + f"Available tools: {tools_str}\n\n" + f"Task prompt:\n{prompt}\n\n" + f"Verifier code:\n```python\n{verifier}\n```\n\n" + "A valid task must:\n" + "1. Have a clear, specific prompt describing what an agent should do\n" + "2. Have a verifier that checks the correct outcome via the DB API " + '(env.db("seed"), env.db("current"), .table().eq().all())\n' + "3. The verifier must check what the prompt actually asks\n" + "4. The prompt must not leak the answer or expected values\n" + "5. The verifier must return 0.0 on a fresh env (before agent acts)\n\n" + "Answer with exactly one word: VALID or INVALID" + ) + + try: + import litellm + + response = litellm.completion( + model=f"openrouter/{self.judge_model}", + messages=[{"role": "user", "content": judge_prompt}], + temperature=0, + max_tokens=10, + api_key=self.openrouter_api_key, + ) + answer = response.choices[0].message.content.strip().upper() + is_valid = "VALID" in answer and "INVALID" not in answer + logger.info(f"LLM judge [{self.env_key}]: {answer} -> {'VALID' if is_valid else 'INVALID'}") + return 1.0 if is_valid else 0.0 + except Exception as e: + logger.warning(f"LLM judge failed, defaulting to valid: {e}") + return 1.0 + + @staticmethod + def _build_hint_text( + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], + ) -> str: + """Build hint text from verifier feedback. No LLM call. + + Parses ERROR_ACCUMULATOR / SUCCESS_ACCUMULATOR from verifier stdout + and formats tool errors into structured feedback for hinted rollouts. + """ + parts: List[str] = [] + + if verifier_stdout: + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + if err_match or suc_match: + try: + errors = ast.literal_eval(err_match.group(1)) if err_match else [] + successes = ast.literal_eval(suc_match.group(1)) if suc_match else [] + except Exception: + errors, successes = [], [] + if successes: + parts.append(f"Checks passed ({len(successes)}): " + ", ".join(str(s)[:100] for s in successes[:5])) + if errors: + parts.append(f"Checks failed ({len(errors)}): " + ", ".join(str(e)[:100] for e in errors[:5])) + + if verifier_error: + parts.append(f"Verifier: {verifier_error}") + + if tool_error_messages: + unique = list(dict.fromkeys(tool_error_messages))[:5] + parts.append("Tool errors: " + "; ".join(e[:200] for e in unique)) + + return "\n".join(parts) if parts else "The previous attempt failed. Try a different approach." + + def _get_fleet_client(self): + """Lazy-init Fleet SDK client.""" + if self._fleet_client is None: + from fleet import Fleet + + self._fleet_client = Fleet(api_key=self.fleet_api_key) + return self._fleet_client + + async def _poll_job(self, fleet, job_id: str, poll_interval: int = 10, timeout: int = 600) -> str: + """Poll Fleet job until completion or timeout. + + Returns: + Final job status string. + """ + start = time.time() + while time.time() - start < timeout: + try: + job = fleet.get_job(job_id) + status = job.status + if status in ("completed", "cancelled", "errored"): + return status + except Exception as e: + logger.warning(f"Error polling job {job_id}: {e}") + await asyncio.sleep(poll_interval) + + logger.error(f"Job {job_id} timed out after {timeout}s") + return "timeout" + + def _query_supabase_scores(self, job_id: str) -> Dict[str, float]: + """Query Supabase for session verifier scores as fallback. + + When Fleet backend doesn't populate verifier_execution FK (regression + since 2026-03-23), the score is still available in session metadata. + + Returns: + Dict mapping session_id -> verifier_score. + """ + supabase_url = os.environ.get("SUPABASE_URL", "") + supabase_key = os.environ.get("SUPABASE_KEY", "") + if not supabase_url or not supabase_key: + return {} + try: + import httpx + + resp = httpx.get( + f"{supabase_url}/rest/v1/sessions", + params={"job_id": f"eq.{job_id}", "select": "id,metadata"}, + headers={ + "apikey": supabase_key, + "Authorization": f"Bearer {supabase_key}", + }, + timeout=10, + ) + if resp.status_code != 200: + logger.warning(f"Supabase query failed: {resp.status_code}") + return {} + scores = {} + for row in resp.json(): + meta = row.get("metadata") or {} + sid = row.get("id") + v_score = meta.get("verifier_score") + if sid and v_score is not None: + scores[sid] = float(v_score) + return scores + except Exception as e: + logger.warning(f"Supabase fallback failed: {e}") + return {} + + def _extract_job_results(self, fleet, job_id: str) -> List[Tuple[float, Optional[str], Optional[str]]]: + """Extract (score, verifier_stdout, verifier_error) from completed job sessions. + + Primary path: read from session.verifier_execution (Fleet SDK). + Fallback: query Supabase for metadata.verifier_score when VE is null + (Fleet backend regression since 2026-03-23 stopped populating VE FK). + + Returns: + List of (score, stdout, error) tuples per session. + """ + results: List[Tuple[float, Optional[str], Optional[str]]] = [] + sessions_response = fleet.list_job_sessions(job_id) + + # Check if any session has verifier_execution populated + all_ve_null = all(s.verifier_execution is None for tg in sessions_response.tasks for s in tg.sessions) + + # Fallback: query Supabase only when needed + supabase_scores: Dict[str, float] = {} + if all_ve_null: + supabase_scores = self._query_supabase_scores(job_id) + if supabase_scores: + logger.info(f"[{job_id[:8]}] Using Supabase fallback for {len(supabase_scores)} session scores") + + for task_group in sessions_response.tasks: + for session in task_group.sessions: + score = 0.0 + stdout = None + error = None + if session.verifier_execution: + if session.verifier_execution.score is not None: + score = float(session.verifier_execution.score) + elif session.verifier_execution.success: + score = 1.0 + stdout = getattr(session.verifier_execution, "stdout", None) + # Capture error from verifier crashes — error is nested in result.error + ve_result = getattr(session.verifier_execution, "result", None) + if ve_result: + ve_error = ( + ve_result.get("error") if isinstance(ve_result, dict) else getattr(ve_result, "error", None) + ) + if ve_error: + error = ( + ve_error.get("message", "") + if isinstance(ve_error, dict) + else getattr(ve_error, "message", "") + ) + traceback_str = ( + ve_error.get("traceback", "") + if isinstance(ve_error, dict) + else getattr(ve_error, "traceback", "") + ) + if traceback_str: + # Extract just the last line of traceback (the actual error) + error = traceback_str.strip().split("\n")[-1] if traceback_str else error + elif session.session_id in supabase_scores: + # Fallback: use Supabase metadata.verifier_score + score = supabase_scores[session.session_id] + results.append((score, stdout, error)) + return results + + async def _run_harness_job( + self, prompt: str, verifier: str, k: int + ) -> List[Tuple[float, Optional[str], Optional[str]]]: + """Run a single Fleet harness job and return per-session results + job ID. + + 1. Import task to Fleet + 2. Create harness job with pass_k=k + 3. Poll until completion + 4. Extract results + + Returns: + Tuple of (job_id, results) where results is a list of + (score, verifier_stdout, verifier_error) tuples. + job_id is None on failure. + """ + from fleet.tasks import Task + + fleet = self._get_fleet_client() + task_key = f"taskgen_{uuid.uuid4().hex[:12]}" + + task = Task( + key=task_key, + prompt=prompt, + env_id=self.env_key, + version=self.env_version or None, + verifier_func=verifier, + data_id=self.data_key or None, + data_version=self.data_version or None, + env_variables=self.env_variables or {}, + ) + + import_response = fleet.import_single_task(task) + if import_response is None: + logger.error(f"[{task_key}] Failed to import task to Fleet") + return (None, [(0.0, None, None)] * k) + + job_response = fleet.create_job( + models=[self.evaluator_model], + task_keys=[task_key], + pass_k=k, + max_steps=self.max_eval_steps, + mode="tool-use", + name=f"taskgen-eval-{task_key}", + ) + job_id = job_response.job_id + logger.info(f"[{task_key}] Harness job created: {job_id} (model={self.evaluator_model}, k={k})") + + status = await self._poll_job(fleet, job_id) + if status != "completed": + logger.warning(f"[{task_key}] Job {job_id} ended with status: {status}") + return (job_id, [(0.0, None, None)] * k) + + return (job_id, self._extract_job_results(fleet, job_id)) + + async def _evaluate_task(self, prompt: str, verifier: str) -> Dict[str, float]: + """Run hint-based evaluation via Fleet harness jobs. + + 1. Raw job: k rollouts without hints + 2. Build hint from first failing session's verifier stdout + 3. Hinted job: k rollouts with hint appended to prompt + 4. Compute reward via compute_task_reward() + + Returns: + Reward breakdown dict from compute_task_reward. + """ + from integrations.fleet.task_gen_reward import compute_task_reward + + zero_result = compute_task_reward([], [], validity=1.0) + + if not self.fleet_api_key: + return zero_result + + task_id = f"taskgen_{uuid.uuid4().hex[:8]}" + start = time.time() + + try: + # Eval mode: k=8 raw only (no hints) for pass rate measurement + # Train mode: k raw + k hinted for hint_gap signal + eval_k = self.eval_k_rollouts if self.is_eval else self.k_rollouts + + # 1. Raw job: k rollouts without hints + raw_job_id, raw_results = await self._run_harness_job(prompt, verifier, k=eval_k) + raw_scores = [r[0] for r in raw_results] + + if self.is_eval: + # Eval: no hints, reward = alpha * var_raw (hint_gap=0) + hinted_scores = [] + hinted_job_id = None + hint_text = "" + result = compute_task_reward(raw_scores, raw_scores, validity=1.0) + else: + # 2. Build hint from first failing session's stdout/error + hint_stdout = None + hint_error = None + for score, stdout, error in raw_results: + if score < 1.0: + if stdout: + hint_stdout = stdout + if error: + hint_error = error + if hint_stdout or hint_error: + break + hint_text = self._build_hint_text(hint_stdout, hint_error, None) + + # Fallback: if hint is generic (no VE stdout due to backend regression), + # use the verifier source code as the hint. This tells the hinted agent + # exactly what checks to satisfy, creating hint_gap signal. + if hint_text == "The previous attempt failed. Try a different approach.": + # Truncate verifier to avoid blowing up prompt length + verifier_hint = verifier[:2000] + hint_text = ( + "Here is the verification function that will be used to check your work. " + "Make sure your actions satisfy all the checks:\n\n" + f"```python\n{verifier_hint}\n```" + ) + + # 3. Hinted job: k rollouts with hint + hinted_prompt = f"{prompt}\n\nHere is feedback from a previous attempt to help you:\n{hint_text}" + hinted_job_id, hinted_results = await self._run_harness_job(hinted_prompt, verifier, k=self.k_rollouts) + hinted_scores = [r[0] for r in hinted_results] + + # 4. Compute reward + result = compute_task_reward(raw_scores, hinted_scores, validity=1.0) + + duration = time.time() - start + + # --- Iron-clad eval logging --- + # Truncate prompt/verifier for log readability + prompt_log = prompt[:300].replace("\n", " ") + verifier_log = verifier[:200].replace("\n", " ") + hint_log = hint_text[:200].replace("\n", " ") + logger.info( + f"[{task_id}] EVAL | " + f"raw_job={raw_job_id} hinted_job={hinted_job_id} | " + f"raw={raw_scores} hinted={hinted_scores} | " + f"var={result['var_raw']:.4f} gap={result['hint_gap']:.4f} total={result['total']:.4f} | " + f"time={duration:.0f}s | " + f"prompt={prompt_log} | " + f"verifier={verifier_log} | " + f"hint={hint_log}" + ) + + # Save full rollout to local JSONL + self._save_rollout( + task_id=task_id, + env_key=self.env_key, + data_key=self.data_key, + prompt=prompt, + verifier=verifier, + hint=hint_text, + raw_scores=raw_scores, + hinted_scores=hinted_scores, + raw_job_id=raw_job_id, + hinted_job_id=hinted_job_id, + result=result, + duration=duration, + ) + + return result + + except Exception as e: + logger.error(f"[{task_id}] Evaluation failed: {e}") + return zero_result + + def _save_rollout( + self, + task_id, + env_key, + data_key, + prompt, + verifier, + hint, + raw_scores, + hinted_scores, + raw_job_id, + hinted_job_id, + result, + duration, + ): + """Append full rollout data to a local JSONL file.""" + try: + run_name = os.environ.get("RUN_NAME", "unknown") + path = os.path.join(self._rollout_dir, f"{run_name}.jsonl") + record = { + "task_id": task_id, + "env_key": env_key, + "data_key": data_key, + "prompt": prompt, + "verifier": verifier, + "hint": hint, + "raw_scores": raw_scores, + "hinted_scores": hinted_scores, + "raw_job_id": raw_job_id, + "hinted_job_id": hinted_job_id, + "var_raw": result["var_raw"], + "hint_gap": result["hint_gap"], + "total": result["total"], + "duration": duration, + "timestamp": time.time(), + } + with open(path, "a") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + except Exception as e: + logger.warning(f"[{task_id}] Failed to save rollout: {e}") + + async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: + """Evaluate a generated task through the full pipeline. + + Pipeline: + 1. Parse output -> fail = reward 0 + 2. Sandbox validation -> fail = reward 0 + 3. LLM-as-a-judge -> gate (0/1), fail = reward 0 + 4. Hint-based evaluation via Fleet harness (k raw + k hinted rollouts) + 5. R = base_quality + judge_gate * compute_task_reward(raw, hinted) + + base_quality (default 0.1) rewards structural validity (sandbox+judge pass), + providing GRPO gradient signal even when harness evals return all zeros. + """ + metadata: Dict[str, Any] = {"env_key": self.env_key, "turn": self.turns} + + # 1. Parse + parsed = parse_task_output(action) + if parsed is None: + metadata["error"] = "parse_failed" + metadata["reward_breakdown"] = {"total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + prompt = parsed["prompt"] + verifier = parsed["verifier"] + metadata["generated_prompt"] = prompt + metadata["generated_verifier"] = verifier + + # 2. Sandbox validation + validation = self.sandbox.validate(verifier, prompt) + metadata["validation"] = { + "valid": validation.valid, + "passed": validation.checks_passed, + "failed": validation.checks_failed, + "error": validation.error, + } + if not validation.valid: + metadata["reward_breakdown"] = {"sandbox": 0.0, "total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + # 3. LLM-as-a-judge gate + judge_gate = self._judge_task(prompt, verifier) + metadata["judge_gate"] = judge_gate + + if judge_gate == 0.0: + metadata["reward_breakdown"] = {"sandbox": 1.0, "judge": 0.0, "total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + # 4. Hint-based evaluation via Fleet harness + eval_result = await self._evaluate_task(prompt, verifier) + + # 5. R = base_quality + eval_signal + # base_quality: small reward for passing sandbox+judge (structural validity) + # eval_signal: judge_gate * compute_task_reward (harness-based quality) + # This prevents GRPO zero-signal deadlock when all harness evals fail. + base_quality = self.base_quality_reward + eval_signal = judge_gate * eval_result["total"] + reward = base_quality + eval_signal + + metadata["reward_breakdown"] = { + "sandbox": 1.0, + "judge": judge_gate, + "base_quality": base_quality, + "eval_signal": eval_signal, + **eval_result, + "total": reward, + } + + return BaseTextEnvStepOutput(observations=[], reward=reward, done=True, metadata=metadata) + + def step(self, action: str) -> BaseTextEnvStepOutput: + """Sync wrapper for step_async.""" + return asyncio.run(self.step_async(action)) + + async def step_async(self, action: str) -> BaseTextEnvStepOutput: + """Execute one step — tool call, task generation, or nudge. + + Multi-turn flow: + 1. block detected → evaluation pipeline (done=True) + 2. detected → execute describe_db/query_db (done=False) + 3. Neither → nudge observation (done=False) + 4. max_turns reached → done=True, reward=0 + """ + self.turns += 1 + max_turns_reached = self.turns >= self.max_turns + + # 1. Check for block → evaluation pipeline + if "" in action: + # Gate: require describe_db + query_db + at least one env tool call + # before generating a task (unless single-turn or out of turns) + if self.max_turns > 1 and not max_turns_reached: + missing = [] + if not self.called_describe_db: + missing.append("`describe_db` (to see the schema)") + if not self.called_query_db: + missing.append("`query_db` (to inspect actual data)") + if self.mcp_tool_calls < 1: + missing.append("at least one environment API tool (to understand input/output formats)") + if missing: + observation = { + "role": "user", + "content": ( + "You must explore the environment before generating a task. " + "You still need to call: " + + "; ".join(missing) + + ". NEVER hardcode database IDs — always query to find them first." + ), + } + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={ + "env_key": self.env_key, + "turn": self.turns, + "rejected": "no_exploration", + }, + ) + return await self._handle_task_generation(action) + + # 2. Check for tool call → execute via Fleet orchestrator or MCP + # Enforce exploration sequence: describe_db → query_db → env tool + tool_call = parse_tool_call(action) + if tool_call and tool_call["name"] in self.callable_tools: + if self.max_turns > 1 and not max_turns_reached: + name = tool_call["name"] + if name == "query_db" and not self.called_describe_db: + return BaseTextEnvStepOutput( + observations=[ + { + "role": "user", + "content": "Call `describe_db` first to see the schema before querying data.", + } + ], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "rejected": "sequence_violation"}, + ) + if name not in _META_TOOLS and not self.called_query_db: + return BaseTextEnvStepOutput( + observations=[ + { + "role": "user", + "content": ( + "Call `describe_db` and `query_db` first to understand the schema and data " + "before calling environment tools." + ), + } + ], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "rejected": "sequence_violation"}, + ) + + if tool_call["name"] in _META_TOOLS: + self.meta_tool_calls += 1 + if tool_call["name"] == "describe_db": + self.called_describe_db = True + elif tool_call["name"] == "query_db": + self.called_query_db = True + obs_content = await self._execute_meta_tool(tool_call) + else: + self.mcp_tool_calls += 1 + obs_content = await self._execute_mcp_tool(tool_call) + + if max_turns_reached: + return BaseTextEnvStepOutput( + observations=[], + reward=0.0, + done=True, + metadata={"env_key": self.env_key, "turn": self.turns, "done_reason": "max_turns"}, + ) + + observation = {"role": "user", "content": obs_content} + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "tool_call": tool_call}, + ) + + # 3. Neither task nor tool call → nudge + if max_turns_reached: + return BaseTextEnvStepOutput( + observations=[], + reward=0.0, + done=True, + metadata={ + "env_key": self.env_key, + "turn": self.turns, + "done_reason": "max_turns", + }, + ) + + remaining = self.max_turns - self.turns + if self.max_turns == 1: + nudge = "No block found. Output your task in ... format." + elif remaining <= 2: + nudge = ( + f"You have {remaining} turn(s) left. Output your block NOW or you will " + "get reward 0. Stop exploring and generate the task." + ) + else: + nudge = "Use to explore the database or call environment tools, then generate a block." + observation = {"role": "user", "content": nudge} + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns}, + ) + + async def _execute_meta_tool(self, tool_call: Dict[str, Any]) -> str: + """Execute a describe_db or query_db meta-tool call via the Fleet orchestrator.""" + name = tool_call["name"] + args = tool_call.get("arguments", {}) + + if self.orch is None: + return "Error: Fleet environment not provisioned. Generate a directly." + + try: + if name == "describe_db": + result = await self.orch.describe_db_async(db_name=args.get("db_name", "seed")) + elif name == "query_db": + sql = args.get("sql", "") + if not sql: + return "Error: query_db requires a 'sql' argument." + result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) + else: + return f"Error: Unknown meta-tool '{name}'." + + if isinstance(result, dict): + return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" + return f"Tool result:\n{result}" + except Exception as e: + return f"Error: {e}" + + async def _execute_mcp_tool(self, tool_call: Dict[str, Any]) -> str: + """Execute an MCP tool call via FleetMCPTools.""" + name = tool_call["name"] + args = tool_call.get("arguments", {}) + + if self.mcp_tools is None: + return "Error: MCP tools not available. Use describe_db/query_db or generate a ." + + try: + result = await self.mcp_tools.call_tool(name, args) + if isinstance(result, dict): + return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" + return f"Tool result:\n{result}" + except Exception as e: + return f"Error calling {name}: {e}" + + async def init_async(self, prompt: ConversationType) -> Tuple[ConversationType, Dict[str, Any]]: + """Initialize the environment, optionally provisioning a Fleet env for DB exploration. + + When ``max_turns > 1``, provisions a Fleet environment via + ``FleetEnvClient.from_fleet_async`` so the model can call + ``describe_db`` / ``query_db`` during exploration turns. + Falls back to single-turn if provisioning fails. + """ + self.turns = 0 + self.meta_tool_calls = 0 + self.mcp_tool_calls = 0 + self.called_describe_db = False + self.called_query_db = False + self.orch = None + self.mcp_tools = None + self.callable_tools = set(_META_TOOLS) + + # Provision Fleet env for multi-turn exploration (DB + MCP tools) + if self.max_turns > 1 and self.fleet_api_key and self.data_key: + try: + from envs.fleet_env import FleetEnvClient + + self.orch, self.mcp_tools = await FleetEnvClient.from_fleet_async( + api_key=self.fleet_api_key, + env_key=self.env_key, + data_key=self.data_key, + data_version=self.data_version, + image_type="standard", + ttl_seconds=900, + ) + # Load instance resources so db("seed") works + # instance.load() is async — must await directly, not via to_thread + await self.orch._fleet_env.instance.load() + logger.info(f"TaskGenEnv [{self.env_key}]: Fleet env provisioned for DB + tool exploration") + + # Discover MCP tools so the model can call them + if self.mcp_tools: + try: + tools_action = await self.mcp_tools.list_tools() + mcp_tools = [ + t for t in tools_action.tools if "function" in t and t["function"].get("name") != "computer" + ] + mcp_tool_names = {t["function"]["name"] for t in mcp_tools} + self.callable_tools = set(_META_TOOLS) | mcp_tool_names + # Update tool schemas for system prompt if dataset didn't have them + if not self.env_tools_schema: + self.env_tools_schema = mcp_tools + self.env_tools = [t["function"]["name"] for t in mcp_tools] + logger.info(f"TaskGenEnv [{self.env_key}]: {len(mcp_tool_names)} MCP tools available") + except Exception as e: + logger.warning(f"TaskGenEnv [{self.env_key}]: Failed to list MCP tools: {e}") + except Exception as e: + logger.warning( + f"TaskGenEnv [{self.env_key}]: Fleet provisioning failed, " f"falling back to single-turn: {e}" + ) + self.max_turns = 1 + + system_prompt = self._build_system_prompt() + + user_content = ( + f"Generate a task for the {self.env_key} environment. " + "First explore the database to understand the data, then draft a prompt and verifier. " + "Before outputting, query the DB to verify your assumptions are correct — " + "iterate on your draft until you're confident the data supports it." + if self.max_turns > 1 + else f"Generate a task for the {self.env_key} environment." + ) + + conversation = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + metadata = { + "env_key": self.env_key, + "env_version": self.env_version, + "num_tools": len(self.env_tools), + "multi_turn": self.max_turns > 1, + } + + return conversation, metadata + + def init(self, prompt: ConversationType) -> Tuple[ConversationType, Dict[str, Any]]: + """Sync wrapper for init_async.""" + return asyncio.run(self.init_async(prompt)) + + def close(self): + """Close the Fleet orchestrator if provisioned.""" + if self.orch is not None: + try: + self.orch.close() + except Exception: + pass + self.orch = None + + async def close_async(self): + """Async close — release Fleet orchestrator resources.""" + if self.orch is not None: + try: + await self.orch.close_async() + except Exception: + pass + self.orch = None + + def get_metrics(self) -> Dict[str, Any]: + """Return per-episode metrics.""" + return { + "env_key": self.env_key, + "turns": self.turns, + } + + @staticmethod + def aggregate_metrics(metrics: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate metrics across episodes.""" + if not metrics: + return {} + + # Group by env_key + env_counts: Dict[str, int] = {} + for m in metrics: + env_key = m.get("env_key", "unknown") + env_counts[env_key] = env_counts.get(env_key, 0) + 1 + + result = {"total_episodes": len(metrics)} + for env_key, count in env_counts.items(): + result[f"{env_key}/episodes"] = count + + return result diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py b/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py new file mode 100644 index 0000000000..f328507d18 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py @@ -0,0 +1,67 @@ +""" +Tool call parser for task generation environment. + +Parses and tagged JSON from LLM responses. +Copied from skyrl-train/integrations/fleet/env.py to avoid cross-package imports. +""" + +import json +import re +from typing import Any, Dict, Optional + + +def _try_parse_json(raw: str) -> Optional[Dict[str, Any]]: + """Try to parse JSON, repairing missing trailing braces if needed.""" + raw = raw.strip() + try: + parsed = json.loads(raw) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + pass + + # Repair: models often drop trailing closing braces on nested JSON. + # Try appending up to 3 closing braces. + for extra in range(1, 4): + try: + parsed = json.loads(raw + "}" * extra) + if isinstance(parsed, dict): + return parsed + except (json.JSONDecodeError, ValueError): + continue + + return None + + +def parse_tool_call(action: str) -> Optional[Dict[str, Any]]: + """ + Parse tool call from LLM response. + + Supports tag-based formats: + - {"name": "...", "arguments": {...}} + - {"name": "...", "arguments": {...}} + + Also handles cases where the closing tag is missing (e.g., when + is used as the stop string and not included in the output). + + Returns dict with "name" and "arguments" keys, or None if not found. + """ + # Try common tag formats + for tag in ["tool_call", "function_call"]: + # First try with closing tag + match = re.search(rf"<{tag}>(.*?)", action, re.DOTALL) + if not match: + # Try without closing tag (for when is the stop string) + # Match from opening tag to end of string or next special token + match = re.search(rf"<{tag}>(.*?)(?:<\||\Z)", action, re.DOTALL) + if match: + parsed = _try_parse_json(match.group(1)) + if parsed is None: + continue + # Normalize keys + name = parsed.get("name") or parsed.get("tool") + args = parsed.get("arguments") or parsed.get("params", {}) + if name: + return {"name": name, "arguments": args} + + return None diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py new file mode 100644 index 0000000000..5da7cd3093 --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py @@ -0,0 +1,300 @@ +""" +Verifier sandbox for task generation. + +Validates generated verifier code via AST analysis and safe execution checks. +Used as the validity gate in the task generation reward: + R(task) = validity * (variance + alpha * separation) + +If validity returns 0, the entire reward is zeroed out. +""" + +import ast +import re +from dataclasses import dataclass, field +from typing import List, Optional, Set + + +@dataclass +class ValidationResult: + """Result of verifier validation.""" + + valid: bool + checks_passed: List[str] = field(default_factory=list) + checks_failed: List[str] = field(default_factory=list) + error: Optional[str] = None + + @property + def score(self) -> float: + """Return 1.0 if valid, 0.0 otherwise (multiplicative gate).""" + return 1.0 if self.valid else 0.0 + + +# Disallowed modules/builtins in verifier code +BLOCKED_IMPORTS = { + "os", + "sys", + "subprocess", + "shutil", + "pathlib", + "socket", + "http", + "urllib", + "requests", + "importlib", + "ctypes", + "signal", + "multiprocessing", + "threading", + "pickle", + "shelve", + "tempfile", + "glob", + "io", +} + +BLOCKED_BUILTINS = { + "exec", + "eval", + "compile", + "__import__", + "open", + "input", + "breakpoint", + "exit", + "quit", +} + +# Min/max AST node count for verifier complexity +MIN_AST_NODES = 5 # reject trivial verifiers like `return 1.0` +MAX_AST_NODES = 500 # reject overly complex verifiers + + +class VerifierSandbox: + """Validates and sandboxes generated verifier code. + + Performs static analysis to catch common issues before any execution: + 1. Python syntax validity (AST parsing) + 2. Function signature check (must be `async def verify(env, ...)`) + 3. Complexity bounds (not trivial, not overly complex) + 4. No dangerous imports or builtins + 5. Must reference env parameter (actually uses the environment) + 6. Prompt-verifier alignment (optional, LLM-based) + """ + + def __init__(self, available_tools: Optional[Set[str]] = None): + """ + Args: + available_tools: Set of tool names available in the target environment. + If provided, checks that verifier references at least one real tool. + """ + self.available_tools = available_tools or set() + + def validate( + self, + verifier_code: str, + prompt: Optional[str] = None, + ) -> ValidationResult: + """Run all validation checks on verifier code. + + Args: + verifier_code: The generated verifier Python code. + prompt: The associated task prompt (for alignment checks). + + Returns: + ValidationResult with pass/fail and details. + """ + result = ValidationResult(valid=True) + + # 1. Parse as valid Python + tree = self._check_syntax(verifier_code, result) + if tree is None: + result.valid = False + return result + + # 2. Check function signature + self._check_signature(tree, result) + + # 3. Check complexity bounds + self._check_complexity(tree, result) + + # 4. Check for dangerous imports/builtins + self._check_safety(tree, result) + + # 5. Check env usage + self._check_env_usage(tree, result) + + # 6. Check for hardcoded return values + self._check_hardcoded_returns(tree, result) + + # 7. Check prompt length bounds (if prompt provided) + if prompt is not None: + self._check_prompt_bounds(prompt, result) + + # Any failed check -> invalid + if result.checks_failed: + result.valid = False + + return result + + def _check_syntax(self, code: str, result: ValidationResult) -> Optional[ast.AST]: + """Check that verifier code is valid Python.""" + try: + tree = ast.parse(code) + result.checks_passed.append("syntax") + return tree + except SyntaxError as e: + result.checks_failed.append("syntax") + result.error = f"SyntaxError: {e}" + return None + + def _check_signature(self, tree: ast.AST, result: ValidationResult): + """Check that verifier defines a valid function with env parameter. + + Accepts both `verify(env, ...)` and `validate_task(env, ...)` names, + both sync and async. + """ + valid_names = {"verify", "validate_task"} + for node in ast.walk(tree): + if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): + if node.name in valid_names: + args = node.args + arg_names = [a.arg for a in args.args] + if "env" in arg_names: + result.checks_passed.append("signature") + return + else: + result.checks_failed.append("signature") + result.error = f"{node.name}() must have 'env' parameter, got: {arg_names}" + return + + result.checks_failed.append("signature") + result.error = "No verify(env, ...) or validate_task(env, ...) function found" + + def _check_complexity(self, tree: ast.AST, result: ValidationResult): + """Check AST node count is within bounds.""" + node_count = sum(1 for _ in ast.walk(tree)) + + if node_count < MIN_AST_NODES: + result.checks_failed.append("complexity_min") + result.error = f"Verifier too simple ({node_count} nodes < {MIN_AST_NODES})" + elif node_count > MAX_AST_NODES: + result.checks_failed.append("complexity_max") + result.error = f"Verifier too complex ({node_count} nodes > {MAX_AST_NODES})" + else: + result.checks_passed.append("complexity") + + def _check_safety(self, tree: ast.AST, result: ValidationResult): + """Check for dangerous imports and builtin calls.""" + for node in ast.walk(tree): + # Check imports + if isinstance(node, ast.Import): + for alias in node.names: + module = alias.name.split(".")[0] + if module in BLOCKED_IMPORTS: + result.checks_failed.append("safety_import") + result.error = f"Blocked import: {alias.name}" + return + + elif isinstance(node, ast.ImportFrom): + if node.module: + module = node.module.split(".")[0] + if module in BLOCKED_IMPORTS: + result.checks_failed.append("safety_import") + result.error = f"Blocked import from: {node.module}" + return + + # Check dangerous builtin calls + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + if node.func.id in BLOCKED_BUILTINS: + result.checks_failed.append("safety_builtin") + result.error = f"Blocked builtin call: {node.func.id}" + return + + result.checks_passed.append("safety") + + def _check_env_usage(self, tree: ast.AST, result: ValidationResult): + """Check that the verifier actually uses the env parameter.""" + # Look for attribute access on 'env' (e.g., env.list_issues, env.get_data) + # or 'env' passed as argument to await expressions + env_used = False + for node in ast.walk(tree): + if isinstance(node, ast.Attribute): + if isinstance(node.value, ast.Name) and node.value.id == "env": + env_used = True + break + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name) and node.func.id == "env": + env_used = True + break + + if env_used: + result.checks_passed.append("env_usage") + else: + result.checks_failed.append("env_usage") + result.error = "Verifier does not use 'env' parameter" + + def _check_hardcoded_returns(self, tree: ast.AST, result: ValidationResult): + """Check that verifier isn't just `return 1.0` or `return 0.0`.""" + valid_names = {"verify", "validate_task"} + verify_func = None + for node in ast.walk(tree): + if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef)): + if node.name in valid_names: + verify_func = node + break + + if verify_func is None: + return # Already caught by signature check + + # Check if all return statements are constant + returns = [n for n in ast.walk(verify_func) if isinstance(n, ast.Return)] + if not returns: + result.checks_failed.append("hardcoded_return") + result.error = "Verifier has no return statements" + return + + all_constant = all(isinstance(r.value, ast.Constant) for r in returns if r.value is not None) + + if all_constant and len(returns) == 1: + result.checks_failed.append("hardcoded_return") + result.error = "Verifier always returns a constant value" + else: + result.checks_passed.append("return_logic") + + def _check_prompt_bounds(self, prompt: str, result: ValidationResult): + """Check that prompt is within reasonable length bounds.""" + word_count = len(prompt.split()) + + if word_count < 5: + result.checks_failed.append("prompt_length") + result.error = f"Prompt too short ({word_count} words < 5)" + elif word_count > 500: + result.checks_failed.append("prompt_length") + result.error = f"Prompt too long ({word_count} words > 500)" + else: + result.checks_passed.append("prompt_length") + + +def parse_task_output(action: str) -> Optional[dict]: + """Parse LLM output to extract task prompt and verifier code. + + Expected format: + + ... + ... + + + Returns: + Dict with 'prompt' and 'verifier' keys, or None if parsing fails. + """ + prompt_match = re.search(r"(.*?)", action, re.DOTALL) + verifier_match = re.search(r"(.*?)", action, re.DOTALL) + + if not prompt_match or not verifier_match: + return None + + return { + "prompt": prompt_match.group(1).strip(), + "verifier": verifier_match.group(1).strip(), + } From 91776e10fa6edeaa0ef4ca5cb6260d0f4334be3b Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 14:57:53 -0700 Subject: [PATCH 004/121] Add hint augmentation support for Fleet task training When all raw rollout samples for a prompt score 0, hint augmentation generates additional rollouts with verifier feedback injected into the prompt. This rescues GRPO signal for otherwise dead prompts. Key components: - _run_hint_augmentation() in SkyRLGymGenerator: groups outputs by instance_id, identifies failing prompts, builds hint text from verifier ERROR/SUCCESS_ACCUMULATOR, launches hinted rollouts - RLTF-SD: replaces hinted prompt_ids with original unhinted prompt_ids so the model learns to produce hint-quality outputs from the original prompt alone (grad log pi(y_hint | x_0) not grad log pi(y_hint | x_0 + hint)) - First-turn baseline in compute_grpo_outcome_advantage: when is_hinted is present, computes group mean/std from raw samples only, preventing hinted samples from contaminating the GRPO baseline - Metrics: hint/total_hinted_rollouts, hint/hint_success_rate, hint/prompts_hinted, hint/signal_rescued Config: enable_hints, hint_reward_threshold, n_hint_samples in fleet_task section of skyrl_gym_config. Only runs during training (not eval), only for non-step-wise trajectories, and only when fleet_task.enable_hints=true. Depends on PR #1 (fleet/task-env) for FleetTaskEnv.build_hint_text(). Co-Authored-By: Claude Opus 4.6 --- skyrl/backends/skyrl_train/utils/ppo_utils.py | 58 ++++-- .../config/skyrl_gym_config/default.yaml | 10 + skyrl/train/generators/base.py | 2 + skyrl/train/generators/skyrl_gym_generator.py | 193 ++++++++++++++++++ skyrl/train/trainer.py | 14 +- 5 files changed, 264 insertions(+), 13 deletions(-) diff --git a/skyrl/backends/skyrl_train/utils/ppo_utils.py b/skyrl/backends/skyrl_train/utils/ppo_utils.py index 189fecccc0..fe48a8dded 100644 --- a/skyrl/backends/skyrl_train/utils/ppo_utils.py +++ b/skyrl/backends/skyrl_train/utils/ppo_utils.py @@ -1192,17 +1192,24 @@ def compute_grpo_outcome_advantage( index: np.ndarray, epsilon: float = 1e-6, grpo_norm_by_std: bool = True, + is_hinted: Optional[np.ndarray] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). + When ``is_hinted`` is provided, uses a **first-turn baseline**: the group mean and std + are computed from raw (unhinted) samples only. All samples (raw + hinted) are then + centered using this raw-only baseline. This prevents hinted samples from contaminating + the baseline for raw samples (RLTF-SD paper, Section 3.2). + Expects: - token_level_rewards: Float[torch.Tensor, "batch_size seqlen"] - response_mask: Float[torch.Tensor, "batch_size seqlen"] - index: np.ndarray (batch_size) - epsilon: float - grpo_norm_by_std: bool + - is_hinted: Optional[np.ndarray] bool array (batch_size), True for hinted samples Returns: - advantages: Float[torch.Tensor, "batch_size seqlen"] @@ -1211,23 +1218,50 @@ def compute_grpo_outcome_advantage( # this assumes response-level rewards scores = token_level_rewards.sum(dim=-1) - id2score = defaultdict(list) id2mean = {} id2std = {} + use_first_turn_baseline = is_hinted is not None and np.any(is_hinted) + with torch.no_grad(): bsz = scores.shape[0] - for i in range(bsz): - id2score[index[i]].append(scores[i]) - for idx in id2score: - if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) - elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) - else: - raise ValueError(f"no score in prompt index: {idx}") + + if use_first_turn_baseline: + # First-turn baseline: compute mean/std from raw (unhinted) samples only + id2raw_scores = defaultdict(list) + for i in range(bsz): + if not is_hinted[i]: + id2raw_scores[index[i]].append(scores[i]) + + for idx in id2raw_scores: + raw = id2raw_scores[idx] + if len(raw) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + else: + id2mean[idx] = torch.mean(torch.tensor(raw)) + id2std[idx] = torch.std(torch.tensor([raw])) + + # For groups with only hinted samples (no raw), use 0 mean / 1 std + for i in range(bsz): + if index[i] not in id2mean: + id2mean[index[i]] = torch.tensor(0.0) + id2std[index[i]] = torch.tensor(1.0) + else: + # Standard GRPO: compute mean/std from all samples + id2score = defaultdict(list) + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): if grpo_norm_by_std: scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) diff --git a/skyrl/train/config/skyrl_gym_config/default.yaml b/skyrl/train/config/skyrl_gym_config/default.yaml index a94985f1e5..e15732ae5d 100644 --- a/skyrl/train/config/skyrl_gym_config/default.yaml +++ b/skyrl/train/config/skyrl_gym_config/default.yaml @@ -14,3 +14,13 @@ search: search_url: "http://127.0.0.1:8000/retrieve" topk: 3 timeout: 30 + +fleet_task: + tasks_file: null + api_key: null + ttl_seconds: 900 + partial_reward: false + enable_hints: false + hint_reward_threshold: 0.0 + n_hint_samples: 2 + enable_context_tools: false diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index c2456b974f..da04c0ffdd 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -43,6 +43,8 @@ class GeneratorOutput(TypedDict): rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] # Applicable only for step-wise training is_last_step: Optional[List[bool]] + # Hint augmentation: True for samples generated with hint feedback + is_hinted: Optional[List[bool]] class MetricsOutput(TypedDict): diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 446ad0e572..2592e81b24 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -7,6 +7,7 @@ import asyncio import copy +from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -629,6 +630,127 @@ def _update_chat_history( chat_history += new_obs return chat_history + async def _run_hint_augmentation( + self, + all_outputs: List[TrajectoryOutput], + prompts: List[ConversationType], + env_classes: List[str], + env_extras: List[Dict[str, Any]], + trajectory_ids: List[TrajectoryID], + max_tokens: int, + max_input_length: int, + sampling_params: Optional[Dict[str, Any]], + hint_cfg, + ) -> Tuple[List[TrajectoryOutput], List[TrajectoryID], List[str]]: + """Run hinted rollouts for prompts where all raw samples failed. + + Groups raw outputs by instance_id, identifies groups where max_reward < threshold, + builds hint text from verifier feedback, and launches additional hinted rollouts. + + Uses RLTF-SD: hinted rollout prompt_ids are replaced with the original unhinted + prompt_ids so the model learns to produce hint-quality outputs conditioned on the + original prompt alone: grad log pi(y_hint | x_0) instead of grad log pi(y_hint | x_0 + hint). + + Returns: + Tuple of (hinted_outputs, hinted_trajectory_ids, hinted_env_classes) + """ + from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + + # 1. Group outputs by instance_id + groups: Dict[str, List[Tuple[int, TrajectoryOutput]]] = defaultdict(list) + for i, output in enumerate(all_outputs): + iid = trajectory_ids[i].instance_id + groups[iid].append((i, output)) + + # 2. Identify prompts needing hints + hint_tasks = [] + hint_tids = [] + hint_envs = [] + orig_prompt_ids = [] # unhinted prompt_ids for RLTF-SD + prompts_hinted = 0 + hint_reward_threshold = hint_cfg.get("hint_reward_threshold", 0.0) if hasattr(hint_cfg, "get") else 0.0 + + for iid, items in groups.items(): + rewards = [] + for _, output in items: + r = output.reward + rewards.append(r if isinstance(r, (int, float)) else sum(r)) + max_reward = max(rewards) + + if max_reward > hint_reward_threshold: + continue # at least one raw sample has signal + + # Find best raw rollout (highest partial reward) for feedback + best_idx = max(range(len(items)), key=lambda j: rewards[j]) + best_orig_idx, best_output = items[best_idx] + metrics = best_output.env_metrics + + # Build hint from verifier feedback + hint_text = FleetTaskEnv.build_hint_text( + verifier_stdout=metrics.get("verifier_stdout"), + verifier_error=metrics.get("verifier_error"), + tool_error_messages=metrics.get("tool_error_messages"), + ) + if not hint_text: + continue + + logger.info( + f"Hint for instance {iid} (best_reward={rewards[best_idx]:.3f}, " + f"verifier_stdout={bool(metrics.get('verifier_stdout'))}, " + f"verifier_error={bool(metrics.get('verifier_error'))}):\n{hint_text}" + ) + prompts_hinted += 1 + + # Create hinted agent_loop tasks (new env instances) + base_rep_id = max(item[0] for item in items) + 1 + n_hint = hint_cfg.get("n_hint_samples", 2) if hasattr(hint_cfg, "get") else 2 + for h in range(n_hint): + hinted_extras = dict(env_extras[best_orig_idx]) + hinted_extras["hint"] = hint_text + hinted_extras["is_hinted"] = True + tid = TrajectoryID(instance_id=iid, repetition_id=base_rep_id + h) + hint_tasks.append( + self.agent_loop( + prompts[best_orig_idx], + env_classes[best_orig_idx], + hinted_extras, + max_tokens, + max_input_length, + sampling_params=sampling_params, + trajectory_id=tid, + ) + ) + hint_tids.append(tid) + hint_envs.append(env_classes[best_orig_idx]) + orig_prompt_ids.append(best_output.prompt_ids) + + # 3. Run all hinted rollouts in parallel + if hint_tasks: + logger.info( + f"Hint augmentation: {prompts_hinted} prompts need hints, " + f"launching {len(hint_tasks)} hinted rollouts" + ) + hint_outputs = await tqdm.gather( + *hint_tasks, + desc="Hinted Rollouts", + miniters=1, + mininterval=5, + ) + # RLTF-SD: strip hint from training prompt. Replace hinted prompt_ids + # with the original unhinted prompt_ids so the model learns to produce + # hint-quality outputs conditioned on the original prompt alone. + hint_outputs = list(hint_outputs) + for i, output in enumerate(hint_outputs): + hinted_len = len(output.prompt_ids) + output.prompt_ids = orig_prompt_ids[i] + logger.debug( + f"RLTF-SD: replaced hinted prompt ({hinted_len} tokens) " + f"with original prompt ({len(output.prompt_ids)} tokens)" + ) + return hint_outputs, hint_tids, hint_envs + + return [], [], [] + async def generate_batched( self, prompts: List[ConversationType], @@ -771,6 +893,47 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False disable=disable_tqdm, ) + # --- Hint augmentation: rescue GRPO signal on dead prompts --- + # Only during training; eval should not run hints. + n_raw = len(all_outputs) + batch_metadata = input_batch.get("batch_metadata") + is_training = batch_metadata is not None and batch_metadata.training_phase == "train" + hint_cfg = getattr(self.skyrl_gym_cfg, "fleet_task", None) + enable_hints = hint_cfg is not None and (hint_cfg.get("enable_hints", False) if hasattr(hint_cfg, "get") else False) + if ( + enable_hints + and not self.generator_cfg.step_wise_trajectories + and trajectory_ids is not None + and is_training + ): + hint_outputs, hint_tids, hint_env_classes = await self._run_hint_augmentation( + all_outputs=list(all_outputs), + prompts=prompts, + env_classes=env_classes, + env_extras=env_extras, + trajectory_ids=trajectory_ids, + max_tokens=max_tokens, + max_input_length=max_input_length, + sampling_params=sampling_params, + hint_cfg=hint_cfg, + ) + if hint_outputs: + all_outputs = list(all_outputs) + hint_outputs + # Extend in-place so input_batch references are updated (trainer reads these) + trajectory_ids.extend(hint_tids) + env_classes.extend(hint_env_classes) + # Also extend prompts and env_extras arrays to stay aligned + for tid in hint_tids: + # Find original prompt index for this instance_id + for orig_i, orig_tid in enumerate(input_batch.get("trajectory_ids", [])): + if orig_tid.instance_id == tid.instance_id: + prompts.append(prompts[orig_i]) + env_extras.append(env_extras[orig_i]) + break + + # Build is_hinted array: raw samples are False, hint-augmented samples are True + is_hinted = [False] * n_raw + [True] * (len(all_outputs) - n_raw) + if self.generator_cfg.step_wise_trajectories: responses = [] rewards = [] @@ -827,6 +990,35 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) + # Log hint augmentation metrics + hinted_metrics = [m for m in env_metrics if isinstance(m, dict) and m.get("is_hinted")] + if hinted_metrics: + n_hinted = len(hinted_metrics) + hinted_rewards = [] + for m in hinted_metrics: + r = m.get("final_reward", 0.0) + hinted_rewards.append(r if r is not None else 0.0) + n_success = sum(1 for r in hinted_rewards if r > 0) + rollout_metrics["hint/total_hinted_rollouts"] = n_hinted + rollout_metrics["hint/hint_success_rate"] = n_success / n_hinted if n_hinted > 0 else 0.0 + # Count unique prompts that were hinted (by instance_id) + hinted_iids = set() + for i, m in enumerate(env_metrics): + if m.get("is_hinted") and trajectory_ids is not None and i < len(trajectory_ids): + hinted_iids.add(trajectory_ids[i].instance_id) + rollout_metrics["hint/prompts_hinted"] = len(hinted_iids) + # Signal rescued: prompts where at least 1 hinted sample scored > 0 + rescued = 0 + for iid in hinted_iids: + for j, m in enumerate(env_metrics): + if m.get("is_hinted") and trajectory_ids is not None and j < len(trajectory_ids): + if trajectory_ids[j].instance_id == iid: + r = m.get("final_reward", 0.0) + if r is not None and r > 0: + rescued += 1 + break + rollout_metrics["hint/signal_rescued"] = rescued / len(hinted_iids) if hinted_iids else 0.0 + if self.generator_cfg.zero_reward_on_non_stop: # set reward to 0 if the stop reason is not "stop" rewards = self._zero_reward_if_not_stop(rewards, stop_reasons) @@ -846,6 +1038,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "trajectory_ids": out_trajectory_ids, "rollout_expert_indices": rollout_expert_indices, "is_last_step": is_last_step, + "is_hinted": is_hinted, } return generator_output diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index e22312c56e..a7d6629155 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -659,6 +659,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis }, ) training_input.metadata = {"uids": uids} + # Track which samples are hint-augmented for first-turn baseline + if generator_output.get("is_hinted") is not None: + training_input.metadata["is_hinted"] = generator_output["is_hinted"] # padded response length training_input.metadata["response_length"] = response_masks_tensor.shape[1] batch_num_seq, batch_padded_seq_len = sequences_tensor.shape @@ -798,10 +801,15 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn """ token_level_rewards = data["rewards"] + # Convert is_hinted metadata to numpy array for advantage computation + is_hinted_list = data.metadata.get("is_hinted") + is_hinted = np.array(is_hinted_list) if is_hinted_list is not None else None + if self.cfg.generator.step_wise_trajectories: is_last_step = data["is_last_step"].bool() index = np.array(data.metadata["uids"]) values = data["values"] + last_step_is_hinted = is_hinted[is_last_step.cpu().numpy()] if is_hinted is not None else None # Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator # NOTE(Charlie): so we ignore per-step rewards in step-wise training. last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns( @@ -814,6 +822,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn gamma=self.cfg.trainer.algorithm.gamma, lambd=self.cfg.trainer.algorithm.lambd, grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std, + is_hinted=last_step_is_hinted, ) # Broadcast each trajectory's advantage and return to all steps of each trajectory. traj_ids = ( @@ -836,6 +845,7 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn gamma=self.cfg.trainer.algorithm.gamma, lambd=self.cfg.trainer.algorithm.lambd, grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std, + is_hinted=is_hinted, ) data["returns"] = returns data["advantages"] = advantages @@ -920,8 +930,10 @@ def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [ f"pad{i}" for i in range(pad_size) ] + if "is_hinted" in training_input.metadata: + new_training_input.metadata["is_hinted"] = training_input.metadata["is_hinted"] + [False] * pad_size for key, value in training_input.metadata.items(): - if key not in ["uids", "trajectory_ids"]: + if key not in ["uids", "trajectory_ids", "is_hinted"]: new_training_input.metadata[key] = copy.deepcopy(value) return new_training_input From ff094d9236fa05136db84f69450c45267d0f4a2b Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 23:11:02 -0700 Subject: [PATCH 005/121] Add VL/CUA multimodal support ported from SkyRL PR #288 Port vision-language model support from SkyRL v1 (feat/vl-support-clean) to SkyRL-v2's architecture: - Generator: VL-aware chat template, image accumulation across turns, multi_modal_data construction for vLLM - Engine pipeline: thread multi_modal_data through preprocess/generate in both sync and async vLLM engines - Fleet env: Qwen coordinate adaptation ([0,1000] <-> pixel), initial screenshot capture, computer_use browser hints, done signal detection - Utilities: image extraction, base64 decode, processor loading, VL chat template with proper vision token expansion - New VL run script and SkyPilot YAML for CUA training - Update existing YAMLs to use fleet/all branch Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 92 +++++++++ skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 127 +++++++++++- .../skyrl_train/inference_engines/base.py | 30 ++- .../inference_engine_client.py | 3 + .../inference_engines/vllm/vllm_engine.py | 41 +++- skyrl/train/entrypoints/main_base.py | 1 + skyrl/train/generators/skyrl_gym_generator.py | 85 +++++++- skyrl/train/generators/utils.py | 184 +++++++++++++++++- tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 62 ++++++ tasks/task-gen-grpo-qwen3_5-9b.yaml | 2 +- 11 files changed, 604 insertions(+), 25 deletions(-) create mode 100755 scripts/fleet-vl-run.sh create mode 100644 tasks/openenv-fleet-grpo-vl.yaml diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh new file mode 100755 index 0000000000..5aaf456562 --- /dev/null +++ b/scripts/fleet-vl-run.sh @@ -0,0 +1,92 @@ +#!/usr/bin/env bash +# VL/CUA (Vision-Language / Computer Use Agent) GRPO training config. +# Called by the SkyPilot YAML and by fleet-research run.sh. +# +# Based on working config from SkyRL PR #288 (feat/vl-support-clean), +# adapted to SkyRL-v2's fleet-common-run.sh pattern. +# +# Model: Qwen/Qwen3.5-9B (9B params, natively multimodal, GatedDeltaNet) +# TP=1 (single GPU per engine, 8 engines on 8x H200) +# Modality: computer_use (screenshots + coordinate normalization) +# +# Required env vars: FLEET_API_KEY, WANDB_API_KEY +# Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) +set -euo pipefail +cd "$(dirname "$0")/../.." # cd to SkyRL root + +# Defaults for vars normally set by SkyPilot YAML envs block +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export DATA_VERSION="${DATA_VERSION:-v52}" +export MODALITY="${MODALITY:-computer_use}" +export NUM_EPOCHS="${NUM_EPOCHS:-10}" +export MAX_TURNS="${MAX_TURNS:-50}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-131072}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export ENV_KEYS="${ENV_KEYS:-}" +export DIFFICULTY="${DIFFICULTY:-}" +export RUN_ID="${RUN_ID:-}" +export MAX_TASKS="${MAX_TASKS:-}" +export RESUME_RUN_NAME="${RESUME_RUN_NAME:-}" +export AWS_REGION="${AWS_REGION:-us-east-1}" +export S3_DATASET_BUCKET="${S3_DATASET_BUCKET:-fleet-internal-datasets}" +export S3_CHECKPOINT_BUCKET="${S3_CHECKPOINT_BUCKET:-skyrl-checkpoints}" +export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" + +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf -- \ + environment.skyrl_gym.fleet_task.ttl_seconds=1800 \ + environment.skyrl_gym.fleet_task.partial_reward=false \ + environment.skyrl_gym.fleet_task.enable_hints=false \ + trainer.algorithm.advantage_estimator=grpo \ + trainer.policy.model.path="Qwen/Qwen3.5-9B" \ + trainer.flash_attn=true \ + trainer.loss_chunk_size=4096 \ + trainer.use_sample_packing=false \ + trainer.algorithm.loss_reduction="sequence_mean" \ + +generator.chat_template_kwargs='{enable_thinking:true}' \ + +generator.engine_init_kwargs.mm_processor_cache_gb=0 \ + generator.inference_engine_tensor_parallel_size=1 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=12 \ + trainer.eval_before_train=true \ + trainer.eval_interval=10 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=16 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=16 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_prompt_length=2048 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.9 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=[""]' \ + 'generator.eval_sampling_params.stop=[""]' \ + trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.algorithm.use_kl_loss=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=4 \ + generator.eval_n_samples_per_prompt=3 \ + generator.gpu_memory_utilization=0.80 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-browser-use-grpo" \ + trainer.run_name="fleet_qwen35_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/fleet_qwen35_${MODALITY}" \ + trainer.export_path="$HOME/exports" \ + trainer.dump_data_batch=true \ + "$@" diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py index 7dfc2f07ff..4d9295ce90 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -224,6 +224,84 @@ def __init__( "ContextManager not available, disabling context tools" ) + def _adapt_computer_tool_for_qwen(self): + """Adapt computer tool description for Qwen VL's [0, 1000] coordinate space. + + Qwen3-VL/3.5 output coordinates in a normalized [0, 1000] grid regardless + of screen resolution. This rewrites tool descriptions to match, and + _convert_qwen_coordinates() converts back to pixels before MCP execution. + """ + for tool in self.tools: + func = tool.get("function", {}) + if func.get("name") != "computer": + continue + + desc = func.get("description", "") + + # Parse actual screen dimensions + res_match = re.search(r"Screen resolution:\s*(\d+)x(\d+)", desc) + if res_match: + self.screen_width = int(res_match.group(1)) + self.screen_height = int(res_match.group(2)) + else: + self.screen_width = 1366 + self.screen_height = 768 + + w, h = self.screen_width, self.screen_height + + # Rewrite description for Qwen's [0, 1000] coordinate space + desc = re.sub( + r"Screen resolution:\s*\d+x\d+\s*pixels\s*(\([^)]*\))?", + "Screen resolution: 1000x1000", + desc, + ) + desc = re.sub( + r"\(0, 0\) is top-left,\s*\(\d+, \d+\) is bottom-right", + "(0, 0) is top-left, (999, 999) is bottom-right", + desc, + ) + desc = re.sub( + r"valid range: x=0-\d+, y=0-\d+", + "valid range: x=0-999, y=0-999", + desc, + ) + desc = re.sub( + r"JPEG format at \d+x\d+", + "JPEG format at 1000x1000", + desc, + ) + func["description"] = desc + + logger.info( + f"Adapted computer tool for Qwen VL: actual_screen={w}x{h}, " + f"model coordinate space=[0, 1000]" + ) + break + + def _convert_qwen_coordinates(self, tool_call: Dict[str, Any]): + """Convert Qwen's [0, 1000] normalized coordinates to pixel coordinates. + + Modifies tool_call arguments in-place. + """ + if not getattr(self, "screen_width", None) or not getattr( + self, "screen_height", None + ): + return + args = tool_call.get("arguments", {}) + if not args or tool_call.get("name") != "computer": + return + for field in ("coordinate", "start_coordinate"): + coords = args.get(field) + if ( + coords + and isinstance(coords, (list, tuple)) + and len(coords) == 2 + ): + args[field] = [ + int(coords[0] / 1000 * self.screen_width), + int(coords[1] / 1000 * self.screen_height), + ] + def _normalize_task_config(self) -> Dict[str, Any]: """Normalize task config to OpenEnv's expected format.""" config = self.task_config.copy() @@ -291,6 +369,11 @@ async def init_async( f"Task {self.task_key}: no tools found. Fleet env requires tools." ) + # VL: adapt computer tool for Qwen's normalized coordinate space + modality = self.task_config.get("task_modality", "tool_use") + if modality == "computer_use": + self._adapt_computer_tool_for_qwen() + # Build initial prompt with task instruction task_prompt = self.task_config.get("prompt", "") @@ -348,13 +431,29 @@ async def init_async( "information_schema.columns WHERE table_name = 'your_table'\n" ) + # Computer-use hints for VL models + computer_use_hints = "" + if modality == "computer_use": + computer_use_hints = ( + "\n## Browser Interaction Strategy\n" + "You are controlling a web browser via screenshots. Follow this loop:\n" + "1. **Act**: Perform ONE action (click, type, scroll, etc.)\n" + "2. **Observe**: Take a screenshot to see the result\n" + "3. **Think**: Analyze what happened and decide the next action\n\n" + "Tips:\n" + "- Always take a screenshot after each action to verify the result\n" + "- Click on elements by their visual position in the screenshot\n" + "- If an element is not visible, scroll to find it\n" + "- Use keyboard shortcuts when appropriate (Ctrl+A, Ctrl+C, etc.)\n" + ) + system_content = ( f"You are a helpful agent. Complete the task by calling tools.\n\n" f"## Current Date\n" f"Today's date is {current_date}. When dates are mentioned without " f"a year, assume the current year ({datetime.now().year}) or a " f"future date.\n" - f"{env_context}{env_hints}\n" + f"{env_context}{env_hints}{computer_use_hints}\n" f"## Available Tools\n{tools_json}\n\n" f"## Tool Call Format\n" f'{{"name": "tool_name", "arguments": ' @@ -379,7 +478,18 @@ async def init_async( ) system_message = {"role": "system", "content": system_content} - user_message = {"role": "user", "content": task_prompt} + + # VL: include initial screenshot in multimodal user message + initial_screenshot = obs.get("initial_screenshot") + if initial_screenshot and isinstance(initial_screenshot, list): + user_content = [{"type": "text", "text": task_prompt}] + for item in initial_screenshot: + if isinstance(item, dict) and item.get("type") == "image_url": + user_content.append(item) + user_message = {"role": "user", "content": user_content} + else: + user_message = {"role": "user", "content": task_prompt} + self.chat_history = [system_message, user_message] metadata = { @@ -432,6 +542,19 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: reward = 0.0 mcp_time = 0.0 + # VL: catch done signal wrapped in a computer tool call + if ( + not agent_done + and tool_call + and tool_call.get("arguments", {}).get("action") == "done" + ): + agent_done = True + tool_call = None + + # VL: convert Qwen [0,1000] coordinates to pixel coordinates + if tool_call and getattr(self, "screen_width", None): + self._convert_qwen_coordinates(tool_call) + # Handle context management tools locally (no MCP call) if ( tool_call diff --git a/skyrl/backends/skyrl_train/inference_engines/base.py b/skyrl/backends/skyrl_train/inference_engines/base.py index c3040fc1a2..7d687926d8 100644 --- a/skyrl/backends/skyrl_train/inference_engines/base.py +++ b/skyrl/backends/skyrl_train/inference_engines/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Hashable, List, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Dict, Hashable, List, Optional, TypedDict, Union if TYPE_CHECKING: from skyrl.backends.skyrl_train.weight_sync import WeightUpdateRequest @@ -7,7 +7,30 @@ WeightSyncInitInfo, ) -MessageType = Dict[str, str] + +# --- Multimodal Message Types (OpenAI-compatible) --- +class TextContent(TypedDict): + type: str # "text" + text: str + + +class ImageUrlContent(TypedDict): + url: str # "data:image/png;base64,..." or URL + + +class ImageContent(TypedDict): + type: str # "image_url" + image_url: ImageUrlContent + + +ContentType = Union[str, List[Union[TextContent, ImageContent]]] + + +class MessageType(TypedDict): + role: str + content: ContentType + + ConversationType = List[MessageType] @@ -17,6 +40,9 @@ class InferenceEngineInput(TypedDict): prompt_token_ids: Optional[List[List[int]]] sampling_params: Optional[Dict[str, Any]] session_ids: Optional[List[Hashable]] + # Multimodal data for VL models. Each element corresponds to a prompt. + # Format: {"image": [PIL.Image, ...]} or {"image": ["base64_string", ...]} + multi_modal_data: Optional[List[Optional[Dict[str, Any]]]] class InferenceEngineOutput(TypedDict): diff --git a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py index 5f11761b91..271110156d 100644 --- a/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py +++ b/skyrl/backends/skyrl_train/inference_engines/inference_engine_client.py @@ -95,6 +95,7 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu prompt_token_ids = input_batch.get("prompt_token_ids") session_ids = input_batch.get("session_ids") sampling_params = input_batch.get("sampling_params") + multi_modal_data = input_batch.get("multi_modal_data") if (prompts is None and prompt_token_ids is None) or (prompts is not None and prompt_token_ids is not None): raise ValueError("Either `prompts` or `prompt_token_ids` must be provided, but not both.") @@ -122,9 +123,11 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu for engine_idx, prompt_ids in engine_idx_to_prompt_ids.items(): # index prompt_token_ids with prompt_ids cur_prompt_token_ids = [prompt_token_ids[i] for i in prompt_ids] + cur_mm_data = [multi_modal_data[i] for i in prompt_ids] if multi_modal_data else None engine_input = InferenceEngineInput( prompt_token_ids=cur_prompt_token_ids, sampling_params=sampling_params, + multi_modal_data=cur_mm_data, ) tasks.append(asyncio.create_task(self.engines[engine_idx].generate(engine_input))) indices_list.append(prompt_ids) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 6d827ac327..fd28948fee 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -135,6 +135,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): prompts = input_batch.get("prompts") prompt_token_ids = input_batch.get("prompt_token_ids") request_sampling_params = input_batch.get("sampling_params") + multi_modal_data = input_batch.get("multi_modal_data") assert ( prompts is None and prompt_token_ids is not None @@ -144,7 +145,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): SamplingParams(**request_sampling_params) if request_sampling_params is not None else SamplingParams() ) - return prompt_token_ids, sampling_params + return prompt_token_ids, sampling_params, multi_modal_data def _postprocess_outputs(self, outputs): """Common output processing logic.""" @@ -247,7 +248,7 @@ def _create_engine(self, *args, **kwargs): return vllm.LLM(*args, **kwargs) async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: - prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) # Check if LoRA is enabled and create LoRA requests lora_requests = None @@ -261,9 +262,18 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path") ] * batch_size + # Build prompts with multimodal data for VL models + prompts = [] + for i, token_ids in enumerate(prompt_token_ids): + mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None + if mm_data: + prompts.append({"prompt_token_ids": token_ids, "multi_modal_data": mm_data}) + else: + prompts.append(TokensPrompt(prompt_token_ids=token_ids)) + outputs = await asyncio.to_thread( self.llm.generate, - prompts=[TokensPrompt(prompt_token_ids=r) for r in prompt_token_ids], + prompts=prompts, sampling_params=sampling_params, lora_request=lora_requests, ) @@ -460,7 +470,13 @@ async def _load_lora_from_disk(self, lora_path: str): result = await self.llm.add_lora(lora_request) return result - async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams): + async def _collect_outputs( + self, + prompt_token_ids, + request_id: str, + sampling_params: SamplingParams, + multi_modal_data: Optional[Dict[str, Any]] = None, + ): """Collect outputs for a single prompt.""" # Check if LoRA is enabled and create LoRA request final_output = None @@ -475,8 +491,16 @@ async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_par lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path" ) + # Build prompt with multimodal data for VL models + if multi_modal_data: + num_images = len(multi_modal_data.get("image", [])) + logger.info(f"VL generate: {num_images} images, {len(prompt_token_ids)} input tokens") + prompt = {"prompt_token_ids": prompt_token_ids, "multi_modal_data": multi_modal_data} + else: + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + async for request_output in self.llm.generate( - prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, @@ -487,14 +511,15 @@ async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_par async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: """Generate responses using vLLM's async engine.""" - prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) tasks = [] - for prompt in prompt_token_ids: + for i, prompt in enumerate(prompt_token_ids): # Schedule the collection of outputs for each prompt. # Avoid duplicate request_ids request_id = str(uuid4().hex) - task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params)) + mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None + task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params, mm_data)) tasks.append(task) outputs = await asyncio.gather(*tasks) diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index 9e2fe6e3da..aa81fd4ac3 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -230,6 +230,7 @@ def get_generator(self, cfg, tokenizer, inference_engine_client): skyrl_gym_cfg=cfg.environment.skyrl_gym, inference_engine_client=inference_engine_client, tokenizer=tokenizer, + model_name=cfg.trainer.policy.model.path, ) def get_trainer( diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 2592e81b24..a44f55a02d 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -32,10 +32,14 @@ TrajectoryID, ) from skyrl.train.generators.utils import ( + apply_chat_template_with_images, apply_overlong_filtering, + extract_images_from_conversation, get_custom_chat_template, get_generation_prompt_ids, get_rollout_metrics, + is_multimodal_conversation, + try_load_processor, ) from skyrl_gym.envs.base_text_env import BaseTextEnvStepOutput @@ -52,6 +56,7 @@ class TrajectoryOutput: rollout_logprobs: Optional[List[float]] env_metrics: Dict[str, Any] rollout_expert_indices: Optional[List[List[List[int]]]] = None + multi_modal_data: Optional[Dict[str, Any]] = None @dataclass @@ -70,6 +75,7 @@ class AgentLoopState: response_end_idx: Optional[int] done: bool rollout_expert_indices: Optional[List[List[List[int]]]] = None + accumulated_images: Optional[List[Any]] = None @dataclass @@ -140,17 +146,22 @@ def __init__( skyrl_gym_cfg: SkyRLGymConfig, inference_engine_client: InferenceEngineClient, tokenizer, + model_name: str = "", ): """ Args: generator_cfg: GeneratorConfig object containing the generator configuration inference_engine_client: InferenceEngineClient object for interacting with the inference engines tokenizer: tokenizer object for encoding and decoding text + model_name: HuggingFace model name (used for VL processor detection) """ self.generator_cfg = generator_cfg self.skyrl_gym_cfg = skyrl_gym_cfg self.inference_engine_client = inference_engine_client self.tokenizer = tokenizer + self.model_name = model_name + self.processor = try_load_processor(model_name) if model_name else None + self.is_vl_model = self.processor is not None self.max_turns = generator_cfg.max_turns self.batched = generator_cfg.batched self.use_conversation_multi_turn = generator_cfg.use_conversation_multi_turn @@ -212,6 +223,28 @@ def _validate_cfg(self, generator_cfg: GeneratorConfig): if not self.use_conversation_multi_turn: raise ValueError("`step_wise_trajectories` doesn't support `use_conversation_multi_turn=False`") + def _apply_chat_template( + self, + conversation: ConversationType, + add_generation_prompt: bool = True, + **kwargs, + ) -> List[int]: + """Apply chat template, routing to VL processor for multimodal conversations.""" + if self.is_vl_model and is_multimodal_conversation(conversation): + return apply_chat_template_with_images( + self.processor, + conversation, + add_generation_prompt=add_generation_prompt, + **kwargs, + ) + return self.tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=True, + return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) + async def _run_in_executor_if_available(self, func, *args, **kwargs): if (executor := self.env_executor) is not None: loop = asyncio.get_running_loop() @@ -275,16 +308,31 @@ async def agent_loop( # init() returns the first prompt to be given to the model, and optional metadata dict chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) initial_chat_history_length = len(chat_history) - initial_input_ids = self.tokenizer.apply_chat_template( - chat_history, - # If retokenize_chat_history==True, avoid including the generation prompt in both the - # prompt_ids and response_ids due to how `response_encodings["input_ids"]` works. - add_generation_prompt=not retokenize_chat_history, - chat_template=self.custom_chat_template if retokenize_chat_history else None, - tokenize=True, - return_dict=False, - **self.generator_cfg.chat_template_kwargs, + + # VL: extract images from initial prompt for multimodal models + initial_images = ( + extract_images_from_conversation(chat_history) + if self.is_vl_model and is_multimodal_conversation(chat_history) + else [] ) + if self.is_vl_model and initial_images: + logger.info(f"Session {session_id}: VL model, extracted {len(initial_images)} initial images") + + # Tokenize initial prompt (VL-aware for multimodal content) + if self.is_vl_model and is_multimodal_conversation(chat_history): + initial_input_ids = self._apply_chat_template( + chat_history, + add_generation_prompt=not retokenize_chat_history, + ) + else: + initial_input_ids = self.tokenizer.apply_chat_template( + chat_history, + add_generation_prompt=not retokenize_chat_history, + chat_template=self.custom_chat_template if retokenize_chat_history else None, + tokenize=True, + return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) initial_prompt_length = len(initial_input_ids) loss_mask = [] # this excludes the prompt @@ -311,6 +359,7 @@ async def agent_loop( rollout_logprobs=[] if get_logprobs else None, response_end_idx=None, done=False, + accumulated_images=initial_images if initial_images else None, ) while not agent_loop_state.done: @@ -333,8 +382,16 @@ async def agent_loop( agent_loop_state.loss_mask = [] agent_loop_state.rollout_logprobs = None + # VL: build multimodal data for engine input + mm_data = None + if agent_loop_state.accumulated_images: + mm_data = [{"image": agent_loop_state.accumulated_images}] + engine_input = InferenceEngineInput( - prompt_token_ids=[agent_loop_state.input_ids], session_ids=[session_id], sampling_params=sampling_params + prompt_token_ids=[agent_loop_state.input_ids], + session_ids=[session_id], + sampling_params=sampling_params, + multi_modal_data=mm_data, ) engine_output = await self.inference_engine_client.generate(engine_input) output = engine_output["responses"][0] @@ -385,6 +442,14 @@ async def agent_loop( obs_ids = self.get_obs_ids_from_obs(new_obs, agent_loop_state.done) + # VL: accumulate images from observations + if new_obs and is_multimodal_conversation(new_obs): + new_images = extract_images_from_conversation(new_obs) + if new_images: + if agent_loop_state.accumulated_images is None: + agent_loop_state.accumulated_images = [] + agent_loop_state.accumulated_images.extend(new_images) + # final turn output containing generated response and environment observations turn_output = TurnOutput( output=output, diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index 39410b9d6b..5dba17a259 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -1,7 +1,7 @@ import copy import os from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -574,3 +574,185 @@ def get_response_ids_and_loss_mask_from_messages( assert len(rollout_logprobs) == len(response_ids) if rollout_logprobs is not None else True return response_ids, loss_mask, rollout_logprobs + + +# --- Multimodal/VL Utilities --- + + +def is_multimodal_message(message: Dict[str, Any]) -> bool: + """Check if a message contains multimodal content (images).""" + content = message.get("content") + if isinstance(content, str): + return False + if isinstance(content, list): + return any(isinstance(item, dict) and item.get("type") == "image_url" for item in content) + return False + + +def is_multimodal_conversation(conversation: ConversationType) -> bool: + """Check if any message in a conversation contains multimodal content.""" + return any(is_multimodal_message(msg) for msg in conversation) + + +def extract_images_from_conversation(conversation: ConversationType) -> List[Any]: + """Extract all images from a conversation in order. + + Supports base64 data URLs, HTTP URLs, and local file paths. + Returns a list of PIL Images or image URL strings. + """ + images = [] + for message in conversation: + content = message.get("content") + if not isinstance(content, list): + continue + for item in content: + if not isinstance(item, dict) or item.get("type") != "image_url": + continue + image_url_data = item.get("image_url", {}) + url = image_url_data.get("url", "") + if url.startswith("data:image"): + images.append(decode_base64_image(url)) + elif url.startswith(("http://", "https://")): + images.append(url) + elif url: + images.append(load_image_from_path(url)) + return images + + +def decode_base64_image(data_url: str) -> "Image.Image": + """Decode a base64 image from a data URL.""" + import base64 + import io + + try: + from PIL import Image + except ImportError: + raise ImportError("PIL/Pillow is required for multimodal support. Install with: pip install pillow") + + if "," in data_url: + base64_data = data_url.split(",", 1)[1] + else: + base64_data = data_url + + image_bytes = base64.b64decode(base64_data) + return Image.open(io.BytesIO(image_bytes)) + + +def load_image_from_path(path: str) -> "Image.Image": + """Load an image from a file path.""" + try: + from PIL import Image + except ImportError: + raise ImportError("PIL/Pillow is required for multimodal support. Install with: pip install pillow") + + return Image.open(path) + + +def get_text_from_multimodal_content(content: Any) -> str: + """Extract text from multimodal content, ignoring images.""" + if isinstance(content, str): + return content + if isinstance(content, list): + texts = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + texts.append(item.get("text", "")) + return " ".join(texts) + return "" + + +def convert_to_text_only_conversation(conversation: ConversationType) -> ConversationType: + """Convert multimodal conversation to processor format (image_url -> image type).""" + text_only = [] + for message in conversation: + content = message.get("content") + if isinstance(content, str): + text_only.append(message) + elif isinstance(content, list): + new_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "image_url": + new_content.append({"type": "image"}) + else: + new_content.append(item) + else: + new_content.append(item) + text_only.append({"role": message["role"], "content": new_content}) + else: + text_only.append(message) + return text_only + + +def try_load_processor(model_name: str) -> Optional[Any]: + """Try to load a HuggingFace processor for VL models. Returns None for text-only models.""" + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + if hasattr(processor, "image_processor") or "VL" in type(processor).__name__: + logger.info(f"Loaded VL processor for model: {model_name}") + return processor + return None + except Exception as e: + logger.debug(f"No processor available for {model_name}: {e}") + return None + + +def apply_chat_template_with_images( + processor_or_tokenizer, + conversation: ConversationType, + add_generation_prompt: bool = True, + chat_template: Optional[str] = None, + **kwargs, +) -> List[int]: + """Apply chat template handling both text-only and multimodal conversations. + + For VL models (with processor), extracts images and uses the processor to get + correctly-sized token IDs (vision tokens expand based on image dimensions). + """ + has_processor = hasattr(processor_or_tokenizer, "image_processor") or hasattr( + processor_or_tokenizer, "tokenizer" + ) + + if has_processor and is_multimodal_conversation(conversation): + processor = processor_or_tokenizer + converted = convert_to_text_only_conversation(conversation) + images = extract_images_from_conversation(conversation) + + text = processor.apply_chat_template( + converted, + tokenize=False, + add_generation_prompt=add_generation_prompt, + **kwargs, + ) + + if images: + inputs = processor(text=text, images=images, return_tensors="pt") + token_ids = inputs.input_ids[0].tolist() + else: + tokenizer = getattr(processor, "tokenizer", processor) + token_ids = tokenizer.encode(text, add_special_tokens=False) + return token_ids + else: + tokenizer = getattr(processor_or_tokenizer, "tokenizer", processor_or_tokenizer) + + if is_multimodal_conversation(conversation): + text_conversation = [] + for msg in conversation: + content = msg.get("content") + if isinstance(content, list): + text = get_text_from_multimodal_content(content) + text_conversation.append({"role": msg["role"], "content": text}) + else: + text_conversation.append(msg) + conversation = text_conversation + + return tokenizer.apply_chat_template( + conversation, + add_generation_prompt=add_generation_prompt, + tokenize=True, + chat_template=chat_template, + return_dict=False, + **kwargs, + ) diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index 7acbbb06fa..c2039685ed 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -33,7 +33,7 @@ num_nodes: 2 workdir: url: https://github.com/fleet-ai/SkyRL-v2.git - ref: main + ref: fleet/all envs: WANDB_API_KEY: "" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml new file mode 100644 index 0000000000..af114f7bbc --- /dev/null +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -0,0 +1,62 @@ +# Fleet VL/CUA GRPO Training via SkyPilot - Qwen3.5-9B (Vision-Language) +# Usage: sky launch tasks/openenv-fleet-grpo-vl.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# VL (Vision-Language) training for computer_use environments with screenshots. +# Based on working config from SkyRL PR #288 (feat/vl-support-clean). +# +# Model: Qwen/Qwen3.5-9B (9B params, natively multimodal, GatedDeltaNet) +# GPUs: 8x H200 (single node, TP=1) +# +# NOTE: Requires vLLM >= 0.17.0 for native Qwen3.5/GDN support + FlashAttention 4 + +name: fleet-vl-grpo-qwen3-5-9b + +resources: + disk_size: 750 + ports: 6479 + any_of: + - accelerators: H200:8 + cloud: runpod + - accelerators: H200:8 + cloud: lambda + - accelerators: H200:8 + cloud: nebius + - accelerators: H200-SXM:8 + cloud: vast + +num_nodes: 1 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: fleet/all + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v52" + ENV_KEYS: "" + DIFFICULTY: "" + MODALITY: "computer_use" + MAX_TURNS: 50 + MAX_INPUT_LENGTH: 131072 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 10 + RUN_ID: "" + MAX_TASKS: "" + RESUME_RUN_NAME: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh + +run: | + bash scripts/fleet-vl-run.sh diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index bc1abdd2c5..36c5c5a73b 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -24,7 +24,7 @@ num_nodes: 1 workdir: url: https://github.com/fleet-ai/SkyRL-v2.git - ref: main + ref: fleet/all envs: WANDB_API_KEY: "" From 1af0a8e5d2561e1a5da4ca98e4c09028b6f12b10 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 23:16:35 -0700 Subject: [PATCH 006/121] Add GCP spot H200 option for VL YAML RunPod/Lambda/Nebius/Vast were all out of H200 capacity. Add GCP spot with proper NVIDIA 570 driver image. Co-Authored-By: Claude Opus 4.6 --- tasks/openenv-fleet-grpo-vl.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index af114f7bbc..04823bb0ea 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -15,6 +15,10 @@ resources: disk_size: 750 ports: 6479 any_of: + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 - accelerators: H200:8 cloud: runpod - accelerators: H200:8 From 9660f8c9a143f9011eef96e481d9378e5dd9084d Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 23:25:16 -0700 Subject: [PATCH 007/121] Fix setup: use fsdp extra instead of non-existent vllm extra SkyRL-v2 pyproject.toml defines 'fsdp' extra (includes vllm, flash-attn, torch, flashinfer) but not a standalone 'vllm' extra. The old SkyRL had 'vllm' as a separate extra. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-common-setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh index aaeea4b908..baf0290cd5 100755 --- a/scripts/fleet-common-setup.sh +++ b/scripts/fleet-common-setup.sh @@ -87,7 +87,7 @@ fi source .venv/bin/activate # vLLM 0.17.0 has native Qwen3.5 support (GDN via torch.ops.vllm.gdn_attention_core), # FlashAttention 4, and PyTorch 2.10.0 -uv sync --extra vllm +uv sync --extra fsdp uv pip install wandb boto3 awscli # Pin fleet-python<=0.2.119: 0.2.120+ has async BaseWrapper bug (missing jwt/team_id params) uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" From ce0cef4bb4d64aea7f5e7f86a17687308dd5560d Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 23:33:53 -0700 Subject: [PATCH 008/121] Fix causal-conv1d build: use pip instead of uv for CUDA extension uv pip install silently fails to build causal-conv1d CUDA extension (reports "Checked 1 package" but module is not importable). Use pip with --no-build-isolation to ensure it finds torch from the venv. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-qwen35-extra-setup.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/fleet-qwen35-extra-setup.sh b/scripts/fleet-qwen35-extra-setup.sh index 03367b8f64..66e77038fc 100755 --- a/scripts/fleet-qwen35-extra-setup.sh +++ b/scripts/fleet-qwen35-extra-setup.sh @@ -55,7 +55,9 @@ echo "export PATH=$CUDA_HOME/bin:\$PATH" >> "$HOME/.cuda_env" # Without it, fla-core falls back to a naive PyTorch implementation that crashes # with cudaErrorIllegalAddress on multi-node FSDP2 (Xid 31 MMU fault). # Must be built from source (needs nvcc + g++) — install AFTER CUDA toolkit setup. -uv pip install "causal-conv1d>=1.6.0" +# Build from source with --no-build-isolation so it finds torch from the venv. +# uv pip install can silently fail on CUDA extensions; use pip directly. +pip install --no-cache-dir --no-build-isolation "causal-conv1d>=1.6.0" python -c "import causal_conv1d; print(f'causal-conv1d OK: {causal_conv1d.__version__}')" # Verify pinned packages survived dependency resolution From 8b16d85e2602efcfe6b69ab4d51891056c23d0f1 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 28 Mar 2026 23:54:00 -0700 Subject: [PATCH 009/121] Fix cd path in run scripts for SkyRL-v2 repo layout In SkyRL-v2, scripts/ is directly under repo root (not nested under skyrl-train/). Changed cd from "../.." to ".." so the run scripts correctly resolve the repo root directory. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- scripts/fleet-vl-run.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index b183e317e0..37a33e4a70 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -5,7 +5,7 @@ # Required env vars: FLEET_API_KEY, WANDB_API_KEY # Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) set -euo pipefail -cd "$(dirname "$0")/../.." # cd to SkyRL root +cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo root) # Defaults for vars normally set by SkyPilot YAML envs block export LOGGER="${LOGGER:-wandb}" diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 5aaf456562..f610fac128 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -12,7 +12,7 @@ # Required env vars: FLEET_API_KEY, WANDB_API_KEY # Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) set -euo pipefail -cd "$(dirname "$0")/../.." # cd to SkyRL root +cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo root) # Defaults for vars normally set by SkyPilot YAML envs block export LOGGER="${LOGGER:-wandb}" From 282f5fe1c9a73b68e70ceb51004e203f74baf905 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 00:10:03 -0700 Subject: [PATCH 010/121] Add missing config fields for Fleet training overrides TrainerConfig: loss_chunk_size, use_hybrid_env_sampling, min_samples_per_env GeneratorConfig: inject_context_status, context_warning_threshold, trajectory_timeout_seconds SkyRL-v2's strict Hydra config rejects unknown keys (no + prefix), so these must be defined in the dataclass and YAML defaults. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/config/config.py | 12 ++++++++++++ skyrl/train/config/ppo_base_config.yaml | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 310f5da196..06a5334a5d 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -517,6 +517,12 @@ class GeneratorConfig(BaseConfig): """Can differ from the trainer's ``rope_scaling``, useful for thinking models.""" rope_theta: Optional[float] = None step_wise_trajectories: bool = False + inject_context_status: bool = False + """Inject context length status into the conversation.""" + context_warning_threshold: float = 0.90 + """Threshold for context length warning (fraction of max_input_length).""" + trajectory_timeout_seconds: Optional[int] = None + """Timeout in seconds for each trajectory rollout.""" def __post_init__(self): @@ -609,6 +615,12 @@ class TrainerConfig(BaseConfig): dump_eval_results: bool = True rope_scaling: Optional[Dict[str, Any]] = None rope_theta: Optional[float] = None + loss_chunk_size: Optional[int] = None + """Chunk size for loss computation to reduce memory usage.""" + use_hybrid_env_sampling: bool = False + """Enable hybrid environment sampling for multi-env training.""" + min_samples_per_env: int = 1 + """Minimum number of samples per environment in each batch.""" def __post_init__(self): # ref model defaults to the policy model diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index f2a52006e3..284dfbcb2b 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -255,6 +255,9 @@ trainer: # YaRN: rope_scaling: null rope_theta: null + loss_chunk_size: null + use_hybrid_env_sampling: false + min_samples_per_env: 1 # rope_scaling: # rope_type: yarn # factor: 1.0 @@ -381,6 +384,9 @@ generator: rope_theta: ${trainer.rope_theta} step_wise_trajectories: false + inject_context_status: false + context_warning_threshold: 0.90 + trajectory_timeout_seconds: null environment: env_class: "gsm8k" From d1342f460e57c2c63c46123b76849fecccdcc17e Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 00:20:07 -0700 Subject: [PATCH 011/121] Apply legacy config translation in Hydra entrypoints The fleet entrypoints use @hydra.main which loads the legacy YAML directly, but validate_cfg expects generator.inference_engine.* (the new structured format). Apply translate_legacy_config to convert flat generator.* keys before validation. Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/entrypoints/main_fleet.py | 9 +++++++++ integrations/fleet/entrypoints/main_task_gen.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py index 5f8d9c82f3..d8faa3c35f 100644 --- a/integrations/fleet/entrypoints/main_fleet.py +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -26,8 +26,10 @@ import hydra import ray +from omegaconf import OmegaConf from skyrl_gym.envs import register from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.config.legacy import is_legacy_config, translate_legacy_config from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir from skyrl.train.utils import validate_cfg from skyrl.train.utils.utils import initialize_ray @@ -85,6 +87,13 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: SkyRLTrainConfig) -> None: """Main entry point for Fleet task training.""" + # Hydra loads the legacy YAML with flat generator.* keys. + # Translate to the new generator.inference_engine.* structure + # before validate_cfg accesses those fields. + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + if is_legacy_config(cfg_dict): + cfg_dict = translate_legacy_config(cfg_dict) + cfg = OmegaConf.create(cfg_dict) validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) diff --git a/integrations/fleet/entrypoints/main_task_gen.py b/integrations/fleet/entrypoints/main_task_gen.py index 3731bc2327..fbf548f211 100644 --- a/integrations/fleet/entrypoints/main_task_gen.py +++ b/integrations/fleet/entrypoints/main_task_gen.py @@ -18,7 +18,9 @@ import hydra import ray +from omegaconf import OmegaConf from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.config.legacy import is_legacy_config, translate_legacy_config from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir from skyrl.train.utils import validate_cfg from skyrl.train.utils.utils import initialize_ray @@ -71,6 +73,13 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: SkyRLTrainConfig) -> None: """Main entry point for task generation training.""" + # Hydra loads the legacy YAML with flat generator.* keys. + # Translate to the new generator.inference_engine.* structure + # before validate_cfg accesses those fields. + cfg_dict = OmegaConf.to_container(cfg, resolve=True) + if is_legacy_config(cfg_dict): + cfg_dict = translate_legacy_config(cfg_dict) + cfg = OmegaConf.create(cfg_dict) validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) From 1f196e262119c4cc4ec6df17b4dd78d2be5ef3e6 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 00:28:04 -0700 Subject: [PATCH 012/121] Add inference_engine defaults to YAML for Hydra entrypoints The legacy YAML has flat generator.* keys (e.g. generator.backend) but validate_cfg expects generator.inference_engine.* with all fields including distributed_executor_backend. Add the full inference_engine section with defaults so all fields are present after Hydra loads the config and translate_legacy_config moves CLI overrides into it. Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/entrypoints/main_fleet.py | 7 ++-- .../fleet/entrypoints/main_task_gen.py | 7 ++-- skyrl/train/config/ppo_base_config.yaml | 36 +++++++++++++++++++ 3 files changed, 44 insertions(+), 6 deletions(-) diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py index d8faa3c35f..af5e4edca7 100644 --- a/integrations/fleet/entrypoints/main_fleet.py +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -87,9 +87,10 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: SkyRLTrainConfig) -> None: """Main entry point for Fleet task training.""" - # Hydra loads the legacy YAML with flat generator.* keys. - # Translate to the new generator.inference_engine.* structure - # before validate_cfg accesses those fields. + # Hydra loads the legacy YAML with flat generator.* keys (e.g. generator.backend). + # validate_cfg expects the new generator.inference_engine.* structure. + # Convert to dict, apply legacy translation, then build the full typed config + # so all dataclass defaults (like distributed_executor_backend) are populated. cfg_dict = OmegaConf.to_container(cfg, resolve=True) if is_legacy_config(cfg_dict): cfg_dict = translate_legacy_config(cfg_dict) diff --git a/integrations/fleet/entrypoints/main_task_gen.py b/integrations/fleet/entrypoints/main_task_gen.py index fbf548f211..4f9a918bb4 100644 --- a/integrations/fleet/entrypoints/main_task_gen.py +++ b/integrations/fleet/entrypoints/main_task_gen.py @@ -73,9 +73,10 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: SkyRLTrainConfig) -> None: """Main entry point for task generation training.""" - # Hydra loads the legacy YAML with flat generator.* keys. - # Translate to the new generator.inference_engine.* structure - # before validate_cfg accesses those fields. + # Hydra loads the legacy YAML with flat generator.* keys (e.g. generator.backend). + # validate_cfg expects the new generator.inference_engine.* structure. + # Convert to dict, apply legacy translation, then build the full typed config + # so all dataclass defaults (like distributed_executor_backend) are populated. cfg_dict = OmegaConf.to_container(cfg, resolve=True) if is_legacy_config(cfg_dict): cfg_dict = translate_legacy_config(cfg_dict) diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 284dfbcb2b..59d3df958d 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -265,6 +265,42 @@ trainer: generator: + # New structured inference_engine config (provides defaults for validate_cfg). + # Legacy flat fields below are translated into this section at runtime. + inference_engine: + model_dtype: "bfloat16" + run_engines_locally: true + num_engines: 1 + backend: "vllm" + weight_sync_backend: "nccl" + weight_transfer_threshold_cuda_ipc_GB: 1.0 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + expert_parallel_size: 1 + data_parallel_size: 1 + async_engine: true + vllm_v1_disable_multiproc: true + enable_prefix_caching: true + enable_chunked_prefill: true + enable_return_routed_experts: false + max_num_batched_tokens: 8192 + enforce_eager: true + fully_sharded_loras: false + enable_ray_prometheus_stats: false + gpu_memory_utilization: 0.8 + max_num_seqs: 1024 + remote_urls: ["127.0.0.1:8001"] + enable_http_endpoint: false + http_endpoint_host: "127.0.0.1" + http_endpoint_port: 8000 + served_model_name: null + distributed_executor_backend: "ray" + engine_init_kwargs: {} + override_existing_update_group: "auto" + external_proxy_url: null + external_server_urls: null + + # ---- Legacy flat fields (kept for backward compat; translated at runtime) ---- model_name: ${trainer.policy.model.path} model_dtype: "bfloat16" # should match dtype for inference engine run_engines_locally: true From f223bba479dc7727f8de62f45fe033c60fd75384 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 00:52:07 -0700 Subject: [PATCH 013/121] Fix fleet_task double registration and task-gen data path - Remove explicit fleet_task register() call from main_fleet.py since skyrl_gym.envs.__init__ already auto-registers it - Remove --data-dir-name task_gen from task-gen run script so it uses the default MODALITY-based path (matching setup's download path) Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/entrypoints/main_fleet.py | 8 ++------ scripts/fleet-task-gen-run.sh | 3 +-- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py index af5e4edca7..cb08eebbe9 100644 --- a/integrations/fleet/entrypoints/main_fleet.py +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -27,7 +27,6 @@ import hydra import ray from omegaconf import OmegaConf -from skyrl_gym.envs import register from skyrl.train.config import SkyRLTrainConfig from skyrl.train.config.legacy import is_legacy_config, translate_legacy_config from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir @@ -75,11 +74,8 @@ def run(self): @ray.remote(num_cpus=1) def skyrl_entrypoint(cfg: SkyRLTrainConfig): - """Ray remote function that registers Fleet environment and runs training.""" - register( - id="fleet_task", - entry_point="skyrl_gym.envs.fleet_task.env:FleetTaskEnv", - ) + """Ray remote function that runs Fleet training.""" + # fleet_task env is auto-registered by skyrl_gym.envs.__init__ exp = FleetPPOExp(cfg) exp.run() diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh index e066ef2d66..e121499f00 100755 --- a/scripts/fleet-task-gen-run.sh +++ b/scripts/fleet-task-gen-run.sh @@ -24,8 +24,7 @@ bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ --set-ulimit --no-pytorch-alloc-conf \ --entrypoint integrations.fleet.entrypoints.main_task_gen \ - --env-class task_gen \ - --data-dir-name task_gen -- \ + --env-class task_gen -- \ trainer.algorithm.advantage_estimator="grpo" \ trainer.policy.model.path="Qwen/Qwen3.5-9B" \ trainer.flash_attn=true \ From 4b6c15e64ab68d8d4c6789cc8647061a824e5419 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 01:13:50 -0700 Subject: [PATCH 014/121] Fix legacy config sync, registration, and task-gen data path - Replace OmegaConf.create() approach (loses dataclass type info) with in-place sync of flat generator.* CLI overrides into the structured generator.inference_engine section. This preserves the Hydra DictConfig and avoids TypeError on dataclasses.asdict(). - Remove --skip-prepare from task-gen YAML so parquet files are generated - Remove duplicate fleet_task registration (auto-registered by __init__) Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/entrypoints/main_fleet.py | 43 ++++++++++++++----- .../fleet/entrypoints/main_task_gen.py | 13 ++---- tasks/task-gen-grpo-qwen3_5-9b.yaml | 3 +- 3 files changed, 37 insertions(+), 22 deletions(-) diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py index cb08eebbe9..cb7c94ce24 100644 --- a/integrations/fleet/entrypoints/main_fleet.py +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -26,9 +26,9 @@ import hydra import ray -from omegaconf import OmegaConf +from omegaconf import OmegaConf, open_dict from skyrl.train.config import SkyRLTrainConfig -from skyrl.train.config.legacy import is_legacy_config, translate_legacy_config +from skyrl.train.config.legacy import GENERATOR_TO_INFERENCE_ENGINE_FIELDS from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir from skyrl.train.utils import validate_cfg from skyrl.train.utils.utils import initialize_ray @@ -36,6 +36,32 @@ logger = logging.getLogger(__name__) +def _sync_legacy_generator_to_inference_engine(cfg): + """Sync flat legacy generator.* CLI overrides into generator.inference_engine.*. + + The YAML has both flat legacy keys (generator.backend, etc.) and the structured + generator.inference_engine section. CLI args override the flat keys, but + validate_cfg reads from the structured section. This function copies the flat + values into inference_engine so both stay in sync. + """ + gen = cfg.generator + with open_dict(gen): + if not OmegaConf.is_missing(gen, "inference_engine"): + ie = gen.inference_engine + for old_field, new_field in GENERATOR_TO_INFERENCE_ENGINE_FIELDS.items(): + if OmegaConf.is_missing(gen, old_field): + continue + try: + value = getattr(gen, old_field) + except Exception: + continue + target_field = new_field if new_field else old_field + try: + setattr(ie, target_field, value) + except Exception: + pass + + class FleetPPOExp(BasePPOExp): """Fleet-specific PPO experiment with S3 checkpoint management.""" @@ -83,14 +109,11 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: SkyRLTrainConfig) -> None: """Main entry point for Fleet task training.""" - # Hydra loads the legacy YAML with flat generator.* keys (e.g. generator.backend). - # validate_cfg expects the new generator.inference_engine.* structure. - # Convert to dict, apply legacy translation, then build the full typed config - # so all dataclass defaults (like distributed_executor_backend) are populated. - cfg_dict = OmegaConf.to_container(cfg, resolve=True) - if is_legacy_config(cfg_dict): - cfg_dict = translate_legacy_config(cfg_dict) - cfg = OmegaConf.create(cfg_dict) + # Hydra loads the legacy YAML with flat generator.* keys (e.g. generator.backend) + # that are also overridden by CLI args. The YAML also has a structured + # generator.inference_engine section with defaults. Sync flat CLI overrides + # into the structured section so validate_cfg sees the right values. + _sync_legacy_generator_to_inference_engine(cfg) validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) diff --git a/integrations/fleet/entrypoints/main_task_gen.py b/integrations/fleet/entrypoints/main_task_gen.py index 4f9a918bb4..adaa16a6a0 100644 --- a/integrations/fleet/entrypoints/main_task_gen.py +++ b/integrations/fleet/entrypoints/main_task_gen.py @@ -18,9 +18,7 @@ import hydra import ray -from omegaconf import OmegaConf from skyrl.train.config import SkyRLTrainConfig -from skyrl.train.config.legacy import is_legacy_config, translate_legacy_config from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir from skyrl.train.utils import validate_cfg from skyrl.train.utils.utils import initialize_ray @@ -73,14 +71,9 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: SkyRLTrainConfig) -> None: """Main entry point for task generation training.""" - # Hydra loads the legacy YAML with flat generator.* keys (e.g. generator.backend). - # validate_cfg expects the new generator.inference_engine.* structure. - # Convert to dict, apply legacy translation, then build the full typed config - # so all dataclass defaults (like distributed_executor_backend) are populated. - cfg_dict = OmegaConf.to_container(cfg, resolve=True) - if is_legacy_config(cfg_dict): - cfg_dict = translate_legacy_config(cfg_dict) - cfg = OmegaConf.create(cfg_dict) + # Import and call the shared sync function from main_fleet + from integrations.fleet.entrypoints.main_fleet import _sync_legacy_generator_to_inference_engine + _sync_legacy_generator_to_inference_engine(cfg) validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index 36c5c5a73b..d25a59bbbd 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -52,8 +52,7 @@ envs: setup: | bash scripts/fleet-common-setup.sh \ --openenv-branch deniz/fleet_client \ - --extra-setup scripts/fleet-qwen35-extra-setup.sh \ - --skip-prepare + --extra-setup scripts/fleet-qwen35-extra-setup.sh run: | bash scripts/fleet-task-gen-run.sh From 0dc3a177a45b151165f0d1412e98e4843cc82439 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 01:37:18 -0700 Subject: [PATCH 015/121] Handle OmegaConf DictConfig in get_config_as_yaml_str Hydra entrypoints pass DictConfig (not dataclass instances), so dataclasses.asdict() fails. Fall back to OmegaConf.to_yaml() for DictConfig objects. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/config/config.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 06a5334a5d..8e43ea7f4d 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -885,5 +885,14 @@ def get_config_as_dict(cfg: Union[dict, BaseConfig]) -> dict: return asdict(cfg) -def get_config_as_yaml_str(cfg: BaseConfig) -> str: - return yaml.dump(asdict(cfg)) +def get_config_as_yaml_str(cfg) -> str: + if dataclasses.is_dataclass(cfg) and not isinstance(cfg, type): + return yaml.dump(asdict(cfg)) + # Handle OmegaConf DictConfig (from Hydra entrypoints) + try: + from omegaconf import OmegaConf + if OmegaConf.is_config(cfg): + return OmegaConf.to_yaml(cfg, resolve=True) + except ImportError: + pass + return str(cfg) From 8393415a3e1d21919d909e25adf616c91d910410 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 01:46:42 -0700 Subject: [PATCH 016/121] Replace @hydra.main with from_cli_overrides in Fleet entrypoints Hydra's @hydra.main produces DictConfig objects, but the codebase expects typed dataclass instances (asdict(), attribute access, etc.). Switch Fleet entrypoints to use SkyRLTrainConfig.from_cli_overrides() which produces proper typed dataclasses via the legacy config translation path. - Add fleet_task/task_gen as Optional[Dict] fields on SkyRLGymConfig - Strip ++/+ Hydra prefixes from CLI args before from_cli_overrides - Remove _sync_legacy_generator_to_inference_engine (legacy path handles it) Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/entrypoints/main_fleet.py | 55 ++++++++----------- .../fleet/entrypoints/main_task_gen.py | 16 +++--- skyrl/train/config/config.py | 2 + 3 files changed, 32 insertions(+), 41 deletions(-) diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py index cb7c94ce24..a2a1c1b2ac 100644 --- a/integrations/fleet/entrypoints/main_fleet.py +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -22,44 +22,34 @@ import asyncio import logging import os +import sys from pathlib import Path -import hydra import ray -from omegaconf import OmegaConf, open_dict from skyrl.train.config import SkyRLTrainConfig -from skyrl.train.config.legacy import GENERATOR_TO_INFERENCE_ENGINE_FIELDS -from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir +from skyrl.train.entrypoints.main_base import BasePPOExp from skyrl.train.utils import validate_cfg from skyrl.train.utils.utils import initialize_ray logger = logging.getLogger(__name__) -def _sync_legacy_generator_to_inference_engine(cfg): - """Sync flat legacy generator.* CLI overrides into generator.inference_engine.*. +def _strip_hydra_prefixes(args: list[str]) -> list[str]: + """Strip Hydra ++ and + prefixes from CLI args. - The YAML has both flat legacy keys (generator.backend, etc.) and the structured - generator.inference_engine section. CLI args override the flat keys, but - validate_cfg reads from the structured section. This function copies the flat - values into inference_engine so both stay in sync. + from_cli_overrides rejects +/++ prefixed args, but our run scripts use + them for environment-specific config (e.g. ++environment.skyrl_gym.task_gen.*). + Since these fields now exist in the dataclass, we can strip the prefix. """ - gen = cfg.generator - with open_dict(gen): - if not OmegaConf.is_missing(gen, "inference_engine"): - ie = gen.inference_engine - for old_field, new_field in GENERATOR_TO_INFERENCE_ENGINE_FIELDS.items(): - if OmegaConf.is_missing(gen, old_field): - continue - try: - value = getattr(gen, old_field) - except Exception: - continue - target_field = new_field if new_field else old_field - try: - setattr(ie, target_field, value) - except Exception: - pass + cleaned = [] + for arg in args: + if arg.startswith("++"): + cleaned.append(arg[2:]) + elif arg.startswith("+"): + cleaned.append(arg[1:]) + else: + cleaned.append(arg) + return cleaned class FleetPPOExp(BasePPOExp): @@ -106,14 +96,13 @@ def skyrl_entrypoint(cfg: SkyRLTrainConfig): exp.run() -@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) -def main(cfg: SkyRLTrainConfig) -> None: +def main() -> None: """Main entry point for Fleet task training.""" - # Hydra loads the legacy YAML with flat generator.* keys (e.g. generator.backend) - # that are also overridden by CLI args. The YAML also has a structured - # generator.inference_engine section with defaults. Sync flat CLI overrides - # into the structured section so validate_cfg sees the right values. - _sync_legacy_generator_to_inference_engine(cfg) + # Strip ++/+ prefixes from CLI args (used for env-specific config keys + # that now have proper dataclass fields) + args = _strip_hydra_prefixes(sys.argv[1:]) + # Build typed dataclass config (handles legacy flat→nested translation) + cfg = SkyRLTrainConfig.from_cli_overrides(args) validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) diff --git a/integrations/fleet/entrypoints/main_task_gen.py b/integrations/fleet/entrypoints/main_task_gen.py index adaa16a6a0..1c19fabd63 100644 --- a/integrations/fleet/entrypoints/main_task_gen.py +++ b/integrations/fleet/entrypoints/main_task_gen.py @@ -14,12 +14,12 @@ import asyncio import logging import os +import sys from pathlib import Path -import hydra import ray from skyrl.train.config import SkyRLTrainConfig -from skyrl.train.entrypoints.main_base import BasePPOExp, config_dir +from skyrl.train.entrypoints.main_base import BasePPOExp from skyrl.train.utils import validate_cfg from skyrl.train.utils.utils import initialize_ray @@ -63,17 +63,17 @@ def run(self): @ray.remote(num_cpus=1) def skyrl_entrypoint(cfg: SkyRLTrainConfig): """Ray remote function that registers TaskGenEnv and runs training.""" - # task_gen env is registered in skyrl_gym.envs.__init__ (after PR 3) + # task_gen env is registered in skyrl_gym.envs.__init__ exp = FleetPPOExp(cfg) exp.run() -@hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) -def main(cfg: SkyRLTrainConfig) -> None: +def main() -> None: """Main entry point for task generation training.""" - # Import and call the shared sync function from main_fleet - from integrations.fleet.entrypoints.main_fleet import _sync_legacy_generator_to_inference_engine - _sync_legacy_generator_to_inference_engine(cfg) + from integrations.fleet.entrypoints.main_fleet import _strip_hydra_prefixes + + args = _strip_hydra_prefixes(sys.argv[1:]) + cfg = SkyRLTrainConfig.from_cli_overrides(args) validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 8e43ea7f4d..09c5dabd62 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -550,6 +550,8 @@ class SkyRLGymConfig(BaseConfig): text2sql: Text2SQLEnvConfig = field(default_factory=Text2SQLEnvConfig) llm_as_a_judge: GSM8kLLMJudgeEnvConfig = field(default_factory=GSM8kLLMJudgeEnvConfig) search: SearchEnvConfig = field(default_factory=SearchEnvConfig) + fleet_task: Optional[Dict[str, Any]] = None + task_gen: Optional[Dict[str, Any]] = None @dataclass From 061fcdcd9b05c846f225b0bc270bd9001b42c712 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 01:54:00 -0700 Subject: [PATCH 017/121] Upgrade accelerate in extra-setup to fix _is_hf_initialized TypeError accelerate 1.12.0 passes param.__dict__ (which includes transformers 5.3.0's _is_hf_initialized flag) to Parameter.__new__() during init_empty_weights. PyTorch 2.10.0 rejects this unknown kwarg. Newer accelerate filters it. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-qwen35-extra-setup.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/fleet-qwen35-extra-setup.sh b/scripts/fleet-qwen35-extra-setup.sh index 66e77038fc..c1f9a52c7e 100755 --- a/scripts/fleet-qwen35-extra-setup.sh +++ b/scripts/fleet-qwen35-extra-setup.sh @@ -9,6 +9,10 @@ # - 5.1.0 doesn't register qwen3_5_moe in AUTO_CONFIG_MAPPING. # - 5.3.0 is the first stable release with full qwen3_5_moe support. # - Do NOT install from git main (renamed layer_type_validation, breaks vLLM 0.17). +# Upgrade accelerate first — 1.12.0 (from uv.lock) passes param.__dict__ kwargs +# (including _is_hf_initialized from transformers 5.x) to Parameter.__new__() which +# torch 2.10 rejects. Newer accelerate filters these kwargs properly. +uv pip install -U accelerate uv pip install -U "transformers==5.3.0" # flash-attn 2.8.3 prebuilt wheel for torch 2.10 + CUDA 12 (training forward/backward) From d0112329ac06d8ee28a09dacff5833be17c6a2a5 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 02:03:25 -0700 Subject: [PATCH 018/121] Fix accelerate install: use --no-deps to avoid torch re-resolution uv pip install -U accelerate pulls newer torch with CUDA 13.0, breaking torchvision (CUDA 12.8). Use pip install --no-deps instead to upgrade only accelerate without re-resolving transitive dependencies. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-qwen35-extra-setup.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-qwen35-extra-setup.sh b/scripts/fleet-qwen35-extra-setup.sh index c1f9a52c7e..4a5be02e43 100755 --- a/scripts/fleet-qwen35-extra-setup.sh +++ b/scripts/fleet-qwen35-extra-setup.sh @@ -9,10 +9,11 @@ # - 5.1.0 doesn't register qwen3_5_moe in AUTO_CONFIG_MAPPING. # - 5.3.0 is the first stable release with full qwen3_5_moe support. # - Do NOT install from git main (renamed layer_type_validation, breaks vLLM 0.17). -# Upgrade accelerate first — 1.12.0 (from uv.lock) passes param.__dict__ kwargs +# Upgrade accelerate — 1.12.0 (from uv.lock) passes param.__dict__ kwargs # (including _is_hf_initialized from transformers 5.x) to Parameter.__new__() which # torch 2.10 rejects. Newer accelerate filters these kwargs properly. -uv pip install -U accelerate +# Use --no-deps to avoid re-resolving torch (uv would pull CUDA 13.0 torch, breaking torchvision). +pip install --no-deps "accelerate>=1.14" uv pip install -U "transformers==5.3.0" # flash-attn 2.8.3 prebuilt wheel for torch 2.10 + CUDA 12 (training forward/backward) From baa37a5f37f4ea05214c52dbd6dca7fc8400c5e8 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 02:16:26 -0700 Subject: [PATCH 019/121] Patch Parameter.__new__ to fix _is_hf_initialized TypeError accelerate's init_empty_weights passes param.__dict__ to Parameter() which includes _is_hf_initialized (set by transformers 5.x). torch 2.10 rejects this unknown kwarg. Patch Parameter.__new__ in fsdp_utils.py to filter it out. Revert accelerate upgrade attempt (latest is 1.13.0, still has the same issue). Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-qwen35-extra-setup.sh | 5 ----- .../backends/skyrl_train/distributed/fsdp_utils.py | 14 ++++++++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/scripts/fleet-qwen35-extra-setup.sh b/scripts/fleet-qwen35-extra-setup.sh index 4a5be02e43..66e77038fc 100755 --- a/scripts/fleet-qwen35-extra-setup.sh +++ b/scripts/fleet-qwen35-extra-setup.sh @@ -9,11 +9,6 @@ # - 5.1.0 doesn't register qwen3_5_moe in AUTO_CONFIG_MAPPING. # - 5.3.0 is the first stable release with full qwen3_5_moe support. # - Do NOT install from git main (renamed layer_type_validation, breaks vLLM 0.17). -# Upgrade accelerate — 1.12.0 (from uv.lock) passes param.__dict__ kwargs -# (including _is_hf_initialized from transformers 5.x) to Parameter.__new__() which -# torch 2.10 rejects. Newer accelerate filters these kwargs properly. -# Use --no-deps to avoid re-resolving torch (uv would pull CUDA 13.0 torch, breaking torchvision). -pip install --no-deps "accelerate>=1.14" uv pip install -U "transformers==5.3.0" # flash-attn 2.8.3 prebuilt wheel for torch 2.10 + CUDA 12 (training forward/backward) diff --git a/skyrl/backends/skyrl_train/distributed/fsdp_utils.py b/skyrl/backends/skyrl_train/distributed/fsdp_utils.py index 449cf0266e..c192632a88 100644 --- a/skyrl/backends/skyrl_train/distributed/fsdp_utils.py +++ b/skyrl/backends/skyrl_train/distributed/fsdp_utils.py @@ -23,6 +23,20 @@ import torch import torch.distributed as dist import torch.nn as nn + +# Patch torch.nn.Parameter.__new__ to accept and ignore _is_hf_initialized. +# accelerate's init_empty_weights passes param.__dict__ (which includes +# _is_hf_initialized set by transformers 5.x) to Parameter(), but torch 2.10 +# rejects unknown kwargs. This patch filters them out. +_orig_param_new = torch.nn.Parameter.__new__ + + +def _patched_param_new(cls, *args, **kwargs): + kwargs.pop("_is_hf_initialized", None) + return _orig_param_new(cls, *args, **kwargs) + + +torch.nn.Parameter.__new__ = _patched_param_new from omegaconf import DictConfig from packaging import version from peft.utils.save_and_load import get_peft_model_state_dict From acafea3cd4b01afc958010d5963ccded83216230 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 03:08:46 -0700 Subject: [PATCH 020/121] Fix config parsing and per-record env_class for fleet training - config.py: Always use legacy config path in from_cli_overrides to ensure flat keys (generator.backend etc.) are properly translated via translate_legacy_config. Fixes VL/35B ValueError on GeneratorConfig. - prepare_dataset.py: Add --env-class CLI arg (fleet_task|task_gen) to set per-record env_class in parquet data. Previously hardcoded to fleet_task, causing task_gen training to create FleetTaskEnv (requires tasks_file). - fleet-common-setup.sh: Accept --env-class and pass to prepare_dataset. - task-gen YAML: Pass --env-class task_gen in setup block. Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/prepare_dataset.py | 20 +++++++++---- scripts/fleet-common-setup.sh | 4 ++- skyrl/train/config/config.py | 42 +++++++++++++++------------ tasks/task-gen-grpo-qwen3_5-9b.yaml | 3 +- 4 files changed, 44 insertions(+), 25 deletions(-) diff --git a/integrations/fleet/prepare_dataset.py b/integrations/fleet/prepare_dataset.py index cdc465b8d7..67f02199a9 100644 --- a/integrations/fleet/prepare_dataset.py +++ b/integrations/fleet/prepare_dataset.py @@ -191,6 +191,7 @@ def prepare_fleet_dataset( max_tasks: Optional[int] = None, max_env_ratio: float = MAX_ENV_TRAIN_RATIO, # v0.3.1: cap dominant environments max_eval_prompts: Optional[int] = MAX_EVAL_PROMPTS, # v0.3.2: cap total eval prompts + env_class: str = "fleet_task", # SkyRL env_class per record (fleet_task or task_gen) ): """ Convert Fleet tasks JSON to SkyRL parquet dataset. @@ -203,6 +204,7 @@ def prepare_fleet_dataset( env_filter: Optional env_key filter (e.g., "github", "booking") max_tasks: Optional maximum number of tasks to include max_env_ratio: Maximum fraction any single env can contribute to training (default: 0.20) + env_class: SkyRL env_class per record (default: "fleet_task", use "task_gen" for task generation) """ # Log applied filters at the start print("\n=== Dataset Filters ===") @@ -315,7 +317,7 @@ def prepare_fleet_dataset( if env_key in held_out_envs: env_eval_count = 0 for task in env_tasks: - record = _task_to_record(task, env_key) + record = _task_to_record(task, env_key, env_class=env_class) if record: eval_records.append(record) env_eval_count += 1 @@ -330,7 +332,7 @@ def prepare_fleet_dataset( if target_eval_size < MIN_EVAL_SAMPLES: env_train_count = 0 for task in env_tasks: - record = _task_to_record(task, env_key) + record = _task_to_record(task, env_key, env_class=env_class) if record: train_records.append(record) env_train_count += 1 @@ -346,7 +348,7 @@ def prepare_fleet_dataset( env_eval = 0 for task in env_tasks: task_key = task.get("key") or task.get("task_key") - record = _task_to_record(task, env_key) + record = _task_to_record(task, env_key, env_class=env_class) if not record: continue @@ -465,7 +467,7 @@ def prepare_fleet_dataset( ) -def _task_to_record(task: Dict[str, Any], env_key: str) -> Optional[Dict[str, Any]]: +def _task_to_record(task: Dict[str, Any], env_key: str, env_class: str = "fleet_task") -> Optional[Dict[str, Any]]: """Convert a task dict to a dataset record.""" task_key = task.get("key") or task.get("task_key") prompt = task.get("prompt", "") @@ -476,7 +478,7 @@ def _task_to_record(task: Dict[str, Any], env_key: str) -> Optional[Dict[str, An return { # Required fields for SkyRL "prompt": [{"role": "user", "content": prompt}], - "env_class": "fleet_task", # This tells SkyRL to use FleetTaskEnv + "env_class": env_class, # Task identification (passed as env_extras) "task_key": task_key, # Data source for per-environment metrics in WandB @@ -541,6 +543,13 @@ def main(): default=MAX_EVAL_PROMPTS, help=f"Maximum total eval prompts across all environments (default: {MAX_EVAL_PROMPTS})", ) + parser.add_argument( + "--env-class", + type=str, + default="fleet_task", + choices=["fleet_task", "task_gen"], + help="SkyRL env_class per record (default: fleet_task, use task_gen for task generation)", + ) args = parser.parse_args() @@ -557,6 +566,7 @@ def main(): max_tasks=args.max_tasks, max_env_ratio=args.max_env_ratio, max_eval_prompts=args.max_eval_prompts, + env_class=args.env_class, ) diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh index baf0290cd5..37fb55ecc8 100755 --- a/scripts/fleet-common-setup.sh +++ b/scripts/fleet-common-setup.sh @@ -18,6 +18,7 @@ DATA_ROOT="" SKIP_UV_ISOLATED=false EXTRA_PIP="" SKIP_PREPARE=false +ENV_CLASS="fleet_task" # Parse args while [[ $# -gt 0 ]]; do @@ -28,6 +29,7 @@ while [[ $# -gt 0 ]]; do --skip-uv-isolated) SKIP_UV_ISOLATED=true; shift ;; --extra-pip) EXTRA_PIP="$2"; shift 2 ;; --skip-prepare) SKIP_PREPARE=true; shift ;; + --env-class) ENV_CLASS="$2"; shift 2 ;; *) echo "ERROR: Unknown arg: $1"; exit 1 ;; esac done @@ -121,7 +123,7 @@ if [ "$SKIP_PREPARE" = true ]; then echo "Skipping prepare_dataset (--skip-prepare). Caller handles preparation." else DATA_DIR="${DATA_ROOT}/data/fleet/${MODALITY}" - PREPARE_CMD="python -m integrations.fleet.prepare_dataset --tasks-json $TASKS_FILE --output-dir $DATA_DIR --modality $MODALITY" + PREPARE_CMD="python -m integrations.fleet.prepare_dataset --tasks-json $TASKS_FILE --output-dir $DATA_DIR --modality $MODALITY --env-class $ENV_CLASS" [ -n "${ENV_KEYS:-}" ] && PREPARE_CMD="$PREPARE_CMD --env-filter $ENV_KEYS" [ -n "${DIFFICULTY:-}" ] && PREPARE_CMD="$PREPARE_CMD --difficulty-filter $DIFFICULTY" eval "$PREPARE_CMD" diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 09c5dabd62..da4f101b6b 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -800,25 +800,31 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig": ) overrides = OmegaConf.from_cli(args) - # Try new format first + # Always load base config and merge overrides. + # Our run scripts use legacy flat keys (e.g. generator.backend) that + # need translation to the new nested format (generator.inference_engine.backend). + # The direct from_dict_config path only works for fully-qualified new-format keys. try: - return cls.from_dict_config(overrides) - except ValueError: - # Fall back to legacy format: load base YAML, merge overrides, translate - try: - base_cfg = get_legacy_config() - merged = OmegaConf.merge(base_cfg, overrides) - merged_dict = OmegaConf.to_container(merged, resolve=True) - - if is_legacy_config(merged_dict): - warn_legacy_config() - translated = translate_legacy_config(merged_dict) - return build_nested_dataclass(cls, translated) - except Exception: - pass # Legacy translation failed, re-raise original error - - # Re-raise original error if not a legacy config issue - raise + base_cfg = get_legacy_config() + except Exception: + # Hydra compose can fail (e.g., GlobalHydra already initialized). + # Fall back to loading YAML directly without Hydra defaults resolution. + import yaml + config_yaml = Path(__file__).parent / "ppo_base_config.yaml" + with open(config_yaml) as f: + raw_yaml = yaml.safe_load(f) + # Remove Hydra defaults key (not needed for direct loading) + raw_yaml.pop("defaults", None) + base_cfg = OmegaConf.create(raw_yaml) + + merged = OmegaConf.merge(base_cfg, overrides) + merged_dict = OmegaConf.to_container(merged, resolve=True) + + if is_legacy_config(merged_dict): + warn_legacy_config() + merged_dict = translate_legacy_config(merged_dict) + + return build_nested_dataclass(cls, merged_dict) def make_config( diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index d25a59bbbd..44b5a50831 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -52,7 +52,8 @@ envs: setup: | bash scripts/fleet-common-setup.sh \ --openenv-branch deniz/fleet_client \ - --extra-setup scripts/fleet-qwen35-extra-setup.sh + --extra-setup scripts/fleet-qwen35-extra-setup.sh \ + --env-class task_gen run: | bash scripts/fleet-task-gen-run.sh From 31e96c1395c058ebe34fde55348ba6acc90aedae Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 03:32:05 -0700 Subject: [PATCH 021/121] Fix task_gen rollout dir and OmegaConf struct flag for config overrides - task_gen_env.py: Default ROLLOUT_DIR to ~/rollouts instead of /workspace/rollouts. /workspace doesn't exist on GCP (only RunPod), causing PermissionError. - config.py: Disable OmegaConf struct flag on base config before merging CLI overrides. Empty dicts in YAML (like chat_template_kwargs: {}) are loaded as closed structs, rejecting new keys during merge. - config.py: Add try/except around asdict() in get_config_as_yaml_str to handle edge cases where asdict fails on Ray-serialized dataclasses. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 3 ++- skyrl/train/config/config.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 5e3dea83c7..f184ca9478 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -180,7 +180,8 @@ def __init__( self._fleet_client = None # Rollout dump directory (full prompt/verifier/scores per eval) - self._rollout_dir = os.environ.get("ROLLOUT_DIR", "/workspace/rollouts") + default_rollout_dir = os.path.join(os.path.expanduser("~"), "rollouts") + self._rollout_dir = os.environ.get("ROLLOUT_DIR", default_rollout_dir) os.makedirs(self._rollout_dir, exist_ok=True) # Base quality reward for tasks passing sandbox + judge gate. diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index da4f101b6b..ceb665ed1e 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -817,6 +817,11 @@ def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig": raw_yaml.pop("defaults", None) base_cfg = OmegaConf.create(raw_yaml) + # Disable struct flag so overrides can add new keys to dict-typed fields + # (e.g., chat_template_kwargs={enable_thinking:true}). + # OmegaConf loads empty dicts from YAML as closed structs by default. + OmegaConf.set_struct(base_cfg, False) + merged = OmegaConf.merge(base_cfg, overrides) merged_dict = OmegaConf.to_container(merged, resolve=True) @@ -895,7 +900,11 @@ def get_config_as_dict(cfg: Union[dict, BaseConfig]) -> dict: def get_config_as_yaml_str(cfg) -> str: if dataclasses.is_dataclass(cfg) and not isinstance(cfg, type): - return yaml.dump(asdict(cfg)) + try: + return yaml.dump(asdict(cfg)) + except TypeError: + # asdict can fail in some Ray serialization edge cases; fall back to str + return str(cfg) # Handle OmegaConf DictConfig (from Hydra entrypoints) try: from omegaconf import OmegaConf From e0bd62c74ab00ce7da091fbf42627f55536f4f31 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 03:53:32 -0700 Subject: [PATCH 022/121] Export FLEET_API_KEY to Ray runtime env and improve task import error logging FLEET_API_KEY was not being propagated to Ray workers via runtime_env, causing task_gen's import_single_task to fail with empty API key. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 8 ++++++-- skyrl/train/utils/utils.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index f184ca9478..42e1d1a839 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -789,9 +789,13 @@ async def _run_harness_job( env_variables=self.env_variables or {}, ) - import_response = fleet.import_single_task(task) + try: + import_response = fleet.import_single_task(task) + except Exception as e: + logger.error(f"[{task_key}] Failed to import task to Fleet: {e}") + return (None, [(0.0, None, None)] * k) if import_response is None: - logger.error(f"[{task_key}] Failed to import task to Fleet") + logger.error(f"[{task_key}] Failed to import task to Fleet (returned None, api_key set: {bool(self.fleet_api_key)})") return (None, [(0.0, None, None)] * k) job_response = fleet.create_job( diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index 1fa7b1d47b..ffa0f63d14 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -638,6 +638,11 @@ def prepare_runtime_environment(cfg: SkyRLTrainConfig) -> dict[str, str]: logger.info("Exporting mlflow tracking token to ray runtime env") env_vars["MLFLOW_TRACKING_TOKEN"] = os.environ["MLFLOW_TRACKING_TOKEN"] + # Fleet env vars needed by fleet_task and task_gen environments + for var_name in ["FLEET_API_KEY", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_REGION"]: + if value := os.environ.get(var_name): + env_vars[var_name] = value + # NOTE(charlie): these are for Harbor. We should remove these once we have a sustainable way to handle these environment vars. for var_name in ["DAYTONA_API_KEY", "MODAL_TOKEN_ID", "MODAL_TOKEN_SECRET"]: if value := os.environ.get(var_name): From bfd547a149c807c42ecaf901aecae3ba900bb712 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 05:35:56 -0700 Subject: [PATCH 023/121] fix(task_gen): use data_source as fallback for env_key in extras The dataset prepare step stores the environment name as 'data_source' column, but TaskGenEnv.__init__ only looked for 'env_key'. This caused all import_single_task calls to use env_id='unknown', which fails with "Environment 'unknown' not found" from Fleet API. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 42e1d1a839..5526746fdf 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -98,7 +98,7 @@ def __init__( self.called_query_db = False # Environment context from dataset (extras) - self.env_key = extras.get("env_key", "unknown") + self.env_key = extras.get("env_key") or extras.get("data_source", "unknown") self.env_version = extras.get("env_version", "") self.data_key = extras.get("data_key", "") self.data_version = extras.get("data_version", "") From d5620dc49651c70c42eac785a11041fc80d66efb Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 05:52:44 -0700 Subject: [PATCH 024/121] fix(35b): add --no-pytorch-alloc-conf to prevent vLLM CuMem crash expandable_segments:True in PYTORCH_CUDA_ALLOC_CONF is incompatible with vLLM's CuMemAllocator, causing AssertionError during model load. The 9B script already had this flag; the 35B was missing it. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 37a33e4a70..2aba162b79 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -32,7 +32,7 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ - --set-ulimit \ + --set-ulimit --no-pytorch-alloc-conf \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ From 205252fbeb6e28a6df25ddab9745541edf640feb Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 06:21:40 -0700 Subject: [PATCH 025/121] fix: sanitize multimodal content for text-only chat templates Fleet env returns list-format content (from OpenEnv multimodal observations) that text-only templates like Qwen3.5-35B-A3B can't handle. This converts list content (strings or image_url dicts) to plain text before applying the chat template, preventing jinja2 TemplateError on non-VL models. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/generators/skyrl_gym_generator.py | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index a44f55a02d..aeef02b6a9 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -628,6 +628,37 @@ def _build_per_token_rewards( reward_out = token_level_rewards return reward_out + @staticmethod + def _sanitize_messages_for_template(messages: ConversationType) -> ConversationType: + """Ensure message content is compatible with the model's chat template. + + Converts list-format content (multimodal observations from fleet env) to + plain text. This handles two cases from OpenEnv: + 1. List of strings (multiple text results): joined into one string + 2. List of dicts (image_url / text blocks): text extracted, images replaced + """ + sanitized = [] + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + text_parts = [] + for item in content: + if isinstance(item, str): + text_parts.append(item) + elif isinstance(item, dict): + if "text" in item: + text_parts.append(item["text"]) + elif "image_url" in item or "image" in item: + text_parts.append("[image]") + else: + text_parts.append(str(item)) + else: + text_parts.append(str(item)) + sanitized.append({**msg, "content": "\n".join(text_parts)}) + else: + sanitized.append(msg) + return sanitized + def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List[int]: """ Returns observation token ids from observation messages for a turn. @@ -644,10 +675,13 @@ def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List # 2. apply chat template for observations, also generate generation prompt for next turn obs_ids_to_add = [] if len(new_obs) > 0: + # Sanitize list-format content (multimodal) to plain text for + # compatibility with text-only chat templates (e.g. Qwen3.5-35B-A3B) + safe_obs = self._sanitize_messages_for_template(new_obs) # For Qwen, this will generate `\n<|user|>Some observation<|im_end|>\n`. Note that the # first `\n` is generated since we stripped it in ``base_conversation_token_ids``. obs_ids_to_add = self.tokenizer.apply_chat_template( - [*self.base_conversation, *new_obs], + [*self.base_conversation, *safe_obs], add_generation_prompt=not is_done, tokenize=True, return_dict=False, @@ -660,7 +694,8 @@ def get_obs_ids_from_obs(self, new_obs: ConversationType, is_done: bool) -> List # no generation prompt is added in this case obs_ids_to_add = [] if len(new_obs) > 0: - for obs in new_obs: + safe_obs = self._sanitize_messages_for_template(new_obs) + for obs in safe_obs: obs_tokens = self.tokenizer.encode(obs["content"], add_special_tokens=False) obs_ids_to_add.extend(obs_tokens) return obs_ids_to_add From 3d4ad77d24509f0a945c50d5e77ccb0f7eb80aee Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 06:49:05 -0700 Subject: [PATCH 026/121] fix: update uids after hint augmentation extends trajectory_ids Hint augmentation extends trajectory_ids in generator_input in-place, but the separate uids variable in the trainer was never updated. This caused IndexError in postprocess_generator_output when uids had fewer entries than rewards (128 raw + N hinted rewards vs 128 uids). Co-Authored-By: Claude Opus 4.6 --- skyrl/train/trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index a7d6629155..735628ec3f 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -231,6 +231,10 @@ async def train(self): # NOTE: We use instance_ids from `trajectory_ids` here instead of re-using `uids` # this is because in step-wise training, len(uids) != len(generator_output["response_ids"]) uids = [trajectory_id.instance_id for trajectory_id in generator_output["trajectory_ids"]] + elif "trajectory_ids" in generator_input and generator_input["trajectory_ids"] is not None: + # Hint augmentation may extend trajectory_ids in-place during generate(). + # Re-derive uids to stay aligned with rewards/responses. + uids = [tid.instance_id for tid in generator_input["trajectory_ids"]] # dynamic sampling if self.cfg.trainer.algorithm.dynamic_sampling.type is not None: From 0f921a7a1f39a686092f6e37473f5d861635a9c5 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 07:15:37 -0700 Subject: [PATCH 027/121] fix: catch env.init failures in agent_loop to prevent training crash When a Fleet environment fails to provision (e.g., list_tools timeout), return a zero-reward trajectory instead of propagating the exception through tqdm.gather and crashing the entire training step. This makes training resilient to transient Fleet API / MCP failures. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/generators/skyrl_gym_generator.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index aeef02b6a9..5ea13ccaed 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -306,7 +306,25 @@ async def agent_loop( chat_history = copy.deepcopy(prompt) # init() returns the first prompt to be given to the model, and optional metadata dict - chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) + try: + chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) + except Exception as e: + logger.warning(f"Session {session_id}: env.init failed ({type(e).__name__}: {e}), returning zero-reward trajectory") + # Return a minimal failed trajectory so training can continue + dummy_ids = self.tokenizer.apply_chat_template( + chat_history, add_generation_prompt=False, tokenize=True, return_dict=False, + **self.generator_cfg.chat_template_kwargs, + ) + eos_id = self.tokenizer.eos_token_id + return TrajectoryOutput( + response_ids=[eos_id] if eos_id is not None else [0], + reward=0.0, + stop_reason="env_init_error", + loss_mask=[0], + prompt_ids=dummy_ids, + rollout_logprobs=None, + env_metrics={"env_init_error": str(e), "final_reward": 0.0}, + ) initial_chat_history_length = len(chat_history) # VL: extract images from initial prompt for multimodal models From 35de0bbaae3147993cc5d97d091d04d236ad747d Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 07:18:55 -0700 Subject: [PATCH 028/121] fix: hardcode flash_attn=false in all fleet run scripts + add "$@" to task-gen flash-attn 2.8.3 on GCP H200 causes Xid 31 FAULT_PDE crashes during FSDP2 ref model forward. Hardcoding false in all scripts ensures the fix is applied regardless of CLI overrides. Also adds missing "$@" passthrough in fleet-task-gen-run.sh so CLI overrides are properly forwarded (was silently dropping them). Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- scripts/fleet-task-gen-run.sh | 5 +++-- scripts/fleet-vl-run.sh | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 2aba162b79..5daac94c46 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ - trainer.flash_attn=true \ + trainer.flash_attn=false \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh index e121499f00..1f2d4f902b 100755 --- a/scripts/fleet-task-gen-run.sh +++ b/scripts/fleet-task-gen-run.sh @@ -27,7 +27,7 @@ bash scripts/fleet-common-run.sh \ --env-class task_gen -- \ trainer.algorithm.advantage_estimator="grpo" \ trainer.policy.model.path="Qwen/Qwen3.5-9B" \ - trainer.flash_attn=true \ + trainer.flash_attn=false \ trainer.use_sample_packing=false \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ @@ -77,4 +77,5 @@ bash scripts/fleet-common-run.sh \ ++environment.skyrl_gym.task_gen.alpha=$ALPHA \ ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ ++environment.skyrl_gym.task_gen.evaluator_model="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" \ - ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 + ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 \ + "$@" diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index f610fac128..54da097cd2 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -44,7 +44,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.enable_hints=false \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-9B" \ - trainer.flash_attn=true \ + trainer.flash_attn=false \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ trainer.algorithm.loss_reduction="sequence_mean" \ From 763f056dbce4a95c9750279c9c5d6459f7b6792b Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 08:26:58 -0700 Subject: [PATCH 029/121] fix: use [0.0] rollout_logprobs in env_init_error fallback (not None) The validator at trainer_utils.py:648 checks len(rollout_logprobs[i]) for each sample. When env.init fails, the fallback TrajectoryOutput had rollout_logprobs=None, causing TypeError: object of type 'NoneType' has no len(). Match the single-token response_ids with [0.0] logprobs. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/generators/skyrl_gym_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 5ea13ccaed..03c6a3c28d 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -322,7 +322,7 @@ async def agent_loop( stop_reason="env_init_error", loss_mask=[0], prompt_ids=dummy_ids, - rollout_logprobs=None, + rollout_logprobs=[0.0], env_metrics={"env_init_error": str(e), "final_reward": 0.0}, ) initial_chat_history_length = len(chat_history) From 244f40d62415935f2c1b23c5b5ed05b8a3b817ea Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 09:42:25 -0700 Subject: [PATCH 030/121] fix: match reward format in env_init_error fallback Normal multi-turn trajectories return per-token List[float] rewards (from _build_per_token_rewards), but the env.init error handler was returning a plain float 0.0. The validate_generator_output check at trainer_utils.py:660 rejects mixed types (some List[float], some float). Fix: use [0.0] when not using custom_chat_template (per-token format), 0.0 when using custom_chat_template (scalar format). Co-Authored-By: Claude Opus 4.6 --- skyrl/train/generators/skyrl_gym_generator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 03c6a3c28d..d891ff1dc0 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -316,9 +316,11 @@ async def agent_loop( **self.generator_cfg.chat_template_kwargs, ) eos_id = self.tokenizer.eos_token_id + # Match reward format: custom_chat_template uses float, otherwise per-token List[float] + reward_val = 0.0 if self.custom_chat_template else [0.0] return TrajectoryOutput( response_ids=[eos_id] if eos_id is not None else [0], - reward=0.0, + reward=reward_val, stop_reason="env_init_error", loss_mask=[0], prompt_ids=dummy_ids, From 79700ee3b9353dcd5c4a50cf4e7d904f80d558f9 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 12:00:57 -0700 Subject: [PATCH 031/121] add per-env task-gen launcher script Ported from original SkyRL fork. Launches 8 separate SkyPilot clusters (one per Fleet environment) with computed NUM_EPOCHS targeting ~40 training steps per env. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-task-gen-launch-per-env.sh | 35 ++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100755 scripts/fleet-task-gen-launch-per-env.sh diff --git a/scripts/fleet-task-gen-launch-per-env.sh b/scripts/fleet-task-gen-launch-per-env.sh new file mode 100755 index 0000000000..ee1e21c39d --- /dev/null +++ b/scripts/fleet-task-gen-launch-per-env.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash +# Launch per-env task-gen experiments — one SkyPilot cluster per environment +# Targets ~40 training steps per env by computing NUM_EPOCHS from seed counts. +set -euo pipefail + +YAML="tasks/task-gen-grpo-qwen3_5-9b.yaml" +EVAL_RATIO="0.05" +TARGET_STEPS=40 +BATCH_SIZE=12 + +# Seed counts per env from v55 dataset (after EVAL_RATIO=0.05 split) +declare -A SEEDS=( + [booking]=539 [budget]=567 [carlisle]=336 [outlook]=181 + [reddit]=505 [rops-mail]=44 [ticketmaster]=212 [zillow]=106 +) + +for env in "${!SEEDS[@]}"; do + seeds=${SEEDS[$env]} + steps_per_epoch=$(( (seeds + BATCH_SIZE - 1) / BATCH_SIZE )) + num_epochs=$(( (TARGET_STEPS + steps_per_epoch - 1) / steps_per_epoch )) + total_steps=$(( steps_per_epoch * num_epochs )) + + echo "Launching task-gen-${env}: ${seeds} seeds, ${steps_per_epoch} steps/epoch, ${num_epochs} epochs (${total_steps} steps)" + sky launch -c "task-gen-${env}" "$YAML" \ + --env ENV_KEYS="$env" \ + --env EVAL_RATIO="$EVAL_RATIO" \ + --env NUM_EPOCHS="$num_epochs" \ + --env FLEET_API_KEY="$FLEET_API_KEY" \ + --env WANDB_API_KEY="$WANDB_API_KEY" \ + --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ + --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ + --yes --async +done + +echo "All 8 clusters launched. Monitor with: sky status" From c1ceec016dfdf19db5accea7210a6b512c049b64 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 13:16:16 -0700 Subject: [PATCH 032/121] fix: handle per-trajectory exceptions in generate() instead of crashing batch Replace tqdm.gather() with asyncio.wait() + per-task exception handling, matching the old SkyRL fork's resilient pattern. Previously, a single trajectory failure (e.g., None screenshot URL) would crash the entire training step. Now failed trajectories get zero-reward outputs and training continues. Also adds null guard in extract_images_from_conversation for defense-in-depth. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/generators/skyrl_gym_generator.py | 71 +++++++++++++++---- skyrl/train/generators/utils.py | 2 +- 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index d891ff1dc0..b1917bf453 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -252,6 +252,29 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs): else: return func(*args, **kwargs) + def _make_zero_reward_output( + self, + prompt: ConversationType, + zero_reward: Union[float, list], + is_step_wise: bool, + stop_reason: str = "trajectory_error", + env_metrics: Optional[Dict[str, Any]] = None, + ) -> Union[TrajectoryOutput, StepWiseOutput]: + """Create a zero-reward output for failed/cancelled trajectories.""" + prompt_ids = self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_dict=False) + output = TrajectoryOutput( + response_ids=[self.tokenizer.eos_token_id], + reward=zero_reward, + stop_reason=stop_reason, + loss_mask=[0], + prompt_ids=prompt_ids, + rollout_logprobs=[0.0], + env_metrics=env_metrics or {stop_reason: 1.0}, + ) + if is_step_wise: + return StepWiseOutput(step_outputs=[output]) + return output + async def agent_loop( self, prompt: ConversationType, @@ -991,27 +1014,47 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False return await self.generate_batched(prompts, env_classes, env_extras, max_tokens, sampling_params) # Async agent loop to generate trajectories in parallel. - tasks = [] + # Use asyncio.wait() instead of gather() so individual trajectory failures + # don't crash the entire batch — failed trajectories get zero-reward outputs. + is_step_wise = self.generator_cfg.step_wise_trajectories + zero_reward = 0.0 if self.custom_chat_template else [0.0] + + async_tasks = [] for i in range(len(prompts)): - tasks.append( - self.agent_loop( - prompts[i], - env_classes[i], - env_extras[i], - max_tokens, - max_input_length, - sampling_params=sampling_params, - trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None, - ) + coro = self.agent_loop( + prompts[i], + env_classes[i], + env_extras[i], + max_tokens, + max_input_length, + sampling_params=sampling_params, + trajectory_id=trajectory_ids[i] if trajectory_ids is not None else None, ) + async_tasks.append(asyncio.ensure_future(coro)) - all_outputs = await tqdm.gather( - *tasks, + task_to_idx = {id(t): i for i, t in enumerate(async_tasks)} + + pbar = tqdm( + total=len(async_tasks), desc="Generating Trajectories", - miniters=max(1, len(tasks) // 10), + miniters=max(1, len(async_tasks) // 10), mininterval=5, disable=disable_tqdm, ) + for t in async_tasks: + t.add_done_callback(lambda _: pbar.update(1)) + + done, pending = await asyncio.wait(async_tasks) + pbar.close() + + all_outputs: list = [None] * len(async_tasks) + for t in done: + idx = task_to_idx[id(t)] + if t.exception() is not None: + logger.error(f"Trajectory {idx} raised exception: {t.exception()}") + all_outputs[idx] = self._make_zero_reward_output(prompts[idx], zero_reward, is_step_wise) + else: + all_outputs[idx] = t.result() # --- Hint augmentation: rescue GRPO signal on dead prompts --- # Only during training; eval should not run hints. diff --git a/skyrl/train/generators/utils.py b/skyrl/train/generators/utils.py index 5dba17a259..a7bf440702 100644 --- a/skyrl/train/generators/utils.py +++ b/skyrl/train/generators/utils.py @@ -609,7 +609,7 @@ def extract_images_from_conversation(conversation: ConversationType) -> List[Any if not isinstance(item, dict) or item.get("type") != "image_url": continue image_url_data = item.get("image_url", {}) - url = image_url_data.get("url", "") + url = image_url_data.get("url") or "" if url.startswith("data:image"): images.append(decode_base64_image(url)) elif url.startswith(("http://", "https://")): From 9d51e335aa6a622b44906d1681b1d7f2159f2f11 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 13:40:22 -0700 Subject: [PATCH 033/121] fix: multi-node FSDP2 stability + hint batch size for 35B training Ports fixes from old fork (PR #328, #333) plus a new fix for hint augmentation batch sizing in stage_chunks. 1. Synchronous ref offload/backload with barrier (fsdp_worker.py) - Prevents cudaErrorIllegalAddress across nodes with no shared CUDA context 2. empty_cache before backward (worker.py, 3 sites) - Prevents OOM from CUDA memory fragmentation on 35B 3. Enable expandable_segments for 35B (fleet-35b-run.sh) - Remove --no-pytorch-alloc-conf so Triton autotuning doesn't fragment segments 4. Dynamic mini_batch_size for hint augmentation (dispatch.py) - Adjusts mini_batch_size to divide variable batch sizes (e.g. 160 -> 80x2) - All samples including hints are trained on (no silent dropping) Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- .../skyrl_train/distributed/dispatch.py | 22 ++++++++++++++++--- .../skyrl_train/workers/fsdp/fsdp_worker.py | 11 ++++++++-- skyrl/backends/skyrl_train/workers/worker.py | 3 +++ 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 5daac94c46..1876ad3c25 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -32,7 +32,7 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ - --set-ulimit --no-pytorch-alloc-conf \ + --set-ulimit \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ diff --git a/skyrl/backends/skyrl_train/distributed/dispatch.py b/skyrl/backends/skyrl_train/distributed/dispatch.py index 18e56666d1..eea4d6a479 100644 --- a/skyrl/backends/skyrl_train/distributed/dispatch.py +++ b/skyrl/backends/skyrl_train/distributed/dispatch.py @@ -1,10 +1,13 @@ """Defines dispatch and collect logic for distributed training""" import asyncio +import logging from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Type +logger = logging.getLogger(__name__) + import ray from ray import ObjectRef from ray.actor import ActorHandle @@ -187,9 +190,22 @@ def stage_chunks( List of per-mini-batch chunk ref lists. ``result[i][dp_rank]`` is the ObjectRef for mini-batch *i*, DP rank *dp_rank*. """ - assert ( - len(data) % mini_batch_size == 0 - ), f"data batch size must be divisible by mini_batch_size, got {len(data)} and {mini_batch_size}" + # Hint augmentation can produce variable batch sizes that don't evenly + # divide the configured mini_batch_size. Rather than dropping samples + # (which wastes the expensive hint rollouts), reduce mini_batch_size to + # the largest value that divides both len(data) and dp_size. + if len(data) % mini_batch_size != 0: + original_mbs = mini_batch_size + # Step down by dp_size to stay dp-divisible + while mini_batch_size > 0 and (len(data) % mini_batch_size != 0 or mini_batch_size % dp_size != 0): + mini_batch_size -= dp_size + if mini_batch_size <= 0: + mini_batch_size = dp_size + logger.info( + f"Adjusted mini_batch_size from {original_mbs} to {mini_batch_size} " + f"to evenly divide batch of {len(data)} samples " + f"({len(data) // mini_batch_size} mini-batches)." + ) assert ( mini_batch_size % dp_size == 0 ), f"mini_batch_size must be divisible by dp_size, got {mini_batch_size} and {dp_size}" diff --git a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py index a0a1990f5d..779545d90e 100644 --- a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py @@ -395,10 +395,17 @@ def forward( class FSDPRefWorkerBase(RefWorkerBase): def offload_to_cpu(self, pin_memory=True, non_blocking=True, **kwargs): self._set_numa_affinity(torch.distributed.get_rank() % torch.cuda.device_count()) - self.strategy.offload_to_cpu(self.model, None, pin_memory, non_blocking) + # Force synchronous transfers + barrier to prevent cudaErrorIllegalAddress + # when policy workers access GPU memory that ref workers are still offloading + # across nodes (no shared CUDA context in multi-node). + self.strategy.offload_to_cpu(self.model, None, pin_memory, non_blocking=False) + if torch.distributed.is_initialized(): + torch.distributed.barrier() def backload_to_gpu(self, non_blocking=True, **kwargs): - self.strategy.backload_to_gpu(self.model, None, non_blocking) + self.strategy.backload_to_gpu(self.model, None, non_blocking=False) + if torch.distributed.is_initialized(): + torch.distributed.barrier() def init_model(self, model_path): assert self.cfg.strategy in ("fsdp", "fsdp2") diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index a88104a975..2609bf9347 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -805,6 +805,7 @@ def _forward_backward_micro( # SFT path: skip KL/entropy terms, return per-token outputs for Tinker API if resolved_loss_name == "cross_entropy": loss = policy_loss + torch.cuda.empty_cache() # defrag before backward to prevent OOM on 35B self.strategy.backward(loss, self.model, self.optimizer) # Compute elementwise loss for Tinker API (per-token NLL) @@ -870,6 +871,7 @@ def _forward_backward_micro( kl_loss_term = kl_loss * self.cfg.algorithm.kl_loss_coef loss = policy_loss + kl_loss_term - entropy_loss_term + torch.cuda.empty_cache() # defrag before backward to prevent OOM on 35B self.strategy.backward(loss, self.model, self.optimizer) # Build per-sequence loss_fn_outputs with logprobs. @@ -1100,6 +1102,7 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: loss_mask=loss_mask, ) # NO loss scaling here - gradient scaling happens at optim_step + torch.cuda.empty_cache() # defrag before backward to prevent OOM on 35B self.strategy.backward(loss, self.model, self.optimizer) status = { From 09d2b43b77754e910b43aa6384cfe63811813a81 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 13:42:04 -0700 Subject: [PATCH 034/121] docs: add CLAUDE.md and fleet changelog for multi-node fixes Documents the 4 fixes ported from old fork (PR #328, #333) plus the new dynamic mini_batch_size fix for hint augmentation. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 32 ++++++++++++++++ integrations/fleet/CHANGELOG.md | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 CLAUDE.md create mode 100644 integrations/fleet/CHANGELOG.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..66e366c4a2 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,32 @@ +# SkyRL-v2 (fleet-ai/SkyRL-v2) + +Fork of SkyRL with Fleet-specific optimizations for multi-node FSDP2 training at scale. + +## Fleet Integration + +Fleet-specific changes, fixes, and context are documented in: +- **[integrations/fleet/CHANGELOG.md](integrations/fleet/CHANGELOG.md)** — detailed changelog with root causes and fixes + +Always consult the changelog before modifying Fleet training paths (`fsdp_worker.py`, `worker.py`, `dispatch.py`, `fleet-*.sh`). + +## Key Differences from Upstream SkyRL + +1. **Multi-node FSDP2 stability**: Synchronous ref model offload/backload with `torch.distributed.barrier()` in `fsdp_worker.py`. Required because cross-node colocated training has no shared CUDA context. + +2. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation on large models with tight GPU memory margins. + +3. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. + +4. **`expandable_segments` for 35B**: `fleet-35b-run.sh` does NOT pass `--no-pytorch-alloc-conf`, so `fleet-common-run.sh` enables `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. 9B runs on RunPod still use `--no-pytorch-alloc-conf` (no `CAP_SYS_PTRACE`). + +## Training Scripts + +- `scripts/fleet-common-run.sh` — shared infra (Ray, NCCL, gIB detection, deps). Used by all runs. +- `scripts/fleet-35b-run.sh` — Qwen3.5-35B config. Calls `fleet-common-run.sh`. +- `scripts/fleet-9b-run.sh` — Qwen3.5-9B config. Calls `fleet-common-run.sh`. + +All training flags live in these scripts. Never duplicate flags in SkyPilot YAMLs or fleet-research scripts. + +## Branch + +Primary development branch: `fleet/all` diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md new file mode 100644 index 0000000000..77e719e587 --- /dev/null +++ b/integrations/fleet/CHANGELOG.md @@ -0,0 +1,65 @@ +# Fleet Integration Changelog + +## 2026-03-29: Multi-node FSDP stability + hint batch size fix + +Ported from fleet-ai/SkyRL PR #328 and PR #333, plus a new fix for hint augmentation batch sizing. + +### Problem + +2-node (16 GPU) Qwen3.5-35B training on GCP H200 crashed with: +1. `cudaErrorIllegalAddress` segfaults during FSDP ref model offload/backload +2. OOM during backward pass from CUDA memory fragmentation +3. `AssertionError: data batch size must be divisible by mini_batch_size, got 160 and 128` when hints are enabled + +### Root causes and fixes + +#### 1. Synchronous ref offload + barrier (`fsdp_worker.py`) + +**Where:** `FSDPRefWorkerBase.offload_to_cpu()` and `backload_to_gpu()` + +**Problem:** With colocated models, the trainer cycles: ref on GPU -> ref offload to CPU -> policy on GPU. With `non_blocking=True`, the CPU<-GPU transfer is *queued* but returns immediately. On a single node, CUDA stream ordering serializes this naturally. Across nodes, there's no shared CUDA context -- node 0's policy worker can start touching GPU memory while node 1's ref worker is still mid-transfer. Result: `cudaErrorIllegalAddress`. + +**Fix:** `non_blocking=False` (wait for transfer) + `torch.distributed.barrier()` (all ranks synchronize). Guarantees every GPU finishes offloading before any policy worker starts backloading. + +**Why upstream SkyRL doesn't need this:** Designed for single-node where all workers share the same CUDA context and stream ordering prevents races. + +#### 2. empty_cache before backward (`worker.py`) + +**Where:** `PolicyWorkerBase._forward_backward_micro()` (both SFT and RL paths) and `CriticWorkerBase._forward_backward_micro()` + +**Problem:** After the forward pass, freed intermediate tensors stay in PyTorch's CUDA cache as scattered blocks. The backward pass needs large contiguous allocations for gradients. On the 35B model with tight GPU memory margins, the fragmented cache can't satisfy these allocations -> OOM, even though total free memory is sufficient. + +**Fix:** `torch.cuda.empty_cache()` before `strategy.backward()`. Returns cached blocks to CUDA which coalesces them into contiguous allocations. + +**Why upstream SkyRL doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. + +#### 3. Enable expandable_segments for 35B (`fleet-35b-run.sh`) + +**Where:** Removed `--no-pytorch-alloc-conf` flag from `fleet-35b-run.sh` + +**Problem:** Without `expandable_segments:True`, PyTorch allocates fixed-size CUDA segments. Triton autotuning (FlashAttention, MoE kernels) allocates many trial buffers then frees them, leaving a fragmented segment map. Subsequent large allocations fail even though total free memory is sufficient. + +**Fix:** Remove `--no-pytorch-alloc-conf` so `fleet-common-run.sh` sets `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. GCP has writable `ptrace_scope=0` so vLLM's CuMemAllocator and expandable_segments coexist fine. + +**Note:** 9B task-gen runs on RunPod still use `--no-pytorch-alloc-conf` because RunPod containers lack `CAP_SYS_PTRACE` and expandable_segments uses cuMem APIs that need `pidfd_getfd`. + +#### 4. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) + +**Where:** `MeshDispatch.stage_chunks()` + +**Problem:** `mini_batch_size` is computed as `policy_mini_batch_size * n_samples_per_prompt` (e.g., 16 * 8 = 128). But hint augmentation appends extra samples: 16 prompts * 2 hints = 32 additional, total batch = 160. The `stage_chunks` method asserted `160 % 128 == 0` -> crash. + +The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) silently dropped the 32 hint samples -- no crash, but hint training was wasted. + +**Fix:** When batch size isn't divisible by mini_batch_size, step down mini_batch_size (by `dp_size` increments to stay DP-divisible) until it divides evenly. For 160 samples with dp_size=16: adjusts from 128 -> 80, giving 2 mini-batches of 80. All 160 samples (including hints) are trained on. + +**Why upstream SkyRL doesn't have this:** Upstream uses a simple `for` loop with `//` division (no `stage_chunks` optimization). The `stage_chunks` pre-staging is a SkyRL-v2 optimization that added a strict assert the old code path never had. + +### Files changed + +| File | Change | +|------|--------| +| `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Synchronous ref offload + barrier | +| `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | +| `scripts/fleet-35b-run.sh` | Remove `--no-pytorch-alloc-conf` | +| `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | From b98b24089a5f84d9b621430bb41c15da2211aa66 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 13:46:45 -0700 Subject: [PATCH 035/121] Add training trajectory logging + S3 upload - Add env_metrics field to GeneratorOutput and pass through from generator - Add dump_training_trajectories() to trainer_utils for per-step JSONL export - Wire dump + S3 upload in trainer (after generate, before postprocess) - Upload reward rollouts to S3 at checkpoint intervals - Rename ROLLOUT_DIR -> REWARD_ROLLOUT_DIR in task_gen_env - Rename wandb project to fleet-task-gen - Enable dump_training_trajectories in fleet-common-run.sh Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/s3_checkpoints.py | 104 ++++++++++++++++++ scripts/fleet-common-run.sh | 1 + scripts/fleet-task-gen-run.sh | 2 +- .../skyrl_gym/envs/task_gen/task_gen_env.py | 4 +- skyrl/train/config/ppo_base_config.yaml | 1 + skyrl/train/generators/base.py | 1 + skyrl/train/generators/skyrl_gym_generator.py | 2 + skyrl/train/trainer.py | 31 ++++++ skyrl/train/utils/trainer_utils.py | 60 ++++++++++ 9 files changed, 203 insertions(+), 3 deletions(-) diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py index 5db5aebb40..441da3e002 100644 --- a/integrations/fleet/s3_checkpoints.py +++ b/integrations/fleet/s3_checkpoints.py @@ -437,3 +437,107 @@ def upload_eval_results_to_s3( except Exception as e: logger.error(f"S3 upload failed for eval results {local_dir}: {e}") return False + + +def upload_training_trajectories_to_s3( + local_path: str, + run_name: str, + global_step: int, + bucket: Optional[str] = None, + region: Optional[str] = None, +) -> bool: + """Upload a single training trajectory JSONL file to S3. + + Args: + local_path: Path to the JSONL file + run_name: Run name for S3 prefix + global_step: Global step number + bucket: S3 bucket (default: from S3_TRAJECTORY_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + + Returns: + True if upload succeeded + """ + bucket = bucket or os.environ.get("S3_TRAJECTORY_BUCKET", "skyrl-trajectories") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.warning("AWS credentials not found. Skipping training trajectory upload.") + return False + + if not os.path.exists(local_path): + logger.warning(f"Trajectory file does not exist: {local_path}") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config(retries={"max_attempts": 3, "mode": "adaptive"}) + s3 = boto3.client("s3", region_name=region, config=config) + + s3_key = f"rollouts/{run_name}/global_step_{global_step}.jsonl" + s3.upload_file(local_path, bucket, s3_key) + logger.info(f"Uploaded training trajectories to s3://{bucket}/{s3_key}") + return True + + except Exception as e: + logger.error(f"S3 upload failed for training trajectories: {e}") + return False + + +def upload_reward_rollouts_to_s3( + rollout_dir: str, + run_name: str, + bucket: Optional[str] = None, + region: Optional[str] = None, +) -> bool: + """Upload reward rollout files to S3. + + Args: + rollout_dir: Local directory containing reward rollout JSONL files + run_name: Run name for S3 prefix + bucket: S3 bucket (default: from S3_TRAJECTORY_BUCKET env var) + region: AWS region (default: from AWS_REGION env var) + + Returns: + True if upload succeeded + """ + bucket = bucket or os.environ.get("S3_TRAJECTORY_BUCKET", "skyrl-trajectories") + region = region or os.environ.get("AWS_REGION", "us-east-1") + + aws_key = os.environ.get("AWS_ACCESS_KEY_ID") + aws_secret = os.environ.get("AWS_SECRET_ACCESS_KEY") + if not (aws_key and aws_secret): + logger.warning("AWS credentials not found. Skipping reward rollout upload.") + return False + + rollout_path = Path(rollout_dir) + if not rollout_path.exists(): + logger.info(f"No reward rollout directory at {rollout_dir}, skipping upload.") + return False + + try: + import boto3 + from botocore.config import Config + + config = Config(retries={"max_attempts": 3, "mode": "adaptive"}) + s3 = boto3.client("s3", region_name=region, config=config) + + uploaded = 0 + for file_path in rollout_path.rglob("*"): + if file_path.is_file(): + relative = file_path.relative_to(rollout_path) + s3_key = f"reward_rollouts/{run_name}/{relative}" + s3.upload_file(str(file_path), bucket, s3_key) + uploaded += 1 + + if uploaded: + logger.info(f"Uploaded {uploaded} reward rollout files to s3://{bucket}/reward_rollouts/{run_name}/") + return True + + except Exception as e: + logger.error(f"S3 upload failed for reward rollouts: {e}") + return False diff --git a/scripts/fleet-common-run.sh b/scripts/fleet-common-run.sh index 2a4571c563..7e12cb7e73 100755 --- a/scripts/fleet-common-run.sh +++ b/scripts/fleet-common-run.sh @@ -268,6 +268,7 @@ if [ "${SKYPILOT_NODE_RANK:-0}" = "0" ]; then "generator.num_inference_engines=$NUM_INFERENCE_ENGINES" "trainer.ckpt_path=${CKPT_ROOT}/ckpts" "trainer.export_path=${CKPT_ROOT}/exports" + trainer.dump_training_trajectories=true ) # Append model-specific hydra overrides (passed after --) diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh index 1f2d4f902b..b23ac3d3d6 100755 --- a/scripts/fleet-task-gen-run.sh +++ b/scripts/fleet-task-gen-run.sh @@ -66,7 +66,7 @@ bash scripts/fleet-common-run.sh \ generator.eval_n_samples_per_prompt=3 \ generator.gpu_memory_utilization=0.75 \ trainer.logger="$LOGGER" \ - trainer.project_name="task-gen-grpo" \ + trainer.project_name="fleet-task-gen" \ trainer.run_name="$RUN_NAME" \ trainer.resume_mode=latest \ trainer.ckpt_path="$HOME/ckpts/task_gen" \ diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 5526746fdf..211093316b 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -180,8 +180,8 @@ def __init__( self._fleet_client = None # Rollout dump directory (full prompt/verifier/scores per eval) - default_rollout_dir = os.path.join(os.path.expanduser("~"), "rollouts") - self._rollout_dir = os.environ.get("ROLLOUT_DIR", default_rollout_dir) + default_rollout_dir = os.path.join(os.path.expanduser("~"), "reward_rollouts") + self._rollout_dir = os.environ.get("REWARD_ROLLOUT_DIR", default_rollout_dir) os.makedirs(self._rollout_dir, exist_ok=True) # Base quality reward for tasks passing sandbox + judge gate. diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 59d3df958d..854954de17 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -250,6 +250,7 @@ trainer: run_name: "test_run" logger: "wandb" dump_data_batch: false + dump_training_trajectories: false dump_eval_results: true # YaRN: diff --git a/skyrl/train/generators/base.py b/skyrl/train/generators/base.py index da04c0ffdd..b41d64d36a 100644 --- a/skyrl/train/generators/base.py +++ b/skyrl/train/generators/base.py @@ -41,6 +41,7 @@ class GeneratorOutput(TypedDict): rollout_logprobs: Optional[List[List[float]]] trajectory_ids: Optional[List[TrajectoryID]] rollout_expert_indices: Optional[List[List[List[List[int]]]]] # [batch_size, seq_len, layer_num, topk] + env_metrics: Optional[List[Dict[str, Any]]] # Applicable only for step-wise training is_last_step: Optional[List[bool]] # Hint augmentation: True for samples generated with hint feedback diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index b1917bf453..8efd53d2bf 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -985,6 +985,7 @@ async def generate_batched( "rollout_metrics": rollout_metrics, "rollout_logprobs": truncated_logprobs, "rollout_expert_indices": truncated_indices, + "env_metrics": env_metrics, } return generator_output @@ -1200,6 +1201,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "rollout_logprobs": rollout_logprobs, "trajectory_ids": out_trajectory_ids, "rollout_expert_indices": rollout_expert_indices, + "env_metrics": env_metrics, "is_last_step": is_last_step, "is_hinted": is_hinted, } diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 735628ec3f..0c0b127d1d 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -70,6 +70,7 @@ ResumeMode, build_dataloader, cleanup_old_checkpoints, + dump_training_trajectories, extract_step_from_path, run_on_each_node, validate_consistency_for_latest_checkpoint, @@ -248,6 +249,26 @@ async def train(self): # if we are not continuing sampling, we sleep the inference engine await self.inference_engine_client.sleep() + # 1.1.5 dump training trajectories + if self.cfg.trainer.dump_training_trajectories: + with Timer("dump_training_trajectories", self.all_timings): + traj_file = dump_training_trajectories( + dump_dir=self.cfg.trainer.export_path, + tokenizer=self.tokenizer, + generator_output=generator_output, + env_extras=generator_input.get("env_extras", []), + global_step=self.global_step, + ) + try: + from integrations.fleet.s3_checkpoints import upload_training_trajectories_to_s3 + upload_training_trajectories_to_s3( + local_path=traj_file, + run_name=self.cfg.trainer.run_name, + global_step=self.global_step, + ) + except Exception as e: + logger.warning(f"Failed to upload training trajectories to S3: {e}") + # 1.2 postprocess rewards with Timer("postprocess_generator_output", self.all_timings): generator_output = self.postprocess_generator_output(generator_output, uids) @@ -299,6 +320,16 @@ async def train(self): if self.cfg.trainer.ckpt_interval > 0 and self.global_step % self.cfg.trainer.ckpt_interval == 0: with Timer("save_checkpoints", self.all_timings): self.save_checkpoints() + if self.cfg.trainer.dump_training_trajectories: + try: + from integrations.fleet.s3_checkpoints import upload_reward_rollouts_to_s3 + reward_rollout_dir = os.environ.get("REWARD_ROLLOUT_DIR", "/workspace/reward_rollouts") + upload_reward_rollouts_to_s3( + rollout_dir=reward_rollout_dir, + run_name=self.cfg.trainer.run_name, + ) + except Exception as e: + logger.warning(f"Failed to upload reward rollouts to S3: {e}") if ( self.cfg.trainer.hf_save_interval > 0 and self.global_step % self.cfg.trainer.hf_save_interval == 0 diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 7fe3e53fb5..aaa33fba7e 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -1,5 +1,6 @@ import json import os +import time from collections import defaultdict from enum import Enum from pathlib import Path @@ -291,6 +292,61 @@ def dump_per_dataset_eval_results( logger.info(f"Dumped aggregated eval metrics to {aggregated_filename}") +def dump_training_trajectories( + dump_dir: str, + tokenizer: AutoTokenizer, + generator_output: GeneratorOutput, + env_extras: List[Dict[str, Any]], + global_step: int, +) -> str: + """Dump training trajectories to a JSONL file for analysis. + + Each line contains: step, env_key, data_source, stop_reason, reward, turns, tokens, prompt, text, timestamp. + """ + traj_dir = Path(dump_dir) / "dumped_trajectories" + traj_dir.mkdir(parents=True, exist_ok=True) + filename = traj_dir / f"global_step_{global_step}.jsonl" + + env_metrics_list = generator_output.get("env_metrics") or [] + rewards_list = generator_output["rewards"] + stop_reasons = generator_output.get("stop_reasons") or [] + ts = time.time() + + with open(filename, "w") as f: + for i in range(len(generator_output["response_ids"])): + env_m = env_metrics_list[i] if i < len(env_metrics_list) and env_metrics_list[i] else {} + env_key = env_m.get("env_key", "unknown") + turns = env_m.get("turns", env_m.get("num_turns", 0)) + extras = env_extras[i] if i < len(env_extras) else {} + data_source = extras.get("data_source", "unknown") if isinstance(extras, dict) else "unknown" + + reward = rewards_list[i] + if isinstance(reward, list): + reward = float(sum(reward)) + else: + reward = float(reward) + + stop_reason = stop_reasons[i] if i < len(stop_reasons) else "unknown" + tokens = len(generator_output["response_ids"][i]) + + entry = { + "step": global_step, + "env_key": env_key, + "data_source": data_source, + "stop_reason": stop_reason, + "reward": reward, + "turns": turns, + "tokens": tokens, + "prompt": tokenizer.decode(generator_output["prompt_token_ids"][i]), + "text": tokenizer.decode(generator_output["response_ids"][i]), + "timestamp": ts, + } + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + logger.info(f"Dumped {len(generator_output['response_ids'])} training trajectories to {filename}") + return str(filename) + + class DynamicSamplingState(TypedDict, total=False): """Schema for dynamic sampling state dictionary. @@ -565,6 +621,10 @@ def filter_generator_output(output: GeneratorOutput, kept_indices: List[int]) -> if output.get("stop_reasons"): filtered["stop_reasons"] = [output["stop_reasons"][i] for i in kept_indices] + filtered["env_metrics"] = ( + [output["env_metrics"][i] for i in kept_indices] if output.get("env_metrics") else None + ) + return filtered From 56cd76a2cb1cd9f1d737d5db84497d03ca61ce07 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 14:29:54 -0700 Subject: [PATCH 036/121] fix: add dump_training_trajectories to TrainerConfig dataclass The field was added to ppo_base_config.yaml but missed in the Python dataclass, causing hydra validation to reject it at startup. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/config/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index ceb665ed1e..99c198744e 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -615,6 +615,7 @@ class TrainerConfig(BaseConfig): logger: str = "wandb" dump_data_batch: bool = False dump_eval_results: bool = True + dump_training_trajectories: bool = False rope_scaling: Optional[Dict[str, Any]] = None rope_theta: Optional[float] = None loss_chunk_size: Optional[int] = None From 22d6acec6d6ad8f25ee0c9cb343100ec83e1bb18 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 14:51:35 -0700 Subject: [PATCH 037/121] fix: match system prompt with old fork for VL parity Two parity gaps vs feat/qwen3.5-9b branch: 1. Tool Call Format: old fork lists actual tool names ("Use the tools listed above by name (computer)") while new fork had a generic "tool_name" placeholder that the model copied verbatim, causing "Tool 'tool_name' not found" errors. 2. Computer-use hints: old fork has specific guidance (never repeat actions, use wait() only once, recovery strategies) while new fork had generic tips. Ported the battle-tested hints from old fork. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 28 +++++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py index 4d9295ce90..86b52de654 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -436,17 +436,25 @@ async def init_async( if modality == "computer_use": computer_use_hints = ( "\n## Browser Interaction Strategy\n" - "You are controlling a web browser via screenshots. Follow this loop:\n" + "You are controlling a web browser via screenshots. Follow this loop:\n\n" "1. **Act**: Perform ONE action (click, type, scroll, etc.)\n" "2. **Observe**: Take a screenshot to see the result\n" - "3. **Think**: Analyze what happened and decide the next action\n\n" - "Tips:\n" - "- Always take a screenshot after each action to verify the result\n" - "- Click on elements by their visual position in the screenshot\n" - "- If an element is not visible, scroll to find it\n" - "- Use keyboard shortcuts when appropriate (Ctrl+A, Ctrl+C, etc.)\n" + "3. **Adapt**: If the screen hasn't changed, try a DIFFERENT action\n\n" + "Key rules:\n" + "- After clicking or typing, ALWAYS take a screenshot next to see what happened\n" + "- NEVER repeat the same action more than twice. If it didn't work, try something different:\n" + " - Can't find an element by scrolling? Use the search bar or navigation menu instead\n" + " - Page not loading after a click? Try refreshing with key(\"F5\") or clicking a different element\n" + " - Form not submitting? Check if required fields are missing\n" + "- Use wait() only ONCE after a page navigation, then screenshot to check. Do not wait repeatedly\n" + "- When the task is fully complete, say . Do not keep clicking after finishing\n" ) + tool_names = [ + t["function"]["name"] for t in self.tools if "function" in t + ] + tool_names_str = ", ".join(tool_names) + system_content = ( f"You are a helpful agent. Complete the task by calling tools.\n\n" f"## Current Date\n" @@ -456,8 +464,10 @@ async def init_async( f"{env_context}{env_hints}{computer_use_hints}\n" f"## Available Tools\n{tools_json}\n\n" f"## Tool Call Format\n" - f'{{"name": "tool_name", "arguments": ' - f'{{"param": "value"}}}}\n\n' + f"Use the tools listed above by name ({tool_names_str}). " + f"Format each call as:\n" + f'{{"name": "", "arguments": ' + f"{{...}}}}\n\n" f"## Error Handling\n" f"If a tool call returns an error:\n" f"- Read the error message carefully\n" From d40c5f7bb6959bb9edf4420bf5d70f392306ed1c Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 14:51:35 -0700 Subject: [PATCH 038/121] fix: match system prompt with old fork for VL parity Two parity gaps vs feat/qwen3.5-9b branch: 1. Tool Call Format: old fork lists actual tool names ("Use the tools listed above by name (computer)") while new fork had a generic "tool_name" placeholder that the model copied verbatim, causing "Tool 'tool_name' not found" errors. 2. Computer-use hints: old fork has specific guidance (never repeat actions, use wait() only once, recovery strategies) while new fork had generic tips. Ported the battle-tested hints from old fork. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 28 +++++++++++++++------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py index 4d9295ce90..86b52de654 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -436,17 +436,25 @@ async def init_async( if modality == "computer_use": computer_use_hints = ( "\n## Browser Interaction Strategy\n" - "You are controlling a web browser via screenshots. Follow this loop:\n" + "You are controlling a web browser via screenshots. Follow this loop:\n\n" "1. **Act**: Perform ONE action (click, type, scroll, etc.)\n" "2. **Observe**: Take a screenshot to see the result\n" - "3. **Think**: Analyze what happened and decide the next action\n\n" - "Tips:\n" - "- Always take a screenshot after each action to verify the result\n" - "- Click on elements by their visual position in the screenshot\n" - "- If an element is not visible, scroll to find it\n" - "- Use keyboard shortcuts when appropriate (Ctrl+A, Ctrl+C, etc.)\n" + "3. **Adapt**: If the screen hasn't changed, try a DIFFERENT action\n\n" + "Key rules:\n" + "- After clicking or typing, ALWAYS take a screenshot next to see what happened\n" + "- NEVER repeat the same action more than twice. If it didn't work, try something different:\n" + " - Can't find an element by scrolling? Use the search bar or navigation menu instead\n" + " - Page not loading after a click? Try refreshing with key(\"F5\") or clicking a different element\n" + " - Form not submitting? Check if required fields are missing\n" + "- Use wait() only ONCE after a page navigation, then screenshot to check. Do not wait repeatedly\n" + "- When the task is fully complete, say . Do not keep clicking after finishing\n" ) + tool_names = [ + t["function"]["name"] for t in self.tools if "function" in t + ] + tool_names_str = ", ".join(tool_names) + system_content = ( f"You are a helpful agent. Complete the task by calling tools.\n\n" f"## Current Date\n" @@ -456,8 +464,10 @@ async def init_async( f"{env_context}{env_hints}{computer_use_hints}\n" f"## Available Tools\n{tools_json}\n\n" f"## Tool Call Format\n" - f'{{"name": "tool_name", "arguments": ' - f'{{"param": "value"}}}}\n\n' + f"Use the tools listed above by name ({tool_names_str}). " + f"Format each call as:\n" + f'{{"name": "", "arguments": ' + f"{{...}}}}\n\n" f"## Error Handling\n" f"If a tool call returns an error:\n" f"- Read the error message carefully\n" From 0f023edf2296b9d37f9f34f8bfa03afe306d8001 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 15:09:19 -0700 Subject: [PATCH 039/121] Add data.env_filter for per-env dataset filtering at training time TASK_GEN_ENV_CLASSES env var was being passed but never read. This adds runtime filtering by data_source column in the parquet, so per-env runs (e.g. outlook-only) only train on that env's seeds. - DataConfig.env_filter: comma-separated data_source values - PromptDataset: filters dataset after load, before prompt length filter - fleet-task-gen-run.sh: reads TASK_GEN_ENV_CLASSES and passes data.env_filter Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-task-gen-run.sh | 9 +++++++++ skyrl/train/config/config.py | 2 ++ skyrl/train/config/ppo_base_config.yaml | 1 + skyrl/train/dataset/dataset.py | 13 +++++++++++++ skyrl/train/entrypoints/main_base.py | 2 ++ 5 files changed, 27 insertions(+) diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh index b23ac3d3d6..e885b380d6 100755 --- a/scripts/fleet-task-gen-run.sh +++ b/scripts/fleet-task-gen-run.sh @@ -14,6 +14,14 @@ set -euo pipefail # Always use random hex suffix for unique run names export RUN_NAME="task_gen_$(python3 -c 'import os; print(os.urandom(4).hex())')" +# Optional: per-env dataset filtering via TASK_GEN_ENV_CLASSES env var +# e.g. TASK_GEN_ENV_CLASSES="outlook" or TASK_GEN_ENV_CLASSES="outlook,booking" +ENV_FILTER_ARGS=() +if [ -n "${TASK_GEN_ENV_CLASSES:-}" ]; then + echo "=== env_filter: $TASK_GEN_ENV_CLASSES ===" + ENV_FILTER_ARGS+=("data.env_filter=$TASK_GEN_ENV_CLASSES") +fi + # Task-gen GRPO training via shared run script # --entrypoint: task-gen entrypoint (not main_fleet) # --env-class: task_gen environment (not fleet_task) @@ -78,4 +86,5 @@ bash scripts/fleet-common-run.sh \ ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ ++environment.skyrl_gym.task_gen.evaluator_model="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" \ ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 \ + "${ENV_FILTER_ARGS[@]}" \ "$@" diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 99c198744e..c04c9634e5 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -41,6 +41,8 @@ def from_dict_config(cls, cfg: DictConfig) -> "BaseConfig": class DataConfig(BaseConfig): train_data: List[str] = field(default_factory=lambda: [os.path.expanduser("~/data/gsm8k/train.parquet")]) val_data: List[str] = field(default_factory=lambda: [os.path.expanduser("~/data/gsm8k/validation.parquet")]) + env_filter: Optional[str] = None + """Comma-separated list of data_source values to include (e.g. 'outlook,github'). None = no filtering.""" # --------------------------------------------------------------------------- diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index 854954de17..a32f40e3f4 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -11,6 +11,7 @@ defaults: data: train_data: ["${oc.env:HOME}/data/gsm8k/train.parquet"] val_data: ["${oc.env:HOME}/data/gsm8k/validation.parquet"] + env_filter: null trainer: placement: diff --git a/skyrl/train/dataset/dataset.py b/skyrl/train/dataset/dataset.py index 82383fd9b9..f49e09f448 100644 --- a/skyrl/train/dataset/dataset.py +++ b/skyrl/train/dataset/dataset.py @@ -15,12 +15,14 @@ def __init__( num_workers: int = 8, prompt_key: str = "prompt", env_class_key: str = "env_class", + env_filter: str | None = None, ): self.tokenizer = tokenizer self.max_prompt_length = max_prompt_length self.prompt_key = prompt_key self.env_class_key = env_class_key self.num_workers = num_workers + self.env_filter = env_filter self.datasets = datasets if isinstance(self.datasets, str): @@ -55,6 +57,17 @@ def _read_files_and_tokenize(self): logger.info(f"Total dataset size: {len(self.dataframe)}") + # Filter by data_source if env_filter is set + if self.env_filter: + allowed = {e.strip() for e in self.env_filter.split(",") if e.strip()} + before = len(self.dataframe) + self.dataframe = self.dataframe.filter( + lambda row: row.get("data_source", "") in allowed, + num_proc=self.num_workers, + desc=f"Filtering by env_filter ({allowed})", + ) + logger.info(f"env_filter={allowed}: {before} -> {len(self.dataframe)} rows") + # filter out too long prompts tokenizer = self.tokenizer prompt_key = self.prompt_key diff --git a/skyrl/train/entrypoints/main_base.py b/skyrl/train/entrypoints/main_base.py index aa81fd4ac3..4de2e4dd16 100644 --- a/skyrl/train/entrypoints/main_base.py +++ b/skyrl/train/entrypoints/main_base.py @@ -168,6 +168,7 @@ def get_train_dataset(self): tokenizer=self.tokenizer, max_prompt_length=self.cfg.trainer.max_prompt_length, num_workers=8, + env_filter=getattr(self.cfg.data, "env_filter", None), ) # make sure the dataset is large enough to train on assert ( @@ -187,6 +188,7 @@ def get_eval_dataset(self): tokenizer=self.tokenizer, max_prompt_length=self.cfg.trainer.max_prompt_length, num_workers=8, + env_filter=getattr(self.cfg.data, "env_filter", None), ) return prompts_dataset return None From 9cf2d5dbff54c9f05d9376f0fc732b7cfcf665bc Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 15:42:33 -0700 Subject: [PATCH 040/121] fix: re-add --no-pytorch-alloc-conf for vLLM 0.18.0 CuMemAllocator compat vLLM 0.18.0's CuMemAllocator uses cuMemCreate/cuMemMap for its memory pool. PyTorch's expandable_segments:True also uses cuMemCreate/cuMemMap. Two independent cuMem allocators in the same process conflict, causing AssertionError at vLLM engine init. Anti-fragmentation is handled by empty_cache() before backward (9d51e33). Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 1876ad3c25..5daac94c46 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -32,7 +32,7 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ - --set-ulimit \ + --set-ulimit --no-pytorch-alloc-conf \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ From 95450e20177f1a6434d8d4308dfd5c8279f3cf6c Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 15:42:48 -0700 Subject: [PATCH 041/121] docs: update changelog + CLAUDE.md for vLLM 0.18.0 CuMemAllocator fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Section 3 in CHANGELOG.md was incorrect — it said we removed --no-pytorch-alloc-conf, but vLLM 0.18.0's CuMemAllocator conflicts with expandable_segments. Corrected to document that we keep the flag. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 +- integrations/fleet/CHANGELOG.md | 12 ++++++------ scripts/fleet-35b-run.sh | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 66e366c4a2..01159d37f6 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,7 +17,7 @@ Always consult the changelog before modifying Fleet training paths (`fsdp_worker 3. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. -4. **`expandable_segments` for 35B**: `fleet-35b-run.sh` does NOT pass `--no-pytorch-alloc-conf`, so `fleet-common-run.sh` enables `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. 9B runs on RunPod still use `--no-pytorch-alloc-conf` (no `CAP_SYS_PTRACE`). +4. **No `expandable_segments` with vLLM 0.18.0**: `fleet-35b-run.sh` passes `--no-pytorch-alloc-conf` because vLLM 0.18.0's `CuMemAllocator` uses `cuMemCreate`/`cuMemMap` and conflicts with PyTorch's `expandable_segments:True` (also cuMem-based). Anti-fragmentation is handled by `empty_cache()` before backward (fix #2). Old SkyRL (vLLM 0.17.0, `cudaMalloc`) doesn't have this conflict. ## Training Scripts diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index 77e719e587..16108c20ec 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -33,15 +33,15 @@ Ported from fleet-ai/SkyRL PR #328 and PR #333, plus a new fix for hint augmenta **Why upstream SkyRL doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. -#### 3. Enable expandable_segments for 35B (`fleet-35b-run.sh`) +#### 3. Keep `--no-pytorch-alloc-conf` for vLLM 0.18.0 compatibility (`fleet-35b-run.sh`) -**Where:** Removed `--no-pytorch-alloc-conf` flag from `fleet-35b-run.sh` +**Where:** `fleet-35b-run.sh` retains `--no-pytorch-alloc-conf` flag. -**Problem:** Without `expandable_segments:True`, PyTorch allocates fixed-size CUDA segments. Triton autotuning (FlashAttention, MoE kernels) allocates many trial buffers then frees them, leaving a fragmented segment map. Subsequent large allocations fail even though total free memory is sufficient. +**Problem:** SkyRL-v2 uses vLLM 0.18.0 which introduced `CuMemAllocator` — a custom CUDA memory allocator that uses `cuMemCreate`/`cuMemMap` (virtual memory management APIs) for its memory pool. PyTorch's `expandable_segments:True` (set by `fleet-common-run.sh` when `--no-pytorch-alloc-conf` is absent) also uses `cuMemCreate`/`cuMemMap`. Two independent cuMem-based allocators in the same process maintain conflicting bookkeeping of the virtual address space, causing `AssertionError: Expandable segments are not compatible with memory pool` at vLLM engine init. -**Fix:** Remove `--no-pytorch-alloc-conf` so `fleet-common-run.sh` sets `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. GCP has writable `ptrace_scope=0` so vLLM's CuMemAllocator and expandable_segments coexist fine. +**Why this wasn't an issue in the old SkyRL fork:** Old SkyRL uses vLLM 0.17.0 which uses standard `cudaMalloc`/`cudaFree` — no cuMem APIs, no conflict with `expandable_segments`. -**Note:** 9B task-gen runs on RunPod still use `--no-pytorch-alloc-conf` because RunPod containers lack `CAP_SYS_PTRACE` and expandable_segments uses cuMem APIs that need `pidfd_getfd`. +**Fix:** Keep `--no-pytorch-alloc-conf` so `expandable_segments` is never set. CUDA memory fragmentation (the problem `expandable_segments` would solve) is instead mitigated by the `empty_cache()` calls added in fix #2 above, which defragment the PyTorch allocator cache before each backward pass. #### 4. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) @@ -61,5 +61,5 @@ The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) s |------|--------| | `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Synchronous ref offload + barrier | | `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | -| `scripts/fleet-35b-run.sh` | Remove `--no-pytorch-alloc-conf` | +| `scripts/fleet-35b-run.sh` | Keep `--no-pytorch-alloc-conf` (vLLM 0.18.0 CuMemAllocator compat) | | `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 1876ad3c25..5daac94c46 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -32,7 +32,7 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ - --set-ulimit \ + --set-ulimit --no-pytorch-alloc-conf \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ From 90af7d9fd009bf74e7f82191f87fe1386e4725d8 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 16:07:03 -0700 Subject: [PATCH 042/121] Pass data_key, data_version, env_version, env_variables through to parquet TaskGenEnv.init_async() needs data_key to provision the Fleet orchestrator for DB exploration (describe_db, query_db). Without it, self.orch is always None and the model gets stuck in an impossible exploration loop where the env demands tool calls but can't serve them. Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/prepare_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/integrations/fleet/prepare_dataset.py b/integrations/fleet/prepare_dataset.py index 67f02199a9..73d1ffc84e 100644 --- a/integrations/fleet/prepare_dataset.py +++ b/integrations/fleet/prepare_dataset.py @@ -475,7 +475,7 @@ def _task_to_record(task: Dict[str, Any], env_key: str, env_class: str = "fleet_ if not task_key or not prompt: return None - return { + record = { # Required fields for SkyRL "prompt": [{"role": "user", "content": prompt}], "env_class": env_class, @@ -483,7 +483,13 @@ def _task_to_record(task: Dict[str, Any], env_key: str, env_class: str = "fleet_ "task_key": task_key, # Data source for per-environment metrics in WandB "data_source": env_key, + # Environment/data fields needed by TaskGenEnv for orchestrator provisioning + "data_key": task.get("data_key") or "", + "data_version": task.get("data_version") or "", + "env_version": task.get("env_version") or "", + "env_variables": json.dumps(task.get("env_variables") or {}), } + return record def main(): From 5d7c8782f703af0e6fcaa3eb477226fe902cdad6 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 16:39:11 -0700 Subject: [PATCH 043/121] chore: rename wandb project to fleet-tool-use-grpo Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 5daac94c46..f11913155a 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -81,7 +81,7 @@ bash scripts/fleet-common-run.sh \ generator.inject_context_status=true \ generator.context_warning_threshold=0.90 \ trainer.logger="$LOGGER" \ - trainer.project_name="fleet-task-grpo" \ + trainer.project_name="fleet-tool-use-grpo" \ trainer.run_name="fleet_qwen35_35b_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ trainer.resume_mode=latest \ trainer.ckpt_path="$HOME/ckpts/fleet_qwen35_35b_${MODALITY}" \ From ee19cd99f917d864591be9737b747287149524eb Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 16:39:16 -0700 Subject: [PATCH 044/121] chore: rename wandb project to fleet-tool-use-grpo Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 5daac94c46..f11913155a 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -81,7 +81,7 @@ bash scripts/fleet-common-run.sh \ generator.inject_context_status=true \ generator.context_warning_threshold=0.90 \ trainer.logger="$LOGGER" \ - trainer.project_name="fleet-task-grpo" \ + trainer.project_name="fleet-tool-use-grpo" \ trainer.run_name="fleet_qwen35_35b_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ trainer.resume_mode=latest \ trainer.ckpt_path="$HOME/ckpts/fleet_qwen35_35b_${MODALITY}" \ From 89e8799b359b93ebc6a6ef1cc5be1709b8c26710 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 16:57:35 -0700 Subject: [PATCH 045/121] fix: enable flash_attn for 35B training (OOM without it) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Without flash attention, FSDP backward materializes full N×N attention matrices for 97K-length sequences, causing OOM. Old SkyRL fork has flash_attn=true; this was incorrectly set to false in SkyRL-v2. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index f11913155a..180415420a 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ - trainer.flash_attn=false \ + trainer.flash_attn=true \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ From b115b31efad7d086532fd4b70d5d88041f778541 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 16:57:43 -0700 Subject: [PATCH 046/121] fix: enable flash_attn for 35B training (OOM without it) Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index f11913155a..180415420a 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ - trainer.flash_attn=false \ + trainer.flash_attn=true \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ From 2db6c69038db15d8ce152985ea2f22f9d9e07bd5 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 17:45:54 -0700 Subject: [PATCH 047/121] Pass representative env_variables + env_variable_keys per-env to parquet Many tasks in the Fleet JSON lack env_variables (e.g., budget: 477/597, hubspot: 70/70 missing). The original fork collects representative env_variables from the first task that has them and applies to all tasks in that env. This ensures the TaskGenEnv system prompt includes LOGGED_IN_USER, CURRENT_DATE, etc. for task generation guidance. Also adds env_variable_keys column (union of all var keys per env). Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/prepare_dataset.py | 63 ++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/integrations/fleet/prepare_dataset.py b/integrations/fleet/prepare_dataset.py index 73d1ffc84e..c4a5a937bd 100644 --- a/integrations/fleet/prepare_dataset.py +++ b/integrations/fleet/prepare_dataset.py @@ -302,6 +302,31 @@ def prepare_fleet_dataset( env_key = task.get("env_key") or task.get("env_id") or "unknown" tasks_by_env[env_key].append(task) + # Collect per-env metadata: representative env_variables and env_variable_keys + # (mirrors original SkyRL fork's _collect_env_metadata) + env_metadata: Dict[str, Dict[str, Any]] = {} + for env_key, env_tasks_list in tasks_by_env.items(): + all_var_keys: set = set() + representative_env_vars: Dict[str, Any] = {} + for t in env_tasks_list: + env_vars = t.get("env_variables") or {} + if isinstance(env_vars, str): + try: + env_vars = json.loads(env_vars) + except json.JSONDecodeError: + env_vars = {} + all_var_keys.update(env_vars.keys()) + if not representative_env_vars and env_vars: + representative_env_vars = dict(env_vars) + env_metadata[env_key] = { + "env_variable_keys": sorted(all_var_keys), + "env_variables": representative_env_vars, + } + print("\nEnvironment metadata:") + for ek in sorted(env_metadata): + meta = env_metadata[ek] + print(f" {ek}: env_vars={meta['env_variable_keys']}") + # Prepare records with stratified split train_records = [] eval_records = [] @@ -317,7 +342,7 @@ def prepare_fleet_dataset( if env_key in held_out_envs: env_eval_count = 0 for task in env_tasks: - record = _task_to_record(task, env_key, env_class=env_class) + record = _task_to_record(task, env_key, env_class=env_class, env_meta=env_metadata.get(env_key)) if record: eval_records.append(record) env_eval_count += 1 @@ -332,7 +357,7 @@ def prepare_fleet_dataset( if target_eval_size < MIN_EVAL_SAMPLES: env_train_count = 0 for task in env_tasks: - record = _task_to_record(task, env_key, env_class=env_class) + record = _task_to_record(task, env_key, env_class=env_class, env_meta=env_metadata.get(env_key)) if record: train_records.append(record) env_train_count += 1 @@ -348,7 +373,7 @@ def prepare_fleet_dataset( env_eval = 0 for task in env_tasks: task_key = task.get("key") or task.get("task_key") - record = _task_to_record(task, env_key, env_class=env_class) + record = _task_to_record(task, env_key, env_class=env_class, env_meta=env_metadata.get(env_key)) if not record: continue @@ -467,14 +492,39 @@ def prepare_fleet_dataset( ) -def _task_to_record(task: Dict[str, Any], env_key: str, env_class: str = "fleet_task") -> Optional[Dict[str, Any]]: - """Convert a task dict to a dataset record.""" +def _task_to_record( + task: Dict[str, Any], + env_key: str, + env_class: str = "fleet_task", + env_meta: Optional[Dict[str, Any]] = None, +) -> Optional[Dict[str, Any]]: + """Convert a task dict to a dataset record. + + Args: + task: Task dict from Fleet JSON + env_key: Environment identifier + env_class: SkyRL env class (fleet_task or task_gen) + env_meta: Per-env metadata with representative env_variables and env_variable_keys + """ task_key = task.get("key") or task.get("task_key") prompt = task.get("prompt", "") if not task_key or not prompt: return None + # Use per-task env_variables if available, otherwise fall back to + # representative per-env values (some tasks lack env_variables) + task_env_vars = task.get("env_variables") or {} + if isinstance(task_env_vars, str): + try: + task_env_vars = json.loads(task_env_vars) + except json.JSONDecodeError: + task_env_vars = {} + if not task_env_vars and env_meta: + task_env_vars = env_meta.get("env_variables", {}) + + env_var_keys = (env_meta or {}).get("env_variable_keys", []) + record = { # Required fields for SkyRL "prompt": [{"role": "user", "content": prompt}], @@ -487,7 +537,8 @@ def _task_to_record(task: Dict[str, Any], env_key: str, env_class: str = "fleet_ "data_key": task.get("data_key") or "", "data_version": task.get("data_version") or "", "env_version": task.get("env_version") or "", - "env_variables": json.dumps(task.get("env_variables") or {}), + "env_variables": json.dumps(task_env_vars), + "env_variable_keys": json.dumps(env_var_keys), } return record From 42fe8fd4ae070a8f1da268e8712eb2fe0de39f67 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 17:58:16 -0700 Subject: [PATCH 048/121] feat(tinker): add stop-sequences, top-p, loss-fn args and fix avg_raw_reward metric Add missing CLI arguments to main_fleet_tinker.py: - --stop-sequences: JSON list of stop sequences (e.g. '[""]') - --top-p: nucleus sampling parameter - --temperature: sampling temperature - --loss-fn: configurable loss function for forward_backward (was hardcoded "ppo") Wire these through collect_fleet_rollout() and collect_batch_rollouts(). Fix KeyError on metrics['reward/avg_raw_reward'] by adding it to compute_rollout_metrics(). Add fleet-tinker-tool-use-run.sh launch script mirroring fleet-35b-run.sh config for the Tinker backend. Co-Authored-By: Claude Opus 4.6 --- .../fleet/entrypoints/main_fleet_tinker.py | 63 +++++++++++++++++-- scripts/fleet-tinker-tool-use-run.sh | 34 ++++++++++ 2 files changed, 91 insertions(+), 6 deletions(-) create mode 100755 scripts/fleet-tinker-tool-use-run.sh diff --git a/integrations/fleet/entrypoints/main_fleet_tinker.py b/integrations/fleet/entrypoints/main_fleet_tinker.py index a442f42420..d7f0464315 100644 --- a/integrations/fleet/entrypoints/main_fleet_tinker.py +++ b/integrations/fleet/entrypoints/main_fleet_tinker.py @@ -247,6 +247,7 @@ def compute_rollout_metrics( # Core reward metrics using shared module core_metrics = compute_reward_metrics(rewards, uids, n_samples_per_prompt) metrics[f"reward/avg_pass_at_{n_samples_per_prompt}"] = core_metrics[f"pass_at_{n_samples_per_prompt}"] + metrics["reward/avg_raw_reward"] = np.mean(rewards) metrics["reward/variance_per_prompt"] = core_metrics["variance_per_prompt"] metrics["reward/mean_positive_reward"] = core_metrics["mean_positive_reward"] @@ -409,6 +410,8 @@ async def collect_fleet_rollout( max_generate_length: int = 2048, max_input_length: int = 30720, temperature: float = 1.0, + top_p: float = 1.0, + stop_sequences: List[str] = None, ) -> Dict[str, Any]: """ Collect a single trajectory using Fleet environment and Tinker inference. @@ -465,11 +468,14 @@ async def collect_fleet_rollout( # Generate with Tinker gen_start = time.time() - sampling_params = types.SamplingParams( - max_tokens=max_generate_length, - temperature=temperature, - top_p=1.0, - ) + sampling_params_kwargs = { + "max_tokens": max_generate_length, + "temperature": temperature, + "top_p": top_p, + } + if stop_sequences: + sampling_params_kwargs["stop"] = stop_sequences + sampling_params = types.SamplingParams(**sampling_params_kwargs) # Use async sampling to avoid blocking the event loop result = await sampling_client.sample_async( @@ -550,6 +556,9 @@ async def collect_batch_rollouts( max_input_length: int = 30720, n_samples_per_prompt: int = 1, max_concurrent: int = 8, + temperature: float = 1.0, + top_p: float = 1.0, + stop_sequences: List[str] = None, ) -> List[Dict[str, Any]]: """Collect rollouts for a batch of tasks with limited concurrency. @@ -573,6 +582,9 @@ async def collect_single_rollout(task_config: Dict[str, Any], index: int) -> tup max_turns=max_turns, max_generate_length=max_generate_length, max_input_length=max_input_length, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, ) return index, rollout except Exception as e: @@ -646,6 +658,10 @@ async def main( seed: int = 42, wandb_project: str = "fleet-tinker-grpo", wandb_name: str = None, + temperature: float = 1.0, + top_p: float = 1.0, + stop_sequences: List[str] = None, + loss_fn: str = "ppo", ): """ Main training loop using Tinker for training/inference and Fleet for environments. @@ -657,6 +673,9 @@ async def main( wandb_name = f"{model_name.split('/')[-1]}_{datetime.now().strftime('%m%d_%H%M')}" # Initialize WandB + if stop_sequences is None: + stop_sequences = [] + wandb.init( project=wandb_project, name=wandb_name, @@ -670,6 +689,10 @@ async def main( "max_input_length": max_input_length, "max_sequence_length": max_sequence_length, "n_samples_per_prompt": n_samples_per_prompt, + "temperature": temperature, + "top_p": top_p, + "stop_sequences": stop_sequences, + "loss_fn": loss_fn, }, ) @@ -744,6 +767,9 @@ def create_dataloader(epoch: int): max_generate_length=max_generate_length, max_input_length=max_input_length, n_samples_per_prompt=n_samples_per_prompt, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, ) metrics["time/rollout"] = time.time() - rollout_start @@ -820,7 +846,7 @@ def create_dataloader(epoch: int): logger.info(f"Step {step}: Training on {len(training_datums)} sequences...") train_start = time.time() - fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn="ppo") + fwd_bwd_future = training_client.forward_backward(training_datums, loss_fn=loss_fn) optim_step_future = training_client.optim_step(adam_params) fwd_bwd_future.result() @@ -855,6 +881,9 @@ def create_dataloader(epoch: int): max_generate_length=max_generate_length, max_input_length=max_input_length, n_samples_per_prompt=1, + temperature=temperature, + top_p=top_p, + stop_sequences=stop_sequences, ) all_eval_rollouts.extend([r for r in eval_rollouts if not r.error]) @@ -912,9 +941,27 @@ def create_dataloader(epoch: int): default=False, help="Track additional gradient metrics (for parity with SkyRL config)", ) + parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") + parser.add_argument("--top-p", type=float, default=1.0, help="Top-p (nucleus) sampling") + parser.add_argument( + "--stop-sequences", + type=str, + default="[]", + help="JSON list of stop sequences (e.g. '[\"\"]')", + ) + parser.add_argument( + "--loss-fn", + type=str, + default="ppo", + help="Loss function for Tinker forward_backward (e.g. ppo, grpo)", + ) args = parser.parse_args() + import json as _json + + stop_sequences = _json.loads(args.stop_sequences) + asyncio.run( main( model_name=args.model_name, @@ -935,5 +982,9 @@ def create_dataloader(epoch: int): seed=args.seed, wandb_project=args.wandb_project, wandb_name=args.wandb_name, + temperature=args.temperature, + top_p=args.top_p, + stop_sequences=stop_sequences, + loss_fn=args.loss_fn, ) ) diff --git a/scripts/fleet-tinker-tool-use-run.sh b/scripts/fleet-tinker-tool-use-run.sh new file mode 100755 index 0000000000..380a67ff44 --- /dev/null +++ b/scripts/fleet-tinker-tool-use-run.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +# Launch Fleet tool-use training via Tinker hosted service. +# Mirrors fleet-35b-run.sh config but uses the Tinker backend. +# +# Required env vars: TINKER_API_KEY, FLEET_API_KEY, WANDB_API_KEY +# Optional: TINKER_API_URL (SDK uses default if not set) +set -euo pipefail + +export TINKER_API_KEY="${TINKER_API_KEY:?Set TINKER_API_KEY}" +export TINKER_API_URL="${TINKER_API_URL:-}" +export FLEET_API_KEY="${FLEET_API_KEY:?Set FLEET_API_KEY}" +export WANDB_API_KEY="${WANDB_API_KEY:?Set WANDB_API_KEY}" + +cd "$(dirname "$0")/.." # cd to SkyRL-v2 root + +python -m integrations.fleet.entrypoints.main_fleet_tinker \ + --model-name Qwen/Qwen3.5-35B-A3B \ + --tasks-file "${TASKS_FILE:?Set TASKS_FILE}" \ + --dataset-file "${DATASET_FILE:?Set DATASET_FILE}" \ + --batch-size 16 \ + --learning-rate 5.0e-7 \ + --lora-rank 16 \ + --max-steps 200 \ + --max-turns 50 \ + --max-generate-length 4096 \ + --max-input-length 96000 \ + --n-samples-per-prompt 8 \ + --eval-every 20 \ + --temperature 0.9 \ + --top-p 0.95 \ + --stop-sequences '[""]' \ + --loss-fn ppo \ + --wandb-project fleet-tinker-grpo \ + "$@" From 4a1b65d0bbc818889c6d0ac5c111c64e9fa295bb Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 18:03:00 -0700 Subject: [PATCH 049/121] fix: use async env methods to prevent event loop isolation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The generator was calling env.step()/close()/init() via _run_in_executor_if_available() which runs sync methods in a thread pool. The sync methods use asyncio.run() creating NEW event loops, breaking MCP transports bound to the main event loop (→ TCPTransport closed, 100% verifier failure rate). Add _env_init/_env_step/_env_close that prefer async variants when available (init_async/step_async/close_async), keeping everything in the main event loop. Falls back to executor for envs without async methods. Co-Authored-By: Claude Opus 4.6 --- skyrl/train/generators/skyrl_gym_generator.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 8efd53d2bf..56602b5f9a 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -252,6 +252,24 @@ async def _run_in_executor_if_available(self, func, *args, **kwargs): else: return func(*args, **kwargs) + async def _env_init(self, env, *args, **kwargs): + """Call env.init, using async path if available to avoid event loop isolation.""" + if hasattr(env, "init_async"): + return await env.init_async(*args, **kwargs) + return await self._run_in_executor_if_available(env.init, *args, **kwargs) + + async def _env_step(self, env, action): + """Call env.step, using async path if available to avoid event loop isolation.""" + if hasattr(env, "step_async"): + return await env.step_async(action) + return await self._run_in_executor_if_available(env.step, action) + + async def _env_close(self, env): + """Call env.close, using async path if available to avoid event loop isolation.""" + if hasattr(env, "close_async"): + return await env.close_async() + return await self._run_in_executor_if_available(env.close) + def _make_zero_reward_output( self, prompt: ConversationType, @@ -330,7 +348,7 @@ async def agent_loop( # init() returns the first prompt to be given to the model, and optional metadata dict try: - chat_history, _ = await self._run_in_executor_if_available(env.init, chat_history) + chat_history, _ = await self._env_init(env, chat_history) except Exception as e: logger.warning(f"Session {session_id}: env.init failed ({type(e).__name__}: {e}), returning zero-reward trajectory") # Return a minimal failed trajectory so training can continue @@ -468,7 +486,7 @@ async def agent_loop( added_eos = True # 2. Environment step - env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) + env_step_output: BaseTextEnvStepOutput = await self._env_step(env, output) new_obs = env_step_output["observations"] step_reward: float = env_step_output["reward"] agent_loop_state.done = env_step_output["done"] @@ -550,7 +568,7 @@ async def agent_loop( # Get environment-specific metrics after the episode is done env_metrics = env.get_metrics() # Close the environment - await self._run_in_executor_if_available(env.close) + await self._env_close(env) prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None @@ -921,7 +939,7 @@ async def generate_batched( env_extra["max_turns"] = self.max_turns env_config = getattr(self.skyrl_gym_cfg, env_class, dict()) env = skyrl_gym.make(env_class, env_config=env_config, extras=env_extra) - init_prompt, _ = await self._run_in_executor_if_available(env.init, prompt) + init_prompt, _ = await self._env_init(env, prompt) init_prompts.append(init_prompt) envs.append(env) @@ -949,7 +967,7 @@ async def generate_batched( for i, (output, response, env, env_class) in enumerate(zip(outputs, responses, envs, env_classes)): # step on environment and compute reward - env_step_output: BaseTextEnvStepOutput = await self._run_in_executor_if_available(env.step, output) + env_step_output: BaseTextEnvStepOutput = await self._env_step(env, output) reward = env_step_output["reward"] rewards.append(reward) @@ -968,7 +986,7 @@ async def generate_batched( # Get environment-specific metrics env_metrics.append(env.get_metrics()) # Close the environment - await self._run_in_executor_if_available(env.close) + await self._env_close(env) rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) From 3a69f08da8b8322189eea374e288cb1da66b5417 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 18:24:09 -0700 Subject: [PATCH 050/121] Port chunked lm_head forward + rewrite CHANGELOG as coherent document Code changes (same as fleet/training): - Port loss_chunk_size to HFModelWrapper: identity lm_head trick returns hidden states (B,S,8192) instead of logits (B,S,131072), then computes logits in 4096-token chunks with gradient checkpointing - Pass loss_chunk_size from fsdp_worker.py for policy and ref model - Revert flash_attn back to false (GatedDeltaNet Xid 31 in FSDP2) Docs: - Rewrite CHANGELOG.md as clean coherent document with all 6 fixes - Update CLAUDE.md to match Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 12 +- integrations/fleet/CHANGELOG.md | 70 ++++-- scripts/fleet-35b-run.sh | 2 +- .../skyrl_train/workers/fsdp/fsdp_worker.py | 2 + .../skyrl_train/workers/model_wrapper.py | 213 ++++++++++++++++-- 5 files changed, 247 insertions(+), 52 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 01159d37f6..520b5028e3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -7,17 +7,21 @@ Fork of SkyRL with Fleet-specific optimizations for multi-node FSDP2 training at Fleet-specific changes, fixes, and context are documented in: - **[integrations/fleet/CHANGELOG.md](integrations/fleet/CHANGELOG.md)** — detailed changelog with root causes and fixes -Always consult the changelog before modifying Fleet training paths (`fsdp_worker.py`, `worker.py`, `dispatch.py`, `fleet-*.sh`). +Always consult the changelog before modifying Fleet training paths (`fsdp_worker.py`, `worker.py`, `model_wrapper.py`, `dispatch.py`, `fleet-*.sh`). ## Key Differences from Upstream SkyRL 1. **Multi-node FSDP2 stability**: Synchronous ref model offload/backload with `torch.distributed.barrier()` in `fsdp_worker.py`. Required because cross-node colocated training has no shared CUDA context. -2. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation on large models with tight GPU memory margins. +2. **Chunked lm_head forward**: `model_wrapper.py` has `loss_chunk_size` support ported from the old fork. Avoids materializing full `(B, S, vocab_size)` logits — critical for 35B with 131K vocab at 97K sequence length. Without it, OOM/Xid 31 during training forward. -3. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. +3. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation. Especially important because `expandable_segments` can't be used (see #5). -4. **No `expandable_segments` with vLLM 0.18.0**: `fleet-35b-run.sh` passes `--no-pytorch-alloc-conf` because vLLM 0.18.0's `CuMemAllocator` uses `cuMemCreate`/`cuMemMap` and conflicts with PyTorch's `expandable_segments:True` (also cuMem-based). Anti-fragmentation is handled by `empty_cache()` before backward (fix #2). Old SkyRL (vLLM 0.17.0, `cudaMalloc`) doesn't have this conflict. +4. **`flash_attn=false` for GatedDeltaNet**: `fleet-35b-run.sh` uses SDPA, not flash_attention_2. Qwen3.5-35B's GatedDeltaNet linear attention layers crash with flash_attn in multi-node FSDP2. Memory savings come from chunked lm_head (#2), not flash attention. + +5. **No `expandable_segments` with vLLM 0.18.0**: `fleet-35b-run.sh` passes `--no-pytorch-alloc-conf` because vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`) conflicts with PyTorch's `expandable_segments:True`. Old SkyRL uses vLLM 0.17.0 (`cudaMalloc`) which has no conflict. + +6. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. ## Training Scripts diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index 16108c20ec..37f1398953 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -1,15 +1,16 @@ # Fleet Integration Changelog -## 2026-03-29: Multi-node FSDP stability + hint batch size fix +## 2026-03-29: Multi-node 35B training parity with old SkyRL fork -Ported from fleet-ai/SkyRL PR #328 and PR #333, plus a new fix for hint augmentation batch sizing. +Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fleet-ai/SkyRL PR #328 and PR #333, plus new fixes for SkyRL-v2-specific issues. -### Problem +### Problems -2-node (16 GPU) Qwen3.5-35B training on GCP H200 crashed with: -1. `cudaErrorIllegalAddress` segfaults during FSDP ref model offload/backload -2. OOM during backward pass from CUDA memory fragmentation -3. `AssertionError: data batch size must be divisible by mini_batch_size, got 160 and 128` when hints are enabled +2-node training crashed with: +1. `cudaErrorIllegalAddress` during FSDP ref model offload/backload (multi-node race) +2. OOM / Xid 31 FAULT_PDE during policy training forward+backward (missing chunked lm_head) +3. `AssertionError: Expandable segments are not compatible with memory pool` at vLLM init (vLLM 0.18.0 vs expandable_segments) +4. `AssertionError: data batch size must be divisible by mini_batch_size, got 160 and 128` (hint augmentation) ### Root causes and fixes @@ -17,41 +18,63 @@ Ported from fleet-ai/SkyRL PR #328 and PR #333, plus a new fix for hint augmenta **Where:** `FSDPRefWorkerBase.offload_to_cpu()` and `backload_to_gpu()` -**Problem:** With colocated models, the trainer cycles: ref on GPU -> ref offload to CPU -> policy on GPU. With `non_blocking=True`, the CPU<-GPU transfer is *queued* but returns immediately. On a single node, CUDA stream ordering serializes this naturally. Across nodes, there's no shared CUDA context -- node 0's policy worker can start touching GPU memory while node 1's ref worker is still mid-transfer. Result: `cudaErrorIllegalAddress`. +**Problem:** With colocated models, the trainer cycles: ref on GPU → ref offload to CPU → policy on GPU. With `non_blocking=True`, the CPU←GPU transfer is *queued* but returns immediately. On a single node, CUDA stream ordering serializes this naturally. Across nodes, there's no shared CUDA context — node 0's policy worker can start touching GPU memory while node 1's ref worker is still mid-transfer. Result: `cudaErrorIllegalAddress`. **Fix:** `non_blocking=False` (wait for transfer) + `torch.distributed.barrier()` (all ranks synchronize). Guarantees every GPU finishes offloading before any policy worker starts backloading. -**Why upstream SkyRL doesn't need this:** Designed for single-node where all workers share the same CUDA context and stream ordering prevents races. +**Why the old fork doesn't need this:** Designed for single-node where all workers share the same CUDA context and stream ordering prevents races. -#### 2. empty_cache before backward (`worker.py`) +#### 2. Port chunked lm_head forward (`model_wrapper.py`, `fsdp_worker.py`) + +**Where:** `HFModelWrapper.forward()` and `HFModelWrapper._chunked_lm_head_forward()` + +**Problem:** SkyRL-v2's `HFModelWrapper` was missing `loss_chunk_size` support entirely — the parameter existed in config but was never passed through `fsdp_worker.py` to the model wrapper. Without it, the model materializes the full `(B, S, 131072)` logits tensor during forward pass (~10 GB for 97K-length sequences on Qwen3.5-35B with vocab_size=131072). This consumed so much GPU memory that the subsequent training forward pass (with gradients enabled) hit OOM or Xid 31 FAULT_PDE when FSDP tried to unshard parameters. + +**Fix:** Ported the chunked lm_head implementation from the old fork: +- Added `loss_chunk_size` parameter to `HFModelWrapper.__init__` +- Pass `loss_chunk_size` from `fsdp_worker.py` for both policy and ref model init +- During forward, replace `lm_head` with an identity module so the model returns hidden states `(B, S, 8192)` instead of logits `(B, S, 131072)` — 16x smaller +- Compute logits in chunks of 4096 tokens with gradient checkpointing, never materializing full logits + +**Why the old fork doesn't have this problem:** It already has `loss_chunk_size` support and passes it correctly. + +#### 3. `empty_cache` before backward (`worker.py`) **Where:** `PolicyWorkerBase._forward_backward_micro()` (both SFT and RL paths) and `CriticWorkerBase._forward_backward_micro()` -**Problem:** After the forward pass, freed intermediate tensors stay in PyTorch's CUDA cache as scattered blocks. The backward pass needs large contiguous allocations for gradients. On the 35B model with tight GPU memory margins, the fragmented cache can't satisfy these allocations -> OOM, even though total free memory is sufficient. +**Problem:** After the forward pass, freed intermediate tensors stay in PyTorch's CUDA cache as scattered blocks. The backward pass needs large contiguous allocations for gradients. On the 35B model with tight GPU memory margins, the fragmented cache can't satisfy these allocations → OOM, even though total free memory is sufficient. -**Fix:** `torch.cuda.empty_cache()` before `strategy.backward()`. Returns cached blocks to CUDA which coalesces them into contiguous allocations. +**Fix:** `torch.cuda.empty_cache()` before `strategy.backward()`. Returns cached blocks to CUDA which coalesces them into contiguous allocations. This is especially important because `expandable_segments:True` cannot be used (see fix #4). -**Why upstream SkyRL doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. +**Why the old fork doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. -#### 3. Keep `--no-pytorch-alloc-conf` for vLLM 0.18.0 compatibility (`fleet-35b-run.sh`) +#### 4. Keep `--no-pytorch-alloc-conf` for vLLM 0.18.0 (`fleet-35b-run.sh`) **Where:** `fleet-35b-run.sh` retains `--no-pytorch-alloc-conf` flag. -**Problem:** SkyRL-v2 uses vLLM 0.18.0 which introduced `CuMemAllocator` — a custom CUDA memory allocator that uses `cuMemCreate`/`cuMemMap` (virtual memory management APIs) for its memory pool. PyTorch's `expandable_segments:True` (set by `fleet-common-run.sh` when `--no-pytorch-alloc-conf` is absent) also uses `cuMemCreate`/`cuMemMap`. Two independent cuMem-based allocators in the same process maintain conflicting bookkeeping of the virtual address space, causing `AssertionError: Expandable segments are not compatible with memory pool` at vLLM engine init. +**Problem:** SkyRL-v2 uses vLLM 0.18.0 which introduced `CuMemAllocator` — a custom CUDA memory allocator using `cuMemCreate`/`cuMemMap` for its memory pool. PyTorch's `expandable_segments:True` (set by `fleet-common-run.sh` when `--no-pytorch-alloc-conf` is absent) also uses `cuMemCreate`/`cuMemMap`. Two independent cuMem-based allocators in the same process conflict → `AssertionError: Expandable segments are not compatible with memory pool` at vLLM engine init. + +**Fix:** Keep `--no-pytorch-alloc-conf` so `expandable_segments` is never set. Anti-fragmentation is handled by `empty_cache()` (fix #3) and chunked lm_head (fix #2). + +**Why the old fork doesn't have this:** Old SkyRL uses vLLM 0.17.0 (`cudaMalloc`/`cudaFree`, no cuMem APIs, no conflict). + +#### 5. `flash_attn=false` for Qwen3.5-35B GatedDeltaNet (`fleet-35b-run.sh`) + +**Where:** `fleet-35b-run.sh` sets `trainer.flash_attn=false`. -**Why this wasn't an issue in the old SkyRL fork:** Old SkyRL uses vLLM 0.17.0 which uses standard `cudaMalloc`/`cudaFree` — no cuMem APIs, no conflict with `expandable_segments`. +**Problem:** Qwen3.5-35B uses GatedDeltaNet architecture which alternates softmax attention layers with linear attention layers (`torch_chunk_gated_delta_rule`). Setting `flash_attn=true` → `attn_implementation="flash_attention_2"` causes Xid 31 FAULT_PDE during the GDN linear attention layers' tensor allocation in multi-node FSDP2 training. -**Fix:** Keep `--no-pytorch-alloc-conf` so `expandable_segments` is never set. CUDA memory fragmentation (the problem `expandable_segments` would solve) is instead mitigated by the `empty_cache()` calls added in fix #2 above, which defragment the PyTorch allocator cache before each backward pass. +**Fix:** Use `flash_attn=false` (SDPA). The old fork has `flash_attn=true` but the real memory savings there come from chunked lm_head (fix #2), not flash attention itself. With chunked lm_head ported, SDPA provides sufficient memory headroom. -#### 4. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) +#### 6. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) **Where:** `MeshDispatch.stage_chunks()` -**Problem:** `mini_batch_size` is computed as `policy_mini_batch_size * n_samples_per_prompt` (e.g., 16 * 8 = 128). But hint augmentation appends extra samples: 16 prompts * 2 hints = 32 additional, total batch = 160. The `stage_chunks` method asserted `160 % 128 == 0` -> crash. +**Problem:** `mini_batch_size` is computed as `policy_mini_batch_size * n_samples_per_prompt` (e.g., 16 × 8 = 128). But hint augmentation appends extra samples: 16 prompts × 2 hints = 32 additional, total batch = 160. The `stage_chunks` method asserted `160 % 128 == 0` → crash. -The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) silently dropped the 32 hint samples -- no crash, but hint training was wasted. +The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) silently dropped the 32 hint samples — no crash, but hint training was wasted. -**Fix:** When batch size isn't divisible by mini_batch_size, step down mini_batch_size (by `dp_size` increments to stay DP-divisible) until it divides evenly. For 160 samples with dp_size=16: adjusts from 128 -> 80, giving 2 mini-batches of 80. All 160 samples (including hints) are trained on. +**Fix:** When batch size isn't divisible by mini_batch_size, step down mini_batch_size (by `dp_size` increments to stay DP-divisible) until it divides evenly. For 160 samples with dp_size=16: adjusts from 128 → 80, giving 2 mini-batches of 80. All 160 samples (including hints) are trained on. **Why upstream SkyRL doesn't have this:** Upstream uses a simple `for` loop with `//` division (no `stage_chunks` optimization). The `stage_chunks` pre-staging is a SkyRL-v2 optimization that added a strict assert the old code path never had. @@ -59,7 +82,8 @@ The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) s | File | Change | |------|--------| -| `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Synchronous ref offload + barrier | +| `skyrl/backends/skyrl_train/workers/model_wrapper.py` | Port chunked lm_head forward (loss_chunk_size) | +| `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Pass loss_chunk_size to HFModelWrapper; synchronous ref offload + barrier | | `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | -| `scripts/fleet-35b-run.sh` | Keep `--no-pytorch-alloc-conf` (vLLM 0.18.0 CuMemAllocator compat) | +| `scripts/fleet-35b-run.sh` | flash_attn=false, --no-pytorch-alloc-conf, wandb project rename | | `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 180415420a..f11913155a 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ - trainer.flash_attn=true \ + trainer.flash_attn=false \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ diff --git a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py index 779545d90e..d66ccbc7b7 100644 --- a/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py @@ -188,6 +188,7 @@ def init_model(self, model_path, num_training_steps: int = None): rope_scaling=get_rope_scaling_config(self.cfg), rope_theta=get_rope_theta_config(self.cfg), model_config_kwargs=self.cfg.policy.model_config_kwargs, + loss_chunk_size=self.cfg.loss_chunk_size, ) # in-place patch self._seq_parallel_monkey_patch(model=wrapped_model.model) @@ -433,6 +434,7 @@ def init_model(self, model_path): rope_scaling=get_rope_scaling_config(self.cfg), rope_theta=get_rope_theta_config(self.cfg), model_config_kwargs=self.cfg.ref.model_config_kwargs, + loss_chunk_size=self.cfg.loss_chunk_size, ) self._seq_parallel_monkey_patch(model=wrapped_model.model) diff --git a/skyrl/backends/skyrl_train/workers/model_wrapper.py b/skyrl/backends/skyrl_train/workers/model_wrapper.py index 3eb45f80a7..5bd10112a4 100644 --- a/skyrl/backends/skyrl_train/workers/model_wrapper.py +++ b/skyrl/backends/skyrl_train/workers/model_wrapper.py @@ -8,7 +8,9 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F import transformers +from torch.utils.checkpoint import checkpoint as gradient_checkpoint from flash_attn.bert_padding import pad_input, unpad_input from loguru import logger from packaging.version import Version @@ -33,6 +35,36 @@ ) +def _chunked_logprobs_only(hidden_chunk, labels_chunk, weight, bias, temperature): + """Compute logprobs for one chunk. Used inside gradient_checkpoint. + Uses F.linear with raw weight/bias to avoid DTensor issues with FSDP2. + """ + logits = F.linear(hidden_chunk, weight, bias) + logits = logits / temperature + return logprobs_from_logits(logits, labels_chunk, inplace_backward=False) + + +def _chunked_logprobs_and_entropy(hidden_chunk, labels_chunk, weight, bias, temperature): + """Compute logprobs and entropy for one chunk. Used inside gradient_checkpoint. + Uses F.linear with raw weight/bias to avoid DTensor issues with FSDP2. + """ + logits = F.linear(hidden_chunk, weight, bias) + logits = logits / temperature + lp = logprobs_from_logits(logits, labels_chunk, inplace_backward=False) + log_softmax_vals = F.log_softmax(logits, dim=-1) + entropy = -(log_softmax_vals.exp() * log_softmax_vals).sum(dim=-1) + return lp, entropy + + +class _IdentityLMHead(nn.Module): + """Dummy lm_head that passes hidden states through unchanged. + Used to prevent the HF model from materializing the full (B, S, vocab) logits tensor. + """ + + def forward(self, x): + return x + + class HFModelWrapper(nn.Module): """ Base class for wrapped HF models in reinforcement learning. @@ -74,6 +106,7 @@ def __init__( use_liger_kernel=False, sequence_parallel_size=1, use_sample_packing: bool = False, + loss_chunk_size: int = 0, use_torch_compile: bool = False, rope_scaling: Dict[str, Any] = {}, rope_theta: float | None = None, @@ -85,6 +118,7 @@ def __init__( self.sequence_parallel_size = sequence_parallel_size self.attn_implementation = "flash_attention_2" if use_flash_attention_2 else "sdpa" self.use_sample_packing = use_sample_packing + self.loss_chunk_size = loss_chunk_size self.is_vlm = False # packing samples using Flash Attention 2 if use_sample_packing: @@ -351,31 +385,62 @@ def forward( sequences_rolled, None, None, self.sequence_parallel_size ) - if self.is_vlm: - output = self.model( - sequences_fwd, - attention_mask=attention_mask_fwd, - position_ids=None, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - ) - # NOTE (sumanthrh): Once we have position_ids, we don't need attention mask with flash attention. - elif self.use_sample_packing and self.attn_implementation == "flash_attention_2": - # NOTE (sumanthrh): Don't use attention mask. position_ids is enough. - # Not using attention mask leads to higher perf since flash attention varlen func is enabled - output = self.model(sequences_fwd, attention_mask=None, position_ids=position_ids_fwd) - else: - output = self.model(sequences_fwd, attention_mask=attention_mask_fwd, position_ids=position_ids_fwd) + use_chunked = self.loss_chunk_size > 0 - logits_BSV = output["logits"] - logits_BSV.div_(temperature) + if use_chunked: + # Chunked lm_head: avoid materializing full (B, S, vocab_size) logits tensor. + # Replace lm_head with identity so the model returns hidden states instead. + lm_head = self.model.lm_head + self.model.lm_head = _IdentityLMHead() - # NOTE: this is slightly inaccurate with sample packing because last token from nth seq -> first token of n+1th seq loss is added. - log_probs = logprobs_from_logits( - logits_BSV, - sequences_rolled, - inplace_backward=True, - ) + try: + if self.is_vlm: + output = self.model( + sequences_fwd, + attention_mask=attention_mask_fwd, + position_ids=None, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + # NOTE (sumanthrh): Once we have position_ids, we don't need attention mask with flash attention. + elif self.use_sample_packing and self.attn_implementation == "flash_attention_2": + # NOTE (sumanthrh): Don't use attention mask. position_ids is enough. + # Not using attention mask leads to higher perf since flash attention varlen func is enabled + output = self.model(sequences_fwd, attention_mask=None, position_ids=position_ids_fwd) + else: + output = self.model(sequences_fwd, attention_mask=attention_mask_fwd, position_ids=position_ids_fwd) + finally: + if use_chunked: + self.model.lm_head = lm_head + + if use_chunked: + # output["logits"] is actually hidden_states (B, S, hidden_dim) since lm_head was identity + hidden_states = output["logits"] + entropy_mask = None + if compute_entropy and not self.use_sample_packing: + entropy_mask = attention_mask_fwd + log_probs, entropy_BS = self._chunked_lm_head_forward( + hidden_states, + lm_head, + sequences_rolled, + temperature, + self.loss_chunk_size, + compute_entropy=compute_entropy, + entropy_requires_grad=entropy_requires_grad, + attention_mask=entropy_mask, + ) + # Replace hidden_states in output with None to free memory + output["logits"] = None + else: + logits_BSV = output["logits"] + logits_BSV.div_(temperature) + + # NOTE: this is slightly inaccurate with sample packing because last token from nth seq -> first token of n+1th seq loss is added. + log_probs = logprobs_from_logits( + logits_BSV, + sequences_rolled, + inplace_backward=True, + ) # gather output if sp > 1 if self.sequence_parallel_size > 1: @@ -392,7 +457,20 @@ def forward( log_probs.transpose(0, 1), indices=nnz_indices, batch=batch_size, seqlen=seqlen ).squeeze(-1) - if compute_entropy: + if use_chunked: + # Entropy already computed in _chunked_lm_head_forward + if compute_entropy: + if self.sequence_parallel_size > 1: + dim = entropy_BS.ndim - 1 + entropy_BS = gather_outputs_and_unpad( + entropy_BS, gather_dim=dim, unpad_dim=dim, padding_size=pad_size + ) + if self.use_sample_packing: + entropy_BS = pad_input( + entropy_BS.transpose(0, 1), indices=nnz_indices, batch=batch_size, seqlen=seqlen + ).squeeze(-1) + output["entropy"] = entropy_BS + elif compute_entropy: # For sample packing: entropy is calculated on unpacked data, so no attention mask needed # For non-sample packing: pass the attention mask to exclude padding tokens entropy_mask = None @@ -431,6 +509,93 @@ def forward( else: return action_log_probs + def _chunked_lm_head_forward( + self, + hidden_states, + lm_head, + labels, + temperature, + chunk_size, + compute_entropy=False, + entropy_requires_grad=True, + attention_mask=None, + ): + """Compute log_probs (and optionally entropy) via chunked lm_head projection. + + Instead of materializing the full (B, S, vocab_size) logits tensor, this + computes lm_head in chunks of `chunk_size` tokens along the sequence dimension. + Each chunk uses gradient checkpointing so logits are recomputed during backward + rather than stored, keeping peak memory at (B, chunk_size, vocab_size). + """ + B, S, H = hidden_states.shape + all_log_probs = [] + all_entropy = [] if compute_entropy else None + + # Extract weight/bias from lm_head module. With FSDP2, parameters are DTensors; + # calling the module inside gradient_checkpoint causes DTensor/Tensor mismatch. + # We all-gather DTensors to regular tensors via full_tensor() which is differentiable. + weight = lm_head.weight + bias = lm_head.bias + try: + from torch.distributed.tensor import DTensor + + if isinstance(weight, DTensor): + weight = weight.full_tensor() + if bias is not None and isinstance(bias, DTensor): + bias = bias.full_tensor() + except ImportError: + pass + + # When not computing gradients (ref model), skip gradient_checkpoint entirely — + # just compute each chunk directly with no_grad already active from caller. + use_checkpointing = torch.is_grad_enabled() + + for start in range(0, S, chunk_size): + end = min(start + chunk_size, S) + chunk_hidden = hidden_states[:, start:end] + chunk_labels = labels[:, start:end] + + if compute_entropy: + if use_checkpointing: + chunk_lp, chunk_ent = gradient_checkpoint( + _chunked_logprobs_and_entropy, + chunk_hidden, + chunk_labels, + weight, + bias, + temperature, + use_reentrant=False, + ) + else: + chunk_lp, chunk_ent = _chunked_logprobs_and_entropy( + chunk_hidden, chunk_labels, weight, bias, temperature + ) + if not entropy_requires_grad: + chunk_ent = chunk_ent.detach() + if attention_mask is not None: + chunk_mask = attention_mask[:, start:end] + chunk_ent = chunk_ent * chunk_mask + all_entropy.append(chunk_ent) + else: + if use_checkpointing: + chunk_lp = gradient_checkpoint( + _chunked_logprobs_only, + chunk_hidden, + chunk_labels, + weight, + bias, + temperature, + use_reentrant=False, + ) + else: + chunk_lp = _chunked_logprobs_only(chunk_hidden, chunk_labels, weight, bias, temperature) + + all_log_probs.append(chunk_lp) + + log_probs = torch.cat(all_log_probs, dim=1) + entropy = torch.cat(all_entropy, dim=1) if compute_entropy else None + return log_probs, entropy + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) From 3c7d2e1b32d4db539492e6e371dea30db37c573c Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 18:24:27 -0700 Subject: [PATCH 051/121] Fix apply_overlong_filtering call signature SkyRL-v2's apply_overlong_filtering takes (loss_masks, stop_reasons) not (loss_masks, response_ids, eos_token_id). Updated the call in prepare_training_data() to match. Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/entrypoints/main_fleet_tinker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/fleet/entrypoints/main_fleet_tinker.py b/integrations/fleet/entrypoints/main_fleet_tinker.py index d7f0464315..83ab6d2af5 100644 --- a/integrations/fleet/entrypoints/main_fleet_tinker.py +++ b/integrations/fleet/entrypoints/main_fleet_tinker.py @@ -323,10 +323,10 @@ def prepare_training_data( Returns: Tuple of (training_datums, truncated_count) """ - # Apply DAPO overlong filtering (zero out loss mask if response doesn't end with EOS) - all_response_ids = [r.response_ids for r in rollouts] + # Apply DAPO overlong filtering (zero out loss mask for truncated responses) all_loss_masks = [r.loss_mask for r in rollouts] - filtered_loss_masks = apply_overlong_filtering(all_loss_masks, all_response_ids, tokenizer.eos_token_id) + stop_reasons = [r.stop_reason for r in rollouts] + filtered_loss_masks = apply_overlong_filtering(all_loss_masks, stop_reasons) training_datums = [] truncated_count = 0 From 833a7ae40bd2088b9ad0928783925aac000cba27 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 18:45:52 -0700 Subject: [PATCH 052/121] chore: reduce VL eval samples to 1 for faster iteration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit eval_n_samples_per_prompt 3→1: 57 trajectories instead of 171, ~3x faster eval. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 54da097cd2..aa934875f0 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -80,7 +80,7 @@ bash scripts/fleet-common-run.sh \ generator.batched=false \ generator.use_conversation_multi_turn=true \ generator.n_samples_per_prompt=4 \ - generator.eval_n_samples_per_prompt=3 \ + generator.eval_n_samples_per_prompt=1 \ generator.gpu_memory_utilization=0.80 \ trainer.logger="$LOGGER" \ trainer.project_name="fleet-browser-use-grpo" \ From 3dca64831af51432d0c5c49645ad20689e41b1cf Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 19:03:00 -0700 Subject: [PATCH 053/121] revert: restore eval_n_samples_per_prompt=3 for pass@3 k=3 sampling needed for pass@3 metric. Use MAX_TASKS env var to reduce eval set size instead. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index aa934875f0..54da097cd2 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -80,7 +80,7 @@ bash scripts/fleet-common-run.sh \ generator.batched=false \ generator.use_conversation_multi_turn=true \ generator.n_samples_per_prompt=4 \ - generator.eval_n_samples_per_prompt=1 \ + generator.eval_n_samples_per_prompt=3 \ generator.gpu_memory_utilization=0.80 \ trainer.logger="$LOGGER" \ trainer.project_name="fleet-browser-use-grpo" \ From 786b2af1e561a4d62ffda8d4271bb3ac4646bf19 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 19:46:53 -0700 Subject: [PATCH 054/121] Switch to flash_attn=true + update docs with corrected diagnosis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit flash_attn=false OOMs at 97K seq len even with chunked lm_head. The old fork uses flash_attn=true + chunked lm_head — matching that config. Updated CHANGELOG fix #5 and CLAUDE.md: earlier Xid 31 was from memory exhaustion (missing chunked lm_head), not GatedDeltaNet incompatibility. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 +- integrations/fleet/CHANGELOG.md | 12 +++++++----- scripts/fleet-35b-run.sh | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 520b5028e3..9b499ca65e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,7 +17,7 @@ Always consult the changelog before modifying Fleet training paths (`fsdp_worker 3. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation. Especially important because `expandable_segments` can't be used (see #5). -4. **`flash_attn=false` for GatedDeltaNet**: `fleet-35b-run.sh` uses SDPA, not flash_attention_2. Qwen3.5-35B's GatedDeltaNet linear attention layers crash with flash_attn in multi-node FSDP2. Memory savings come from chunked lm_head (#2), not flash attention. +4. **`flash_attn=true` required for memory headroom**: `fleet-35b-run.sh` uses flash_attention_2. SDPA uses too much memory at 97K sequence lengths — even with chunked lm_head, backward pass OOMs. Flash attention + chunked lm_head together provide sufficient headroom. (Earlier Xid 31 crashes were misattributed to GatedDeltaNet; actually caused by missing chunked lm_head.) 5. **No `expandable_segments` with vLLM 0.18.0**: `fleet-35b-run.sh` passes `--no-pytorch-alloc-conf` because vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`) conflicts with PyTorch's `expandable_segments:True`. Old SkyRL uses vLLM 0.17.0 (`cudaMalloc`) which has no conflict. diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index 37f1398953..9f36bbaceb 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -58,13 +58,15 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle **Why the old fork doesn't have this:** Old SkyRL uses vLLM 0.17.0 (`cudaMalloc`/`cudaFree`, no cuMem APIs, no conflict). -#### 5. `flash_attn=false` for Qwen3.5-35B GatedDeltaNet (`fleet-35b-run.sh`) +#### 5. `flash_attn=true` required for memory headroom (`fleet-35b-run.sh`) -**Where:** `fleet-35b-run.sh` sets `trainer.flash_attn=false`. +**Where:** `fleet-35b-run.sh` sets `trainer.flash_attn=true`. -**Problem:** Qwen3.5-35B uses GatedDeltaNet architecture which alternates softmax attention layers with linear attention layers (`torch_chunk_gated_delta_rule`). Setting `flash_attn=true` → `attn_implementation="flash_attention_2"` causes Xid 31 FAULT_PDE during the GDN linear attention layers' tensor allocation in multi-node FSDP2 training. +**Problem:** SDPA (`flash_attn=false`) uses significantly more memory for attention activations than flash_attention_2, especially at 97K sequence lengths. Even with chunked lm_head (fix #2), SDPA doesn't leave enough headroom for backward pass gradients → OOM requesting 5.95 GiB during `strategy.backward()`. -**Fix:** Use `flash_attn=false` (SDPA). The old fork has `flash_attn=true` but the real memory savings there come from chunked lm_head (fix #2), not flash attention itself. With chunked lm_head ported, SDPA provides sufficient memory headroom. +Earlier Xid 31 crashes with `flash_attn=true` were misattributed to GatedDeltaNet incompatibility. The Xid 31 occurred in `torch_chunk_gated_delta_rule` (a linear attention kernel that doesn't use flash attention at all) — it was caused by memory exhaustion from the missing chunked lm_head (fix #2), not by flash_attn itself. + +**Fix:** Use `flash_attn=true` (matches old fork). Flash attention + chunked lm_head together provide sufficient memory headroom for training forward+backward on 97K sequences. #### 6. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) @@ -85,5 +87,5 @@ The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) s | `skyrl/backends/skyrl_train/workers/model_wrapper.py` | Port chunked lm_head forward (loss_chunk_size) | | `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Pass loss_chunk_size to HFModelWrapper; synchronous ref offload + barrier | | `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | -| `scripts/fleet-35b-run.sh` | flash_attn=false, --no-pytorch-alloc-conf, wandb project rename | +| `scripts/fleet-35b-run.sh` | flash_attn=true, --no-pytorch-alloc-conf, wandb project rename | | `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index f11913155a..180415420a 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ - trainer.flash_attn=false \ + trainer.flash_attn=true \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ From a31d087d9eeb0ddba68e4e08c354f554594ca8ac Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 20:45:03 -0700 Subject: [PATCH 055/121] Fix logprobs/tokens shape mismatch and cap max_input_length 1. Guard against Tinker returning logprobs with different length than tokens in rollout collection (truncate/pad to match). 2. Safety check in prepare_training_data ensuring all arrays match target_tokens length before creating Datum. 3. These prevent the "target_tokens and logprobs must have the same shape" error from Tinker's forward_backward API. Co-Authored-By: Claude Opus 4.6 --- .../fleet/entrypoints/main_fleet_tinker.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/integrations/fleet/entrypoints/main_fleet_tinker.py b/integrations/fleet/entrypoints/main_fleet_tinker.py index 83ab6d2af5..440bd620be 100644 --- a/integrations/fleet/entrypoints/main_fleet_tinker.py +++ b/integrations/fleet/entrypoints/main_fleet_tinker.py @@ -349,8 +349,19 @@ def prepare_training_data( logprobs = logprobs[:response_len] if logprobs else [] loss_mask_data = loss_mask_data[:response_len] + # Ensure logprobs and response_ids are in sync before building training data + if len(logprobs) != len(response_ids): + logger.warning( + f"Datum {idx}: logprobs ({len(logprobs)}) != response_ids ({len(response_ids)}), fixing" + ) + if len(logprobs) > len(response_ids): + logprobs = logprobs[: len(response_ids)] + else: + logprobs = logprobs + [0.0] * (len(response_ids) - len(logprobs)) + # Target tokens (shifted by 1) target_tokens = full_sequence[1:] + seq_len = len(target_tokens) # Logprobs (0 for prompt, actual for response) full_logprobs = [0.0] * prompt_len + logprobs @@ -360,6 +371,10 @@ def prepare_training_data( full_mask = [0] * prompt_len + loss_mask_data full_mask = full_mask[1:] + # Safety: ensure all arrays match target_tokens length + full_logprobs = full_logprobs[:seq_len] + [0.0] * max(0, seq_len - len(full_logprobs)) + full_mask = full_mask[:seq_len] + [0] * max(0, seq_len - len(full_mask)) + # Advantages (apply only where loss mask is 1) advantage_value = advantages[idx] full_advantages = torch.zeros(len(full_sequence)) @@ -494,6 +509,16 @@ async def collect_fleet_rollout( output_ids = sequence.tokens output_logprobs = sequence.logprobs if sequence.logprobs else [] + # Guard: logprobs must match token count (Tinker may return different lengths) + if output_logprobs and len(output_logprobs) != len(output_ids): + logger.warning( + f"[{task_key}] Turn {turn_num}: logprobs length ({len(output_logprobs)}) != tokens length ({len(output_ids)}), truncating/padding" + ) + if len(output_logprobs) > len(output_ids): + output_logprobs = output_logprobs[: len(output_ids)] + else: + output_logprobs = output_logprobs + [0.0] * (len(output_ids) - len(output_logprobs)) + # Decode output output_text = tokenizer.decode(output_ids, skip_special_tokens=True) From 5e8ac67c9041cd53c390b53229a7ef22faf43fba Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:13:46 -0700 Subject: [PATCH 056/121] fix: add retry logic to _execute_meta_tool for transient connection errors Retries describe_db/query_db up to 3 times on connection-related errors (TCPTransport closed, connection reset). Defense-in-depth alongside the Fleet SDK keepalive_expiry fix (fleet-sdk PR #85). Co-Authored-By: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 211093316b..ec08b4737c 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -1186,22 +1186,27 @@ async def _execute_meta_tool(self, tool_call: Dict[str, Any]) -> str: if self.orch is None: return "Error: Fleet environment not provisioned. Generate a directly." - try: - if name == "describe_db": - result = await self.orch.describe_db_async(db_name=args.get("db_name", "seed")) - elif name == "query_db": - sql = args.get("sql", "") - if not sql: - return "Error: query_db requires a 'sql' argument." - result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) - else: - return f"Error: Unknown meta-tool '{name}'." - - if isinstance(result, dict): - return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" - return f"Tool result:\n{result}" - except Exception as e: - return f"Error: {e}" + max_retries = 3 + for attempt in range(max_retries): + try: + if name == "describe_db": + result = await self.orch.describe_db_async(db_name=args.get("db_name", "seed")) + elif name == "query_db": + sql = args.get("sql", "") + if not sql: + return "Error: query_db requires a 'sql' argument." + result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) + else: + return f"Error: Unknown meta-tool '{name}'." + + if isinstance(result, dict): + return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" + return f"Tool result:\n{result}" + except Exception as e: + if attempt < max_retries - 1 and ("closed" in str(e).lower() or "transport" in str(e).lower() or "connection" in str(e).lower()): + await asyncio.sleep(1) + continue + return f"Error: {e}" async def _execute_mcp_tool(self, tool_call: Dict[str, Any]) -> str: """Execute an MCP tool call via FleetMCPTools.""" From 3b2dc025bb61ec821f57cd8d896e575693d7c05b Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:13:48 -0700 Subject: [PATCH 057/121] Pin vLLM 0.17.0 + re-enable expandable_segments + update docs vLLM 0.18.0 CuMemAllocator conflicts with expandable_segments. Without it, memory fragmentation causes OOM on 35B. Pin 0.17.0 (cudaMalloc, no conflict). Consolidated CHANGELOG: 5 fixes (merged old #4/#5 into single vLLM pin fix). Updated CLAUDE.md to match. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 8 +++----- integrations/fleet/CHANGELOG.md | 24 ++++++++---------------- scripts/fleet-35b-run.sh | 9 ++++++++- 3 files changed, 19 insertions(+), 22 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 9b499ca65e..179d6ada89 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -15,13 +15,11 @@ Always consult the changelog before modifying Fleet training paths (`fsdp_worker 2. **Chunked lm_head forward**: `model_wrapper.py` has `loss_chunk_size` support ported from the old fork. Avoids materializing full `(B, S, vocab_size)` logits — critical for 35B with 131K vocab at 97K sequence length. Without it, OOM/Xid 31 during training forward. -3. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation. Especially important because `expandable_segments` can't be used (see #5). +3. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation. -4. **`flash_attn=true` required for memory headroom**: `fleet-35b-run.sh` uses flash_attention_2. SDPA uses too much memory at 97K sequence lengths — even with chunked lm_head, backward pass OOMs. Flash attention + chunked lm_head together provide sufficient headroom. (Earlier Xid 31 crashes were misattributed to GatedDeltaNet; actually caused by missing chunked lm_head.) +4. **Pin vLLM 0.17.0 for expandable_segments**: `fleet-35b-run.sh` pins `vllm==0.17.0` because 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`) conflicts with PyTorch's `expandable_segments:True`. Without expandable_segments, memory fragmentation causes OOM during backward. vLLM 0.17.0 uses `cudaMalloc` (no conflict), enabling the full proven config: expandable_segments + flash_attn=true + chunked lm_head. -5. **No `expandable_segments` with vLLM 0.18.0**: `fleet-35b-run.sh` passes `--no-pytorch-alloc-conf` because vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`) conflicts with PyTorch's `expandable_segments:True`. Old SkyRL uses vLLM 0.17.0 (`cudaMalloc`) which has no conflict. - -6. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. +5. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. ## Training Scripts diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index 9f36bbaceb..abb000c34e 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -48,27 +48,19 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle **Why the old fork doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. -#### 4. Keep `--no-pytorch-alloc-conf` for vLLM 0.18.0 (`fleet-35b-run.sh`) +#### 4. Pin vLLM 0.17.0 to enable `expandable_segments` (`fleet-35b-run.sh`) -**Where:** `fleet-35b-run.sh` retains `--no-pytorch-alloc-conf` flag. +**Where:** `fleet-35b-run.sh` pins `vllm==0.17.0` before training starts. -**Problem:** SkyRL-v2 uses vLLM 0.18.0 which introduced `CuMemAllocator` — a custom CUDA memory allocator using `cuMemCreate`/`cuMemMap` for its memory pool. PyTorch's `expandable_segments:True` (set by `fleet-common-run.sh` when `--no-pytorch-alloc-conf` is absent) also uses `cuMemCreate`/`cuMemMap`. Two independent cuMem-based allocators in the same process conflict → `AssertionError: Expandable segments are not compatible with memory pool` at vLLM engine init. +**Problem:** SkyRL-v2's `pyproject.toml` pins `vllm==0.18.0`, which introduced `CuMemAllocator` using `cuMemCreate`/`cuMemMap`. This conflicts with PyTorch's `expandable_segments:True` (also uses `cuMemCreate`/`cuMemMap`). Without `expandable_segments`, memory fragmentation causes OOM during backward on the 35B model — PyTorch can't satisfy large contiguous allocations from scattered free blocks. -**Fix:** Keep `--no-pytorch-alloc-conf` so `expandable_segments` is never set. Anti-fragmentation is handled by `empty_cache()` (fix #3) and chunked lm_head (fix #2). +Additionally, `flash_attn=false` (SDPA) uses too much attention memory at 97K sequences (OOM requesting 5.95 GiB during backward). `flash_attn=true` with vLLM 0.18.0 triggers Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward — not a memory issue but an incompatibility between vLLM 0.18.0's CUDA allocator and FSDP2 DTensor operations. -**Why the old fork doesn't have this:** Old SkyRL uses vLLM 0.17.0 (`cudaMalloc`/`cudaFree`, no cuMem APIs, no conflict). +**Fix:** Pin `vllm==0.17.0` at the start of `fleet-35b-run.sh` (`pip install --force-reinstall --no-deps`). vLLM 0.17.0 uses `cudaMalloc` (no `CuMemAllocator`), so `expandable_segments:True` works. Removed `--no-pytorch-alloc-conf` flag so `fleet-common-run.sh` enables `expandable_segments`. This restores the proven config: vLLM 0.17.0 + `expandable_segments` + `flash_attn=true` + chunked lm_head. -#### 5. `flash_attn=true` required for memory headroom (`fleet-35b-run.sh`) +**Why the rest of SkyRL-v2 keeps 0.18.0:** The pin is only in `fleet-35b-run.sh`. Other models (9B) may work fine with 0.18.0. The 35B model's extreme memory requirements make expandable_segments essential. -**Where:** `fleet-35b-run.sh` sets `trainer.flash_attn=true`. - -**Problem:** SDPA (`flash_attn=false`) uses significantly more memory for attention activations than flash_attention_2, especially at 97K sequence lengths. Even with chunked lm_head (fix #2), SDPA doesn't leave enough headroom for backward pass gradients → OOM requesting 5.95 GiB during `strategy.backward()`. - -Earlier Xid 31 crashes with `flash_attn=true` were misattributed to GatedDeltaNet incompatibility. The Xid 31 occurred in `torch_chunk_gated_delta_rule` (a linear attention kernel that doesn't use flash attention at all) — it was caused by memory exhaustion from the missing chunked lm_head (fix #2), not by flash_attn itself. - -**Fix:** Use `flash_attn=true` (matches old fork). Flash attention + chunked lm_head together provide sufficient memory headroom for training forward+backward on 97K sequences. - -#### 6. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) +#### 5. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) **Where:** `MeshDispatch.stage_chunks()` @@ -87,5 +79,5 @@ The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) s | `skyrl/backends/skyrl_train/workers/model_wrapper.py` | Port chunked lm_head forward (loss_chunk_size) | | `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Pass loss_chunk_size to HFModelWrapper; synchronous ref offload + barrier | | `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | -| `scripts/fleet-35b-run.sh` | flash_attn=true, --no-pytorch-alloc-conf, wandb project rename | +| `scripts/fleet-35b-run.sh` | Pin vLLM 0.17.0, flash_attn=true, expandable_segments, wandb project rename | | `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 180415420a..8f164bc171 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -30,9 +30,16 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" : "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" : "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" +# Pin vLLM to 0.17.0 for 35B training. vLLM 0.18.0's CuMemAllocator +# (cuMemCreate/cuMemMap) conflicts with PyTorch's expandable_segments:True. +# Without expandable_segments, memory fragmentation causes OOM during backward. +# vLLM 0.17.0 uses cudaMalloc (no conflict), enabling expandable_segments. +source .venv/bin/activate +pip install --force-reinstall --no-deps "vllm==0.17.0" + bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ - --set-ulimit --no-pytorch-alloc-conf \ + --set-ulimit \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ From 0f34391fe0227ed4a5d0e1585b8adcfa20992c4d Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:25:48 -0700 Subject: [PATCH 058/121] chore: temporarily disable eval_before_train for training verification Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 54da097cd2..d1912e312d 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=true \ + trainer.eval_before_train=false \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ From 9c92108eecdb74f6fd1d089dac3e06dd18808db3 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:40:31 -0700 Subject: [PATCH 059/121] Revert vllm_engine.py to pre-0.18 for vLLM 0.17.0 compatibility The upstream vLLM 0.18 bump (d00b17e7) removed backward-compat shims and added 0.18-only APIs (OpenAIModelRegistry, OpenAIServingRender). Since fleet-35b-run.sh pins vLLM 0.17.0 (for expandable_segments compatibility), revert vllm_engine.py to the version with 0.17.0 support. Co-Authored-By: Claude Opus 4.6 --- .../inference_engines/vllm/vllm_engine.py | 130 +++++++----------- 1 file changed, 52 insertions(+), 78 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index fd28948fee..cfef56ace2 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -15,6 +15,7 @@ import ray import vllm from loguru import logger +from packaging import version from vllm import SamplingParams from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, @@ -27,12 +28,7 @@ ) from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse -from vllm.entrypoints.openai.models.serving import ( - BaseModelPath, - OpenAIModelRegistry, - OpenAIServingModels, -) -from vllm.entrypoints.serve.render.serving import OpenAIServingRender +from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels from vllm.inputs import TokensPrompt from vllm.lora.request import LoRARequest @@ -102,7 +98,8 @@ def __init__(self, *args, bundle_indices: list = None, **kwargs): setup_envvars_for_vllm(kwargs, bundle_indices) vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False) - if vllm_v1_disable_multiproc: + if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2": + # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" # Store common attributes @@ -135,7 +132,6 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): prompts = input_batch.get("prompts") prompt_token_ids = input_batch.get("prompt_token_ids") request_sampling_params = input_batch.get("sampling_params") - multi_modal_data = input_batch.get("multi_modal_data") assert ( prompts is None and prompt_token_ids is not None @@ -145,7 +141,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): SamplingParams(**request_sampling_params) if request_sampling_params is not None else SamplingParams() ) - return prompt_token_ids, sampling_params, multi_modal_data + return prompt_token_ids, sampling_params def _postprocess_outputs(self, outputs): """Common output processing logic.""" @@ -248,7 +244,7 @@ def _create_engine(self, *args, **kwargs): return vllm.LLM(*args, **kwargs) async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: - prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) # Check if LoRA is enabled and create LoRA requests lora_requests = None @@ -262,18 +258,9 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path") ] * batch_size - # Build prompts with multimodal data for VL models - prompts = [] - for i, token_ids in enumerate(prompt_token_ids): - mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None - if mm_data: - prompts.append({"prompt_token_ids": token_ids, "multi_modal_data": mm_data}) - else: - prompts.append(TokensPrompt(prompt_token_ids=token_ids)) - outputs = await asyncio.to_thread( self.llm.generate, - prompts=prompts, + prompts=[TokensPrompt(prompt_token_ids=r) for r in prompt_token_ids], sampling_params=sampling_params, lora_request=lora_requests, ) @@ -364,7 +351,10 @@ def _create_engine(self, *args, **kwargs): enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) + if version.parse(vllm.__version__) >= version.parse("0.10.0"): + engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) + else: + engine_args = vllm.AsyncEngineArgs(disable_log_requests=not enable_log_requests, **kwargs) # Setup stat loggers for vLLM v1 if Ray Prometheus stats are enabled stat_loggers = None @@ -373,6 +363,8 @@ def _create_engine(self, *args, **kwargs): engine = vllm.AsyncLLMEngine.from_engine_args(engine_args, stat_loggers=stat_loggers) + # Adapted from https://github.com/volcengine/verl/blob/e90f18c40aa639cd25092b78a5ff7e2d2508c088/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L327 + model_config = engine.model_config model_path = kwargs.get("model") # Use served_model_name if provided (from generator.served_model_name config), # otherwise fall back to model_path. This allows using a different model name @@ -382,7 +374,15 @@ def _create_engine(self, *args, **kwargs): model_name = served_model_name if served_model_name is not None else model_path base_model_paths = [BaseModelPath(name=model_name, model_path=model_path)] - models = OpenAIServingModels(engine, base_model_paths) + + # vllm >= 0.11.2 removed model_config from OpenAI serving APIs + is_new_api = version.parse(vllm.__version__) >= version.parse("0.11.2") + legacy_kwargs = {} + if is_new_api: + models = OpenAIServingModels(engine, base_model_paths) + else: + models = OpenAIServingModels(engine, model_config, base_model_paths) + legacy_kwargs["model_config"] = model_config # Build request logger for debugging (off by default). # Enable via: generator.engine_init_kwargs.enable_log_requests=true @@ -393,39 +393,14 @@ def _create_engine(self, *args, **kwargs): request_logger = RequestLogger(max_log_len=max_log_len) - chat_template = openai_kwargs.pop("chat_template", None) - - from vllm.plugins.io_processors import get_io_processor - from vllm.renderers import renderer_from_config - - model_registry = OpenAIModelRegistry( - model_config=engine.model_config, - base_model_paths=base_model_paths, - ) - renderer = renderer_from_config(engine.vllm_config) - io_processor = get_io_processor( - engine.vllm_config, - renderer, - engine.model_config.io_processor_plugin, - ) - openai_serving_render = OpenAIServingRender( - model_config=engine.model_config, - renderer=renderer, - io_processor=io_processor, - model_registry=model_registry, - request_logger=request_logger, - chat_template=chat_template, - chat_template_content_format="auto", - ) - self.openai_serving_chat = OpenAIServingChat( engine_client=engine, models=models, response_role="assistant", - openai_serving_render=openai_serving_render, request_logger=request_logger, - chat_template=chat_template, + chat_template=openai_kwargs.pop("chat_template", None), # used to template /chat/completions requests chat_template_content_format="auto", + **legacy_kwargs, **openai_kwargs, ) @@ -434,8 +409,8 @@ def _create_engine(self, *args, **kwargs): self.openai_serving_completion = OpenAIServingCompletion( engine_client=engine, models=models, - openai_serving_render=openai_serving_render, request_logger=request_logger, + **legacy_kwargs, ) return engine @@ -470,13 +445,7 @@ async def _load_lora_from_disk(self, lora_path: str): result = await self.llm.add_lora(lora_request) return result - async def _collect_outputs( - self, - prompt_token_ids, - request_id: str, - sampling_params: SamplingParams, - multi_modal_data: Optional[Dict[str, Any]] = None, - ): + async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams): """Collect outputs for a single prompt.""" # Check if LoRA is enabled and create LoRA request final_output = None @@ -491,16 +460,8 @@ async def _collect_outputs( lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path" ) - # Build prompt with multimodal data for VL models - if multi_modal_data: - num_images = len(multi_modal_data.get("image", [])) - logger.info(f"VL generate: {num_images} images, {len(prompt_token_ids)} input tokens") - prompt = {"prompt_token_ids": prompt_token_ids, "multi_modal_data": multi_modal_data} - else: - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - async for request_output in self.llm.generate( - prompt=prompt, + prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, @@ -511,15 +472,14 @@ async def _collect_outputs( async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: """Generate responses using vLLM's async engine.""" - prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) tasks = [] - for i, prompt in enumerate(prompt_token_ids): + for prompt in prompt_token_ids: # Schedule the collection of outputs for each prompt. # Avoid duplicate request_ids request_id = str(uuid4().hex) - mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None - task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params, mm_data)) + task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params)) tasks.append(task) outputs = await asyncio.gather(*tasks) @@ -602,13 +562,20 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint request = CompletionRequest(**body) assert request.stream is False, "Streaming is not supported in SkyRL yet, please set stream to False." except Exception as e: - return ErrorResponse( - error=ErrorInfo( + if version.parse(vllm.__version__) >= version.parse("0.10.0"): + return ErrorResponse( + error=ErrorInfo( + message=str(e), + type=HTTPStatus.BAD_REQUEST.phrase, + code=HTTPStatus.BAD_REQUEST.value, + ), + ).model_dump() + else: + return ErrorResponse( message=str(e), type=HTTPStatus.BAD_REQUEST.phrase, code=HTTPStatus.BAD_REQUEST.value, - ), - ).model_dump() + ).model_dump() # 2. Call vllm engine try: @@ -640,13 +607,20 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint else: http_status = HTTPStatus.INTERNAL_SERVER_ERROR - return ErrorResponse( - error=ErrorInfo( + if version.parse(vllm.__version__) >= version.parse("0.10.0"): + return ErrorResponse( + error=ErrorInfo( + message=str(e), + type=http_status.phrase, + code=http_status.value, + ), + ).model_dump() + else: + return ErrorResponse( message=str(e), type=http_status.phrase, code=http_status.value, - ), - ).model_dump() + ).model_dump() async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """OpenAI-compatible HTTP endpoint for handling `/chat/completions` in Python vLLM engine. From dd9355467d43d95474091452a46806e2505c584e Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:55:03 -0700 Subject: [PATCH 060/121] Keep vLLM 0.18.0, reduce seq length to 72K, restore vllm_engine.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts vLLM 0.17.0 pin and vllm_engine.py pre-0.18 revert. Instead: - MAX_INPUT_LENGTH 96000→72000 to reduce memory pressure - --no-pytorch-alloc-conf (disables expandable_segments for 0.18.0 compat) - flash_attn=true + chunked lm_head + empty_cache at 72K Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 +- integrations/fleet/CHANGELOG.md | 19 ++- scripts/fleet-35b-run.sh | 11 +- .../inference_engines/vllm/vllm_engine.py | 130 +++++++++++------- 4 files changed, 90 insertions(+), 72 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 179d6ada89..e75733bd07 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,7 +17,7 @@ Always consult the changelog before modifying Fleet training paths (`fsdp_worker 3. **CUDA memory management for 35B**: `torch.cuda.empty_cache()` before backward pass in `worker.py` (policy + critic). Prevents OOM from fragmentation. -4. **Pin vLLM 0.17.0 for expandable_segments**: `fleet-35b-run.sh` pins `vllm==0.17.0` because 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`) conflicts with PyTorch's `expandable_segments:True`. Without expandable_segments, memory fragmentation causes OOM during backward. vLLM 0.17.0 uses `cudaMalloc` (no conflict), enabling the full proven config: expandable_segments + flash_attn=true + chunked lm_head. +4. **Reduced sequence length (72K) for 35B**: `fleet-35b-run.sh` uses `MAX_INPUT_LENGTH=72000` (down from 96000) with `--no-pytorch-alloc-conf` (disables `expandable_segments` which conflicts with vLLM 0.18.0's `CuMemAllocator`). At 97K, SDPA OOM'd and flash_attn hit Xid 31 in GatedDeltaNet. At 72K, flash_attn=true + chunked lm_head + empty_cache fits without expandable_segments. 5. **`stage_chunks` pre-staging**: `dispatch.py` has a `stage_chunks` optimization (not in upstream) that pre-stages mini-batch chunks in Ray object store. Includes dynamic `mini_batch_size` adjustment for hint augmentation's variable batch sizes. diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index abb000c34e..b4efb7d18b 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -9,7 +9,7 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle 2-node training crashed with: 1. `cudaErrorIllegalAddress` during FSDP ref model offload/backload (multi-node race) 2. OOM / Xid 31 FAULT_PDE during policy training forward+backward (missing chunked lm_head) -3. `AssertionError: Expandable segments are not compatible with memory pool` at vLLM init (vLLM 0.18.0 vs expandable_segments) +3. OOM / Xid 31 at 97K sequence length — SDPA too memory-hungry, flash_attn triggers GatedDeltaNet crash 4. `AssertionError: data batch size must be divisible by mini_batch_size, got 160 and 128` (hint augmentation) ### Root causes and fixes @@ -48,17 +48,16 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle **Why the old fork doesn't need this:** Targets smaller models (8B) with enough GPU headroom that fragmentation doesn't matter. -#### 4. Pin vLLM 0.17.0 to enable `expandable_segments` (`fleet-35b-run.sh`) +#### 4. Reduce sequence length to 72K and disable `expandable_segments` (`fleet-35b-run.sh`) -**Where:** `fleet-35b-run.sh` pins `vllm==0.17.0` before training starts. +**Where:** `fleet-35b-run.sh` — `MAX_INPUT_LENGTH` and `--no-pytorch-alloc-conf` flag. -**Problem:** SkyRL-v2's `pyproject.toml` pins `vllm==0.18.0`, which introduced `CuMemAllocator` using `cuMemCreate`/`cuMemMap`. This conflicts with PyTorch's `expandable_segments:True` (also uses `cuMemCreate`/`cuMemMap`). Without `expandable_segments`, memory fragmentation causes OOM during backward on the 35B model — PyTorch can't satisfy large contiguous allocations from scattered free blocks. +**Problem:** At 97K sequences (96000 input + 4096 generate), memory was too tight even with chunked lm_head and `empty_cache`: +- `flash_attn=false` (SDPA): OOM requesting 5.95 GiB during backward — SDPA's O(n²) attention memory is too large at 97K. +- `flash_attn=true`: Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward. +- `expandable_segments:True` would help with fragmentation but conflicts with vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`). -Additionally, `flash_attn=false` (SDPA) uses too much attention memory at 97K sequences (OOM requesting 5.95 GiB during backward). `flash_attn=true` with vLLM 0.18.0 triggers Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward — not a memory issue but an incompatibility between vLLM 0.18.0's CUDA allocator and FSDP2 DTensor operations. - -**Fix:** Pin `vllm==0.17.0` at the start of `fleet-35b-run.sh` (`pip install --force-reinstall --no-deps`). vLLM 0.17.0 uses `cudaMalloc` (no `CuMemAllocator`), so `expandable_segments:True` works. Removed `--no-pytorch-alloc-conf` flag so `fleet-common-run.sh` enables `expandable_segments`. This restores the proven config: vLLM 0.17.0 + `expandable_segments` + `flash_attn=true` + chunked lm_head. - -**Why the rest of SkyRL-v2 keeps 0.18.0:** The pin is only in `fleet-35b-run.sh`. Other models (9B) may work fine with 0.18.0. The 35B model's extreme memory requirements make expandable_segments essential. +**Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K). This lowers peak memory enough that `flash_attn=true` + chunked lm_head + `empty_cache` fits without needing `expandable_segments`. The `--no-pytorch-alloc-conf` flag disables `expandable_segments` to avoid the vLLM 0.18.0 CuMemAllocator conflict. #### 5. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) @@ -79,5 +78,5 @@ The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) s | `skyrl/backends/skyrl_train/workers/model_wrapper.py` | Port chunked lm_head forward (loss_chunk_size) | | `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Pass loss_chunk_size to HFModelWrapper; synchronous ref offload + barrier | | `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | -| `scripts/fleet-35b-run.sh` | Pin vLLM 0.17.0, flash_attn=true, expandable_segments, wandb project rename | +| `scripts/fleet-35b-run.sh` | Reduce seq length to 72K, flash_attn=true, --no-pytorch-alloc-conf, wandb project rename | | `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 8f164bc171..6fb68a8a49 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -14,7 +14,7 @@ export DATA_VERSION="${DATA_VERSION:-v55}" export MODALITY="${MODALITY:-tool_use}" export NUM_EPOCHS="${NUM_EPOCHS:-20}" export MAX_TURNS="${MAX_TURNS:-50}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" export ENV_KEYS="${ENV_KEYS:-}" @@ -30,16 +30,9 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" : "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" : "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" -# Pin vLLM to 0.17.0 for 35B training. vLLM 0.18.0's CuMemAllocator -# (cuMemCreate/cuMemMap) conflicts with PyTorch's expandable_segments:True. -# Without expandable_segments, memory fragmentation causes OOM during backward. -# vLLM 0.17.0 uses cudaMalloc (no conflict), enabling expandable_segments. -source .venv/bin/activate -pip install --force-reinstall --no-deps "vllm==0.17.0" - bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ - --set-ulimit \ + --set-ulimit --no-pytorch-alloc-conf \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index cfef56ace2..fd28948fee 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -15,7 +15,6 @@ import ray import vllm from loguru import logger -from packaging import version from vllm import SamplingParams from vllm.entrypoints.openai.chat_completion.protocol import ( ChatCompletionRequest, @@ -28,7 +27,12 @@ ) from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse -from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.models.serving import ( + BaseModelPath, + OpenAIModelRegistry, + OpenAIServingModels, +) +from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.inputs import TokensPrompt from vllm.lora.request import LoRARequest @@ -98,8 +102,7 @@ def __init__(self, *args, bundle_indices: list = None, **kwargs): setup_envvars_for_vllm(kwargs, bundle_indices) vllm_v1_disable_multiproc = kwargs.pop("vllm_v1_disable_multiproc", False) - if vllm_v1_disable_multiproc or vllm.__version__ == "0.8.2": - # https://github.com/vllm-project/vllm/blob/effc5d24fae10b29996256eb7a88668ff7941aed/examples/offline_inference/reproduciblity.py#L11 + if vllm_v1_disable_multiproc: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" # Store common attributes @@ -132,6 +135,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): prompts = input_batch.get("prompts") prompt_token_ids = input_batch.get("prompt_token_ids") request_sampling_params = input_batch.get("sampling_params") + multi_modal_data = input_batch.get("multi_modal_data") assert ( prompts is None and prompt_token_ids is not None @@ -141,7 +145,7 @@ def _preprocess_prompts(self, input_batch: InferenceEngineInput): SamplingParams(**request_sampling_params) if request_sampling_params is not None else SamplingParams() ) - return prompt_token_ids, sampling_params + return prompt_token_ids, sampling_params, multi_modal_data def _postprocess_outputs(self, outputs): """Common output processing logic.""" @@ -244,7 +248,7 @@ def _create_engine(self, *args, **kwargs): return vllm.LLM(*args, **kwargs) async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: - prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) # Check if LoRA is enabled and create LoRA requests lora_requests = None @@ -258,9 +262,18 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu LoRARequest(lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path") ] * batch_size + # Build prompts with multimodal data for VL models + prompts = [] + for i, token_ids in enumerate(prompt_token_ids): + mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None + if mm_data: + prompts.append({"prompt_token_ids": token_ids, "multi_modal_data": mm_data}) + else: + prompts.append(TokensPrompt(prompt_token_ids=token_ids)) + outputs = await asyncio.to_thread( self.llm.generate, - prompts=[TokensPrompt(prompt_token_ids=r) for r in prompt_token_ids], + prompts=prompts, sampling_params=sampling_params, lora_request=lora_requests, ) @@ -351,10 +364,7 @@ def _create_engine(self, *args, **kwargs): enable_log_requests = kwargs.pop("enable_log_requests", False) max_log_len = kwargs.pop("max_log_len", None) - if version.parse(vllm.__version__) >= version.parse("0.10.0"): - engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) - else: - engine_args = vllm.AsyncEngineArgs(disable_log_requests=not enable_log_requests, **kwargs) + engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) # Setup stat loggers for vLLM v1 if Ray Prometheus stats are enabled stat_loggers = None @@ -363,8 +373,6 @@ def _create_engine(self, *args, **kwargs): engine = vllm.AsyncLLMEngine.from_engine_args(engine_args, stat_loggers=stat_loggers) - # Adapted from https://github.com/volcengine/verl/blob/e90f18c40aa639cd25092b78a5ff7e2d2508c088/verl/workers/rollout/vllm_rollout/vllm_async_server.py#L327 - model_config = engine.model_config model_path = kwargs.get("model") # Use served_model_name if provided (from generator.served_model_name config), # otherwise fall back to model_path. This allows using a different model name @@ -374,15 +382,7 @@ def _create_engine(self, *args, **kwargs): model_name = served_model_name if served_model_name is not None else model_path base_model_paths = [BaseModelPath(name=model_name, model_path=model_path)] - - # vllm >= 0.11.2 removed model_config from OpenAI serving APIs - is_new_api = version.parse(vllm.__version__) >= version.parse("0.11.2") - legacy_kwargs = {} - if is_new_api: - models = OpenAIServingModels(engine, base_model_paths) - else: - models = OpenAIServingModels(engine, model_config, base_model_paths) - legacy_kwargs["model_config"] = model_config + models = OpenAIServingModels(engine, base_model_paths) # Build request logger for debugging (off by default). # Enable via: generator.engine_init_kwargs.enable_log_requests=true @@ -393,14 +393,39 @@ def _create_engine(self, *args, **kwargs): request_logger = RequestLogger(max_log_len=max_log_len) + chat_template = openai_kwargs.pop("chat_template", None) + + from vllm.plugins.io_processors import get_io_processor + from vllm.renderers import renderer_from_config + + model_registry = OpenAIModelRegistry( + model_config=engine.model_config, + base_model_paths=base_model_paths, + ) + renderer = renderer_from_config(engine.vllm_config) + io_processor = get_io_processor( + engine.vllm_config, + renderer, + engine.model_config.io_processor_plugin, + ) + openai_serving_render = OpenAIServingRender( + model_config=engine.model_config, + renderer=renderer, + io_processor=io_processor, + model_registry=model_registry, + request_logger=request_logger, + chat_template=chat_template, + chat_template_content_format="auto", + ) + self.openai_serving_chat = OpenAIServingChat( engine_client=engine, models=models, response_role="assistant", + openai_serving_render=openai_serving_render, request_logger=request_logger, - chat_template=openai_kwargs.pop("chat_template", None), # used to template /chat/completions requests + chat_template=chat_template, chat_template_content_format="auto", - **legacy_kwargs, **openai_kwargs, ) @@ -409,8 +434,8 @@ def _create_engine(self, *args, **kwargs): self.openai_serving_completion = OpenAIServingCompletion( engine_client=engine, models=models, + openai_serving_render=openai_serving_render, request_logger=request_logger, - **legacy_kwargs, ) return engine @@ -445,7 +470,13 @@ async def _load_lora_from_disk(self, lora_path: str): result = await self.llm.add_lora(lora_request) return result - async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_params: SamplingParams): + async def _collect_outputs( + self, + prompt_token_ids, + request_id: str, + sampling_params: SamplingParams, + multi_modal_data: Optional[Dict[str, Any]] = None, + ): """Collect outputs for a single prompt.""" # Check if LoRA is enabled and create LoRA request final_output = None @@ -460,8 +491,16 @@ async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_par lora_name=f"{lora_int_id}", lora_int_id=lora_int_id, lora_path="/dummy_lora_path" ) + # Build prompt with multimodal data for VL models + if multi_modal_data: + num_images = len(multi_modal_data.get("image", [])) + logger.info(f"VL generate: {num_images} images, {len(prompt_token_ids)} input tokens") + prompt = {"prompt_token_ids": prompt_token_ids, "multi_modal_data": multi_modal_data} + else: + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + async for request_output in self.llm.generate( - prompt=TokensPrompt(prompt_token_ids=prompt_token_ids), + prompt=prompt, sampling_params=sampling_params, request_id=request_id, lora_request=lora_request, @@ -472,14 +511,15 @@ async def _collect_outputs(self, prompt_token_ids, request_id: str, sampling_par async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: """Generate responses using vLLM's async engine.""" - prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) + prompt_token_ids, sampling_params, multi_modal_data = self._preprocess_prompts(input_batch) tasks = [] - for prompt in prompt_token_ids: + for i, prompt in enumerate(prompt_token_ids): # Schedule the collection of outputs for each prompt. # Avoid duplicate request_ids request_id = str(uuid4().hex) - task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params)) + mm_data = multi_modal_data[i] if multi_modal_data and i < len(multi_modal_data) else None + task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params, mm_data)) tasks.append(task) outputs = await asyncio.gather(*tasks) @@ -562,20 +602,13 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint request = CompletionRequest(**body) assert request.stream is False, "Streaming is not supported in SkyRL yet, please set stream to False." except Exception as e: - if version.parse(vllm.__version__) >= version.parse("0.10.0"): - return ErrorResponse( - error=ErrorInfo( - message=str(e), - type=HTTPStatus.BAD_REQUEST.phrase, - code=HTTPStatus.BAD_REQUEST.value, - ), - ).model_dump() - else: - return ErrorResponse( + return ErrorResponse( + error=ErrorInfo( message=str(e), type=HTTPStatus.BAD_REQUEST.phrase, code=HTTPStatus.BAD_REQUEST.value, - ).model_dump() + ), + ).model_dump() # 2. Call vllm engine try: @@ -607,20 +640,13 @@ async def _handle_openai_request(self, request_payload: Dict[str, Any], endpoint else: http_status = HTTPStatus.INTERNAL_SERVER_ERROR - if version.parse(vllm.__version__) >= version.parse("0.10.0"): - return ErrorResponse( - error=ErrorInfo( - message=str(e), - type=http_status.phrase, - code=http_status.value, - ), - ).model_dump() - else: - return ErrorResponse( + return ErrorResponse( + error=ErrorInfo( message=str(e), type=http_status.phrase, code=http_status.value, - ).model_dump() + ), + ).model_dump() async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: """OpenAI-compatible HTTP endpoint for handling `/chat/completions` in Python vLLM engine. From 5b7bb43de0c069f903ea6a8a56df502b939571b6 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:57:22 -0700 Subject: [PATCH 061/121] Update YAML MAX_INPUT_LENGTH to 72000 to match fleet-35b-run.sh Co-Authored-By: Claude Opus 4.6 --- tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index c2039685ed..02067cf095 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -45,7 +45,7 @@ envs: DIFFICULTY: "" MODALITY: "tool_use" MAX_TURNS: 50 - MAX_INPUT_LENGTH: 96000 + MAX_INPUT_LENGTH: 72000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 20 RUN_ID: "" From d7a6b48336a17ef66dc90bf8dc8c3c92532583d3 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 21:58:57 -0700 Subject: [PATCH 062/121] chore: re-enable eval_before_train for production VL run Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index d1912e312d..54da097cd2 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=false \ + trainer.eval_before_train=true \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ From 2d12453c02a9c188e70af0d3cedd476ab4707417 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 22:04:47 -0700 Subject: [PATCH 063/121] chore: disable eval_before_train to verify backward pass Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 54da097cd2..d1912e312d 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=true \ + trainer.eval_before_train=false \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ From a00402557d1713717034bec1dbeef8e0dea14d96 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 22:10:42 -0700 Subject: [PATCH 064/121] fix: parse all tool calls per turn + remove exploration gate 1. parse_tool_calls() returns ALL tags instead of just the first one. The model often batches multiple tool calls in one generation (73% of trajectories). Previously only the first was executed, the rest silently dropped. 2. Remove the "must explore before generating task" gate. With TCP errors on describe_db, this gate rejected 62-78% of generated tasks. The model should be free to generate tasks at any point. Co-Authored-By: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 96 ++++--------------- .../envs/task_gen/tool_call_parser.py | 56 +++++++---- 2 files changed, 56 insertions(+), 96 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index ec08b4737c..8d8cfdbb79 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -36,7 +36,7 @@ BaseTextEnvStepOutput, ConversationType, ) -from skyrl_gym.envs.task_gen.tool_call_parser import parse_tool_call +from skyrl_gym.envs.task_gen.tool_call_parser import parse_tool_call, parse_tool_calls from skyrl_gym.envs.task_gen.verifier_sandbox import ( VerifierSandbox, parse_task_output, @@ -1054,82 +1054,25 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: # 1. Check for block → evaluation pipeline if "" in action: - # Gate: require describe_db + query_db + at least one env tool call - # before generating a task (unless single-turn or out of turns) - if self.max_turns > 1 and not max_turns_reached: - missing = [] - if not self.called_describe_db: - missing.append("`describe_db` (to see the schema)") - if not self.called_query_db: - missing.append("`query_db` (to inspect actual data)") - if self.mcp_tool_calls < 1: - missing.append("at least one environment API tool (to understand input/output formats)") - if missing: - observation = { - "role": "user", - "content": ( - "You must explore the environment before generating a task. " - "You still need to call: " - + "; ".join(missing) - + ". NEVER hardcode database IDs — always query to find them first." - ), - } - return BaseTextEnvStepOutput( - observations=[observation], - reward=0.0, - done=False, - metadata={ - "env_key": self.env_key, - "turn": self.turns, - "rejected": "no_exploration", - }, - ) return await self._handle_task_generation(action) - # 2. Check for tool call → execute via Fleet orchestrator or MCP - # Enforce exploration sequence: describe_db → query_db → env tool - tool_call = parse_tool_call(action) - if tool_call and tool_call["name"] in self.callable_tools: - if self.max_turns > 1 and not max_turns_reached: - name = tool_call["name"] - if name == "query_db" and not self.called_describe_db: - return BaseTextEnvStepOutput( - observations=[ - { - "role": "user", - "content": "Call `describe_db` first to see the schema before querying data.", - } - ], - reward=0.0, - done=False, - metadata={"env_key": self.env_key, "turn": self.turns, "rejected": "sequence_violation"}, - ) - if name not in _META_TOOLS and not self.called_query_db: - return BaseTextEnvStepOutput( - observations=[ - { - "role": "user", - "content": ( - "Call `describe_db` and `query_db` first to understand the schema and data " - "before calling environment tools." - ), - } - ], - reward=0.0, - done=False, - metadata={"env_key": self.env_key, "turn": self.turns, "rejected": "sequence_violation"}, - ) - - if tool_call["name"] in _META_TOOLS: - self.meta_tool_calls += 1 - if tool_call["name"] == "describe_db": - self.called_describe_db = True - elif tool_call["name"] == "query_db": - self.called_query_db = True - obs_content = await self._execute_meta_tool(tool_call) - else: - self.mcp_tool_calls += 1 - obs_content = await self._execute_mcp_tool(tool_call) + # 2. Check for tool calls → execute all via Fleet orchestrator or MCP + tool_calls = parse_tool_calls(action) + tool_calls = [tc for tc in tool_calls if tc["name"] in self.callable_tools] + if tool_calls: + results = [] + for tc in tool_calls: + if tc["name"] in _META_TOOLS: + self.meta_tool_calls += 1 + if tc["name"] == "describe_db": + self.called_describe_db = True + elif tc["name"] == "query_db": + self.called_query_db = True + result = await self._execute_meta_tool(tc) + else: + self.mcp_tool_calls += 1 + result = await self._execute_mcp_tool(tc) + results.append(f"[{tc['name']}] {result}") if max_turns_reached: return BaseTextEnvStepOutput( @@ -1139,12 +1082,13 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: metadata={"env_key": self.env_key, "turn": self.turns, "done_reason": "max_turns"}, ) + obs_content = "\n\n".join(results) observation = {"role": "user", "content": obs_content} return BaseTextEnvStepOutput( observations=[observation], reward=0.0, done=False, - metadata={"env_key": self.env_key, "turn": self.turns, "tool_call": tool_call}, + metadata={"env_key": self.env_key, "turn": self.turns, "tool_calls": [tc["name"] for tc in tool_calls]}, ) # 3. Neither task nor tool call → nudge diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py b/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py index f328507d18..95e21912fd 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/tool_call_parser.py @@ -7,7 +7,7 @@ import json import re -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional def _try_parse_json(raw: str) -> Optional[Dict[str, Any]]: @@ -33,9 +33,26 @@ def _try_parse_json(raw: str) -> Optional[Dict[str, Any]]: return None +def _parse_one(match_text: str) -> Optional[Dict[str, Any]]: + """Parse a single tool call from matched text.""" + parsed = _try_parse_json(match_text) + if parsed is None: + return None + name = parsed.get("name") or parsed.get("tool") + args = parsed.get("arguments") or parsed.get("params", {}) + if name: + return {"name": name, "arguments": args} + return None + + def parse_tool_call(action: str) -> Optional[Dict[str, Any]]: - """ - Parse tool call from LLM response. + """Parse the first tool call from LLM response. Returns None if not found.""" + calls = parse_tool_calls(action) + return calls[0] if calls else None + + +def parse_tool_calls(action: str) -> List[Dict[str, Any]]: + """Parse all tool calls from LLM response. Supports tag-based formats: - {"name": "...", "arguments": {...}} @@ -44,24 +61,23 @@ def parse_tool_call(action: str) -> Optional[Dict[str, Any]]: Also handles cases where the closing tag is missing (e.g., when is used as the stop string and not included in the output). - Returns dict with "name" and "arguments" keys, or None if not found. + Returns list of dicts with "name" and "arguments" keys. """ - # Try common tag formats + results: List[Dict[str, Any]] = [] + for tag in ["tool_call", "function_call"]: - # First try with closing tag - match = re.search(rf"<{tag}>(.*?)", action, re.DOTALL) - if not match: - # Try without closing tag (for when is the stop string) - # Match from opening tag to end of string or next special token + # Find all with closing tag + for match in re.finditer(rf"<{tag}>(.*?)", action, re.DOTALL): + parsed = _parse_one(match.group(1)) + if parsed: + results.append(parsed) + + # If none found with closing tags, try without (stop string case) + if not results: match = re.search(rf"<{tag}>(.*?)(?:<\||\Z)", action, re.DOTALL) - if match: - parsed = _try_parse_json(match.group(1)) - if parsed is None: - continue - # Normalize keys - name = parsed.get("name") or parsed.get("tool") - args = parsed.get("arguments") or parsed.get("params", {}) - if name: - return {"name": name, "arguments": args} + if match: + parsed = _parse_one(match.group(1)) + if parsed: + results.append(parsed) - return None + return results From 621ae49b331169e10bee5475328a940edecdf64b Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 22:48:18 -0700 Subject: [PATCH 065/121] Clarify --no-pytorch-alloc-conf mechanism in CHANGELOG Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index b4efb7d18b..fda16e3a3d 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -57,7 +57,7 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle - `flash_attn=true`: Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward. - `expandable_segments:True` would help with fragmentation but conflicts with vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`). -**Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K). This lowers peak memory enough that `flash_attn=true` + chunked lm_head + `empty_cache` fits without needing `expandable_segments`. The `--no-pytorch-alloc-conf` flag disables `expandable_segments` to avoid the vLLM 0.18.0 CuMemAllocator conflict. +**Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K). This lowers peak memory enough that `flash_attn=true` + chunked lm_head + `empty_cache` fits without needing `expandable_segments`. The `--no-pytorch-alloc-conf` flag passed to `fleet-common-run.sh` skips the default `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, avoiding the vLLM 0.18.0 CuMemAllocator conflict. The 9B VL script (`fleet-vl-run.sh`) also passes this flag for the same reason. #### 5. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) From a50ca53daa8775af1d1656c05004364cde343384 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 22:50:48 -0700 Subject: [PATCH 066/121] =?UTF-8?q?Switch=20to=20flash=5Fattn=3Dfalse=20?= =?UTF-8?q?=E2=80=94=20flash=5Fattn=3Dtrue=20causes=20Xid=2031=20with=20vL?= =?UTF-8?q?LM=200.18.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit flash_attn=true + vLLM 0.18.0 triggers Xid 31 FAULT_PDE in GatedDeltaNet during ref forward at both 97K and 72K — not a memory issue but a CUDA memory mapping corruption from vLLM's CuMemAllocator. Trying SDPA at 72K. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 6fb68a8a49..3ea02704a1 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ - trainer.flash_attn=true \ + trainer.flash_attn=false \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ From 816064b32b51eb19bbcb3cf20433f14cd950b38c Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 29 Mar 2026 22:57:23 -0700 Subject: [PATCH 067/121] Re-enable eval_before_train for production VL run Backward pass verified working on sky-4da1-deniz. Re-enable step-0 eval for production training. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index d1912e312d..54da097cd2 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=false \ + trainer.eval_before_train=true \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ From 32a8022c8e80456987457c6ef46fab5cd0fa13ac Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 00:52:17 -0700 Subject: [PATCH 068/121] =?UTF-8?q?docs:=20update=20CHANGELOG=20fix=20#4?= =?UTF-8?q?=20=E2=80=94=20flash=5Fattn=3Dfalse,=20verified=20working=20ste?= =?UTF-8?q?p=20timing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Corrected fix #4 to reflect the final working config: - flash_attn=false (SDPA), not flash_attn=true - flash_attn=true causes Xid 31 at both 97K AND 72K (CuMemAllocator issue) - Added "Verified working" note: ref forward 8.4 min, backward 45.6 min Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/CHANGELOG.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index fda16e3a3d..5658217b7b 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -54,10 +54,12 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle **Problem:** At 97K sequences (96000 input + 4096 generate), memory was too tight even with chunked lm_head and `empty_cache`: - `flash_attn=false` (SDPA): OOM requesting 5.95 GiB during backward — SDPA's O(n²) attention memory is too large at 97K. -- `flash_attn=true`: Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward. +- `flash_attn=true`: Xid 31 FAULT_PDE in GatedDeltaNet layers during ref model forward — reproduced at both 97K and 72K. Not a memory issue; vLLM 0.18.0's CuMemAllocator corrupts CUDA memory mappings that FSDP2 DTensor operations later touch. - `expandable_segments:True` would help with fragmentation but conflicts with vLLM 0.18.0's `CuMemAllocator` (`cuMemCreate`/`cuMemMap`). -**Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K). This lowers peak memory enough that `flash_attn=true` + chunked lm_head + `empty_cache` fits without needing `expandable_segments`. The `--no-pytorch-alloc-conf` flag passed to `fleet-common-run.sh` skips the default `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, avoiding the vLLM 0.18.0 CuMemAllocator conflict. The 9B VL script (`fleet-vl-run.sh`) also passes this flag for the same reason. +**Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K) and use `flash_attn=false` (SDPA). At 72K, SDPA's O(n²) memory is ~55% of what it was at 97K — enough to fit with chunked lm_head + `empty_cache`. The `--no-pytorch-alloc-conf` flag passed to `fleet-common-run.sh` skips the default `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, avoiding the vLLM 0.18.0 CuMemAllocator conflict. The 9B VL script (`fleet-vl-run.sh`) also passes this flag for the same reason. + +**Verified working:** Step 1 completed — ref forward 8.4 min, policy backward 45.6 min, total step ~80 min. SDPA is slower than flash_attn but stable. #### 5. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) @@ -78,5 +80,5 @@ The old fork's manual loop (`num_mini_batches = len(data) // mini_batch_size`) s | `skyrl/backends/skyrl_train/workers/model_wrapper.py` | Port chunked lm_head forward (loss_chunk_size) | | `skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py` | Pass loss_chunk_size to HFModelWrapper; synchronous ref offload + barrier | | `skyrl/backends/skyrl_train/workers/worker.py` | empty_cache before backward (3 sites) | -| `scripts/fleet-35b-run.sh` | Reduce seq length to 72K, flash_attn=true, --no-pytorch-alloc-conf, wandb project rename | +| `scripts/fleet-35b-run.sh` | Reduce seq length to 72K, flash_attn=false, --no-pytorch-alloc-conf, wandb project rename | | `skyrl/backends/skyrl_train/distributed/dispatch.py` | Dynamic mini_batch_size adjustment | From 54d74adc9cdaf5b76bd9767630d2a88c98662d5d Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 11:30:35 -0700 Subject: [PATCH 069/121] =?UTF-8?q?docs:=20update=20CHANGELOG=20=E2=80=94?= =?UTF-8?q?=2010=20steps=20verified,=20checkpoint=20at=20step=2010?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 12 hours on GCP spot 2×H200:8 with zero GPU errors. Avg step time ~70 min, checkpoint saved to S3 at step 10. Co-Authored-By: Claude Opus 4.6 --- integrations/fleet/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/fleet/CHANGELOG.md b/integrations/fleet/CHANGELOG.md index 5658217b7b..38ebbb2211 100644 --- a/integrations/fleet/CHANGELOG.md +++ b/integrations/fleet/CHANGELOG.md @@ -59,7 +59,7 @@ Fixes for 2-node (16 GPU) Qwen3.5-35B GRPO training on GCP H200. Ported from fle **Fix:** Reduce `MAX_INPUT_LENGTH` from 96000 to 72000 (total seq ~76K) and use `flash_attn=false` (SDPA). At 72K, SDPA's O(n²) memory is ~55% of what it was at 97K — enough to fit with chunked lm_head + `empty_cache`. The `--no-pytorch-alloc-conf` flag passed to `fleet-common-run.sh` skips the default `export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`, avoiding the vLLM 0.18.0 CuMemAllocator conflict. The 9B VL script (`fleet-vl-run.sh`) also passes this flag for the same reason. -**Verified working:** Step 1 completed — ref forward 8.4 min, policy backward 45.6 min, total step ~80 min. SDPA is slower than flash_attn but stable. +**Verified working:** 10 steps completed on GCP spot 2×H200:8 (asia-south1-b) with zero GPU errors over 12 hours. Step timing: generation ~7 min, ref forward ~8 min, policy backward ~44 min, total step ~70 min avg. Checkpoint saved to S3 at step 10. SDPA is slower than flash_attn but stable. WandB: `fleet_qwen35_35b_tool_use_2c0e13b7` (run ID `f6kw15p2`). #### 5. Dynamic mini_batch_size for hint augmentation (`dispatch.py`) From ef7687f96783b676f6977cdf03c882532a8482e1 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 15:45:04 -0700 Subject: [PATCH 070/121] fix: disable hints during training by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add enable_hints flag (default False) gated by env_config. Previously hints were always ON during training (controlled by is_eval). Now hints only run when enable_hints=True AND not in eval mode. Hints were net negative in iter#11 — verifier code dump confused evaluator. Reward now uses raw variance only: R = base_quality + judge_gate * alpha * var(raw). Co-Authored-By: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 8d8cfdbb79..50f517aa44 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -172,9 +172,12 @@ def __init__( self.openrouter_api_key = os.environ.get("OPENROUTER_API_KEY", "") self.fleet_api_key = os.environ.get("FLEET_API_KEY", "") - # Eval mode: k=8 raw only (no hints); Train mode: k with hints + # Eval mode: k=8 raw only (no hints) self.is_eval = extras.get("training_phase") == "eval" self.eval_k_rollouts = int(env_config.get("eval_k_rollouts", 8)) if env_config else 8 + # Whether to run hinted evaluation jobs (2nd harness job with verifier feedback). + # Default off — hints were net negative in iter#11 (verifier code dump confused evaluator). + self.enable_hints = bool(env_config.get("enable_hints", False)) if env_config else False # Lazy-init Fleet SDK client for harness evaluation self._fleet_client = None @@ -838,21 +841,15 @@ async def _evaluate_task(self, prompt: str, verifier: str) -> Dict[str, float]: start = time.time() try: - # Eval mode: k=8 raw only (no hints) for pass rate measurement - # Train mode: k raw + k hinted for hint_gap signal + # Eval: k=eval_k_rollouts for pass rate; Train: k=k_rollouts eval_k = self.eval_k_rollouts if self.is_eval else self.k_rollouts # 1. Raw job: k rollouts without hints raw_job_id, raw_results = await self._run_harness_job(prompt, verifier, k=eval_k) raw_scores = [r[0] for r in raw_results] - if self.is_eval: - # Eval: no hints, reward = alpha * var_raw (hint_gap=0) - hinted_scores = [] - hinted_job_id = None - hint_text = "" - result = compute_task_reward(raw_scores, raw_scores, validity=1.0) - else: + if self.enable_hints and not self.is_eval: + # Hinted training: k raw + k hinted for hint_gap signal # 2. Build hint from first failing session's stdout/error hint_stdout = None hint_error = None @@ -885,6 +882,12 @@ async def _evaluate_task(self, prompt: str, verifier: str) -> Dict[str, float]: # 4. Compute reward result = compute_task_reward(raw_scores, hinted_scores, validity=1.0) + else: + # No hints — reward based on raw variance only + hinted_scores = [] + hinted_job_id = None + hint_text = "" + result = compute_task_reward(raw_scores, raw_scores, validity=1.0) duration = time.time() - start From 36d7f325b2d19306c39c92ff2155d2506e7759c4 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 19:16:15 -0700 Subject: [PATCH 071/121] feat: tool-call reward shaping + increase context to 65K MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add tool_call_reward_per_call config (default 0.0, set to 0.02 in run script) Rewards each successful meta-tool call (describe_db, query_db) to incentivize multi-turn DB exploration instead of single-turn guessing from system prompt. - MAX_INPUT_LENGTH 30720 → 65536: baseline runs showed 30K forced single-turn convergence by step 5 (describe_db schemas overflow context budget). - MAX_GENERATE_LENGTH 2048 → 4096: more room for task+verifier output. - eval_interval 20 → 10: get eval signal earlier. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-task-gen-run.sh | 3 ++- skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 14 +++++++++++--- tasks/task-gen-grpo-qwen3_5-9b.yaml | 4 ++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/scripts/fleet-task-gen-run.sh b/scripts/fleet-task-gen-run.sh index e885b380d6..ca1717183a 100755 --- a/scripts/fleet-task-gen-run.sh +++ b/scripts/fleet-task-gen-run.sh @@ -41,7 +41,7 @@ bash scripts/fleet-common-run.sh \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ trainer.eval_before_train=false \ - trainer.eval_interval=20 \ + trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=12 \ trainer.use_hybrid_env_sampling=true \ @@ -86,5 +86,6 @@ bash scripts/fleet-common-run.sh \ ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ ++environment.skyrl_gym.task_gen.evaluator_model="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" \ ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 \ + ++environment.skyrl_gym.task_gen.tool_call_reward_per_call=0.02 \ "${ENV_FILTER_ARGS[@]}" \ "$@" diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 50f517aa44..1e99f83c2d 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -191,12 +191,16 @@ def __init__( # Provides GRPO gradient signal even when all harness evals return 0. self.base_quality_reward = float(env_config.get("base_quality_reward", 0.1)) if env_config else 0.1 + # Small per-tool-call reward to incentivize DB exploration (describe_db, query_db). + # Default 0.0 = off (no behavior change for existing runs). + self.tool_call_reward_per_call = float(env_config.get("tool_call_reward_per_call", 0.0)) if env_config else 0.0 + logger.info( f"TaskGenEnv: env={self.env_key}, max_turns={self.max_turns}, " f"judge={self.judge_model or 'none'}, " f"tools={len(self.env_tools)}, k={self.k_rollouts}, eval_k={self.eval_k_rollouts}, " f"evaluator={self.evaluator_model}, is_eval={self.is_eval}, " - f"base_quality={self.base_quality_reward}" + f"base_quality={self.base_quality_reward}, tool_call_reward={self.tool_call_reward_per_call}" ) def _format_tool_schema(self, tool: Dict[str, Any]) -> str: @@ -1020,18 +1024,22 @@ async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: # 4. Hint-based evaluation via Fleet harness eval_result = await self._evaluate_task(prompt, verifier) - # 5. R = base_quality + eval_signal + # 5. R = base_quality + tool_call_reward + eval_signal # base_quality: small reward for passing sandbox+judge (structural validity) + # tool_call_reward: incentivize DB exploration (describe_db, query_db) # eval_signal: judge_gate * compute_task_reward (harness-based quality) # This prevents GRPO zero-signal deadlock when all harness evals fail. base_quality = self.base_quality_reward + tool_call_reward = self.meta_tool_calls * self.tool_call_reward_per_call eval_signal = judge_gate * eval_result["total"] - reward = base_quality + eval_signal + reward = base_quality + tool_call_reward + eval_signal metadata["reward_breakdown"] = { "sandbox": 1.0, "judge": judge_gate, "base_quality": base_quality, + "tool_call_reward": tool_call_reward, + "meta_tool_calls": self.meta_tool_calls, "eval_signal": eval_signal, **eval_result, "total": reward, diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index 44b5a50831..be05736c8e 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -34,8 +34,8 @@ envs: DATA_VERSION: "v55" MODALITY: "tool_use" MAX_TURNS: 10 - MAX_INPUT_LENGTH: 30720 - MAX_GENERATE_LENGTH: 2048 + MAX_INPUT_LENGTH: 65536 + MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 20 JUDGE_MODEL: "anthropic/claude-sonnet-4.5" EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" From efe1fb0a0af74c960225063f48ec94c15d8bc73d Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 20:46:59 -0700 Subject: [PATCH 072/121] feat: LLM classifier gate to filter broken tasks before harness Rewrites _judge_task as a pre-filter optimized for very low false positive rate. Only rejects tasks with clear structural defects: 1. Phantom tables (not in env schema) 2. Undefined function/constant references 3. Vacuous checks (only user-exists or len>0) 4. Read-write mismatch (prompt asks reads, verifier checks writes) Passes env_schema to the classifier so it can verify table references. Uses Haiku via OpenRouter for low cost/latency (~$0.001 per call). Defaults to ACCEPT on any error (conservative). Co-Authored-By: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 66 +++++++++++++------ tasks/task-gen-grpo-qwen3_5-9b.yaml | 3 +- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 1e99f83c2d..c6ab178b5d 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -548,31 +548,54 @@ def find_new_entries(table_name, id_field="id", filter_conditions=None): return "\n".join(parts) def _judge_task(self, prompt: str, verifier: str) -> float: - """LLM-as-a-judge gate: returns 0.0 (invalid) or 1.0 (valid). - - Uses a model to check if the generated (prompt, verifier) pair - is valid and coherent. This is the binary gate in the reward formula. + """LLM classifier gate: returns 0.0 (reject) or 1.0 (accept). + + Predicts whether the (prompt, verifier) pair will produce meaningful + evaluation signal. Optimized for very low false positive rate — only + rejects tasks that are near-certain to waste harness compute. + + Checks: + 1. Phantom tables: verifier references tables not in env schema + 2. Undefined references: calls to functions/constants not defined + 3. Vacuous checks: verifier only checks user existence or len>0 + 4. Prompt-verifier misalignment: verifier checks outcomes the + prompt never asked for (e.g., write checks for read prompts) """ if not self.judge_model or not self.openrouter_api_key: return 1.0 # No judge configured, pass through - # Build concise tool list for context + # Build context for the classifier tool_names = [t for t in self.env_tools if t != "computer"] tools_str = ", ".join(tool_names[:20]) if tool_names else "none discovered" + schema_block = self.env_schema if self.env_schema else "Schema not available." + judge_prompt = ( - f'Evaluate this task for the "{self.env_key}" environment.\n\n' + "You are a pre-filter that decides whether to run an expensive evaluation harness " + "on a generated (prompt, verifier) pair. Your goal: reject tasks that will CERTAINLY " + "produce zero useful signal, while accepting anything that MIGHT work.\n\n" + "ACCEPT unless you find a clear, concrete defect from this list:\n\n" + "1. PHANTOM TABLES — The verifier calls .table(\"X\") where X does not exist in the " + "schema below. If the table doesn't exist, the verifier crashes → zero signal.\n\n" + "2. UNDEFINED REFERENCES — The verifier calls functions (e.g., find_new_entries()) or " + "uses constants (e.g., TASK_FAILED_SCORE) that are not defined in the code and are " + "not Python builtins. These cause NameError → zero signal.\n\n" + "3. VACUOUS CHECKS — The verifier's ONLY checks are whether a user/account exists or " + "whether any table has rows (len > 0), without validating any task-specific outcome. " + "These always pass regardless of agent behavior → zero variance → zero signal.\n\n" + "4. PROMPT-VERIFIER MISMATCH — The prompt asks for read-only operations (search, list, " + "view, find, get) but the verifier checks for write operations (new rows in tables, " + "saved items, created records). Read operations don't create DB entries, so the " + "verifier will find nothing → zero signal.\n\n" + "IMPORTANT: When in doubt, ACCEPT. A false reject wastes a potentially good task. " + "A false accept only wastes one harness call. Err heavily toward ACCEPT.\n\n" + f'--- Environment: "{self.env_key}" ---\n\n' f"Available tools: {tools_str}\n\n" - f"Task prompt:\n{prompt}\n\n" - f"Verifier code:\n```python\n{verifier}\n```\n\n" - "A valid task must:\n" - "1. Have a clear, specific prompt describing what an agent should do\n" - "2. Have a verifier that checks the correct outcome via the DB API " - '(env.db("seed"), env.db("current"), .table().eq().all())\n' - "3. The verifier must check what the prompt actually asks\n" - "4. The prompt must not leak the answer or expected values\n" - "5. The verifier must return 0.0 on a fresh env (before agent acts)\n\n" - "Answer with exactly one word: VALID or INVALID" + f"Database schema (these are the ONLY valid tables):\n```\n{schema_block}\n```\n\n" + f"--- Generated task ---\n\n" + f"Prompt:\n{prompt}\n\n" + f"Verifier:\n```python\n{verifier}\n```\n\n" + "Answer with exactly one word: ACCEPT or REJECT" ) try: @@ -586,11 +609,14 @@ def _judge_task(self, prompt: str, verifier: str) -> float: api_key=self.openrouter_api_key, ) answer = response.choices[0].message.content.strip().upper() - is_valid = "VALID" in answer and "INVALID" not in answer - logger.info(f"LLM judge [{self.env_key}]: {answer} -> {'VALID' if is_valid else 'INVALID'}") - return 1.0 if is_valid else 0.0 + accepted = "ACCEPT" in answer and "REJECT" not in answer + logger.info( + f"LLM classifier [{self.env_key}]: {answer} -> " + f"{'ACCEPT' if accepted else 'REJECT'}" + ) + return 1.0 if accepted else 0.0 except Exception as e: - logger.warning(f"LLM judge failed, defaulting to valid: {e}") + logger.warning(f"LLM classifier failed, defaulting to accept: {e}") return 1.0 @staticmethod diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index be05736c8e..c6f41bf7bd 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -37,7 +37,8 @@ envs: MAX_INPUT_LENGTH: 65536 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 20 - JUDGE_MODEL: "anthropic/claude-sonnet-4.5" + JUDGE_MODEL: "anthropic/claude-3.5-haiku" + OPENROUTER_API_KEY: "" EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" K_ROLLOUTS: 4 ALPHA: "1.0" From f653a30e7cf3d0aa901e760c31ec6d1b47b9fabc Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 21:22:42 -0700 Subject: [PATCH 073/121] fix: remove read-write mismatch check, use Sonnet 4.5 for classifier Read-write mismatch is too subjective and risks false positives. Classifier now checks only: phantom tables, undefined refs, vacuous checks. Switch judge model from Haiku to Sonnet 4.5. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 6 ------ tasks/task-gen-grpo-qwen3_5-9b.yaml | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index c6ab178b5d..61b0d1b874 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -558,8 +558,6 @@ def _judge_task(self, prompt: str, verifier: str) -> float: 1. Phantom tables: verifier references tables not in env schema 2. Undefined references: calls to functions/constants not defined 3. Vacuous checks: verifier only checks user existence or len>0 - 4. Prompt-verifier misalignment: verifier checks outcomes the - prompt never asked for (e.g., write checks for read prompts) """ if not self.judge_model or not self.openrouter_api_key: return 1.0 # No judge configured, pass through @@ -583,10 +581,6 @@ def _judge_task(self, prompt: str, verifier: str) -> float: "3. VACUOUS CHECKS — The verifier's ONLY checks are whether a user/account exists or " "whether any table has rows (len > 0), without validating any task-specific outcome. " "These always pass regardless of agent behavior → zero variance → zero signal.\n\n" - "4. PROMPT-VERIFIER MISMATCH — The prompt asks for read-only operations (search, list, " - "view, find, get) but the verifier checks for write operations (new rows in tables, " - "saved items, created records). Read operations don't create DB entries, so the " - "verifier will find nothing → zero signal.\n\n" "IMPORTANT: When in doubt, ACCEPT. A false reject wastes a potentially good task. " "A false accept only wastes one harness call. Err heavily toward ACCEPT.\n\n" f'--- Environment: "{self.env_key}" ---\n\n' diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index c6f41bf7bd..30d2085491 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -37,7 +37,7 @@ envs: MAX_INPUT_LENGTH: 65536 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 20 - JUDGE_MODEL: "anthropic/claude-3.5-haiku" + JUDGE_MODEL: "anthropic/claude-sonnet-4.5" OPENROUTER_API_KEY: "" EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" K_ROLLOUTS: 4 From 4fa377c03a82d0451e55d5a1700dcf66a88b38f8 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 22:46:44 -0700 Subject: [PATCH 074/121] chore: update workdir.ref to main for task-gen YAML Co-Authored-By: Claude Opus 4.6 --- tasks/task-gen-grpo-qwen3_5-9b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index 30d2085491..40382beba9 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -24,7 +24,7 @@ num_nodes: 1 workdir: url: https://github.com/fleet-ai/SkyRL-v2.git - ref: fleet/all + ref: main envs: WANDB_API_KEY: "" From cf0098ccff6a6cc6ae580955b75e9ce879e1af99 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 30 Mar 2026 22:53:47 -0700 Subject: [PATCH 075/121] chore: update workdir.ref to main for VL and 35B YAMLs Co-Authored-By: Claude Opus 4.6 --- tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index 02067cf095..6666663d95 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -33,7 +33,7 @@ num_nodes: 2 workdir: url: https://github.com/fleet-ai/SkyRL-v2.git - ref: fleet/all + ref: main envs: WANDB_API_KEY: "" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 04823bb0ea..f85700a488 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -32,7 +32,7 @@ num_nodes: 1 workdir: url: https://github.com/fleet-ai/SkyRL-v2.git - ref: fleet/all + ref: main envs: WANDB_API_KEY: "" From 9990be60cbd9d015cd05af5800448fe882d0d417 Mon Sep 17 00:00:00 2001 From: Deniz Date: Tue, 31 Mar 2026 20:03:44 -0700 Subject: [PATCH 076/121] =?UTF-8?q?feat(task-gen):=20v4=20reward=20hacking?= =?UTF-8?q?=20fixes=20=E2=80=94=20judge,=20exploration,=20schema=20(#8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four changes to address v3 reward hacking (89% keyword-only verifiers): 1. Enhanced LLM judge prompt: Replaces lenient pre-filter (93.8% false positive rate) with verifier rigor classification. ACCEPT only for DB-grounded verifiers (mutation diff, DB-queried answer validation, specific record lookup). REJECT keyword-only, prompt-echo, dead-code DB, cargo-cult, phantom tables, undefined refs. Backtested on 6,409 v3 trajectories: PASS 4.7%, FAIL 95.3%. 2. AST node limit 500 → 700: Unblocks outlook verifiers (avg 414 nodes, 336 rejected at old limit) without accepting degenerate verifiers. 3. Exploration enforcement: Gates submission on called_describe_db when max_turns > 1. Forces minimum 2 turns (describe_db → submit) so model sees actual schema before generating verifier. 4. Auto-populate env_schema: Calls describe_db("seed") during init_async() when env_schema is empty (all current datasets). Ensures judge prompt and system prompt always have the real schema for phantom table detection. Co-authored-by: Deniz Co-authored-by: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 113 +++++++++++++++--- .../envs/task_gen/verifier_sandbox.py | 2 +- 2 files changed, 97 insertions(+), 18 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 61b0d1b874..f26dddb832 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -569,24 +569,58 @@ def _judge_task(self, prompt: str, verifier: str) -> float: schema_block = self.env_schema if self.env_schema else "Schema not available." judge_prompt = ( - "You are a pre-filter that decides whether to run an expensive evaluation harness " - "on a generated (prompt, verifier) pair. Your goal: reject tasks that will CERTAINLY " - "produce zero useful signal, while accepting anything that MIGHT work.\n\n" - "ACCEPT unless you find a clear, concrete defect from this list:\n\n" - "1. PHANTOM TABLES — The verifier calls .table(\"X\") where X does not exist in the " - "schema below. If the table doesn't exist, the verifier crashes → zero signal.\n\n" - "2. UNDEFINED REFERENCES — The verifier calls functions (e.g., find_new_entries()) or " - "uses constants (e.g., TASK_FAILED_SCORE) that are not defined in the code and are " - "not Python builtins. These cause NameError → zero signal.\n\n" - "3. VACUOUS CHECKS — The verifier's ONLY checks are whether a user/account exists or " - "whether any table has rows (len > 0), without validating any task-specific outcome. " - "These always pass regardless of agent behavior → zero variance → zero signal.\n\n" - "IMPORTANT: When in doubt, ACCEPT. A false reject wastes a potentially good task. " - "A false accept only wastes one harness call. Err heavily toward ACCEPT.\n\n" - f'--- Environment: "{self.env_key}" ---\n\n' + "You are a verifier quality judge for an AI task-generation system. You evaluate " + "whether a generated verifier function can reliably determine if an AI agent " + "correctly completed a task.\n\n" + "## Context\n\n" + "The verifier has access to:\n" + "- `env.db(\"seed\")` — database state BEFORE the agent acted\n" + "- `env.db(\"current\")` — database state AFTER the agent acted\n" + "- `final_answer` — the agent's text response\n" + "- DB query methods: `.table(name)`, `.eq(col, val)`, `.first()`, `.all()`, " + "`.select()`, `.neq()`, `.gt()`, `.lt()`\n\n" + f"Database schema (valid tables and columns):\n```\n{schema_block}\n```\n\n" + f'Environment: "{self.env_key}"\n' f"Available tools: {tools_str}\n\n" - f"Database schema (these are the ONLY valid tables):\n```\n{schema_block}\n```\n\n" - f"--- Generated task ---\n\n" + "## Classification Criteria\n\n" + "### ACCEPT if the verifier does ANY of:\n\n" + "1. **Mutation verification**: Compares seed vs current database state to detect " + "that the agent created, modified, or deleted records.\n\n" + "2. **DB-grounded answer validation**: Queries the database for specific records " + "and validates that values FROM those records appear in `final_answer`. The " + "expected values must come from the database, not from hardcoded strings or " + "the task prompt.\n\n" + "3. **Specific record validation**: Looks up a record by ID or unique field and " + "checks its field values match expected values.\n\n" + "### REJECT if the verifier does ANY of:\n\n" + "1. **Generic keyword checking**: Checks if generic category words appear in " + "`final_answer` (e.g., \"event\", \"venue\", \"concert\", \"price\", \"bedroom\", " + "\"listing\"). These words appear in any topically-relevant response regardless " + "of task completion.\n\n" + "2. **Prompt echo checking**: Checks if values from the task prompt appear in " + "`final_answer` (e.g., \"Los Angeles\" when the prompt asked about LA events). " + "The agent could echo prompt values without doing any work.\n\n" + "3. **Exists-check-only**: Only checks `final_answer is not None` or " + "`len(answer) > 0`.\n\n" + "4. **Dead code DB queries**: Has `seed.table()` or `current.table()` calls but " + "never uses the query results in conditional logic that affects the return value.\n\n" + "5. **Nonexistent API access**: References `env.instance.tool_calls`, " + "`get_call_history()`, or `env.call_tool()` — these don't exist in the verifier " + "runtime.\n\n" + "6. **Cargo-cult DB**: Queries the DB only for user/account existence (which always " + "passes for pre-existing entities), then gates on keyword checks for actual " + "validation.\n\n" + "7. **Phantom tables**: The verifier calls `.table(\"X\")` where X does not exist " + "in the schema above.\n\n" + "8. **Undefined references**: The verifier calls functions or uses constants that " + "are not defined in the code and are not Python builtins.\n\n" + "### Edge Cases:\n\n" + "- Read-only tasks with DB-grounded keywords: ACCEPT — if the verifier queries a " + "DB table to get specific values then checks those values appear in `final_answer`.\n" + "- JSON structure validation without DB cross-reference: REJECT.\n" + "- Existence checks on initially-empty tables (e.g., orders after \"place order\"): " + "weak ACCEPT.\n\n" + f"## Generated Task\n\n" f"Prompt:\n{prompt}\n\n" f"Verifier:\n```python\n{verifier}\n```\n\n" "Answer with exactly one word: ACCEPT or REJECT" @@ -1085,6 +1119,29 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: # 1. Check for block → evaluation pipeline if "" in action: + # Gate: model must call describe_db before submitting a task. + # This forces at least 2 turns (describe_db → submit) and ensures + # the model has seen the actual schema before generating a verifier. + if not self.called_describe_db and self.max_turns > 1: + remaining = self.max_turns - self.turns + nudge = ( + "You must call `describe_db` to see the database schema before submitting a task. " + "Use `describe_db` first, then explore with `query_db`, and finally submit your `` block." + ) + if remaining <= 1: + # Last turn and never explored — game over + return BaseTextEnvStepOutput( + observations=[], + reward=0.0, + done=True, + metadata={"env_key": self.env_key, "turn": self.turns, "done_reason": "no_exploration"}, + ) + return BaseTextEnvStepOutput( + observations=[{"role": "user", "content": nudge}], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "rejected": "no_describe_db"}, + ) return await self._handle_task_generation(action) # 2. Check for tool calls → execute all via Fleet orchestrator or MCP @@ -1234,6 +1291,28 @@ async def init_async(self, prompt: ConversationType) -> Tuple[ConversationType, await self.orch._fleet_env.instance.load() logger.info(f"TaskGenEnv [{self.env_key}]: Fleet env provisioned for DB + tool exploration") + # Auto-populate env_schema from describe_db if not provided in dataset. + # This ensures the judge prompt and system prompt always have the real schema. + if not self.env_schema: + try: + schema_result = await self.orch.describe_db_async(db_name="seed") + if isinstance(schema_result, dict): + # Format as compact "table: col1, col2, ..." lines + lines = [] + for table_name, columns in schema_result.items(): + if isinstance(columns, list): + col_names = ", ".join(str(c) for c in columns) + else: + col_names = str(columns) + lines.append(f"{table_name}: {col_names}") + self.env_schema = "\n".join(lines) + elif isinstance(schema_result, str): + self.env_schema = schema_result + if self.env_schema: + logger.info(f"TaskGenEnv [{self.env_key}]: Auto-populated env_schema from describe_db") + except Exception as e: + logger.warning(f"TaskGenEnv [{self.env_key}]: Failed to auto-populate env_schema: {e}") + # Discover MCP tools so the model can call them if self.mcp_tools: try: diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py index 5da7cd3093..a173404518 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py @@ -66,7 +66,7 @@ def score(self) -> float: # Min/max AST node count for verifier complexity MIN_AST_NODES = 5 # reject trivial verifiers like `return 1.0` -MAX_AST_NODES = 500 # reject overly complex verifiers +MAX_AST_NODES = 700 # reject overly complex verifiers class VerifierSandbox: From 690aad7bd2798d242f537a287587987bb0606333 Mon Sep 17 00:00:00 2001 From: Deniz Date: Tue, 31 Mar 2026 22:02:32 -0700 Subject: [PATCH 077/121] fix: update per-env launch script for v4 - Use TASK_GEN_ENV_CLASSES (not ENV_KEYS) to match run script - Add OPENROUTER_API_KEY (required for LLM judge) - Default to ticketmaster/zillow/outlook (not all 8) - Accept envs as CLI args Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-task-gen-launch-per-env.sh | 35 +++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/scripts/fleet-task-gen-launch-per-env.sh b/scripts/fleet-task-gen-launch-per-env.sh index ee1e21c39d..b03127a94a 100755 --- a/scripts/fleet-task-gen-launch-per-env.sh +++ b/scripts/fleet-task-gen-launch-per-env.sh @@ -1,6 +1,13 @@ #!/usr/bin/env bash -# Launch per-env task-gen experiments — one SkyPilot cluster per environment +# Launch per-env task-gen experiments — one SkyPilot cluster per environment. # Targets ~40 training steps per env by computing NUM_EPOCHS from seed counts. +# +# Usage: +# export FLEET_API_KEY=... WANDB_API_KEY=... OPENROUTER_API_KEY=... +# export AWS_ACCESS_KEY_ID=... AWS_SECRET_ACCESS_KEY=... +# bash scripts/fleet-task-gen-launch-per-env.sh [env1 env2 ...] +# +# If no envs specified, defaults to: ticketmaster zillow outlook set -euo pipefail YAML="tasks/task-gen-grpo-qwen3_5-9b.yaml" @@ -8,28 +15,42 @@ EVAL_RATIO="0.05" TARGET_STEPS=40 BATCH_SIZE=12 +# Required env vars +: "${FLEET_API_KEY:?set FLEET_API_KEY}" +: "${WANDB_API_KEY:?set WANDB_API_KEY}" +: "${OPENROUTER_API_KEY:?set OPENROUTER_API_KEY}" +: "${AWS_ACCESS_KEY_ID:?set AWS_ACCESS_KEY_ID}" +: "${AWS_SECRET_ACCESS_KEY:?set AWS_SECRET_ACCESS_KEY}" + # Seed counts per env from v55 dataset (after EVAL_RATIO=0.05 split) declare -A SEEDS=( [booking]=539 [budget]=567 [carlisle]=336 [outlook]=181 [reddit]=505 [rops-mail]=44 [ticketmaster]=212 [zillow]=106 ) -for env in "${!SEEDS[@]}"; do - seeds=${SEEDS[$env]} +# Default envs if none specified on command line +ENVS=("${@:-ticketmaster zillow outlook}") +if [[ $# -eq 0 ]]; then + ENVS=(ticketmaster zillow outlook) +fi + +for env in "${ENVS[@]}"; do + seeds=${SEEDS[$env]:-100} steps_per_epoch=$(( (seeds + BATCH_SIZE - 1) / BATCH_SIZE )) num_epochs=$(( (TARGET_STEPS + steps_per_epoch - 1) / steps_per_epoch )) total_steps=$(( steps_per_epoch * num_epochs )) - echo "Launching task-gen-${env}: ${seeds} seeds, ${steps_per_epoch} steps/epoch, ${num_epochs} epochs (${total_steps} steps)" + echo "=== Launching task-gen-${env}: ${seeds} seeds, ${steps_per_epoch} steps/epoch, ${num_epochs} epochs (${total_steps} steps) ===" sky launch -c "task-gen-${env}" "$YAML" \ - --env ENV_KEYS="$env" \ - --env EVAL_RATIO="$EVAL_RATIO" \ + --env TASK_GEN_ENV_CLASSES="$env" \ --env NUM_EPOCHS="$num_epochs" \ --env FLEET_API_KEY="$FLEET_API_KEY" \ --env WANDB_API_KEY="$WANDB_API_KEY" \ + --env OPENROUTER_API_KEY="$OPENROUTER_API_KEY" \ --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ --yes --async done -echo "All 8 clusters launched. Monitor with: sky status" +echo "" +echo "Launched ${#ENVS[@]} clusters. Monitor with: sky status" From ffa523874e0d1836e079a4f458e67acc40d1d94a Mon Sep 17 00:00:00 2001 From: Deniz Date: Tue, 31 Mar 2026 22:07:25 -0700 Subject: [PATCH 078/121] fix: use case statement for seed counts (bash compat) declare -A breaks with set -u on some shells. Use case/esac instead. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-task-gen-launch-per-env.sh | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/scripts/fleet-task-gen-launch-per-env.sh b/scripts/fleet-task-gen-launch-per-env.sh index b03127a94a..a1d22970f8 100755 --- a/scripts/fleet-task-gen-launch-per-env.sh +++ b/scripts/fleet-task-gen-launch-per-env.sh @@ -8,10 +8,9 @@ # bash scripts/fleet-task-gen-launch-per-env.sh [env1 env2 ...] # # If no envs specified, defaults to: ticketmaster zillow outlook -set -euo pipefail +set -eo pipefail YAML="tasks/task-gen-grpo-qwen3_5-9b.yaml" -EVAL_RATIO="0.05" TARGET_STEPS=40 BATCH_SIZE=12 @@ -23,19 +22,23 @@ BATCH_SIZE=12 : "${AWS_SECRET_ACCESS_KEY:?set AWS_SECRET_ACCESS_KEY}" # Seed counts per env from v55 dataset (after EVAL_RATIO=0.05 split) -declare -A SEEDS=( - [booking]=539 [budget]=567 [carlisle]=336 [outlook]=181 - [reddit]=505 [rops-mail]=44 [ticketmaster]=212 [zillow]=106 -) +get_seeds() { + case "$1" in + booking) echo 539 ;; budget) echo 567 ;; carlisle) echo 336 ;; + outlook) echo 181 ;; reddit) echo 505 ;; rops-mail) echo 44 ;; + ticketmaster) echo 212 ;; zillow) echo 106 ;; *) echo 100 ;; + esac +} # Default envs if none specified on command line -ENVS=("${@:-ticketmaster zillow outlook}") -if [[ $# -eq 0 ]]; then +if [[ $# -gt 0 ]]; then + ENVS=("$@") +else ENVS=(ticketmaster zillow outlook) fi for env in "${ENVS[@]}"; do - seeds=${SEEDS[$env]:-100} + seeds=$(get_seeds "$env") steps_per_epoch=$(( (seeds + BATCH_SIZE - 1) / BATCH_SIZE )) num_epochs=$(( (TARGET_STEPS + steps_per_epoch - 1) / steps_per_epoch )) total_steps=$(( steps_per_epoch * num_epochs )) From b967f43c17ac75ba2fd8509007e10964b2d30a10 Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 1 Apr 2026 14:56:52 -0700 Subject: [PATCH 079/121] Reduce VL max_input_length from 128K to 96K to prevent OOM Step 16 OOM'd during forward_backward with 128K context + VL screenshots. GPU 0 had 133.4 GiB used, only 436 MiB free, needed 2.32 GiB. 96K matches the 35B run's approach for avoiding OOM without expandable_segments. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 54da097cd2..d12d96e210 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -21,7 +21,7 @@ export DATA_VERSION="${DATA_VERSION:-v52}" export MODALITY="${MODALITY:-computer_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-50}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-131072}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" export DIFFICULTY="${DIFFICULTY:-}" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index f85700a488..f9e71223a0 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -44,7 +44,7 @@ envs: DIFFICULTY: "" MODALITY: "computer_use" MAX_TURNS: 50 - MAX_INPUT_LENGTH: 131072 + MAX_INPUT_LENGTH: 96000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 RUN_ID: "" From 519293dd8cf3e32b38fe2064beb3723ba19913cb Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 1 Apr 2026 22:52:24 -0700 Subject: [PATCH 080/121] Compact schema + remove describe_db tool MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add _format_compact_schema(): "table: col (type), ..." format instead of raw describe_db dump (152K → ~10K for zillow) - Remove describe_db from _META_TOOLS — schema is in the prompt - Remove describe_db exploration gate — was causing context starvation for large-schema envs (zillow 82 tables, outlook 62 tables) - Update system prompt: workflow starts with query_db Co-Authored-By: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 128 +++++++----------- 1 file changed, 50 insertions(+), 78 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index f26dddb832..a8fd67b0e0 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -2,7 +2,7 @@ Task Generation Environment for SkyRL. Multi-turn BaseTextEnv where the LLM can explore the seed database via -``describe_db`` / ``query_db`` meta-tools before generating a task. +``query_db`` meta-tool before generating a task. When ``max_turns > 1`` (the default), the model explores the DB first and then produces a ```` block. When ``max_turns == 1`` it @@ -44,8 +44,28 @@ logger = logging.getLogger(__name__) +def _format_compact_schema(describe_result: Any) -> str: + """Convert a DescribeResponse dict to compact 'table: col (type), ...' format.""" + if not isinstance(describe_result, dict): + return str(describe_result) if describe_result else "" + tables = describe_result.get("tables") + if not tables or not isinstance(tables, list): + return "" + lines = [] + for t in tables: + name = t.get("name", "") + cols = t.get("columns", []) + col_parts = [] + for c in cols: + col_name = c.get("name", "") + col_type = c.get("type", "").lower() + col_parts.append(f"{col_name} ({col_type})" if col_type else col_name) + lines.append(f"{name}: {', '.join(col_parts)}") + return "\n".join(lines) + + # Meta-tools the model can call to explore the seed database. -_META_TOOLS = {"describe_db", "query_db"} +_META_TOOLS = {"query_db"} # All callable tools = meta-tools + any MCP env tools discovered at init time. # Populated per-instance in init_async(). @@ -55,8 +75,8 @@ class TaskGenEnv(BaseTextEnv): """Environment for RL-based task generation. The LLM generates (prompt, verifier) pairs for Fleet environments. - Supports multi-turn: the model can explore the seed DB via ``describe_db`` - and ``query_db`` meta-tools before outputting a ```` block. + Supports multi-turn: the model can explore the seed DB via ``query_db`` + meta-tool before outputting a ```` block. Schema is in the prompt. Reward = llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) @@ -94,7 +114,6 @@ def __init__( # Set of all callable tool names (meta-tools + MCP tools) self.callable_tools = set(_META_TOOLS) # Exploration sequence tracking (reset in init_async) - self.called_describe_db = False self.called_query_db = False # Environment context from dataset (extras) @@ -191,7 +210,7 @@ def __init__( # Provides GRPO gradient signal even when all harness evals return 0. self.base_quality_reward = float(env_config.get("base_quality_reward", 0.1)) if env_config else 0.1 - # Small per-tool-call reward to incentivize DB exploration (describe_db, query_db). + # Small per-tool-call reward to incentivize DB exploration (query_db). # Default 0.0 = off (no behavior change for existing runs). self.tool_call_reward_per_call = float(env_config.get("tool_call_reward_per_call", 0.0)) if env_config else 0.0 @@ -457,7 +476,7 @@ def find_new_entries(table_name, id_field="id", filter_conditions=None): FIX option 1: Make the prompt specific: "Find the designer in Mexico City who joined after 2023" FIX option 2: Make the verifier accept all valid answers: check that ANY designer in Mexico is returned -Use `describe_db`/`query_db` to check the actual data before writing the prompt. If a query returns multiple rows, either narrow the prompt or widen the verifier. Always verify your assumptions by querying — don't guess. You MUST call all three of `describe_db`, `query_db`, and at least one environment API tool before writing the task — your task will be rejected otherwise. +Use `query_db` to check the actual data before writing the prompt. If a query returns multiple rows, either narrow the prompt or widen the verifier. Always verify your assumptions by querying — don't guess. ### Avoiding Overspecification A prompt is overspecified when it dictates HOW to accomplish the task rather than WHAT outcome is needed. This makes the task trivially easy (no learning signal) and doesn't test real problem-solving. @@ -491,34 +510,26 @@ def find_new_entries(table_name, id_field="id", filter_conditions=None): """ ## Exploration Tools -Before generating a task, explore the environment to understand the actual data and API behavior. +The database schema is provided above. Use `query_db` to inspect actual data and environment tools to understand API behavior. ### Database Tools -{"name": "describe_db", "arguments": {}} -Returns the full schema: table names, columns, types. - {"name": "query_db", "arguments": {"sql": "SELECT * FROM table_name LIMIT 5"}} Runs a read-only SQL query against the seed database. ### Environment Tools -You MUST call at least one of the environment's API tools listed above to understand their input/output formats. - -**REQUIRED before generating a task:** You must call ALL THREE of: (1) `describe_db`, (2) `query_db`, and (3) at least one environment API tool. Your task will be rejected if any are missing. - {"name": "tool_name", "arguments": {"param": "value"}} Calls the tool and returns its result. Use this to understand input/output formats. ### Workflow -1. **Explore**: Call `describe_db` to see all tables and columns. -2. **Inspect data**: Call `query_db` with SELECT queries to inspect real data (values, ranges, row counts, patterns). -3. **Try tools**: Call at least one environment API tool to understand its behavior, input/output formats, and edge cases. -4. **Draft a task idea**: Think about what prompt + verifier you could write based on the data you've seen. -5. **Validate your draft**: Before outputting the task, run `query_db` to verify your assumptions: +1. **Inspect data**: Call `query_db` with SELECT queries to inspect real data (values, ranges, row counts, patterns). The schema above shows table and column names. +2. **Try tools**: Call environment API tools to understand their behavior, input/output formats, and edge cases. +3. **Draft a task idea**: Think about what prompt + verifier you could write based on the data you've seen. +4. **Validate your draft**: Before outputting the task, run `query_db` to verify your assumptions: - Does the data your prompt references actually exist? (e.g., "Update Jamie's email" — is there a Jamie?) - Will the verifier return 0.0 on a fresh DB? (Check seed state) - Are there edge cases? (e.g., multiple matches, null values, empty tables) -6. **Iterate**: If your queries reveal problems (wrong assumptions, ambiguous data, too many/few matches), revise your task idea and verify again. Do NOT output the task until you've confirmed the data supports it. -7. **Output**: Only when confident, output the final task in the format below.""" +5. **Iterate**: If your queries reveal problems (wrong assumptions, ambiguous data, too many/few matches), revise your task idea and verify again. Do NOT output the task until you've confirmed the data supports it. +6. **Output**: Only when confident, output the final task in the format below.""" ) # --- D. Few-shot examples removed --- @@ -1080,7 +1091,7 @@ async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: # 5. R = base_quality + tool_call_reward + eval_signal # base_quality: small reward for passing sandbox+judge (structural validity) - # tool_call_reward: incentivize DB exploration (describe_db, query_db) + # tool_call_reward: incentivize DB exploration (query_db) # eval_signal: judge_gate * compute_task_reward (harness-based quality) # This prevents GRPO zero-signal deadlock when all harness evals fail. base_quality = self.base_quality_reward @@ -1110,7 +1121,7 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: Multi-turn flow: 1. block detected → evaluation pipeline (done=True) - 2. detected → execute describe_db/query_db (done=False) + 2. detected → execute query_db/MCP tools (done=False) 3. Neither → nudge observation (done=False) 4. max_turns reached → done=True, reward=0 """ @@ -1119,29 +1130,6 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: # 1. Check for block → evaluation pipeline if "" in action: - # Gate: model must call describe_db before submitting a task. - # This forces at least 2 turns (describe_db → submit) and ensures - # the model has seen the actual schema before generating a verifier. - if not self.called_describe_db and self.max_turns > 1: - remaining = self.max_turns - self.turns - nudge = ( - "You must call `describe_db` to see the database schema before submitting a task. " - "Use `describe_db` first, then explore with `query_db`, and finally submit your `` block." - ) - if remaining <= 1: - # Last turn and never explored — game over - return BaseTextEnvStepOutput( - observations=[], - reward=0.0, - done=True, - metadata={"env_key": self.env_key, "turn": self.turns, "done_reason": "no_exploration"}, - ) - return BaseTextEnvStepOutput( - observations=[{"role": "user", "content": nudge}], - reward=0.0, - done=False, - metadata={"env_key": self.env_key, "turn": self.turns, "rejected": "no_describe_db"}, - ) return await self._handle_task_generation(action) # 2. Check for tool calls → execute all via Fleet orchestrator or MCP @@ -1152,9 +1140,7 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: for tc in tool_calls: if tc["name"] in _META_TOOLS: self.meta_tool_calls += 1 - if tc["name"] == "describe_db": - self.called_describe_db = True - elif tc["name"] == "query_db": + if tc["name"] == "query_db": self.called_query_db = True result = await self._execute_meta_tool(tc) else: @@ -1211,26 +1197,24 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: ) async def _execute_meta_tool(self, tool_call: Dict[str, Any]) -> str: - """Execute a describe_db or query_db meta-tool call via the Fleet orchestrator.""" + """Execute a query_db meta-tool call via the Fleet orchestrator.""" name = tool_call["name"] args = tool_call.get("arguments", {}) if self.orch is None: return "Error: Fleet environment not provisioned. Generate a directly." + if name != "query_db": + return f"Error: Unknown meta-tool '{name}'." + + sql = args.get("sql", "") + if not sql: + return "Error: query_db requires a 'sql' argument." + max_retries = 3 for attempt in range(max_retries): try: - if name == "describe_db": - result = await self.orch.describe_db_async(db_name=args.get("db_name", "seed")) - elif name == "query_db": - sql = args.get("sql", "") - if not sql: - return "Error: query_db requires a 'sql' argument." - result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) - else: - return f"Error: Unknown meta-tool '{name}'." - + result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) if isinstance(result, dict): return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" return f"Tool result:\n{result}" @@ -1246,7 +1230,7 @@ async def _execute_mcp_tool(self, tool_call: Dict[str, Any]) -> str: args = tool_call.get("arguments", {}) if self.mcp_tools is None: - return "Error: MCP tools not available. Use describe_db/query_db or generate a ." + return "Error: MCP tools not available. Use query_db or generate a ." try: result = await self.mcp_tools.call_tool(name, args) @@ -1261,13 +1245,12 @@ async def init_async(self, prompt: ConversationType) -> Tuple[ConversationType, When ``max_turns > 1``, provisions a Fleet environment via ``FleetEnvClient.from_fleet_async`` so the model can call - ``describe_db`` / ``query_db`` during exploration turns. + ``query_db`` during exploration turns. Falls back to single-turn if provisioning fails. """ self.turns = 0 self.meta_tool_calls = 0 self.mcp_tool_calls = 0 - self.called_describe_db = False self.called_query_db = False self.orch = None self.mcp_tools = None @@ -1292,24 +1275,13 @@ async def init_async(self, prompt: ConversationType) -> Tuple[ConversationType, logger.info(f"TaskGenEnv [{self.env_key}]: Fleet env provisioned for DB + tool exploration") # Auto-populate env_schema from describe_db if not provided in dataset. - # This ensures the judge prompt and system prompt always have the real schema. + # Compact format: "table: col1 (type), col2 (type), ..." — one line per table. if not self.env_schema: try: schema_result = await self.orch.describe_db_async(db_name="seed") - if isinstance(schema_result, dict): - # Format as compact "table: col1, col2, ..." lines - lines = [] - for table_name, columns in schema_result.items(): - if isinstance(columns, list): - col_names = ", ".join(str(c) for c in columns) - else: - col_names = str(columns) - lines.append(f"{table_name}: {col_names}") - self.env_schema = "\n".join(lines) - elif isinstance(schema_result, str): - self.env_schema = schema_result + self.env_schema = _format_compact_schema(schema_result) if self.env_schema: - logger.info(f"TaskGenEnv [{self.env_key}]: Auto-populated env_schema from describe_db") + logger.info(f"TaskGenEnv [{self.env_key}]: Auto-populated env_schema ({len(self.env_schema)} chars)") except Exception as e: logger.warning(f"TaskGenEnv [{self.env_key}]: Failed to auto-populate env_schema: {e}") From c6773a398a0656443525b16182237425d65b0ab8 Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 3 Apr 2026 11:05:44 -0700 Subject: [PATCH 081/121] =?UTF-8?q?feat(task-gen):=20verifier=20hardening?= =?UTF-8?q?=20=E2=80=94=20exploration=20gate,=20anti-permissiveness,=20unf?= =?UTF-8?q?iltered=20.all()=20rejection=20(#14)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adapt battle-tested patterns from orchestrator verifier into task-gen prompt and sandbox to fix v4.1 single-turn collapse and degenerate verifiers. - Exploration gate: bounce submission if model hasn't called query_db yet (multi-turn mode only), preventing single-turn collapse - Strengthened verifier template: find_new_entries docstring, set-based comparison example (order-independent validation) - Three new Rules: unfiltered .all() prohibition, set-based comparison, anti-permissiveness (must return 0 on unmodified DB) - AST hard fail: sandbox rejects verifiers with .table("X").all() without preceding .eq()/.neq()/.select() filter (prevents warm-pool saturation) Co-authored-by: Deniz Co-authored-by: Claude Opus 4.6 --- .../skyrl_gym/envs/task_gen/task_gen_env.py | 41 ++++++++++++++++ .../envs/task_gen/verifier_sandbox.py | 47 ++++++++++++++++++- 2 files changed, 87 insertions(+), 1 deletion(-) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index a8fd67b0e0..20c8bc5c44 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -410,6 +410,17 @@ def validate_task(env: Environment, final_answer: str | None = None) -> int: current = env.db("current") def find_new_entries(table_name, id_field="id", filter_conditions=None): + \"\"\"Compare seed vs current to find rows added by the agent. + + Args: + table_name: Table to compare. + id_field: Primary key column (default "id"). + filter_conditions: Optional dict of {{column: value}} filters + applied to BOTH seed and current before comparison. + + Returns: + List[dict] — rows present in current but not in seed. + \"\"\" before_query = seed.table(table_name) after_query = current.table(table_name) if filter_conditions: @@ -419,6 +430,15 @@ def find_new_entries(table_name, id_field="id", filter_conditions=None): before_ids = set(entry[id_field] for entry in before_query.select(id_field).all()) return [e for e in after_query.all() if e[id_field] not in before_ids] + # --- Validation: use SET-BASED comparison, never row-index --- + # GOOD: compare by content/ID sets, order-independent + # expected_ids = {{"id_1", "id_2"}} + # actual_ids = {{row["id"] for row in new_entries}} + # if not expected_ids.issubset(actual_ids): ... + # + # BAD: comparing by row index (fragile, order-dependent) + # if new_entries[0]["id"] == "id_1": ... + # Check conditions... # On early failure: if critical_failure: @@ -454,6 +474,9 @@ def find_new_entries(table_name, id_field="id", filter_conditions=None): - Look up the logged-in user by name/email from the users table, don't assume an ID - Compare `seed` (before) vs `current` (after) to detect what the agent did - Must return `TASK_FAILED_SCORE` on a fresh environment (before agent acts) +- **NEVER call `.table("X").all()` without a preceding `.eq()` or `.neq()` filter** — unfiltered `.all()` fetches every row, which is wasteful and causes warm-pool saturation with large tables. Always filter first: `current.table("orders").eq("user_id", uid).all()`. The only exception is inside `find_new_entries` where `.select(id_field).all()` fetches just IDs for comparison +- **Use order-independent (set-based) comparison** — never compare results by row index or list position. Rows may be returned in any order. Use sets: `actual_ids = {{r["id"] for r in rows}}; assert expected_ids.issubset(actual_ids)`. NEVER do `rows[0]["id"] == expected` — it breaks when row order changes +- **Verifier MUST return 0 on unmodified DB** — the verifier must fail when the agent has not acted. Always compare `seed` vs `current` state. A verifier that only checks `current` without comparing to `seed` is permissive — it may return 1 even when the agent did nothing. Pattern: `new_entries = find_new_entries("table"); if not new_entries: return TASK_FAILED_SCORE` - Use `final_answer` for tasks that require the agent to report a value - Reference actual tool names from this environment @@ -1130,6 +1153,24 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: # 1. Check for block → evaluation pipeline if "" in action: + # Exploration gate: in multi-turn mode, bounce back if model hasn't + # called query_db yet and still has turns remaining. Prevents + # single-turn collapse where model skips DB exploration entirely. + if self.max_turns > 1 and not self.called_query_db and not max_turns_reached: + remaining = self.max_turns - self.turns + nudge = ( + "You must explore the database with `query_db` before submitting a task. " + "Use SELECT queries to inspect actual data — table contents, value ranges, " + f"row counts — so your task and verifier are grounded in real data. " + f"You have {remaining} turn(s) remaining." + ) + observation = {"role": "user", "content": nudge} + return BaseTextEnvStepOutput( + observations=[observation], + reward=0.0, + done=False, + metadata={"env_key": self.env_key, "turn": self.turns, "exploration_gate": True}, + ) return await self._handle_task_generation(action) # 2. Check for tool calls → execute all via Fleet orchestrator or MCP diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py index a173404518..9d6c2ace44 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/verifier_sandbox.py @@ -126,7 +126,10 @@ def validate( # 6. Check for hardcoded return values self._check_hardcoded_returns(tree, result) - # 7. Check prompt length bounds (if prompt provided) + # 7. Check for unfiltered .all() calls + self._check_unfiltered_all(tree, result) + + # 8. Check prompt length bounds (if prompt provided) if prompt is not None: self._check_prompt_bounds(prompt, result) @@ -262,6 +265,48 @@ def _check_hardcoded_returns(self, tree: ast.AST, result: ValidationResult): else: result.checks_passed.append("return_logic") + def _check_unfiltered_all(self, tree: ast.AST, result: ValidationResult): + """Reject verifiers that call .table("X").all() without a filter. + + Unfiltered .all() fetches every row from a table, causing warm-pool + saturation with large tables (6.5k zombie verifiers in production). + + Allowed patterns (filter present in chain): + .table("X").eq("col", val).all() + .table("X").neq("col", val).all() + .table("X").select("col1").all() # ID-only in find_new_entries + + Rejected pattern: + .table("X").all() # no filter before .all() + """ + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + # Match .all() call + if not (isinstance(node.func, ast.Attribute) and node.func.attr == "all"): + continue + # Walk up the chain: .all() is called on some object + receiver = node.func.value + # Check if the receiver is a .table() call (direct: .table("X").all()) + if self._is_table_call(receiver): + result.checks_failed.append("unfiltered_all") + result.error = ( + 'Unfiltered .all() on table — use .eq()/.neq()/.select() ' + 'before .all() (e.g., table("X").eq("col", val).all())' + ) + return + + result.checks_passed.append("filtered_all") + + @staticmethod + def _is_table_call(node: ast.AST) -> bool: + """Check if an AST node is a .table("...") call.""" + return ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "table" + ) + def _check_prompt_bounds(self, prompt: str, result: ValidationResult): """Check that prompt is within reasonable length bounds.""" word_count = len(prompt.split()) From 35efb8e30b66779e7d7d410a78d5d537e48d50dc Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 3 Apr 2026 11:07:46 -0700 Subject: [PATCH 082/121] feat(task-gen): add 35B task-gen YAML and run script (#15) Add Qwen3.5-35B-A3B task generation config: - 2-node (16 GPUs), TP=2, 8 inference engines - flash_attn=false (SDPA), 72K input, chunked lm_head - Task-gen entrypoint with judge, evaluator, and k_rollouts config - Lower LR (5e-7) matching 35B tool-use training Co-authored-by: Deniz Co-authored-by: Claude Opus 4.6 --- scripts/fleet-task-gen-35b-run.sh | 110 +++++++++++++++++++++++++++ tasks/task-gen-grpo-qwen3_5-35b.yaml | 62 +++++++++++++++ 2 files changed, 172 insertions(+) create mode 100755 scripts/fleet-task-gen-35b-run.sh create mode 100644 tasks/task-gen-grpo-qwen3_5-35b.yaml diff --git a/scripts/fleet-task-gen-35b-run.sh b/scripts/fleet-task-gen-35b-run.sh new file mode 100755 index 0000000000..cab355f70e --- /dev/null +++ b/scripts/fleet-task-gen-35b-run.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash +# Task-gen specific run for Qwen3.5-35B: calls common run with task-gen entrypoint +# and 35B-specific config (TP=2, flash_attn=false, 72K input, chunked lm_head). +# +# Usage (from SkyPilot YAML run block): +# bash scripts/fleet-task-gen-35b-run.sh +# +# Required env vars: WANDB_API_KEY, FLEET_API_KEY +# SkyPilot env vars: SKYPILOT_NUM_GPUS_PER_NODE, SKYPILOT_NODE_IPS +set -euo pipefail + +# Export RUN_NAME so task_gen_env can tag rollout dumps +export RUN_NAME="task_gen_35b_$(python3 -c 'import os; print(os.urandom(4).hex())')" + +# Defaults for vars normally set by SkyPilot YAML envs block +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export MODALITY="${MODALITY:-tool_use}" +export NUM_EPOCHS="${NUM_EPOCHS:-20}" +export MAX_TURNS="${MAX_TURNS:-10}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" +export JUDGE_MODEL="${JUDGE_MODEL:-anthropic/claude-sonnet-4.5}" +export EVALUATOR_MODEL="${EVALUATOR_MODEL:-anthropic/claude-sonnet-4.5}" +export K_ROLLOUTS="${K_ROLLOUTS:-4}" +export ALPHA="${ALPHA:-1.0}" +export MAX_EVAL_STEPS="${MAX_EVAL_STEPS:-20}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" + +# Optional: per-env dataset filtering via TASK_GEN_ENV_CLASSES env var +ENV_FILTER_ARGS=() +if [ -n "${TASK_GEN_ENV_CLASSES:-}" ]; then + echo "=== env_filter: $TASK_GEN_ENV_CLASSES ===" + ENV_FILTER_ARGS+=("data.env_filter=$TASK_GEN_ENV_CLASSES") +fi + +# Task-gen GRPO training with 35B model +# --entrypoint: task-gen entrypoint (not main_fleet) +# --env-class: task_gen environment (not fleet_task) +# TP=2: 8 engines × 2 GPUs each across 2 nodes (16 GPUs total) +# flash_attn=false: SDPA to avoid Xid 31 in GatedDeltaNet with vLLM 0.18.0 +# loss_chunk_size=4096: chunked lm_head to avoid OOM on 131K vocab +# --no-pytorch-alloc-conf: disables expandable_segments (conflicts with vLLM CuMemAllocator) +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --nccl-heartbeat 1800 \ + --entrypoint integrations.fleet.entrypoints.main_task_gen \ + --env-class task_gen -- \ + trainer.algorithm.advantage_estimator="grpo" \ + trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ + trainer.flash_attn=false \ + trainer.loss_chunk_size=4096 \ + trainer.use_sample_packing=false \ + generator.inference_engine_tensor_parallel_size=2 \ + trainer.epochs=${NUM_EPOCHS} \ + trainer.eval_batch_size=8 \ + trainer.eval_before_train=false \ + trainer.eval_interval=10 \ + trainer.update_epochs_per_batch=1 \ + trainer.train_batch_size=12 \ + trainer.use_hybrid_env_sampling=true \ + trainer.min_samples_per_env=1 \ + trainer.policy_mini_batch_size=12 \ + trainer.micro_forward_batch_size_per_gpu=1 \ + trainer.micro_train_batch_size_per_gpu=1 \ + trainer.ckpt_interval=10 \ + trainer.max_ckpts_to_keep=1 \ + trainer.max_prompt_length=4096 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.sampling_params.temperature=0.95 \ + generator.sampling_params.top_p=0.95 \ + 'generator.sampling_params.stop=["", ""]' \ + generator.eval_sampling_params.temperature=0.95 \ + generator.eval_sampling_params.top_p=0.95 \ + 'generator.eval_sampling_params.stop=["", ""]' \ + trainer.policy.optimizer_config.lr=5.0e-7 \ + trainer.algorithm.use_kl_loss=true \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.trajectory_timeout_seconds=1800 \ + generator.use_conversation_multi_turn=true \ + generator.n_samples_per_prompt=8 \ + generator.eval_n_samples_per_prompt=3 \ + generator.enforce_eager=false \ + generator.gpu_memory_utilization=0.65 \ + trainer.logger="$LOGGER" \ + trainer.project_name="fleet-task-gen" \ + trainer.run_name="$RUN_NAME" \ + trainer.resume_mode=latest \ + trainer.ckpt_path="$HOME/ckpts/task_gen_35b" \ + trainer.dump_data_batch=true \ + ++environment.skyrl_gym.task_gen.max_turns=$MAX_TURNS \ + ++environment.skyrl_gym.task_gen.judge_model="$JUDGE_MODEL" \ + ++environment.skyrl_gym.task_gen.k_rollouts=$K_ROLLOUTS \ + ++environment.skyrl_gym.task_gen.alpha=$ALPHA \ + ++environment.skyrl_gym.task_gen.max_eval_steps=$MAX_EVAL_STEPS \ + ++environment.skyrl_gym.task_gen.evaluator_model="$EVALUATOR_MODEL" \ + ++environment.skyrl_gym.task_gen.eval_k_rollouts=8 \ + ++environment.skyrl_gym.task_gen.tool_call_reward_per_call=0.02 \ + "${ENV_FILTER_ARGS[@]}" \ + "$@" diff --git a/tasks/task-gen-grpo-qwen3_5-35b.yaml b/tasks/task-gen-grpo-qwen3_5-35b.yaml new file mode 100644 index 0000000000..b52015672b --- /dev/null +++ b/tasks/task-gen-grpo-qwen3_5-35b.yaml @@ -0,0 +1,62 @@ +# Task Generation GRPO Training via SkyPilot - Qwen3.5-35B-A3B (MoE, Multi-Node) +# Usage: sky launch tasks/task-gen-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# +# MoE: 35B total, 3B active (256 experts, 9 active/token). GatedDeltaNet architecture. +# Multi-node (2-node default, 16 GPUs total): TP=2, 8 inference engines. +# flash_attn=false (SDPA) to avoid Xid 31 in GatedDeltaNet with vLLM 0.18.0. + +name: task-gen-grpo-qwen3-5-35b + +resources: + disk_size: 750 + memory: 1500+ + ports: 6479 + any_of: + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + - accelerators: H200:8 + cloud: runpod + - accelerators: H200:8 + cloud: lambda + +num_nodes: 2 + +workdir: + url: https://github.com/fleet-ai/SkyRL-v2.git + ref: main + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v55" + MODALITY: "tool_use" + MAX_TURNS: 10 + MAX_INPUT_LENGTH: 72000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 20 + JUDGE_MODEL: "anthropic/claude-sonnet-4.5" + OPENROUTER_API_KEY: "" + EVALUATOR_MODEL: "anthropic/claude-sonnet-4.5" + K_ROLLOUTS: 4 + ALPHA: "1.0" + MAX_EVAL_STEPS: 20 + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + NUM_INFERENCE_ENGINES: 8 + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh \ + --env-class task_gen + +run: | + bash scripts/fleet-task-gen-35b-run.sh From 086807effd8f396af0950e55a177bfe3c8cc2f28 Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 3 Apr 2026 20:20:34 -0700 Subject: [PATCH 083/121] feat: LLM-synthesized hints for failed trajectories Replace static verifier-feedback hints with Claude Sonnet-powered hint synthesis that analyzes the full failed trajectory + verifier errors to produce actionable guidance. Falls back to static hints on any failure. Key changes: - New hint_synthesizer.py module (batch async synthesis with semaphore) - Expose chat_history in env_metrics for trajectory analysis - Track hint_category (llm_synthesized vs static_fallback) in metrics - Add use_llm_hints, hint_model, hint_llm_timeout config options - Add ANTHROPIC_API_KEY to 35B run script and SkyPilot YAML Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 5 +- scripts/fleet-common-setup.sh | 2 +- skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 3 + .../envs/fleet_task/hint_synthesizer.py | 262 ++++++++++++++++++ .../config/skyrl_gym_config/default.yaml | 3 + skyrl/train/generators/skyrl_gym_generator.py | 123 ++++++-- tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 3 +- 7 files changed, 378 insertions(+), 23 deletions(-) create mode 100644 skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 3ea02704a1..2a879a3a9c 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -2,7 +2,7 @@ # Single source of truth for Qwen3.5-35B-A3B GRPO training config. # Called by the SkyPilot YAML and by fleet-research run.sh. # -# Required env vars: FLEET_API_KEY, WANDB_API_KEY +# Required env vars: FLEET_API_KEY, WANDB_API_KEY, ANTHROPIC_API_KEY # Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) set -euo pipefail cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo root) @@ -29,6 +29,8 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" : "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" : "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" +: "${ANTHROPIC_API_KEY:?Set ANTHROPIC_API_KEY before running (needed for LLM hint synthesis)}" +export ANTHROPIC_API_KEY bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ @@ -38,6 +40,7 @@ bash scripts/fleet-common-run.sh \ environment.skyrl_gym.fleet_task.partial_reward=true \ environment.skyrl_gym.fleet_task.enable_hints=true \ environment.skyrl_gym.fleet_task.n_hint_samples=2 \ + environment.skyrl_gym.fleet_task.use_llm_hints=true \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ trainer.flash_attn=false \ diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh index 37fb55ecc8..3bf7d8b7da 100755 --- a/scripts/fleet-common-setup.sh +++ b/scripts/fleet-common-setup.sh @@ -92,7 +92,7 @@ source .venv/bin/activate uv sync --extra fsdp uv pip install wandb boto3 awscli # Pin fleet-python<=0.2.119: 0.2.120+ has async BaseWrapper bug (missing jwt/team_id params) -uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" +uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" anthropic # --- Extra pip packages (installed before extra-setup to avoid dependency downgrades) --- if [ -n "$EXTRA_PIP" ]; then diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py index 86b52de654..75b36cbdbe 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -829,6 +829,9 @@ def get_metrics(self) -> Dict[str, Any]: metrics["verifier_error"] = self._verifier_error if self._tool_error_messages: metrics["tool_error_messages"] = self._tool_error_messages + # Include chat_history for LLM hint synthesis (consumed then deleted by generator) + if self.chat_history: + metrics["chat_history"] = self.chat_history return metrics @staticmethod diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py new file mode 100644 index 0000000000..e04b3cb5fb --- /dev/null +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py @@ -0,0 +1,262 @@ +"""LLM-synthesized hints for failed trajectories. + +Analyzes the full failed trajectory + verifier errors and produces actionable +guidance via Claude Sonnet. Falls back to static build_hint_text() on failure. +""" + +import asyncio +import logging +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +# Category tag for LLM-synthesized hints +CATEGORY_LLM = "llm_synthesized" +CATEGORY_STATIC = "static_fallback" +CATEGORY_LLM_FAILED = "llm_failed_static_fallback" + +HINT_SYSTEM_PROMPT = """\ +You are a debugging assistant for an AI agent that failed a task. \ +Analyze the failed trajectory and verifier feedback, then provide \ +2-5 sentences of actionable guidance for the agent's next attempt. + +Rules: +- Be specific: reference exact actions that failed and why. +- Be actionable: tell the agent what to do differently, not just what went wrong. +- If the agent ran out of context/turns, suggest being more efficient (fewer unnecessary steps). +- If tool calls errored, explain the correct usage pattern. +- Do NOT repeat the task instructions verbatim. +- Do NOT say "the previous attempt failed" — the agent already knows that.""" + + +def format_trajectory_for_hint( + chat_history: List[Dict[str, Any]], + max_turns: int = 15, + max_msg_chars: int = 3000, + max_total_chars: int = 150_000, +) -> str: + """Format chat_history into readable text for LLM hint synthesis. + + Truncates to the last `max_turns` messages, caps individual messages, + and enforces a total character budget. + """ + if not chat_history: + return "(empty trajectory)" + + # Take last N turns + recent = chat_history[-max_turns:] + parts = [] + total = 0 + + for msg in recent: + role = msg.get("role", "unknown") + content = msg.get("content", "") + + # Handle list-type content (multimodal messages) + if isinstance(content, list): + text_parts = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + text_parts.append(block.get("text", "")) + elif block.get("type") == "image_url": + text_parts.append("[image]") + elif block.get("type") == "tool_use": + name = block.get("name", "unknown_tool") + inp = str(block.get("input", ""))[:500] + text_parts.append(f"[tool_use: {name}({inp})]") + elif block.get("type") == "tool_result": + text_parts.append(f"[tool_result: {str(block.get('content', ''))[:500]}]") + else: + text_parts.append(str(block)[:200]) + else: + text_parts.append(str(block)[:200]) + content = "\n".join(text_parts) + + if isinstance(content, str) and len(content) > max_msg_chars: + content = content[:max_msg_chars] + f"... [truncated, {len(content)} chars total]" + + line = f"[{role}]: {content}" + if total + len(line) > max_total_chars: + parts.append(f"... [trajectory truncated at {max_total_chars} chars]") + break + parts.append(line) + total += len(line) + + return "\n\n".join(parts) + + +def format_verifier_feedback( + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], +) -> str: + """Extract verifier errors/successes and tool errors into readable text.""" + import ast + import re + + parts = [] + + if verifier_stdout: + err_match = re.search( + r">>> ERROR_ACCUMULATOR >>>\n(.+?)\n<<< ERROR_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + suc_match = re.search( + r">>> SUCCESS_ACCUMULATOR >>>\n(.+?)\n<<< SUCCESS_ACCUMULATOR <<<", + verifier_stdout, + re.DOTALL, + ) + if err_match or suc_match: + try: + errors = ast.literal_eval(err_match.group(1)) if err_match else [] + successes = ast.literal_eval(suc_match.group(1)) if suc_match else [] + except Exception: + errors, successes = [], [] + if successes: + parts.append(f"Verifier checks PASSED ({len(successes)}):") + for s in successes[:10]: + parts.append(f" - {str(s)[:200]}") + if errors: + parts.append(f"Verifier checks FAILED ({len(errors)}):") + for e in errors[:10]: + parts.append(f" - {str(e)[:200]}") + + if verifier_error: + parts.append(f"Verifier error: {verifier_error[:500]}") + + if tool_error_messages: + unique = list(dict.fromkeys(tool_error_messages))[:10] + parts.append("Tool errors encountered:") + for e in unique: + parts.append(f" - {e[:300]}") + + return "\n".join(parts) if parts else "(no verifier feedback available)" + + +async def synthesize_hint( + task_prompt: str, + chat_history: List[Dict[str, Any]], + verifier_stdout: Optional[str], + verifier_error: Optional[str], + tool_error_messages: Optional[List[str]], + model: str = "claude-sonnet-4-20250514", + timeout: float = 30.0, + static_fallback_fn=None, +) -> Tuple[str, str]: + """Synthesize a hint from a failed trajectory using an LLM. + + Returns: + (hint_text, hint_category) where category is one of + CATEGORY_LLM, CATEGORY_STATIC, CATEGORY_LLM_FAILED. + """ + try: + import anthropic + except ImportError: + logger.warning("anthropic package not installed, falling back to static hints") + if static_fallback_fn: + return static_fallback_fn(verifier_stdout, verifier_error, tool_error_messages), CATEGORY_STATIC + return "The previous attempt failed. Try a different approach.", CATEGORY_STATIC + + trajectory_text = format_trajectory_for_hint(chat_history) + verifier_text = format_verifier_feedback(verifier_stdout, verifier_error, tool_error_messages) + + user_message = f"""## Task +{task_prompt[:5000]} + +## Agent Trajectory (last turns) +{trajectory_text} + +## Verifier Feedback +{verifier_text} + +Based on the trajectory and feedback above, provide 2-5 sentences of specific, actionable guidance for the agent's next attempt.""" + + try: + client = anthropic.AsyncAnthropic( + api_key=os.environ.get("ANTHROPIC_API_KEY"), + ) + response = await asyncio.wait_for( + client.messages.create( + model=model, + max_tokens=300, + temperature=0.3, + system=HINT_SYSTEM_PROMPT, + messages=[{"role": "user", "content": user_message}], + ), + timeout=timeout, + ) + hint_text = response.content[0].text.strip() + if hint_text: + return hint_text, CATEGORY_LLM + else: + logger.warning("LLM returned empty hint, falling back to static") + except asyncio.TimeoutError: + logger.warning(f"LLM hint synthesis timed out after {timeout}s") + except Exception as e: + logger.warning(f"LLM hint synthesis failed: {e}") + + # Fallback to static hint + if static_fallback_fn: + return static_fallback_fn(verifier_stdout, verifier_error, tool_error_messages), CATEGORY_LLM_FAILED + return "The previous attempt failed. Try a different approach.", CATEGORY_LLM_FAILED + + +async def synthesize_hints_batch( + hint_requests: List[Dict[str, Any]], + model: str = "claude-sonnet-4-20250514", + timeout: float = 30.0, + max_concurrency: int = 20, + static_fallback_fn=None, +) -> List[Tuple[str, str]]: + """Synthesize hints for a batch of failed trajectories concurrently. + + Args: + hint_requests: List of dicts with keys: + - task_prompt: str + - chat_history: List[Dict] + - verifier_stdout: Optional[str] + - verifier_error: Optional[str] + - tool_error_messages: Optional[List[str]] + - instance_id: str (for logging) + model: LLM model to use + timeout: per-request timeout + max_concurrency: max concurrent LLM calls + static_fallback_fn: fallback function for static hints + + Returns: + List of (hint_text, hint_category) tuples, one per request. + """ + if not hint_requests: + return [] + + sem = asyncio.Semaphore(max_concurrency) + start = time.monotonic() + + async def _synth(req: Dict[str, Any]) -> Tuple[str, str]: + async with sem: + return await synthesize_hint( + task_prompt=req["task_prompt"], + chat_history=req.get("chat_history", []), + verifier_stdout=req.get("verifier_stdout"), + verifier_error=req.get("verifier_error"), + tool_error_messages=req.get("tool_error_messages"), + model=model, + timeout=timeout, + static_fallback_fn=static_fallback_fn, + ) + + results = await asyncio.gather(*[_synth(req) for req in hint_requests]) + + elapsed = time.monotonic() - start + n_llm = sum(1 for _, cat in results if cat == CATEGORY_LLM) + n_fallback = sum(1 for _, cat in results if cat in (CATEGORY_STATIC, CATEGORY_LLM_FAILED)) + logger.info( + f"Hint synthesis batch: {len(results)} total, {n_llm} LLM-synthesized, " + f"{n_fallback} fallback, {elapsed:.1f}s elapsed" + ) + + return list(results) diff --git a/skyrl/train/config/skyrl_gym_config/default.yaml b/skyrl/train/config/skyrl_gym_config/default.yaml index 7d450d8f12..d7826528ac 100644 --- a/skyrl/train/config/skyrl_gym_config/default.yaml +++ b/skyrl/train/config/skyrl_gym_config/default.yaml @@ -24,6 +24,9 @@ fleet_task: hint_reward_threshold: 0.0 n_hint_samples: 2 enable_context_tools: false + use_llm_hints: false + hint_model: "claude-sonnet-4-20250514" + hint_llm_timeout: 30.0 task_gen: max_turns: 10 diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 56602b5f9a..7694c14763 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -791,6 +791,20 @@ def _update_chat_history( chat_history += new_obs return chat_history + @staticmethod + def _extract_task_prompt(prompt: ConversationType) -> str: + """Extract the user's task prompt text from a conversation.""" + for msg in prompt: + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, list): + return " ".join( + b.get("text", "") for b in content + if isinstance(b, dict) and b.get("type") == "text" + ) + return str(content) + return "" + async def _run_hint_augmentation( self, all_outputs: List[TrajectoryOutput], @@ -806,7 +820,7 @@ async def _run_hint_augmentation( """Run hinted rollouts for prompts where all raw samples failed. Groups raw outputs by instance_id, identifies groups where max_reward < threshold, - builds hint text from verifier feedback, and launches additional hinted rollouts. + synthesizes hints (LLM or static), and launches additional hinted rollouts. Uses RLTF-SD: hinted rollout prompt_ids are replaced with the original unhinted prompt_ids so the model learns to produce hint-quality outputs conditioned on the @@ -817,18 +831,18 @@ async def _run_hint_augmentation( """ from skyrl_gym.envs.fleet_task.env import FleetTaskEnv + use_llm_hints = hint_cfg.get("use_llm_hints", False) if hasattr(hint_cfg, "get") else False + hint_model = hint_cfg.get("hint_model", "claude-sonnet-4-20250514") if hasattr(hint_cfg, "get") else "claude-sonnet-4-20250514" + hint_timeout = hint_cfg.get("hint_llm_timeout", 30.0) if hasattr(hint_cfg, "get") else 30.0 + # 1. Group outputs by instance_id groups: Dict[str, List[Tuple[int, TrajectoryOutput]]] = defaultdict(list) for i, output in enumerate(all_outputs): iid = trajectory_ids[i].instance_id groups[iid].append((i, output)) - # 2. Identify prompts needing hints - hint_tasks = [] - hint_tids = [] - hint_envs = [] - orig_prompt_ids = [] # unhinted prompt_ids for RLTF-SD - prompts_hinted = 0 + # 2. Identify prompts needing hints and collect data for LLM synthesis + failed_groups = [] # (iid, best_orig_idx, best_output, best_reward) hint_reward_threshold = hint_cfg.get("hint_reward_threshold", 0.0) if hasattr(hint_cfg, "get") else 0.0 for iid, items in groups.items(): @@ -841,34 +855,73 @@ async def _run_hint_augmentation( if max_reward > hint_reward_threshold: continue # at least one raw sample has signal - # Find best raw rollout (highest partial reward) for feedback best_idx = max(range(len(items)), key=lambda j: rewards[j]) best_orig_idx, best_output = items[best_idx] - metrics = best_output.env_metrics - - # Build hint from verifier feedback - hint_text = FleetTaskEnv.build_hint_text( - verifier_stdout=metrics.get("verifier_stdout"), - verifier_error=metrics.get("verifier_error"), - tool_error_messages=metrics.get("tool_error_messages"), + failed_groups.append((iid, best_orig_idx, best_output, rewards[best_idx])) + + if not failed_groups: + return [], [], [] + + # 3. Build hints — LLM-synthesized or static + if use_llm_hints: + from skyrl_gym.envs.fleet_task.hint_synthesizer import synthesize_hints_batch + + hint_requests = [] + for iid, best_orig_idx, best_output, _ in failed_groups: + metrics = best_output.env_metrics + hint_requests.append({ + "task_prompt": self._extract_task_prompt(prompts[best_orig_idx]), + "chat_history": metrics.get("chat_history", []), + "verifier_stdout": metrics.get("verifier_stdout"), + "verifier_error": metrics.get("verifier_error"), + "tool_error_messages": metrics.get("tool_error_messages"), + "instance_id": iid, + }) + + hint_results = await synthesize_hints_batch( + hint_requests=hint_requests, + model=hint_model, + timeout=hint_timeout, + static_fallback_fn=FleetTaskEnv.build_hint_text, ) + else: + hint_results = [] + for iid, best_orig_idx, best_output, _ in failed_groups: + metrics = best_output.env_metrics + hint_text = FleetTaskEnv.build_hint_text( + verifier_stdout=metrics.get("verifier_stdout"), + verifier_error=metrics.get("verifier_error"), + tool_error_messages=metrics.get("tool_error_messages"), + ) + hint_results.append((hint_text, "static_fallback")) + + # 4. Create hinted agent_loop tasks + hint_tasks = [] + hint_tids = [] + hint_envs = [] + orig_prompt_ids = [] + hint_categories = [] + prompts_hinted = 0 + + for group_idx, (iid, best_orig_idx, best_output, best_reward) in enumerate(failed_groups): + hint_text, hint_category = hint_results[group_idx] if not hint_text: continue logger.info( - f"Hint for instance {iid} (best_reward={rewards[best_idx]:.3f}, " - f"verifier_stdout={bool(metrics.get('verifier_stdout'))}, " - f"verifier_error={bool(metrics.get('verifier_error'))}):\n{hint_text}" + f"Hint [{hint_category}] for instance {iid} " + f"(best_reward={best_reward:.3f}):\n{hint_text[:500]}" ) prompts_hinted += 1 - # Create hinted agent_loop tasks (new env instances) + items = groups[iid] base_rep_id = max(item[0] for item in items) + 1 n_hint = hint_cfg.get("n_hint_samples", 2) if hasattr(hint_cfg, "get") else 2 for h in range(n_hint): hinted_extras = dict(env_extras[best_orig_idx]) hinted_extras["hint"] = hint_text hinted_extras["is_hinted"] = True + hinted_extras["hint_category"] = hint_category tid = TrajectoryID(instance_id=iid, repetition_id=base_rep_id + h) hint_tasks.append( self.agent_loop( @@ -884,8 +937,13 @@ async def _run_hint_augmentation( hint_tids.append(tid) hint_envs.append(env_classes[best_orig_idx]) orig_prompt_ids.append(best_output.prompt_ids) + hint_categories.append(hint_category) - # 3. Run all hinted rollouts in parallel + # 5. Clean up chat_history from env_metrics to free memory + for _, _, best_output, _ in failed_groups: + best_output.env_metrics.pop("chat_history", None) + + # 6. Run all hinted rollouts in parallel if hint_tasks: logger.info( f"Hint augmentation: {prompts_hinted} prompts need hints, " @@ -904,6 +962,9 @@ async def _run_hint_augmentation( for i, output in enumerate(hint_outputs): hinted_len = len(output.prompt_ids) output.prompt_ids = orig_prompt_ids[i] + # Propagate hint_category into env_metrics for tracking + if isinstance(output.env_metrics, dict): + output.env_metrics["hint_category"] = hint_categories[i] logger.debug( f"RLTF-SD: replaced hinted prompt ({hinted_len} tokens) " f"with original prompt ({len(output.prompt_ids)} tokens)" @@ -1201,6 +1262,28 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False break rollout_metrics["hint/signal_rescued"] = rescued / len(hinted_iids) if hinted_iids else 0.0 + # Category-level metrics (LLM-synthesized vs static) + from skyrl_gym.envs.fleet_task.hint_synthesizer import CATEGORY_LLM + llm_metrics = [m for m in hinted_metrics if m.get("hint_category") == CATEGORY_LLM] + n_llm = len(llm_metrics) + if n_llm > 0: + llm_rewards = [m.get("final_reward", 0.0) or 0.0 for m in llm_metrics] + llm_success = sum(1 for r in llm_rewards if r > 0) + rollout_metrics["hint/category_llm_synthesized_count"] = n_llm + rollout_metrics["hint/category_llm_synthesized_success_rate"] = llm_success / n_llm + static_metrics = [m for m in hinted_metrics if m.get("hint_category", "").endswith("fallback")] + n_static = len(static_metrics) + if n_static > 0: + static_rewards = [m.get("final_reward", 0.0) or 0.0 for m in static_metrics] + static_success = sum(1 for r in static_rewards if r > 0) + rollout_metrics["hint/category_static_fallback_count"] = n_static + rollout_metrics["hint/category_static_fallback_success_rate"] = static_success / n_static + + # Clean up chat_history from env_metrics to prevent it from being serialized downstream + for m in env_metrics: + if isinstance(m, dict): + m.pop("chat_history", None) + if self.generator_cfg.zero_reward_on_non_stop: # set reward to 0 if the stop reason is not "stop" rewards = self._zero_reward_if_not_stop(rewards, stop_reasons) diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index 6666663d95..2d6a28ea3d 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -1,5 +1,5 @@ # Fleet Task GRPO Training via SkyPilot - Qwen3.5-35B-A3B (MoE, Multi-Node) -# Usage: sky launch tasks/openenv-fleet-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= +# Usage: sky launch tasks/openenv-fleet-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= --env ANTHROPIC_API_KEY= # # MoE: 35B total, 3B active (256 experts, 9 active/token). GatedDeltaNet architecture. # 262K native context. All 35B params in memory (~70GB fp16), optimizer ~140GB, gradients ~70GB. @@ -38,6 +38,7 @@ workdir: envs: WANDB_API_KEY: "" FLEET_API_KEY: "" + ANTHROPIC_API_KEY: "" LOGGER: "wandb" INFERENCE_BACKEND: "vllm" DATA_VERSION: "v55" From 19ad98c76e3b1a82d497d5121ca45db5fba7d55f Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 3 Apr 2026 20:27:25 -0700 Subject: [PATCH 084/121] Enable partial_reward for VL training Switch partial_reward=false to partial_reward=true in fleet-vl-run.sh. Old fork v3 showed 5x faster learning with partial rewards enabled. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index d12d96e210..572bf519c5 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -40,7 +40,7 @@ bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ --set-ulimit --no-pytorch-alloc-conf -- \ environment.skyrl_gym.fleet_task.ttl_seconds=1800 \ - environment.skyrl_gym.fleet_task.partial_reward=false \ + environment.skyrl_gym.fleet_task.partial_reward=true \ environment.skyrl_gym.fleet_task.enable_hints=false \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-9B" \ From b6174df58ddb54cdd28a26059bade91d7ce8c54c Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 3 Apr 2026 21:22:00 -0700 Subject: [PATCH 085/121] fix: use OpenRouter via litellm instead of direct Anthropic API Switch hint synthesis from anthropic SDK to litellm acompletion() with openrouter/anthropic/claude-sonnet model routing. Replaces ANTHROPIC_API_KEY with OPENROUTER_API_KEY everywhere. Co-Authored-By: Claude Opus 4.6 --- scripts/fleet-35b-run.sh | 6 ++--- scripts/fleet-common-setup.sh | 2 +- .../envs/fleet_task/hint_synthesizer.py | 24 +++++++++---------- .../config/skyrl_gym_config/default.yaml | 2 +- skyrl/train/generators/skyrl_gym_generator.py | 2 +- tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 4 ++-- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 2a879a3a9c..6e7411faca 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -2,7 +2,7 @@ # Single source of truth for Qwen3.5-35B-A3B GRPO training config. # Called by the SkyPilot YAML and by fleet-research run.sh. # -# Required env vars: FLEET_API_KEY, WANDB_API_KEY, ANTHROPIC_API_KEY +# Required env vars: FLEET_API_KEY, WANDB_API_KEY, OPENROUTER_API_KEY # Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) set -euo pipefail cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo root) @@ -29,8 +29,8 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" : "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" : "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" -: "${ANTHROPIC_API_KEY:?Set ANTHROPIC_API_KEY before running (needed for LLM hint synthesis)}" -export ANTHROPIC_API_KEY +: "${OPENROUTER_API_KEY:?Set OPENROUTER_API_KEY before running (needed for LLM hint synthesis)}" +export OPENROUTER_API_KEY bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh index 3bf7d8b7da..37fb55ecc8 100755 --- a/scripts/fleet-common-setup.sh +++ b/scripts/fleet-common-setup.sh @@ -92,7 +92,7 @@ source .venv/bin/activate uv sync --extra fsdp uv pip install wandb boto3 awscli # Pin fleet-python<=0.2.119: 0.2.120+ has async BaseWrapper bug (missing jwt/team_id params) -uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" anthropic +uv pip install "litellm>=1.75.5" "fleet-python<=0.2.119" logfire "mcp>=1.0.0" # --- Extra pip packages (installed before extra-setup to avoid dependency downgrades) --- if [ -n "$EXTRA_PIP" ]; then diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py index e04b3cb5fb..cbb029d6a4 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py @@ -1,7 +1,8 @@ """LLM-synthesized hints for failed trajectories. Analyzes the full failed trajectory + verifier errors and produces actionable -guidance via Claude Sonnet. Falls back to static build_hint_text() on failure. +guidance via an LLM (via litellm/OpenRouter). Falls back to static +build_hint_text() on failure. """ import asyncio @@ -143,20 +144,20 @@ async def synthesize_hint( verifier_stdout: Optional[str], verifier_error: Optional[str], tool_error_messages: Optional[List[str]], - model: str = "claude-sonnet-4-20250514", + model: str = "openrouter/anthropic/claude-sonnet-4-20250514", timeout: float = 30.0, static_fallback_fn=None, ) -> Tuple[str, str]: - """Synthesize a hint from a failed trajectory using an LLM. + """Synthesize a hint from a failed trajectory using an LLM via litellm. Returns: (hint_text, hint_category) where category is one of CATEGORY_LLM, CATEGORY_STATIC, CATEGORY_LLM_FAILED. """ try: - import anthropic + from litellm import acompletion except ImportError: - logger.warning("anthropic package not installed, falling back to static hints") + logger.warning("litellm not installed, falling back to static hints") if static_fallback_fn: return static_fallback_fn(verifier_stdout, verifier_error, tool_error_messages), CATEGORY_STATIC return "The previous attempt failed. Try a different approach.", CATEGORY_STATIC @@ -176,20 +177,19 @@ async def synthesize_hint( Based on the trajectory and feedback above, provide 2-5 sentences of specific, actionable guidance for the agent's next attempt.""" try: - client = anthropic.AsyncAnthropic( - api_key=os.environ.get("ANTHROPIC_API_KEY"), - ) response = await asyncio.wait_for( - client.messages.create( + acompletion( model=model, max_tokens=300, temperature=0.3, - system=HINT_SYSTEM_PROMPT, - messages=[{"role": "user", "content": user_message}], + messages=[ + {"role": "system", "content": HINT_SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ], ), timeout=timeout, ) - hint_text = response.content[0].text.strip() + hint_text = response.choices[0].message.content.strip() if hint_text: return hint_text, CATEGORY_LLM else: diff --git a/skyrl/train/config/skyrl_gym_config/default.yaml b/skyrl/train/config/skyrl_gym_config/default.yaml index d7826528ac..3518b61c23 100644 --- a/skyrl/train/config/skyrl_gym_config/default.yaml +++ b/skyrl/train/config/skyrl_gym_config/default.yaml @@ -25,7 +25,7 @@ fleet_task: n_hint_samples: 2 enable_context_tools: false use_llm_hints: false - hint_model: "claude-sonnet-4-20250514" + hint_model: "openrouter/anthropic/claude-sonnet-4-20250514" hint_llm_timeout: 30.0 task_gen: diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 7694c14763..eef664685d 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -832,7 +832,7 @@ async def _run_hint_augmentation( from skyrl_gym.envs.fleet_task.env import FleetTaskEnv use_llm_hints = hint_cfg.get("use_llm_hints", False) if hasattr(hint_cfg, "get") else False - hint_model = hint_cfg.get("hint_model", "claude-sonnet-4-20250514") if hasattr(hint_cfg, "get") else "claude-sonnet-4-20250514" + hint_model = hint_cfg.get("hint_model", "openrouter/anthropic/claude-sonnet-4-20250514") if hasattr(hint_cfg, "get") else "openrouter/anthropic/claude-sonnet-4-20250514" hint_timeout = hint_cfg.get("hint_llm_timeout", 30.0) if hasattr(hint_cfg, "get") else 30.0 # 1. Group outputs by instance_id diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index 2d6a28ea3d..b0e85b4f55 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -1,5 +1,5 @@ # Fleet Task GRPO Training via SkyPilot - Qwen3.5-35B-A3B (MoE, Multi-Node) -# Usage: sky launch tasks/openenv-fleet-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= --env ANTHROPIC_API_KEY= +# Usage: sky launch tasks/openenv-fleet-grpo-qwen3_5-35b.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= --env OPENROUTER_API_KEY= # # MoE: 35B total, 3B active (256 experts, 9 active/token). GatedDeltaNet architecture. # 262K native context. All 35B params in memory (~70GB fp16), optimizer ~140GB, gradients ~70GB. @@ -38,7 +38,7 @@ workdir: envs: WANDB_API_KEY: "" FLEET_API_KEY: "" - ANTHROPIC_API_KEY: "" + OPENROUTER_API_KEY: "" LOGGER: "wandb" INFERENCE_BACKEND: "vllm" DATA_VERSION: "v55" From 8cf2fd8c5df11191d7ce4815227b83fa2bdcb06b Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 3 Apr 2026 22:20:43 -0700 Subject: [PATCH 086/121] fix: use correct OpenRouter model ID for hint synthesis anthropic/claude-sonnet-4-20250514 is not a valid OpenRouter model ID. The correct ID is anthropic/claude-sonnet-4. Co-Authored-By: Claude Opus 4.6 --- skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py | 4 ++-- skyrl/train/config/skyrl_gym_config/default.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py index cbb029d6a4..f71edbac53 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/hint_synthesizer.py @@ -144,7 +144,7 @@ async def synthesize_hint( verifier_stdout: Optional[str], verifier_error: Optional[str], tool_error_messages: Optional[List[str]], - model: str = "openrouter/anthropic/claude-sonnet-4-20250514", + model: str = "openrouter/anthropic/claude-sonnet-4", timeout: float = 30.0, static_fallback_fn=None, ) -> Tuple[str, str]: @@ -207,7 +207,7 @@ async def synthesize_hint( async def synthesize_hints_batch( hint_requests: List[Dict[str, Any]], - model: str = "claude-sonnet-4-20250514", + model: str = "openrouter/anthropic/claude-sonnet-4", timeout: float = 30.0, max_concurrency: int = 20, static_fallback_fn=None, diff --git a/skyrl/train/config/skyrl_gym_config/default.yaml b/skyrl/train/config/skyrl_gym_config/default.yaml index 3518b61c23..c6a2fec7a8 100644 --- a/skyrl/train/config/skyrl_gym_config/default.yaml +++ b/skyrl/train/config/skyrl_gym_config/default.yaml @@ -25,7 +25,7 @@ fleet_task: n_hint_samples: 2 enable_context_tools: false use_llm_hints: false - hint_model: "openrouter/anthropic/claude-sonnet-4-20250514" + hint_model: "openrouter/anthropic/claude-sonnet-4" hint_llm_timeout: 30.0 task_gen: From c03cf7919195b6efee7cc991f4399302f67dce53 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 6 Apr 2026 23:23:28 -0700 Subject: [PATCH 087/121] Binary reward + truncate query_db responses - Reward: 1.0 if mixed solver results (pass+fail), 0.0 otherwise - Truncate query_db to 5 rows / 3000 chars to prevent context blowout - Close env before get_metrics() so final_reward is captured Co-Authored-By: Claude Opus 4.6 (1M context) --- integrations/fleet/task_gen_reward.py | 70 +++---------------- .../skyrl_gym/envs/task_gen/task_gen_env.py | 26 +++---- skyrl/train/generators/skyrl_gym_generator.py | 13 ++-- 3 files changed, 31 insertions(+), 78 deletions(-) diff --git a/integrations/fleet/task_gen_reward.py b/integrations/fleet/task_gen_reward.py index 80bdf5a886..55248ece7d 100644 --- a/integrations/fleet/task_gen_reward.py +++ b/integrations/fleet/task_gen_reward.py @@ -1,89 +1,41 @@ """ Reward functions for task generation RL. -Computes: - R(task) = llm_validity * (alpha * var(raw_scores) + (p_hint - p_raw)) - -Components: - - var(raw_scores): Variance of k raw (no-hint) evaluator rollouts. - Measures difficulty calibration — maximized at p_raw ≈ 0.5 - (Bernoulli variance = 0.25). Tasks at the evaluator's frontier. - - p_hint - p_raw: Hint gap — mean(hinted) minus mean(raw). - Positive when hints help, meaning the task is hard but solvable. - Captures learnability beyond current capability. - - llm_validity: LLM-as-a-judge gate (0/1). Kills reward for broken tasks. - - alpha: Weight balancing variance (frontier difficulty) vs hint gap (learnability). Default 1.0 (equal weight). +Binary reward: 1.0 if solver rollouts have mixed results (at least one pass +and one fail), 0.0 otherwise. Mixed results = the task is at the right +difficulty frontier. """ from typing import Dict, List def compute_variance(scores: List[float]) -> float: - """Compute variance of binary rollout scores. - - Args: - scores: List of binary (0/1) rollout outcomes. - - Returns: - Variance in [0, 0.25]. Zero when all same, max at p=0.5. - """ if len(scores) < 2: return 0.0 mean = sum(scores) / len(scores) return sum((s - mean) ** 2 for s in scores) / len(scores) -def compute_hint_gap(raw_scores: List[float], hinted_scores: List[float]) -> float: - """Compute hint gap: mean(hinted) - mean(raw). - - Positive when hints help the evaluator solve the task. - Zero or negative when hints don't help (task too easy or too hard). - - Args: - raw_scores: Scores from evaluator rollouts without hints. - hinted_scores: Scores from evaluator rollouts with hints. - - Returns: - Hint gap in [-1, 1]. - """ - if not raw_scores or not hinted_scores: - return 0.0 - p_raw = sum(raw_scores) / len(raw_scores) - p_hint = sum(hinted_scores) / len(hinted_scores) - return p_hint - p_raw - - def compute_task_reward( raw_scores: List[float], hinted_scores: List[float], validity: float = 1.0, alpha: float = 1.0, ) -> Dict[str, float]: - """Compute the full task generation reward. + """Binary reward: 1.0 if mixed solver results, 0.0 otherwise.""" + if not raw_scores: + return {"validity": validity, "p_raw": 0.0, "var_raw": 0.0, "total": 0.0} - R = validity * (alpha * var(raw) + (p_hint - p_raw)) - - Args: - raw_scores: Scores from k evaluator rollouts without hints. - hinted_scores: Scores from k evaluator rollouts with hints. - validity: LLM-as-a-judge gate (0.0 or 1.0). - alpha: Weight for variance term. - - Returns: - Dict with all reward components and total. - """ - p_raw = sum(raw_scores) / len(raw_scores) if raw_scores else 0.0 - p_hint = sum(hinted_scores) / len(hinted_scores) if hinted_scores else 0.0 + p_raw = sum(raw_scores) / len(raw_scores) var_raw = compute_variance(raw_scores) - hint_gap = p_hint - p_raw - total = validity * (alpha * var_raw + hint_gap) + has_pass = any(s > 0 for s in raw_scores) + has_fail = any(s == 0 for s in raw_scores) + total = 1.0 if (has_pass and has_fail and validity > 0) else 0.0 return { "validity": validity, "p_raw": p_raw, - "p_hint": p_hint, "var_raw": var_raw, - "hint_gap": hint_gap, - "alpha": alpha, + "hint_gap": 0.0, "total": total, } diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 20c8bc5c44..3a4f1012f8 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -1112,23 +1112,12 @@ async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: # 4. Hint-based evaluation via Fleet harness eval_result = await self._evaluate_task(prompt, verifier) - # 5. R = base_quality + tool_call_reward + eval_signal - # base_quality: small reward for passing sandbox+judge (structural validity) - # tool_call_reward: incentivize DB exploration (query_db) - # eval_signal: judge_gate * compute_task_reward (harness-based quality) - # This prevents GRPO zero-signal deadlock when all harness evals fail. - base_quality = self.base_quality_reward - tool_call_reward = self.meta_tool_calls * self.tool_call_reward_per_call - eval_signal = judge_gate * eval_result["total"] - reward = base_quality + tool_call_reward + eval_signal + # Binary reward: 1.0 if mixed solver results, 0.0 otherwise + reward = eval_result["total"] metadata["reward_breakdown"] = { "sandbox": 1.0, "judge": judge_gate, - "base_quality": base_quality, - "tool_call_reward": tool_call_reward, - "meta_tool_calls": self.meta_tool_calls, - "eval_signal": eval_signal, **eval_result, "total": reward, } @@ -1257,8 +1246,15 @@ async def _execute_meta_tool(self, tool_call: Dict[str, Any]) -> str: try: result = await self.orch.query_db_async(sql=sql, db_name=args.get("db_name", "seed")) if isinstance(result, dict): - return f"Tool result:\n{json.dumps(result, indent=2, default=str)}" - return f"Tool result:\n{result}" + # Truncate rows to save context — model only needs a sample + if "rows" in result and isinstance(result["rows"], list) and len(result["rows"]) > 5: + result["rows"] = result["rows"][:5] + result["message"] = f"Query returned more rows; showing first 5." + formatted = json.dumps(result, indent=2, default=str) + if len(formatted) > 3000: + formatted = formatted[:3000] + "\n... (truncated)" + return f"Tool result:\n{formatted}" + return f"Tool result:\n{str(result)[:3000]}" except Exception as e: if attempt < max_retries - 1 and ("closed" in str(e).lower() or "transport" in str(e).lower() or "connection" in str(e).lower()): await asyncio.sleep(1) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index eef664685d..45d9b8e1ab 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -565,10 +565,14 @@ async def agent_loop( per_step_rewards.append((step_reward, agent_loop_state.response_end_idx)) + # Close the environment first so final_reward and verifier feedback + # are captured into the env before we read metrics. Otherwise + # env_metrics is missing final_reward / verifier_stdout / tool_errors, + # which breaks downstream hint recovery metrics (they read + # m.get("final_reward", 0.0) and get 0 for every hinted rollout). + await self._env_close(env) # Get environment-specific metrics after the episode is done env_metrics = env.get_metrics() - # Close the environment - await self._env_close(env) prompt_ids = agent_loop_state.input_ids[:initial_prompt_length] rollout_logprobs = None @@ -1044,10 +1048,11 @@ async def generate_batched( prompt_len = len(prompt_token_ids[i]) truncated_indices.append(sample_indices[: prompt_len + len(response)]) + # Close the environment first so final_reward and verifier + # feedback are populated before get_metrics() reads them. + await self._env_close(env) # Get environment-specific metrics env_metrics.append(env.get_metrics()) - # Close the environment - await self._env_close(env) rollout_metrics = get_rollout_metrics(responses, rewards, env_metrics, env_classes) From bb2e5de57772f2fa76e2635bf89c386b3f76aa33 Mon Sep 17 00:00:00 2001 From: Deniz Date: Tue, 7 Apr 2026 09:50:25 -0700 Subject: [PATCH 088/121] Fix binary reward: restore base_quality + ablation config - Restore base_quality=0.1 for gate-passing gradient signal - Binary eval signal on top: R = 0.1 + (1.0 if mixed else 0.0) - Apply recommended ablation config: kl_loss_coef=1.0, entropy_loss_coef=0.001 Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-task-gen-35b-run.sh | 2 ++ skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-task-gen-35b-run.sh b/scripts/fleet-task-gen-35b-run.sh index cab355f70e..e1333688cc 100755 --- a/scripts/fleet-task-gen-35b-run.sh +++ b/scripts/fleet-task-gen-35b-run.sh @@ -80,6 +80,8 @@ bash scripts/fleet-common-run.sh \ 'generator.eval_sampling_params.stop=["", ""]' \ trainer.policy.optimizer_config.lr=5.0e-7 \ trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.kl_loss_coef=1.0 \ + trainer.algorithm.entropy_loss_coef=0.001 \ generator.max_turns=$MAX_TURNS \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 3a4f1012f8..6ea8a00ff9 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -1112,12 +1112,16 @@ async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: # 4. Hint-based evaluation via Fleet harness eval_result = await self._evaluate_task(prompt, verifier) - # Binary reward: 1.0 if mixed solver results, 0.0 otherwise - reward = eval_result["total"] + # R = base_quality + binary_eval_signal + # base_quality (0.1): gradient for passing sandbox+judge gates + # binary_eval_signal (1.0 if mixed, 0.0 otherwise): difficulty frontier + base_quality = self.base_quality_reward + reward = base_quality + eval_result["total"] metadata["reward_breakdown"] = { "sandbox": 1.0, "judge": judge_gate, + "base_quality": base_quality, **eval_result, "total": reward, } From 1b08b600fcf6678b1fd87fb38c66002166058a1e Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 8 Apr 2026 14:05:29 -0700 Subject: [PATCH 089/121] Fix submission nudge: append to tool results, not dead branch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The "submit NOW" nudge was unreachable — only fired when model output plain text, but model always outputs . Now appended to tool results when remaining <= 2 turns. Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 6ea8a00ff9..4d252502ac 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -1191,6 +1191,13 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: ) obs_content = "\n\n".join(results) + remaining = self.max_turns - self.turns + if remaining <= 2 and self.called_query_db: + obs_content += ( + f"\n\n⚠️ You have {remaining} turn(s) left. " + "You MUST output your block on your next turn or you will get reward 0. " + "Stop exploring and generate the task NOW." + ) observation = {"role": "user", "content": obs_content} return BaseTextEnvStepOutput( observations=[observation], From d89caf6e563e2b94f4ddfefab67b719c8b12eb1e Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 8 Apr 2026 15:07:22 -0700 Subject: [PATCH 090/121] CLAUDE.md: report binary variance reward, not just pass@8 pass@8 includes base_quality (0.1) gate-passing, which inflates the metric. Binary variance reward (1.0 for mixed solver results) is the actual learning signal. Co-Authored-By: Claude Opus 4.6 (1M context) --- CLAUDE.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index e75733bd07..81b86d7f81 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -29,6 +29,14 @@ Always consult the changelog before modifying Fleet training paths (`fsdp_worker All training flags live in these scripts. Never duplicate flags in SkyPilot YAMLs or fleet-research scripts. +## Task-Gen Metrics + +When reporting task-gen training metrics, distinguish between: +- **pass@8 / avg_raw_reward**: includes `base_quality=0.1` for passing sandbox+judge. Misleading — inflated by gate-passing alone. +- **binary variance reward**: the actual learning signal. `1.0` when solver rollouts are mixed (at least 1 pass + 1 fail), `0.0` otherwise. This is what matters. + +Report binary variance reward count (how many tasks got `reward >= 1.0`) separately from gate-pass count. Check `EVAL` log lines for `total=1.0000` vs `total=0.0000`. + ## Branch Primary development branch: `fleet/all` From 06cb395caca98baaa997a3e086e88340ce50e04e Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 8 Apr 2026 22:43:27 -0700 Subject: [PATCH 091/121] v5.1: verifier dry-run, MCP tool prompt, earlier nudge, zero_variance_filter - Verifier dry-run on seed DB before harness eval. Broken verifiers (return 1 on seed or crash) get feedback, model can retry. - Prompt requires calling MCP tools during exploration, not just query_db. - Nudge threshold: remaining <= 3 (was 2), gives model 3 retry chances. - Sandbox failures also return feedback instead of terminating when turns remain. - Enable zero_variance_filter=true to mask zero-advantage prompt groups. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-task-gen-35b-run.sh | 1 + .../skyrl_gym/envs/task_gen/task_gen_env.py | 90 ++++++++++++++----- 2 files changed, 68 insertions(+), 23 deletions(-) diff --git a/scripts/fleet-task-gen-35b-run.sh b/scripts/fleet-task-gen-35b-run.sh index e1333688cc..6cad5e5dab 100755 --- a/scripts/fleet-task-gen-35b-run.sh +++ b/scripts/fleet-task-gen-35b-run.sh @@ -82,6 +82,7 @@ bash scripts/fleet-common-run.sh \ trainer.algorithm.use_kl_loss=true \ trainer.algorithm.kl_loss_coef=1.0 \ trainer.algorithm.entropy_loss_coef=0.001 \ + trainer.algorithm.zero_variance_filter=true \ generator.max_turns=$MAX_TURNS \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ diff --git a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py index 4d252502ac..c966e133d0 100644 --- a/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py +++ b/skyrl-gym/skyrl_gym/envs/task_gen/task_gen_env.py @@ -533,26 +533,26 @@ def find_new_entries(table_name, id_field="id", filter_conditions=None): """ ## Exploration Tools -The database schema is provided above. Use `query_db` to inspect actual data and environment tools to understand API behavior. +The database schema is provided above. Use BOTH `query_db` AND environment API tools during exploration. ### Database Tools {"name": "query_db", "arguments": {"sql": "SELECT * FROM table_name LIMIT 5"}} Runs a read-only SQL query against the seed database. -### Environment Tools +### Environment API Tools {"name": "tool_name", "arguments": {"param": "value"}} -Calls the tool and returns its result. Use this to understand input/output formats. +Calls the environment API tool and returns its result. **You MUST call at least one API tool** (e.g., searchEvents, getAvailability) during exploration to understand what the solver agent will experience. The solver uses these API tools, not SQL — if you only explore via SQL, you won't know whether the API tools actually work for your task. ### Workflow -1. **Inspect data**: Call `query_db` with SELECT queries to inspect real data (values, ranges, row counts, patterns). The schema above shows table and column names. -2. **Try tools**: Call environment API tools to understand their behavior, input/output formats, and edge cases. -3. **Draft a task idea**: Think about what prompt + verifier you could write based on the data you've seen. -4. **Validate your draft**: Before outputting the task, run `query_db` to verify your assumptions: - - Does the data your prompt references actually exist? (e.g., "Update Jamie's email" — is there a Jamie?) - - Will the verifier return 0.0 on a fresh DB? (Check seed state) - - Are there edge cases? (e.g., multiple matches, null values, empty tables) -5. **Iterate**: If your queries reveal problems (wrong assumptions, ambiguous data, too many/few matches), revise your task idea and verify again. Do NOT output the task until you've confirmed the data supports it. -6. **Output**: Only when confident, output the final task in the format below.""" +1. **Inspect data**: Call `query_db` to inspect real data (values, ranges, row counts). +2. **Try API tools**: Call at least one environment API tool to understand its behavior, input/output format, and what data it returns. This is critical — your task must be achievable using these tools. +3. **Draft a task idea**: Based on the data AND tool behavior you've observed. +4. **Validate**: Before outputting, verify: + - Does the data your prompt references actually exist? (Query to confirm.) + - Is the task achievable using the available API tools? (You tested them.) + - Does your verifier check for a DB mutation (e.g., new order, new cart item)? If so, does the task actually cause that mutation? + - Will the verifier return 0 on the unmodified DB? (If it uses `find_new_entries`, the task MUST involve a write action like buy/reserve/create — NOT just search/list.) +5. **Output**: Only when confident, output the final task in the format below.""" ) # --- D. Few-shot examples removed --- @@ -1062,20 +1062,49 @@ def _save_rollout( except Exception as e: logger.warning(f"[{task_id}] Failed to save rollout: {e}") + async def _dryrun_verifier(self, verifier: str) -> Tuple[bool, str]: + """Run verifier against seed DB (no agent actions). Returns (ok, error_msg). + + A correct verifier should return 0 on unmodified DB (task not done yet). + Returns 1 → broken (permissive). Crashes → broken. + """ + if self.orch is None: + return True, "" # Can't dry-run without orchestrator, skip + try: + from fleet._async.tasks import Task as AsyncFleetTask + task = AsyncFleetTask( + key=f"dryrun_{self.env_key}", + prompt="dry-run", + env_id=self.env_key, + verifier_func=verifier, + ) + result = await task.verify_detailed_async(self.orch._fleet_env) + if result.success: + return False, "Verifier returned 1 on the unmodified database — it passes even when no agent has acted. Your verifier must return 0 on seed state. Check that your task involves a write/mutation action and your verifier checks for that mutation (e.g., find_new_entries)." + return True, "" + except Exception as e: + err_msg = str(e) + # Truncate long tracebacks + if len(err_msg) > 500: + err_msg = err_msg[:500] + "..." + return False, f"Verifier crashed on seed DB: {err_msg}" + async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: """Evaluate a generated task through the full pipeline. Pipeline: 1. Parse output -> fail = reward 0 2. Sandbox validation -> fail = reward 0 - 3. LLM-as-a-judge -> gate (0/1), fail = reward 0 - 4. Hint-based evaluation via Fleet harness (k raw + k hinted rollouts) - 5. R = base_quality + judge_gate * compute_task_reward(raw, hinted) + 3. Verifier dry-run on seed DB -> if broken, return feedback (retry) + 4. LLM-as-a-judge -> gate (0/1), fail = reward 0 + 5. Hint-based evaluation via Fleet harness (k raw + k hinted rollouts) + 6. R = base_quality + binary_eval_signal base_quality (default 0.1) rewards structural validity (sandbox+judge pass), providing GRPO gradient signal even when harness evals return all zeros. """ metadata: Dict[str, Any] = {"env_key": self.env_key, "turn": self.turns} + max_turns_reached = self.turns >= self.max_turns # 1. Parse parsed = parse_task_output(action) @@ -1098,10 +1127,26 @@ async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: "error": validation.error, } if not validation.valid: + if not max_turns_reached: + remaining = self.max_turns - self.turns + obs = {"role": "user", "content": f"Sandbox rejected your verifier: {', '.join(validation.checks_failed)}. Fix and resubmit. {remaining} turn(s) left."} + return BaseTextEnvStepOutput(observations=[obs], reward=0.0, done=False, metadata=metadata) metadata["reward_breakdown"] = {"sandbox": 0.0, "total": 0.0} return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) - # 3. LLM-as-a-judge gate + # 3. Verifier dry-run on seed DB + dryrun_ok, dryrun_error = await self._dryrun_verifier(verifier) + metadata["dryrun_ok"] = dryrun_ok + if not dryrun_ok: + logger.info(f"TaskGenEnv [{self.env_key}]: Verifier dry-run failed: {dryrun_error[:200]}") + if not max_turns_reached: + remaining = self.max_turns - self.turns + obs = {"role": "user", "content": f"⚠️ Verifier dry-run FAILED: {dryrun_error}\n\nFix your verifier and resubmit. {remaining} turn(s) left."} + return BaseTextEnvStepOutput(observations=[obs], reward=0.0, done=False, metadata=metadata) + metadata["reward_breakdown"] = {"dryrun": 0.0, "total": 0.0} + return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) + + # 4. LLM-as-a-judge gate judge_gate = self._judge_task(prompt, verifier) metadata["judge_gate"] = judge_gate @@ -1109,17 +1154,16 @@ async def _handle_task_generation(self, action: str) -> BaseTextEnvStepOutput: metadata["reward_breakdown"] = {"sandbox": 1.0, "judge": 0.0, "total": 0.0} return BaseTextEnvStepOutput(observations=[], reward=0.0, done=True, metadata=metadata) - # 4. Hint-based evaluation via Fleet harness + # 5. Hint-based evaluation via Fleet harness eval_result = await self._evaluate_task(prompt, verifier) - # R = base_quality + binary_eval_signal - # base_quality (0.1): gradient for passing sandbox+judge gates - # binary_eval_signal (1.0 if mixed, 0.0 otherwise): difficulty frontier + # 6. R = base_quality + binary_eval_signal base_quality = self.base_quality_reward reward = base_quality + eval_result["total"] metadata["reward_breakdown"] = { "sandbox": 1.0, + "dryrun": 1.0, "judge": judge_gate, "base_quality": base_quality, **eval_result, @@ -1192,11 +1236,11 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: obs_content = "\n\n".join(results) remaining = self.max_turns - self.turns - if remaining <= 2 and self.called_query_db: + if remaining <= 3 and self.called_query_db: obs_content += ( f"\n\n⚠️ You have {remaining} turn(s) left. " - "You MUST output your block on your next turn or you will get reward 0. " - "Stop exploring and generate the task NOW." + "You MUST output your block NOW. " + "Stop exploring and generate the task." ) observation = {"role": "user", "content": obs_content} return BaseTextEnvStepOutput( From 8a8fbde083dbb4fe46013398d3e67049fa640cda Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 10 Apr 2026 15:39:39 -0700 Subject: [PATCH 092/121] 35b: baseline on v6, disable hints - DATA_VERSION default v55 -> v6 (new v6 tool_use dataset at s3://fleet-internal-datasets/v6/openenv/all_tool_use.json) - enable_hints=false (drop n_hint_samples, use_llm_hints flags) - OPENROUTER_API_KEY now optional (only needed for LLM hint synthesis) Previous 35B run (e0b2cd94) had 4% hint recovery vs an expected naive-retry baseline of ~10-19%. Resetting to a clean baseline on the updated v6 dataset before further hint work. --- scripts/fleet-35b-run.sh | 10 ++++------ tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index 6e7411faca..bccd3220f9 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -10,7 +10,7 @@ cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo roo # Defaults for vars normally set by SkyPilot YAML envs block export LOGGER="${LOGGER:-wandb}" export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" -export DATA_VERSION="${DATA_VERSION:-v55}" +export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-tool_use}" export NUM_EPOCHS="${NUM_EPOCHS:-20}" export MAX_TURNS="${MAX_TURNS:-50}" @@ -29,8 +29,8 @@ export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" : "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" : "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" -: "${OPENROUTER_API_KEY:?Set OPENROUTER_API_KEY before running (needed for LLM hint synthesis)}" -export OPENROUTER_API_KEY +# OPENROUTER_API_KEY only needed when enable_hints=true (LLM hint synthesis) +export OPENROUTER_API_KEY="${OPENROUTER_API_KEY:-}" bash scripts/fleet-common-run.sh \ --use-python-direct --cuda-env "$HOME/.cuda_env" \ @@ -38,9 +38,7 @@ bash scripts/fleet-common-run.sh \ --nccl-heartbeat 1800 -- \ environment.skyrl_gym.fleet_task.ttl_seconds=900 \ environment.skyrl_gym.fleet_task.partial_reward=true \ - environment.skyrl_gym.fleet_task.enable_hints=true \ - environment.skyrl_gym.fleet_task.n_hint_samples=2 \ - environment.skyrl_gym.fleet_task.use_llm_hints=true \ + environment.skyrl_gym.fleet_task.enable_hints=false \ trainer.algorithm.advantage_estimator=grpo \ trainer.policy.model.path="Qwen/Qwen3.5-35B-A3B" \ trainer.flash_attn=false \ diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index b0e85b4f55..48c74a56bb 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -41,7 +41,7 @@ envs: OPENROUTER_API_KEY: "" LOGGER: "wandb" INFERENCE_BACKEND: "vllm" - DATA_VERSION: "v55" + DATA_VERSION: "v6" ENV_KEYS: "" DIFFICULTY: "" MODALITY: "tool_use" From d7d087eedbac336df034ff5d1f72e290781644dd Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 11 Apr 2026 22:27:53 -0700 Subject: [PATCH 093/121] VL training: add browser_use modality support, switch to v6 data v6 dataset splits the old computer_use into browser_use (10,678 tasks, browser-only) and computer_use (414 tasks, full desktop fos-* envs). This updates env.py, prepare_dataset.py, and the VL YAML/script to support browser_use as a first-class modality for VL training. Co-Authored-By: Claude Opus 4.6 (1M context) --- integrations/fleet/prepare_dataset.py | 4 +++- scripts/fleet-vl-run.sh | 6 +++--- skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 8 ++++---- tasks/openenv-fleet-grpo-vl.yaml | 6 +++--- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/integrations/fleet/prepare_dataset.py b/integrations/fleet/prepare_dataset.py index c4a5a937bd..3bc8601ad9 100644 --- a/integrations/fleet/prepare_dataset.py +++ b/integrations/fleet/prepare_dataset.py @@ -44,6 +44,7 @@ HELD_OUT_ENVS = { "tool_use": [], # v0.3: all envs split normally (outlook now included in train) "computer_use": [], + "browser_use": [], } # Excluded environments (removed from both train and eval) @@ -52,6 +53,7 @@ EXCLUDED_ENVS = { "tool_use": ["dropbox"], "computer_use": ["dropbox"], + "browser_use": ["dropbox"], } # Tasks excluded due to missing CURRENT_DATE in env_variables (v0.4.0) @@ -561,7 +563,7 @@ def main(): "--modality", type=str, default="tool_use", - choices=["tool_use", "computer_use", "all"], + choices=["tool_use", "computer_use", "browser_use", "all"], help="Task modality filter ('all' for no filter)", ) parser.add_argument( diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 572bf519c5..d284a91753 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -7,7 +7,7 @@ # # Model: Qwen/Qwen3.5-9B (9B params, natively multimodal, GatedDeltaNet) # TP=1 (single GPU per engine, 8 engines on 8x H200) -# Modality: computer_use (screenshots + coordinate normalization) +# Modality: browser_use (screenshots + coordinate normalization) # # Required env vars: FLEET_API_KEY, WANDB_API_KEY # Optional: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (for S3 checkpoints) @@ -17,8 +17,8 @@ cd "$(dirname "$0")/.." # cd to SkyRL root (scripts/ is directly under repo roo # Defaults for vars normally set by SkyPilot YAML envs block export LOGGER="${LOGGER:-wandb}" export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" -export DATA_VERSION="${DATA_VERSION:-v52}" -export MODALITY="${MODALITY:-computer_use}" +export DATA_VERSION="${DATA_VERSION:-v6}" +export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-50}" export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py index 75b36cbdbe..1255744782 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -5,7 +5,7 @@ keeping a clean separation between SkyRL's training interface and Fleet's environment management. -Multi-modal support: When the task modality is "computer_use", step() returns +Multi-modal support: When the task modality is "computer_use" or "browser_use", step() returns multimodal observations in OpenAI format (image_url content blocks). Upstream SkyRL's generator already handles these via extract_images_from_conversation() and passes them as multi_modal_data to vLLM — no upstream changes needed. @@ -371,7 +371,7 @@ async def init_async( # VL: adapt computer tool for Qwen's normalized coordinate space modality = self.task_config.get("task_modality", "tool_use") - if modality == "computer_use": + if modality in ("computer_use", "browser_use"): self._adapt_computer_tool_for_qwen() # Build initial prompt with task instruction @@ -433,7 +433,7 @@ async def init_async( # Computer-use hints for VL models computer_use_hints = "" - if modality == "computer_use": + if modality in ("computer_use", "browser_use"): computer_use_hints = ( "\n## Browser Interaction Strategy\n" "You are controlling a web browser via screenshots. Follow this loop:\n\n" @@ -528,7 +528,7 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: Parses the action for tool calls, executes via OpenEnv's FleetTaskEnv, and returns observation. Reward is computed by the verifier on completion. - For computer_use modality, observations may include multimodal content + For computer_use/browser_use modality, observations may include multimodal content (image_url blocks with base64 screenshots). Upstream SkyRL's generator handles these via extract_images_from_conversation(). """ diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index f9e71223a0..9ae8f14715 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -1,7 +1,7 @@ # Fleet VL/CUA GRPO Training via SkyPilot - Qwen3.5-9B (Vision-Language) # Usage: sky launch tasks/openenv-fleet-grpo-vl.yaml --env FLEET_API_KEY= --env WANDB_API_KEY= # -# VL (Vision-Language) training for computer_use environments with screenshots. +# VL (Vision-Language) training for browser_use environments with screenshots. # Based on working config from SkyRL PR #288 (feat/vl-support-clean). # # Model: Qwen/Qwen3.5-9B (9B params, natively multimodal, GatedDeltaNet) @@ -39,10 +39,10 @@ envs: FLEET_API_KEY: "" LOGGER: "wandb" INFERENCE_BACKEND: "vllm" - DATA_VERSION: "v52" + DATA_VERSION: "v6" ENV_KEYS: "" DIFFICULTY: "" - MODALITY: "computer_use" + MODALITY: "browser_use" MAX_TURNS: 50 MAX_INPUT_LENGTH: 96000 MAX_GENERATE_LENGTH: 4096 From e9ebbef67cf7f8387c6e03964b3c9c252ec08f47 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 11 Apr 2026 22:35:22 -0700 Subject: [PATCH 094/121] fix: allow browser_use modality in fleet-common-setup.sh validation Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-common-setup.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-common-setup.sh b/scripts/fleet-common-setup.sh index 37fb55ecc8..fe07c020fa 100755 --- a/scripts/fleet-common-setup.sh +++ b/scripts/fleet-common-setup.sh @@ -64,8 +64,8 @@ fi if [ -z "${AWS_ACCESS_KEY_ID:-}" ] || [ -z "${AWS_SECRET_ACCESS_KEY:-}" ]; then echo "ERROR: AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are required for S3 dataset download"; exit 1 fi -if [ "${MODALITY:-}" != "tool_use" ] && [ "${MODALITY:-}" != "computer_use" ]; then - echo "ERROR: MODALITY must be 'tool_use' or 'computer_use', got: ${MODALITY:-unset}"; exit 1 +if [ "${MODALITY:-}" != "tool_use" ] && [ "${MODALITY:-}" != "computer_use" ] && [ "${MODALITY:-}" != "browser_use" ]; then + echo "ERROR: MODALITY must be 'tool_use', 'computer_use', or 'browser_use', got: ${MODALITY:-unset}"; exit 1 fi echo "Environment validation passed" From 1ff40659d1b73c3329b81b4715c92b94345e2050 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 12 Apr 2026 00:29:12 -0700 Subject: [PATCH 095/121] fix: add 900s trajectory timeout to VL training Prevents stuck trajectories from hanging the entire training run indefinitely. The previous run hung for 35+ min on a single eval trajectory with no timeout configured. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index d284a91753..c42d61795e 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -82,6 +82,7 @@ bash scripts/fleet-common-run.sh \ generator.n_samples_per_prompt=4 \ generator.eval_n_samples_per_prompt=3 \ generator.gpu_memory_utilization=0.80 \ + generator.trajectory_timeout_seconds=900 \ trainer.logger="$LOGGER" \ trainer.project_name="fleet-browser-use-grpo" \ trainer.run_name="fleet_qwen35_${MODALITY}_${RUN_ID:-$(head -c 4 /dev/urandom | xxd -p)}" \ From 310982931c8d7b0a97694fece6bef3e5f808c583 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 12 Apr 2026 02:19:42 -0700 Subject: [PATCH 096/121] 35b: use triton GDN prefill to avoid FlashInfer JIT hang FlashInfer GDN prefill kernel JIT compilation hangs silently on GCP DL images (CUDA 12.8 + driver 570). vLLM logs suggest `--gdn-prefill-backend triton` as workaround. Pass via engine_init_kwargs. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-35b-run.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index bccd3220f9..c4fcd6b466 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -45,6 +45,7 @@ bash scripts/fleet-common-run.sh \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ + +generator.inference_engine.engine_init_kwargs.gdn_prefill_backend=triton \ generator.inference_engine_tensor_parallel_size=2 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=8 \ From af81a439282cf4103e1763f22b17e8398ebc3c88 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 12 Apr 2026 12:26:15 -0700 Subject: [PATCH 097/121] Revert "35b: use triton GDN prefill to avoid FlashInfer JIT hang" This reverts commit 310982931c8d7b0a97694fece6bef3e5f808c583. --- scripts/fleet-35b-run.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index c4fcd6b466..bccd3220f9 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -45,7 +45,6 @@ bash scripts/fleet-common-run.sh \ trainer.loss_chunk_size=4096 \ trainer.use_sample_packing=false \ +generator.chat_template_kwargs='{enable_thinking:true}' \ - +generator.inference_engine.engine_init_kwargs.gdn_prefill_backend=triton \ generator.inference_engine_tensor_parallel_size=2 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=8 \ From 8a87a254be2cb1ce2e6f1491a74ad60a333108e6 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 13 Apr 2026 15:33:16 -0700 Subject: [PATCH 098/121] 35b: enable eval_before_train for step 0 baseline Need step 0 eval to measure training improvement vs base model. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index bccd3220f9..f26e686194 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -48,7 +48,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=2 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=8 \ - trainer.eval_before_train=false \ + trainer.eval_before_train=true \ trainer.eval_interval=20 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ From 0dc943e6272dcebe117cc9353987c907c8771cd4 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 13 Apr 2026 17:27:02 -0700 Subject: [PATCH 099/121] fix: CLAUDE.md primary branch is main, not fleet/all Co-Authored-By: Claude Opus 4.6 (1M context) --- CLAUDE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CLAUDE.md b/CLAUDE.md index 81b86d7f81..5ab9ef4a90 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -39,4 +39,4 @@ Report binary variance reward count (how many tasks got `reward >= 1.0`) separat ## Branch -Primary development branch: `fleet/all` +Primary development branch: `main` From 9b7fc1fa669a24466943e067818a4fa4fde233e9 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 13 Apr 2026 17:38:41 -0700 Subject: [PATCH 100/121] fix: wire S3 upload for eval results after every eval upload_eval_results_to_s3 was defined but never called from the trainer. Eval results were dumped to local disk only and lost when clusters terminated. Now uploads to s3://skyrl-trajectories/evals/ after every eval (both eval_before_train and periodic). Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl/train/trainer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 0c0b127d1d..e3bb0d3028 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -173,6 +173,22 @@ async def eval(self) -> Dict[str, float]: global_step=self.global_step, tokenizer=self.tokenizer, ) + + # Upload eval results to S3 + if self.cfg.trainer.dump_eval_results: + try: + from integrations.fleet.s3_checkpoints import upload_eval_results_to_s3 + + step_suffix = "eval_only" if self.global_step is None else f"global_step_{self.global_step}_evals" + local_dir = os.path.join(self.cfg.trainer.export_path, "dumped_evals", step_suffix) + upload_eval_results_to_s3( + local_dir=local_dir, + run_name=self.cfg.trainer.run_name, + global_step=self.global_step, + ) + except Exception as e: + logger.warning(f"Failed to upload eval results to S3: {e}") + return eval_metrics async def train(self): From 5a84fde3b822b97763efdcbb4577264fbe65d7a8 Mon Sep 17 00:00:00 2001 From: Deniz Date: Tue, 14 Apr 2026 14:31:49 -0700 Subject: [PATCH 101/121] 35b: eval_interval=10 (was 20) Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-35b-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-35b-run.sh b/scripts/fleet-35b-run.sh index f26e686194..0192acc262 100755 --- a/scripts/fleet-35b-run.sh +++ b/scripts/fleet-35b-run.sh @@ -49,7 +49,7 @@ bash scripts/fleet-common-run.sh \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=8 \ trainer.eval_before_train=true \ - trainer.eval_interval=20 \ + trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ trainer.use_hybrid_env_sampling=true \ From e95a47ba59d689d967c53d7d8c94e10417de8b7e Mon Sep 17 00:00:00 2001 From: Sumiran Date: Thu, 16 Apr 2026 20:22:12 -0700 Subject: [PATCH 102/121] feat: add checkpoint broadcast to workers for multi-node resume --- integrations/fleet/entrypoints/main_fleet.py | 4 + integrations/fleet/s3_checkpoints.py | 80 ++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/integrations/fleet/entrypoints/main_fleet.py b/integrations/fleet/entrypoints/main_fleet.py index a2a1c1b2ac..940b24e1c4 100644 --- a/integrations/fleet/entrypoints/main_fleet.py +++ b/integrations/fleet/entrypoints/main_fleet.py @@ -74,6 +74,10 @@ def run(self): project_name=project_name, model_name=model_name, ) + + # Broadcast checkpoint to worker nodes (FSDP requires shards on every node) + from integrations.fleet.s3_checkpoints import broadcast_checkpoint_to_workers + broadcast_checkpoint_to_workers(ckpt_path) except Exception as e: logger.warning(f"Failed to download checkpoint from S3: {e}") diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py index 441da3e002..760ae0bab5 100644 --- a/integrations/fleet/s3_checkpoints.py +++ b/integrations/fleet/s3_checkpoints.py @@ -242,6 +242,86 @@ def save_checkpoints_with_cleanup(): return trainer +def broadcast_checkpoint_to_workers(ckpt_path: str) -> None: + """Broadcast checkpoint from head node to all worker nodes via rsync. + + FSDP requires checkpoint shards on every node. The S3 download only runs + on the head node, so we rsync the checkpoint directory to all workers. + + Discovers worker IPs from SKYPILOT_NODE_IPS (shell env) or Ray cluster + nodes (when running inside a Ray task). No-op on single-node. + """ + import subprocess + import socket + + # Try SKYPILOT_NODE_IPS first (set by SkyPilot run script) + node_ips_str = os.environ.get("SKYPILOT_NODE_IPS", "").strip() + if node_ips_str: + node_ips = [ip.strip() for ip in node_ips_str.split("\n") if ip.strip()] + else: + # Fall back to Ray cluster node discovery + try: + import ray + nodes = ray.nodes() + node_ips = sorted(set( + n["NodeManagerAddress"] for n in nodes + if n.get("Alive", False) + )) + logger.info(f"Discovered {len(node_ips)} nodes from Ray cluster") + except Exception as e: + logger.warning(f"Could not discover nodes: {e}") + return + + if len(node_ips) <= 1: + return # single node, nothing to broadcast + + # Head IP is the current node + head_ip = socket.gethostbyname(socket.gethostname()) + worker_ips = [ip for ip in node_ips if ip != head_ip] + + if not worker_ips: + # Try: head is first in the list + worker_ips = node_ips[1:] + + if not worker_ips: + logger.info("No worker nodes found, skipping checkpoint broadcast") + return + + # Find SSH key — SkyPilot uses ~/.ssh/sky-cluster-key on provisioned VMs + ssh_key = None + for key_path in ["~/.ssh/sky-cluster-key", "~/.ssh/sky-key", "~/.ssh/id_rsa"]: + expanded = os.path.expanduser(key_path) + if os.path.exists(expanded): + ssh_key = expanded + break + ssh_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -i {ssh_key}" if ssh_key else "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30" + + for worker_ip in worker_ips: + logger.info(f"Broadcasting checkpoint to worker {worker_ip} (ssh key: {ssh_key})...") + try: + # Create parent directory on worker (rsync can't create it) + subprocess.run( + ["ssh"] + ssh_cmd.split()[1:] + [f"gcpuser@{worker_ip}", f"mkdir -p {ckpt_path}"], + check=True, + timeout=30, + ) + subprocess.run( + [ + "rsync", "-az", + "-e", ssh_cmd, + f"{ckpt_path}/", + f"gcpuser@{worker_ip}:{ckpt_path}/", + ], + check=True, + timeout=600, # 10 min max for ~140GB checkpoint + ) + logger.info(f"Checkpoint broadcast to {worker_ip} complete") + except subprocess.TimeoutExpired: + logger.warning(f"Checkpoint broadcast to {worker_ip} timed out") + except subprocess.CalledProcessError as e: + logger.warning(f"Checkpoint broadcast to {worker_ip} failed: {e}") + + def download_checkpoint_from_s3( ckpt_path: str, run_name: str, From d643e4194995421fbfaf3638b8616e242e0fa82a Mon Sep 17 00:00:00 2001 From: Sumiran Date: Fri, 17 Apr 2026 00:26:40 -0700 Subject: [PATCH 103/121] fix: gather checkpoint shards from workers before S3 upload --- integrations/fleet/s3_checkpoints.py | 63 ++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py index 760ae0bab5..63dd3bba85 100644 --- a/integrations/fleet/s3_checkpoints.py +++ b/integrations/fleet/s3_checkpoints.py @@ -63,9 +63,72 @@ def __init__( self._pending: set = set() self._lock = threading.Lock() + def _gather_from_workers(self, local_dir: str) -> None: + """Gather checkpoint shards from worker nodes before S3 upload. + + FSDP saves each rank's shards locally on its node. The head has ranks 0-N, + workers have ranks N+1-M. We rsync worker shards to the head so the S3 + upload gets all shards. + """ + import subprocess + import socket + + node_ips_str = os.environ.get("SKYPILOT_NODE_IPS", "").strip() + if node_ips_str: + node_ips = [ip.strip() for ip in node_ips_str.split("\n") if ip.strip()] + else: + try: + import ray + nodes = ray.nodes() + node_ips = sorted(set( + n["NodeManagerAddress"] for n in nodes + if n.get("Alive", False) + )) + except Exception: + return + + if len(node_ips) <= 1: + return + + head_ip = socket.gethostbyname(socket.gethostname()) + worker_ips = [ip for ip in node_ips if ip != head_ip] + if not worker_ips: + worker_ips = node_ips[1:] + if not worker_ips: + return + + ssh_key = None + for key_path in ["~/.ssh/sky-cluster-key", "~/.ssh/sky-key", "~/.ssh/id_rsa"]: + expanded = os.path.expanduser(key_path) + if os.path.exists(expanded): + ssh_key = expanded + break + ssh_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -i {ssh_key}" if ssh_key else "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30" + + for worker_ip in worker_ips: + logger.info(f"Gathering checkpoint shards from worker {worker_ip}...") + try: + subprocess.run( + [ + "rsync", "-az", + "-e", ssh_cmd, + f"gcpuser@{worker_ip}:{local_dir}/", + f"{local_dir}/", + ], + check=True, + timeout=600, + ) + logger.info(f"Gathered shards from {worker_ip}") + except subprocess.TimeoutExpired: + logger.warning(f"Gathering from {worker_ip} timed out") + except subprocess.CalledProcessError as e: + logger.warning(f"Gathering from {worker_ip} failed: {e}") + def _upload_sync(self, local_dir: str) -> bool: """Synchronous upload that runs in thread pool.""" try: + # Gather shards from worker nodes before uploading + self._gather_from_workers(local_dir) import boto3 from botocore.config import Config from boto3.s3.transfer import TransferConfig From 9f7137e24a72abb6e407bf9f7359a2c8e368a306 Mon Sep 17 00:00:00 2001 From: Sumiran Date: Fri, 17 Apr 2026 13:42:59 -0700 Subject: [PATCH 104/121] fix: increase broadcast timeout to 30min for large checkpoints --- integrations/fleet/s3_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py index 63dd3bba85..785cca5b3d 100644 --- a/integrations/fleet/s3_checkpoints.py +++ b/integrations/fleet/s3_checkpoints.py @@ -376,7 +376,7 @@ def broadcast_checkpoint_to_workers(ckpt_path: str) -> None: f"gcpuser@{worker_ip}:{ckpt_path}/", ], check=True, - timeout=600, # 10 min max for ~140GB checkpoint + timeout=1800, # 30 min for large checkpoints (model + optimizer can be 300GB+) ) logger.info(f"Checkpoint broadcast to {worker_ip} complete") except subprocess.TimeoutExpired: From d1b9d8701fba95968b6a9146d092e1fa8868d406 Mon Sep 17 00:00:00 2001 From: Sumiran Date: Fri, 17 Apr 2026 16:04:16 -0700 Subject: [PATCH 105/121] fix: dynamic rsync timeout based on checkpoint size --- integrations/fleet/s3_checkpoints.py | 44 ++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/integrations/fleet/s3_checkpoints.py b/integrations/fleet/s3_checkpoints.py index 785cca5b3d..f6e75f8a7c 100644 --- a/integrations/fleet/s3_checkpoints.py +++ b/integrations/fleet/s3_checkpoints.py @@ -105,8 +105,10 @@ def _gather_from_workers(self, local_dir: str) -> None: break ssh_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -i {ssh_key}" if ssh_key else "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30" + timeout = _estimate_rsync_timeout(local_dir) + for worker_ip in worker_ips: - logger.info(f"Gathering checkpoint shards from worker {worker_ip}...") + logger.info(f"Gathering checkpoint shards from worker {worker_ip} (timeout={timeout}s)...") try: subprocess.run( [ @@ -116,11 +118,11 @@ def _gather_from_workers(self, local_dir: str) -> None: f"{local_dir}/", ], check=True, - timeout=600, + timeout=timeout, ) logger.info(f"Gathered shards from {worker_ip}") except subprocess.TimeoutExpired: - logger.warning(f"Gathering from {worker_ip} timed out") + logger.warning(f"Gathering from {worker_ip} timed out ({timeout}s)") except subprocess.CalledProcessError as e: logger.warning(f"Gathering from {worker_ip} failed: {e}") @@ -305,7 +307,30 @@ def save_checkpoints_with_cleanup(): return trainer -def broadcast_checkpoint_to_workers(ckpt_path: str) -> None: +def _estimate_rsync_timeout(path: str, min_timeout: int = 300) -> int: + """Estimate rsync timeout based on directory size. + + Assumes ~100MB/s conservative transfer speed + 60s buffer. + + Args: + path: Directory to measure. + min_timeout: Minimum timeout in seconds (default 5 min). + + Returns: + Timeout in seconds. + """ + try: + total_size = sum( + f.stat().st_size for f in Path(path).rglob("*") if f.is_file() + ) + timeout = max(min_timeout, int(total_size / (100 * 1024 * 1024)) + 60) + logger.info(f"Estimated rsync timeout for {total_size / 1e9:.1f}GB: {timeout}s") + return timeout + except Exception: + return min_timeout + + +def broadcast_checkpoint_to_workers(ckpt_path: str, timeout: Optional[int] = None) -> None: """Broadcast checkpoint from head node to all worker nodes via rsync. FSDP requires checkpoint shards on every node. The S3 download only runs @@ -313,6 +338,10 @@ def broadcast_checkpoint_to_workers(ckpt_path: str) -> None: Discovers worker IPs from SKYPILOT_NODE_IPS (shell env) or Ray cluster nodes (when running inside a Ray task). No-op on single-node. + + Args: + ckpt_path: Local checkpoint directory to broadcast. + timeout: Rsync timeout in seconds. If None, auto-calculated from checkpoint size. """ import subprocess import socket @@ -359,8 +388,11 @@ def broadcast_checkpoint_to_workers(ckpt_path: str) -> None: break ssh_cmd = f"ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -i {ssh_key}" if ssh_key else "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30" + if timeout is None: + timeout = _estimate_rsync_timeout(ckpt_path) + for worker_ip in worker_ips: - logger.info(f"Broadcasting checkpoint to worker {worker_ip} (ssh key: {ssh_key})...") + logger.info(f"Broadcasting checkpoint to worker {worker_ip} (ssh key: {ssh_key}, timeout={timeout}s)...") try: # Create parent directory on worker (rsync can't create it) subprocess.run( @@ -376,7 +408,7 @@ def broadcast_checkpoint_to_workers(ckpt_path: str) -> None: f"gcpuser@{worker_ip}:{ckpt_path}/", ], check=True, - timeout=1800, # 30 min for large checkpoints (model + optimizer can be 300GB+) + timeout=timeout, ) logger.info(f"Checkpoint broadcast to {worker_ip} complete") except subprocess.TimeoutExpired: From c7631c720055f3709fc63e1f9c3d45cf16da79de Mon Sep 17 00:00:00 2001 From: Deniz Date: Fri, 17 Apr 2026 23:03:38 -0700 Subject: [PATCH 106/121] VL v1: lr 5e-7, max_turns 64, eval_before_train false MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - lr 1e-6 → 5e-7: prevent entropy collapse (v0 collapsed at step 3) - max_turns 50 → 64: more headroom for browser workflows - eval_before_train false: skip 11h initial eval (have step 0 baseline) Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 6 +++--- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index c42d61795e..df594ac2d1 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -20,7 +20,7 @@ export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" -export MAX_TURNS="${MAX_TURNS:-50}" +export MAX_TURNS="${MAX_TURNS:-64}" export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=true \ + trainer.eval_before_train=false \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=16 \ @@ -70,7 +70,7 @@ bash scripts/fleet-common-run.sh \ generator.sampling_params.top_p=0.95 \ 'generator.sampling_params.stop=[""]' \ 'generator.eval_sampling_params.stop=[""]' \ - trainer.policy.optimizer_config.lr=1.0e-6 \ + trainer.policy.optimizer_config.lr=5.0e-7 \ trainer.algorithm.use_kl_loss=true \ generator.max_turns=$MAX_TURNS \ generator.backend=$INFERENCE_BACKEND \ diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 9ae8f14715..85c4c42552 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -43,7 +43,7 @@ envs: ENV_KEYS: "" DIFFICULTY: "" MODALITY: "browser_use" - MAX_TURNS: 50 + MAX_TURNS: 64 MAX_INPUT_LENGTH: 96000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 From 6a5a81ad1580a6e630c76462cf180849477ea76a Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 18 Apr 2026 13:34:47 -0700 Subject: [PATCH 107/121] =?UTF-8?q?VL=20v1.1:=20max=5Finput=5Flength=2096K?= =?UTF-8?q?=20=E2=86=92=2072K=20(fix=20NaN=20gradients)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Steps with ~95K response length produce grad_norm=NaN with SDPA, causing entropy collapse on the next step. Reducing to 72K matches the stable 35B parity run. Sequences truncate earlier, avoiding the NaN gradient region. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index df594ac2d1..f9a17e0032 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -21,7 +21,7 @@ export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-64}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" export DIFFICULTY="${DIFFICULTY:-}" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 85c4c42552..56160a4c35 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -44,7 +44,7 @@ envs: DIFFICULTY: "" MODALITY: "browser_use" MAX_TURNS: 64 - MAX_INPUT_LENGTH: 96000 + MAX_INPUT_LENGTH: 72000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 RUN_ID: "" From 179b23c577ff1553f0550c64b385b54d9595ab08 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 18 Apr 2026 13:49:44 -0700 Subject: [PATCH 108/121] =?UTF-8?q?Revert=20"VL=20v1.1:=20max=5Finput=5Fle?= =?UTF-8?q?ngth=2096K=20=E2=86=92=2072K=20(fix=20NaN=20gradients)"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 6a5a81ad1580a6e630c76462cf180849477ea76a. --- scripts/fleet-vl-run.sh | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index f9a17e0032..df594ac2d1 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -21,7 +21,7 @@ export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-64}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" export DIFFICULTY="${DIFFICULTY:-}" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 56160a4c35..85c4c42552 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -44,7 +44,7 @@ envs: DIFFICULTY: "" MODALITY: "browser_use" MAX_TURNS: 64 - MAX_INPUT_LENGTH: 72000 + MAX_INPUT_LENGTH: 96000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 RUN_ID: "" From 3360624cb55bed2fe04987c96c53933aa53c7587 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 18 Apr 2026 14:24:19 -0700 Subject: [PATCH 109/121] =?UTF-8?q?VL=20v2:=20max=5Finput=5Flength=2096K?= =?UTF-8?q?=20=E2=86=92=2080K=20(fix=20NaN=20gradients)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit grad_norm=NaN on every step with padded response_length > 90K. Capping at 80K to stay below the NaN threshold while keeping more context than the 72K 35B config. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index df594ac2d1..a761ae8457 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -21,7 +21,7 @@ export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-64}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-96000}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-80000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" export DIFFICULTY="${DIFFICULTY:-}" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 85c4c42552..95262eb8d7 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -44,7 +44,7 @@ envs: DIFFICULTY: "" MODALITY: "browser_use" MAX_TURNS: 64 - MAX_INPUT_LENGTH: 96000 + MAX_INPUT_LENGTH: 80000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 RUN_ID: "" From 27182513383481be4dab3b643ec53bf631558835 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 18 Apr 2026 21:23:33 -0700 Subject: [PATCH 110/121] VL v3: max_input_length 64K, zero_variance_filter=true MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 80K still produced NaN grad_norm at 74K response_length. Threshold is between 63K (fine) and 74K (NaN). Dropping to 64K. zero_variance_filter=true drops prompts where all 4 rollouts get same reward (no GRPO learning signal). With shorter context, more trajectories will truncate → more zero-reward prompts → filter them. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 3 ++- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index a761ae8457..de7591b81d 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -21,7 +21,7 @@ export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-64}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-80000}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-64000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" export DIFFICULTY="${DIFFICULTY:-}" @@ -72,6 +72,7 @@ bash scripts/fleet-common-run.sh \ 'generator.eval_sampling_params.stop=[""]' \ trainer.policy.optimizer_config.lr=5.0e-7 \ trainer.algorithm.use_kl_loss=true \ + trainer.algorithm.zero_variance_filter=true \ generator.max_turns=$MAX_TURNS \ generator.backend=$INFERENCE_BACKEND \ generator.run_engines_locally=true \ diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 95262eb8d7..ee07ac9dc0 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -44,7 +44,7 @@ envs: DIFFICULTY: "" MODALITY: "browser_use" MAX_TURNS: 64 - MAX_INPUT_LENGTH: 80000 + MAX_INPUT_LENGTH: 64000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 RUN_ID: "" From 3b9a85e8954547795395755ee0321a38cf76db0a Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 25 Apr 2026 09:40:22 -0700 Subject: [PATCH 111/121] feat: port HybridEnvSampler from SkyRL-archived MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ensures minimum representation from each environment per batch. With 26 envs and batch_size=16, each batch gets at least 1 sample from every env (if min_samples_per_env=1 and enough envs fit). Remaining slots filled proportionally by dataset size. Previously the config fields existed but nothing read them — sampling was purely proportional, so small envs (rops-mail 93 tasks) got zero samples while large zero-reward envs (zillow 1000, ticketmaster 1000) dominated batches. Ported from fleet-ai/SkyRL-archived skyrl_train/utils/trainer_utils.py Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl/train/utils/trainer_utils.py | 149 ++++++++++++++++++++++++++--- 1 file changed, 136 insertions(+), 13 deletions(-) diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index aaa33fba7e..ddb2482bf1 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -804,6 +804,102 @@ def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: ) +class HybridEnvSampler(torch.utils.data.Sampler): + """Ensures minimum representation from each environment per batch. + + Prevents batches dominated by large envs (zillow 1000 tasks) while small + envs (rops-mail 93 tasks) get zero samples. Each batch gets at least + min_samples_per_env from every env, remaining slots filled proportionally. + + Ported from fleet-ai/SkyRL-archived. + """ + + def __init__(self, dataset, batch_size, min_samples_per_env=1, generator=None, drop_last=True): + self.dataset = dataset + self.batch_size = batch_size + self.min_samples_per_env = min_samples_per_env + self.generator = generator + self.drop_last = drop_last + + self.env_to_indices: Dict[str, List[int]] = defaultdict(list) + for idx in range(len(dataset)): + row = dataset.dataframe[idx] + group = row.get("data_source") or row.get(dataset.env_class_key, "unknown") + self.env_to_indices[group].append(idx) + + self.env_classes = list(self.env_to_indices.keys()) + self.num_envs = len(self.env_classes) + + min_required = self.num_envs * min_samples_per_env + if min_required > batch_size: + logger.warning( + f"HybridEnvSampler: {self.num_envs} envs × {min_samples_per_env} = {min_required} " + f"> batch_size {batch_size}. Reducing min_samples_per_env." + ) + self.min_samples_per_env = max(1, batch_size // self.num_envs) + + total_samples = len(dataset) + self.env_weights = {env: len(indices) / total_samples for env, indices in self.env_to_indices.items()} + + logger.info(f"HybridEnvSampler: {self.num_envs} envs, batch_size={batch_size}, min_per_env={self.min_samples_per_env}") + for env, indices in sorted(self.env_to_indices.items()): + logger.info(f" {env}: {len(indices)} samples ({self.env_weights[env]*100:.1f}%)") + + def __iter__(self): + env_indices_shuffled = {} + for env, indices in self.env_to_indices.items(): + shuffled = indices.copy() + perm = torch.randperm(len(shuffled), generator=self.generator).tolist() + env_indices_shuffled[env] = [shuffled[i] for i in perm] + + env_positions = {env: 0 for env in self.env_classes} + + min_batches_per_env = [len(indices) // self.min_samples_per_env for indices in self.env_to_indices.values()] + num_batches = min(min_batches_per_env) + total_samples = sum(len(indices) for indices in self.env_to_indices.values()) + num_batches = min(num_batches, total_samples // self.batch_size) + + for _ in range(num_batches): + batch_indices = [] + + for env in self.env_classes: + available = len(env_indices_shuffled[env]) - env_positions[env] + samples_to_take = min(self.min_samples_per_env, available) + for _ in range(samples_to_take): + batch_indices.append(env_indices_shuffled[env][env_positions[env]]) + env_positions[env] += 1 + + remaining = self.batch_size - len(batch_indices) + if remaining > 0: + available_by_env = {env: env_indices_shuffled[env][env_positions[env]:] for env in self.env_classes} + for _ in range(remaining): + envs_with_samples = [env for env, avail in available_by_env.items() if avail] + if not envs_with_samples: + break + weights = [self.env_weights[env] for env in envs_with_samples] + total_w = sum(weights) + weights = [w / total_w for w in weights] + rand_val = torch.rand(1, generator=self.generator).item() + cumsum = 0 + chosen = envs_with_samples[-1] + for env, w in zip(envs_with_samples, weights): + cumsum += w + if rand_val < cumsum: + chosen = env + break + batch_indices.append(available_by_env[chosen].pop(0)) + env_positions[chosen] += 1 + + perm = torch.randperm(len(batch_indices), generator=self.generator).tolist() + yield [batch_indices[i] for i in perm] + + def __len__(self): + min_batches_per_env = [len(indices) // self.min_samples_per_env for indices in self.env_to_indices.values()] + num_batches = min(min_batches_per_env) + total_samples = sum(len(indices) for indices in self.env_to_indices.values()) + return min(num_batches, total_samples // self.batch_size) + + def build_dataloader( cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False ) -> StatefulDataLoader: @@ -824,20 +920,47 @@ def build_dataloader( seeded_generator = torch.Generator() seeded_generator.manual_seed(cfg.trainer.seed) - dataloader = StatefulDataLoader( - dataset, - batch_size=batch_size if not is_fully_async else 1, - shuffle=True if is_train else False, - collate_fn=dataset.collate_fn, - # TODO(Charlie): debug why inference http endpoint is slow when num_workers is 8 - num_workers=0 if cfg.generator.inference_engine.enable_http_endpoint else 8, - drop_last=True if is_train else False, - generator=seeded_generator, - # NOTE (sumanthrh): We use ray and thus use `spawn` start method. - # forking within ray leads to undefined behaviour and often causes hard to debug - # memory leaks. See: https://docs.ray.io/en/latest/ray-core/patterns/fork-new-processes.html - multiprocessing_context="spawn" if not cfg.generator.inference_engine.enable_http_endpoint else None, + use_hybrid_sampling = ( + is_train + and not is_fully_async + and getattr(cfg.trainer, "use_hybrid_env_sampling", False) + and hasattr(dataset, "dataframe") + and hasattr(dataset, "env_class_key") ) + + num_workers = 0 if cfg.generator.inference_engine.enable_http_endpoint else 8 + mp_context = "spawn" if not cfg.generator.inference_engine.enable_http_endpoint else None + + if use_hybrid_sampling: + from torch.utils.data import DataLoader + + min_samples_per_env = getattr(cfg.trainer, "min_samples_per_env", 1) + sampler = HybridEnvSampler( + dataset=dataset, + batch_size=batch_size, + min_samples_per_env=min_samples_per_env, + generator=seeded_generator, + drop_last=True, + ) + dataloader = DataLoader( + dataset, + batch_sampler=sampler, + collate_fn=dataset.collate_fn, + num_workers=num_workers, + ) + logger.info(f"Using HybridEnvSampler with min_samples_per_env={min_samples_per_env}") + else: + dataloader = StatefulDataLoader( + dataset, + batch_size=batch_size if not is_fully_async else 1, + shuffle=True if is_train else False, + collate_fn=dataset.collate_fn, + num_workers=num_workers, + drop_last=True if is_train else False, + generator=seeded_generator, + multiprocessing_context=mp_context, + ) + if is_train: if not is_fully_async: logger.info(f"Total steps: {len(dataloader) * cfg.trainer.epochs}") From 7fc32ab6b3857091d87d7c271a14a4c3e16596d7 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sat, 25 Apr 2026 10:20:17 -0700 Subject: [PATCH 112/121] VL: 2-node, batch_size=50, min_samples_per_env=2 (#18) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - num_nodes: 1 → 2 (16 H200 GPUs, 16 inference engines) - train_batch_size: 16 → 50 (50 prompts × 4 samples = 200 trajectories/step) - min_samples_per_env: 1 → 2 (guarantees 2 prompts from each of 25 envs per batch) - policy_mini_batch_size: 16 → 50 With HybridEnvSampler, every batch covers all 25 envs with at least 2 samples each. Previously 11/25 envs were never sampled in 20 steps. Co-authored-by: Deniz Co-authored-by: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 6 +++--- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index de7591b81d..d9857c8b03 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -56,10 +56,10 @@ bash scripts/fleet-common-run.sh \ trainer.eval_before_train=false \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ - trainer.train_batch_size=16 \ + trainer.train_batch_size=50 \ trainer.use_hybrid_env_sampling=true \ - trainer.min_samples_per_env=1 \ - trainer.policy_mini_batch_size=16 \ + trainer.min_samples_per_env=2 \ + trainer.policy_mini_batch_size=50 \ trainer.micro_forward_batch_size_per_gpu=1 \ trainer.micro_train_batch_size_per_gpu=1 \ trainer.ckpt_interval=10 \ diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index ee07ac9dc0..54489bdb88 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -28,7 +28,7 @@ resources: - accelerators: H200-SXM:8 cloud: vast -num_nodes: 1 +num_nodes: 2 workdir: url: https://github.com/fleet-ai/SkyRL-v2.git From 1ab8ef437dff0652801bfc7f5505d3c703b2f569 Mon Sep 17 00:00:00 2001 From: Sumiran Date: Sat, 25 Apr 2026 17:12:33 -0700 Subject: [PATCH 113/121] feat: add Fleet eval-only entrypoint with S3 checkpoint resume Adds `integrations/fleet/entrypoints/main_eval.py` (FleetEvalExp), a sibling to FleetPPOExp that mirrors the S3 download / FSDP weight load / inference-engine sync path but skips the training loop. Calls `trainer.eval()` once; the trainer's existing dump_eval_results path handles S3 upload of the eval results. Why: until now, replaying a checkpoint to get an extra independent eval required launching `main_fleet` with `eval_before_train=true`, which then proceeded into the full training loop and burned ~50min of GPU per cluster after the eval was already done. The new entrypoint keeps the exact same eval contract (same cfg, same trainer.eval(), same S3 upload prefix), just without the trailing train loop. Also updates `scripts/fleet-eval-only-run.sh` to drive the new entrypoint, with RESUME_RUN_NAME / RESUME_CKPT_PATH / RESUME_MODE env vars (auto-defaults RESUME_MODE=latest when RESUME_RUN_NAME is set, none otherwise so the same script works for base-model evals). Adds unit tests covering all branches of `_load_policy_only` (NONE / LATEST without marker / LATEST with valid marker / FROM_PATH / FROM_PATH missing dir / FROM_PATH invalid dir name). Run: uv run --extra dev --extra skyrl-train pytest \ integrations/fleet/tests/test_main_eval.py Co-Authored-By: Claude Opus 4.7 (1M context) --- integrations/fleet/entrypoints/main_eval.py | 221 ++++++++++++++++++++ integrations/fleet/tests/test_main_eval.py | 177 ++++++++++++++++ scripts/fleet-eval-only-run.sh | 115 ++++++++++ 3 files changed, 513 insertions(+) create mode 100644 integrations/fleet/entrypoints/main_eval.py create mode 100644 integrations/fleet/tests/test_main_eval.py create mode 100644 scripts/fleet-eval-only-run.sh diff --git a/integrations/fleet/entrypoints/main_eval.py b/integrations/fleet/entrypoints/main_eval.py new file mode 100644 index 0000000000..d70fd7249a --- /dev/null +++ b/integrations/fleet/entrypoints/main_eval.py @@ -0,0 +1,221 @@ +""" +Fleet Task Eval-Only Entrypoint for SkyRL. + +Resumes a Fleet GRPO checkpoint from S3 (FSDP shards on every node), runs a +single evaluation pass over the eval dataset, and uploads the dumped eval +results to S3. No training loop, no optimizer state. + +Mirrors the resume + weight-sync path used by `main_fleet.py:FleetPPOExp.run()` +so the same FSDP checkpoints can be replayed against the same eval set on a +fresh cluster (e.g. for variance bars across seeds). + +Usage: + python -m integrations.fleet.entrypoints.main_eval \ + environment.env_class=fleet_task \ + environment.skyrl_gym.fleet_task.tasks_file=/path/to/tasks.json \ + data.val_data=['/path/to/validation.parquet'] \ + trainer.policy.model.path=Qwen/Qwen3.5-9B \ + trainer.run_name=my_eval_run \ + trainer.dump_eval_results=true + +Environment Variables for S3 Checkpoint Management: + AWS_ACCESS_KEY_ID: AWS access key + AWS_SECRET_ACCESS_KEY: AWS secret key + AWS_REGION: AWS region (default: us-east-1) + S3_CHECKPOINT_BUCKET: S3 bucket for FSDP checkpoints (default: skyrl-checkpoints) + S3_TRAJECTORY_BUCKET: S3 bucket for eval result uploads (default: skyrl-trajectories) + RESUME_RUN_NAME: Run name to resume from. If unset, evaluates the base + weights at trainer.policy.model.path with no FSDP load. +""" + +import asyncio +import logging +import os +import sys +from pathlib import Path + +import ray +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.entrypoints.main_base import BasePPOExp +from skyrl.train.utils import validate_cfg +from skyrl.train.utils.utils import initialize_ray + +logger = logging.getLogger(__name__) + + +def _strip_hydra_prefixes(args: list[str]) -> list[str]: + """Strip Hydra ++ and + prefixes from CLI args. + + Matches `main_fleet.py`. `from_cli_overrides` rejects +/++ prefixed args, + but our run scripts use them for environment-specific config keys that + now exist in the dataclass — so we can strip the prefix safely. + """ + cleaned = [] + for arg in args: + if arg.startswith("++"): + cleaned.append(arg[2:]) + elif arg.startswith("+"): + cleaned.append(arg[1:]) + else: + cleaned.append(arg) + return cleaned + + +class FleetEvalExp(BasePPOExp): + """Fleet eval-only experiment with optional S3 checkpoint resume. + + Reuses the trainer's FSDP weight loading and inference-engine weight sync, + then calls `trainer.eval()` once. `trainer.eval()` already handles local + eval dump and S3 upload when `trainer.dump_eval_results=true`, so this + entrypoint just needs to wire up resume + run a single eval pass. + """ + + def get_train_dataset(self): + """No train dataset is needed for eval-only runs.""" + return None + + def run(self): + trainer = self._setup_trainer() + assert trainer.eval_dataloader is not None, ( + "FleetEvalExp requires an eval dataset. Set `data.val_data` " + "and `trainer.eval_interval > 0`." + ) + + # Optional S3 resume: download FSDP shards on this VM and broadcast + # to the rest of the cluster. Mirrors FleetPPOExp.run(). + resume_run_name = os.environ.get("RESUME_RUN_NAME", "") + if resume_run_name: + try: + from integrations.fleet.s3_checkpoints import ( + broadcast_checkpoint_to_workers, + download_checkpoint_from_s3, + ) + + ckpt_path = trainer.cfg.trainer.ckpt_path + model_path = getattr(trainer.cfg.trainer.policy.model, "path", "unknown-model") + model_name = Path(model_path).name + project_name = getattr(trainer.cfg.trainer, "project_name", "skyrl") + download_checkpoint_from_s3( + ckpt_path=ckpt_path, + run_name=resume_run_name, + project_name=project_name, + model_name=model_name, + ) + broadcast_checkpoint_to_workers(ckpt_path) + except Exception as e: + logger.warning(f"Failed to download checkpoint from S3: {e}") + + asyncio.run(self._run_eval(trainer)) + + async def _run_eval(self, trainer): + """Initialize weight sync, load policy weights, and run eval once.""" + trainer.init_weight_sync_state() + + # Load only the policy FSDP shards. We bypass `trainer.load_checkpoints()` + # because it also restores `train_dataloader.state_dict()`, which is None + # in eval-only mode. Optimizer / lr scheduler state are skipped too. + self._load_policy_only(trainer) + + # Push fresh weights to the inference engine for evaluation. + await trainer.dispatch.save_weights_for_sampler() + + # `trainer.eval()` runs the eval loop and uploads to S3 when + # `dump_eval_results=true`. The S3 prefix uses `trainer.global_step`, + # which `_load_policy_only` sets from the resumed checkpoint. + eval_metrics = await trainer.eval() + trainer.tracker.log(eval_metrics, step=trainer.global_step, commit=True) + trainer.tracker.finish() + logger.info(f"Eval-only metrics: {eval_metrics}") + + def _load_policy_only(self, trainer): + """Load only the policy FSDP shards from a `global_step_` directory. + + Resolves the checkpoint path the same way `trainer.load_checkpoints()` + does (LATEST via `latest_ckpt_global_step.txt`, or FROM_PATH via + `cfg.trainer.resume_path`), then calls `dispatch.load_checkpoint` + with optimizer / scheduler state disabled. Sets `trainer.global_step` + so downstream eval dumps and S3 uploads land under the correct step. + + TODO: This duplicates the path-resolution half of + `RayPPOTrainer.load_checkpoints()`. The reason for the duplication is + that `load_checkpoints()` unconditionally calls + `self.train_dataloader.load_state_dict(...)`, which crashes when + `train_dataloader is None` (eval-only). If trainer ever grows a + `skip_dataloader_state` / `policy_only` flag, drop this helper and + call `trainer.load_checkpoints(...)` directly. + """ + from skyrl.backends.skyrl_train.utils.io import io + from skyrl.train.utils.trainer_utils import ( + GLOBAL_STEP_PREFIX, + ResumeMode, + extract_step_from_path, + validate_consistency_for_latest_checkpoint, + ) + + if trainer.resume_mode == ResumeMode.NONE: + logger.info("resume_mode=none; evaluating base model weights") + return + + if trainer.resume_mode == ResumeMode.LATEST: + latest_file = os.path.join( + trainer.cfg.trainer.ckpt_path, "latest_ckpt_global_step.txt" + ) + if not io.exists(latest_file): + logger.warning( + "resume_mode=latest but no checkpoint found at " + f"{trainer.cfg.trainer.ckpt_path}; using base weights" + ) + return + with io.open_file(latest_file, "r") as f: + step = int(f.read().strip()) + ckpt_dir = os.path.join( + trainer.cfg.trainer.ckpt_path, f"{GLOBAL_STEP_PREFIX}{step}" + ) + validate_consistency_for_latest_checkpoint( + trainer.cfg.trainer.ckpt_path, + step, + ckpt_dir, + latest_file, + trainer.cfg.trainer.ckpt_interval, + ) + else: # ResumeMode.FROM_PATH + ckpt_dir = str(trainer.cfg.trainer.resume_path) + + if not io.exists(ckpt_dir): + raise FileNotFoundError(f"Checkpoint path not found: {ckpt_dir}") + + global_step = extract_step_from_path(Path(ckpt_dir)) + if global_step == -1: + raise ValueError(f"Checkpoint path is not a valid global_step dir: {ckpt_dir}") + trainer.global_step = global_step + + policy_ckpt_dir = os.path.join(ckpt_dir, "policy") + logger.info(f"Loading policy checkpoint from {policy_ckpt_dir} (step {global_step})") + trainer.dispatch.load_checkpoint( + "policy", + policy_ckpt_dir, + load_optimizer_states=False, + load_lr_scheduler_states=False, + ) + logger.info("Successfully loaded policy checkpoint for eval") + + +@ray.remote(num_cpus=1) +def skyrl_eval_entrypoint(cfg: SkyRLTrainConfig): + """Ray remote function that runs Fleet eval-only.""" + # fleet_task env is auto-registered by skyrl_gym.envs.__init__ + exp = FleetEvalExp(cfg) + exp.run() + + +def main() -> None: + """Main entry point for Fleet task eval-only.""" + args = _strip_hydra_prefixes(sys.argv[1:]) + cfg = SkyRLTrainConfig.from_cli_overrides(args) + validate_cfg(cfg) + initialize_ray(cfg) + ray.get(skyrl_eval_entrypoint.remote(cfg)) + + +if __name__ == "__main__": + main() diff --git a/integrations/fleet/tests/test_main_eval.py b/integrations/fleet/tests/test_main_eval.py new file mode 100644 index 0000000000..58a95660ce --- /dev/null +++ b/integrations/fleet/tests/test_main_eval.py @@ -0,0 +1,177 @@ +"""Unit tests for the Fleet eval-only entrypoint. + +uv run --extra dev --extra skyrl-train pytest integrations/fleet/tests/test_main_eval.py +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from integrations.fleet.entrypoints.main_eval import ( + FleetEvalExp, + _strip_hydra_prefixes, +) + + +# --------------------------------------------------------------------------- +# _strip_hydra_prefixes +# --------------------------------------------------------------------------- + + +def test_strip_hydra_prefixes_handles_all_three_arg_shapes(): + args = [ + "trainer.run_name=my_run", + "+trainer.eval_interval=1", + "++environment.skyrl_gym.fleet_task.tasks_file=/tmp/tasks.json", + ] + out = _strip_hydra_prefixes(args) + assert out == [ + "trainer.run_name=my_run", + "trainer.eval_interval=1", + "environment.skyrl_gym.fleet_task.tasks_file=/tmp/tasks.json", + ] + + +def test_strip_hydra_prefixes_empty(): + assert _strip_hydra_prefixes([]) == [] + + +def test_strip_hydra_prefixes_double_plus_takes_precedence_over_single(): + # "++" matches startswith("++") first, so it strips two characters, not one. + assert _strip_hydra_prefixes(["++key=value"]) == ["key=value"] + + +# --------------------------------------------------------------------------- +# FleetEvalExp.get_train_dataset +# --------------------------------------------------------------------------- + + +def test_get_train_dataset_returns_none(): + # Bypass __init__ so we don't pull in tokenizer / placement group. + exp = FleetEvalExp.__new__(FleetEvalExp) + assert exp.get_train_dataset() is None + + +# --------------------------------------------------------------------------- +# FleetEvalExp._load_policy_only — path resolution + dispatch wiring +# --------------------------------------------------------------------------- + + +def _make_trainer_mock(resume_mode_value: str, ckpt_path: str, resume_path: str = "") -> MagicMock: + """Build a minimal trainer mock for _load_policy_only tests. + + Mirrors the attribute shape `_load_policy_only` reads: trainer.resume_mode, + trainer.cfg.trainer.{ckpt_path, ckpt_interval, resume_path}, trainer.dispatch. + """ + from skyrl.train.utils.trainer_utils import ResumeMode + + trainer = MagicMock() + trainer.resume_mode = ResumeMode(resume_mode_value) + trainer.cfg = SimpleNamespace( + trainer=SimpleNamespace( + ckpt_path=ckpt_path, + ckpt_interval=10, + resume_path=resume_path, + ) + ) + trainer.global_step = 0 + return trainer + + +def _make_exp() -> FleetEvalExp: + """Create a FleetEvalExp bypassing __init__ (which loads a tokenizer).""" + return FleetEvalExp.__new__(FleetEvalExp) + + +def test_load_policy_only_resume_none_is_noop(): + exp = _make_exp() + trainer = _make_trainer_mock("none", ckpt_path="/tmp/does-not-matter") + + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() + assert trainer.global_step == 0 + + +def test_load_policy_only_latest_with_no_marker_file_is_noop(tmp_path): + exp = _make_exp() + trainer = _make_trainer_mock("latest", ckpt_path=str(tmp_path)) + # No latest_ckpt_global_step.txt written → fall through, no load. + + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() + assert trainer.global_step == 0 + + +def test_load_policy_only_latest_loads_policy_and_sets_global_step(tmp_path): + # Build a realistic checkpoint layout that the resolver expects. + ckpt_dir = tmp_path / "global_step_30" + (ckpt_dir / "policy").mkdir(parents=True) + (tmp_path / "latest_ckpt_global_step.txt").write_text("30") + + exp = _make_exp() + trainer = _make_trainer_mock("latest", ckpt_path=str(tmp_path)) + + # The consistency validator hits the filesystem in non-trivial ways; stub + # it out so the test stays focused on this method's contract. + with patch( + "skyrl.train.utils.trainer_utils.validate_consistency_for_latest_checkpoint" + ) as validator: + exp._load_policy_only(trainer) + + validator.assert_called_once() + trainer.dispatch.load_checkpoint.assert_called_once_with( + "policy", + str(ckpt_dir / "policy"), + load_optimizer_states=False, + load_lr_scheduler_states=False, + ) + assert trainer.global_step == 30 + + +def test_load_policy_only_from_path_loads_specified_checkpoint(tmp_path): + ckpt_dir = tmp_path / "global_step_42" + (ckpt_dir / "policy").mkdir(parents=True) + + exp = _make_exp() + trainer = _make_trainer_mock("from_path", ckpt_path=str(tmp_path), resume_path=str(ckpt_dir)) + + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_called_once_with( + "policy", + str(ckpt_dir / "policy"), + load_optimizer_states=False, + load_lr_scheduler_states=False, + ) + assert trainer.global_step == 42 + + +def test_load_policy_only_from_path_missing_dir_raises(tmp_path): + exp = _make_exp() + trainer = _make_trainer_mock( + "from_path", + ckpt_path=str(tmp_path), + resume_path=str(tmp_path / "global_step_99"), # never created + ) + + with pytest.raises(FileNotFoundError): + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() + + +def test_load_policy_only_from_path_invalid_dir_name_raises(tmp_path): + # extract_step_from_path returns -1 when the dir name has no global_step prefix. + bad_dir = tmp_path / "not_a_step_dir" + (bad_dir / "policy").mkdir(parents=True) + + exp = _make_exp() + trainer = _make_trainer_mock("from_path", ckpt_path=str(tmp_path), resume_path=str(bad_dir)) + + with pytest.raises(ValueError, match="not a valid global_step dir"): + exp._load_policy_only(trainer) + + trainer.dispatch.load_checkpoint.assert_not_called() diff --git a/scripts/fleet-eval-only-run.sh b/scripts/fleet-eval-only-run.sh new file mode 100644 index 0000000000..4d1a0b07d4 --- /dev/null +++ b/scripts/fleet-eval-only-run.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash +# Eval-only run on Fleet envs with optional S3 checkpoint resume. +# +# When RESUME_RUN_NAME is set, downloads the latest FSDP checkpoint from S3, +# broadcasts it to worker nodes, loads policy weights, and runs a single eval +# pass. Eval results are dumped locally and uploaded to S3. +# +# When RESUME_RUN_NAME is unset, evaluates the base model at trainer.policy.model.path. +# +# Required env vars: FLEET_API_KEY, WANDB_API_KEY +# Optional env vars: +# RESUME_RUN_NAME Run name to resume from (S3 prefix). Empty = base model eval. +# RESUME_CKPT_PATH Local checkpoint dir to download into. Default: $HOME/ckpts/eval_only +# AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY (required for S3 resume / upload) +# MODEL_PATH HF model repo or path. Default: Qwen/Qwen3.5-9B +# PROJECT_NAME W&B / S3 project prefix. Default: fleet-tool-use-grpo +# RUN_NAME W&B run name + S3 eval upload prefix. Default: fleet_eval_only__ +# EVAL_N_SAMPLES pass@K samples per prompt. Default: 8 +# +set -euo pipefail +cd "$(dirname "$0")/.." # cd to SkyRL root + +export LOGGER="${LOGGER:-wandb}" +export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" +export MODALITY="${MODALITY:-tool_use}" +export MAX_TURNS="${MAX_TURNS:-50}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-72000}" +export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" +export NUM_INFERENCE_ENGINES="${NUM_INFERENCE_ENGINES:-8}" +export EVAL_N_SAMPLES="${EVAL_N_SAMPLES:-8}" +export AWS_REGION="${AWS_REGION:-us-east-1}" +export S3_DATASET_BUCKET="${S3_DATASET_BUCKET:-fleet-internal-datasets}" +export S3_CHECKPOINT_BUCKET="${S3_CHECKPOINT_BUCKET:-skyrl-checkpoints}" +export S3_TRAJECTORY_BUCKET="${S3_TRAJECTORY_BUCKET:-skyrl-trajectories}" + +: "${FLEET_API_KEY:?Set FLEET_API_KEY before running}" +: "${WANDB_API_KEY:?Set WANDB_API_KEY before running}" +export OPENROUTER_API_KEY="${OPENROUTER_API_KEY:-}" + +MODEL_PATH="${MODEL_PATH:-Qwen/Qwen3.5-9B}" +PROJECT_NAME="${PROJECT_NAME:-fleet-tool-use-grpo}" +RUN_NAME="${RUN_NAME:-fleet_eval_only_${MODALITY}_pass_at_${EVAL_N_SAMPLES}}" +RESUME_CKPT_PATH="${RESUME_CKPT_PATH:-$HOME/ckpts/eval_only}" + +# resume_mode controls how main_eval picks the checkpoint inside RESUME_CKPT_PATH. +# latest = read latest_ckpt_global_step.txt (written by S3 download); none = base weights. +if [ -n "${RESUME_RUN_NAME:-}" ]; then + RESUME_MODE="${RESUME_MODE:-latest}" +else + RESUME_MODE="${RESUME_MODE:-none}" +fi +export RESUME_RUN_NAME="${RESUME_RUN_NAME:-}" + +DATA_ROOT="" +if [ -d "/workspace" ] && [ -w "/workspace" ]; then + DATA_ROOT="/workspace" +else + DATA_ROOT="$HOME" +fi + +EVAL_PARQUET="${DATA_ROOT}/data/fleet/${MODALITY}/validation.parquet" +TASKS_FILE="${DATA_ROOT}/data/fleet/tasks_${MODALITY}.json" + +echo "=== Fleet Eval-Only Run ===" +echo "Model: $MODEL_PATH" +echo "Project / Run: $PROJECT_NAME / $RUN_NAME" +echo "Resume run name: ${RESUME_RUN_NAME:-(none — base model eval)}" +echo "Resume mode: $RESUME_MODE" +echo "Ckpt path: $RESUME_CKPT_PATH" +echo "Eval data: $EVAL_PARQUET" +echo "Samples/prompt: $EVAL_N_SAMPLES" + +bash scripts/fleet-common-run.sh \ + --use-python-direct --cuda-env "$HOME/.cuda_env" \ + --set-ulimit --no-pytorch-alloc-conf \ + --entrypoint integrations.fleet.entrypoints.main_eval \ + --nccl-heartbeat 1800 -- \ + environment.skyrl_gym.fleet_task.ttl_seconds=900 \ + environment.skyrl_gym.fleet_task.partial_reward=true \ + environment.skyrl_gym.fleet_task.enable_hints=false \ + trainer.policy.model.path="$MODEL_PATH" \ + trainer.flash_attn=false \ + trainer.use_sample_packing=false \ + trainer.resume_mode="$RESUME_MODE" \ + trainer.ckpt_path="$RESUME_CKPT_PATH" \ + trainer.eval_batch_size=4 \ + trainer.eval_interval=1 \ + trainer.max_prompt_length=2048 \ + trainer.dump_eval_results=true \ + trainer.export_path="$HOME/exports" \ + generator.chat_template_kwargs='{enable_thinking:true}' \ + generator.inference_engine_tensor_parallel_size=1 \ + generator.max_input_length=$MAX_INPUT_LENGTH \ + generator.eval_sampling_params.max_generate_length=$MAX_GENERATE_LENGTH \ + generator.eval_sampling_params.temperature=0.9 \ + generator.eval_sampling_params.top_p=0.95 \ + 'generator.eval_sampling_params.stop=[""]' \ + generator.max_turns=$MAX_TURNS \ + generator.backend=$INFERENCE_BACKEND \ + generator.run_engines_locally=true \ + generator.weight_sync_backend=nccl \ + generator.async_engine=true \ + generator.batched=false \ + generator.use_conversation_multi_turn=true \ + generator.eval_n_samples_per_prompt=$EVAL_N_SAMPLES \ + generator.enforce_eager=false \ + generator.gpu_memory_utilization=0.65 \ + generator.inject_context_status=true \ + generator.context_warning_threshold=0.90 \ + trainer.logger="$LOGGER" \ + trainer.project_name="$PROJECT_NAME" \ + trainer.run_name="$RUN_NAME" \ + "data.val_data=['${EVAL_PARQUET}']" \ + "environment.skyrl_gym.fleet_task.tasks_file=${TASKS_FILE}" \ + "$@" From 0865c8a781fb53ebd7f1a14dfb839f6504670be2 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 26 Apr 2026 12:54:07 -0700 Subject: [PATCH 114/121] eval_before_train=true for checkpoint resume eval --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index d9857c8b03..cb993a2182 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=false \ + trainer.eval_before_train=true \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=50 \ From 5dfd1983bd50aa1a577f046b397cfe77c18f8587 Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 29 Apr 2026 15:26:05 -0700 Subject: [PATCH 115/121] Prioritize RunPod reserved H200s in SkyPilot task configs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch from any_of to ordered in all 4 task YAMLs. Ordering: RunPod reserved H200s → GKE/GCP → other providers. Co-Authored-By: Claude Opus 4.6 (1M context) --- tasks/openenv-fleet-grpo-qwen3_5-35b.yaml | 13 ++++++++----- tasks/openenv-fleet-grpo-vl.yaml | 9 ++++++--- tasks/task-gen-grpo-qwen3_5-35b.yaml | 9 ++++++--- tasks/task-gen-grpo-qwen3_5-9b.yaml | 9 ++++++--- 4 files changed, 26 insertions(+), 14 deletions(-) diff --git a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml index 48c74a56bb..c65312b3db 100644 --- a/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml +++ b/tasks/openenv-fleet-grpo-qwen3_5-35b.yaml @@ -13,17 +13,20 @@ resources: disk_size: 750 memory: 1500+ ports: 6479 - any_of: + ordered: + # RunPod reserved H200s (priority) - accelerators: H200:8 - cloud: gcp - use_spot: true - image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + cloud: runpod + # GKE fallback - accelerators: H200:8 cloud: kubernetes network_tier: best use_spot: true - accelerators: H200:8 - cloud: runpod + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + # Other providers - accelerators: H200:8 cloud: lambda - accelerators: H200:8 diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 54489bdb88..e9d073367d 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -14,13 +14,16 @@ name: fleet-vl-grpo-qwen3-5-9b resources: disk_size: 750 ports: 6479 - any_of: + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GCP fallback - accelerators: H200:8 cloud: gcp use_spot: true image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 - - accelerators: H200:8 - cloud: runpod + # Other providers - accelerators: H200:8 cloud: lambda - accelerators: H200:8 diff --git a/tasks/task-gen-grpo-qwen3_5-35b.yaml b/tasks/task-gen-grpo-qwen3_5-35b.yaml index b52015672b..db2d53b8e8 100644 --- a/tasks/task-gen-grpo-qwen3_5-35b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-35b.yaml @@ -11,13 +11,16 @@ resources: disk_size: 750 memory: 1500+ ports: 6479 - any_of: + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GCP fallback - accelerators: H200:8 cloud: gcp use_spot: true image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 - - accelerators: H200:8 - cloud: runpod + # Other providers - accelerators: H200:8 cloud: lambda diff --git a/tasks/task-gen-grpo-qwen3_5-9b.yaml b/tasks/task-gen-grpo-qwen3_5-9b.yaml index 40382beba9..7c3c1edbf5 100644 --- a/tasks/task-gen-grpo-qwen3_5-9b.yaml +++ b/tasks/task-gen-grpo-qwen3_5-9b.yaml @@ -10,13 +10,16 @@ resources: disk_size: 500 memory: 800+ ports: 6479 - any_of: + ordered: + # RunPod reserved H200s (priority) + - accelerators: H200:8 + cloud: runpod + # GCP fallback - accelerators: H200:8 cloud: gcp use_spot: true image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 - - accelerators: H200:8 - cloud: runpod + # Other providers - accelerators: H200:8 cloud: lambda From e64bc7e07d440ca38bab17be464c179b6ce1fa29 Mon Sep 17 00:00:00 2001 From: Deniz Date: Wed, 29 Apr 2026 19:25:55 -0700 Subject: [PATCH 116/121] VL: increase max_input_length to 80K for longer browser trajectories MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NaN grad_norm at long sequences is safely handled by the optimizer (skips step, zeros grads) — no weight corruption. 80K allows browser_use trajectories to complete without truncation for envs that need 50+ turns. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index cb993a2182..92e3dabe12 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -21,7 +21,7 @@ export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" export MAX_TURNS="${MAX_TURNS:-64}" -export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-64000}" +export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-80000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" export DIFFICULTY="${DIFFICULTY:-}" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index e9d073367d..4783b3225d 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -47,7 +47,7 @@ envs: DIFFICULTY: "" MODALITY: "browser_use" MAX_TURNS: 64 - MAX_INPUT_LENGTH: 64000 + MAX_INPUT_LENGTH: 80000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 RUN_ID: "" From 6ad8c76c464cedd883442b820aee7d441f778435 Mon Sep 17 00:00:00 2001 From: Deniz Date: Thu, 30 Apr 2026 21:46:20 -0700 Subject: [PATCH 117/121] =?UTF-8?q?VL:=20increase=20max=5Fturns=2064?= =?UTF-8?q?=E2=86=9280=20for=20browser-use=20turn=20limit=20ablation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit v6-v2 showed 70% of trajectories hit the 64-turn ceiling at step 1 while using only 28-70% of context. Turn limit, not context length, is the binding constraint. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 2 +- tasks/openenv-fleet-grpo-vl.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index 92e3dabe12..b33dfcbd22 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -20,7 +20,7 @@ export INFERENCE_BACKEND="${INFERENCE_BACKEND:-vllm}" export DATA_VERSION="${DATA_VERSION:-v6}" export MODALITY="${MODALITY:-browser_use}" export NUM_EPOCHS="${NUM_EPOCHS:-10}" -export MAX_TURNS="${MAX_TURNS:-64}" +export MAX_TURNS="${MAX_TURNS:-80}" export MAX_INPUT_LENGTH="${MAX_INPUT_LENGTH:-80000}" export MAX_GENERATE_LENGTH="${MAX_GENERATE_LENGTH:-4096}" export ENV_KEYS="${ENV_KEYS:-}" diff --git a/tasks/openenv-fleet-grpo-vl.yaml b/tasks/openenv-fleet-grpo-vl.yaml index 4783b3225d..27d9152ae1 100644 --- a/tasks/openenv-fleet-grpo-vl.yaml +++ b/tasks/openenv-fleet-grpo-vl.yaml @@ -46,7 +46,7 @@ envs: ENV_KEYS: "" DIFFICULTY: "" MODALITY: "browser_use" - MAX_TURNS: 64 + MAX_TURNS: 80 MAX_INPUT_LENGTH: 80000 MAX_GENERATE_LENGTH: 4096 NUM_EPOCHS: 10 From 84b49de80a5cfb4cdf1c302b59befa1bf296e05d Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 3 May 2026 21:24:24 -0700 Subject: [PATCH 118/121] feat: save screenshots in trajectory dumps for VL training - Propagate accumulated_images from AgentLoopState through TrajectoryOutput to GeneratorOutput (multi_modal_data field was defined but never set) - Save PIL images as JPEG alongside trajectory JSONL in dumped_trajectories/global_step_N_images/ - Store image_paths and num_screenshots in trajectory entries - URLs are stored as-is (not downloaded during training) Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl/train/generators/skyrl_gym_generator.py | 7 +++ skyrl/train/utils/trainer_utils.py | 45 ++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/skyrl/train/generators/skyrl_gym_generator.py b/skyrl/train/generators/skyrl_gym_generator.py index 45d9b8e1ab..cbc9137349 100644 --- a/skyrl/train/generators/skyrl_gym_generator.py +++ b/skyrl/train/generators/skyrl_gym_generator.py @@ -651,6 +651,9 @@ async def agent_loop( rollout_logprobs=rollout_logprobs, env_metrics=env_metrics, rollout_expert_indices=rollout_expert_indices_out, + multi_modal_data={"images": agent_loop_state.accumulated_images} + if agent_loop_state.accumulated_images + else None, ) return agent_loop_output @@ -1297,6 +1300,9 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False # set loss mask to 0 if the stop reason is not "stop" loss_masks = apply_overlong_filtering(loss_masks, stop_reasons) + # Collect per-trajectory images (for dump_training_trajectories) + multi_modal_data_list = [output.multi_modal_data for output in all_outputs] + generator_output: GeneratorOutput = { "prompt_token_ids": prompt_token_ids, "response_ids": responses, @@ -1310,6 +1316,7 @@ async def generate(self, input_batch: GeneratorInput, disable_tqdm: bool = False "env_metrics": env_metrics, "is_last_step": is_last_step, "is_hinted": is_hinted, + "multi_modal_data": multi_modal_data_list, } return generator_output diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index ddb2482bf1..2a5cd81b26 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -308,10 +308,19 @@ def dump_training_trajectories( filename = traj_dir / f"global_step_{global_step}.jsonl" env_metrics_list = generator_output.get("env_metrics") or [] + multi_modal_data_list = generator_output.get("multi_modal_data") or [] rewards_list = generator_output["rewards"] stop_reasons = generator_output.get("stop_reasons") or [] ts = time.time() + # Save screenshots alongside JSONL if any trajectories have images + images_dir = traj_dir / f"global_step_{global_step}_images" + has_any_images = any( + mm and mm.get("images") for mm in multi_modal_data_list if isinstance(mm, dict) + ) + if has_any_images: + images_dir.mkdir(parents=True, exist_ok=True) + with open(filename, "w") as f: for i in range(len(generator_output["response_ids"])): env_m = env_metrics_list[i] if i < len(env_metrics_list) and env_metrics_list[i] else {} @@ -329,6 +338,28 @@ def dump_training_trajectories( stop_reason = stop_reasons[i] if i < len(stop_reasons) else "unknown" tokens = len(generator_output["response_ids"][i]) + # Save screenshots for this trajectory + image_paths = [] + mm_data = multi_modal_data_list[i] if i < len(multi_modal_data_list) else None + if isinstance(mm_data, dict) and mm_data.get("images"): + for j, img in enumerate(mm_data["images"]): + img_filename = f"traj_{i:03d}_img_{j:03d}.jpg" + img_path = images_dir / img_filename + try: + if hasattr(img, "save"): + # PIL Image + img.save(str(img_path), "JPEG", quality=85) + image_paths.append(str(img_path)) + elif isinstance(img, str) and img.startswith(("http://", "https://")): + # URL — store the URL, don't download during training + image_paths.append(img) + elif isinstance(img, bytes): + with open(img_path, "wb") as img_f: + img_f.write(img) + image_paths.append(str(img_path)) + except Exception as e: + logger.warning(f"Failed to save image {j} for trajectory {i}: {e}") + entry = { "step": global_step, "env_key": env_key, @@ -341,9 +372,21 @@ def dump_training_trajectories( "text": tokenizer.decode(generator_output["response_ids"][i]), "timestamp": ts, } + if image_paths: + entry["image_paths"] = image_paths + entry["num_screenshots"] = len(image_paths) f.write(json.dumps(entry, ensure_ascii=False) + "\n") - logger.info(f"Dumped {len(generator_output['response_ids'])} training trajectories to {filename}") + n_images = sum( + len(entry.get("images", [])) + for mm in multi_modal_data_list + if isinstance(mm, dict) + for entry in [mm] + ) + logger.info( + f"Dumped {len(generator_output['response_ids'])} training trajectories to {filename}" + + (f" ({n_images} screenshots saved)" if has_any_images else "") + ) return str(filename) From 646d5a92736ac9abc26d65d6ecab620f9a308277 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 3 May 2026 21:49:26 -0700 Subject: [PATCH 119/121] feat: save screenshots in eval trajectory dumps too Extends screenshot saving from training dumps to eval dumps (dump_per_dataset_eval_results). Eval JSONL entries now include image_paths and num_screenshots when VL images are available. Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl/train/utils/trainer_utils.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 2a5cd81b26..d82f4fc1c7 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -254,6 +254,15 @@ def dump_per_dataset_eval_results( # Prepare common data input_prompts = [tokenizer.decode(prompt) for prompt in concat_generator_outputs["prompt_token_ids"]] output_responses = [tokenizer.decode(response) for response in concat_generator_outputs["response_ids"]] + multi_modal_data_list = concat_generator_outputs.get("multi_modal_data") or [] + + # Save screenshots if any trajectories have images + images_dir = dump_dir_path / "images" + has_any_images = any( + mm and mm.get("images") for mm in multi_modal_data_list if isinstance(mm, dict) + ) + if has_any_images: + images_dir.mkdir(parents=True, exist_ok=True) # Group indices by data source data_source_indices = {} @@ -265,12 +274,36 @@ def dump_per_dataset_eval_results( data_source_indices[data_source].append(i) # Dump per-dataset files + total_images_saved = 0 for data_source, indices in data_source_indices.items(): sanitized_data_source = sanitize_data_source(data_source) filename = dump_dir_path / f"{sanitized_data_source}.jsonl" with open(filename, "w") as f: for i in indices: + # Save screenshots for this eval trajectory + image_paths = [] + mm_data = multi_modal_data_list[i] if i < len(multi_modal_data_list) else None + if isinstance(mm_data, dict) and mm_data.get("images"): + for j, img in enumerate(mm_data["images"]): + img_filename = f"eval_{i:04d}_img_{j:03d}.jpg" + img_path = images_dir / img_filename + try: + if hasattr(img, "save"): + img.save(str(img_path), "JPEG", quality=85) + image_paths.append(str(img_path)) + total_images_saved += 1 + elif isinstance(img, str) and img.startswith(("http://", "https://")): + image_paths.append(img) + total_images_saved += 1 + elif isinstance(img, bytes): + with open(img_path, "wb") as img_f: + img_f.write(img) + image_paths.append(str(img_path)) + total_images_saved += 1 + except Exception as e: + logger.warning(f"Failed to save eval image {j} for trajectory {i}: {e}") + entry = { "input_prompt": input_prompts[i], "output_response": output_responses[i], @@ -280,10 +313,16 @@ def dump_per_dataset_eval_results( "env_extras": concat_env_extras[i], "data_source": data_source, } + if image_paths: + entry["image_paths"] = image_paths + entry["num_screenshots"] = len(image_paths) f.write(json.dumps(entry, ensure_ascii=False) + "\n") logger.info(f"Dumped eval data for {data_source} to {filename}") + if total_images_saved: + logger.info(f"Saved {total_images_saved} eval screenshots to {images_dir}") + # Dump aggregated results file aggregated_filename = dump_dir_path / "aggregated_results.jsonl" with open(aggregated_filename, "w") as f: From abf40083e6086fe1b7992acb8be175fdd39f0ff1 Mon Sep 17 00:00:00 2001 From: Deniz Date: Sun, 3 May 2026 23:08:19 -0700 Subject: [PATCH 120/121] VL: set eval_before_train=false to skip 10h eval overhead Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/fleet-vl-run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fleet-vl-run.sh b/scripts/fleet-vl-run.sh index b33dfcbd22..12e728fbcb 100755 --- a/scripts/fleet-vl-run.sh +++ b/scripts/fleet-vl-run.sh @@ -53,7 +53,7 @@ bash scripts/fleet-common-run.sh \ generator.inference_engine_tensor_parallel_size=1 \ trainer.epochs=${NUM_EPOCHS} \ trainer.eval_batch_size=12 \ - trainer.eval_before_train=true \ + trainer.eval_before_train=false \ trainer.eval_interval=10 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=50 \ From 9e9f648f53b37f10b3f4ebfc878cdf55f6795134 Mon Sep 17 00:00:00 2001 From: Deniz Date: Mon, 4 May 2026 15:46:45 -0700 Subject: [PATCH 121/121] Add taste-reward shaping on top of main (rebased) Cherry-picks the taste judge integration from the original taste-reward-shaping branch onto current main, which includes all fleet scripts and the async env wrapper fix for MCP transport errors. - skyrl_taste/ package: async judge wrapper with provider routing - env.py: taste_floor config, _apply_taste_reward gating at episode end - YAML: single-node H200, corrected workdir URL to skyrl-fleet Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/taste/LAUNCH.md | 147 ++++++++++++++ docs/taste/integration_map.md | 219 +++++++++++++++++++++ docs/taste/smoke_test.py | 219 +++++++++++++++++++++ skyrl-gym/pyproject.toml | 2 +- skyrl-gym/skyrl_gym/envs/fleet_task/env.py | 124 +++++++++++- skyrl-gym/skyrl_taste/__init__.py | 13 ++ skyrl-gym/skyrl_taste/judge.py | 177 +++++++++++++++++ tasks/openenv-fleet-grpo-vl-taste.yaml | 131 ++++++++++++ 8 files changed, 1027 insertions(+), 5 deletions(-) create mode 100644 docs/taste/LAUNCH.md create mode 100644 docs/taste/integration_map.md create mode 100644 docs/taste/smoke_test.py create mode 100644 skyrl-gym/skyrl_taste/__init__.py create mode 100644 skyrl-gym/skyrl_taste/judge.py create mode 100644 tasks/openenv-fleet-grpo-vl-taste.yaml diff --git a/docs/taste/LAUNCH.md b/docs/taste/LAUNCH.md new file mode 100644 index 0000000000..af5bd75ae4 --- /dev/null +++ b/docs/taste/LAUNCH.md @@ -0,0 +1,147 @@ +# Taste-Judge GRPO Launch Recipe + +Wires `research/judge/judge.py` into the SkyRL Fleet GRPO training loop. +Reward shape is **GATED TASTE**: + +``` +effective_taste = max(taste_floor, taste_score) # 1.0 if judge fails / None +reward = verifier_reward * effective_taste +``` + +Blended only on the terminal step of each rollout, with a 10s judge timeout +and verifier-only fallback (`effective_taste = 1.0`, so reward collapses to +`verifier_reward`) on timeout/exception/None. + +### Why gated > additive + +The previous additive shape `R = alpha * verifier + (1-alpha) * taste` +rewarded "pretty failures" — a trajectory that fails the verifier (v=0) +but narrates clean intent (t high) earned `(1-alpha) * t > 0`, which +incentivized the policy to learn good-looking failure modes. Gated taste +closes this hack: `verifier=0` forces `reward=0` regardless of taste, so +there is zero gradient toward pretty-failure mimicry. Among successes, +ugly successes still earn `floor * verifier` (default `floor=0.1`) so GRPO +sees within-group taste variance and can prefer pretty successes; setting +`floor=1.0` collapses the shape to pure verifier and serves as a clean +ablation baseline. **The floor is set to 0.1 (not 0.3) because offline +analysis showed mean rescaled taste of verifier=1 trajectories is ~0.13; +floor=0.3 would clip nearly all successes and kill within-group variance. +Re-tune floor after a 50-100 step pilot using the empirical effective_taste +P25 logged in WandB.** + +## One-block launch + +```bash +# 0. From your machine: +cd /tmp && rm -rf skyrl-fleet && git clone https://github.com/fleet-ai/skyrl-fleet.git +cd /tmp/skyrl-fleet + +# 1. Apply the env patch (adds taste_floor config, _apply_taste_reward helper, +# and updates the three terminal returns + get_metrics). +git apply /Users/alliegu/Desktop/fleet/integration/env.py.diff + +# 2. Vendor the taste-judge package into the workdir Python path. +cp -r /Users/alliegu/Desktop/fleet/integration/skyrl_taste skyrl-gym/skyrl_taste +cp -r /Users/alliegu/Desktop/fleet/research/judge research/judge + +# 3. Drop the new YAML config into tasks/. +cp /Users/alliegu/Desktop/fleet/integration/configs/openenv-fleet-grpo-vl-taste.yaml \ + tasks/openenv-fleet-grpo-vl-taste.yaml + +# 4. Sky launch with the new yaml + new env vars (judge keys are NEW; the rest +# are unchanged from the existing VL launch). +sky launch tasks/openenv-fleet-grpo-vl-taste.yaml \ + --env FLEET_API_KEY="$FLEET_API_KEY" \ + --env WANDB_API_KEY="$WANDB_API_KEY" \ + --env AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ + --env AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ + --env ANTHROPIC_API_KEY="$ANTHROPIC_API_KEY" \ + --env OPENAI_API_KEY="$OPENAI_API_KEY" +``` + +## Required env vars + +- `ANTHROPIC_API_KEY` — **required**. Default judge backend (Claude via + `research/judge/judge.py`). Without it the judge import fails and the env + silently falls back to verifier-only reward (you'll see + `taste_judge_failed=True` in WandB). +- `OPENAI_API_KEY` — **only required if running inter-rater agreement + passes** (GPT-4o judge for cross-checking Claude scores during eval). Not + needed for the standard training run. +- `FLEET_API_KEY`, `WANDB_API_KEY`, `AWS_ACCESS_KEY_ID`, + `AWS_SECRET_ACCESS_KEY` — same as the upstream VL launch. + +**Important:** Invoke `judge.py` with `blind_outcome=True` at training time +to suppress outcome bleed (Stream 4 finding — when the judge sees the +verifier outcome, taste scores correlate ~0.7 with verifier and the +shaping signal collapses to a noisy duplicate of the binary reward). The +async wrapper in `skyrl_taste/judge.py` handles this; double-check the +flag is forwarded if you swap the wrapper. + +## WandB metrics to watch + +- `reward/train/mean` — gated reward; bounded above by verifier mean. +- `env/taste_reward` — judge's [0,1] raw score per trajectory. +- `env/effective_taste` — `max(floor, taste_reward)`; what actually + multiplies the verifier. +- `env/verifier_reward` — raw binary verifier per trajectory. +- `env/taste_floor` — the configured floor; sanity-check. +- `env/taste_judge_failed` — should stay near 0; spikes mean Anthropic + outage or judge parse failures (auto-fallback to pure verifier engaged). +- **Cross-check**: in within-group runs, plot Pearson(`taste_reward`, + `verifier_reward`). If correlation collapses below ~0.3, the judge is + scoring a different signal than the verifier — that's the expected case + and where the shaped-reward gradient comes from. If it climbs above + ~0.7, suspect outcome bleed (re-verify `blind_outcome=True`). +- `reward/train/variance_per_prompt` and `signal_ratio` (from + `integrations/fleet/reward_metrics.py`) should *increase* relative to a + verifier-only baseline on groups with mixed pretty/ugly successes. + +## Rollback + +**Runtime kill switch (no redeploy):** +```bash +sky exec "echo SKYRL_TASTE_DISABLED=1 >> ~/.bashrc && pkill -HUP -f main_fleet" +# or update the SkyPilot env block and re-launch with --env SKYRL_TASTE_DISABLED=1 +``` +This makes `score_trajectory_async` return `None`, the env's +`effective_taste` becomes `1.0`, and reward collapses to pure verifier. + +**Full revert (uncheck-out the patch):** +```bash +cd /tmp/skyrl-fleet +git apply -R /Users/alliegu/Desktop/fleet/integration/env.py.diff +rm -rf skyrl-gym/skyrl_taste research/judge +``` + +## Two-knob ablation (floor x grpo_norm_by_std) + +| floor \ grpo_norm_by_std | true (default) | false (recommended w/ gated taste) | +|---|---|---| +| 0.0 (pure multiplicative) | Ugly successes get R=0; group std collapses on all-ugly groups. Heavy gradient damping; expect slow learning. | Same dynamics, undamped; risk of policy ignoring ugly successes entirely. | +| 0.1 | Tiny within-success variance; std-norm wipes most of the gradient. | Tight bonus for pretty successes; conservative shaping. | +| 0.1 (default) | Tiny within-success variance from floor itself; std-norm still wipes most of the gradient. | **Headline candidate.** Multiplicative-with-cushion; closes hack and matches the empirical taste distribution. | +| 0.3 | Within-success std damped; offline data shows nearly all successes clip to floor — kills the signal. | Heavier shaping; only sensible if live taste distribution skews high. | +| 0.5 | Floor close to pretty-mid; less taste differentiation among successes. | Shallower shaping; useful as sensitivity check. | +| 1.0 (pure verifier) | **Identical to upstream baseline.** A/B control, no taste in std. | Identical to upstream too (no taste in std). | + +Recommended order: run cell `(0.1, false)` first as the headline candidate, +then `(0.1, true)` to measure the std-norm effect, then `(1.0, true)` as +the upstream baseline. `(0.0, false)` is a diagnostic: confirms the gate +itself bites (ugly successes get zero) without floor compensation. + +## Risks / gotchas + +- **Judge latency budget**: 10s timeout x `n_samples_per_prompt=4` at + `train_batch_size=50` = ~200 concurrent judge calls per training step. + Anthropic rate limits will throttle you before the GPU does. Watch + `taste_judge_failed` — sustained >10% means raise the limit or batch. +- **Reward range**: gated reward is in `[0, 1]` — same as verifier — so + pass@n threshold (`reward >= 1.0` in `reward_metrics.py:79-82`) only + triggers on `(verifier=1, taste=1.0)`. With `floor=0.1` and `verifier=1`, + blended max is 1.0 only when `taste_score=1.0`. **Pass@n will look + worse than verifier-only**; report it alongside the new gated-reward + mean, and consider plotting `verifier_reward >= 1.0` as a separate + pass@n line for direct comparison to the baseline. +- **Outcome bleed**: confirmed Stream 4 risk if the judge ever sees the + verifier outcome. Keep `blind_outcome=True` in `score_trajectory_async`. diff --git a/docs/taste/integration_map.md b/docs/taste/integration_map.md new file mode 100644 index 0000000000..9c6ef7fc00 --- /dev/null +++ b/docs/taste/integration_map.md @@ -0,0 +1,219 @@ +# Fleet GRPO Reward Integration Map + +Repo: `https://github.com/fleet-ai/skyrl-fleet` (cloned to `/tmp/skyrl-fleet-2` in sandbox; `git clone` into `/sessions/.../outputs` failed because the existing mount blocked write to `.git/`, so we cloned to `/tmp/skyrl-fleet-2`). + +The `skyrl-train` package has been merged into `skyrl/` (per `skyrl-train/README.md`). Modern code paths live under `skyrl/train/...`. + +--- + +## Reward emit point + +The Fleet env returns reward in **two places**, both in `skyrl-gym/skyrl_gym/envs/fleet_task/env.py`: + +### Per-step reward — `step_async()` returns +File: `skyrl-gym/skyrl_gym/envs/fleet_task/env.py` + +The reward is initialized to `0.0` at line **552**, populated from OpenEnv at lines **590–592** and **615–617**, and finally emitted on the `BaseTextEnvStepOutput` returns at lines **674, 708, 762**. + +``` +552 reward = 0.0 +... +588 try: +589 mcp_start = time.time() +590 obs, reward, done, info = ( +591 await self.openenv_task_env.step_async(openenv_action) +592 ) +... +613 try: +614 mcp_start = time.time() +615 obs, reward, done, info = ( +616 await self.openenv_task_env.step_async(openenv_action) +617 ) +... +672 return BaseTextEnvStepOutput( +673 observations=[], +674 reward=reward, +675 done=True, +676 metadata={...}, +677 ) +... +706 return BaseTextEnvStepOutput( +707 observations=[new_obs], +708 reward=reward, +709 done=episode_done, +710 metadata=metadata, +711 ) +... +760 return BaseTextEnvStepOutput( +761 observations=[new_obs], +762 reward=reward, +763 done=episode_done, +764 metadata=metadata, +765 ) +``` + +### Final reward fallback — `close()` / `close_async()` +For trajectories that get terminated by SkyRL (context overflow, timeout) **without** the agent emitting ``, OpenEnv's verifier is run inside `close()` / `close_async()` and the result is stashed on `self.last_reward` (lines **789–790** and **805–806**). It then surfaces via `get_metrics()` as `final_reward` (line **824**): + +``` +784 def close(self): +785 """Close the Fleet environment and cleanup resources.""" +786 if self.openenv_task_env: +787 try: +788 self.openenv_task_env.close() +789 if self.openenv_task_env.final_reward is not None: +790 self.last_reward = self.openenv_task_env.final_reward +... +796 async def close_async(self): +... +802 if self.openenv_task_env: +803 try: +804 await self.openenv_task_env.close_async() +805 if self.openenv_task_env.final_reward is not None: +806 self.last_reward = self.openenv_task_env.final_reward +``` + +The terminal reward used by the GRPO trainer comes from the last step where `done=True`, i.e. one of the three return sites above. **The clean place to inject `taste_score` is inside `step_async()` immediately before each of those three returns, when `episode_done is True`.** + +--- + +## Verifier source + +The binary `0.0 / 1.0` reward is **not computed inside this repo**. It comes back from OpenEnv's `FleetTaskEnv.step_async()` (and `close_async()`) at: + +- `skyrl-gym/skyrl_gym/envs/fleet_task/env.py:590-592` — happy path during the step where the agent submits its tool call +- `skyrl-gym/skyrl_gym/envs/fleet_task/env.py:615-617` — when the agent emits `` with no tool call +- `skyrl-gym/skyrl_gym/envs/fleet_task/env.py:788-790, 804-806` — orphaned-trajectory fallback via `openenv_task_env.final_reward` + +OpenEnv runs a programmatic Python verifier server-side; the Fleet wrapper only consumes its return value. There is also a **partial-reward** mode (not binary) toggled by `env_config.partial_reward` (constructor lines **176–181**); the VL launch script enables it (`scripts/fleet-vl-run.sh:42` — `environment.skyrl_gym.fleet_task.partial_reward=true`). Per `reward_metrics.py:79-82`, only `reward >= 1.0` counts as a "pass" in pass@n, so partial values land in `(0,1)`. + +For task-generation runs there is also `integrations/fleet/task_gen_reward.py` which applies a derived "mixed result" reward — orthogonal to the browser-use loop but worth noting because it's a precedent for shaping rewards inside this repo. + +--- + +## LLM-as-judge example + +Path: `examples/train/llm_as_a_judge/`. Four files (5 with `__init__.py`): + +| File | Purpose | +|---|---| +| `llm_judge_env.py` | The env: `GSM8kLLMJudgeEnv(BaseTextEnv)` with a synchronous `step()` that calls the OpenAI client to score an answer. | +| `main_llm_judge.py` | Ray entrypoint that registers the env id `"llm_as_a_judge"` and calls `BasePPOExp(cfg).run()`. | +| `gsm8k_dataset_judge.py` | Dataset prep: emits parquet with `env_class="llm_as_a_judge"` and `reward_spec.ground_truth`. | +| `run_llm_judge.sh` | GRPO launch (Qwen2.5-1.5B-Instruct, 4× GPU). Sets `environment.skyrl_gym.llm_as_a_judge.model="gpt-4o-mini"`. | + +What the env actually does — quoting the **only** reward-relevant section of `llm_judge_env.py`: + +```python +def _get_reward(self, action: str) -> float: + message = PROMPT + f"\n\nGOLD SOLUTION:\n{self.ground_truth}\n\nPREDICTED SOLUTION:\n{action}\n\nAnswer:" + try: + response = self.llm_judge_client.chat.completions.create( + model=self.model, messages=[{"role": "user", "content": message}] + ) + reply = response.choices[0].message.content.strip() + match = re.search(r"### Final Score:\s*([01](?:\.0)?)", reply) + if match: + return float(match.group(1)) + if reply.strip() in {"1", "0"}: + return float(reply.strip()) + return 0.0 + except Exception as e: + print(f"LLM Judge error: {type(e).__name__}: {e}") + return 0.0 + +def step(self, action: str) -> BaseTextEnvStepOutput: + done = True + reward = self._get_reward(action) + return BaseTextEnvStepOutput(observations=[], reward=reward, done=done, metadata={}) +``` + +Properties: + +- **Synchronous and blocking.** `step()` is sync and uses the `openai.OpenAI` blocking client. Each rollout's `step()` call blocks the worker thread for the full judge latency. +- **Single-turn.** The env always returns `done=True` on the first step, so the judge runs exactly once per trajectory at the very end. +- **No batching.** One judge call per rollout, no aggregation across the GRPO group of `n_samples_per_prompt` trajectories. +- **No async / no thread pool / no retry / no timeout.** Errors swallow to `0.0` (silent failure mode). +- **Caller is the env itself.** Reward is computed inline in `step()` — the trainer never knows there is an LLM judge in the loop. + +The reason this is acceptable in the GSM8k example: it is **single-turn**, run on cheap CPU-side I/O, with a tiny batch (the script uses `train_batch_size=32`, `n_samples_per_prompt=5`), and rollouts in SkyRL's generator already run concurrently via the async generator + Ray, so blocking calls in different envs proceed in parallel. The pattern does **not** scale cleanly to long multi-turn browser_use rollouts where you don't want to hold the env alive for an extra 1–3 s × group_size at the very end. + +--- + +## Async strategy + +**Recommendation: post-hoc, parallel, and out-of-step.** Specifically: + +1. **Do not call the judge inside `step_async()` per turn.** Browser-use trajectories have 50–80 turns (`MAX_TURNS=80` in the YAML); judging every step is wasteful and the judge can't reasonably score before the trajectory is done anyway. +2. **At episode end** (the `episode_done` branch in `step_async()` and inside `close_async()`), kick off the judge call **as an awaitable**. Two options, in order of cleanliness: + - **Option A (preferred):** make `score_trajectory` an `async def` that uses `httpx.AsyncClient` or `openai.AsyncOpenAI`, with `asyncio.wait_for(..., timeout=judge_timeout_s)`. SkyRL's generator already runs `step_async` inside an asyncio task per rollout, so judge calls across the entire GRPO group naturally overlap. With `n_samples_per_prompt=4` and ~50 prompts, you get 200 judge calls running concurrently and the wall-clock cost collapses to ~max(judge_latency). + - **Option B (escape hatch):** wrap the sync judge in `asyncio.to_thread(...)` (Python 3.9+) so the existing sync OpenAI/Anthropic client doesn't block the event loop. Slightly worse than A under load but a one-line change. +3. **Use `asyncio.gather` or `asyncio.wait_for` with a hard timeout** of e.g. 10 s. On timeout/exception, log a warning and fall back to `verifier_reward` only (i.e. effectively `alpha = 1.0` for that trajectory). This keeps a slow Anthropic API outage from stalling a training step. +4. **Do not gate trajectory cleanup on the judge.** Resolve the judge future, attach the score to the final `BaseTextEnvStepOutput`, and let `close_async()` proceed independently if the judge is still pending. (In practice, since `step_async` returns the terminal `done=True` step, you must either `await` the judge before the final return or do post-hoc reward attribution at the trainer level.) +5. **Optional optimization — batch by prompt-group at the trainer layer.** A more invasive variant: store the trajectory transcripts in `metadata`, then have the trainer call the judge once per GRPO group (with all `n` trajectories in one prompt) before computing advantages. This gives the judge cross-trajectory context for relative ranking and is what most production RLAIF setups do. Requires patching the trainer's reward post-processing path (where `flatten_rewards` in `integrations/fleet/reward_metrics.py` is called), not the env. Out of scope for the minimal patch but worth flagging. + +**The existing `llm_as_a_judge` example uses none of these**: it is sync, inline, single-call, single-turn, no timeout, no retry. **Do not copy it as-is for browser_use** — copy the *interface shape* (judge runs inside the env at episode end and emits a scalar in `[0,1]`) and rewrite the call to be async + timed-out. + +--- + +## GRPO config knobs + +Defaults in `skyrl/train/config/ppo_base_config.yaml` (lines 96–124), with VL overrides from `scripts/fleet-vl-run.sh`: + +| Knob | Default | VL launch override | Interaction with shaped reward | +|---|---|---|---| +| `trainer.algorithm.advantage_estimator` | `"grpo"` | `grpo` | Computes per-prompt-group advantages from raw rewards. A continuous `taste_score` increases within-group variance and produces non-zero advantages even when all trajectories pass/fail the binary verifier — exactly the desired effect. | +| `trainer.algorithm.grpo_norm_by_std` | `true` | (default) | GRPO divides advantage by group-level reward std. With binary rewards, std is 0 when the whole group passes/fails; mixing in `taste_score` raises std, which **also damps the advantage magnitude**. Watch for: groups where verifier is unanimous now produce small but non-zero advantages — the gradient signal will be tiny. May want `grpo_norm_by_std=false` once shaped reward is on. | +| `trainer.algorithm.zero_variance_filter` | `false` | `true` (line 73) | Currently masks out prompts where all rewards are identical (no signal). With shaped reward this filter would fire **far less often** since `taste_score` is approximately continuous → almost every prompt now contributes a gradient. This is good for sample efficiency but may also amplify judge noise into the policy. Consider keeping it on but with a tolerance threshold. | +| `trainer.algorithm.use_kl_loss` | `true` | `true` | KL is on the policy loss, so it is independent of reward scale. Good. | +| `trainer.algorithm.kl_loss_coef` | `0.001` | (default) | Independent of reward, no change needed. | +| `trainer.algorithm.use_kl_in_reward` | `false` | (default, mutually exclusive with `use_kl_loss`) | If you ever flip to `use_kl_in_reward=true`, the KL term gets *added to the reward* and competes directly with `taste_score`. Keep this `false`. | +| `trainer.algorithm.eps_clip_low / eps_clip_high` | `0.2 / 0.2` | (default) | PPO ratio clip. Independent of reward magnitude (operates on log-prob ratio), so safe. | +| `trainer.algorithm.advantage_batch_normalize` | `false` | (default) | If turned on, would re-normalize advantages across the whole batch. Consider enabling if the taste_score's scale + verifier mix produces unstable cross-prompt advantage magnitudes. | +| `trainer.algorithm.loss_reduction` | `"token_mean"` | `"sequence_mean"` (line 47) | Doesn't touch reward, but `sequence_mean` is what's used for VL — keep aware that gradient is per-trajectory averaged. | + +**Concrete suggestions:** +- Start with `alpha=0.5` (balanced). +- Keep `grpo_norm_by_std=true` initially; if you observe gradient norm collapse, set it to `false`. +- Bound `taste_score` to `[0,1]` (same range as verifier) so the mixed reward stays in `[0,1]` and existing pass@n / signal-ratio metrics in `integrations/fleet/reward_metrics.py` still parse correctly. +- Consider reporting `verifier_reward` and `taste_reward` separately as wandb metrics so you can disentangle their contributions — fits naturally into the existing metric schema. + +--- + +## Existing evals + +**Eval entrypoint:** `integrations/fleet/entrypoints/main_eval.py` — `FleetEvalExp(BasePPOExp).run()`. Resumes FSDP weights from S3, calls `await trainer.eval()` once (line 125), logs the dict via `trainer.tracker.log(...)`, and (optionally) uploads dump to S3. + +**Metric computation:** `integrations/fleet/reward_metrics.py` exposes: +- `flatten_rewards(rewards)` — collapses token-level rewards to scalars. +- `compute_pass_at_n(rewards, uids)` — fraction of unique prompts with **at least one rollout `>= 1.0`**. +- The module's docstring documents the wandb naming convention: `reward/{group}/pass_at_n`, `reward/{group}/variance_per_prompt`, `reward/{group}/signal_ratio`, `reward/{group}/mean_positive_reward`. + +**What gets measured today:** +- Final reward distribution (pass@n with threshold ≥ 1.0). +- Within-prompt reward variance (the GRPO learning-signal proxy). +- Signal ratio (% prompts with non-zero variance). +- Mean positive reward. +- Per-env metrics emitted from `FleetTaskEnv.get_metrics()` at lines 812–835: `task_key`, `env_key`, `turns`, `tool_calls`, `tool_errors`, `is_hinted`, `final_reward`, `verifier_stdout`, `verifier_error`, `tool_error_messages`, `chat_history`. + +**How to add a new metric (e.g. `taste_reward_mean`):** +1. In `step_async()`'s terminal returns, also stash `self.last_taste_reward` and `self.last_verifier_reward` on the env. +2. Append both to the metadata dict and to `get_metrics()` output (line 814 onward) so they flow into the trainer's metric aggregator alongside `final_reward`. +3. The trainer's `_get_response_level_rewards`/eval-dump path picks up env metadata — no further patching needed if the keys are scalar-typed. For aggregated metrics (group-level), add a function to `reward_metrics.py` modeled on `compute_pass_at_n` and call it from wherever `pass_at_n` is logged in the trainer (search `compute_pass_at_n` to find the call sites — they live in `skyrl/train/trainer.py` and `integrations/fleet/entrypoints/main_fleet_tinker.py`). +4. Test path: `integrations/fleet/tests/test_main_eval.py`. + +--- + +## Files referenced (absolute, in cloned repo) + +- `/tmp/skyrl-fleet-2/skyrl-gym/skyrl_gym/envs/fleet_task/env.py` — env, reward emit point (lines 525–765, 784–810). +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/llm_judge_env.py` — sync judge example. +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/main_llm_judge.py` +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/gsm8k_dataset_judge.py` +- `/tmp/skyrl-fleet-2/examples/train/llm_as_a_judge/run_llm_judge.sh` +- `/tmp/skyrl-fleet-2/tasks/openenv-fleet-grpo-vl.yaml` — VL launch task. +- `/tmp/skyrl-fleet-2/scripts/fleet-vl-run.sh` — actual GRPO CLI args. +- `/tmp/skyrl-fleet-2/skyrl/train/config/ppo_base_config.yaml` — GRPO/PPO defaults (lines 96–124). +- `/tmp/skyrl-fleet-2/integrations/fleet/reward_metrics.py` — metric helpers. +- `/tmp/skyrl-fleet-2/integrations/fleet/entrypoints/main_eval.py` — eval entrypoint. +- `/tmp/skyrl-fleet-2/integrations/fleet/task_gen_reward.py` — precedent for shaping rewards inside this repo (task-gen, not browser-use). diff --git a/docs/taste/smoke_test.py b/docs/taste/smoke_test.py new file mode 100644 index 0000000000..2e7002fd17 --- /dev/null +++ b/docs/taste/smoke_test.py @@ -0,0 +1,219 @@ +"""Smoke test for the patched FleetTaskEnv reward-gating logic. + +Runs WITHOUT a real Fleet env: we duplicate the small `_apply_taste_reward` +helper that the diff installs on FleetTaskEnv (lifted verbatim from the diff +body) and exercise it against stubbed `score_trajectory_async` callables. + +Reward shape under test: + effective_taste = max(taste_floor, taste_score) (1.0 on judge fail/None) + reward = verifier_reward * effective_taste + +Cases: + (a) success + pretty taste (v=1, t=1.0, floor=0.1) -> R=1.0 + (b) success + mid taste (v=1, t=0.5, floor=0.1) -> R=0.5 + (c) success + ugly taste (v=1, t=0.0, floor=0.1) -> R=0.1 (floor) + (d) failure + pretty taste (v=0, t=1.0) -> R=0.0 (gated to 0) + (e) failure + ugly taste (v=0, t=0.0) -> R=0.0 + (f) judge timeout + success -> R=verifier (1.0) + (g) judge exception + success -> R=verifier (1.0) + (h) SKYRL_TASTE_DISABLED=1 + success -> R=verifier (1.0) + +Prints PASS/FAIL per test. Exits 0 if all pass, 1 otherwise. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from typing import Any, Optional + + +# ----------------------------------------------------------------------------- +# Reproduce the helper installed by env.py.diff. Keep this in sync with the +# diff body (search for "async def _apply_taste_reward" in env.py.diff). +# ----------------------------------------------------------------------------- + +logger = logging.getLogger("smoke_test") + + +class _FakeFleetTaskEnv: + """Minimal stand-in for FleetTaskEnv with the attributes the helper reads.""" + + def __init__(self, floor: float, timeout_s: float, judge_async): + self.taste_floor = floor + self.taste_judge_timeout_s = timeout_s + self.task_key = "smoke-task-1" + self.task_config = {"prompt": "Send an email to bob@example.com saying hi"} + self.chat_history = [ + {"role": "system", "content": "you are a CU agent"}, + {"role": "user", "content": "Send an email..."}, + {"role": "assistant", "content": "I will click Compose."}, + {"role": "user", "content": "ok"}, + {"role": "assistant", "content": "Now I type the address."}, + ] + self.last_verifier_reward: Optional[float] = None + self.last_taste_reward: Optional[float] = None + self.last_effective_taste: Optional[float] = None + self.last_taste_judge_failed: bool = False + # Inject the stubbed judge in place of the real package. + self._judge_async = judge_async + + async def _apply_taste_reward(self, verifier_reward: float, episode_done: bool) -> float: + # Body lifted from env.py.diff (kept tight). + if not episode_done: + return verifier_reward + + self.last_verifier_reward = float(verifier_reward) + self.last_taste_reward = None + self.last_effective_taste = None + self.last_taste_judge_failed = False + + score_trajectory_async = self._judge_async + + actions = [ + {"role": m.get("role"), "content": m.get("content")} + for m in self.chat_history + if m.get("role") == "assistant" + ] + task_text = self.task_config.get("prompt", "") + outcome = bool(self.last_verifier_reward >= 1.0) + + taste_score: Optional[float] + try: + taste_score = await asyncio.wait_for( + score_trajectory_async(task_text, actions, outcome), + timeout=self.taste_judge_timeout_s, + ) + except asyncio.TimeoutError: + self.last_taste_judge_failed = True + taste_score = None + except Exception: + self.last_taste_judge_failed = True + taste_score = None + + if taste_score is None: + self.last_effective_taste = 1.0 + return verifier_reward + + taste_score = max(0.0, min(1.0, float(taste_score))) + self.last_taste_reward = taste_score + effective_taste = max(self.taste_floor, taste_score) + self.last_effective_taste = effective_taste + return verifier_reward * effective_taste + + +# ----------------------------------------------------------------------------- +# Stubbed judges +# ----------------------------------------------------------------------------- + + +def _judge_returning(value: float): + async def _inner(task: str, actions, outcome: bool) -> float: + return value + return _inner + + +async def _judge_returns_none_if_disabled(task: str, actions, outcome: bool) -> Optional[float]: + # Mimics the SKYRL_TASTE_DISABLED=1 short-circuit in skyrl_taste.judge. + if os.environ.get("SKYRL_TASTE_DISABLED") == "1": + return None + return 1.0 + + +async def _judge_slow(task: str, actions, outcome: bool) -> float: + await asyncio.sleep(5.0) + return 0.9 + + +async def _judge_raises(task: str, actions, outcome: bool) -> float: + raise RuntimeError("simulated API outage") + + +# ----------------------------------------------------------------------------- +# Test cases +# ----------------------------------------------------------------------------- + + +def _ok(name: str) -> None: + print(f"PASS: {name}") + + +def _fail(name: str, msg: str) -> None: + print(f"FAIL: {name} -> {msg}") + + +async def _check(name: str, env: _FakeFleetTaskEnv, verifier: float, expected: float, + *, expect_failed: bool = False) -> int: + r = await env._apply_taste_reward(verifier_reward=verifier, episode_done=True) + ok = abs(r - expected) < 1e-9 and env.last_taste_judge_failed is expect_failed + if ok: + _ok(name) + return 0 + _fail(name, f"r={r} expected={expected} failed={env.last_taste_judge_failed} " + f"verifier={env.last_verifier_reward} taste={env.last_taste_reward} " + f"effective={env.last_effective_taste}") + return 1 + + +async def run() -> int: + failures = 0 + floor = 0.1 + + # (a) success + pretty taste -> 1.0 + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(1.0)) + failures += await _check("a_success_pretty_v1_t1_floor0.1_R1.0", env, 1.0, 1.0) + + # (b) success + mid taste -> 0.5 + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(0.5)) + failures += await _check("b_success_mid_v1_t0.5_floor0.1_R0.5", env, 1.0, 0.5) + + # (c) success + ugly taste -> floor (0.1) + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(0.0)) + failures += await _check("c_success_ugly_v1_t0_floor0.1_R0.1", env, 1.0, 0.1) + + # (d) failure + pretty taste -> 0.0 (the hack is closed) + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(1.0)) + failures += await _check("d_failure_pretty_v0_t1_R0.0_HACK_CLOSED", env, 0.0, 0.0) + + # (e) failure + ugly taste -> 0.0 + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_returning(0.0)) + failures += await _check("e_failure_ugly_v0_t0_R0.0", env, 0.0, 0.0) + + # (f) judge timeout + success -> verifier (1.0), failed=True + env = _FakeFleetTaskEnv(floor=floor, timeout_s=0.05, judge_async=_judge_slow) + failures += await _check("f_timeout_success_R_eq_verifier_1.0", env, 1.0, 1.0, + expect_failed=True) + + # (g) judge exception + success -> verifier (1.0), failed=True + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, judge_async=_judge_raises) + failures += await _check("g_exception_success_R_eq_verifier_1.0", env, 1.0, 1.0, + expect_failed=True) + + # (h) SKYRL_TASTE_DISABLED=1 + success -> verifier (1.0), failed=False + os.environ["SKYRL_TASTE_DISABLED"] = "1" + try: + env = _FakeFleetTaskEnv(floor=floor, timeout_s=2.0, + judge_async=_judge_returns_none_if_disabled) + failures += await _check("h_disabled_env_var_R_eq_verifier_1.0", env, 1.0, 1.0, + expect_failed=False) + # Extra invariant: effective_taste should be 1.0 in the disabled path. + if env.last_effective_taste != 1.0: + _fail("h_disabled_env_var_R_eq_verifier_1.0", + f"effective_taste={env.last_effective_taste} expected 1.0") + failures += 1 + finally: + del os.environ["SKYRL_TASTE_DISABLED"] + + return failures + + +if __name__ == "__main__": + failures = asyncio.run(run()) + if failures == 0: + print("\nALL SMOKE TESTS PASSED (8/8)") + sys.exit(0) + else: + print(f"\n{failures} TEST(S) FAILED") + sys.exit(1) diff --git a/skyrl-gym/pyproject.toml b/skyrl-gym/pyproject.toml index 82c4078e51..88f42c68eb 100644 --- a/skyrl-gym/pyproject.toml +++ b/skyrl-gym/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ Repository = "https://github.com/NovaSky-AI/SkyRL" [tool.setuptools.packages.find] -include = ["skyrl_gym*"] +include = ["skyrl_gym*", "skyrl_taste*"] [project.optional-dependencies] dev = [ diff --git a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py index 1255744782..1dc4471fb6 100644 --- a/skyrl-gym/skyrl_gym/envs/fleet_task/env.py +++ b/skyrl-gym/skyrl_gym/envs/fleet_task/env.py @@ -180,6 +180,28 @@ def __init__( else False ) + # Taste judge (LLM-as-judge) GATED reward: + # effective_taste = max(taste_floor, taste_score) (1.0 on judge fail) + # final_reward = verifier_reward * effective_taste + self.taste_floor = float( + env_config.get("taste_floor", 0.1) + if hasattr(env_config, "get") + else 0.1 + ) + if not 0.0 <= self.taste_floor <= 1.0: + raise ValueError( + f"taste_floor must be in [0,1], got {self.taste_floor}" + ) + self.taste_judge_timeout_s = float( + env_config.get("taste_judge_timeout_s", 10.0) + if hasattr(env_config, "get") + else 10.0 + ) + self.last_verifier_reward: Optional[float] = None + self.last_taste_reward: Optional[float] = None + self.last_effective_taste: Optional[float] = None + self.last_taste_judge_failed: bool = False + # Hint config self.enable_hints = ( env_config.get("enable_hints", False) @@ -667,16 +689,25 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: f"Failed to upload trace for {self.task_key}: {e}" ) + # Apply taste reward gating at episode end + if episode_done: + reward = await self._apply_taste_reward(reward, episode_done) + # Build observation message if max_turns_reached: + metadata = { + "done_reason": "max_turns", + "task_key": self.task_key, + "taste_reward": self.last_taste_reward, + "effective_taste": self.last_effective_taste, + "taste_floor": self.taste_floor, + "taste_judge_failed": self.last_taste_judge_failed, + } return BaseTextEnvStepOutput( observations=[], reward=reward, done=True, - metadata={ - "done_reason": "max_turns", - "task_key": self.task_key, - }, + metadata=metadata, ) # Build response observation @@ -703,6 +734,11 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: "step_time": step_time, "mcp_time": mcp_time, } + if episode_done: + metadata["taste_reward"] = self.last_taste_reward + metadata["effective_taste"] = self.last_effective_taste + metadata["taste_floor"] = self.taste_floor + metadata["taste_judge_failed"] = self.last_taste_judge_failed return BaseTextEnvStepOutput( observations=[new_obs], reward=reward, @@ -757,6 +793,12 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: if tool_call["name"] == "manage_context": metadata["modified_chat_history"] = self.chat_history.copy() + if episode_done: + metadata["taste_reward"] = self.last_taste_reward + metadata["effective_taste"] = self.last_effective_taste + metadata["taste_floor"] = self.taste_floor + metadata["taste_judge_failed"] = self.last_taste_judge_failed + return BaseTextEnvStepOutput( observations=[new_obs], reward=reward, @@ -764,6 +806,73 @@ async def step_async(self, action: str) -> BaseTextEnvStepOutput: metadata=metadata, ) + async def _apply_taste_reward( + self, verifier_reward: float, episode_done: bool + ) -> float: + """Gate the binary verifier reward by the taste-judge score. + + On non-terminal steps we pass through verifier_reward unchanged. + On terminal steps we call the judge with a hard timeout; on + timeout/exception/None we set effective_taste=1.0 (pure verifier). + """ + if not episode_done: + return verifier_reward + + self.last_verifier_reward = float(verifier_reward) + self.last_taste_reward = None + self.last_effective_taste = None + self.last_taste_judge_failed = False + + try: + from skyrl_taste.judge import score_trajectory_async + except Exception as e: + logger.warning( + "skyrl_taste import failed (%s); verifier-only reward", e + ) + self.last_taste_judge_failed = True + self.last_effective_taste = 1.0 + return verifier_reward + + actions = [ + {"role": m.get("role"), "content": m.get("content")} + for m in self.chat_history + if m.get("role") == "assistant" + ] + task_text = self.task_config.get("prompt", "") if self.task_config else "" + outcome = bool(self.last_verifier_reward >= 1.0) + + taste_score: Optional[float] + try: + taste_score = await asyncio.wait_for( + score_trajectory_async(task_text, actions, outcome), + timeout=self.taste_judge_timeout_s, + ) + except asyncio.TimeoutError: + self.last_taste_judge_failed = True + logger.warning( + "taste judge timed out after %.1fs for task_key=%s", + self.taste_judge_timeout_s, + getattr(self, "task_key", "?"), + ) + taste_score = None + except Exception as e: + self.last_taste_judge_failed = True + logger.warning( + "taste judge raised %s: %s for task_key=%s", + type(e).__name__, e, getattr(self, "task_key", "?"), + ) + taste_score = None + + if taste_score is None: + self.last_effective_taste = 1.0 + return verifier_reward + + taste_score = max(0.0, min(1.0, float(taste_score))) + self.last_taste_reward = taste_score + effective_taste = max(self.taste_floor, taste_score) + self.last_effective_taste = effective_taste + return verifier_reward * effective_taste + def step(self, action: str) -> BaseTextEnvStepOutput: """Execute one step in the Fleet environment (sync wrapper).""" return asyncio.run(self.step_async(action)) @@ -822,6 +931,13 @@ def get_metrics(self) -> Dict[str, Any]: } if self.last_reward is not None: metrics["final_reward"] = self.last_reward + # Taste judge metrics + if self.last_taste_reward is not None: + metrics["taste_reward"] = self.last_taste_reward + if self.last_effective_taste is not None: + metrics["effective_taste"] = self.last_effective_taste + metrics["taste_floor"] = self.taste_floor + metrics["taste_judge_failed"] = self.last_taste_judge_failed # Include verifier feedback for hint generation if self._verifier_stdout is not None: metrics["verifier_stdout"] = self._verifier_stdout diff --git a/skyrl-gym/skyrl_taste/__init__.py b/skyrl-gym/skyrl_taste/__init__.py new file mode 100644 index 0000000000..65aa9054ec --- /dev/null +++ b/skyrl-gym/skyrl_taste/__init__.py @@ -0,0 +1,13 @@ +"""skyrl_taste: thin async wrapper around the taste-judge for SkyRL GRPO. + +Public API: + score_trajectory_async(task, actions, outcome) -> Optional[float] + get_judge_provider_info() -> {"taste_judge_provider", "taste_judge_model"} + +Returns a value in [0, 1] (rescaled from the 1-5 weighted_total) or None +when the judge is disabled / errored. +""" + +from .judge import score_trajectory_async, get_judge_provider_info + +__all__ = ["score_trajectory_async", "get_judge_provider_info"] diff --git a/skyrl-gym/skyrl_taste/judge.py b/skyrl-gym/skyrl_taste/judge.py new file mode 100644 index 0000000000..3bd1459ead --- /dev/null +++ b/skyrl-gym/skyrl_taste/judge.py @@ -0,0 +1,177 @@ +"""skyrl_taste.judge +==================== + +Async wrapper around the synchronous taste judge defined in +`research/judge/judge.py`. Re-exposes the judge with the contract the +SkyRL Fleet env expects: + + async def score_trajectory_async(task, actions, outcome) -> Optional[float] + +Provider routing (env vars, read at call-time so swaps don't require a +restart of the process -- only a fresh rollout): +- ``SKYRL_TASTE_PROVIDER`` in {"anthropic", "openai", "openrouter"}. + Default: "openrouter" (cheapest production path). +- ``SKYRL_TASTE_MODEL``: model identifier. Default + "anthropic/claude-haiku-4.5" (an OpenRouter slug). For provider="anthropic" + this would be e.g. "claude-sonnet-4-6"; for provider="openai" e.g. + "gpt-4o-mini". +- ``SKYRL_TASTE_BLIND_OUTCOME``: "1" (default) suppresses the verifier + outcome from the judge prompt. Stream 4 found that exposing the outcome + causes taste scores to correlate ~0.7 with verifier (outcome bleed) and + collapses the shaping signal. + +Behavior: +- The underlying judges are *synchronous* and use blocking SDKs. We run + them in `asyncio.to_thread(...)` so they do not stall the event loop. + SkyRL's generator runs each rollout's `step_async` as its own asyncio + task, so judge calls across the GRPO group naturally overlap. +- Returns the rubric's `weighted_total`, rescaled from [1, 5] -> [0, 1] so + the blended reward stays in [0, 1] and existing pass@n / signal-ratio + metrics in `integrations/fleet/reward_metrics.py` keep working. +- Returns None on: + * `SKYRL_TASTE_DISABLED=1` (env-var bypass) + * The underlying judge returning a None-shaped result (parse / API failure) + The caller is expected to fall back to verifier-only reward when None. +- Screenshots are NOT passed in this version (text-only judge). Trade-off: + text-only keeps judge latency around 1-3 s/trajectory and avoids blowing + the judge's context with 50-80 base64 PNGs per browser_use rollout. We + lose direct ui_grounding signal, but the judge can still infer it from + action targets + tool errors. Re-enable screenshots later by sampling the + `tool_result` image_url blocks out of `chat_history` and threading them + through the judge call with `screenshots=...`. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from pathlib import Path +from typing import Any, Callable, Optional + +logger = logging.getLogger("skyrl_taste") + +# Make the research-side judge importable. In a packaged install this would +# be a sibling import; for the launch-ready integration we add the research +# tree to sys.path so we don't have to vendor it. +_RESEARCH_JUDGE_DIR = Path(__file__).resolve().parents[2] / "research" / "judge" +if _RESEARCH_JUDGE_DIR.is_dir() and str(_RESEARCH_JUDGE_DIR) not in sys.path: + sys.path.insert(0, str(_RESEARCH_JUDGE_DIR)) + +try: + from judge import ( # type: ignore + score_trajectory as _score_trajectory_anthropic, + score_trajectory_gpt4o as _score_trajectory_openai, + score_trajectory_openrouter as _score_trajectory_openrouter, + ) +except Exception as e: # pragma: no cover + logger.warning("could not import research judge: %s", e) + _score_trajectory_anthropic = None # type: ignore[assignment] + _score_trajectory_openai = None # type: ignore[assignment] + _score_trajectory_openrouter = None # type: ignore[assignment] + + +_DEFAULT_PROVIDER = "openrouter" +_DEFAULT_MODEL = "anthropic/claude-haiku-4.5" + + +def _resolve_provider() -> tuple[str, str, bool, Optional[Callable[..., dict]]]: + """Read SKYRL_TASTE_PROVIDER / SKYRL_TASTE_MODEL / SKYRL_TASTE_BLIND_OUTCOME + and return (provider, model, blind_outcome, callable). The callable is + None if the corresponding research-side function failed to import.""" + provider = os.environ.get("SKYRL_TASTE_PROVIDER", _DEFAULT_PROVIDER).strip().lower() + model = os.environ.get("SKYRL_TASTE_MODEL", _DEFAULT_MODEL) + blind_outcome = os.environ.get("SKYRL_TASTE_BLIND_OUTCOME", "1") == "1" + + if provider == "anthropic": + return provider, model, blind_outcome, _score_trajectory_anthropic + if provider == "openai": + return provider, model, blind_outcome, _score_trajectory_openai + if provider == "openrouter": + return provider, model, blind_outcome, _score_trajectory_openrouter + logger.warning( + "unknown SKYRL_TASTE_PROVIDER=%r; falling back to %s", + provider, + _DEFAULT_PROVIDER, + ) + return _DEFAULT_PROVIDER, model, blind_outcome, _score_trajectory_openrouter + + +def _rescale_to_unit_interval(weighted_total: Optional[float]) -> Optional[float]: + """Rescale weighted_total from [1, 5] (rubric) to [0, 1] (RL reward). + + Returns None passthrough; clips defensively. + """ + if weighted_total is None: + return None + try: + v = (float(weighted_total) - 1.0) / 4.0 + except (TypeError, ValueError): + return None + if v < 0.0: + return 0.0 + if v > 1.0: + return 1.0 + return v + + +def get_judge_provider_info() -> dict[str, str]: + """Return the resolved (provider, model) for run-once metric logging.""" + provider, model, _, _ = _resolve_provider() + return {"taste_judge_provider": provider, "taste_judge_model": model} + + +async def score_trajectory_async( + task: str, + actions: list[dict[str, Any]], + outcome: bool, +) -> Optional[float]: + """Async-friendly entrypoint to the taste judge. + + Args: + task: natural-language task description (`task_config["prompt"]`). + actions: ordered list of action dicts pulled from the trajectory. + outcome: bool from the verifier (verifier_reward >= 1.0). + + Returns: + A scalar in [0, 1] = rescaled `weighted_total`, or None if the + judge is disabled or failed. The caller must treat None as + "fall back to verifier-only reward". + """ + if os.environ.get("SKYRL_TASTE_DISABLED") == "1": + # Hard kill switch for runtime rollback. + return None + + provider, model, blind_outcome, fn = _resolve_provider() + if fn is None: + logger.warning( + "taste judge module unavailable for provider=%s; returning None", + provider, + ) + return None + + # Run the blocking judge in a thread so we don't stall the event loop. + # screenshots=None: see module docstring for the rationale. + try: + result = await asyncio.to_thread( + fn, + task, + actions, + outcome, + None, # screenshots + model, + blind_outcome, + ) + except Exception as e: + logger.warning("taste judge (%s) raised in thread: %s", provider, e) + return None + + if not isinstance(result, dict): + return None + + if result.get("error"): + # The judge already logged; signal fall-back. + return None + + return _rescale_to_unit_interval(result.get("weighted_total")) diff --git a/tasks/openenv-fleet-grpo-vl-taste.yaml b/tasks/openenv-fleet-grpo-vl-taste.yaml new file mode 100644 index 0000000000..e74fc24e50 --- /dev/null +++ b/tasks/openenv-fleet-grpo-vl-taste.yaml @@ -0,0 +1,131 @@ +# Fleet VL/CUA GRPO Training WITH TASTE JUDGE - Qwen3.5-9B (Vision-Language) +# +# Reward shape: GATED TASTE +# effective_taste = max(taste_floor, taste_score) (1.0 on judge fail) +# reward = verifier_reward * effective_taste +# Closes the "pretty failure" hack (verifier=0 -> reward=0 always) while +# preserving within-success taste variance via the floor (default 0.1). +# +# Delta from tasks/openenv-fleet-grpo-vl.yaml: +# - environment.skyrl_gym.fleet_task.taste_floor=0.1 (NEW) +# - environment.skyrl_gym.fleet_task.taste_judge_timeout_s=10.0 (NEW) +# - trainer.algorithm.grpo_norm_by_std=false (FLIPPED, was true default) +# - ANTHROPIC_API_KEY / OPENAI_API_KEY env vars added (NEW) +# +# Required env vars (pass each via `sky launch --env KEY=VALUE`): +# FLEET_API_KEY - Fleet API access for OpenEnv environments +# WANDB_API_KEY - WandB logging +# AWS_ACCESS_KEY_ID - S3 dataset/checkpoint/trajectory buckets +# AWS_SECRET_ACCESS_KEY - S3 credentials +# ANTHROPIC_API_KEY - NEW. Claude judge (research/judge/judge.py default). +# OPENAI_API_KEY - NEW. Reserved for inter-rater / GPT-4o judge path. +# +# Usage: +# sky launch configs/openenv-fleet-grpo-vl-taste.yaml \ +# --env FLEET_API_KEY=... \ +# --env WANDB_API_KEY=... \ +# --env AWS_ACCESS_KEY_ID=... \ +# --env AWS_SECRET_ACCESS_KEY=... \ +# --env ANTHROPIC_API_KEY=... \ +# --env OPENAI_API_KEY=... + +name: fleet-vl-grpo-qwen3-5-9b-taste + +resources: + disk_size: 750 + ports: 6479 + ordered: + - accelerators: H200:8 + cloud: runpod + - accelerators: H200:8 + cloud: gcp + use_spot: true + image_id: projects/deeplearning-platform-release/global/images/common-cu128-ubuntu-2204-nvidia-570-v20260305 + - accelerators: H200:8 + cloud: lambda + - accelerators: H200:8 + cloud: nebius + - accelerators: H200-SXM:8 + cloud: vast + +num_nodes: 1 + +workdir: + url: https://github.com/fleet-ai/skyrl-fleet.git + ref: taste-reward-shaping + +envs: + WANDB_API_KEY: "" + FLEET_API_KEY: "" + # NEW: judge credentials. Anthropic is the default judge backend; OpenAI + # is used for inter-rater agreement / fallback paths. + ANTHROPIC_API_KEY: "" + OPENAI_API_KEY: "" + LOGGER: "wandb" + INFERENCE_BACKEND: "vllm" + DATA_VERSION: "v6" + ENV_KEYS: "" + DIFFICULTY: "" + MODALITY: "browser_use" + MAX_TURNS: 80 + MAX_INPUT_LENGTH: 80000 + MAX_GENERATE_LENGTH: 4096 + NUM_EPOCHS: 10 + RUN_ID: "" + MAX_TASKS: "" + RESUME_RUN_NAME: "" + AWS_ACCESS_KEY_ID: "" + AWS_SECRET_ACCESS_KEY: "" + AWS_REGION: "us-east-1" + S3_DATASET_BUCKET: "fleet-internal-datasets" + S3_CHECKPOINT_BUCKET: "skyrl-checkpoints" + S3_TRAJECTORY_BUCKET: "skyrl-trajectories" + # NEW: runtime kill-switch for taste judge. Set to "1" to fall back to + # verifier-only reward without touching the patched env or restarting. + SKYRL_TASTE_DISABLED: "0" + # NEW: floor for the gated taste reward (forwarded into Hydra override). + # reward = verifier * max(taste_floor, taste_score) + # floor=1.0 -> pure verifier (clean ablation baseline) + # floor=0.0 -> pure multiplicative (every ugly success -> 0 reward) + # floor=0.1 -> default; offline taste of verifier=1 trajectories sits + # around 0.13 so floor=0.1 acts as multiplicative-with-cushion; + # re-tune after a 50-100 step pilot using effective_taste P25. + TASTE_FLOOR: "0.1" + # NEW: production judge selection. OpenRouter is used at training time so + # we don't hit per-org Anthropic rate limits during burst-end-of-step + # judge calls. + OPENROUTER_API_KEY: "" + SKYRL_TASTE_PROVIDER: "openrouter" + SKYRL_TASTE_MODEL: "anthropic/claude-haiku-4.5" + # Stream 4 finding: pass blind_outcome=True at training time to suppress + # outcome bleed (judge sees outcome=True and inflates ~+1.4 weighted-pts). + SKYRL_TASTE_BLIND_OUTCOME: "1" + +setup: | + bash scripts/fleet-common-setup.sh \ + --openenv-branch deniz/fleet_client \ + --extra-setup scripts/fleet-qwen35-extra-setup.sh + # NEW: install the taste-judge package next to skyrl_gym so the env's + # `from skyrl_taste.judge import score_trajectory_async` import resolves. + pip install --no-deps anthropic openai + # skyrl_taste is in skyrl-gym/ — ensure it's importable even if editable install misses it + pip install --no-deps -e ./skyrl-gym 2>/dev/null || true + +run: | + # We delegate to the existing fleet-vl-run.sh wrapper, which forwards extra + # Hydra overrides via the trailing args. The new flags below are the ONLY + # delta from the upstream script. + bash scripts/fleet-vl-run.sh \ + environment.skyrl_gym.fleet_task.taste_floor=${TASTE_FLOOR} \ + environment.skyrl_gym.fleet_task.taste_judge_timeout_s=10.0 \ + trainer.algorithm.grpo_norm_by_std=false + # ^ grpo_norm_by_std=false (flipped from default true): + # Even under gated taste, within-group reward std is inflated whenever a + # group has a mix of pretty and ugly successes (rewards in {0.1, ..., 1.0}) + # on top of the binary verifier signal. Default GRPO normalization would + # divide advantages by that larger denominator and damp the gradient. + # Stream 1's analysis showed std-norm collapses learning under shaped + # reward; that conclusion still applies here. Disabling std normalization + # keeps advantage magnitudes proportional to (reward - mean), which is + # the part the taste signal actually increases. Re-enable and tune + # advantage_batch_normalize=true if cross-prompt magnitudes get unstable.