Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,11 @@ async def _run_node_async(
with tracer.start_as_current_span('invocation'):
# 1. Setup
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
user_id=user_id,
session_id=session_id,
get_session_config=run_config.get_session_config
if run_config
else None,
)

# Validate and resolve resume inputs
Expand Down Expand Up @@ -1000,7 +1004,9 @@ async def run_async(

if self.agent.mode == 'chat':
session = await self._get_or_create_session(
user_id=user_id, session_id=session_id
user_id=user_id,
session_id=session_id,
get_session_config=run_config.get_session_config,
)
# when the chat coordinator has task-mode sub-agents,
# the wrapper handles delegation via ctx.run_node. Don't let
Expand Down
28 changes: 27 additions & 1 deletion tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,9 +1695,28 @@ async def test_run_async_passes_get_session_config():
),
)

events_seen_by_agent = []

class EventCheckingAgent(BaseAgent):

def __init__(self, name: str):
super().__init__(name=name, sub_agents=[])

async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
events_seen_by_agent.extend(invocation_context.session.events)
yield Event(
invocation_id=invocation_context.invocation_id,
author=self.name,
content=types.Content(
role="model", parts=[types.Part(text="Test response")]
),
)

runner = Runner(
app_name=TEST_APP_ID,
agent=MockAgent("test_agent"),
agent=EventCheckingAgent("test_agent"),
session_service=session_service,
artifact_service=InMemoryArtifactService(),
)
Expand All @@ -1720,6 +1739,13 @@ async def test_run_async_passes_get_session_config():
assert len(events) >= 1
assert events[0].author == "test_agent"

# The agent should have only seen 3 historical events + 1 new message = 4 events.
assert len(events_seen_by_agent) == 4
assert events_seen_by_agent[0].invocation_id == "inv_7"
assert events_seen_by_agent[1].invocation_id == "inv_8"
assert events_seen_by_agent[2].invocation_id == "inv_9"
assert events_seen_by_agent[3].content.parts[0].text == "hello"


@pytest.mark.asyncio
async def test_run_async_teardown_on_aclose():
Expand Down
Loading