From 86c02e61059021434bac423405b6f3221ad75e4e Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 14 Jan 2026 19:55:15 -0800 Subject: [PATCH 01/16] memorysession and asyncmemorysession --- integrations/langchain/pyproject.toml | 2 +- integrations/openai/pyproject.toml | 5 + .../src/databricks_openai/agents/session.py | 647 +++++++++++ .../openai/tests/unit_tests/test_session.py | 1035 +++++++++++++++++ 4 files changed, 1688 insertions(+), 1 deletion(-) create mode 100644 integrations/openai/src/databricks_openai/agents/session.py create mode 100644 integrations/openai/tests/unit_tests/test_session.py diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 4fa284b39..10ee6e01c 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ [project.optional-dependencies] memory = [ "langgraph-checkpoint-postgres>=2.0.5", - "databricks-ai-bridge[memory]>=0.10.0", + "databricks-ai-bridge[memory]>=0.11.0", ] [dependency-groups] diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index b3eea7f83..492f3810a 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -20,6 +20,11 @@ dependencies = [ "openai-agents>=0.5.0" ] +[project.optional-dependencies] +memory = [ + "databricks-ai-bridge[memory]>=0.11.0", +] + [dependency-groups] dev = [ "typing_extensions>=4.15.0", diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py new file mode 100644 index 000000000..3c6e7549b --- /dev/null +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timezone +from threading import Lock +from typing import Any, Dict, Optional, Tuple, Union, cast +from uuid import UUID + +from databricks.sdk import WorkspaceClient + +try: + from agents.items import TResponseInputItem + from agents.memory.session import SessionABC + from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool + from psycopg import sql + from psycopg.sql import Composed +except ImportError as e: + raise ImportError( + "MemorySession requires databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) from e + +logger = logging.getLogger(__name__) + +# Module-level pool cache: instance_name -> LakebasePool +_pool_cache: Dict[str, LakebasePool] = {} +_pool_cache_lock = Lock() + +# Module-level async pool cache: instance_name -> AsyncLakebasePool +_async_pool_cache: Dict[str, AsyncLakebasePool] = {} +_async_pool_cache_lock = asyncio.Lock() + + +def _get_or_create_pool( + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + **pool_kwargs, +) -> LakebasePool: + """Get cached pool or create new one for this instance.""" + cache_key = instance_name + + with _pool_cache_lock: + if cache_key not in _pool_cache: + logger.info(f"Creating new LakebasePool for {cache_key}") + _pool_cache[cache_key] = LakebasePool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + return _pool_cache[cache_key] + + +async def _get_or_create_async_pool( + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + **pool_kwargs, +) -> AsyncLakebasePool: + """Get cached async pool or create new one for this instance.""" + cache_key = instance_name + + async with _async_pool_cache_lock: + if cache_key not in _async_pool_cache: + logger.info(f"Creating new AsyncLakebasePool for {cache_key}") + pool = AsyncLakebasePool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + await pool.open() + _async_pool_cache[cache_key] = pool + return _async_pool_cache[cache_key] + + +class _MemorySessionBase(SessionABC): + """ + Base class with shared SQL, configuration, and helper methods for memory sessions. + + Subclasses implement sync or async pool initialization and database operations. + """ + + # Table names + SESSIONS_TABLE = "agent_sessions" + MESSAGES_TABLE = "agent_messages" + + CREATE_SESSIONS_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS {sessions_table} ( + session_id UUID PRIMARY KEY, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + ); + """ + + CREATE_MESSAGES_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS {messages_table} ( + id BIGSERIAL PRIMARY KEY, + session_id UUID NOT NULL REFERENCES {sessions_table}(session_id) ON DELETE CASCADE, + message_data JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS {idx_session_id} + ON {messages_table}(session_id); + CREATE INDEX IF NOT EXISTS {idx_session_order} + ON {messages_table}(session_id, id); + """ + + session_id: UUID + + def __init__( + self, + session_id: UUID, + *, + sessions_table: str = SESSIONS_TABLE, + messages_table: str = MESSAGES_TABLE, + ) -> None: + """ + Initialize base session attributes. + + Args: + session_id: UUID identifier for this conversation session. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + """ + self.session_id = session_id + self.sessions_table = sessions_table + self.messages_table = messages_table + + # --- SQL Building Helpers --- + + def _build_create_sessions_sql(self) -> Composed: + """Build SQL to create the sessions table.""" + return sql.SQL(self.CREATE_SESSIONS_TABLE_SQL).format( + sessions_table=sql.Identifier(self.sessions_table) + ) + + def _build_create_messages_sql(self) -> Composed: + """Build SQL to create the messages table.""" + return sql.SQL(self.CREATE_MESSAGES_TABLE_SQL).format( + sessions_table=sql.Identifier(self.sessions_table), + messages_table=sql.Identifier(self.messages_table), + idx_session_id=sql.Identifier(f"idx_{self.messages_table}_session_id"), + idx_session_order=sql.Identifier(f"idx_{self.messages_table}_session_order"), + ) + + def _build_ensure_session_sql(self) -> Composed: + """Build SQL to insert session record if not exists.""" + return sql.SQL( + """ + INSERT INTO {} (session_id, created_at, updated_at) + VALUES (%s, %s, %s) + ON CONFLICT (session_id) DO NOTHING + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_get_items_query( + self, limit: int | None + ) -> Tuple[Composed, Tuple[UUID, ...] | Tuple[UUID, int]]: + """Build SQL query and params to get items.""" + if limit is not None: + query = sql.SQL( + """ + SELECT message_data FROM ( + SELECT message_data, id + FROM {} + WHERE session_id = %s + ORDER BY id DESC + LIMIT %s + ) sub + ORDER BY id ASC + """ + ).format(sql.Identifier(self.messages_table)) + params: Tuple[UUID, ...] | Tuple[UUID, int] = (self.session_id, limit) + else: + query = sql.SQL( + """ + SELECT message_data + FROM {} + WHERE session_id = %s + ORDER BY id ASC + """ + ).format(sql.Identifier(self.messages_table)) + params = (self.session_id,) + return query, params + + def _build_add_items_sql(self) -> Composed: + """Build SQL to insert message items.""" + return sql.SQL( + """ + INSERT INTO {} (session_id, message_data) + VALUES (%s, %s) + """ + ).format(sql.Identifier(self.messages_table)) + + def _build_update_session_timestamp_sql(self) -> Composed: + """Build SQL to update session timestamp.""" + return sql.SQL( + """ + UPDATE {} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = %s + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_update_session_timestamp_with_value_sql(self) -> Composed: + """Build SQL to update session timestamp with explicit value.""" + return sql.SQL( + """ + UPDATE {} + SET updated_at = %s + WHERE session_id = %s + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_pop_item_sql(self) -> Composed: + """Build SQL to delete and return most recent item.""" + messages_table_id = sql.Identifier(self.messages_table) + return sql.SQL( + """ + DELETE FROM {messages_table} + WHERE id = ( + SELECT id + FROM {messages_table} + WHERE session_id = %s + ORDER BY id DESC + LIMIT 1 + ) + RETURNING message_data + """ + ).format(messages_table=messages_table_id) + + def _build_clear_session_sql(self) -> Composed: + """Build SQL to delete all messages for session.""" + return sql.SQL("DELETE FROM {} WHERE session_id = %s").format( + sql.Identifier(self.messages_table) + ) + + def _prepare_items_for_insert(self, items: list[TResponseInputItem]) -> list[Tuple[UUID, str]]: + """Prepare items for database insertion.""" + return [(self.session_id, json.dumps(item)) for item in items] + + def _parse_message_data(self, message_data: Union[str, dict[str, Any]]) -> TResponseInputItem: + """Parse message_data from database (may be JSON string or dict).""" + if isinstance(message_data, str): + return cast(TResponseInputItem, json.loads(message_data)) + return cast(TResponseInputItem, message_data) + + def _parse_rows_to_items(self, rows: list) -> list[TResponseInputItem]: + """Parse database rows to list of items.""" + return [self._parse_message_data(row["message_data"]) for row in rows] + + +class MemorySession(_MemorySessionBase): + """ + OpenAI Agents SDK Session implementation using Lakebase for persistent storage. + + This class follows the Session protocol for conversation memory, + storing session data in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + SessionABC: https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.SessionABC + + Example: + ```python + from uuid import UUID + from databricks_openai.agents.session import MemorySession + from agents import Agent, Runner + + async def run_agent(thread_id: UUID | None, message: str): + # Use uuid7 for time-ordered UUIDs (better for database indexing) + session_id = thread_id + session = MemorySession( + session_id=session_id, + instance_name="my-lakebase-instance" + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + """ + + def __init__( + self, + session_id: UUID, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + sessions_table: str = _MemorySessionBase.SESSIONS_TABLE, + messages_table: str = _MemorySessionBase.MESSAGES_TABLE, + **pool_kwargs, + ) -> None: + """ + Initialize a MemorySession. + + On first initialization for a given Lakebase instance, this will automatically + create the required tables if they don't exist. + + Args: + session_id: UUID identifier for this conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + **pool_kwargs: Additional arguments passed to LakebasePool. + """ + super().__init__( + session_id=session_id, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._pool = _get_or_create_pool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + + if not self._tables_exist(): + self._create_tables() + + self._ensure_session() + + def _tables_exist(self) -> bool: + """Check if both session tables already exist.""" + with self._pool.connection() as conn: + result = conn.execute( + """ + SELECT COUNT(*) as cnt FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name IN (%s, %s) + """, + (self.sessions_table, self.messages_table), + ) + row = result.fetchone() + return row["cnt"] == 2 + + def _create_tables(self) -> None: + """Create the required tables.""" + with self._pool.connection() as conn: + conn.execute(self._build_create_sessions_sql()) + conn.execute(self._build_create_messages_sql()) + logger.info(f"Created tables {self.sessions_table}, {self.messages_table}") + + def _ensure_session(self) -> None: + """Ensure the session record exists in agent_sessions table.""" + now = datetime.now(timezone.utc) + with self._pool.connection() as conn: + conn.execute( + self._build_ensure_session_sql(), + (self.session_id, now, now), + ) + logger.debug(f"Ensured session {self.session_id} exists") + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """ + Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history. + """ + query, params = self._build_get_items_query(limit) + with self._pool.connection() as conn: + result = conn.execute(query, params) + rows = result.fetchall() + return self._parse_rows_to_items(rows) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """ + Add new items to the conversation history. + + Args: + items: List of input items to add to the history. + """ + if not items: + return + + with self._pool.connection() as conn: + with conn.cursor() as cur: + cur.executemany( + self._build_add_items_sql(), + self._prepare_items_for_insert(items), + ) + conn.execute( + self._build_update_session_timestamp_sql(), + (self.session_id,), + ) + logger.debug(f"Added {len(items)} items to session {self.session_id}") + + async def pop_item(self) -> TResponseInputItem | None: + """ + Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty. + """ + with self._pool.connection() as conn: + result = conn.execute( + self._build_pop_item_sql(), + (self.session_id,), + ) + row = result.fetchone() + + if row: + now = datetime.now(timezone.utc) + conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + + if row: + logger.debug(f"Popped item from session {self.session_id}") + return self._parse_message_data(row["message_data"]) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + with self._pool.connection() as conn: + result = conn.execute( + self._build_clear_session_sql(), + (self.session_id,), + ) + count = result.rowcount + + now = datetime.now(timezone.utc) + conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + logger.info(f"Cleared {count} items from session {self.session_id}") + + +class AsyncMemorySession(_MemorySessionBase): + """ + OpenAI Agents SDK Session implementation using Lakebase for persistent storage (async version). + + This class follows the Session protocol for conversation memory, + storing session data in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + SessionABC: https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.SessionABC + + Example: + ```python + from uuid import UUID + from databricks_openai.agents.session import AsyncMemorySession + from agents import Agent, Runner + + async def run_agent(thread_id: UUID | None, message: str): + session_id = thread_id + session = AsyncMemorySession( + session_id=session_id, + instance_name="my-lakebase-instance" + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + """ + + def __init__( + self, + session_id: UUID, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + sessions_table: str = _MemorySessionBase.SESSIONS_TABLE, + messages_table: str = _MemorySessionBase.MESSAGES_TABLE, + **pool_kwargs, + ) -> None: + """ + Initialize an AsyncMemorySession. + + Note: The async pool and tables are initialized lazily on first use. + + Args: + session_id: UUID identifier for this conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + **pool_kwargs: Additional arguments passed to AsyncLakebasePool. + """ + super().__init__( + session_id=session_id, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._instance_name = instance_name + self._workspace_client = workspace_client + self._pool_kwargs = pool_kwargs + + self._pool: Optional[AsyncLakebasePool] = None + self._initialized = False + self._init_lock = asyncio.Lock() + + async def _ensure_initialized(self) -> None: + """Ensure the pool is created and tables exist (lazy initialization).""" + if self._initialized: + return + + async with self._init_lock: + if self._initialized: + return + + self._pool = await _get_or_create_async_pool( + instance_name=self._instance_name, + workspace_client=self._workspace_client, + **self._pool_kwargs, + ) + + if not await self._tables_exist(): + await self._create_tables() + + await self._ensure_session() + self._initialized = True + + async def _tables_exist(self) -> bool: + """Check if both session tables already exist.""" + assert self._pool is not None + async with self._pool.connection() as conn: + result = await conn.execute( + """ + SELECT COUNT(*) as cnt FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name IN (%s, %s) + """, + (self.sessions_table, self.messages_table), + ) + row = await result.fetchone() + return row["cnt"] == 2 + + async def _create_tables(self) -> None: + """Create the required tables.""" + assert self._pool is not None + async with self._pool.connection() as conn: + await conn.execute(self._build_create_sessions_sql()) + await conn.execute(self._build_create_messages_sql()) + logger.info(f"Created tables {self.sessions_table}, {self.messages_table}") + + async def _ensure_session(self) -> None: + """Ensure the session record exists in agent_sessions table.""" + assert self._pool is not None + now = datetime.now(timezone.utc) + async with self._pool.connection() as conn: + await conn.execute( + self._build_ensure_session_sql(), + (self.session_id, now, now), + ) + logger.debug(f"Ensured session {self.session_id} exists") + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """ + Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history. + """ + await self._ensure_initialized() + assert self._pool is not None + + query, params = self._build_get_items_query(limit) + async with self._pool.connection() as conn: + result = await conn.execute(query, params) + rows = await result.fetchall() + return self._parse_rows_to_items(rows) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """ + Add new items to the conversation history. + + Args: + items: List of input items to add to the history. + """ + if not items: + return + + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + async with conn.cursor() as cur: + await cur.executemany( + self._build_add_items_sql(), + self._prepare_items_for_insert(items), + ) + await conn.execute( + self._build_update_session_timestamp_sql(), + (self.session_id,), + ) + logger.debug(f"Added {len(items)} items to session {self.session_id}") + + async def pop_item(self) -> TResponseInputItem | None: + """ + Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty. + """ + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + result = await conn.execute( + self._build_pop_item_sql(), + (self.session_id,), + ) + row = await result.fetchone() + + if row: + now = datetime.now(timezone.utc) + await conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + + if row: + logger.debug(f"Popped item from session {self.session_id}") + return self._parse_message_data(row["message_data"]) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + result = await conn.execute( + self._build_clear_session_sql(), + (self.session_id,), + ) + count = result.rowcount + + now = datetime.now(timezone.utc) + await conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + logger.info(f"Cleared {count} items from session {self.session_id}") diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py new file mode 100644 index 000000000..26f2035d1 --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -0,0 +1,1035 @@ +from __future__ import annotations + +import json +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +pytest.importorskip("psycopg") +pytest.importorskip("psycopg_pool") +pytest.importorskip("agents.memory.session") + +from databricks_ai_bridge import lakebase +from psycopg import sql + +from databricks_openai.agents.session import ( + AsyncMemorySession, + MemorySession, + _async_pool_cache, + _pool_cache, +) + +# Use UUID (V7) for performance +TEST_SESSION_ID = UUID("12345678-1234-5678-1234-567812345678") +TEST_SESSION_ID_2 = UUID("22345678-1234-5678-1234-567812345678") +TEST_SESSION_ID_3 = UUID("32345678-1234-5678-1234-567812345678") + + +def query_to_string(query): + """Convert a query (string or sql.Composed) to a string for testing.""" + if isinstance(query, str): + return query + if isinstance(query, (sql.Composed, sql.SQL, sql.Identifier)): + return query.as_string(None) + return str(query) + + +class MockCursor: + """Mock cursor for executemany operations.""" + + def __init__(self): + self.executed_queries = [] + + def executemany(self, query, params): + self.executed_queries.append((query, params)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +class MockAsyncCursor: + """Mock async cursor for executemany operations.""" + + def __init__(self): + self.executed_queries = [] + + async def executemany(self, query, params): + self.executed_queries.append((query, params)) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + +class MockResult: + """Mock result object for database queries.""" + + def __init__(self, rows=None, rowcount=0): + self._rows = rows or [] + self.rowcount = rowcount + self._index = 0 + + def fetchall(self): + return self._rows + + def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + +class MockAsyncResult: + """Mock async result object for database queries.""" + + def __init__(self, rows=None, rowcount=0): + self._rows = rows or [] + self.rowcount = rowcount + self._index = 0 + + async def fetchall(self): + return self._rows + + async def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + +class MockConnection: + """Mock database connection.""" + + def __init__(self): + self.executed_queries = [] + self._cursor = MockCursor() + self._next_result = MockResult() + self._results_queue = [] + + def execute(self, query, params=None): + self.executed_queries.append((query, params)) + if self._results_queue: + return self._results_queue.pop(0) + return self._next_result + + def cursor(self): + return self._cursor + + def set_next_result(self, result): + self._next_result = result + + def queue_result(self, result): + self._results_queue.append(result) + + +class MockAsyncConnection: + """Mock async database connection.""" + + def __init__(self): + self.executed_queries = [] + self._cursor = MockAsyncCursor() + self._next_result = MockAsyncResult() + self._results_queue = [] + + async def execute(self, query, params=None): + self.executed_queries.append((query, params)) + if self._results_queue: + return self._results_queue.pop(0) + return self._next_result + + def cursor(self): + return self._cursor + + def set_next_result(self, result): + self._next_result = result + + def queue_result(self, result): + self._results_queue.append(result) + + +class MockConnectionPool: + """Mock connection pool for testing.""" + + def __init__(self, connection_value=None): + self.connection_value = connection_value or MockConnection() + self.conninfo = "" + + def __call__(self, *, conninfo, connection_class=None, **kwargs): + self.conninfo = conninfo + return self + + def connection(self): + class _Ctx: + def __init__(self, outer): + self.outer = outer + + def __enter__(self): + return self.outer.connection_value + + def __exit__(self, exc_type, exc, tb): + pass + + return _Ctx(self) + + +class MockAsyncConnectionPool: + """Mock async connection pool for testing.""" + + def __init__(self, connection_value=None): + self.connection_value = connection_value or MockAsyncConnection() + self.conninfo = "" + self._opened = False + self._closed = False + + def __call__(self, *, conninfo, connection_class=None, **kwargs): + self.conninfo = conninfo + return self + + async def open(self): + self._opened = True + + async def close(self): + self._closed = True + + def connection(self): + class _AsyncCtx: + def __init__(self, outer): + self.outer = outer + + async def __aenter__(self): + return self.outer.connection_value + + async def __aexit__(self, exc_type, exc, tb): + pass + + return _AsyncCtx(self) + + +@pytest.fixture(autouse=True) +def clear_pool_cache(): + """Clear the pool cache before each test.""" + _pool_cache.clear() + _async_pool_cache.clear() + yield + _pool_cache.clear() + _async_pool_cache.clear() + + +@pytest.fixture +def mock_workspace(): + """Create a mock workspace client.""" + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + return workspace + + +@pytest.fixture +def mock_connection(): + """Create a mock connection.""" + return MockConnection() + + +@pytest.fixture +def mock_async_connection(): + """Create a mock async connection.""" + return MockAsyncConnection() + + +@pytest.fixture +def mock_pool(mock_connection): + """Create a mock connection pool.""" + return MockConnectionPool(connection_value=mock_connection) + + +@pytest.fixture +def mock_async_pool(mock_async_connection): + """Create a mock async connection pool.""" + return MockAsyncConnectionPool(connection_value=mock_async_connection) + + +# ============================================================================= +# MemorySession Tests (Sync) +# ============================================================================= + + +def test_session_configures_lakebase(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that MemorySession correctly configures the Lakebase pool.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables already exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + assert ( + mock_pool.conninfo + == "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require" + ) + assert session.session_id == TEST_SESSION_ID + assert session.sessions_table == "agent_sessions" + assert session.messages_table == "agent_messages" + + +def test_session_creates_tables_on_init_when_not_exist( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test that MemorySession creates tables when they don't exist.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables don't exist (count=0) + mock_connection.queue_result(MockResult(rows=[{"cnt": 0}])) + # CREATE sessions table + mock_connection.queue_result(MockResult()) + # CREATE messages table + mock_connection.queue_result(MockResult()) + # INSERT session + mock_connection.queue_result(MockResult()) + + MemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Should have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + create_sessions_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_sessions" in q for q in queries + ) + create_messages_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_messages" in q for q in queries + ) + + assert create_sessions_found, "Should create sessions table" + assert create_messages_found, "Should create messages table" + + +def test_session_skips_table_creation_when_tables_exist( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test that MemorySession skips table creation when tables already exist.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: both tables exist (count=2) + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + # INSERT session (no CREATE TABLE calls) + mock_connection.queue_result(MockResult()) + + MemorySession( + session_id=TEST_SESSION_ID_3, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Should NOT have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + create_table_found = any("CREATE TABLE" in q for q in queries) + + assert not create_table_found, "Should not create tables when they already exist" + + +def test_session_ensures_session_record(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that MemorySession ensures the session record exists.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables already exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Find the INSERT INTO agent_sessions query + insert_queries = [ + (q, p) + for q, p in mock_connection.executed_queries + if "INSERT INTO" in query_to_string(q) and "agent_sessions" in query_to_string(q) + ] + + assert len(insert_queries) > 0, "Should insert session record" + query, params = insert_queries[0] + assert params[0] == TEST_SESSION_ID, "Should use correct session_id" + + +@pytest.mark.asyncio +async def test_get_items_empty_session(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test get_items returns empty list for new session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist, then INSERT session, then SELECT messages + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=[])) # SELECT messages + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + assert items == [] + + +@pytest.mark.asyncio +async def test_get_items_returns_parsed_json( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test get_items correctly parses JSON data.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + test_messages = [ + {"message_data": json.dumps({"role": "user", "content": "Hello"})}, + {"message_data": json.dumps({"role": "assistant", "content": "Hi there!"})}, + ] + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=test_messages)) # SELECT messages + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Hi there!" + + +@pytest.mark.asyncio +async def test_get_items_with_limit(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test get_items respects limit parameter.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result( + MockResult(rows=[{"message_data": json.dumps({"role": "user", "content": "Latest"})}]) + ) + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items(limit=1) + + assert len(items) == 1 + + # Verify the query used LIMIT + select_queries = [ + query_to_string(q) + for q, p in mock_connection.executed_queries + if "SELECT message_data" in query_to_string(q) + ] + assert any("LIMIT" in q for q in select_queries), "Should use LIMIT in query" + + +@pytest.mark.asyncio +async def test_add_items(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test add_items inserts messages correctly.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + test_items = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + await session.add_items(test_items) + + # Check that executemany was called on cursor + assert len(mock_connection._cursor.executed_queries) > 0 + query, params = mock_connection._cursor.executed_queries[-1] + query_str = query_to_string(query) + assert "INSERT INTO" in query_str and "agent_messages" in query_str + assert len(params) == 2 # Two items + + +@pytest.mark.asyncio +async def test_add_items_empty_list(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test add_items handles empty list gracefully.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + initial_query_count = len(mock_connection.executed_queries) + + await session.add_items([]) + + # Should not execute any additional queries for empty list + # (only the queries from init should be present) + assert len(mock_connection.executed_queries) == initial_query_count + + +@pytest.mark.asyncio +async def test_pop_item_returns_last_item(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test pop_item removes and returns the most recent item.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + # DELETE RETURNING result + mock_connection.queue_result( + MockResult( + rows=[{"message_data": json.dumps({"role": "assistant", "content": "Last msg"})}] + ) + ) + mock_connection.queue_result(MockResult()) # UPDATE session timestamp + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + popped = await session.pop_item() + + assert popped is not None + assert popped["role"] == "assistant" + assert popped["content"] == "Last msg" + + +@pytest.mark.asyncio +async def test_pop_item_returns_none_when_empty( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test pop_item returns None for empty session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=[])) # DELETE RETURNING - empty + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + popped = await session.pop_item() + + assert popped is None + + +@pytest.mark.asyncio +async def test_clear_session(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test clear_session deletes all messages for the session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rowcount=5)) # DELETE messages + mock_connection.queue_result(MockResult()) # UPDATE session timestamp + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + await session.clear_session() + + # Find the DELETE query + delete_queries = [ + (q, p) + for q, p in mock_connection.executed_queries + if "DELETE FROM" in query_to_string(q) + and "agent_messages" in query_to_string(q) + and "WHERE session_id" in query_to_string(q) + ] + + assert len(delete_queries) > 0, "Should execute DELETE query" + query, params = delete_queries[0] + assert params == (TEST_SESSION_ID,), "Should use correct session_id" + + +def test_custom_table_names(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that custom table names are used correctly.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables don't exist (custom names), so they will be created + mock_connection.queue_result(MockResult(rows=[{"cnt": 0}])) # tables don't exist + mock_connection.queue_result(MockResult()) # CREATE sessions + mock_connection.queue_result(MockResult()) # CREATE messages + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + assert session.sessions_table == "custom_sessions" + assert session.messages_table == "custom_messages" + + # Check that CREATE TABLE uses custom names + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + assert any("custom_sessions" in q for q in queries) + assert any("custom_messages" in q for q in queries) + + +def test_pool_caching(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that pools are cached and reused.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock for both session creations + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist + mock_connection.queue_result(MockResult()) # INSERT session 1 + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist + mock_connection.queue_result(MockResult()) # INSERT session 2 + + session1 = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="shared-instance", + workspace_client=mock_workspace, + ) + + session2 = MemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="shared-instance", + workspace_client=mock_workspace, + ) + + # Both sessions should share the same pool + assert session1._pool is session2._pool + + +@pytest.mark.asyncio +async def test_get_items_handles_dict_message_data( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test get_items handles message_data that's already a dict (not JSON string).""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Some database drivers return JSONB as dict directly + test_messages = [ + {"message_data": {"role": "user", "content": "Already parsed"}}, + ] + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=test_messages)) + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + + assert len(items) == 1 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Already parsed" + + +# ============================================================================= +# AsyncMemorySession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_async_session_lazy_initialization( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession initializes lazily on first use.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + # Create session - should NOT trigger any DB operations yet + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # No queries should have been executed yet + assert len(mock_async_connection.executed_queries) == 0 + assert not session._initialized + + # Mock: tables exist, then INSERT session, then SELECT messages + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT messages + + # First use triggers initialization + await session.get_items() + + # Now it should be initialized + assert session._initialized + assert mock_async_pool._opened + + +@pytest.mark.asyncio +async def test_async_session_configures_lakebase( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession correctly configures the Lakebase pool.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables already exist + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + # Trigger initialization + await session.get_items() + + assert ( + mock_async_pool.conninfo + == "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require" + ) + assert session.session_id == TEST_SESSION_ID + assert session.sessions_table == "agent_sessions" + assert session.messages_table == "agent_messages" + + +@pytest.mark.asyncio +async def test_async_session_creates_tables_when_not_exist( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession creates tables when they don't exist.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables don't exist (count=0) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 0}])) + mock_async_connection.queue_result(MockAsyncResult()) # CREATE sessions table + mock_async_connection.queue_result(MockAsyncResult()) # CREATE messages table + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + # Should have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + create_sessions_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_sessions" in q for q in queries + ) + create_messages_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_messages" in q for q in queries + ) + + assert create_sessions_found, "Should create sessions table" + assert create_messages_found, "Should create messages table" + + +@pytest.mark.asyncio +async def test_async_session_skips_table_creation_when_tables_exist( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession skips table creation when tables already exist.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: both tables exist (count=2) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + # Should NOT have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + create_table_found = any("CREATE TABLE" in q for q in queries) + + assert not create_table_found, "Should not create tables when they already exist" + + +@pytest.mark.asyncio +async def test_async_get_items_empty_session( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async get_items returns empty list for new session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT messages + + items = await session.get_items() + assert items == [] + + +@pytest.mark.asyncio +async def test_async_get_items_returns_parsed_json( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async get_items correctly parses JSON data.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + test_messages = [ + {"message_data": json.dumps({"role": "user", "content": "Hello"})}, + {"message_data": json.dumps({"role": "assistant", "content": "Hi there!"})}, + ] + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=test_messages)) # SELECT messages + + items = await session.get_items() + + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Hi there!" + + +@pytest.mark.asyncio +async def test_async_add_items(monkeypatch, mock_workspace, mock_async_pool, mock_async_connection): + """Test async add_items inserts messages correctly.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables exist + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + + test_items = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + await session.add_items(test_items) + + # Check that executemany was called on cursor + assert len(mock_async_connection._cursor.executed_queries) > 0 + query, params = mock_async_connection._cursor.executed_queries[-1] + query_str = query_to_string(query) + assert "INSERT INTO" in query_str and "agent_messages" in query_str + assert len(params) == 2 # Two items + + +@pytest.mark.asyncio +async def test_async_pop_item_returns_last_item( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async pop_item removes and returns the most recent item.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + # DELETE RETURNING result + mock_async_connection.queue_result( + MockAsyncResult( + rows=[{"message_data": json.dumps({"role": "assistant", "content": "Last msg"})}] + ) + ) + mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp + + popped = await session.pop_item() + + assert popped is not None + assert popped["role"] == "assistant" + assert popped["content"] == "Last msg" + + +@pytest.mark.asyncio +async def test_async_pop_item_returns_none_when_empty( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async pop_item returns None for empty session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # DELETE RETURNING - empty + + popped = await session.pop_item() + + assert popped is None + + +@pytest.mark.asyncio +async def test_async_clear_session( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async clear_session deletes all messages for the session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rowcount=5)) # DELETE messages + mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp + + await session.clear_session() + + # Find the DELETE query + delete_queries = [ + (q, p) + for q, p in mock_async_connection.executed_queries + if "DELETE FROM" in query_to_string(q) + and "agent_messages" in query_to_string(q) + and "WHERE session_id" in query_to_string(q) + ] + + assert len(delete_queries) > 0, "Should execute DELETE query" + query, params = delete_queries[0] + assert params == (TEST_SESSION_ID,), "Should use correct session_id" + + +@pytest.mark.asyncio +async def test_async_custom_table_names( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that async session uses custom table names correctly.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + # Mock: tables don't exist (custom names), so they will be created + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 0}])) # tables don't exist + mock_async_connection.queue_result(MockAsyncResult()) # CREATE sessions + mock_async_connection.queue_result(MockAsyncResult()) # CREATE messages + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + assert session.sessions_table == "custom_sessions" + assert session.messages_table == "custom_messages" + + # Check that CREATE TABLE uses custom names + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + assert any("custom_sessions" in q for q in queries) + assert any("custom_messages" in q for q in queries) + + +@pytest.mark.asyncio +async def test_async_pool_caching( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that async pools are cached and reused.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session1 = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="shared-async-instance", + workspace_client=mock_workspace, + ) + + session2 = AsyncMemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="shared-async-instance", + workspace_client=mock_workspace, + ) + + # Mock for initialization of first session + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session 1 + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT + + await session1.get_items() + + # Mock for second session (tables already exist check, but pool should be reused) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session 2 + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT + + await session2.get_items() + + # Both sessions should share the same pool + assert session1._pool is session2._pool \ No newline at end of file From c23e1706cbc9b00cb4aec21e6930ebf3240a05a2 Mon Sep 17 00:00:00 2001 From: Jenny Date: Fri, 16 Jan 2026 10:41:30 -0800 Subject: [PATCH 02/16] update dependencies after dbaibridge release --- integrations/langchain/pyproject.toml | 2 +- integrations/openai/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 10ee6e01c..4fa284b39 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ [project.optional-dependencies] memory = [ "langgraph-checkpoint-postgres>=2.0.5", - "databricks-ai-bridge[memory]>=0.11.0", + "databricks-ai-bridge[memory]>=0.10.0", ] [dependency-groups] diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 492f3810a..52893cc14 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [project.optional-dependencies] memory = [ - "databricks-ai-bridge[memory]>=0.11.0", + "databricks-ai-bridge[memory]>=0.12.0", ] [dependency-groups] From 91365abcdd4c6101aedbd5005aff8865710a5807 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 10:15:24 -0800 Subject: [PATCH 03/16] handle openai omit sentinel --- .../src/databricks_openai/utils/clients.py | 7 ++++--- .../openai/tests/unit_tests/test_clients.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 99bb9d08c..fca933824 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -18,14 +18,15 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield request -def _strip_strict_from_tools(tools: list | None) -> list | None: +def _strip_strict_from_tools(tools) -> list | None: """Remove 'strict' field from tool function definitions. Databricks model endpoints (except GPT) don't support the 'strict' field in tool schemas, but openai-agents SDK v0.6.4+ includes it. """ - if tools is None: - return None + # Handle None or OpenAI's Omit sentinel (non-iterable placeholder) + if tools is None or not isinstance(tools, list): + return tools for tool in tools: if isinstance(tool, dict) and "function" in tool: tool.get("function", {}).pop("strict", None) diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index df44f276e..6f30d5b6a 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -126,6 +126,22 @@ def test_strip_strict_from_tools_handles_none(self): assert _strip_strict_from_tools(None) is None + def test_strip_strict_from_tools_handles_non_list(self): + """Test that non-list values (like OpenAI's Omit sentinel) are passed through.""" + from databricks_openai.utils.clients import _strip_strict_from_tools + + # Simulate OpenAI's Omit sentinel (a non-iterable placeholder) + class Omit: + pass + + omit_sentinel = Omit() + result = _strip_strict_from_tools(omit_sentinel) + assert result is omit_sentinel # Should return unchanged + + # Also test with other non-list types + assert _strip_strict_from_tools("not a list") == "not a list" + assert _strip_strict_from_tools(123) == 123 + def test_strip_strict_from_tools_handles_empty_list(self): from databricks_openai.utils.clients import _strip_strict_from_tools From 3a04fde09e125da46a90420c94211da96b19cfb7 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 10:38:24 -0800 Subject: [PATCH 04/16] type checking/format checks --- .../openai/tests/unit_tests/test_session.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 26f2035d1..dede5e8fd 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING, Any, cast from unittest.mock import MagicMock from uuid import UUID @@ -11,7 +12,11 @@ pytest.importorskip("agents.memory.session") from databricks_ai_bridge import lakebase -from psycopg import sql + +if TYPE_CHECKING: + from psycopg import sql +else: + from psycopg import sql from databricks_openai.agents.session import ( AsyncMemorySession, @@ -412,7 +417,7 @@ async def test_get_items_returns_parsed_json( workspace_client=mock_workspace, ) - items = await session.get_items() + items = cast(list[dict[str, Any]], await session.get_items()) assert len(items) == 2 assert items[0]["role"] == "user" @@ -466,12 +471,12 @@ async def test_add_items(monkeypatch, mock_workspace, mock_pool, mock_connection workspace_client=mock_workspace, ) - test_items = [ + test_items: list[dict[str, Any]] = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}, ] - await session.add_items(test_items) + await session.add_items(cast(Any, test_items)) # Check that executemany was called on cursor assert len(mock_connection._cursor.executed_queries) > 0 @@ -526,7 +531,7 @@ async def test_pop_item_returns_last_item(monkeypatch, mock_workspace, mock_pool workspace_client=mock_workspace, ) - popped = await session.pop_item() + popped = cast(dict[str, Any], await session.pop_item()) assert popped is not None assert popped["role"] == "assistant" @@ -662,7 +667,7 @@ async def test_get_items_handles_dict_message_data( workspace_client=mock_workspace, ) - items = await session.get_items() + items = cast(list[dict[str, Any]], await session.get_items()) assert len(items) == 1 assert items[0]["role"] == "user" @@ -840,7 +845,7 @@ async def test_async_get_items_returns_parsed_json( mock_async_connection.queue_result(MockAsyncResult()) # INSERT session mock_async_connection.queue_result(MockAsyncResult(rows=test_messages)) # SELECT messages - items = await session.get_items() + items = cast(list[dict[str, Any]], await session.get_items()) assert len(items) == 2 assert items[0]["role"] == "user" @@ -864,12 +869,12 @@ async def test_async_add_items(monkeypatch, mock_workspace, mock_async_pool, moc mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) mock_async_connection.queue_result(MockAsyncResult()) # INSERT session - test_items = [ + test_items: list[dict[str, Any]] = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}, ] - await session.add_items(test_items) + await session.add_items(cast(Any, test_items)) # Check that executemany was called on cursor assert len(mock_async_connection._cursor.executed_queries) > 0 @@ -902,7 +907,7 @@ async def test_async_pop_item_returns_last_item( ) mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp - popped = await session.pop_item() + popped = cast(dict[str, Any], await session.pop_item()) assert popped is not None assert popped["role"] == "assistant" @@ -1032,4 +1037,4 @@ async def test_async_pool_caching( await session2.get_items() # Both sessions should share the same pool - assert session1._pool is session2._pool \ No newline at end of file + assert session1._pool is session2._pool From 22653a40e8d159061b0ebe07865f73b56f490e03 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 11:53:01 -0800 Subject: [PATCH 05/16] add memory subpackage dependency in dev --- integrations/openai/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 52893cc14..5d576d6bd 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -31,6 +31,8 @@ dev = [ "databricks-sdk>=0.34.0", "ruff==0.14.10", "ty>=0.0.11", + # memory subpackage dependencies + "databricks-ai-bridge[memory]>=0.12.0", { include-group = "tests" }, ] From 3269b3c0dbf93d4dbb8c28d3a2986eb8010a420a Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 17:36:36 -0800 Subject: [PATCH 06/16] remove memory subpackage in pyproject --- integrations/openai/pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 5d576d6bd..52893cc14 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -31,8 +31,6 @@ dev = [ "databricks-sdk>=0.34.0", "ruff==0.14.10", "ty>=0.0.11", - # memory subpackage dependencies - "databricks-ai-bridge[memory]>=0.12.0", { include-group = "tests" }, ] From 1dea2e08e5b9f702e525550820a6aa71f13a41f2 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 18:02:58 -0800 Subject: [PATCH 07/16] allow lakebasepool to resolve hostname --- src/databricks_ai_bridge/lakebase.py | 90 +++++++++++++---- tests/databricks_ai_bridge/test_lakebase.py | 101 ++++++++++++++++++++ 2 files changed, 173 insertions(+), 18 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 7ebc036cd..11a43ac39 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -50,34 +50,88 @@ def __init__( token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: self.workspace_client: WorkspaceClient = workspace_client or WorkspaceClient() - self.instance_name: str = instance_name self.token_cache_duration_seconds: int = token_cache_duration_seconds - # Resolve host from the Lakebase name - try: - instance = self.workspace_client.database.get_database_instance(instance_name) - except Exception as exc: - raise ValueError( - f"Unable to resolve Lakebase instance '{instance_name}'. " - "Ensure the instance name is correct." - ) from exc + # If input is hostname (e.g., from Databricks Apps valueFrom resolution) + # resolve to lakebase name + if self._is_hostname(instance_name): + # Input is a hostname - resolve to instance name + self.instance_name, self.host = self._resolve_from_hostname(instance_name) + else: + # Input is an instance name + self.instance_name = instance_name + try: + instance = self.workspace_client.database.get_database_instance(instance_name) + except Exception as exc: + raise ValueError( + f"Unable to resolve Lakebase instance '{instance_name}'. " + "Ensure the instance name is correct." + ) from exc + + resolved_host = getattr(instance, "read_write_dns", None) or getattr( + instance, "read_only_dns", None + ) - resolved_host = getattr(instance, "read_write_dns", None) or getattr( - instance, "read_only_dns", None - ) + if not resolved_host: + raise ValueError( + f"Lakebase host not found for instance '{instance_name}'. " + "Ensure the instance is running and in AVAILABLE state." + ) - if not resolved_host: - raise ValueError( - f"Lakebase host not found for instance '{instance_name}'. " - "Ensure the instance is running and in AVAILABLE state." - ) + self.host = resolved_host - self.host: str = resolved_host self.username: str = self._infer_username() self._cached_token: str | None = None self._cache_ts: float | None = None + @staticmethod + def _is_hostname(value: str) -> bool: + """Check if the value looks like a Lakebase hostname rather than an instance name.""" + # Hostname pattern: instance-{uuid}.database.{env}.cloud.databricks.com + # or similar patterns containing ".database." and ending with a domain + return ".database." in value and (value.endswith(".com") or value.endswith(".net")) + + def _resolve_from_hostname(self, hostname: str) -> tuple[str, str]: + """ + Resolve instance name from a hostname by listing database instances. + + Args: + hostname: The database hostname (e.g., from Databricks Apps valueFrom: "database") + + Returns: + Tuple of (instance_name, host) + + Raises: + ValueError: If no matching instance is found + """ + try: + instances = list(self.workspace_client.database.list_database_instances()) + except Exception as exc: + raise ValueError( + f"Unable to list database instances to resolve hostname '{hostname}'. " + "Ensure you have access to database instances." + ) from exc + + # Find the instance that matches this hostname + for instance in instances: + rw_dns = getattr(instance, "read_write_dns", None) + ro_dns = getattr(instance, "read_only_dns", None) + + if hostname in (rw_dns, ro_dns): + instance_name = getattr(instance, "name", None) + if not instance_name: + raise ValueError( + f"Found matching instance for hostname '{hostname}' " + "but instance name is not available." + ) + return instance_name, hostname + + raise ValueError( + f"Unable to find database instance matching hostname '{hostname}'. " + "Ensure the hostname is correct and the instance exists." + ) + def _get_cached_token(self) -> str | None: """Check if the cached token is still valid.""" if not self._cached_token or not self._cache_ts: diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index 72835aa95..5a8d042fd 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -410,3 +410,104 @@ async def test_async_lakebase_pool_open_close_methods(monkeypatch): await pool.close() assert pool.pool._closed + + +# ============================================================================= +# Hostname Resolution Tests +# ============================================================================= + + +def test_is_hostname_detects_database_hostname(): + """Test that _is_hostname correctly identifies database hostnames.""" + from databricks_ai_bridge.lakebase import _LakebasePoolBase + + # Should be detected as hostnames + assert _LakebasePoolBase._is_hostname( + "instance-f757b615-f2fd-4614-87cc-9ba35f2eeb61.database.staging.cloud.databricks.com" + ) + assert _LakebasePoolBase._is_hostname("instance-abc123.database.prod.cloud.databricks.com") + assert _LakebasePoolBase._is_hostname("my-db.database.example.net") + + # Should NOT be detected as hostnames (regular instance names) + assert not _LakebasePoolBase._is_hostname("lakebase") + assert not _LakebasePoolBase._is_hostname("my-database-instance") + assert not _LakebasePoolBase._is_hostname("production_db") + + +def test_lakebase_pool_accepts_hostname(monkeypatch): + """Test that LakebasePool accepts hostname and resolves instance name.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace() + + # Mock list_database_instances to return an instance matching the hostname + hostname = "instance-abc123.database.staging.cloud.databricks.com" + mock_instance = MagicMock() + mock_instance.name = "my-lakebase-instance" + mock_instance.read_write_dns = hostname + mock_instance.read_only_dns = None + workspace.database.list_database_instances.return_value = [mock_instance] + + pool = LakebasePool( + instance_name=hostname, # Pass hostname instead of instance name + workspace_client=workspace, + ) + + # Should have resolved to the instance name + assert pool.instance_name == "my-lakebase-instance" + assert pool.host == hostname + + # get_database_instance should NOT have been called (we used list instead) + workspace.database.get_database_instance.assert_not_called() + + +def test_lakebase_pool_hostname_not_found_raises_error(monkeypatch): + """Test that LakebasePool raises error when hostname doesn't match any instance.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace() + + # Mock list_database_instances to return instances that don't match + other_instance = MagicMock() + other_instance.name = "other-instance" + other_instance.read_write_dns = "other-host.database.staging.cloud.databricks.com" + other_instance.read_only_dns = None + workspace.database.list_database_instances.return_value = [other_instance] + + hostname = "instance-not-found.database.staging.cloud.databricks.com" + + with pytest.raises(ValueError, match="Unable to find database instance matching hostname"): + LakebasePool( + instance_name=hostname, + workspace_client=workspace, + ) + + +@pytest.mark.asyncio +async def test_async_lakebase_pool_accepts_hostname(monkeypatch): + """Test that AsyncLakebasePool accepts hostname and resolves instance name.""" + TestAsyncConnectionPool = _make_async_connection_pool_class() + monkeypatch.setattr( + "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool + ) + + workspace = _make_workspace() + + # Mock list_database_instances to return an instance matching the hostname + hostname = "instance-xyz789.database.prod.cloud.databricks.com" + mock_instance = MagicMock() + mock_instance.name = "prod-lakebase" + mock_instance.read_write_dns = hostname + mock_instance.read_only_dns = None + workspace.database.list_database_instances.return_value = [mock_instance] + + pool = AsyncLakebasePool( + instance_name=hostname, # Pass hostname instead of instance name + workspace_client=workspace, + ) + + # Should have resolved to the instance name + assert pool.instance_name == "prod-lakebase" + assert pool.host == hostname From 3894ec0ff198dfec276f13e34f3e39da128d48be Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 14 Jan 2026 19:55:15 -0800 Subject: [PATCH 08/16] memorysession and asyncmemorysession --- integrations/langchain/pyproject.toml | 2 +- integrations/openai/pyproject.toml | 5 + .../src/databricks_openai/agents/session.py | 647 +++++++++++ .../openai/tests/unit_tests/test_session.py | 1035 +++++++++++++++++ 4 files changed, 1688 insertions(+), 1 deletion(-) create mode 100644 integrations/openai/src/databricks_openai/agents/session.py create mode 100644 integrations/openai/tests/unit_tests/test_session.py diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 4fa284b39..10ee6e01c 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ [project.optional-dependencies] memory = [ "langgraph-checkpoint-postgres>=2.0.5", - "databricks-ai-bridge[memory]>=0.10.0", + "databricks-ai-bridge[memory]>=0.11.0", ] [dependency-groups] diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index b3eea7f83..492f3810a 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -20,6 +20,11 @@ dependencies = [ "openai-agents>=0.5.0" ] +[project.optional-dependencies] +memory = [ + "databricks-ai-bridge[memory]>=0.11.0", +] + [dependency-groups] dev = [ "typing_extensions>=4.15.0", diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py new file mode 100644 index 000000000..3c6e7549b --- /dev/null +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timezone +from threading import Lock +from typing import Any, Dict, Optional, Tuple, Union, cast +from uuid import UUID + +from databricks.sdk import WorkspaceClient + +try: + from agents.items import TResponseInputItem + from agents.memory.session import SessionABC + from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool + from psycopg import sql + from psycopg.sql import Composed +except ImportError as e: + raise ImportError( + "MemorySession requires databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) from e + +logger = logging.getLogger(__name__) + +# Module-level pool cache: instance_name -> LakebasePool +_pool_cache: Dict[str, LakebasePool] = {} +_pool_cache_lock = Lock() + +# Module-level async pool cache: instance_name -> AsyncLakebasePool +_async_pool_cache: Dict[str, AsyncLakebasePool] = {} +_async_pool_cache_lock = asyncio.Lock() + + +def _get_or_create_pool( + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + **pool_kwargs, +) -> LakebasePool: + """Get cached pool or create new one for this instance.""" + cache_key = instance_name + + with _pool_cache_lock: + if cache_key not in _pool_cache: + logger.info(f"Creating new LakebasePool for {cache_key}") + _pool_cache[cache_key] = LakebasePool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + return _pool_cache[cache_key] + + +async def _get_or_create_async_pool( + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + **pool_kwargs, +) -> AsyncLakebasePool: + """Get cached async pool or create new one for this instance.""" + cache_key = instance_name + + async with _async_pool_cache_lock: + if cache_key not in _async_pool_cache: + logger.info(f"Creating new AsyncLakebasePool for {cache_key}") + pool = AsyncLakebasePool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + await pool.open() + _async_pool_cache[cache_key] = pool + return _async_pool_cache[cache_key] + + +class _MemorySessionBase(SessionABC): + """ + Base class with shared SQL, configuration, and helper methods for memory sessions. + + Subclasses implement sync or async pool initialization and database operations. + """ + + # Table names + SESSIONS_TABLE = "agent_sessions" + MESSAGES_TABLE = "agent_messages" + + CREATE_SESSIONS_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS {sessions_table} ( + session_id UUID PRIMARY KEY, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + ); + """ + + CREATE_MESSAGES_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS {messages_table} ( + id BIGSERIAL PRIMARY KEY, + session_id UUID NOT NULL REFERENCES {sessions_table}(session_id) ON DELETE CASCADE, + message_data JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS {idx_session_id} + ON {messages_table}(session_id); + CREATE INDEX IF NOT EXISTS {idx_session_order} + ON {messages_table}(session_id, id); + """ + + session_id: UUID + + def __init__( + self, + session_id: UUID, + *, + sessions_table: str = SESSIONS_TABLE, + messages_table: str = MESSAGES_TABLE, + ) -> None: + """ + Initialize base session attributes. + + Args: + session_id: UUID identifier for this conversation session. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + """ + self.session_id = session_id + self.sessions_table = sessions_table + self.messages_table = messages_table + + # --- SQL Building Helpers --- + + def _build_create_sessions_sql(self) -> Composed: + """Build SQL to create the sessions table.""" + return sql.SQL(self.CREATE_SESSIONS_TABLE_SQL).format( + sessions_table=sql.Identifier(self.sessions_table) + ) + + def _build_create_messages_sql(self) -> Composed: + """Build SQL to create the messages table.""" + return sql.SQL(self.CREATE_MESSAGES_TABLE_SQL).format( + sessions_table=sql.Identifier(self.sessions_table), + messages_table=sql.Identifier(self.messages_table), + idx_session_id=sql.Identifier(f"idx_{self.messages_table}_session_id"), + idx_session_order=sql.Identifier(f"idx_{self.messages_table}_session_order"), + ) + + def _build_ensure_session_sql(self) -> Composed: + """Build SQL to insert session record if not exists.""" + return sql.SQL( + """ + INSERT INTO {} (session_id, created_at, updated_at) + VALUES (%s, %s, %s) + ON CONFLICT (session_id) DO NOTHING + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_get_items_query( + self, limit: int | None + ) -> Tuple[Composed, Tuple[UUID, ...] | Tuple[UUID, int]]: + """Build SQL query and params to get items.""" + if limit is not None: + query = sql.SQL( + """ + SELECT message_data FROM ( + SELECT message_data, id + FROM {} + WHERE session_id = %s + ORDER BY id DESC + LIMIT %s + ) sub + ORDER BY id ASC + """ + ).format(sql.Identifier(self.messages_table)) + params: Tuple[UUID, ...] | Tuple[UUID, int] = (self.session_id, limit) + else: + query = sql.SQL( + """ + SELECT message_data + FROM {} + WHERE session_id = %s + ORDER BY id ASC + """ + ).format(sql.Identifier(self.messages_table)) + params = (self.session_id,) + return query, params + + def _build_add_items_sql(self) -> Composed: + """Build SQL to insert message items.""" + return sql.SQL( + """ + INSERT INTO {} (session_id, message_data) + VALUES (%s, %s) + """ + ).format(sql.Identifier(self.messages_table)) + + def _build_update_session_timestamp_sql(self) -> Composed: + """Build SQL to update session timestamp.""" + return sql.SQL( + """ + UPDATE {} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = %s + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_update_session_timestamp_with_value_sql(self) -> Composed: + """Build SQL to update session timestamp with explicit value.""" + return sql.SQL( + """ + UPDATE {} + SET updated_at = %s + WHERE session_id = %s + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_pop_item_sql(self) -> Composed: + """Build SQL to delete and return most recent item.""" + messages_table_id = sql.Identifier(self.messages_table) + return sql.SQL( + """ + DELETE FROM {messages_table} + WHERE id = ( + SELECT id + FROM {messages_table} + WHERE session_id = %s + ORDER BY id DESC + LIMIT 1 + ) + RETURNING message_data + """ + ).format(messages_table=messages_table_id) + + def _build_clear_session_sql(self) -> Composed: + """Build SQL to delete all messages for session.""" + return sql.SQL("DELETE FROM {} WHERE session_id = %s").format( + sql.Identifier(self.messages_table) + ) + + def _prepare_items_for_insert(self, items: list[TResponseInputItem]) -> list[Tuple[UUID, str]]: + """Prepare items for database insertion.""" + return [(self.session_id, json.dumps(item)) for item in items] + + def _parse_message_data(self, message_data: Union[str, dict[str, Any]]) -> TResponseInputItem: + """Parse message_data from database (may be JSON string or dict).""" + if isinstance(message_data, str): + return cast(TResponseInputItem, json.loads(message_data)) + return cast(TResponseInputItem, message_data) + + def _parse_rows_to_items(self, rows: list) -> list[TResponseInputItem]: + """Parse database rows to list of items.""" + return [self._parse_message_data(row["message_data"]) for row in rows] + + +class MemorySession(_MemorySessionBase): + """ + OpenAI Agents SDK Session implementation using Lakebase for persistent storage. + + This class follows the Session protocol for conversation memory, + storing session data in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + SessionABC: https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.SessionABC + + Example: + ```python + from uuid import UUID + from databricks_openai.agents.session import MemorySession + from agents import Agent, Runner + + async def run_agent(thread_id: UUID | None, message: str): + # Use uuid7 for time-ordered UUIDs (better for database indexing) + session_id = thread_id + session = MemorySession( + session_id=session_id, + instance_name="my-lakebase-instance" + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + """ + + def __init__( + self, + session_id: UUID, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + sessions_table: str = _MemorySessionBase.SESSIONS_TABLE, + messages_table: str = _MemorySessionBase.MESSAGES_TABLE, + **pool_kwargs, + ) -> None: + """ + Initialize a MemorySession. + + On first initialization for a given Lakebase instance, this will automatically + create the required tables if they don't exist. + + Args: + session_id: UUID identifier for this conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + **pool_kwargs: Additional arguments passed to LakebasePool. + """ + super().__init__( + session_id=session_id, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._pool = _get_or_create_pool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + + if not self._tables_exist(): + self._create_tables() + + self._ensure_session() + + def _tables_exist(self) -> bool: + """Check if both session tables already exist.""" + with self._pool.connection() as conn: + result = conn.execute( + """ + SELECT COUNT(*) as cnt FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name IN (%s, %s) + """, + (self.sessions_table, self.messages_table), + ) + row = result.fetchone() + return row["cnt"] == 2 + + def _create_tables(self) -> None: + """Create the required tables.""" + with self._pool.connection() as conn: + conn.execute(self._build_create_sessions_sql()) + conn.execute(self._build_create_messages_sql()) + logger.info(f"Created tables {self.sessions_table}, {self.messages_table}") + + def _ensure_session(self) -> None: + """Ensure the session record exists in agent_sessions table.""" + now = datetime.now(timezone.utc) + with self._pool.connection() as conn: + conn.execute( + self._build_ensure_session_sql(), + (self.session_id, now, now), + ) + logger.debug(f"Ensured session {self.session_id} exists") + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """ + Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history. + """ + query, params = self._build_get_items_query(limit) + with self._pool.connection() as conn: + result = conn.execute(query, params) + rows = result.fetchall() + return self._parse_rows_to_items(rows) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """ + Add new items to the conversation history. + + Args: + items: List of input items to add to the history. + """ + if not items: + return + + with self._pool.connection() as conn: + with conn.cursor() as cur: + cur.executemany( + self._build_add_items_sql(), + self._prepare_items_for_insert(items), + ) + conn.execute( + self._build_update_session_timestamp_sql(), + (self.session_id,), + ) + logger.debug(f"Added {len(items)} items to session {self.session_id}") + + async def pop_item(self) -> TResponseInputItem | None: + """ + Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty. + """ + with self._pool.connection() as conn: + result = conn.execute( + self._build_pop_item_sql(), + (self.session_id,), + ) + row = result.fetchone() + + if row: + now = datetime.now(timezone.utc) + conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + + if row: + logger.debug(f"Popped item from session {self.session_id}") + return self._parse_message_data(row["message_data"]) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + with self._pool.connection() as conn: + result = conn.execute( + self._build_clear_session_sql(), + (self.session_id,), + ) + count = result.rowcount + + now = datetime.now(timezone.utc) + conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + logger.info(f"Cleared {count} items from session {self.session_id}") + + +class AsyncMemorySession(_MemorySessionBase): + """ + OpenAI Agents SDK Session implementation using Lakebase for persistent storage (async version). + + This class follows the Session protocol for conversation memory, + storing session data in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + SessionABC: https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.SessionABC + + Example: + ```python + from uuid import UUID + from databricks_openai.agents.session import AsyncMemorySession + from agents import Agent, Runner + + async def run_agent(thread_id: UUID | None, message: str): + session_id = thread_id + session = AsyncMemorySession( + session_id=session_id, + instance_name="my-lakebase-instance" + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + """ + + def __init__( + self, + session_id: UUID, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + sessions_table: str = _MemorySessionBase.SESSIONS_TABLE, + messages_table: str = _MemorySessionBase.MESSAGES_TABLE, + **pool_kwargs, + ) -> None: + """ + Initialize an AsyncMemorySession. + + Note: The async pool and tables are initialized lazily on first use. + + Args: + session_id: UUID identifier for this conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + **pool_kwargs: Additional arguments passed to AsyncLakebasePool. + """ + super().__init__( + session_id=session_id, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._instance_name = instance_name + self._workspace_client = workspace_client + self._pool_kwargs = pool_kwargs + + self._pool: Optional[AsyncLakebasePool] = None + self._initialized = False + self._init_lock = asyncio.Lock() + + async def _ensure_initialized(self) -> None: + """Ensure the pool is created and tables exist (lazy initialization).""" + if self._initialized: + return + + async with self._init_lock: + if self._initialized: + return + + self._pool = await _get_or_create_async_pool( + instance_name=self._instance_name, + workspace_client=self._workspace_client, + **self._pool_kwargs, + ) + + if not await self._tables_exist(): + await self._create_tables() + + await self._ensure_session() + self._initialized = True + + async def _tables_exist(self) -> bool: + """Check if both session tables already exist.""" + assert self._pool is not None + async with self._pool.connection() as conn: + result = await conn.execute( + """ + SELECT COUNT(*) as cnt FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name IN (%s, %s) + """, + (self.sessions_table, self.messages_table), + ) + row = await result.fetchone() + return row["cnt"] == 2 + + async def _create_tables(self) -> None: + """Create the required tables.""" + assert self._pool is not None + async with self._pool.connection() as conn: + await conn.execute(self._build_create_sessions_sql()) + await conn.execute(self._build_create_messages_sql()) + logger.info(f"Created tables {self.sessions_table}, {self.messages_table}") + + async def _ensure_session(self) -> None: + """Ensure the session record exists in agent_sessions table.""" + assert self._pool is not None + now = datetime.now(timezone.utc) + async with self._pool.connection() as conn: + await conn.execute( + self._build_ensure_session_sql(), + (self.session_id, now, now), + ) + logger.debug(f"Ensured session {self.session_id} exists") + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """ + Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history. + """ + await self._ensure_initialized() + assert self._pool is not None + + query, params = self._build_get_items_query(limit) + async with self._pool.connection() as conn: + result = await conn.execute(query, params) + rows = await result.fetchall() + return self._parse_rows_to_items(rows) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """ + Add new items to the conversation history. + + Args: + items: List of input items to add to the history. + """ + if not items: + return + + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + async with conn.cursor() as cur: + await cur.executemany( + self._build_add_items_sql(), + self._prepare_items_for_insert(items), + ) + await conn.execute( + self._build_update_session_timestamp_sql(), + (self.session_id,), + ) + logger.debug(f"Added {len(items)} items to session {self.session_id}") + + async def pop_item(self) -> TResponseInputItem | None: + """ + Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty. + """ + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + result = await conn.execute( + self._build_pop_item_sql(), + (self.session_id,), + ) + row = await result.fetchone() + + if row: + now = datetime.now(timezone.utc) + await conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + + if row: + logger.debug(f"Popped item from session {self.session_id}") + return self._parse_message_data(row["message_data"]) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + result = await conn.execute( + self._build_clear_session_sql(), + (self.session_id,), + ) + count = result.rowcount + + now = datetime.now(timezone.utc) + await conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + logger.info(f"Cleared {count} items from session {self.session_id}") diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py new file mode 100644 index 000000000..26f2035d1 --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -0,0 +1,1035 @@ +from __future__ import annotations + +import json +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +pytest.importorskip("psycopg") +pytest.importorskip("psycopg_pool") +pytest.importorskip("agents.memory.session") + +from databricks_ai_bridge import lakebase +from psycopg import sql + +from databricks_openai.agents.session import ( + AsyncMemorySession, + MemorySession, + _async_pool_cache, + _pool_cache, +) + +# Use UUID (V7) for performance +TEST_SESSION_ID = UUID("12345678-1234-5678-1234-567812345678") +TEST_SESSION_ID_2 = UUID("22345678-1234-5678-1234-567812345678") +TEST_SESSION_ID_3 = UUID("32345678-1234-5678-1234-567812345678") + + +def query_to_string(query): + """Convert a query (string or sql.Composed) to a string for testing.""" + if isinstance(query, str): + return query + if isinstance(query, (sql.Composed, sql.SQL, sql.Identifier)): + return query.as_string(None) + return str(query) + + +class MockCursor: + """Mock cursor for executemany operations.""" + + def __init__(self): + self.executed_queries = [] + + def executemany(self, query, params): + self.executed_queries.append((query, params)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +class MockAsyncCursor: + """Mock async cursor for executemany operations.""" + + def __init__(self): + self.executed_queries = [] + + async def executemany(self, query, params): + self.executed_queries.append((query, params)) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + +class MockResult: + """Mock result object for database queries.""" + + def __init__(self, rows=None, rowcount=0): + self._rows = rows or [] + self.rowcount = rowcount + self._index = 0 + + def fetchall(self): + return self._rows + + def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + +class MockAsyncResult: + """Mock async result object for database queries.""" + + def __init__(self, rows=None, rowcount=0): + self._rows = rows or [] + self.rowcount = rowcount + self._index = 0 + + async def fetchall(self): + return self._rows + + async def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + +class MockConnection: + """Mock database connection.""" + + def __init__(self): + self.executed_queries = [] + self._cursor = MockCursor() + self._next_result = MockResult() + self._results_queue = [] + + def execute(self, query, params=None): + self.executed_queries.append((query, params)) + if self._results_queue: + return self._results_queue.pop(0) + return self._next_result + + def cursor(self): + return self._cursor + + def set_next_result(self, result): + self._next_result = result + + def queue_result(self, result): + self._results_queue.append(result) + + +class MockAsyncConnection: + """Mock async database connection.""" + + def __init__(self): + self.executed_queries = [] + self._cursor = MockAsyncCursor() + self._next_result = MockAsyncResult() + self._results_queue = [] + + async def execute(self, query, params=None): + self.executed_queries.append((query, params)) + if self._results_queue: + return self._results_queue.pop(0) + return self._next_result + + def cursor(self): + return self._cursor + + def set_next_result(self, result): + self._next_result = result + + def queue_result(self, result): + self._results_queue.append(result) + + +class MockConnectionPool: + """Mock connection pool for testing.""" + + def __init__(self, connection_value=None): + self.connection_value = connection_value or MockConnection() + self.conninfo = "" + + def __call__(self, *, conninfo, connection_class=None, **kwargs): + self.conninfo = conninfo + return self + + def connection(self): + class _Ctx: + def __init__(self, outer): + self.outer = outer + + def __enter__(self): + return self.outer.connection_value + + def __exit__(self, exc_type, exc, tb): + pass + + return _Ctx(self) + + +class MockAsyncConnectionPool: + """Mock async connection pool for testing.""" + + def __init__(self, connection_value=None): + self.connection_value = connection_value or MockAsyncConnection() + self.conninfo = "" + self._opened = False + self._closed = False + + def __call__(self, *, conninfo, connection_class=None, **kwargs): + self.conninfo = conninfo + return self + + async def open(self): + self._opened = True + + async def close(self): + self._closed = True + + def connection(self): + class _AsyncCtx: + def __init__(self, outer): + self.outer = outer + + async def __aenter__(self): + return self.outer.connection_value + + async def __aexit__(self, exc_type, exc, tb): + pass + + return _AsyncCtx(self) + + +@pytest.fixture(autouse=True) +def clear_pool_cache(): + """Clear the pool cache before each test.""" + _pool_cache.clear() + _async_pool_cache.clear() + yield + _pool_cache.clear() + _async_pool_cache.clear() + + +@pytest.fixture +def mock_workspace(): + """Create a mock workspace client.""" + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + return workspace + + +@pytest.fixture +def mock_connection(): + """Create a mock connection.""" + return MockConnection() + + +@pytest.fixture +def mock_async_connection(): + """Create a mock async connection.""" + return MockAsyncConnection() + + +@pytest.fixture +def mock_pool(mock_connection): + """Create a mock connection pool.""" + return MockConnectionPool(connection_value=mock_connection) + + +@pytest.fixture +def mock_async_pool(mock_async_connection): + """Create a mock async connection pool.""" + return MockAsyncConnectionPool(connection_value=mock_async_connection) + + +# ============================================================================= +# MemorySession Tests (Sync) +# ============================================================================= + + +def test_session_configures_lakebase(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that MemorySession correctly configures the Lakebase pool.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables already exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + assert ( + mock_pool.conninfo + == "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require" + ) + assert session.session_id == TEST_SESSION_ID + assert session.sessions_table == "agent_sessions" + assert session.messages_table == "agent_messages" + + +def test_session_creates_tables_on_init_when_not_exist( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test that MemorySession creates tables when they don't exist.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables don't exist (count=0) + mock_connection.queue_result(MockResult(rows=[{"cnt": 0}])) + # CREATE sessions table + mock_connection.queue_result(MockResult()) + # CREATE messages table + mock_connection.queue_result(MockResult()) + # INSERT session + mock_connection.queue_result(MockResult()) + + MemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Should have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + create_sessions_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_sessions" in q for q in queries + ) + create_messages_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_messages" in q for q in queries + ) + + assert create_sessions_found, "Should create sessions table" + assert create_messages_found, "Should create messages table" + + +def test_session_skips_table_creation_when_tables_exist( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test that MemorySession skips table creation when tables already exist.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: both tables exist (count=2) + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + # INSERT session (no CREATE TABLE calls) + mock_connection.queue_result(MockResult()) + + MemorySession( + session_id=TEST_SESSION_ID_3, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Should NOT have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + create_table_found = any("CREATE TABLE" in q for q in queries) + + assert not create_table_found, "Should not create tables when they already exist" + + +def test_session_ensures_session_record(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that MemorySession ensures the session record exists.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables already exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Find the INSERT INTO agent_sessions query + insert_queries = [ + (q, p) + for q, p in mock_connection.executed_queries + if "INSERT INTO" in query_to_string(q) and "agent_sessions" in query_to_string(q) + ] + + assert len(insert_queries) > 0, "Should insert session record" + query, params = insert_queries[0] + assert params[0] == TEST_SESSION_ID, "Should use correct session_id" + + +@pytest.mark.asyncio +async def test_get_items_empty_session(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test get_items returns empty list for new session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist, then INSERT session, then SELECT messages + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=[])) # SELECT messages + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + assert items == [] + + +@pytest.mark.asyncio +async def test_get_items_returns_parsed_json( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test get_items correctly parses JSON data.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + test_messages = [ + {"message_data": json.dumps({"role": "user", "content": "Hello"})}, + {"message_data": json.dumps({"role": "assistant", "content": "Hi there!"})}, + ] + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=test_messages)) # SELECT messages + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Hi there!" + + +@pytest.mark.asyncio +async def test_get_items_with_limit(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test get_items respects limit parameter.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result( + MockResult(rows=[{"message_data": json.dumps({"role": "user", "content": "Latest"})}]) + ) + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items(limit=1) + + assert len(items) == 1 + + # Verify the query used LIMIT + select_queries = [ + query_to_string(q) + for q, p in mock_connection.executed_queries + if "SELECT message_data" in query_to_string(q) + ] + assert any("LIMIT" in q for q in select_queries), "Should use LIMIT in query" + + +@pytest.mark.asyncio +async def test_add_items(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test add_items inserts messages correctly.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + test_items = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + await session.add_items(test_items) + + # Check that executemany was called on cursor + assert len(mock_connection._cursor.executed_queries) > 0 + query, params = mock_connection._cursor.executed_queries[-1] + query_str = query_to_string(query) + assert "INSERT INTO" in query_str and "agent_messages" in query_str + assert len(params) == 2 # Two items + + +@pytest.mark.asyncio +async def test_add_items_empty_list(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test add_items handles empty list gracefully.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + initial_query_count = len(mock_connection.executed_queries) + + await session.add_items([]) + + # Should not execute any additional queries for empty list + # (only the queries from init should be present) + assert len(mock_connection.executed_queries) == initial_query_count + + +@pytest.mark.asyncio +async def test_pop_item_returns_last_item(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test pop_item removes and returns the most recent item.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + # DELETE RETURNING result + mock_connection.queue_result( + MockResult( + rows=[{"message_data": json.dumps({"role": "assistant", "content": "Last msg"})}] + ) + ) + mock_connection.queue_result(MockResult()) # UPDATE session timestamp + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + popped = await session.pop_item() + + assert popped is not None + assert popped["role"] == "assistant" + assert popped["content"] == "Last msg" + + +@pytest.mark.asyncio +async def test_pop_item_returns_none_when_empty( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test pop_item returns None for empty session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=[])) # DELETE RETURNING - empty + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + popped = await session.pop_item() + + assert popped is None + + +@pytest.mark.asyncio +async def test_clear_session(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test clear_session deletes all messages for the session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rowcount=5)) # DELETE messages + mock_connection.queue_result(MockResult()) # UPDATE session timestamp + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + await session.clear_session() + + # Find the DELETE query + delete_queries = [ + (q, p) + for q, p in mock_connection.executed_queries + if "DELETE FROM" in query_to_string(q) + and "agent_messages" in query_to_string(q) + and "WHERE session_id" in query_to_string(q) + ] + + assert len(delete_queries) > 0, "Should execute DELETE query" + query, params = delete_queries[0] + assert params == (TEST_SESSION_ID,), "Should use correct session_id" + + +def test_custom_table_names(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that custom table names are used correctly.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables don't exist (custom names), so they will be created + mock_connection.queue_result(MockResult(rows=[{"cnt": 0}])) # tables don't exist + mock_connection.queue_result(MockResult()) # CREATE sessions + mock_connection.queue_result(MockResult()) # CREATE messages + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + assert session.sessions_table == "custom_sessions" + assert session.messages_table == "custom_messages" + + # Check that CREATE TABLE uses custom names + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + assert any("custom_sessions" in q for q in queries) + assert any("custom_messages" in q for q in queries) + + +def test_pool_caching(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that pools are cached and reused.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock for both session creations + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist + mock_connection.queue_result(MockResult()) # INSERT session 1 + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist + mock_connection.queue_result(MockResult()) # INSERT session 2 + + session1 = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="shared-instance", + workspace_client=mock_workspace, + ) + + session2 = MemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="shared-instance", + workspace_client=mock_workspace, + ) + + # Both sessions should share the same pool + assert session1._pool is session2._pool + + +@pytest.mark.asyncio +async def test_get_items_handles_dict_message_data( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test get_items handles message_data that's already a dict (not JSON string).""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Some database drivers return JSONB as dict directly + test_messages = [ + {"message_data": {"role": "user", "content": "Already parsed"}}, + ] + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=test_messages)) + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + + assert len(items) == 1 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Already parsed" + + +# ============================================================================= +# AsyncMemorySession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_async_session_lazy_initialization( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession initializes lazily on first use.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + # Create session - should NOT trigger any DB operations yet + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # No queries should have been executed yet + assert len(mock_async_connection.executed_queries) == 0 + assert not session._initialized + + # Mock: tables exist, then INSERT session, then SELECT messages + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT messages + + # First use triggers initialization + await session.get_items() + + # Now it should be initialized + assert session._initialized + assert mock_async_pool._opened + + +@pytest.mark.asyncio +async def test_async_session_configures_lakebase( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession correctly configures the Lakebase pool.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables already exist + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + # Trigger initialization + await session.get_items() + + assert ( + mock_async_pool.conninfo + == "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require" + ) + assert session.session_id == TEST_SESSION_ID + assert session.sessions_table == "agent_sessions" + assert session.messages_table == "agent_messages" + + +@pytest.mark.asyncio +async def test_async_session_creates_tables_when_not_exist( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession creates tables when they don't exist.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables don't exist (count=0) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 0}])) + mock_async_connection.queue_result(MockAsyncResult()) # CREATE sessions table + mock_async_connection.queue_result(MockAsyncResult()) # CREATE messages table + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + # Should have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + create_sessions_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_sessions" in q for q in queries + ) + create_messages_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_messages" in q for q in queries + ) + + assert create_sessions_found, "Should create sessions table" + assert create_messages_found, "Should create messages table" + + +@pytest.mark.asyncio +async def test_async_session_skips_table_creation_when_tables_exist( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession skips table creation when tables already exist.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: both tables exist (count=2) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + # Should NOT have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + create_table_found = any("CREATE TABLE" in q for q in queries) + + assert not create_table_found, "Should not create tables when they already exist" + + +@pytest.mark.asyncio +async def test_async_get_items_empty_session( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async get_items returns empty list for new session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT messages + + items = await session.get_items() + assert items == [] + + +@pytest.mark.asyncio +async def test_async_get_items_returns_parsed_json( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async get_items correctly parses JSON data.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + test_messages = [ + {"message_data": json.dumps({"role": "user", "content": "Hello"})}, + {"message_data": json.dumps({"role": "assistant", "content": "Hi there!"})}, + ] + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=test_messages)) # SELECT messages + + items = await session.get_items() + + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Hi there!" + + +@pytest.mark.asyncio +async def test_async_add_items(monkeypatch, mock_workspace, mock_async_pool, mock_async_connection): + """Test async add_items inserts messages correctly.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables exist + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + + test_items = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + await session.add_items(test_items) + + # Check that executemany was called on cursor + assert len(mock_async_connection._cursor.executed_queries) > 0 + query, params = mock_async_connection._cursor.executed_queries[-1] + query_str = query_to_string(query) + assert "INSERT INTO" in query_str and "agent_messages" in query_str + assert len(params) == 2 # Two items + + +@pytest.mark.asyncio +async def test_async_pop_item_returns_last_item( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async pop_item removes and returns the most recent item.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + # DELETE RETURNING result + mock_async_connection.queue_result( + MockAsyncResult( + rows=[{"message_data": json.dumps({"role": "assistant", "content": "Last msg"})}] + ) + ) + mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp + + popped = await session.pop_item() + + assert popped is not None + assert popped["role"] == "assistant" + assert popped["content"] == "Last msg" + + +@pytest.mark.asyncio +async def test_async_pop_item_returns_none_when_empty( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async pop_item returns None for empty session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # DELETE RETURNING - empty + + popped = await session.pop_item() + + assert popped is None + + +@pytest.mark.asyncio +async def test_async_clear_session( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async clear_session deletes all messages for the session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rowcount=5)) # DELETE messages + mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp + + await session.clear_session() + + # Find the DELETE query + delete_queries = [ + (q, p) + for q, p in mock_async_connection.executed_queries + if "DELETE FROM" in query_to_string(q) + and "agent_messages" in query_to_string(q) + and "WHERE session_id" in query_to_string(q) + ] + + assert len(delete_queries) > 0, "Should execute DELETE query" + query, params = delete_queries[0] + assert params == (TEST_SESSION_ID,), "Should use correct session_id" + + +@pytest.mark.asyncio +async def test_async_custom_table_names( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that async session uses custom table names correctly.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + # Mock: tables don't exist (custom names), so they will be created + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 0}])) # tables don't exist + mock_async_connection.queue_result(MockAsyncResult()) # CREATE sessions + mock_async_connection.queue_result(MockAsyncResult()) # CREATE messages + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + assert session.sessions_table == "custom_sessions" + assert session.messages_table == "custom_messages" + + # Check that CREATE TABLE uses custom names + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + assert any("custom_sessions" in q for q in queries) + assert any("custom_messages" in q for q in queries) + + +@pytest.mark.asyncio +async def test_async_pool_caching( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that async pools are cached and reused.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session1 = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="shared-async-instance", + workspace_client=mock_workspace, + ) + + session2 = AsyncMemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="shared-async-instance", + workspace_client=mock_workspace, + ) + + # Mock for initialization of first session + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session 1 + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT + + await session1.get_items() + + # Mock for second session (tables already exist check, but pool should be reused) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session 2 + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT + + await session2.get_items() + + # Both sessions should share the same pool + assert session1._pool is session2._pool \ No newline at end of file From 02f7f188049629d982bf6e00c1b70ca2d3b29bbf Mon Sep 17 00:00:00 2001 From: Jenny Date: Fri, 16 Jan 2026 10:41:30 -0800 Subject: [PATCH 09/16] update dependencies after dbaibridge release --- integrations/langchain/pyproject.toml | 2 +- integrations/openai/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/pyproject.toml b/integrations/langchain/pyproject.toml index 10ee6e01c..4fa284b39 100644 --- a/integrations/langchain/pyproject.toml +++ b/integrations/langchain/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ [project.optional-dependencies] memory = [ "langgraph-checkpoint-postgres>=2.0.5", - "databricks-ai-bridge[memory]>=0.11.0", + "databricks-ai-bridge[memory]>=0.10.0", ] [dependency-groups] diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 492f3810a..52893cc14 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [project.optional-dependencies] memory = [ - "databricks-ai-bridge[memory]>=0.11.0", + "databricks-ai-bridge[memory]>=0.12.0", ] [dependency-groups] From c88c3178d596cea65065df10c18693efde645045 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 10:15:24 -0800 Subject: [PATCH 10/16] handle openai omit sentinel --- .../src/databricks_openai/utils/clients.py | 7 ++++--- .../openai/tests/unit_tests/test_clients.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 99bb9d08c..fca933824 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -18,14 +18,15 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield request -def _strip_strict_from_tools(tools: list | None) -> list | None: +def _strip_strict_from_tools(tools) -> list | None: """Remove 'strict' field from tool function definitions. Databricks model endpoints (except GPT) don't support the 'strict' field in tool schemas, but openai-agents SDK v0.6.4+ includes it. """ - if tools is None: - return None + # Handle None or OpenAI's Omit sentinel (non-iterable placeholder) + if tools is None or not isinstance(tools, list): + return tools for tool in tools: if isinstance(tool, dict) and "function" in tool: tool.get("function", {}).pop("strict", None) diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index df44f276e..6f30d5b6a 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -126,6 +126,22 @@ def test_strip_strict_from_tools_handles_none(self): assert _strip_strict_from_tools(None) is None + def test_strip_strict_from_tools_handles_non_list(self): + """Test that non-list values (like OpenAI's Omit sentinel) are passed through.""" + from databricks_openai.utils.clients import _strip_strict_from_tools + + # Simulate OpenAI's Omit sentinel (a non-iterable placeholder) + class Omit: + pass + + omit_sentinel = Omit() + result = _strip_strict_from_tools(omit_sentinel) + assert result is omit_sentinel # Should return unchanged + + # Also test with other non-list types + assert _strip_strict_from_tools("not a list") == "not a list" + assert _strip_strict_from_tools(123) == 123 + def test_strip_strict_from_tools_handles_empty_list(self): from databricks_openai.utils.clients import _strip_strict_from_tools From 8af9db38b37d09e05b91f73010f23f21bdcf15d8 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 10:38:24 -0800 Subject: [PATCH 11/16] type checking/format checks --- .../openai/tests/unit_tests/test_session.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 26f2035d1..dede5e8fd 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from typing import TYPE_CHECKING, Any, cast from unittest.mock import MagicMock from uuid import UUID @@ -11,7 +12,11 @@ pytest.importorskip("agents.memory.session") from databricks_ai_bridge import lakebase -from psycopg import sql + +if TYPE_CHECKING: + from psycopg import sql +else: + from psycopg import sql from databricks_openai.agents.session import ( AsyncMemorySession, @@ -412,7 +417,7 @@ async def test_get_items_returns_parsed_json( workspace_client=mock_workspace, ) - items = await session.get_items() + items = cast(list[dict[str, Any]], await session.get_items()) assert len(items) == 2 assert items[0]["role"] == "user" @@ -466,12 +471,12 @@ async def test_add_items(monkeypatch, mock_workspace, mock_pool, mock_connection workspace_client=mock_workspace, ) - test_items = [ + test_items: list[dict[str, Any]] = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}, ] - await session.add_items(test_items) + await session.add_items(cast(Any, test_items)) # Check that executemany was called on cursor assert len(mock_connection._cursor.executed_queries) > 0 @@ -526,7 +531,7 @@ async def test_pop_item_returns_last_item(monkeypatch, mock_workspace, mock_pool workspace_client=mock_workspace, ) - popped = await session.pop_item() + popped = cast(dict[str, Any], await session.pop_item()) assert popped is not None assert popped["role"] == "assistant" @@ -662,7 +667,7 @@ async def test_get_items_handles_dict_message_data( workspace_client=mock_workspace, ) - items = await session.get_items() + items = cast(list[dict[str, Any]], await session.get_items()) assert len(items) == 1 assert items[0]["role"] == "user" @@ -840,7 +845,7 @@ async def test_async_get_items_returns_parsed_json( mock_async_connection.queue_result(MockAsyncResult()) # INSERT session mock_async_connection.queue_result(MockAsyncResult(rows=test_messages)) # SELECT messages - items = await session.get_items() + items = cast(list[dict[str, Any]], await session.get_items()) assert len(items) == 2 assert items[0]["role"] == "user" @@ -864,12 +869,12 @@ async def test_async_add_items(monkeypatch, mock_workspace, mock_async_pool, moc mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) mock_async_connection.queue_result(MockAsyncResult()) # INSERT session - test_items = [ + test_items: list[dict[str, Any]] = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}, ] - await session.add_items(test_items) + await session.add_items(cast(Any, test_items)) # Check that executemany was called on cursor assert len(mock_async_connection._cursor.executed_queries) > 0 @@ -902,7 +907,7 @@ async def test_async_pop_item_returns_last_item( ) mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp - popped = await session.pop_item() + popped = cast(dict[str, Any], await session.pop_item()) assert popped is not None assert popped["role"] == "assistant" @@ -1032,4 +1037,4 @@ async def test_async_pool_caching( await session2.get_items() # Both sessions should share the same pool - assert session1._pool is session2._pool \ No newline at end of file + assert session1._pool is session2._pool From 1ae89352391d1a6ec5101d3605d674927b42f25b Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 11:53:01 -0800 Subject: [PATCH 12/16] add memory subpackage dependency in dev --- integrations/openai/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 52893cc14..5d576d6bd 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -31,6 +31,8 @@ dev = [ "databricks-sdk>=0.34.0", "ruff==0.14.10", "ty>=0.0.11", + # memory subpackage dependencies + "databricks-ai-bridge[memory]>=0.12.0", { include-group = "tests" }, ] From 38a7e6077014aa42fa4d66599be8c6faf57cf8dd Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 17:36:36 -0800 Subject: [PATCH 13/16] remove memory subpackage in pyproject --- integrations/openai/pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 5d576d6bd..52893cc14 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -31,8 +31,6 @@ dev = [ "databricks-sdk>=0.34.0", "ruff==0.14.10", "ty>=0.0.11", - # memory subpackage dependencies - "databricks-ai-bridge[memory]>=0.12.0", { include-group = "tests" }, ] From e4498302c1ef5fd0936e0e7a897f28484f3a6fea Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 21 Jan 2026 18:02:58 -0800 Subject: [PATCH 14/16] allow lakebasepool to resolve hostname --- src/databricks_ai_bridge/lakebase.py | 90 +++++++++++++---- tests/databricks_ai_bridge/test_lakebase.py | 101 ++++++++++++++++++++ 2 files changed, 173 insertions(+), 18 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 1cb0e83fa..c2649b877 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -97,34 +97,88 @@ def __init__( token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: self.workspace_client: WorkspaceClient = workspace_client or WorkspaceClient() - self.instance_name: str = instance_name self.token_cache_duration_seconds: int = token_cache_duration_seconds - # Resolve host from the Lakebase name - try: - instance = self.workspace_client.database.get_database_instance(instance_name) - except Exception as exc: - raise ValueError( - f"Unable to resolve Lakebase instance '{instance_name}'. " - "Ensure the instance name is correct." - ) from exc - - resolved_host = getattr(instance, "read_write_dns", None) or getattr( - instance, "read_only_dns", None - ) + # If input is hostname (e.g., from Databricks Apps valueFrom resolution) + # resolve to lakebase name + if self._is_hostname(instance_name): + # Input is a hostname - resolve to instance name + self.instance_name, self.host = self._resolve_from_hostname(instance_name) + else: + # Input is an instance name + self.instance_name = instance_name + try: + instance = self.workspace_client.database.get_database_instance(instance_name) + except Exception as exc: + raise ValueError( + f"Unable to resolve Lakebase instance '{instance_name}'. " + "Ensure the instance name is correct." + ) from exc - if not resolved_host: - raise ValueError( - f"Lakebase host not found for instance '{instance_name}'. " - "Ensure the instance is running and in AVAILABLE state." + resolved_host = getattr(instance, "read_write_dns", None) or getattr( + instance, "read_only_dns", None ) - self.host: str = resolved_host + if not resolved_host: + raise ValueError( + f"Lakebase host not found for instance '{instance_name}'. " + "Ensure the instance is running and in AVAILABLE state." + ) + + self.host = resolved_host + self.username: str = self._infer_username() self._cached_token: str | None = None self._cache_ts: float | None = None + @staticmethod + def _is_hostname(value: str) -> bool: + """Check if the value looks like a Lakebase hostname rather than an instance name.""" + # Hostname pattern: instance-{uuid}.database.{env}.cloud.databricks.com + # or similar patterns containing ".database." and ending with a domain + return ".database." in value and (value.endswith(".com") or value.endswith(".net")) + + def _resolve_from_hostname(self, hostname: str) -> tuple[str, str]: + """ + Resolve instance name from a hostname by listing database instances. + + Args: + hostname: The database hostname (e.g., from Databricks Apps valueFrom: "database") + + Returns: + Tuple of (instance_name, host) + + Raises: + ValueError: If no matching instance is found + """ + try: + instances = list(self.workspace_client.database.list_database_instances()) + except Exception as exc: + raise ValueError( + f"Unable to list database instances to resolve hostname '{hostname}'. " + "Ensure you have access to database instances." + ) from exc + + # Find the instance that matches this hostname + for instance in instances: + rw_dns = getattr(instance, "read_write_dns", None) + ro_dns = getattr(instance, "read_only_dns", None) + + if hostname in (rw_dns, ro_dns): + instance_name = getattr(instance, "name", None) + if not instance_name: + raise ValueError( + f"Found matching instance for hostname '{hostname}' " + "but instance name is not available." + ) + return instance_name, hostname + + raise ValueError( + f"Unable to find database instance matching hostname '{hostname}'. " + "Ensure the hostname is correct and the instance exists." + ) + def _get_cached_token(self) -> str | None: """Check if the cached token is still valid.""" if not self._cached_token or not self._cache_ts: diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index b914f7217..a831f8f9f 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1056,3 +1056,104 @@ def test_execute_grant_handles_insufficient_privilege(self): error_msg = str(exc_info.value) assert "Insufficient privileges" in error_msg assert "CAN MANAGE" in error_msg + + +# ============================================================================= +# Hostname Resolution Tests +# ============================================================================= + + +def test_is_hostname_detects_database_hostname(): + """Test that _is_hostname correctly identifies database hostnames.""" + from databricks_ai_bridge.lakebase import _LakebasePoolBase + + # Should be detected as hostnames + assert _LakebasePoolBase._is_hostname( + "instance-f757b615-f2fd-4614-87cc-9ba35f2eeb61.database.staging.cloud.databricks.com" + ) + assert _LakebasePoolBase._is_hostname("instance-abc123.database.prod.cloud.databricks.com") + assert _LakebasePoolBase._is_hostname("my-db.database.example.net") + + # Should NOT be detected as hostnames (regular instance names) + assert not _LakebasePoolBase._is_hostname("lakebase") + assert not _LakebasePoolBase._is_hostname("my-database-instance") + assert not _LakebasePoolBase._is_hostname("production_db") + + +def test_lakebase_pool_accepts_hostname(monkeypatch): + """Test that LakebasePool accepts hostname and resolves instance name.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace() + + # Mock list_database_instances to return an instance matching the hostname + hostname = "instance-abc123.database.staging.cloud.databricks.com" + mock_instance = MagicMock() + mock_instance.name = "my-lakebase-instance" + mock_instance.read_write_dns = hostname + mock_instance.read_only_dns = None + workspace.database.list_database_instances.return_value = [mock_instance] + + pool = LakebasePool( + instance_name=hostname, # Pass hostname instead of instance name + workspace_client=workspace, + ) + + # Should have resolved to the instance name + assert pool.instance_name == "my-lakebase-instance" + assert pool.host == hostname + + # get_database_instance should NOT have been called (we used list instead) + workspace.database.get_database_instance.assert_not_called() + + +def test_lakebase_pool_hostname_not_found_raises_error(monkeypatch): + """Test that LakebasePool raises error when hostname doesn't match any instance.""" + TestConnectionPool = _make_connection_pool_class() + monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) + + workspace = _make_workspace() + + # Mock list_database_instances to return instances that don't match + other_instance = MagicMock() + other_instance.name = "other-instance" + other_instance.read_write_dns = "other-host.database.staging.cloud.databricks.com" + other_instance.read_only_dns = None + workspace.database.list_database_instances.return_value = [other_instance] + + hostname = "instance-not-found.database.staging.cloud.databricks.com" + + with pytest.raises(ValueError, match="Unable to find database instance matching hostname"): + LakebasePool( + instance_name=hostname, + workspace_client=workspace, + ) + + +@pytest.mark.asyncio +async def test_async_lakebase_pool_accepts_hostname(monkeypatch): + """Test that AsyncLakebasePool accepts hostname and resolves instance name.""" + TestAsyncConnectionPool = _make_async_connection_pool_class() + monkeypatch.setattr( + "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool + ) + + workspace = _make_workspace() + + # Mock list_database_instances to return an instance matching the hostname + hostname = "instance-xyz789.database.prod.cloud.databricks.com" + mock_instance = MagicMock() + mock_instance.name = "prod-lakebase" + mock_instance.read_write_dns = hostname + mock_instance.read_only_dns = None + workspace.database.list_database_instances.return_value = [mock_instance] + + pool = AsyncLakebasePool( + instance_name=hostname, # Pass hostname instead of instance name + workspace_client=workspace, + ) + + # Should have resolved to the instance name + assert pool.instance_name == "prod-lakebase" + assert pool.host == hostname From bb3fc6566aa1e158f950821ed4eff139dc3ff3c7 Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 27 Jan 2026 15:56:02 -0800 Subject: [PATCH 15/16] add integration tests for lakebase memory session --- .../integration_tests/test_memory_session.py | 408 ++++++++++++++++++ 1 file changed, 408 insertions(+) create mode 100644 integrations/openai/tests/integration_tests/test_memory_session.py diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py new file mode 100644 index 000000000..8ab74a006 --- /dev/null +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -0,0 +1,408 @@ +"""Integration tests for MemorySession and AsyncMemorySession. + +These tests require: +1. A Lakebase instance to be available +2. Valid Databricks authentication (DATABRICKS_HOST + DATABRICKS_TOKEN as env variables) + +Set the environment variable: + LAKEBASE_INSTANCE_NAME: Name or hostname of the Lakebase instance + +Example: + LAKEBASE_INSTANCE_NAME=lakebase pytest tests/integration_tests/test_memory_session.py -v +""" + +from __future__ import annotations + +import asyncio +import os +import uuid + +import pytest + +# Skip all tests if LAKEBASE_INSTANCE_NAME is not set +pytestmark = pytest.mark.skipif( + not os.environ.get("LAKEBASE_INSTANCE_NAME"), + reason="LAKEBASE_INSTANCE_NAME environment variable not set", +) + + +def get_instance_name() -> str: + """Get the Lakebase instance name from environment.""" + return os.environ["LAKEBASE_INSTANCE_NAME"] + + +def get_unique_table_names() -> tuple[str, str]: + """Generate unique table names for test isolation.""" + suffix = uuid.uuid4().hex[:8] + return f"test_sessions_{suffix}", f"test_messages_{suffix}" + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_pool_cache(): + """Session-scoped fixture to close cached pools after all tests complete.""" + yield + + # Close sync pool cache + from databricks_openai.agents import session as session_module + + for pool in session_module._pool_cache.values(): + try: + pool.close() + except Exception: + pass + session_module._pool_cache.clear() + + # Close async pool cache - need to handle event loop carefully + for pool in list(session_module._async_pool_cache.values()): + try: + # Access the underlying pool and close it synchronously if possible + # The pool's _pool attribute is the actual AsyncConnectionPool + if hasattr(pool, "_pool") and pool._pool is not None: + # Use wait=False to avoid blocking on workers + pool._pool.close(timeout=0) + except Exception: + pass + session_module._async_pool_cache.clear() + + +@pytest.fixture +def cleanup_tables(): + """Fixture to track and clean up test tables after tests.""" + tables_to_cleanup: list[tuple[str, str]] = [] + + yield tables_to_cleanup + + # Cleanup after test + if tables_to_cleanup: + from databricks_ai_bridge.lakebase import LakebasePool + + pool = LakebasePool(instance_name=get_instance_name()) + with pool.connection() as conn: + for sessions_table, messages_table in tables_to_cleanup: + # Drop messages first (foreign key constraint) + conn.execute(f"DROP TABLE IF EXISTS {messages_table}") + conn.execute(f"DROP TABLE IF EXISTS {sessions_table}") + + +# ============================================================================= +# Sync MemorySession Tests +# ============================================================================= + + +def test_memory_session_crud_operations(cleanup_tables): + """ + Comprehensive CRUD test for sync MemorySession. + + Tests the full lifecycle: + - clear_session() on fresh session + - get_items() returns empty list for new session + - add_items() stores messages + - get_items() retrieves stored messages + - get_items(limit=N) returns latest N items in order + - pop_item() removes and returns most recent item + - clear_session() removes all items + """ + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = uuid.uuid4() + session = MemorySession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Clear any existing data (should be no-op for new session) + asyncio.run(session.clear_session()) + + # Test get_items on empty session + items = asyncio.run(session.get_items()) + assert items == [], f"Expected empty list, got {items}" + + # Test add_items + test_items = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + asyncio.run(session.add_items(test_items)) + + # Test get_items returns what we added + items = asyncio.run(session.get_items()) + assert len(items) == 2, f"Expected 2 items, got {len(items)}" + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, how are you?" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "I'm doing well, thank you!" + + # Test get_items with limit - should return latest N items in chronological order + items = asyncio.run(session.get_items(limit=1)) + assert len(items) == 1, f"Expected 1 item with limit, got {len(items)}" + assert items[0]["role"] == "assistant" # Latest item + + # Test pop_item - removes and returns the last item + popped = asyncio.run(session.pop_item()) + assert popped is not None + assert popped["role"] == "assistant" # Should be the last item + + # Verify only 1 item remains + items = asyncio.run(session.get_items()) + assert len(items) == 1, f"Expected 1 item after pop, got {len(items)}" + assert items[0]["role"] == "user" + + # Test clear_session + asyncio.run(session.clear_session()) + items = asyncio.run(session.get_items()) + assert items == [], f"Expected empty after clear, got {items}" + + +def test_memory_session_multiple_sessions_isolated(cleanup_tables): + """Test that different session_ids have isolated data.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id_1 = uuid.uuid4() + session_id_2 = uuid.uuid4() + + session_1 = MemorySession( + session_id=session_id_1, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + session_2 = MemorySession( + session_id=session_id_2, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add different items to each session + asyncio.run(session_1.add_items([{"role": "user", "content": "Session 1 message"}])) + asyncio.run(session_2.add_items([{"role": "user", "content": "Session 2 message"}])) + + # Verify isolation + items_1 = asyncio.run(session_1.get_items()) + items_2 = asyncio.run(session_2.get_items()) + + assert len(items_1) == 1 + assert len(items_2) == 1 + assert items_1[0]["content"] == "Session 1 message" + assert items_2[0]["content"] == "Session 2 message" + + # Clear one session shouldn't affect the other + asyncio.run(session_1.clear_session()) + items_1 = asyncio.run(session_1.get_items()) + items_2 = asyncio.run(session_2.get_items()) + assert len(items_1) == 0 + assert len(items_2) == 1 + + +def test_memory_session_pop_empty_returns_none(cleanup_tables): + """Test that pop_item returns None on empty session.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Pop on empty session should return None + popped = asyncio.run(session.pop_item()) + assert popped is None + + +def test_memory_session_add_empty_items_noop(cleanup_tables): + """Test that add_items with empty list is a no-op.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add empty list - should not raise + asyncio.run(session.add_items([])) + + # Session should still be empty + items = asyncio.run(session.get_items()) + assert items == [] + + +# ============================================================================= +# Async AsyncMemorySession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_async_memory_session_crud_operations(cleanup_tables): + """ + Comprehensive CRUD test for AsyncMemorySession. + + Tests the full lifecycle: + - clear_session() on fresh session + - get_items() returns empty list for new session + - add_items() stores messages + - get_items() retrieves stored messages + - get_items(limit=N) returns latest N items in order + - pop_item() removes and returns most recent item + - clear_session() removes all items + """ + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = uuid.uuid4() + session = AsyncMemorySession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Clear any existing data (should be no-op for new session) + await session.clear_session() + + # Test get_items on empty session + items = await session.get_items() + assert items == [], f"Expected empty list, got {items}" + + # Test add_items + test_items = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + await session.add_items(test_items) + + # Test get_items returns what we added + items = await session.get_items() + assert len(items) == 2, f"Expected 2 items, got {len(items)}" + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, how are you?" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "I'm doing well, thank you!" + + # Test get_items with limit - should return latest N items in chronological order + items = await session.get_items(limit=1) + assert len(items) == 1, f"Expected 1 item with limit, got {len(items)}" + assert items[0]["role"] == "assistant" # Latest item + + # Test pop_item - removes and returns the last item + popped = await session.pop_item() + assert popped is not None + assert popped["role"] == "assistant" # Should be the last item + + # Verify only 1 item remains + items = await session.get_items() + assert len(items) == 1, f"Expected 1 item after pop, got {len(items)}" + assert items[0]["role"] == "user" + + # Test clear_session + await session.clear_session() + items = await session.get_items() + assert items == [], f"Expected empty after clear, got {items}" + + +@pytest.mark.asyncio +async def test_async_memory_session_multiple_sessions_isolated(cleanup_tables): + """Test that different session_ids have isolated data (async version).""" + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id_1 = uuid.uuid4() + session_id_2 = uuid.uuid4() + + session_1 = AsyncMemorySession( + session_id=session_id_1, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + session_2 = AsyncMemorySession( + session_id=session_id_2, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add different items to each session + await session_1.add_items([{"role": "user", "content": "Async Session 1 message"}]) + await session_2.add_items([{"role": "user", "content": "Async Session 2 message"}]) + + # Verify isolation + items_1 = await session_1.get_items() + items_2 = await session_2.get_items() + + assert len(items_1) == 1 + assert len(items_2) == 1 + assert items_1[0]["content"] == "Async Session 1 message" + assert items_2[0]["content"] == "Async Session 2 message" + + # Clear one session shouldn't affect the other + await session_1.clear_session() + items_1 = await session_1.get_items() + items_2 = await session_2.get_items() + assert len(items_1) == 0 + assert len(items_2) == 1 + + +@pytest.mark.asyncio +async def test_async_memory_session_pop_empty_returns_none(cleanup_tables): + """Test that pop_item returns None on empty session (async version).""" + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncMemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Pop on empty session should return None + popped = await session.pop_item() + assert popped is None + + +@pytest.mark.asyncio +async def test_async_memory_session_add_empty_items_noop(cleanup_tables): + """Test that add_items with empty list is a no-op (async version).""" + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncMemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add empty list - should not raise + await session.add_items([]) + + # Session should still be empty + items = await session.get_items() + assert items == [] From 3a4682b476e65b9db34347ce06f836bc812914da Mon Sep 17 00:00:00 2001 From: Jenny Date: Fri, 30 Jan 2026 15:55:41 -0800 Subject: [PATCH 16/16] remove lakebase hostname resolution changes --- src/databricks_ai_bridge/lakebase.py | 90 ++++------------- tests/databricks_ai_bridge/test_lakebase.py | 101 -------------------- 2 files changed, 18 insertions(+), 173 deletions(-) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index c2649b877..1cb0e83fa 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -97,88 +97,34 @@ def __init__( token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, ) -> None: self.workspace_client: WorkspaceClient = workspace_client or WorkspaceClient() + self.instance_name: str = instance_name self.token_cache_duration_seconds: int = token_cache_duration_seconds - # If input is hostname (e.g., from Databricks Apps valueFrom resolution) - # resolve to lakebase name - if self._is_hostname(instance_name): - # Input is a hostname - resolve to instance name - self.instance_name, self.host = self._resolve_from_hostname(instance_name) - else: - # Input is an instance name - self.instance_name = instance_name - try: - instance = self.workspace_client.database.get_database_instance(instance_name) - except Exception as exc: - raise ValueError( - f"Unable to resolve Lakebase instance '{instance_name}'. " - "Ensure the instance name is correct." - ) from exc - - resolved_host = getattr(instance, "read_write_dns", None) or getattr( - instance, "read_only_dns", None - ) + # Resolve host from the Lakebase name + try: + instance = self.workspace_client.database.get_database_instance(instance_name) + except Exception as exc: + raise ValueError( + f"Unable to resolve Lakebase instance '{instance_name}'. " + "Ensure the instance name is correct." + ) from exc - if not resolved_host: - raise ValueError( - f"Lakebase host not found for instance '{instance_name}'. " - "Ensure the instance is running and in AVAILABLE state." - ) + resolved_host = getattr(instance, "read_write_dns", None) or getattr( + instance, "read_only_dns", None + ) - self.host = resolved_host + if not resolved_host: + raise ValueError( + f"Lakebase host not found for instance '{instance_name}'. " + "Ensure the instance is running and in AVAILABLE state." + ) + self.host: str = resolved_host self.username: str = self._infer_username() self._cached_token: str | None = None self._cache_ts: float | None = None - @staticmethod - def _is_hostname(value: str) -> bool: - """Check if the value looks like a Lakebase hostname rather than an instance name.""" - # Hostname pattern: instance-{uuid}.database.{env}.cloud.databricks.com - # or similar patterns containing ".database." and ending with a domain - return ".database." in value and (value.endswith(".com") or value.endswith(".net")) - - def _resolve_from_hostname(self, hostname: str) -> tuple[str, str]: - """ - Resolve instance name from a hostname by listing database instances. - - Args: - hostname: The database hostname (e.g., from Databricks Apps valueFrom: "database") - - Returns: - Tuple of (instance_name, host) - - Raises: - ValueError: If no matching instance is found - """ - try: - instances = list(self.workspace_client.database.list_database_instances()) - except Exception as exc: - raise ValueError( - f"Unable to list database instances to resolve hostname '{hostname}'. " - "Ensure you have access to database instances." - ) from exc - - # Find the instance that matches this hostname - for instance in instances: - rw_dns = getattr(instance, "read_write_dns", None) - ro_dns = getattr(instance, "read_only_dns", None) - - if hostname in (rw_dns, ro_dns): - instance_name = getattr(instance, "name", None) - if not instance_name: - raise ValueError( - f"Found matching instance for hostname '{hostname}' " - "but instance name is not available." - ) - return instance_name, hostname - - raise ValueError( - f"Unable to find database instance matching hostname '{hostname}'. " - "Ensure the hostname is correct and the instance exists." - ) - def _get_cached_token(self) -> str | None: """Check if the cached token is still valid.""" if not self._cached_token or not self._cache_ts: diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index a831f8f9f..b914f7217 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1056,104 +1056,3 @@ def test_execute_grant_handles_insufficient_privilege(self): error_msg = str(exc_info.value) assert "Insufficient privileges" in error_msg assert "CAN MANAGE" in error_msg - - -# ============================================================================= -# Hostname Resolution Tests -# ============================================================================= - - -def test_is_hostname_detects_database_hostname(): - """Test that _is_hostname correctly identifies database hostnames.""" - from databricks_ai_bridge.lakebase import _LakebasePoolBase - - # Should be detected as hostnames - assert _LakebasePoolBase._is_hostname( - "instance-f757b615-f2fd-4614-87cc-9ba35f2eeb61.database.staging.cloud.databricks.com" - ) - assert _LakebasePoolBase._is_hostname("instance-abc123.database.prod.cloud.databricks.com") - assert _LakebasePoolBase._is_hostname("my-db.database.example.net") - - # Should NOT be detected as hostnames (regular instance names) - assert not _LakebasePoolBase._is_hostname("lakebase") - assert not _LakebasePoolBase._is_hostname("my-database-instance") - assert not _LakebasePoolBase._is_hostname("production_db") - - -def test_lakebase_pool_accepts_hostname(monkeypatch): - """Test that LakebasePool accepts hostname and resolves instance name.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_workspace() - - # Mock list_database_instances to return an instance matching the hostname - hostname = "instance-abc123.database.staging.cloud.databricks.com" - mock_instance = MagicMock() - mock_instance.name = "my-lakebase-instance" - mock_instance.read_write_dns = hostname - mock_instance.read_only_dns = None - workspace.database.list_database_instances.return_value = [mock_instance] - - pool = LakebasePool( - instance_name=hostname, # Pass hostname instead of instance name - workspace_client=workspace, - ) - - # Should have resolved to the instance name - assert pool.instance_name == "my-lakebase-instance" - assert pool.host == hostname - - # get_database_instance should NOT have been called (we used list instead) - workspace.database.get_database_instance.assert_not_called() - - -def test_lakebase_pool_hostname_not_found_raises_error(monkeypatch): - """Test that LakebasePool raises error when hostname doesn't match any instance.""" - TestConnectionPool = _make_connection_pool_class() - monkeypatch.setattr("databricks_ai_bridge.lakebase.ConnectionPool", TestConnectionPool) - - workspace = _make_workspace() - - # Mock list_database_instances to return instances that don't match - other_instance = MagicMock() - other_instance.name = "other-instance" - other_instance.read_write_dns = "other-host.database.staging.cloud.databricks.com" - other_instance.read_only_dns = None - workspace.database.list_database_instances.return_value = [other_instance] - - hostname = "instance-not-found.database.staging.cloud.databricks.com" - - with pytest.raises(ValueError, match="Unable to find database instance matching hostname"): - LakebasePool( - instance_name=hostname, - workspace_client=workspace, - ) - - -@pytest.mark.asyncio -async def test_async_lakebase_pool_accepts_hostname(monkeypatch): - """Test that AsyncLakebasePool accepts hostname and resolves instance name.""" - TestAsyncConnectionPool = _make_async_connection_pool_class() - monkeypatch.setattr( - "databricks_ai_bridge.lakebase.AsyncConnectionPool", TestAsyncConnectionPool - ) - - workspace = _make_workspace() - - # Mock list_database_instances to return an instance matching the hostname - hostname = "instance-xyz789.database.prod.cloud.databricks.com" - mock_instance = MagicMock() - mock_instance.name = "prod-lakebase" - mock_instance.read_write_dns = hostname - mock_instance.read_only_dns = None - workspace.database.list_database_instances.return_value = [mock_instance] - - pool = AsyncLakebasePool( - instance_name=hostname, # Pass hostname instead of instance name - workspace_client=workspace, - ) - - # Should have resolved to the instance name - assert pool.instance_name == "prod-lakebase" - assert pool.host == hostname