From a98a7a0171ff2a0e577f00bb051a8b126acac2f4 Mon Sep 17 00:00:00 2001 From: Liang Wu <18244712+wuliang229@users.noreply.github.com> Date: Wed, 17 Jun 2026 12:59:12 -0700 Subject: [PATCH 1/2] fix: remove live event buffering in runner --- src/google/adk/runners.py | 87 ++++----------------------------------- 1 file changed, 8 insertions(+), 79 deletions(-) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 850c26bbba8..397bb3aca4b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -66,10 +66,6 @@ logger = logging.getLogger('google_adk.' + __name__) -def _is_tool_call_or_response(event: Event) -> bool: - return bool(event.get_function_calls() or event.get_function_responses()) - - def _get_function_responses_from_content( content: types.Content, ) -> list[types.FunctionResponse]: @@ -80,21 +76,6 @@ def _get_function_responses_from_content( ] -def _is_transcription(event: Event) -> bool: - return ( - event.input_transcription is not None - or event.output_transcription is not None - ) - - -def _has_non_empty_transcription_text( - transcription: types.Transcription, -) -> bool: - return bool( - transcription and transcription.text and transcription.text.strip() - ) - - def _apply_run_config_custom_metadata( event: Event, run_config: RunConfig | None ) -> None: @@ -873,22 +854,6 @@ async def _exec_with_plugin( yield early_exit_event else: # Step 2: Otherwise continue with normal execution - # Note for live/bidi: - # the transcription may arrive later than the action(function call - # event and thus function response event). In this case, the order of - # transcription and function call event will be wrong if we just - # append as it arrives. To address this, we should check if there is - # transcription going on. If there is transcription going on, we - # should hold on appending the function call event until the - # transcription is finished. The transcription in progress can be - # identified by checking if the transcription event is partial. When - # the next transcription event is not partial, it means the previous - # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-partial) - # transcription event. - buffered_events: list[Event] = [] - is_transcribing: bool = False - async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: _apply_run_config_custom_metadata( @@ -906,50 +871,14 @@ async def _exec_with_plugin( ) if is_live_call: - if event.partial and _is_transcription(event): - is_transcribing = True - if is_transcribing and _is_tool_call_or_response(event): - # only buffer function call and function response event which is - # non-partial - buffered_events.append(output_event) - continue - # Note for live/bidi: for audio response, it's considered as - # non-partial event(event.partial=None) - # event.partial=False and event.partial=None are considered as - # non-partial event; event.partial=True is considered as partial - # event. - if event.partial is not True: - if _is_transcription(event) and ( - _has_non_empty_transcription_text(event.input_transcription) - or _has_non_empty_transcription_text( - event.output_transcription - ) - ): - # transcription end signal, append buffered events - is_transcribing = False - logger.debug( - 'Appending transcription finished event: %s', event - ) - if self._should_append_event(event, is_live_call): - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) - - for buffered_event in buffered_events: - logger.debug('Appending buffered event: %s', buffered_event) - await self.session_service.append_event( - session=invocation_context.session, event=buffered_event - ) - yield buffered_event # yield buffered events to caller - buffered_events = [] - else: - # non-transcription event or empty transcription event, for - # example, event that stores blob reference, should be appended. - if self._should_append_event(event, is_live_call): - logger.debug('Appending non-buffered event: %s', event) - await self.session_service.append_event( - session=invocation_context.session, event=output_event - ) + # Skip partial transcriptions for Live + if event.partial is not True and self._should_append_event( + event, is_live_call + ): + logger.debug('Appending live event: %s', output_event) + await self.session_service.append_event( + session=invocation_context.session, event=output_event + ) else: if event.partial is not True: await self.session_service.append_event( From 1de0881157ba31895a622a02d2c3ff4ccd06cf45 Mon Sep 17 00:00:00 2001 From: Liang Wu <18244712+wuliang229@users.noreply.github.com> Date: Wed, 17 Jun 2026 13:21:08 -0700 Subject: [PATCH 2/2] Modify tests. --- tests/unittests/streaming/test_streaming.py | 484 ++++++++++---------- 1 file changed, 246 insertions(+), 238 deletions(-) diff --git a/tests/unittests/streaming/test_streaming.py b/tests/unittests/streaming/test_streaming.py index d77b13e5385..409243a09e7 100644 --- a/tests/unittests/streaming/test_streaming.py +++ b/tests/unittests/streaming/test_streaming.py @@ -34,36 +34,36 @@ def test_streaming(): mock_model = testing_utils.MockModel.create([response1]) root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[], ) runner = testing_utils.InMemoryRunner( - root_agent=root_agent, response_modalities=['AUDIO'] + root_agent=root_agent, response_modalities=["AUDIO"] ) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'\x00\xFF', mime_type='audio/pcm') + blob=types.Blob(data=b"\x00\xFF", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' + assert res_events is not None, "Expected a list of events, got None." assert ( len(res_events) > 0 - ), 'Expected at least one response, but got an empty list.' + ), "Expected at least one response, but got an empty list." def test_live_streaming_function_call_single(): """Test live streaming with a single function call response.""" # Create a function call response function_call = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco', 'unit': 'celsius'} + name="get_weather", args={"location": "San Francisco", "unit": "celsius"} ) # Create LLM responses: function call followed by turn completion response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -73,16 +73,16 @@ def test_live_streaming_function_call_single(): mock_model = testing_utils.MockModel.create([response1, response2]) # Mock function that would be called - def get_weather(location: str, unit: str = 'celsius') -> dict: + def get_weather(location: str, unit: str = "celsius") -> dict: return { - 'temperature': 22, - 'condition': 'sunny', - 'location': location, - 'unit': unit, + "temperature": 22, + "condition": "sunny", + "location": location, + "unit": unit, } root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -136,14 +136,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'What is the weather in San Francisco?', mime_type='audio/pcm' + data=b"What is the weather in San Francisco?", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got a function call event function_call_found = False @@ -152,19 +152,19 @@ async def consume_responses(session: testing_utils.Session): for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': + if part.function_call and part.function_call.name == "get_weather": function_call_found = True - assert part.function_call.args['location'] == 'San Francisco' - assert part.function_call.args['unit'] == 'celsius' + assert part.function_call.args["location"] == "San Francisco" + assert part.function_call.args["unit"] == "celsius" elif ( part.function_response - and part.function_response.name == 'get_weather' + and part.function_response.name == "get_weather" ): function_response_found = True - assert part.function_response.response['temperature'] == 22 - assert part.function_response.response['condition'] == 'sunny' + assert part.function_response.response["temperature"] == 22 + assert part.function_response.response["condition"] == "sunny" - assert function_call_found, 'Expected a function call event.' + assert function_call_found, "Expected a function call event." # Note: In live streaming, function responses might be handled differently, # so we check for the function call which is the primary indicator of function calling working @@ -173,19 +173,19 @@ def test_live_streaming_function_call_multiple(): """Test live streaming with multiple function calls in sequence.""" # Create multiple function call responses function_call1 = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco'} + name="get_weather", args={"location": "San Francisco"} ) function_call2 = types.Part.from_function_call( - name='get_time', args={'timezone': 'PST'} + name="get_time", args={"timezone": "PST"} ) # Create LLM responses: two function calls followed by turn completion response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call1]), + content=types.Content(role="model", parts=[function_call1]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[function_call2]), + content=types.Content(role="model", parts=[function_call2]), turn_complete=False, ) response3 = LlmResponse( @@ -196,13 +196,13 @@ def test_live_streaming_function_call_multiple(): # Mock functions def get_weather(location: str) -> dict: - return {'temperature': 22, 'condition': 'sunny', 'location': location} + return {"temperature": 22, "condition": "sunny", "location": location} def get_time(timezone: str) -> dict: - return {'time': '14:30', 'timezone': timezone} + return {"time": "14:30", "timezone": timezone} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather, get_time], ) @@ -255,14 +255,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'What is the weather and current time?', mime_type='audio/pcm' + data=b"What is the weather and current time?", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check function calls weather_call_found = False @@ -272,33 +272,33 @@ async def consume_responses(session: testing_utils.Session): if event.content and event.content.parts: for part in event.content.parts: if part.function_call: - if part.function_call.name == 'get_weather': + if part.function_call.name == "get_weather": weather_call_found = True - assert part.function_call.args['location'] == 'San Francisco' - elif part.function_call.name == 'get_time': + assert part.function_call.args["location"] == "San Francisco" + elif part.function_call.name == "get_time": time_call_found = True - assert part.function_call.args['timezone'] == 'PST' + assert part.function_call.args["timezone"] == "PST" # In live streaming, we primarily check that function calls are generated correctly assert ( weather_call_found or time_call_found - ), 'Expected at least one function call.' + ), "Expected at least one function call." def test_live_streaming_function_call_parallel(): """Test live streaming with parallel function calls.""" # Create parallel function calls in the same response function_call1 = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco'} + name="get_weather", args={"location": "San Francisco"} ) function_call2 = types.Part.from_function_call( - name='get_weather', args={'location': 'New York'} + name="get_weather", args={"location": "New York"} ) # Create LLM response with parallel function calls response1 = LlmResponse( content=types.Content( - role='model', parts=[function_call1, function_call2] + role="model", parts=[function_call1, function_call2] ), turn_complete=False, ) @@ -310,11 +310,11 @@ def test_live_streaming_function_call_parallel(): # Mock function def get_weather(location: str) -> dict: - temperatures = {'San Francisco': 22, 'New York': 15} - return {'temperature': temperatures.get(location, 20), 'location': location} + temperatures = {"San Francisco": 22, "New York": 15} + return {"temperature": temperatures.get(location, 20), "location": location} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -367,14 +367,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'Compare weather in SF and NYC', mime_type='audio/pcm' + data=b"Compare weather in SF and NYC", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check parallel function calls sf_call_found = False @@ -383,28 +383,28 @@ async def consume_responses(session: testing_utils.Session): for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': - location = part.function_call.args['location'] - if location == 'San Francisco': + if part.function_call and part.function_call.name == "get_weather": + location = part.function_call.args["location"] + if location == "San Francisco": sf_call_found = True - elif location == 'New York': + elif location == "New York": nyc_call_found = True assert ( sf_call_found and nyc_call_found - ), 'Expected both location function calls.' + ), "Expected both location function calls." def test_live_streaming_function_call_with_error(): """Test live streaming with function call that returns an error.""" # Create a function call response function_call = types.Part.from_function_call( - name='get_weather', args={'location': 'Invalid Location'} + name="get_weather", args={"location": "Invalid Location"} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -415,12 +415,12 @@ def test_live_streaming_function_call_with_error(): # Mock function that returns an error for invalid locations def get_weather(location: str) -> dict: - if location == 'Invalid Location': - return {'error': 'Location not found'} - return {'temperature': 22, 'condition': 'sunny', 'location': location} + if location == "Invalid Location": + return {"error": "Location not found"} + return {"temperature": 22, "condition": "sunny", "location": location} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -473,37 +473,37 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'What is weather in Invalid Location?', mime_type='audio/pcm' + data=b"What is weather in Invalid Location?", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got the function call (error handling happens at execution time) function_call_found = False for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': + if part.function_call and part.function_call.name == "get_weather": function_call_found = True - assert part.function_call.args['location'] == 'Invalid Location' + assert part.function_call.args["location"] == "Invalid Location" - assert function_call_found, 'Expected function call event with error case.' + assert function_call_found, "Expected function call event with error case." def test_live_streaming_function_call_sync_tool(): """Test live streaming with synchronous function call.""" # Create a function call response function_call = types.Part.from_function_call( - name='calculate', args={'x': 5, 'y': 3} + name="calculate", args={"x": 5, "y": 3} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -514,10 +514,10 @@ def test_live_streaming_function_call_sync_tool(): # Mock sync function def calculate(x: int, y: int) -> dict: - return {'result': x + y, 'operation': 'addition'} + return {"result": x + y, "operation": "addition"} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[calculate], ) @@ -569,37 +569,37 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Calculate 5 plus 3', mime_type='audio/pcm') + blob=types.Blob(data=b"Calculate 5 plus 3", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check function call function_call_found = False for event in res_events: if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'calculate': + if part.function_call and part.function_call.name == "calculate": function_call_found = True - assert part.function_call.args['x'] == 5 - assert part.function_call.args['y'] == 3 + assert part.function_call.args["x"] == 5 + assert part.function_call.args["y"] == 3 - assert function_call_found, 'Expected calculate function call event.' + assert function_call_found, "Expected calculate function call event." def test_live_streaming_simple_streaming_tool(): """Test live streaming with a simple streaming tool (non-video).""" # Create a function call response for the streaming tool function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'AAPL'} + name="monitor_stock_price", args={"stock_symbol": "AAPL"} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -612,18 +612,18 @@ def test_live_streaming_simple_streaming_tool(): async def monitor_stock_price(stock_symbol: str): """Mock streaming tool that monitors stock prices.""" # Simulate some streaming updates - yield f'Stock {stock_symbol} price: $150' + yield f"Stock {stock_symbol} price: $150" await asyncio.sleep(0.1) - yield f'Stock {stock_symbol} price: $155' + yield f"Stock {stock_symbol} price: $155" await asyncio.sleep(0.1) - yield f'Stock {stock_symbol} price: $160' + yield f"Stock {stock_symbol} price: $160" def stop_streaming(function_name: str): """Stop the streaming tool.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, stop_streaming], ) @@ -675,13 +675,13 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor AAPL stock price', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor AAPL stock price", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got the streaming tool function call function_call_found = False @@ -690,26 +690,26 @@ async def consume_responses(session: testing_utils.Session): for part in event.content.parts: if ( part.function_call - and part.function_call.name == 'monitor_stock_price' + and part.function_call.name == "monitor_stock_price" ): function_call_found = True - assert part.function_call.args['stock_symbol'] == 'AAPL' + assert part.function_call.args["stock_symbol"] == "AAPL" assert ( function_call_found - ), 'Expected monitor_stock_price function call event.' + ), "Expected monitor_stock_price function call event." def test_live_streaming_video_streaming_tool(): """Test live streaming with a video streaming tool.""" # Create a function call response for the video streaming tool function_call = types.Part.from_function_call( - name='monitor_video_stream', args={} + name="monitor_video_stream", args={} ) # Create LLM responses response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse( @@ -727,13 +727,13 @@ async def monitor_video_stream(input_stream: LiveRequestQueue): try: # Try to get a frame from the queue with timeout live_req = await asyncio.wait_for(input_stream.get(), timeout=0.1) - if live_req.blob and live_req.blob.mime_type == 'image/jpeg': + if live_req.blob and live_req.blob.mime_type == "image/jpeg": frame_count += 1 - yield f'Processed frame {frame_count}: detected 2 people' + yield f"Processed frame {frame_count}: detected 2 people" except asyncio.TimeoutError: # No more frames, simulate detection anyway for testing frame_count += 1 - yield f'Simulated frame {frame_count}: detected 1 person' + yield f"Simulated frame {frame_count}: detected 1 person" await asyncio.sleep(0.1) def stop_streaming(function_name: str): @@ -741,7 +741,7 @@ def stop_streaming(function_name: str): pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_video_stream, stop_streaming], ) @@ -795,19 +795,19 @@ async def consume_responses(session: testing_utils.Session): # Send some mock video frames live_request_queue.send_realtime( - blob=types.Blob(data=b'fake_jpeg_data_1', mime_type='image/jpeg') + blob=types.Blob(data=b"fake_jpeg_data_1", mime_type="image/jpeg") ) live_request_queue.send_realtime( - blob=types.Blob(data=b'fake_jpeg_data_2', mime_type='image/jpeg') + blob=types.Blob(data=b"fake_jpeg_data_2", mime_type="image/jpeg") ) live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor video stream', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor video stream", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got the video streaming tool function call function_call_found = False @@ -816,32 +816,32 @@ async def consume_responses(session: testing_utils.Session): for part in event.content.parts: if ( part.function_call - and part.function_call.name == 'monitor_video_stream' + and part.function_call.name == "monitor_video_stream" ): function_call_found = True assert ( function_call_found - ), 'Expected monitor_video_stream function call event.' + ), "Expected monitor_video_stream function call event." def test_live_streaming_stop_streaming_tool(): """Test live streaming with stop_streaming functionality.""" # Create function calls for starting and stopping a streaming tool start_function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'TSLA'} + name="monitor_stock_price", args={"stock_symbol": "TSLA"} ) stop_function_call = types.Part.from_function_call( - name='stop_streaming', args={'function_name': 'monitor_stock_price'} + name="stop_streaming", args={"function_name": "monitor_stock_price"} ) # Create LLM responses: start streaming, then stop streaming response1 = LlmResponse( - content=types.Content(role='model', parts=[start_function_call]), + content=types.Content(role="model", parts=[start_function_call]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[stop_function_call]), + content=types.Content(role="model", parts=[stop_function_call]), turn_complete=False, ) response3 = LlmResponse( @@ -853,17 +853,17 @@ def test_live_streaming_stop_streaming_tool(): # Mock streaming tool and stop function async def monitor_stock_price(stock_symbol: str): """Mock streaming tool that monitors stock prices.""" - yield f'Started monitoring {stock_symbol}' + yield f"Started monitoring {stock_symbol}" while True: # Infinite stream (would be stopped by stop_streaming) - yield f'Stock {stock_symbol} price update' + yield f"Stock {stock_symbol} price update" await asyncio.sleep(0.1) def stop_streaming(function_name: str): """Stop the streaming tool.""" - return f'Stopped streaming for {function_name}' + return f"Stopped streaming for {function_name}" root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, stop_streaming], ) @@ -915,13 +915,13 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor TSLA and then stop', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor TSLA and then stop", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got both function calls monitor_call_found = False @@ -931,34 +931,34 @@ async def consume_responses(session: testing_utils.Session): if event.content and event.content.parts: for part in event.content.parts: if part.function_call: - if part.function_call.name == 'monitor_stock_price': + if part.function_call.name == "monitor_stock_price": monitor_call_found = True - assert part.function_call.args['stock_symbol'] == 'TSLA' - elif part.function_call.name == 'stop_streaming': + assert part.function_call.args["stock_symbol"] == "TSLA" + elif part.function_call.name == "stop_streaming": stop_call_found = True assert ( - part.function_call.args['function_name'] - == 'monitor_stock_price' + part.function_call.args["function_name"] + == "monitor_stock_price" ) - assert monitor_call_found, 'Expected monitor_stock_price function call event.' - assert stop_call_found, 'Expected stop_streaming function call event.' + assert monitor_call_found, "Expected monitor_stock_price function call event." + assert stop_call_found, "Expected stop_streaming function call event." def test_live_streaming_multiple_streaming_tools(): """Test live streaming with multiple streaming tools running simultaneously.""" # Create function calls for multiple streaming tools stock_function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'NVDA'} + name="monitor_stock_price", args={"stock_symbol": "NVDA"} ) video_function_call = types.Part.from_function_call( - name='monitor_video_stream', args={} + name="monitor_video_stream", args={} ) # Create LLM responses: start both streaming tools response1 = LlmResponse( content=types.Content( - role='model', parts=[stock_function_call, video_function_call] + role="model", parts=[stock_function_call, video_function_call] ), turn_complete=False, ) @@ -971,22 +971,22 @@ def test_live_streaming_multiple_streaming_tools(): # Mock streaming tools async def monitor_stock_price(stock_symbol: str): """Mock streaming tool that monitors stock prices.""" - yield f'Stock {stock_symbol} price: $800' + yield f"Stock {stock_symbol} price: $800" await asyncio.sleep(0.1) - yield f'Stock {stock_symbol} price: $805' + yield f"Stock {stock_symbol} price: $805" async def monitor_video_stream(input_stream: LiveRequestQueue): """Mock video streaming tool.""" - yield 'Video monitoring started' + yield "Video monitoring started" await asyncio.sleep(0.1) - yield 'Detected motion in video stream' + yield "Detected motion in video stream" def stop_streaming(function_name: str): """Stop the streaming tool.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, monitor_video_stream, stop_streaming], ) @@ -1039,14 +1039,14 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( blob=types.Blob( - data=b'Monitor both stock and video', mime_type='audio/pcm' + data=b"Monitor both stock and video", mime_type="audio/pcm" ) ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." # Check that we got both streaming tool function calls stock_call_found = False @@ -1056,39 +1056,38 @@ async def consume_responses(session: testing_utils.Session): if event.content and event.content.parts: for part in event.content.parts: if part.function_call: - if part.function_call.name == 'monitor_stock_price': + if part.function_call.name == "monitor_stock_price": stock_call_found = True - assert part.function_call.args['stock_symbol'] == 'NVDA' - elif part.function_call.name == 'monitor_video_stream': + assert part.function_call.args["stock_symbol"] == "NVDA" + elif part.function_call.name == "monitor_video_stream": video_call_found = True - assert stock_call_found, 'Expected monitor_stock_price function call event.' - assert video_call_found, 'Expected monitor_video_stream function call event.' + assert stock_call_found, "Expected monitor_stock_price function call event." + assert video_call_found, "Expected monitor_video_stream function call event." -def test_live_streaming_buffered_function_call_yielded_during_transcription(): - """Test that function calls buffered during transcription are yielded. +def test_live_streaming_function_call_yielded_before_finished_transcription(): + """Test that function calls arriving during live transcription are yielded immediately. - This tests the fix for the bug where function_call and function_response - events were buffered during active transcription but never yielded to the - caller. The fix ensures buffered events are yielded after transcription ends. + This verifies that tool call events are not buffered and are permitted to + arrive in the stream before the final completed transcription event. """ function_call = types.Part.from_function_call( - name='get_weather', args={'location': 'San Francisco'} + name="get_weather", args={"location": "San Francisco"} ) response1 = LlmResponse( - input_transcription=types.Transcription(text='Show'), + input_transcription=types.Transcription(text="Show"), partial=True, # ← Triggers is_transcribing = True ) response2 = LlmResponse( content=types.Content( - role='model', parts=[function_call] + role="model", parts=[function_call] ), # ← Gets buffered turn_complete=False, ) response3 = LlmResponse( - input_transcription=types.Transcription(text='Show me the weather'), + input_transcription=types.Transcription(text="Show me the weather"), partial=False, # ← Transcription ends, buffered events yielded ) response4 = LlmResponse( @@ -1100,10 +1099,10 @@ def test_live_streaming_buffered_function_call_yielded_during_transcription(): ) def get_weather(location: str) -> dict: - return {'temperature': 22, 'location': location} + return {"temperature": 22, "location": location} root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[get_weather], ) @@ -1154,41 +1153,50 @@ async def consume_responses(session: testing_utils.Session): runner = CustomTestRunner(root_agent=root_agent) live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Show me the weather', mime_type='audio/pcm') + blob=types.Blob(data=b"Show me the weather", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' - assert len(res_events) >= 1, 'Expected at least one event.' + assert res_events is not None, "Expected a list of events, got None." + assert len(res_events) >= 1, "Expected at least one event." - function_call_found = False - function_response_found = False + function_call_index = -1 + finished_transcription_index = -1 - for event in res_events: + for idx, event in enumerate(res_events): if event.content and event.content.parts: for part in event.content.parts: - if part.function_call and part.function_call.name == 'get_weather': - function_call_found = True - assert part.function_call.args['location'] == 'San Francisco' + if part.function_call and part.function_call.name == "get_weather": + function_call_index = idx + assert part.function_call.args["location"] == "San Francisco" if ( part.function_response - and part.function_response.name == 'get_weather' + and part.function_response.name == "get_weather" ): - function_response_found = True - assert part.function_response.response['temperature'] == 22 + assert part.function_response.response["temperature"] == 22 + if ( + event.input_transcription + and event.input_transcription.text == "Show me the weather" + ): + finished_transcription_index = idx - assert function_call_found, 'Buffered function_call event was not yielded.' + assert function_call_index != -1, "Function call event was not yielded." assert ( - function_response_found - ), 'Buffered function_response event was not yielded.' + finished_transcription_index != -1 + ), "Finished transcription event was not yielded." + assert function_call_index < finished_transcription_index, ( + f"Expected function call (at index {function_call_index}) to arrive" + " before finished transcription (at index" + f" {finished_transcription_index})." + ) def test_live_streaming_text_content_persisted_in_session(): """Test that user text content sent via send_content is persisted in session.""" response1 = LlmResponse( content=types.Content( - role='model', parts=[types.Part(text='Hello! How can I help you?')] + role="model", parts=[types.Part(text="Hello! How can I help you?")] ), turn_complete=True, ) @@ -1196,7 +1204,7 @@ def test_live_streaming_text_content_persisted_in_session(): mock_model = testing_utils.MockModel.create([response1]) root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[], ) @@ -1253,19 +1261,19 @@ async def consume_responses(session: testing_utils.Session): live_request_queue = LiveRequestQueue() # Send text content (not audio blob) - user_text = 'Hello, this is a test message' + user_text = "Hello, this is a test message" live_request_queue.send_content( - types.Content(role='user', parts=[types.Part(text=user_text)]) + types.Content(role="user", parts=[types.Part(text=user_text)]) ) res_events, session = runner.run_live_and_get_session(live_request_queue) - assert res_events is not None, 'Expected a list of events, got None.' + assert res_events is not None, "Expected a list of events, got None." # Check that user text content was persisted in the session user_content_found = False for event in session.events: - if event.author == 'user' and event.content: + if event.author == "user" and event.content: for part in event.content.parts: if part.text and user_text in part.text: user_content_found = True @@ -1273,7 +1281,7 @@ async def consume_responses(session: testing_utils.Session): assert user_content_found, ( f'Expected user text content "{user_text}" to be persisted in session. ' - f'Session events: {[e.content for e in session.events]}' + f"Session events: {[e.content for e in session.events]}" ) @@ -1328,16 +1336,16 @@ def test_input_streaming_tool_registered_lazily_with_stream(): # tool is NOT registered before the model calls it. text_response = LlmResponse( content=types.Content( - role='model', - parts=[types.Part(text='Processing...')], + role="model", + parts=[types.Part(text="Processing...")], ), turn_complete=False, ) function_call = types.Part.from_function_call( - name='monitor_video_stream', args={} + name="monitor_video_stream", args={} ) call_response = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) done_response = LlmResponse(turn_complete=True) @@ -1354,10 +1362,10 @@ async def monitor_video_stream( """Record whether input_stream was provided.""" nonlocal stream_state_during_call stream_state_during_call = input_stream is not None - yield 'monitoring started' + yield "monitoring started" root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_video_stream], ) @@ -1378,7 +1386,7 @@ def capturing_method(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test_data', mime_type='audio/pcm') + blob=types.Blob(data=b"test_data", mime_type="audio/pcm") ) # Collect events and check that the tool is NOT registered before @@ -1403,7 +1411,7 @@ async def consume(session: testing_utils.Session): and not response.get_function_calls() ): not_registered_before_call = ( - active is None or 'monitor_video_stream' not in active + active is None or "monitor_video_stream" not in active ) if len(collected) >= 4: return @@ -1413,28 +1421,28 @@ async def consume(session: testing_utils.Session): # Tool should not be registered before the model calls it. assert ( not_registered_before_call is True - ), 'Expected tool to NOT be registered before the model calls it' + ), "Expected tool to NOT be registered before the model calls it" # When the model calls the tool, input_stream should be provided. assert ( stream_state_during_call is True - ), 'Expected input_stream to be provided to the streaming tool when called' + ), "Expected input_stream to be provided to the streaming tool when called" def test_stop_streaming_resets_stream_to_none(): """Test that stop_streaming sets stream back to None.""" start_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'GOOG'} + name="monitor_stock_price", args={"stock_symbol": "GOOG"} ) stop_call = types.Part.from_function_call( - name='stop_streaming', args={'function_name': 'monitor_stock_price'} + name="stop_streaming", args={"function_name": "monitor_stock_price"} ) response1 = LlmResponse( - content=types.Content(role='model', parts=[start_call]), + content=types.Content(role="model", parts=[start_call]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[stop_call]), + content=types.Content(role="model", parts=[stop_call]), turn_complete=False, ) response3 = LlmResponse(turn_complete=True) @@ -1445,17 +1453,17 @@ async def monitor_stock_price( stock_symbol: str, ) -> AsyncGenerator[str, None]: """Yield periodic price updates for the given stock symbol.""" - yield f'Monitoring {stock_symbol}' + yield f"Monitoring {stock_symbol}" while True: await asyncio.sleep(0.1) - yield f'{stock_symbol} price update' + yield f"{stock_symbol} price update" def stop_streaming(function_name: str) -> None: """Stop a running streaming tool by name.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price, stop_streaming], ) @@ -1479,7 +1487,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'Monitor GOOG then stop', mime_type='audio/pcm') + blob=types.Blob(data=b"Monitor GOOG then stop", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue, max_responses=4) @@ -1487,32 +1495,32 @@ def capturing_create(*args, **kwargs) -> Any: # Verify both function calls were processed. call_names = _collect_function_call_names(res_events) assert ( - 'monitor_stock_price' in call_names - ), 'Expected monitor_stock_price function call.' + "monitor_stock_price" in call_names + ), "Expected monitor_stock_price function call." assert ( - 'stop_streaming' in call_names - ), 'Expected stop_streaming function call.' + "stop_streaming" in call_names + ), "Expected stop_streaming function call." # Verify that stop_streaming reset the stream to None. assert ( captured_child_context is not None - ), 'Expected child invocation context to be captured' + ), "Expected child invocation context to be captured" active_tools = captured_child_context.active_streaming_tools or {} assert ( - 'monitor_stock_price' in active_tools - ), 'Expected monitor_stock_price in active_streaming_tools' + "monitor_stock_price" in active_tools + ), "Expected monitor_stock_price in active_streaming_tools" assert ( - active_tools['monitor_stock_price'].stream is None - ), 'Expected stream to be reset to None after stop_streaming' + active_tools["monitor_stock_price"].stream is None + ), "Expected stream to be reset to None after stop_streaming" def test_output_streaming_tool_registered_lazily_without_stream(): """Test that output-streaming tools are registered lazily when called, with stream=None.""" function_call = types.Part.from_function_call( - name='monitor_stock_price', args={'stock_symbol': 'GOOG'} + name="monitor_stock_price", args={"stock_symbol": "GOOG"} ) response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse(turn_complete=True) @@ -1523,10 +1531,10 @@ async def monitor_stock_price( stock_symbol: str, ) -> AsyncGenerator[str, None]: """Yield periodic price updates.""" - yield f'price for {stock_symbol}' + yield f"price for {stock_symbol}" root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_stock_price], ) @@ -1548,7 +1556,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test', mime_type='audio/pcm') + blob=types.Blob(data=b"test", mime_type="audio/pcm") ) runner.run_live(live_request_queue, max_responses=3) @@ -1558,11 +1566,11 @@ def capturing_create(*args, **kwargs) -> Any: assert captured_child_context is not None active_tools = captured_child_context.active_streaming_tools or {} assert ( - 'monitor_stock_price' in active_tools - ), 'Expected output-streaming tool to be registered when called' + "monitor_stock_price" in active_tools + ), "Expected output-streaming tool to be registered when called" assert ( - active_tools['monitor_stock_price'].stream is None - ), 'Expected stream to be None for output-streaming tool' + active_tools["monitor_stock_price"].stream is None + ), "Expected stream to be None for output-streaming tool" def _run_single_tool_live( @@ -1581,7 +1589,7 @@ def _run_single_tool_live( name=func_name, args=func_args or {} ) response1 = LlmResponse( - content=types.Content(role='model', parts=[function_call]), + content=types.Content(role="model", parts=[function_call]), turn_complete=False, ) response2 = LlmResponse(turn_complete=True) @@ -1589,7 +1597,7 @@ def _run_single_tool_live( mock_model = testing_utils.MockModel.create([response1, response2]) root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[tool_func], ) @@ -1609,7 +1617,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test', mime_type='audio/pcm') + blob=types.Blob(data=b"test", mime_type="audio/pcm") ) runner.run_live(live_request_queue, max_responses=max_responses) @@ -1625,42 +1633,42 @@ async def monitor_video_stream( input_stream: LiveRequestQueue, ) -> AsyncGenerator[str, None]: """Simulate an input-streaming tool.""" - yield 'started' + yield "started" active_tools = _run_single_tool_live( - monitor_video_stream, 'monitor_video_stream' + monitor_video_stream, "monitor_video_stream" ) assert ( - 'monitor_video_stream' in active_tools - ), 'Expected input-streaming tool to be registered when called' + "monitor_video_stream" in active_tools + ), "Expected input-streaming tool to be registered when called" # Stream should be a LiveRequestQueue, not None. assert ( - active_tools['monitor_video_stream'].stream is not None - ), 'Expected .stream to be set for input-streaming tool' + active_tools["monitor_video_stream"].stream is not None + ), "Expected .stream to be set for input-streaming tool" assert isinstance( - active_tools['monitor_video_stream'].stream, LiveRequestQueue - ), 'Expected .stream to be a LiveRequestQueue instance' + active_tools["monitor_video_stream"].stream, LiveRequestQueue + ), "Expected .stream to be a LiveRequestQueue instance" def test_input_streaming_tool_stream_recreated_after_stop(): """Test that re-invoking an input-streaming tool after stop creates a new stream.""" - start_call = types.Part.from_function_call(name='monitor_video', args={}) + start_call = types.Part.from_function_call(name="monitor_video", args={}) stop_call = types.Part.from_function_call( - name='stop_streaming', args={'function_name': 'monitor_video'} + name="stop_streaming", args={"function_name": "monitor_video"} ) - restart_call = types.Part.from_function_call(name='monitor_video', args={}) + restart_call = types.Part.from_function_call(name="monitor_video", args={}) response1 = LlmResponse( - content=types.Content(role='model', parts=[start_call]), + content=types.Content(role="model", parts=[start_call]), turn_complete=False, ) response2 = LlmResponse( - content=types.Content(role='model', parts=[stop_call]), + content=types.Content(role="model", parts=[stop_call]), turn_complete=False, ) response3 = LlmResponse( - content=types.Content(role='model', parts=[restart_call]), + content=types.Content(role="model", parts=[restart_call]), turn_complete=False, ) response4 = LlmResponse(turn_complete=True) @@ -1677,17 +1685,17 @@ async def monitor_video( """Simulate an input-streaming tool that tracks invocation count.""" nonlocal call_count call_count += 1 - yield f'started (call {call_count})' + yield f"started (call {call_count})" while True: await asyncio.sleep(0.1) - yield 'frame' + yield "frame" def stop_streaming(function_name: str) -> None: """Stop a running streaming tool by name.""" pass root_agent = Agent( - name='root_agent', + name="root_agent", model=mock_model, tools=[monitor_video, stop_streaming], ) @@ -1707,7 +1715,7 @@ def capturing_create(*args, **kwargs) -> Any: live_request_queue = LiveRequestQueue() live_request_queue.send_realtime( - blob=types.Blob(data=b'test', mime_type='audio/pcm') + blob=types.Blob(data=b"test", mime_type="audio/pcm") ) res_events = runner.run_live(live_request_queue, max_responses=8) @@ -1719,16 +1727,16 @@ def capturing_create(*args, **kwargs) -> Any: fc.name for event in res_events for fc in event.get_function_calls() ] assert ( - call_names.count('monitor_video') >= 2 - ), f'Expected monitor_video called at least twice, got: {call_names}' + call_names.count("monitor_video") >= 2 + ), f"Expected monitor_video called at least twice, got: {call_names}" # After re-invocation, stream should be set again (not None). assert captured_child_context is not None active_tools = captured_child_context.active_streaming_tools or {} - assert 'monitor_video' in active_tools + assert "monitor_video" in active_tools assert ( - active_tools['monitor_video'].stream is not None - ), 'Expected .stream to be recreated after stop + re-invocation' + active_tools["monitor_video"].stream is not None + ), "Expected .stream to be recreated after stop + re-invocation" def test_async_gen_with_input_stream_wrong_annotation_gets_no_stream(): @@ -1739,22 +1747,22 @@ async def my_tool(input_stream: str) -> AsyncGenerator[str, None]: """Simulate an async generator whose input_stream is typed as str.""" nonlocal received_input_stream received_input_stream = input_stream - yield f'got: {input_stream}' + yield f"got: {input_stream}" active_tools = _run_single_tool_live( - my_tool, 'my_tool', func_args={'input_stream': 'some_value'} + my_tool, "my_tool", func_args={"input_stream": "some_value"} ) assert ( - 'my_tool' in active_tools - ), 'Expected async generator tool to be registered' + "my_tool" in active_tools + ), "Expected async generator tool to be registered" # Stream should be None because annotation is str, not LiveRequestQueue. - assert active_tools['my_tool'].stream is None, ( - 'Expected .stream to be None when input_stream annotation is not' - ' LiveRequestQueue' + assert active_tools["my_tool"].stream is None, ( + "Expected .stream to be None when input_stream annotation is not" + " LiveRequestQueue" ) # The tool should have received the model-provided arg value, not a # LiveRequestQueue. assert ( - received_input_stream == 'some_value' - ), 'Expected input_stream to be the model-provided string value' + received_input_stream == "some_value" + ), "Expected input_stream to be the model-provided string value"