diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 54bfab2ed28..fa3bde54472 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,22 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_pipeline_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> Exception: + """Callback executed when the runner pipeline encounters an error. + + This callback provides an opportunity to handle pipeline errors globally. + + Args: + invocation_context: The context for the entire invocation. + error: The exception that was raised during runner execution. + + Returns: + An Exception to be raised (either the original error or a new/modified one). + """ + return error diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index 5566349516a..732284346bb 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,7 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_pipeline_error_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -272,6 +273,33 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_pipeline_error_callback( + self, + *, + invocation_context: InvocationContext, + error: Exception, + ) -> Exception: + """Runs the `on_pipeline_error_callback` for all plugins sequentially, chaining the error.""" + for plugin in self.plugins: + try: + error = await plugin.on_pipeline_error_callback( + invocation_context=invocation_context, error=error + ) + except Exception as e: + error_message = ( + f"Error in plugin '{plugin.name}' during " + f"'on_pipeline_error_callback' callback: {e}" + ) + logger.error( + "Error in plugin '%s' during 'on_pipeline_error_callback'" + " callback: %s", + plugin.name, + e, + exc_info=True, + ) + raise RuntimeError(error_message) from e + return error + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: @@ -316,7 +344,13 @@ async def _run_callbacks( f"Error in plugin '{plugin.name}' during '{callback_name}'" f" callback: {e}" ) - logger.error(error_message, exc_info=True) + logger.error( + "Error in plugin '%s' during '%s' callback: %s", + plugin.name, + callback_name, + e, + exc_info=True, + ) raise RuntimeError(error_message) from e return None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 90f004e4f9a..086a80df95c 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1355,66 +1355,73 @@ async def _exec_with_plugin( plugin_manager = invocation_context.plugin_manager - # Step 1: Run the before_run callbacks to see if we should early exit. - early_exit_result = await plugin_manager.run_before_run_callback( - invocation_context=invocation_context - ) - if isinstance(early_exit_result, types.Content): - early_exit_event = Event( - invocation_id=invocation_context.invocation_id, - author='model', - content=early_exit_result, - ) - _apply_run_config_custom_metadata( - early_exit_event, invocation_context.run_config + try: + # Step 1: Run the before_run callbacks to see if we should early exit. + early_exit_result = await plugin_manager.run_before_run_callback( + invocation_context=invocation_context ) - if self._should_append_event(early_exit_event, is_live_call): - await self.session_service.append_event( - session=invocation_context.session, - event=early_exit_event, + if isinstance(early_exit_result, types.Content): + early_exit_event = Event( + invocation_id=invocation_context.invocation_id, + author='model', + content=early_exit_result, ) - yield early_exit_event - else: - # Step 2: Otherwise continue with normal execution - async with aclosing(execute_fn(invocation_context)) as agen: - async for event in agen: - _apply_run_config_custom_metadata( - event, invocation_context.run_config - ) - # Step 3: Run the on_event callbacks before persisting so callback - # changes are stored in the session and match the streamed event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - output_event = self._get_output_event( - original_event=event, - modified_event=modified_event, - run_config=invocation_context.run_config, + _apply_run_config_custom_metadata( + early_exit_event, invocation_context.run_config + ) + if self._should_append_event(early_exit_event, is_live_call): + await self.session_service.append_event( + session=invocation_context.session, + event=early_exit_event, ) + yield early_exit_event + else: + # Step 2: Otherwise continue with normal execution + async with aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + _apply_run_config_custom_metadata( + event, invocation_context.run_config + ) + # Step 3: Run the on_event callbacks before persisting so callback + # changes are stored in the session and match the streamed event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event + ) + output_event = self._get_output_event( + original_event=event, + modified_event=modified_event, + run_config=invocation_context.run_config, + ) - if is_live_call: - # Skip partial transcriptions for Live - if event.partial is not True and self._should_append_event( - event, is_live_call - ): - logger.debug('Appending live event: %s', output_event) - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - else: - if event.partial is not True: - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - - yield output_event + if is_live_call: + # Skip partial transcriptions for Live + if event.partial is not True and self._should_append_event( + event, is_live_call + ): + logger.debug('Appending live event: %s', output_event) + await self.session_service.append_event( + session=invocation_context.session, event=output_event + ) + else: + if event.partial is not True: + await self.session_service.append_event( + session=invocation_context.session, event=output_event + ) - # Step 4: Run the after_run callbacks to perform global cleanup tasks or - # finalizing logs and metrics data. - # This does NOT emit any event. - await plugin_manager.run_after_run_callback( - invocation_context=invocation_context - ) + yield output_event + except Exception as e: + if plugin_manager: + e = await plugin_manager.run_on_pipeline_error_callback( + invocation_context=invocation_context, error=e + ) + raise e + finally: + # Step 4: Run the after_run callbacks to perform global cleanup tasks or + # finalizing logs and metrics data. + # This does NOT emit any event. + await plugin_manager.run_after_run_callback( + invocation_context=invocation_context + ) async def _append_new_message_to_session( self, diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index 6c72a2a6650..668efe8a129 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -91,6 +91,12 @@ async def after_model_callback(self, **kwargs): async def on_model_error_callback(self, **kwargs): return await self._handle_callback("on_model_error_callback") + async def on_pipeline_error_callback(self, error: Exception, **kwargs): + self.call_log.append("on_pipeline_error_callback") + if "on_pipeline_error_callback" in self.exceptions_to_raise: + raise self.exceptions_to_raise["on_pipeline_error_callback"] + return self.return_values.get("on_pipeline_error_callback", error) + @pytest.fixture def service() -> PluginManager: @@ -252,6 +258,10 @@ async def test_all_callbacks_are_supported( llm_request=mock_context, error=mock_context, ) + await service.run_on_pipeline_error_callback( + invocation_context=mock_context, + error=ValueError("err"), + ) # Verify all callbacks were logged expected_callbacks = [ @@ -267,6 +277,7 @@ async def test_all_callbacks_are_supported( "before_model_callback", "after_model_callback", "on_model_error_callback", + "on_pipeline_error_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) @@ -363,3 +374,43 @@ async def test_set_skip_closing_plugins_false_reverts_to_closing( await service.close() plugin1.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_pipeline_error_callback_chaining( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests that on_pipeline_error_callback is called and errors are chained.""" + error1 = ValueError("Original error") + error2 = RuntimeError("Chained error") + plugin1.return_values["on_pipeline_error_callback"] = error2 + + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + result_err = await service.run_on_pipeline_error_callback( + invocation_context=Mock(), error=error1 + ) + + assert result_err is error2 + assert "on_pipeline_error_callback" in plugin1.call_log + assert "on_pipeline_error_callback" in plugin2.call_log + + +@pytest.mark.asyncio +async def test_pipeline_error_callback_exception_wrap( + service: PluginManager, plugin1: TestPlugin +): + """Tests that if on_pipeline_error_callback raises, it wraps in RuntimeError.""" + plugin1.exceptions_to_raise["on_pipeline_error_callback"] = ValueError( + "Callback crashed" + ) + service.register_plugin(plugin1) + + with pytest.raises(RuntimeError) as excinfo: + await service.run_on_pipeline_error_callback( + invocation_context=Mock(), error=ValueError("Original") + ) + + assert "Error in plugin 'plugin1'" in str(excinfo.value) + assert "on_pipeline_error_callback" in str(excinfo.value)