From ff3dfd8f764b841c8d5d6d36bd28214d88f880b0 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 17:57:31 +0000 Subject: [PATCH 1/4] refactor(llm): dedup async/sync completion, responses, and transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract shared helpers to eliminate ~400 lines of duplication between sync and async LLM paths: - _prepare_completion() / _CompletionCtx: shared setup for completion/acompletion - _completion_process_response(): shared post-processing (mock tools, telemetry) - _build_completion_llm_response(): shared ModelResponse → LLMResponse conversion - _prepare_responses() / _ResponsesCtx: shared setup for responses/aresponses - _responses_common_litellm_kwargs(): shared litellm call kwargs builder - _responses_process_stream_event(): shared stream event extraction - _responses_finalize_stream(): shared stream completion validation - _build_responses_llm_response(): shared ResponsesAPI → LLMResponse conversion - _suppress_litellm_warnings(): shared warning filters for transport calls - _common_completion_kwargs(): shared litellm completion call kwargs builder - _metrics_snapshot() / _retry_kwargs(): eliminate repeated boilerplate - _raise_mapped(): shared exception mapping for error handlers Co-authored-by: openhands --- openhands-sdk/openhands/sdk/llm/llm.py | 949 +++++++++++-------------- 1 file changed, 422 insertions(+), 527 deletions(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 3bee201e0b..db3a7b7818 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -7,7 +7,16 @@ import warnings from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, Literal, get_args, get_origin +from dataclasses import dataclass +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + Never, + get_args, + get_origin, +) import httpx # noqa: F401 from pydantic import ( @@ -715,10 +724,7 @@ def _handle_error( ) if result is not None: return result - mapped = map_provider_exception(error) - if mapped is not error: - raise mapped from error - raise + self._raise_mapped(error) async def _ahandle_error( self, @@ -748,66 +754,50 @@ async def _ahandle_error( ) if result is not None: return result + self._raise_mapped(error) + + def _raise_mapped(self, error: Exception) -> Never: + """Map a provider exception and re-raise. Must be called inside except.""" mapped = map_provider_exception(error) if mapped is not error: raise mapped from error raise - def completion( - self, - messages: list[Message], - tools: Sequence[ToolDefinition] | None = None, - _return_metrics: bool = False, - add_security_risk_prediction: bool = False, - on_token: TokenCallbackType | None = None, - **kwargs, - ) -> LLMResponse: - """Generate a completion from the language model. - - This is the method for getting responses from the model via Completion API. - It handles message formatting, tool calling, and response processing. - - Args: - messages: List of conversation messages. - tools: Optional list of tools available to the model. - _return_metrics: Whether to return usage metrics. - add_security_risk_prediction: Add security_risk field to tool schemas. - on_token: Optional callback for streaming tokens. - **kwargs: Additional arguments passed to the LLM API. - - Returns: - LLMResponse containing the model's response and metadata. - - Note: - Summary field is always added to tool schemas for transparency and - explainability of agent actions. + # ========================================================================= + # Shared helpers for completion / responses + # ========================================================================= - Raises: - ValueError: If streaming is requested (not supported). + @dataclass + class _CompletionCtx: + """Pre-computed state shared by sync and async completion paths.""" - Example: - ```python - from openhands.sdk.llm import Message, TextContent + formatted_messages: list[dict[str, Any]] + call_kwargs: dict[str, Any] + telemetry_ctx: dict[str, Any] + use_mock_tools: bool + cc_tools: list[ChatCompletionToolParam] + enable_streaming: bool + add_security_risk_prediction: bool - messages = [Message(role="user", content=[TextContent(text="Hello")])] - response = llm.completion(messages) - print(response.content) - ``` - """ + def _prepare_completion( + self, + messages: list[Message], + tools: Sequence[ToolDefinition] | None, + add_security_risk_prediction: bool, + on_token: AnyTokenCallbackType | None, + kwargs: dict[str, Any], + ) -> _CompletionCtx: + """Build all state needed before the transport call (shared by sync/async).""" enable_streaming = bool(kwargs.get("stream", False)) or self.stream if enable_streaming: if on_token is None: raise ValueError("Streaming requires an on_token callback") kwargs["stream"] = True - # 1) serialize messages formatted_messages = self.format_messages_for_llm(messages) - - # 2) choose function-calling strategy use_native_fc = self.native_tool_calling original_fncall_msgs = copy.deepcopy(formatted_messages) - # Convert Tool objects to ChatCompletionToolParam once here cc_tools: list[ChatCompletionToolParam] = [] if tools: cc_tools = [ @@ -830,23 +820,18 @@ def completion( include_security_params=add_security_risk_prediction, ) - # 3) normalize provider params - # Only pass tools when native FC is active kwargs["tools"] = cc_tools if (bool(cc_tools) and use_native_fc) else None has_tools_flag = bool(cc_tools) and use_native_fc - # Behavior-preserving: delegate to select_chat_options call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) - # 4) request context for telemetry (always include context_window for metrics) assert self._telemetry is not None - # Always pass context_window so metrics are tracked even when logging disabled telemetry_ctx: dict[str, Any] = { "context_window": self.effective_max_input_tokens or 0 } if self._telemetry.log_enabled: telemetry_ctx.update( { - "messages": formatted_messages[:], # already simple dicts + "messages": formatted_messages[:], "tools": tools, "kwargs": {k: v for k, v in call_kwargs.items()}, } @@ -854,8 +839,62 @@ def completion( if tools and not use_native_fc: telemetry_ctx["raw_messages"] = original_fncall_msgs - # 5) do the call with retries - @self.retry_decorator( + return LLM._CompletionCtx( + formatted_messages=formatted_messages, + call_kwargs=call_kwargs, + telemetry_ctx=telemetry_ctx, + use_mock_tools=use_mock_tools, + cc_tools=cc_tools, + enable_streaming=enable_streaming, + add_security_risk_prediction=add_security_risk_prediction, + ) + + def _completion_process_response( + self, + resp: ModelResponse, + ctx: _CompletionCtx, + ) -> ModelResponse: + """Post-process a completion response (mock tools, telemetry). + + Called inside the retry boundary for both sync and async paths. + """ + assert self._telemetry is not None + raw_resp: ModelResponse | None = None + if ctx.use_mock_tools: + raw_resp = copy.deepcopy(resp) + resp = self.post_response_prompt_mock( + resp, + nonfncall_msgs=ctx.formatted_messages, + tools=ctx.cc_tools, + include_security_params=ctx.add_security_risk_prediction, + ) + self._telemetry.on_response(resp, raw_resp=raw_resp) + if not resp.get("choices") or len(resp["choices"]) < 1: + raise LLMNoResponseError( + "Response choices is less than 1. Response: " + str(resp) + ) + return resp + + def _build_completion_llm_response(self, resp: ModelResponse) -> LLMResponse: + """Convert a validated ModelResponse into an LLMResponse.""" + first_choice = resp["choices"][0] + message = Message.from_llm_chat_message(first_choice["message"]) + return LLMResponse( + message=message, + metrics=self._metrics_snapshot(), + raw_response=resp, + ) + + def _metrics_snapshot(self) -> MetricsSnapshot: + return MetricsSnapshot( + model_name=self.metrics.model_name, + accumulated_cost=self.metrics.accumulated_cost, + max_budget_per_task=self.metrics.max_budget_per_task, + accumulated_token_usage=self.metrics.accumulated_token_usage, + ) + + def _retry_kwargs(self) -> dict[str, Any]: + return dict( num_retries=self.num_retries, retry_exceptions=LLM_RETRY_EXCEPTIONS, retry_min_wait=self.retry_min_wait, @@ -863,58 +902,71 @@ def completion( retry_multiplier=self.retry_multiplier, retry_listener=self._retry_listener_fn, ) + + # ========================================================================= + # Chat Completion API + # ========================================================================= + + def completion( + self, + messages: list[Message], + tools: Sequence[ToolDefinition] | None = None, + _return_metrics: bool = False, + add_security_risk_prediction: bool = False, + on_token: TokenCallbackType | None = None, + **kwargs, + ) -> LLMResponse: + """Generate a completion from the language model. + + This is the method for getting responses from the model via Completion API. + It handles message formatting, tool calling, and response processing. + + Args: + messages: List of conversation messages. + tools: Optional list of tools available to the model. + _return_metrics: Whether to return usage metrics. + add_security_risk_prediction: Add security_risk field to tool schemas. + on_token: Optional callback for streaming tokens. + **kwargs: Additional arguments passed to the LLM API. + + Returns: + LLMResponse containing the model's response and metadata. + + Note: + Summary field is always added to tool schemas for transparency and + explainability of agent actions. + + Raises: + ValueError: If streaming is requested (not supported). + + Example: + ```python + from openhands.sdk.llm import Message, TextContent + + messages = [Message(role="user", content=[TextContent(text="Hello")])] + response = llm.completion(messages) + print(response.content) + ``` + """ + ctx = self._prepare_completion( + messages, tools, add_security_risk_prediction, on_token, kwargs + ) + + @self.retry_decorator(**self._retry_kwargs()) def _one_attempt(**retry_kwargs) -> ModelResponse: assert self._telemetry is not None - self._telemetry.on_request(telemetry_ctx=telemetry_ctx) - # Merge retry-modified kwargs (like temperature) with call_kwargs - final_kwargs = {**call_kwargs, **retry_kwargs} + self._telemetry.on_request(telemetry_ctx=ctx.telemetry_ctx) + final_kwargs = {**ctx.call_kwargs, **retry_kwargs} resp = self._transport_call( - messages=formatted_messages, + messages=ctx.formatted_messages, **final_kwargs, - enable_streaming=enable_streaming, + enable_streaming=ctx.enable_streaming, on_token=on_token, ) - raw_resp: ModelResponse | None = None - if use_mock_tools: - raw_resp = copy.deepcopy(resp) - resp = self.post_response_prompt_mock( - resp, - nonfncall_msgs=formatted_messages, - tools=cc_tools, - include_security_params=add_security_risk_prediction, - ) - # 6) telemetry - self._telemetry.on_response(resp, raw_resp=raw_resp) - - # Ensure at least one choice. - # Gemini sometimes returns empty choices; we raise LLMNoResponseError here - # inside the retry boundary so it is retried. - if not resp.get("choices") or len(resp["choices"]) < 1: - raise LLMNoResponseError( - "Response choices is less than 1. Response: " + str(resp) - ) - - return resp + return self._completion_process_response(resp, ctx) try: - resp = _one_attempt() - - # Convert the first choice to an OpenHands Message - first_choice = resp["choices"][0] - message = Message.from_llm_chat_message(first_choice["message"]) - - # Get current metrics snapshot - metrics_snapshot = MetricsSnapshot( - model_name=self.metrics.model_name, - accumulated_cost=self.metrics.accumulated_cost, - max_budget_per_task=self.metrics.max_budget_per_task, - accumulated_token_usage=self.metrics.accumulated_token_usage, - ) - - # Create and return LLMResponse - return LLMResponse( - message=message, metrics=metrics_snapshot, raw_response=resp - ) + return self._build_completion_llm_response(_one_attempt()) except Exception as e: return self._handle_error( e, @@ -944,103 +996,27 @@ async def acompletion( Uses ``litellm.acompletion`` under the hood, freeing the event loop while waiting for the LLM provider response. """ - enable_streaming = bool(kwargs.get("stream", False)) or self.stream - if enable_streaming: - if on_token is None: - raise ValueError("Streaming requires an on_token callback") - kwargs["stream"] = True - - formatted_messages = self.format_messages_for_llm(messages) - - use_native_fc = self.native_tool_calling - original_fncall_msgs = copy.deepcopy(formatted_messages) - - cc_tools: list[ChatCompletionToolParam] = [] - if tools: - cc_tools = [ - t.to_openai_tool( - add_security_risk_prediction=add_security_risk_prediction, - ) - for t in tools - ] - - use_mock_tools = self.should_mock_tool_calls(cc_tools) - if use_mock_tools: - formatted_messages, kwargs = self.pre_request_prompt_mock( - formatted_messages, - cc_tools or [], - kwargs, - include_security_params=add_security_risk_prediction, - ) - - kwargs["tools"] = cc_tools if (bool(cc_tools) and use_native_fc) else None - has_tools_flag = bool(cc_tools) and use_native_fc - call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) - - assert self._telemetry is not None - telemetry_ctx: dict[str, Any] = { - "context_window": self.effective_max_input_tokens or 0 - } - if self._telemetry.log_enabled: - telemetry_ctx.update( - { - "messages": formatted_messages[:], - "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, - } - ) - if tools and not use_native_fc: - telemetry_ctx["raw_messages"] = original_fncall_msgs + ctx = self._prepare_completion( + messages, tools, add_security_risk_prediction, on_token, kwargs + ) resp: ModelResponse | None = None - async for attempt in self.async_retry( - num_retries=self.num_retries, - retry_exceptions=LLM_RETRY_EXCEPTIONS, - retry_min_wait=self.retry_min_wait, - retry_max_wait=self.retry_max_wait, - retry_multiplier=self.retry_multiplier, - retry_listener=self._retry_listener_fn, - ): + async for attempt in self.async_retry(**self._retry_kwargs()): with attempt: assert self._telemetry is not None - self._telemetry.on_request(telemetry_ctx=telemetry_ctx) - resp = await self._atransport_call( - messages=formatted_messages, - **call_kwargs, - enable_streaming=enable_streaming, + self._telemetry.on_request(telemetry_ctx=ctx.telemetry_ctx) + raw = await self._atransport_call( + messages=ctx.formatted_messages, + **ctx.call_kwargs, + enable_streaming=ctx.enable_streaming, on_token=on_token, ) - raw_resp: ModelResponse | None = None - if use_mock_tools: - raw_resp = copy.deepcopy(resp) - resp = self.post_response_prompt_mock( - resp, - nonfncall_msgs=formatted_messages, - tools=cc_tools, - include_security_params=add_security_risk_prediction, - ) - self._telemetry.on_response(resp, raw_resp=raw_resp) - if not resp.get("choices") or len(resp["choices"]) < 1: - raise LLMNoResponseError( - "Response choices is less than 1. Response: " + str(resp) - ) + resp = self._completion_process_response(raw, ctx) try: assert resp is not None - first_choice = resp["choices"][0] - message = Message.from_llm_chat_message(first_choice["message"]) - metrics_snapshot = MetricsSnapshot( - model_name=self.metrics.model_name, - accumulated_cost=self.metrics.accumulated_cost, - max_budget_per_task=self.metrics.max_budget_per_task, - accumulated_token_usage=self.metrics.accumulated_token_usage, - ) - return LLMResponse( - message=message, metrics=metrics_snapshot, raw_response=resp - ) + return self._build_completion_llm_response(resp) except Exception as e: - # Fallback is synchronous; cast the token callback since the - # fallback LLM's sync path accepts TokenCallbackType. _fb_token = cast("TokenCallbackType | None", on_token) return await self._ahandle_error( e, @@ -1056,47 +1032,36 @@ async def acompletion( # ========================================================================= # Responses API (v1) # ========================================================================= - def responses( + @dataclass + class _ResponsesCtx: + """Pre-computed state shared by sync and async responses paths.""" + + instructions: str | None + input_items: list[dict[str, Any]] + resp_tools: list[Any] | None + call_kwargs: dict[str, Any] + telemetry_ctx: dict[str, Any] + user_enable_streaming: bool + + def _prepare_responses( self, messages: list[Message], - tools: Sequence[ToolDefinition] | None = None, - include: list[str] | None = None, - store: bool | None = None, - _return_metrics: bool = False, - add_security_risk_prediction: bool = False, - on_token: TokenCallbackType | None = None, - **kwargs, - ) -> LLMResponse: - """Alternative invocation path using OpenAI Responses API via LiteLLM. - - Maps Message[] -> (instructions, input[]) and returns LLMResponse. - - Args: - messages: List of conversation messages - tools: Optional list of tools available to the model - include: Optional list of fields to include in response - store: Whether to store the conversation - _return_metrics: Whether to return usage metrics - add_security_risk_prediction: Add security_risk field to tool schemas - on_token: Optional callback for streaming deltas - **kwargs: Additional arguments passed to the API - - Note: - Summary field is always added to tool schemas for transparency and - explainability of agent actions. - """ + tools: Sequence[ToolDefinition] | None, + include: list[str] | None, + store: bool | None, + add_security_risk_prediction: bool, + on_token: AnyTokenCallbackType | None, + kwargs: dict[str, Any], + ) -> _ResponsesCtx: + """Build all state needed before the responses transport call.""" user_enable_streaming = bool(kwargs.get("stream", False)) or self.stream if user_enable_streaming: if on_token is None and not self.is_subscription: - # We allow on_token to be None for subscription mode raise ValueError("Streaming requires an on_token callback") kwargs["stream"] = True - # Build instructions + input list using dedicated Responses formatter instructions, input_items = self.format_messages_for_responses(messages) - # Convert Tool objects to Responses ToolParam - # (Responses path always supports function tools) resp_tools = ( [ t.to_responses_tool( @@ -1108,14 +1073,11 @@ def responses( else None ) - # Normalize/override Responses kwargs consistently call_kwargs = select_responses_options( self, kwargs, include=include, store=store ) - # Request context for telemetry (always include context_window for metrics) assert self._telemetry is not None - # Always pass context_window so metrics are tracked even when logging disabled telemetry_ctx: dict[str, Any] = { "context_window": self.effective_max_input_tokens or 0 } @@ -1130,139 +1092,188 @@ def responses( } ) - # Perform call with retries - @self.retry_decorator( - num_retries=self.num_retries, - retry_exceptions=LLM_RETRY_EXCEPTIONS, - retry_min_wait=self.retry_min_wait, - retry_max_wait=self.retry_max_wait, - retry_multiplier=self.retry_multiplier, - retry_listener=self._retry_listener_fn, + return LLM._ResponsesCtx( + instructions=instructions, + input_items=input_items, + resp_tools=resp_tools, + call_kwargs=call_kwargs, + telemetry_ctx=telemetry_ctx, + user_enable_streaming=user_enable_streaming, + ) + + def _responses_common_litellm_kwargs( + self, ctx: _ResponsesCtx, extra: dict[str, Any] + ) -> dict[str, Any]: + """Build kwargs for ``litellm_responses`` / ``litellm_aresponses``.""" + typed_input: ResponseInputParam | str = ( + cast(ResponseInputParam, ctx.input_items) if ctx.input_items else "" + ) + return dict( + model=self.model, + input=typed_input, + instructions=ctx.instructions, + tools=ctx.resp_tools, + api_key=self._get_litellm_api_key_value(), + api_base=self.base_url, + api_version=self.api_version, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + **{**self._aws_kwargs(), **extra}, ) + + @staticmethod + def _responses_process_stream_event(event: Any) -> tuple[Any | None, str | None]: + """Extract collected output item and text delta from a stream event. + + Returns ``(item_or_none, delta_text_or_none)``. + """ + if event is None: + return None, None + item = None + evt_type = getattr(event, "type", None) + if evt_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: + raw_item = getattr(event, "item", None) + if raw_item is not None: + item = raw_item + delta: str | None = None + if isinstance( + event, + ( + OutputTextDeltaEvent, + RefusalDeltaEvent, + ReasoningSummaryTextDeltaEvent, + ), + ): + d = event.delta + if d: + delta = d + return item, delta + + @staticmethod + def _responses_finalize_stream( + completed_event: Any, + collected_output_items: list[Any], + ) -> ResponsesAPIResponse: + """Validate the completed stream event and patch output if needed.""" + if completed_event is None: + raise LLMNoResponseError( + "Responses stream finished without a completed response" + ) + if not isinstance(completed_event, ResponseCompletedEvent): + raise LLMNoResponseError( + f"Unexpected completed event: {type(completed_event)}" + ) + completed_resp = completed_event.response + if not completed_resp.output and collected_output_items: + completed_resp.output = collected_output_items + return completed_resp + + def _build_responses_llm_response( + self, completed: ResponsesAPIResponse + ) -> LLMResponse: + """Convert a ResponsesAPIResponse into an LLMResponse.""" + output_seq = cast(Sequence[Any], completed.output or []) + message = Message.from_llm_responses_output(output_seq) + return LLMResponse( + message=message, + metrics=self._metrics_snapshot(), + raw_response=completed, + ) + + def responses( + self, + messages: list[Message], + tools: Sequence[ToolDefinition] | None = None, + include: list[str] | None = None, + store: bool | None = None, + _return_metrics: bool = False, + add_security_risk_prediction: bool = False, + on_token: TokenCallbackType | None = None, + **kwargs, + ) -> LLMResponse: + """Alternative invocation path using OpenAI Responses API via LiteLLM. + + Maps Message[] -> (instructions, input[]) and returns LLMResponse. + + Args: + messages: List of conversation messages + tools: Optional list of tools available to the model + include: Optional list of fields to include in response + store: Whether to store the conversation + _return_metrics: Whether to return usage metrics + add_security_risk_prediction: Add security_risk field to tool schemas + on_token: Optional callback for streaming deltas + **kwargs: Additional arguments passed to the API + + Note: + Summary field is always added to tool schemas for transparency and + explainability of agent actions. + """ + ctx = self._prepare_responses( + messages, + tools, + include, + store, + add_security_risk_prediction, + on_token, + kwargs, + ) + + @self.retry_decorator(**self._retry_kwargs()) def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: assert self._telemetry is not None - self._telemetry.on_request(telemetry_ctx=telemetry_ctx) - final_kwargs = {**call_kwargs, **retry_kwargs} + self._telemetry.on_request(telemetry_ctx=ctx.telemetry_ctx) + final_kwargs = {**ctx.call_kwargs, **retry_kwargs} with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - typed_input: ResponseInputParam | str = ( - cast(ResponseInputParam, input_items) if input_items else "" - ) - api_key_value = self._get_litellm_api_key_value() - - ret = litellm_responses( - model=self.model, - input=typed_input, - instructions=instructions, - tools=resp_tools, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - **{**self._aws_kwargs(), **final_kwargs}, + litellm_kw = self._responses_common_litellm_kwargs( + ctx, final_kwargs ) + ret = litellm_responses(**litellm_kw) + if isinstance(ret, ResponsesAPIResponse): - if user_enable_streaming: + if ctx.user_enable_streaming: logger.warning( - "Responses streaming was requested, but the provider " - "returned a non-streaming response; no on_token deltas " - "will be emitted." + "Responses streaming was requested, but the " + "provider returned a non-streaming response; " + "no on_token deltas will be emitted." ) self._telemetry.on_response(ret) return ret - # When stream=True, LiteLLM returns a streaming iterator rather than - # a single ResponsesAPIResponse. Drain the iterator and use the - # completed response. if final_kwargs.get("stream", False): if not isinstance(ret, SyncResponsesAPIStreamingIterator): raise AssertionError( f"Expected Responses stream iterator, got {type(ret)}" ) - - stream_callback = on_token if user_enable_streaming else None - # Collect output items from streaming events. - # Some endpoints (e.g., Codex subscription) send output - # items as separate events but the final response.completed - # event has output=[]. We accumulate them here and patch - # the completed response if needed. - collected_output_items: list[Any] = [] + stream_cb = on_token if ctx.user_enable_streaming else None + collected: list[Any] = [] for event in ret: - if event is None: - continue - # Collect finished output items - evt_type = getattr(event, "type", None) - if evt_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: - item = getattr(event, "item", None) - if item is not None: - collected_output_items.append(item) - if stream_callback is None: - continue - if isinstance( - event, - ( - OutputTextDeltaEvent, - RefusalDeltaEvent, - ReasoningSummaryTextDeltaEvent, - ), - ): - delta = event.delta - if delta: - stream_callback( - ModelResponseStream( - choices=[ - StreamingChoices( - delta=Delta(content=delta) - ) - ] - ) + item, delta = self._responses_process_stream_event(event) + if item is not None: + collected.append(item) + if stream_cb is not None and delta: + stream_cb( + ModelResponseStream( + choices=[ + StreamingChoices(delta=Delta(content=delta)) + ] ) - - completed_event = ret.completed_response - if completed_event is None: - raise LLMNoResponseError( - "Responses stream finished without a completed response" - ) - if not isinstance(completed_event, ResponseCompletedEvent): - raise LLMNoResponseError( - f"Unexpected completed event: {type(completed_event)}" - ) - - completed_resp = completed_event.response - - # Patch empty output with items collected from stream - if not completed_resp.output and collected_output_items: - completed_resp.output = collected_output_items - - self._telemetry.on_response(completed_resp) - return completed_resp + ) + resp = self._responses_finalize_stream( + ret.completed_response, collected + ) + self._telemetry.on_response(resp) + return resp raise AssertionError( f"Expected ResponsesAPIResponse, got {type(ret)}" ) try: - resp: ResponsesAPIResponse = _one_attempt() - - # Parse output -> Message (typed) - # Cast to a typed sequence - # accepted by from_llm_responses_output - output_seq = cast(Sequence[Any], resp.output or []) - message = Message.from_llm_responses_output(output_seq) - - metrics_snapshot = MetricsSnapshot( - model_name=self.metrics.model_name, - accumulated_cost=self.metrics.accumulated_cost, - max_budget_per_task=self.metrics.max_budget_per_task, - accumulated_token_usage=self.metrics.accumulated_token_usage, - ) - - return LLMResponse( - message=message, metrics=metrics_snapshot, raw_response=resp - ) + return self._build_responses_llm_response(_one_attempt()) except Exception as e: return self._handle_error( e, @@ -1296,84 +1307,37 @@ async def aresponses( Uses ``litellm.aresponses`` under the hood, freeing the event loop while waiting for the LLM provider response. """ - user_enable_streaming = bool(kwargs.get("stream", False)) or self.stream - if user_enable_streaming: - if on_token is None and not self.is_subscription: - raise ValueError("Streaming requires an on_token callback") - kwargs["stream"] = True - - instructions, input_items = self.format_messages_for_responses(messages) - - resp_tools = ( - [ - t.to_responses_tool( - add_security_risk_prediction=add_security_risk_prediction, - ) - for t in tools - ] - if tools - else None - ) - - call_kwargs = select_responses_options( - self, kwargs, include=include, store=store + ctx = self._prepare_responses( + messages, + tools, + include, + store, + add_security_risk_prediction, + on_token, + kwargs, ) - assert self._telemetry is not None - telemetry_ctx: dict[str, Any] = { - "context_window": self.effective_max_input_tokens or 0 - } - if self._telemetry.log_enabled: - telemetry_ctx.update( - { - "llm_path": "responses", - "instructions": instructions, - "input": input_items[:], - "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, - } - ) - completed: ResponsesAPIResponse | None = None - async for attempt in self.async_retry( - num_retries=self.num_retries, - retry_exceptions=LLM_RETRY_EXCEPTIONS, - retry_min_wait=self.retry_min_wait, - retry_max_wait=self.retry_max_wait, - retry_multiplier=self.retry_multiplier, - retry_listener=self._retry_listener_fn, - ): + async for attempt in self.async_retry(**self._retry_kwargs()): with attempt: assert self._telemetry is not None - self._telemetry.on_request(telemetry_ctx=telemetry_ctx) - final_kwargs = {**call_kwargs} + self._telemetry.on_request(telemetry_ctx=ctx.telemetry_ctx) + final_kwargs = {**ctx.call_kwargs} with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - typed_input: ResponseInputParam | str = ( - cast(ResponseInputParam, input_items) if input_items else "" - ) - api_key_value = self._get_litellm_api_key_value() - - ret = await litellm_aresponses( - model=self.model, - input=typed_input, - instructions=instructions, - tools=resp_tools, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - **{**self._aws_kwargs(), **final_kwargs}, + litellm_kw = self._responses_common_litellm_kwargs( + ctx, final_kwargs ) + ret = await litellm_aresponses(**litellm_kw) + if isinstance(ret, ResponsesAPIResponse): - if user_enable_streaming: + if ctx.user_enable_streaming: logger.warning( - "Responses streaming was requested, but the " - "provider returned a non-streaming response; " - "no on_token deltas will be emitted." + "Responses streaming was requested, but " + "the provider returned a non-streaming " + "response; no on_token deltas will be " + "emitted." ) self._telemetry.on_response(ret) completed = ret @@ -1383,61 +1347,30 @@ async def aresponses( "Expected Responses async stream " f"iterator, got {type(ret)}" ) - - stream_cb = on_token if user_enable_streaming else None - collected_output_items: list[Any] = [] + stream_cb = on_token if ctx.user_enable_streaming else None + collected: list[Any] = [] async for event in ret: - if event is None: - continue - evt_type = getattr(event, "type", None) - if ( - evt_type - == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE - ): - item = getattr(event, "item", None) - if item is not None: - collected_output_items.append(item) - if stream_cb is None: - continue - if isinstance( - event, - ( - OutputTextDeltaEvent, - RefusalDeltaEvent, - ReasoningSummaryTextDeltaEvent, - ), - ): - delta = event.delta - if delta: - await _invoke_token_callback( - stream_cb, - ModelResponseStream( - choices=[ - StreamingChoices( - delta=Delta(content=delta) - ) - ] - ), - ) - - completed_event = ret.completed_response - if completed_event is None: - raise LLMNoResponseError( - "Responses stream finished without a " - "completed response" + item, delta = self._responses_process_stream_event( + event ) - if not isinstance(completed_event, ResponseCompletedEvent): - raise LLMNoResponseError( - "Unexpected completed event: " - f"{type(completed_event)}" - ) - - completed_resp = completed_event.response - if not completed_resp.output and collected_output_items: - completed_resp.output = collected_output_items - - self._telemetry.on_response(completed_resp) - completed = completed_resp + if item is not None: + collected.append(item) + if stream_cb is not None and delta: + await _invoke_token_callback( + stream_cb, + ModelResponseStream( + choices=[ + StreamingChoices( + delta=Delta(content=delta) + ) + ] + ), + ) + resp = self._responses_finalize_stream( + ret.completed_response, collected + ) + self._telemetry.on_response(resp) + completed = resp else: raise AssertionError( f"Expected ResponsesAPIResponse, got {type(ret)}" @@ -1445,17 +1378,7 @@ async def aresponses( try: assert completed is not None - output_seq = cast(Sequence[Any], completed.output or []) - message = Message.from_llm_responses_output(output_seq) - metrics_snapshot = MetricsSnapshot( - model_name=self.metrics.model_name, - accumulated_cost=self.metrics.accumulated_cost, - max_budget_per_task=self.metrics.max_budget_per_task, - accumulated_token_usage=self.metrics.accumulated_token_usage, - ) - return LLMResponse( - message=message, metrics=metrics_snapshot, raw_response=completed - ) + return self._build_responses_llm_response(completed) except Exception as e: _fb_token = cast("TokenCallbackType | None", on_token) return await self._ahandle_error( @@ -1506,6 +1429,43 @@ def _get_litellm_api_key_value(self) -> str | None: return api_key_value + @staticmethod + def _suppress_litellm_warnings() -> None: + """Install warning filters used by transport and responses calls.""" + warnings.filterwarnings("ignore", category=DeprecationWarning, module="httpx.*") + warnings.filterwarnings( + "ignore", + message=r".*content=.*upload.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"There is no current event loop", + category=DeprecationWarning, + ) + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="Accessing the 'model_fields' attribute.*", + ) + + def _common_completion_kwargs( + self, messages: list[dict[str, Any]], extra: dict[str, Any] + ) -> dict[str, Any]: + """Build kwargs for ``litellm_completion`` / ``litellm_acompletion``.""" + return dict( + model=self.model, + api_key=self._get_litellm_api_key_value(), + api_base=self.base_url, + api_version=self.api_version, + timeout=self.timeout, + drop_params=self.drop_params, + seed=self.seed, + messages=messages, + **{**self._aws_kwargs(), **extra}, + ) + def _transport_call( self, *, @@ -1514,51 +1474,14 @@ def _transport_call( on_token: TokenCallbackType | None = None, **kwargs, ) -> ModelResponse: - # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=DeprecationWarning, module="httpx.*" - ) - warnings.filterwarnings( - "ignore", - message=r".*content=.*upload.*", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - message=r"There is no current event loop", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - category=UserWarning, - ) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message="Accessing the 'model_fields' attribute.*", - ) - api_key_value = self._get_litellm_api_key_value() - - # When streaming, request usage in the final chunk so that - # detailed token breakdowns (prompt_tokens_details with - # cached_tokens, etc.) are not silently discarded by - # litellm's streaming handler. + self._suppress_litellm_warnings() if enable_streaming: kwargs.setdefault("stream_options", {"include_usage": True}) - # Some providers need renames handled in _normalize_call_kwargs. ret = litellm_completion( - model=self.model, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - messages=messages, - **{**self._aws_kwargs(), **kwargs}, + **self._common_completion_kwargs(messages, kwargs) ) if enable_streaming and on_token is not None: assert isinstance(ret, CustomStreamWrapper) @@ -1584,40 +1507,12 @@ async def _atransport_call( """Async variant of :meth:`_transport_call`.""" with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", category=DeprecationWarning, module="httpx.*" - ) - warnings.filterwarnings( - "ignore", - message=r".*content=.*upload.*", - category=DeprecationWarning, - ) - warnings.filterwarnings( - "ignore", - message=r"There is no current event loop", - category=DeprecationWarning, - ) - warnings.filterwarnings("ignore", category=UserWarning) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message="Accessing the 'model_fields' attribute.*", - ) - api_key_value = self._get_litellm_api_key_value() - + self._suppress_litellm_warnings() if enable_streaming: kwargs.setdefault("stream_options", {"include_usage": True}) ret = await litellm_acompletion( - model=self.model, - api_key=api_key_value, - api_base=self.base_url, - api_version=self.api_version, - timeout=self.timeout, - drop_params=self.drop_params, - seed=self.seed, - messages=messages, - **{**self._aws_kwargs(), **kwargs}, + **self._common_completion_kwargs(messages, kwargs) ) if enable_streaming and on_token is not None: assert isinstance(ret, CustomStreamWrapper) From 5d3cb9eb685e6b16a47efd4709e3687893378c10 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 18:02:53 +0000 Subject: [PATCH 2/4] refactor(agent): dedup async/sync step, execute_actions, handle_tool_calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract shared helpers to eliminate ~250 lines of duplication: - _step_check_pending(): shared pending actions & blocked message check - _handle_llm_error(): shared 3-way LLM error handling - _dispatch_llm_response(): shared response type routing - _prepare_action_events(): shared tool call → action event loop - _finalize_batch() / _batch_kwargs(): shared batch setup & finalization Co-authored-by: openhands --- openhands-sdk/openhands/sdk/agent/agent.py | 398 +++++++++--------- .../openhands/sdk/agent/response_dispatch.py | 80 ++-- 2 files changed, 226 insertions(+), 252 deletions(-) diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index efafbe69c9..c5cae109e1 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -4,7 +4,7 @@ import re from collections.abc import Callable from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from pydantic import PrivateAttr, ValidationError, model_validator @@ -488,22 +488,13 @@ def get_dynamic_context(self, state: ConversationState) -> str | None: additional_secret_infos=secret_infos, ) - def _execute_actions( + def _finalize_batch( self, + batch: _ActionBatch, conversation: LocalConversation, - action_events: list[ActionEvent], on_event: ConversationCallbackType, ) -> None: - """Prepare a batch, emit results, and handle finish.""" - state = conversation.state - batch = _ActionBatch.prepare( - action_events, - state=state, - executor=self._parallel_executor, - tool_runner=lambda ae: self._execute_action_event(conversation, ae), - tools=self.tools_map, - cancel_token=conversation.cancel_token, - ) + """Emit batch results and handle finish (shared by sync/async).""" batch.emit(on_event) batch.finalize( on_event=on_event, @@ -511,140 +502,123 @@ def _execute_actions( self._check_iterative_refinement(conversation, ae) ), mark_finished=lambda: setattr( - state, + conversation.state, "execution_status", ConversationExecutionStatus.FINISHED, ), ) - async def _aexecute_actions( + def _batch_kwargs( self, conversation: LocalConversation, action_events: list[ActionEvent], - on_event: ConversationCallbackType, - ) -> None: - """Async variant of :meth:`_execute_actions`. - - Each tool call runs in its own thread via - :meth:`ParallelToolExecutor.aexecute_batch`, giving the event - loop an ``await`` boundary between every tool invocation. - """ - state = conversation.state - batch = await _ActionBatch.aprepare( - action_events, - state=state, + ) -> dict[str, Any]: + """Build common kwargs for _ActionBatch.prepare / aprepare.""" + return dict( + action_events=action_events, + state=conversation.state, executor=self._parallel_executor, tool_runner=lambda ae: self._execute_action_event(conversation, ae), tools=self.tools_map, cancel_token=conversation.cancel_token, ) - batch.emit(on_event) - batch.finalize( - on_event=on_event, - check_iterative_refinement=lambda ae: ( - self._check_iterative_refinement(conversation, ae) - ), - mark_finished=lambda: setattr( - state, - "execution_status", - ConversationExecutionStatus.FINISHED, - ), - ) - @observe(name="agent.step", ignore_inputs=["state", "on_event"]) - def step( + def _execute_actions( self, conversation: LocalConversation, + action_events: list[ActionEvent], + on_event: ConversationCallbackType, + ) -> None: + """Prepare a batch, emit results, and handle finish.""" + batch = _ActionBatch.prepare(**self._batch_kwargs(conversation, action_events)) + self._finalize_batch(batch, conversation, on_event) + + async def _aexecute_actions( + self, + conversation: LocalConversation, + action_events: list[ActionEvent], on_event: ConversationCallbackType, - on_token: ConversationTokenCallbackType | None = None, ) -> None: + """Async variant of :meth:`_execute_actions`.""" + batch = await _ActionBatch.aprepare( + **self._batch_kwargs(conversation, action_events) + ) + self._finalize_batch(batch, conversation, on_event) + + def _step_check_pending( + self, + conversation: LocalConversation, + ) -> list[ActionEvent] | None: + """Return pending actions if they exist, else check blocked messages. + + Returns pending actions list (possibly empty if blocked message found + and FINISHED was set), or None if step should proceed to LLM call. + """ state = conversation.state - # Check for pending actions (implicit confirmation) - # and execute them before sampling new actions. - pending_actions = ConversationState.get_unmatched_actions(state.events) - if pending_actions: + pending = ConversationState.get_unmatched_actions(state.events) + if pending: logger.info( "Confirmation mode: Executing %d pending action(s)", - len(pending_actions), + len(pending), ) - self._execute_actions(conversation, pending_actions, on_event) - return + return pending - # Check if the last user message was blocked by a UserPromptSubmit hook - # If so, skip processing and mark conversation as finished if state.last_user_message_id is not None: reason = state.pop_blocked_message(state.last_user_message_id) if reason is not None: logger.info(f"User message blocked by hook: {reason}") state.execution_status = ConversationExecutionStatus.FINISHED - return + return [] # empty = handled, don't proceed elif state.blocked_messages: logger.debug( - "Blocked messages exist but last_user_message_id is None; " - "skipping hook check for legacy conversation state." + "Blocked messages exist but last_user_message_id is " + "None; skipping hook check for legacy conversation " + "state." ) + return None - # Prepare LLM messages using the utility function - _messages_or_condensation = prepare_llm_messages( - state.events, condenser=self.condenser, llm=self.llm - ) - - # Process condensation event before agent sampels another action - if isinstance(_messages_or_condensation, Condensation): - on_event(_messages_or_condensation) - return - - _messages = _messages_or_condensation - - logger.debug( - "Sending messages to LLM: " - f"{json.dumps([m.model_dump() for m in _messages[1:]], indent=2)}" - ) - - try: - llm_response = make_llm_completion( - self.llm, - _messages, - tools=list(self.tools_map.values()), - on_token=on_token, - ) - except FunctionCallValidationError as e: - logger.warning(f"LLM generated malformed function call: {e}") - error_message = MessageEvent( - source="user", - llm_message=Message( - role="user", - content=[TextContent(text=str(e))], - ), + def _handle_llm_error( + self, + error: Exception, + on_event: ConversationCallbackType, + ) -> bool: + """Handle LLM errors shared by step/astep. Returns True if handled.""" + if isinstance(error, FunctionCallValidationError): + logger.warning(f"LLM generated malformed function call: {error}") + on_event( + MessageEvent( + source="user", + llm_message=Message( + role="user", + content=[TextContent(text=str(error))], + ), + ) ) - on_event(error_message) - return - except LLMMalformedConversationHistoryError as e: - # The provider rejected the current message history as structurally - # invalid (for example, broken tool_use/tool_result pairing). Route - # this into condensation recovery, but keep the logs distinct from - # true context-window exhaustion so upstream event-stream bugs remain - # visible. + return True + + if isinstance(error, LLMMalformedConversationHistoryError): if ( self.condenser is not None and self.condenser.handles_condensation_requests() ): logger.warning( "LLM raised malformed conversation history error, " - "triggering condensation retry with condensed history: " - f"{e}" + "triggering condensation retry with condensed " + "history: %s", + error, ) on_event(CondensationRequest()) - return + return True logger.warning( - "LLM raised malformed conversation history error but no " - "condenser can handle condensation requests. This usually " - "indicates an upstream event-stream or resume bug: " - f"{e}" + "LLM raised malformed conversation history error but " + "no condenser can handle condensation requests. This " + "usually indicates an upstream event-stream or resume " + "bug: %s", + error, ) - raise e - except LLMContextWindowExceedError as e: - # If condenser is available and handles requests, trigger condensation + return False + + if isinstance(error, LLMContextWindowExceedError): if ( self.condenser is not None and self.condenser.handles_condensation_requests() @@ -653,23 +627,53 @@ def step( "LLM raised context window exceeded error, triggering condensation" ) on_event(CondensationRequest()) - return - # No condenser available or doesn't handle requests; log helpful warning + return True self._log_context_window_exceeded_warning() - raise e + return False + + return False - # LLMResponse already contains the converted message and metrics snapshot + def _dispatch_llm_response( + self, + llm_response: LLMResponse, + conversation: LocalConversation, + on_event: ConversationCallbackType, + *, + async_mode: bool = False, + ) -> Any: + """Route LLM response to the appropriate handler. + + Returns None for sync, or a coroutine for async (when + ``async_mode=True`` and the response has tool calls). + """ message: Message = llm_response.message + state = conversation.state response_type = classify_response(message) match response_type: case LLMResponseType.TOOL_CALLS: + if async_mode: + return self._ahandle_tool_calls( + message, + llm_response, + conversation, + state, + on_event, + ) self._handle_tool_calls( - message, llm_response, conversation, state, on_event + message, + llm_response, + conversation, + state, + on_event, ) case LLMResponseType.CONTENT: self._handle_content_response( - message, llm_response, conversation, state, on_event + message, + llm_response, + conversation, + state, + on_event, ) case LLMResponseType.REASONING_ONLY | LLMResponseType.EMPTY: self._handle_no_content_response( @@ -680,6 +684,57 @@ def step( on_event, response_type=response_type, ) + return None + + @observe(name="agent.step", ignore_inputs=["state", "on_event"]) + def step( + self, + conversation: LocalConversation, + on_event: ConversationCallbackType, + on_token: ConversationTokenCallbackType | None = None, + ) -> None: + pending = self._step_check_pending(conversation) + if pending is not None: + if pending: + self._execute_actions(conversation, pending, on_event) + return + + _messages_or_condensation = prepare_llm_messages( + conversation.state.events, + condenser=self.condenser, + llm=self.llm, + ) + if isinstance(_messages_or_condensation, Condensation): + on_event(_messages_or_condensation) + return + + _messages = _messages_or_condensation + logger.debug( + "Sending messages to LLM: " + f"{json.dumps([m.model_dump() for m in _messages[1:]], indent=2)}" + ) + + try: + llm_response = make_llm_completion( + self.llm, + _messages, + tools=list(self.tools_map.values()), + on_token=on_token, + ) + except ( + FunctionCallValidationError, + LLMMalformedConversationHistoryError, + LLMContextWindowExceedError, + ) as e: + if self._handle_llm_error(e, on_event): + return + raise + + self._dispatch_llm_response( + llm_response, + conversation, + on_event, + ) async def astep( self, @@ -696,39 +751,26 @@ async def astep( parallel calls with :func:`asyncio.gather`, keeping the event loop responsive during blocking tool I/O. """ - state = conversation.state - # Check for pending actions (implicit confirmation) - pending_actions = ConversationState.get_unmatched_actions(state.events) - if pending_actions: - logger.info( - "Confirmation mode: Executing %d pending action(s)", - len(pending_actions), - ) - await self._aexecute_actions(conversation, pending_actions, on_event) + pending = self._step_check_pending(conversation) + if pending is not None: + if pending: + await self._aexecute_actions( + conversation, + pending, + on_event, + ) return - if state.last_user_message_id is not None: - reason = state.pop_blocked_message(state.last_user_message_id) - if reason is not None: - logger.info(f"User message blocked by hook: {reason}") - state.execution_status = ConversationExecutionStatus.FINISHED - return - elif state.blocked_messages: - logger.debug( - "Blocked messages exist but last_user_message_id is None; " - "skipping hook check for legacy conversation state." - ) - _messages_or_condensation = await aprepare_llm_messages( - state.events, condenser=self.condenser, llm=self.llm + conversation.state.events, + condenser=self.condenser, + llm=self.llm, ) - if isinstance(_messages_or_condensation, Condensation): on_event(_messages_or_condensation) return _messages = _messages_or_condensation - logger.debug( "Sending messages to LLM: " f"{json.dumps([m.model_dump() for m in _messages[1:]], indent=2)}" @@ -741,81 +783,23 @@ async def astep( tools=list(self.tools_map.values()), on_token=on_token, ) - except FunctionCallValidationError as e: - logger.warning(f"LLM generated malformed function call: {e}") - error_message = MessageEvent( - source="user", - llm_message=Message( - role="user", - content=[TextContent(text=str(e))], - ), - ) - on_event(error_message) - return - except LLMMalformedConversationHistoryError as e: - # The provider rejected the current message history as - # structurally invalid (for example, broken - # tool_use/tool_result pairing). Route this into - # condensation recovery, but keep the logs distinct from - # true context-window exhaustion so upstream event-stream - # bugs remain visible. - if ( - self.condenser is not None - and self.condenser.handles_condensation_requests() - ): - logger.warning( - "LLM raised malformed conversation history error, " - "triggering condensation retry with condensed " - "history: %s", - e, - ) - on_event(CondensationRequest()) + except ( + FunctionCallValidationError, + LLMMalformedConversationHistoryError, + LLMContextWindowExceedError, + ) as e: + if self._handle_llm_error(e, on_event): return - logger.warning( - "LLM raised malformed conversation history error but " - "no condenser can handle condensation requests. This " - "usually indicates an upstream event-stream or resume " - "bug: %s", - e, - ) - raise e - except LLMContextWindowExceedError as e: - # If condenser is available and handles requests, trigger - # condensation - if ( - self.condenser is not None - and self.condenser.handles_condensation_requests() - ): - logger.warning( - "LLM raised context window exceeded error, triggering condensation" - ) - on_event(CondensationRequest()) - return - # No condenser available; log helpful warning - self._log_context_window_exceeded_warning() - raise e + raise - message: Message = llm_response.message - response_type = classify_response(message) - - match response_type: - case LLMResponseType.TOOL_CALLS: - await self._ahandle_tool_calls( - message, llm_response, conversation, state, on_event - ) - case LLMResponseType.CONTENT: - self._handle_content_response( - message, llm_response, conversation, state, on_event - ) - case LLMResponseType.REASONING_ONLY | LLMResponseType.EMPTY: - self._handle_no_content_response( - message, - llm_response, - conversation, - state, - on_event, - response_type=response_type, - ) + coro = self._dispatch_llm_response( + llm_response, + conversation, + on_event, + async_mode=True, + ) + if coro is not None: + await coro def _requires_user_confirmation( self, state: ConversationState, action_events: list[ActionEvent] diff --git a/openhands-sdk/openhands/sdk/agent/response_dispatch.py b/openhands-sdk/openhands/sdk/agent/response_dispatch.py index bd5166b075..9b65f14a69 100644 --- a/openhands-sdk/openhands/sdk/agent/response_dispatch.py +++ b/openhands-sdk/openhands/sdk/agent/response_dispatch.py @@ -184,15 +184,15 @@ def _evaluate_with_critic( event: ActionEvent | MessageEvent, ) -> CriticResult | None: ... - def _handle_tool_calls( + def _prepare_action_events( self, message: Message, llm_response: LLMResponse, conversation: LocalConversation, state: ConversationState, on_event: ConversationCallbackType, - ) -> None: - """Handle LLM response containing tool calls.""" + ) -> list[ActionEvent] | None: + """Build action events from tool calls. Returns None if confirmation needed.""" if not all(isinstance(c, TextContent) for c in message.content): logger.warning( "LLM returned tool calls but message content is not all " @@ -222,11 +222,27 @@ def _handle_tool_calls( action_events.append(action_event) if self._requires_user_confirmation(state, action_events): - return + return None + return action_events - if action_events: + def _handle_tool_calls( + self, + message: Message, + llm_response: LLMResponse, + conversation: LocalConversation, + state: ConversationState, + on_event: ConversationCallbackType, + ) -> None: + """Handle LLM response containing tool calls.""" + action_events = self._prepare_action_events( + message, + llm_response, + conversation, + state, + on_event, + ) + if action_events is not None and action_events: self._execute_actions(conversation, action_events, on_event) - self._maybe_emit_vllm_tokens(llm_response, on_event) async def _ahandle_tool_calls( @@ -237,46 +253,20 @@ async def _ahandle_tool_calls( state: ConversationState, on_event: ConversationCallbackType, ) -> None: - """Async variant of :meth:`_handle_tool_calls`. - - Delegates tool execution to :meth:`_aexecute_actions` so each - tool call runs in its own thread and multiple calls are scheduled - concurrently via :func:`asyncio.gather`. - """ - if not all(isinstance(c, TextContent) for c in message.content): - logger.warning( - "LLM returned tool calls but message content is not all " - "TextContent - ignoring non-text content" - ) - - thought_content = [c for c in message.content if isinstance(c, TextContent)] - - action_events: list[ActionEvent] = [] - assert message.tool_calls, "classify_response guarantees tool_calls" - for i, tool_call in enumerate(message.tool_calls): - action_event = self._get_action_event( - tool_call, - conversation=conversation, - llm_response_id=llm_response.id, - on_event=on_event, - security_analyzer=state.security_analyzer, - thought=thought_content if i == 0 else [], - reasoning_content=(message.reasoning_content if i == 0 else None), - thinking_blocks=(list(message.thinking_blocks) if i == 0 else []), - responses_reasoning_item=( - message.responses_reasoning_item if i == 0 else None - ), + """Async variant of :meth:`_handle_tool_calls`.""" + action_events = self._prepare_action_events( + message, + llm_response, + conversation, + state, + on_event, + ) + if action_events is not None and action_events: + await self._aexecute_actions( + conversation, + action_events, + on_event, ) - if action_event is None: - continue - action_events.append(action_event) - - if self._requires_user_confirmation(state, action_events): - return - - if action_events: - await self._aexecute_actions(conversation, action_events, on_event) - self._maybe_emit_vllm_tokens(llm_response, on_event) def _handle_content_response( From 5872197a2bf7f570ee2bf59318963f87ba7a7451 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 21 May 2026 18:20:44 +0000 Subject: [PATCH 3/4] refactor: dedup async/sync in agent utils, condensers, and conversation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Layer 3 (agent utils): - _finalize_messages(): shared condenser result processing Layer 4 (condensers): - _build_condensation_prompt(): shared prompt building - _condensation_from_response(): shared response → Condensation - _validate_forgotten_events(): shared validation for get/aget_condensation Layer 5 (conversation): - _check_loop_state(): shared status/hook/stuck checking - _check_post_step(): shared post-step iteration checking - _enter_run(): shared run initialization - _handle_run_error(): shared error handling Co-authored-by: openhands --- openhands-sdk/openhands/sdk/agent/utils.py | 82 ++-- .../condenser/llm_summarizing_condenser.py | 161 ++++---- .../conversation/impl/local_conversation.py | 356 +++++++----------- 3 files changed, 238 insertions(+), 361 deletions(-) diff --git a/openhands-sdk/openhands/sdk/agent/utils.py b/openhands-sdk/openhands/sdk/agent/utils.py index e636fee5fb..921dd0c18e 100644 --- a/openhands-sdk/openhands/sdk/agent/utils.py +++ b/openhands-sdk/openhands/sdk/agent/utils.py @@ -467,6 +467,24 @@ def prepare_llm_messages( ) -> list[Message] | Condensation: ... +def _finalize_messages( + condensation_result: View | Condensation, + llm_convertible_events: list[LLMConvertibleEvent], + additional_messages: list[Message] | None, +) -> list[Message] | Condensation: + """Apply condenser result and build final message list (shared helper).""" + match condensation_result: + case View(): + llm_convertible_events = condensation_result.events + case Condensation(): + return condensation_result + + messages = LLMConvertibleEvent.events_to_messages(llm_convertible_events) + if additional_messages: + messages.extend(additional_messages) + return messages + + def prepare_llm_messages( events: Sequence[Event], condenser: CondenserBase | None = None, @@ -493,31 +511,20 @@ def prepare_llm_messages( Raises: RuntimeError: If condensation is needed but no callback is provided """ - view = View.from_events(events) llm_convertible_events: list[LLMConvertibleEvent] = view.events - # If a condenser is registered, we need to give it an - # opportunity to transform the events. This will either - # produce a list of events, exactly as expected, or a - # new condensation that needs to be processed if condenser is not None: - condensation_result = condenser.condense(view, agent_llm=llm) - - match condensation_result: - case View(): - llm_convertible_events = condensation_result.events - - case Condensation(): - return condensation_result + result = condenser.condense(view, agent_llm=llm) + return _finalize_messages( + result, + llm_convertible_events, + additional_messages, + ) - # Convert events to messages messages = LLMConvertibleEvent.events_to_messages(llm_convertible_events) - - # Add any additional messages (e.g., user question for ask_agent) if additional_messages: messages.extend(additional_messages) - return messages @@ -559,13 +566,12 @@ def make_llm_completion( add_security_risk_prediction=True, on_token=on_token, ) - else: - return llm.completion( - messages=messages, - tools=tools or [], - add_security_risk_prediction=True, - on_token=on_token, - ) + return llm.completion( + messages=messages, + tools=tools or [], + add_security_risk_prediction=True, + on_token=on_token, + ) # --------------------------------------------------------------------------- @@ -588,19 +594,16 @@ async def aprepare_llm_messages( llm_convertible_events: list[LLMConvertibleEvent] = view.events if condenser is not None: - condensation_result = await condenser.acondense(view, agent_llm=llm) - - match condensation_result: - case View(): - llm_convertible_events = condensation_result.events - case Condensation(): - return condensation_result + result = await condenser.acondense(view, agent_llm=llm) + return _finalize_messages( + result, + llm_convertible_events, + additional_messages, + ) messages = LLMConvertibleEvent.events_to_messages(llm_convertible_events) - if additional_messages: messages.extend(additional_messages) - return messages @@ -620,10 +623,9 @@ async def amake_llm_completion( add_security_risk_prediction=True, on_token=on_token, ) - else: - return await llm.acompletion( - messages=messages, - tools=tools or [], - add_security_risk_prediction=True, - on_token=on_token, - ) + return await llm.acompletion( + messages=messages, + tools=tools or [], + add_security_risk_prediction=True, + on_token=on_token, + ) diff --git a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py index aae387be9b..7e2004055d 100644 --- a/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py +++ b/openhands-sdk/openhands/sdk/context/condenser/llm_summarizing_condenser.py @@ -17,7 +17,7 @@ from openhands.sdk.context.view import View from openhands.sdk.event.base import LLMConvertibleEvent from openhands.sdk.event.condenser import Condensation -from openhands.sdk.llm import LLM, Message, TextContent +from openhands.sdk.llm import LLM, LLMResponse, Message, TextContent from openhands.sdk.logger import get_logger from openhands.sdk.observability.laminar import observe from openhands.sdk.utils import maybe_truncate @@ -138,33 +138,20 @@ def condensation_requirement( if Reason.REQUEST in reasons: return CondensationRequirement.HARD - def _generate_condensation( + def _build_condensation_prompt( self, forgotten_events: Sequence[LLMConvertibleEvent], - summary_offset: int, max_event_str_length: int | None = None, - ) -> Condensation: - """Generate a condensation by using the condenser's LLM to summarize forgotten - events. - - Args: - forgotten_events: The list of events to be summarized. - summary_offset: The index where the summary event should be inserted. - max_event_str_length: Optional maximum length for each event string. If - provided, event strings longer than this will be truncated. - - Returns: - Condensation: The generated condensation object. - - Raises: - ValueError: If forgotten_events is empty (0 events to condense). - """ + ) -> list[Message]: + """Build the summarization prompt messages (shared by sync/async).""" assert len(forgotten_events) > 0, "No events to condense." - # Convert events to strings for the template event_strings = [ - maybe_truncate(str(forgotten_event), truncate_after=max_event_str_length) - for forgotten_event in forgotten_events + maybe_truncate( + str(fe), + truncate_after=max_event_str_length, + ) + for fe in forgotten_events ] prompt = render_template( @@ -172,15 +159,15 @@ def _generate_condensation( "summarizing_prompt.j2", events=event_strings, ) + return [Message(role="user", content=[TextContent(text=prompt)])] - messages = [Message(role="user", content=[TextContent(text=prompt)])] - - # Do not pass extra_body explicitly. The LLM handles forwarding - # litellm_extra_body only when it is non-empty. - llm_response = self.llm.completion( - messages=messages, - ) - # Extract summary from the LLMResponse message + @staticmethod + def _condensation_from_response( + llm_response: LLMResponse, + forgotten_events: Sequence[LLMConvertibleEvent], + summary_offset: int, + ) -> Condensation: + """Extract summary and build Condensation (shared by sync/async).""" summary = None if llm_response.message.content: first_content = llm_response.message.content[0] @@ -188,12 +175,30 @@ def _generate_condensation( summary = first_content.text return Condensation( - forgotten_event_ids={event.id for event in forgotten_events}, + forgotten_event_ids={e.id for e in forgotten_events}, summary=summary, summary_offset=summary_offset, llm_response_id=llm_response.id, ) + def _generate_condensation( + self, + forgotten_events: Sequence[LLMConvertibleEvent], + summary_offset: int, + max_event_str_length: int | None = None, + ) -> Condensation: + """Generate a condensation using the condenser's LLM.""" + messages = self._build_condensation_prompt( + forgotten_events, + max_event_str_length, + ) + llm_response = self.llm.completion(messages=messages) + return self._condensation_from_response( + llm_response, + forgotten_events, + summary_offset, + ) + def _get_forgotten_events( self, view: View, agent_llm: LLM | None = None ) -> tuple[Sequence[LLMConvertibleEvent], int]: @@ -309,31 +314,9 @@ def hard_context_reset( def get_condensation( self, view: View, agent_llm: LLM | None = None ) -> Condensation: - # The condensation is dependent on the events we want to drop and the previous - # summary. If we fail to find an appropriate set of events to forget raise an - # exception so the conversation can keep going until conditions change. - try: - forgotten_events, summary_offset = self._get_forgotten_events( - view, agent_llm=agent_llm - ) - except ValueError as e: - raise NoCondensationAvailableException( - "Unable to compute forgotten events" - ) from e - - if not forgotten_events: - raise NoCondensationAvailableException( - "Cannot condense 0 events. This typically occurs when a tool loop " - "spans almost the entire view, leaving no valid range for forgetting " - "events. Consider adjusting keep_first or max_size parameters." - ) - - if len(forgotten_events) < len(view) * self.minimum_progress: - raise NoCondensationAvailableException( - "Cannot apply condensation: events forgotten below minimum progress " - "threshold." - ) - + forgotten_events, summary_offset = self._validate_forgotten_events( + view, agent_llm + ) return self._generate_condensation( forgotten_events=forgotten_events, summary_offset=summary_offset, @@ -350,39 +333,23 @@ async def _agenerate_condensation( max_event_str_length: int | None = None, ) -> Condensation: """Async variant of :meth:`_generate_condensation`.""" - assert len(forgotten_events) > 0, "No events to condense." - - event_strings = [ - maybe_truncate(str(fe), truncate_after=max_event_str_length) - for fe in forgotten_events - ] - - prompt = render_template( - os.path.join(os.path.dirname(__file__), "prompts"), - "summarizing_prompt.j2", - events=event_strings, + messages = self._build_condensation_prompt( + forgotten_events, + max_event_str_length, ) - - messages = [Message(role="user", content=[TextContent(text=prompt)])] llm_response = await self.llm.acompletion(messages=messages) - - summary = None - if llm_response.message.content: - first_content = llm_response.message.content[0] - if isinstance(first_content, TextContent): - summary = first_content.text - - return Condensation( - forgotten_event_ids={event.id for event in forgotten_events}, - summary=summary, - summary_offset=summary_offset, - llm_response_id=llm_response.id, + return self._condensation_from_response( + llm_response, + forgotten_events, + summary_offset, ) - async def aget_condensation( - self, view: View, agent_llm: LLM | None = None - ) -> Condensation: - """Async variant of :meth:`get_condensation`.""" + def _validate_forgotten_events( + self, + view: View, + agent_llm: LLM | None = None, + ) -> tuple[Sequence[LLMConvertibleEvent], int]: + """Shared validation for get_condensation / aget_condensation.""" try: forgotten_events, summary_offset = self._get_forgotten_events( view, agent_llm=agent_llm @@ -394,18 +361,27 @@ async def aget_condensation( if not forgotten_events: raise NoCondensationAvailableException( - "Cannot condense 0 events. This typically occurs when a tool loop " - "spans almost the entire view, leaving no valid range for " - "forgetting events. Consider adjusting keep_first or max_size " - "parameters." + "Cannot condense 0 events. This typically occurs " + "when a tool loop spans almost the entire view, " + "leaving no valid range for forgetting events. " + "Consider adjusting keep_first or max_size." ) if len(forgotten_events) < len(view) * self.minimum_progress: raise NoCondensationAvailableException( - "Cannot apply condensation: events forgotten below minimum " - "progress threshold." + "Cannot apply condensation: events forgotten " + "below minimum progress threshold." ) + return forgotten_events, summary_offset + + async def aget_condensation( + self, view: View, agent_llm: LLM | None = None + ) -> Condensation: + """Async variant of :meth:`get_condensation`.""" + forgotten_events, summary_offset = self._validate_forgotten_events( + view, agent_llm + ) return await self._agenerate_condensation( forgotten_events=forgotten_events, summary_offset=summary_offset, @@ -436,7 +412,8 @@ async def ahard_context_reset( ) logger.warning( f"Hard context reset summarization failed: {e}. " - f"Reducing max event size to {max_event_str_length}." + f"Reducing max event size to " + f"{max_event_str_length}." ) attempts_remaining -= 1 diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index a6bb1ec60e..3d44f72101 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -765,23 +765,89 @@ def send_message(self, message: str | Message, sender: str | None = None) -> Non ) self._on_event(user_msg_event) - @observe(name="conversation.run") - def run(self) -> None: - """Runs the conversation until the agent finishes. + def _check_loop_state(self) -> str: + """Check execution status and return a loop directive. - In confirmation mode: - - First call: creates actions but doesn't execute them, stops and waits - - Second call: executes pending actions (implicit confirmation) + Must be called while ``self._state`` lock is held. - In normal mode: - - Creates and executes actions immediately - - Can be paused between steps + Returns: + ``"break"`` – exit the loop + ``"continue"`` – skip to the next iteration + ``"step"`` – proceed with an agent step """ - # Ensure agent is fully initialized (loads plugins and initializes agent) + if self._state.execution_status in [ + ConversationExecutionStatus.PAUSED, + ConversationExecutionStatus.STUCK, + ]: + return "break" + + if self._state.execution_status == ConversationExecutionStatus.FINISHED: + if self._hook_processor is not None: + should_stop, feedback = self._hook_processor.run_stop( + reason="agent_finished" + ) + if not should_stop: + logger.info("Stop hook denied agent stopping") + if feedback: + prefixed = f"[Stop hook feedback] {feedback}" + self._on_event( + MessageEvent( + source="environment", + llm_message=Message( + role="user", + content=[TextContent(text=prefixed)], + ), + ) + ) + self._state.execution_status = ConversationExecutionStatus.RUNNING + return "continue" + return "break" + + if self._stuck_detector: + if self._stuck_detector.is_stuck(): + logger.warning("Stuck pattern detected.") + self._state.execution_status = ConversationExecutionStatus.STUCK + return "continue" + + if ( + self._state.execution_status + == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION + ): + self._state.execution_status = ConversationExecutionStatus.RUNNING + + return "step" + + def _check_post_step(self, iteration: int) -> str | None: + """Check post-step conditions. Returns ``"break"`` or None.""" + if ( + self.state.execution_status + == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION + ): + return "break" + + if iteration >= self.max_iteration_per_run: + if self._state.execution_status == ConversationExecutionStatus.FINISHED: + return "break" + error_msg = ( + f"Agent reached maximum iterations limit " + f"({self.max_iteration_per_run})." + ) + logger.error(error_msg) + self._state.execution_status = ConversationExecutionStatus.ERROR + self._on_event( + ConversationErrorEvent( + source="environment", + code="MaxIterationsReached", + detail=error_msg, + ) + ) + return "break" + return None + + def _enter_run(self) -> None: + """Common setup for run/arun.""" self._ensure_agent_ready() self._cancel_token = CancellationToken() - with self._state: if self._state.execution_status in [ ConversationExecutionStatus.IDLE, @@ -791,126 +857,59 @@ def run(self) -> None: ]: self._state.execution_status = ConversationExecutionStatus.RUNNING + def _handle_run_error(self, error: Exception) -> None: + """Common error handling for run/arun (not CancelledError).""" + with self._state: + self._state.execution_status = ConversationExecutionStatus.ERROR + self._on_event( + ConversationErrorEvent( + source="environment", + code=error.__class__.__name__, + detail=str(error), + ) + ) + + @observe(name="conversation.run") + def run(self) -> None: + """Runs the conversation until the agent finishes. + + In confirmation mode: + - First call: creates actions but doesn't execute them, stops and waits + - Second call: executes pending actions (implicit confirmation) + + In normal mode: + - Creates and executes actions immediately + + Can be paused between steps + """ + self._enter_run() + iteration = 0 try: while True: logger.debug(f"Conversation run iteration {iteration}") with self._state: - # Pause attempts to acquire the state lock - # Before value can be modified step can be taken - # Ensure step conditions are checked when lock is already acquired - if self._state.execution_status in [ - ConversationExecutionStatus.PAUSED, - ConversationExecutionStatus.STUCK, - ]: - break - - # Handle stop hooks on FINISHED - if ( - self._state.execution_status - == ConversationExecutionStatus.FINISHED - ): - if self._hook_processor is not None: - should_stop, feedback = self._hook_processor.run_stop( - reason="agent_finished" - ) - if not should_stop: - logger.info("Stop hook denied agent stopping") - if feedback: - prefixed = f"[Stop hook feedback] {feedback}" - feedback_msg = MessageEvent( - source="environment", - llm_message=Message( - role="user", - content=[TextContent(text=prefixed)], - ), - ) - self._on_event(feedback_msg) - self._state.execution_status = ( - ConversationExecutionStatus.RUNNING - ) - continue - # No hooks or hooks allowed stopping + directive = self._check_loop_state() + if directive == "break": break - - # Check for stuck patterns if enabled - if self._stuck_detector: - is_stuck = self._stuck_detector.is_stuck() - - if is_stuck: - logger.warning("Stuck pattern detected.") - self._state.execution_status = ( - ConversationExecutionStatus.STUCK - ) - continue - - # clear the flag before calling agent.step() (user approved) - if ( - self._state.execution_status - == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION - ): - self._state.execution_status = ( - ConversationExecutionStatus.RUNNING - ) + if directive == "continue": + continue self.agent.step( - self, on_event=self._on_event, on_token=self._on_token + self, + on_event=self._on_event, + on_token=self._on_token, ) iteration += 1 - # Check for non-finished terminal conditions - # Note: We intentionally do NOT check for FINISHED status here. - # This allows concurrent user messages to be processed: - # 1. Agent finishes and sets status to FINISHED - # 2. User sends message concurrently via send_message() - # 3. send_message() waits for FIFO lock, then sets status to IDLE - # 4. Run loop continues to next iteration and processes the message - # 5. Without this design, concurrent messages would be lost - if ( - self.state.execution_status - == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION - ): - break - - if iteration >= self.max_iteration_per_run: - # If the agent finished on this final iteration, - # preserve the FINISHED status rather than - # overwriting it with ERROR. - if ( - self._state.execution_status - == ConversationExecutionStatus.FINISHED - ): - break - error_msg = ( - f"Agent reached maximum iterations limit " - f"({self.max_iteration_per_run})." - ) - logger.error(error_msg) - self._state.execution_status = ConversationExecutionStatus.ERROR - self._on_event( - ConversationErrorEvent( - source="environment", - code="MaxIterationsReached", - detail=error_msg, - ) - ) + if self._check_post_step(iteration) == "break": break except Exception as e: - with self._state: - self._state.execution_status = ConversationExecutionStatus.ERROR - - # Add an error event - self._on_event( - ConversationErrorEvent( - source="environment", - code=e.__class__.__name__, - detail=str(e), - ) - ) - - # Re-raise with conversation id and persistence dir for better UX + self._handle_run_error(e) raise ConversationRunError( - self._state.id, e, persistence_dir=self._state.persistence_dir + self._state.id, + e, + persistence_dir=self._state.persistence_dir, ) from e finally: self._cancel_token = None @@ -926,79 +925,20 @@ async def arun(self) -> None: :meth:`interrupt` can cancel it mid-LLM-call. On ``CancelledError`` the conversation transitions to ``PAUSED`` and emits an :class:`InterruptEvent`. - - A fresh :class:`CancellationToken` is created per run so that - ``interrupt()`` can signal in-flight tool calls to abort. After - ``CancelledError`` any ``ActionEvent`` without a matching - observation is patched with a synthetic ``AgentErrorEvent`` so - the LLM conversation history stays consistent. """ - self._ensure_agent_ready() + self._enter_run() self._arun_task = asyncio.current_task() - self._cancel_token = CancellationToken() - - with self._state: - if self._state.execution_status in [ - ConversationExecutionStatus.IDLE, - ConversationExecutionStatus.PAUSED, - ConversationExecutionStatus.ERROR, - ConversationExecutionStatus.STUCK, - ]: - self._state.execution_status = ConversationExecutionStatus.RUNNING iteration = 0 try: while True: logger.debug(f"Conversation arun iteration {iteration}") with self._state: - if self._state.execution_status in [ - ConversationExecutionStatus.PAUSED, - ConversationExecutionStatus.STUCK, - ]: + directive = self._check_loop_state() + if directive == "break": break - - if ( - self._state.execution_status - == ConversationExecutionStatus.FINISHED - ): - if self._hook_processor is not None: - should_stop, feedback = self._hook_processor.run_stop( - reason="agent_finished" - ) - if not should_stop: - logger.info("Stop hook denied agent stopping") - if feedback: - prefixed = f"[Stop hook feedback] {feedback}" - feedback_msg = MessageEvent( - source="environment", - llm_message=Message( - role="user", - content=[TextContent(text=prefixed)], - ), - ) - self._on_event(feedback_msg) - self._state.execution_status = ( - ConversationExecutionStatus.RUNNING - ) - continue - break - - if self._stuck_detector: - is_stuck = self._stuck_detector.is_stuck() - if is_stuck: - logger.warning("Stuck pattern detected.") - self._state.execution_status = ( - ConversationExecutionStatus.STUCK - ) - continue - - if ( - self._state.execution_status - == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION - ): - self._state.execution_status = ( - ConversationExecutionStatus.RUNNING - ) + if directive == "continue": + continue await self.agent.astep( self, @@ -1007,62 +947,20 @@ async def arun(self) -> None: ) iteration += 1 - if ( - self.state.execution_status - == ConversationExecutionStatus.WAITING_FOR_CONFIRMATION - ): - break - - if iteration >= self.max_iteration_per_run: - if ( - self._state.execution_status - == ConversationExecutionStatus.FINISHED - ): - break - error_msg = ( - f"Agent reached maximum iterations limit " - f"({self.max_iteration_per_run})." - ) - logger.error(error_msg) - self._state.execution_status = ConversationExecutionStatus.ERROR - self._on_event( - ConversationErrorEvent( - source="environment", - code="MaxIterationsReached", - detail=error_msg, - ) - ) + if self._check_post_step(iteration) == "break": break except asyncio.CancelledError: - # CancelledError is intentionally NOT re-raised. ``interrupt()`` - # uses ``asyncio.Task.cancel()`` to break out of ``arun()`` and - # expects the task to terminate cleanly. Re-raising would - # propagate the cancellation to EventService/caller which would - # surface it as an unexpected error. Instead we transition to - # PAUSED so the conversation can be resumed later. logger.info("arun() interrupted via task cancellation") with self._state: - # Emit synthetic error observations for any ActionEvents - # that were in-flight when the interrupt landed. Without - # these the LLM history would contain tool-call requests - # with no tool-result, which causes provider errors on - # the next completion call. self._emit_orphaned_action_errors() - self._state.execution_status = ConversationExecutionStatus.PAUSED self._on_event(InterruptEvent()) except Exception as e: - with self._state: - self._state.execution_status = ConversationExecutionStatus.ERROR - self._on_event( - ConversationErrorEvent( - source="environment", - code=e.__class__.__name__, - detail=str(e), - ) - ) + self._handle_run_error(e) raise ConversationRunError( - self._state.id, e, persistence_dir=self._state.persistence_dir + self._state.id, + e, + persistence_dir=self._state.persistence_dir, ) from e finally: self._cancel_token = None From ca64da202e75b432f14d090414f7d35042936aba Mon Sep 17 00:00:00 2001 From: enyst Date: Thu, 21 May 2026 19:14:07 +0000 Subject: [PATCH 4/4] refactor(sdk): restore llm comments Co-authored-by: openhands --- openhands-sdk/openhands/sdk/llm/llm.py | 60 +++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index db3a7b7818..469e0abb23 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -794,10 +794,14 @@ def _prepare_completion( raise ValueError("Streaming requires an on_token callback") kwargs["stream"] = True + # 1) serialize messages formatted_messages = self.format_messages_for_llm(messages) + + # 2) choose function-calling strategy use_native_fc = self.native_tool_calling original_fncall_msgs = copy.deepcopy(formatted_messages) + # Convert Tool objects to ChatCompletionToolParam once here cc_tools: list[ChatCompletionToolParam] = [] if tools: cc_tools = [ @@ -820,18 +824,23 @@ def _prepare_completion( include_security_params=add_security_risk_prediction, ) + # 3) normalize provider params + # Only pass tools when native FC is active kwargs["tools"] = cc_tools if (bool(cc_tools) and use_native_fc) else None has_tools_flag = bool(cc_tools) and use_native_fc + # Behavior-preserving: delegate to select_chat_options call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) + # 4) request context for telemetry (always include context_window for metrics) assert self._telemetry is not None + # Always pass context_window so metrics are tracked even when logging disabled telemetry_ctx: dict[str, Any] = { "context_window": self.effective_max_input_tokens or 0 } if self._telemetry.log_enabled: telemetry_ctx.update( { - "messages": formatted_messages[:], + "messages": formatted_messages[:], # already simple dicts "tools": tools, "kwargs": {k: v for k, v in call_kwargs.items()}, } @@ -868,7 +877,12 @@ def _completion_process_response( tools=ctx.cc_tools, include_security_params=ctx.add_security_risk_prediction, ) + # 6) telemetry self._telemetry.on_response(resp, raw_resp=raw_resp) + + # Ensure at least one choice. + # Gemini sometimes returns empty choices; we raise LLMNoResponseError here + # inside the retry boundary so it is retried. if not resp.get("choices") or len(resp["choices"]) < 1: raise LLMNoResponseError( "Response choices is less than 1. Response: " + str(resp) @@ -877,8 +891,11 @@ def _completion_process_response( def _build_completion_llm_response(self, resp: ModelResponse) -> LLMResponse: """Convert a validated ModelResponse into an LLMResponse.""" + # Convert the first choice to an OpenHands Message first_choice = resp["choices"][0] message = Message.from_llm_chat_message(first_choice["message"]) + + # Create and return LLMResponse return LLMResponse( message=message, metrics=self._metrics_snapshot(), @@ -886,6 +903,7 @@ def _build_completion_llm_response(self, resp: ModelResponse) -> LLMResponse: ) def _metrics_snapshot(self) -> MetricsSnapshot: + # Get current metrics snapshot return MetricsSnapshot( model_name=self.metrics.model_name, accumulated_cost=self.metrics.accumulated_cost, @@ -952,10 +970,12 @@ def completion( messages, tools, add_security_risk_prediction, on_token, kwargs ) + # 5) do the call with retries @self.retry_decorator(**self._retry_kwargs()) def _one_attempt(**retry_kwargs) -> ModelResponse: assert self._telemetry is not None self._telemetry.on_request(telemetry_ctx=ctx.telemetry_ctx) + # Merge retry-modified kwargs (like temperature) with call_kwargs final_kwargs = {**ctx.call_kwargs, **retry_kwargs} resp = self._transport_call( messages=ctx.formatted_messages, @@ -1017,6 +1037,8 @@ async def acompletion( assert resp is not None return self._build_completion_llm_response(resp) except Exception as e: + # Fallback is synchronous; cast the token callback since the + # fallback LLM's sync path accepts TokenCallbackType. _fb_token = cast("TokenCallbackType | None", on_token) return await self._ahandle_error( e, @@ -1056,12 +1078,16 @@ def _prepare_responses( """Build all state needed before the responses transport call.""" user_enable_streaming = bool(kwargs.get("stream", False)) or self.stream if user_enable_streaming: + # We allow on_token to be None for subscription mode if on_token is None and not self.is_subscription: raise ValueError("Streaming requires an on_token callback") kwargs["stream"] = True + # Build instructions + input list using dedicated Responses formatter instructions, input_items = self.format_messages_for_responses(messages) + # Convert Tool objects to Responses ToolParam + # (Responses path always supports function tools) resp_tools = ( [ t.to_responses_tool( @@ -1073,11 +1099,14 @@ def _prepare_responses( else None ) + # Normalize/override Responses kwargs consistently call_kwargs = select_responses_options( self, kwargs, include=include, store=store ) + # Request context for telemetry (always include context_window for metrics) assert self._telemetry is not None + # Always pass context_window so metrics are tracked even when logging disabled telemetry_ctx: dict[str, Any] = { "context_window": self.effective_max_input_tokens or 0 } @@ -1131,6 +1160,7 @@ def _responses_process_stream_event(event: Any) -> tuple[Any | None, str | None] if event is None: return None, None item = None + # Collect finished output items evt_type = getattr(event, "type", None) if evt_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: raw_item = getattr(event, "item", None) @@ -1165,6 +1195,8 @@ def _responses_finalize_stream( f"Unexpected completed event: {type(completed_event)}" ) completed_resp = completed_event.response + + # Patch empty output with items collected from stream if not completed_resp.output and collected_output_items: completed_resp.output = collected_output_items return completed_resp @@ -1173,6 +1205,8 @@ def _build_responses_llm_response( self, completed: ResponsesAPIResponse ) -> LLMResponse: """Convert a ResponsesAPIResponse into an LLMResponse.""" + # Parse output -> Message (typed) + # Cast to a typed sequence accepted by from_llm_responses_output output_seq = cast(Sequence[Any], completed.output or []) message = Message.from_llm_responses_output(output_seq) return LLMResponse( @@ -1220,6 +1254,7 @@ def responses( kwargs, ) + # Perform call with retries @self.retry_decorator(**self._retry_kwargs()) def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: assert self._telemetry is not None @@ -1243,12 +1278,20 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: self._telemetry.on_response(ret) return ret + # When stream=True, LiteLLM returns a streaming iterator rather + # than a single ResponsesAPIResponse. Drain the iterator and + # use the completed response. if final_kwargs.get("stream", False): if not isinstance(ret, SyncResponsesAPIStreamingIterator): raise AssertionError( f"Expected Responses stream iterator, got {type(ret)}" ) stream_cb = on_token if ctx.user_enable_streaming else None + # Collect output items from streaming events. + # Some endpoints (e.g., Codex subscription) send output + # items as separate events but the final response.completed + # event has output=[]. We accumulate them here and patch + # the completed response if needed. collected: list[Any] = [] for event in ret: item, delta = self._responses_process_stream_event(event) @@ -1348,6 +1391,11 @@ async def aresponses( f"iterator, got {type(ret)}" ) stream_cb = on_token if ctx.user_enable_streaming else None + # Collect output items from streaming events. + # Some endpoints (e.g., Codex subscription) send output + # items as separate events but the final response.completed + # event has output=[]. We accumulate them here and patch + # the completed response if needed. collected: list[Any] = [] async for event in ret: item, delta = self._responses_process_stream_event( @@ -1474,10 +1522,15 @@ def _transport_call( on_token: TokenCallbackType | None = None, **kwargs, ) -> ModelResponse: + # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): self._suppress_litellm_warnings() if enable_streaming: + # When streaming, request usage in the final chunk so that + # detailed token breakdowns (prompt_tokens_details with + # cached_tokens, etc.) are not silently discarded by + # litellm's streaming handler. kwargs.setdefault("stream_options", {"include_usage": True}) ret = litellm_completion( @@ -1505,10 +1558,15 @@ async def _atransport_call( **kwargs, ) -> ModelResponse: """Async variant of :meth:`_transport_call`.""" + # litellm.modify_params is GLOBAL; guard it for thread-safety with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): self._suppress_litellm_warnings() if enable_streaming: + # When streaming, request usage in the final chunk so that + # detailed token breakdowns (prompt_tokens_details with + # cached_tokens, etc.) are not silently discarded by + # litellm's streaming handler. kwargs.setdefault("stream_options", {"include_usage": True}) ret = await litellm_acompletion(