diff --git a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py index 30303c8927..9892d98e3c 100644 --- a/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/ray_wrapped_inference_engine.py @@ -70,6 +70,12 @@ async def teardown(self): async def reset_prefix_cache(self): return await self.inference_engine_actor.reset_prefix_cache.remote() + async def start_profile(self, profile_prefix: str | None = None): + return await self.inference_engine_actor.start_profile.remote(profile_prefix=profile_prefix) + + async def stop_profile(self): + return await self.inference_engine_actor.stop_profile.remote() + async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]: return await self.inference_engine_actor.chat_completion.remote(request_payload) diff --git a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py index 6d827ac327..715370e393 100644 --- a/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -345,6 +345,9 @@ class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._weight_loader = VLLMWeightLoader(self.llm, is_async=True) + # vLLM raises if profile() is called without profiler_config; gate on it. + self._profile_enabled = self.llm.vllm_config.profiler_config.profiler is not None + self._profile_counter = 0 def _create_engine(self, *args, **kwargs): openai_kwargs = pop_openai_kwargs(kwargs) @@ -489,17 +492,30 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu """Generate responses using vLLM's async engine.""" prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch) - tasks = [] - for prompt in prompt_token_ids: - # Schedule the collection of outputs for each prompt. - # Avoid duplicate request_ids - request_id = str(uuid4().hex) - task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params)) - tasks.append(task) - outputs = await asyncio.gather(*tasks) + if self._profile_enabled: + await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}") + self._profile_counter += 1 + try: + tasks = [] + for prompt in prompt_token_ids: + # Schedule the collection of outputs for each prompt. + # Avoid duplicate request_ids + request_id = str(uuid4().hex) + task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params)) + tasks.append(task) + outputs = await asyncio.gather(*tasks) + finally: + if self._profile_enabled: + await self.llm.stop_profile() return self._postprocess_outputs(outputs) + async def start_profile(self, profile_prefix: Optional[str] = None) -> None: + await self.llm.start_profile(profile_prefix=profile_prefix) + + async def stop_profile(self) -> None: + await self.llm.stop_profile() + async def wake_up(self, *args: Any, **kwargs: Any): await self.llm.wake_up(tags=kwargs.get("tags", None)) diff --git a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py index 2efbe3ae44..a555564284 100644 --- a/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py @@ -209,12 +209,20 @@ class RemoteInferenceClient: tokenizer: Optional[Any] = None """Optional HF tokenizer for local tokenize/detokenize (avoids HTTP round-trips).""" + profile_each_sample: bool = False + """If True, hit ``/start_profile`` and ``/stop_profile`` around each ``sample()`` + call. Requires the server to have been launched with ``profiler_config`` set + (via ``engine_init_kwargs.profiler_config``). Concurrent ``sample()`` calls + are serialized by a lock so traces don't overlap.""" + # Private fields excluded from repr for cleaner output _session: Optional[aiohttp.ClientSession] = field(default=None, repr=False) _world_size: Optional[Tuple[int, int]] = field(default=None, repr=False) _gen_sem: Optional[asyncio.Semaphore] = field(default=None, repr=False) _detok_sem: Optional[asyncio.Semaphore] = field(default=None, repr=False) _sem_loop: Optional[asyncio.AbstractEventLoop] = field(default=None, repr=False) + _profile_counter: int = field(default=0, repr=False) + _profile_lock: Optional[asyncio.Lock] = field(default=None, repr=False) def __post_init__(self): if self.data_parallel_size <= 0: @@ -289,10 +297,14 @@ async def _post(self, url: str, json: Dict[str, Any], headers: Optional[Dict[str try: body = await resp.json(content_type=None) except Exception as e: + text = "" + try: + text = await resp.text() + except Exception: + pass if 400 <= resp.status < 500: # Non-JSON client error (e.g. plain text 422 from vllm-router). # Raise immediately — client errors won't succeed on retry. - text = await resp.text() raise aiohttp.ClientResponseError( resp.request_info, resp.history, @@ -301,7 +313,10 @@ async def _post(self, url: str, json: Dict[str, Any], headers: Optional[Dict[str headers=resp.headers, ) last_exc = e - logger.debug(f"retry {attempt + 1}/{_DATA_PLANE_RETRIES} for {url=}: {e}") + logger.debug( + f"retry {attempt + 1}/{_DATA_PLANE_RETRIES} for {url=}: " + f"status={resp.status} body={text[:200]!r}: {e}" + ) await asyncio.sleep(1) continue raise_for_status(resp, body) @@ -612,11 +627,28 @@ async def sample(self, request_payload: SampleRequestPayload) -> SampleResponse: url = f"{self.proxy_url}/inference/v1/generate" gen_sem, _ = self._get_semaphores() - if gen_sem is None: - response = await self._post(url, json=payload, headers=headers) - else: + + async def _do_post() -> Dict[str, Any]: + if gen_sem is None: + return await self._post(url, json=payload, headers=headers) async with gen_sem: - response = await self._post(url, json=payload, headers=headers) + return await self._post(url, json=payload, headers=headers) + + if self.profile_each_sample: + # start/stop_profile is global per-engine, so serialize concurrent + # samples to keep traces clean. + if self._profile_lock is None: + self._profile_lock = asyncio.Lock() + async with self._profile_lock: + prefix = f"sample_{self._profile_counter}" + self._profile_counter += 1 + await self.start_profile(profile_prefix=prefix) + try: + response = await _do_post() + finally: + await self.stop_profile() + else: + response = await _do_post() # vLLM returns: list[dict[str(token_id) → {"logprob": float, ...}] | None] result_prompt_logprobs: Optional[List[Optional[float]]] = None @@ -880,6 +912,20 @@ async def _call_all_servers( ) return {url: resp for url, resp in results} + async def start_profile(self, profile_prefix: Optional[str] = None) -> Dict[str, Any]: + """Open a profiler span on every backend server. + + Requires the server to have been launched with ``profiler_config`` set + (otherwise vLLM raises a 500). vLLM's ``/start_profile`` endpoint accepts + ``profile_prefix`` as a query param (used as the trace filename prefix). + """ + params = {"profile_prefix": profile_prefix} if profile_prefix else None + return await self._call_all_servers("/start_profile", params=params) + + async def stop_profile(self) -> Dict[str, Any]: + """Close the profiler span on every backend and flush traces to ``torch_profiler_dir``.""" + return await self._call_all_servers("/stop_profile") + async def pause(self, mode: Union[PauseMode, str] = PauseMode.KEEP, clear_cache: bool = False) -> Dict[str, Any]: """ Pause generation on all backends. diff --git a/skyrl/backends/skyrl_train/inference_servers/setup.py b/skyrl/backends/skyrl_train/inference_servers/setup.py index ec3177c76f..7a69f628c7 100644 --- a/skyrl/backends/skyrl_train/inference_servers/setup.py +++ b/skyrl/backends/skyrl_train/inference_servers/setup.py @@ -283,6 +283,16 @@ def build_new_inference_client( active_lora_name = ( _SKYRL_LORA_ADAPTER_NAME if lora_cfg and lora_cfg.rank > 0 and cfg.trainer.strategy != "megatron" else None ) + + # Auto-enable per-sample profiling when the user configured a vLLM profiler + # via engine_init_kwargs.profiler_config. Accept both dict (raw user input) + # and ProfilerConfig (post-coercion in build_vllm_cli_args). + profiler_cfg = ie_cfg.engine_init_kwargs.get("profiler_config") if ie_cfg.engine_init_kwargs else None + if isinstance(profiler_cfg, dict): + profile_each_sample = bool(profiler_cfg.get("profiler")) + else: + profile_each_sample = bool(profiler_cfg and getattr(profiler_cfg, "profiler", None)) + client = RemoteInferenceClient( proxy_url=server_setup.proxy_url, server_urls=server_setup.server_urls, @@ -292,6 +302,7 @@ def build_new_inference_client( uses_lora_weight_sync=_uses_lora_weight_sync(cfg), data_parallel_size=ie_cfg.data_parallel_size, tokenizer=tokenizer, + profile_each_sample=profile_each_sample, ) return client, server_setup diff --git a/skyrl/backends/skyrl_train/inference_servers/utils.py b/skyrl/backends/skyrl_train/inference_servers/utils.py index 4f6fab74cf..a7965c5c5d 100644 --- a/skyrl/backends/skyrl_train/inference_servers/utils.py +++ b/skyrl/backends/skyrl_train/inference_servers/utils.py @@ -102,8 +102,15 @@ def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace: else: args.enable_lora = False - # Add any extra engine_init_kwargs - engine_kwargs = get_config_as_dict(ie_cfg.engine_init_kwargs) + # Add any extra engine_init_kwargs. Copy so we don't mutate the source + # config (downstream readers in setup.py expect the original shape). + engine_kwargs = dict(get_config_as_dict(ie_cfg.engine_init_kwargs)) + # vLLM's API server asserts args.profiler_config is a ProfilerConfig + # instance (not a dict), so coerce here when the user supplies it as a dict. + if isinstance(engine_kwargs.get("profiler_config"), dict): + from vllm.config.profiler import ProfilerConfig + + engine_kwargs["profiler_config"] = ProfilerConfig(**engine_kwargs["profiler_config"]) for key, value in engine_kwargs.items(): setattr(args, key, value)