From 3351a0955e00d485f435d8f5d5d238dda91394ff Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Feb 2026 13:32:36 -0800 Subject: [PATCH 01/17] memorysession subclassing SQLAlchemySession --- integrations/openai/pyproject.toml | 7 + .../src/databricks_openai/agents/session.py | 232 +++++++++ .../integration_tests/test_memory_session.py | 338 ++++++++++++ .../openai/tests/unit_tests/test_session.py | 491 ++++++++++++++++++ 4 files changed, 1068 insertions(+) create mode 100644 integrations/openai/src/databricks_openai/agents/session.py create mode 100644 integrations/openai/tests/integration_tests/test_memory_session.py create mode 100644 integrations/openai/tests/unit_tests/test_session.py diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 945c78a75..b14e2dde7 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -20,6 +20,13 @@ dependencies = [ "openai-agents>=0.5.0" ] +[project.optional-dependencies] +memory = [ + "sqlalchemy>=2.0.0", + "greenlet>=3.0.0", + "databricks-ai-bridge[memory]>=0.13.0", +] + [dependency-groups] dev = [ "typing_extensions>=4.15.0", diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py new file mode 100644 index 000000000..36a0ff604 --- /dev/null +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -0,0 +1,232 @@ +""" +MemorySession - SQLAlchemy-based session storage for Databricks Lakebase. + +This module provides a MemorySession class that subclasses OpenAI's SQLAlchemySession +to provide persistent conversation history storage in Databricks Lakebase. + +Usage:: + + from databricks_openai.agents.session import MemorySession + from agents import Agent, Runner + + session = MemorySession( + session_id="user-123", + instance_name="my-lakebase-instance", + ) + + agent = Agent(name="Assistant") + result = await Runner.run(agent, "Hello!", session=session) +""" + +from __future__ import annotations + +import logging +import time +from threading import Lock +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncEngine + +try: + from agents.extensions.memory import SQLAlchemySession + from databricks.sdk import WorkspaceClient + from databricks_ai_bridge.lakebase import _LakebasePoolBase + from sqlalchemy import URL, event + from sqlalchemy.ext.asyncio import create_async_engine + from sqlalchemy.pool import NullPool +except ImportError as e: + raise ImportError( + "MemorySession requires databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) from e + +logger = logging.getLogger(__name__) + +# Constants for Lakebase connection +DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 50 * 60 # 50 minutes +DEFAULT_SSLMODE = "require" +DEFAULT_PORT = 5432 +DEFAULT_DATABASE = "databricks_postgres" + + +class _LakebaseCredentials(_LakebasePoolBase): + """ + Lightweight credential provider that reuses _LakebasePoolBase for: + - Instance name → host resolution + - Username inference from workspace client + - Token minting and caching + + Does NOT create a connection pool - just provides credentials for SQLAlchemy. + """ + + def __init__( + self, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + ) -> None: + super().__init__( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + ) + self._cache_lock = Lock() + + def get_token(self) -> str: + """Get cached token or mint a new one (thread-safe).""" + with self._cache_lock: + if cached := self._get_cached_token(): + return cached + token = self._mint_token() + self._cached_token = token + self._cache_ts = time.time() + return token + + +class MemorySession(SQLAlchemySession): + """ + OpenAI Agents SDK Session implementation for Databricks Lakebase. + + This class subclasses SQLAlchemySession to provide: + - Lakebase instance resolution + - OAuth token rotation for authentication + - SQL logic inherited from SQLAlchemySession + + The session stores conversation history in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + Example: + ```python + from databricks_openai.agents.session import MemorySession + from agents import Agent, Runner + + async def run_agent(session_id: str, message: str): + session = MemorySession( + session_id=session_id, + instance_name="my-lakebase-instance", + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + + For more information on the Session protocol, see: + https://openai.github.io/openai-agents-python/ref/memory/session/ + """ + + def __init__( + self, + session_id: str, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + create_tables: bool = True, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + **engine_kwargs, + ) -> None: + """ + Initialize a MemorySession for Databricks Lakebase. + + Args: + session_id: Unique identifier for the conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + If not provided, a default client will be created. + token_cache_duration_seconds: How long to cache OAuth tokens. + Defaults to 50 minutes. + create_tables: Whether to auto-create tables on first use. + Defaults to True. + sessions_table: Name of the sessions table. + Defaults to "agent_sessions". + messages_table: Name of the messages table. + Defaults to "agent_messages". + **engine_kwargs: Additional keyword arguments passed to + SQLAlchemy's create_async_engine(). + """ + # Create credential provider + self._credentials = _LakebaseCredentials( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + ) + + engine = self._create_engine(**engine_kwargs) + + # Initialize parent SQLAlchemySession - inherits all SQL logic + super().__init__( + session_id=session_id, + engine=engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + logger.info( + "MemorySession initialized: instance=%s host=%s session_id=%s", + instance_name, + self._credentials.host, + session_id, + ) + + def _create_engine(self, **engine_kwargs) -> "AsyncEngine": + """Create an AsyncEngine with do_connect event for token injection.""" + # https://docs.sqlalchemy.org/en/21/core/engines.html#creating-urls-programmatically + url = URL.create( + drivername="postgresql+psycopg", + username=self._credentials.username, + host=self._credentials.host, + port=DEFAULT_PORT, + database=DEFAULT_DATABASE, + ) + + # Create engine with NullPool to avoid connection caching issues + # Each operation gets a fresh connection with a fresh token check + engine = create_async_engine( + url, + poolclass=NullPool, + connect_args={"sslmode": DEFAULT_SSLMODE}, + **engine_kwargs, + ) + + # Attach event to inject Lakebase token before each connection + # Note: do_connect fires on sync_engine even for async operations + credentials = self._credentials + + @event.listens_for(engine.sync_engine, "do_connect") + def inject_lakebase_token(dialect, conn_rec, cargs, cparams): + cparams["password"] = credentials.get_token() + logger.debug("Injected Lakebase token for connection") + + return engine + + @property + def instance_name(self) -> str: + """The Lakebase instance name.""" + return self._credentials.instance_name + + @property + def host(self) -> str: + """The resolved Lakebase host.""" + return self._credentials.host + + @property + def username(self) -> str: + """The database username.""" + return self._credentials.username + + @property + def connection_url(self) -> str: + """The SQLAlchemy connection URL (without password, for debugging).""" + url = URL.create( + drivername="postgresql+psycopg", + username=self._credentials.username, + host=self._credentials.host, + port=DEFAULT_PORT, + database=DEFAULT_DATABASE, + query={"sslmode": DEFAULT_SSLMODE}, + ) + return url.render_as_string(hide_password=True) diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py new file mode 100644 index 000000000..1867c621b --- /dev/null +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -0,0 +1,338 @@ +""" +Integration tests for MemorySession. + +These tests require: +1. A Lakebase instance to be available +2. Valid Databricks authentication (DATABRICKS_HOST + DATABRICKS_TOKEN or profile) + +Set the environment variable: + LAKEBASE_INSTANCE_NAME: Name of the Lakebase instance + +Example: + LAKEBASE_INSTANCE_NAME=lakebase pytest tests/integration_tests/test_memory_session.py -v +""" + +from __future__ import annotations + +import asyncio +import os +import uuid +from typing import Any, cast + +import pytest + +# Skip all tests if LAKEBASE_INSTANCE_NAME is not set +pytestmark = pytest.mark.skipif( + not os.environ.get("LAKEBASE_INSTANCE_NAME"), + reason="LAKEBASE_INSTANCE_NAME environment variable not set", +) + + +def get_instance_name() -> str: + """Get the Lakebase instance name from environment.""" + return os.environ["LAKEBASE_INSTANCE_NAME"] + + +def get_unique_table_names() -> tuple[str, str]: + """Generate unique table names for test isolation.""" + suffix = uuid.uuid4().hex[:8] + return f"test_sessions_{suffix}", f"test_messages_{suffix}" + + +@pytest.fixture +def cleanup_tables(): + """Fixture to track and clean up test tables after tests.""" + tables_to_cleanup: list[tuple[str, str]] = [] + + yield tables_to_cleanup + + # Cleanup after test + if tables_to_cleanup: + from databricks_ai_bridge.lakebase import LakebasePool + + pool = LakebasePool(instance_name=get_instance_name()) + with pool.connection() as conn: + for sessions_table, messages_table in tables_to_cleanup: + # Drop messages first (foreign key constraint) + conn.execute(f"DROP TABLE IF EXISTS {messages_table}") + conn.execute(f"DROP TABLE IF EXISTS {sessions_table}") + pool.close() + + +# ============================================================================= +# MemorySession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_memory_session_crud_operations(cleanup_tables): + """ + Comprehensive CRUD test for MemorySession. + + Tests the full lifecycle: + - clear_session() on fresh session + - get_items() returns empty list for new session + - add_items() stores messages + - get_items() retrieves stored messages + - get_items(limit=N) returns latest N items in order + - pop_item() removes and returns most recent item + - clear_session() removes all items + """ + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = str(uuid.uuid4()) + session = MemorySession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Clear any existing data (should be no-op for new session) + await session.clear_session() + + # Test get_items on empty session + items = cast(list[Any], await session.get_items()) + assert items == [], f"Expected empty list, got {items}" + + # Test add_items + test_items: list[Any] = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + await session.add_items(test_items) + + # Test get_items returns what we added + items = cast(list[Any], await session.get_items()) + assert len(items) == 2, f"Expected 2 items, got {len(items)}" + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, how are you?" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "I'm doing well, thank you!" + + # Test get_items with limit - should return latest N items in chronological order + items = cast(list[Any], await session.get_items(limit=1)) + assert len(items) == 1, f"Expected 1 item with limit, got {len(items)}" + assert items[0]["role"] == "assistant" # Latest item + + # Test pop_item - removes and returns the last item + popped = cast(Any, await session.pop_item()) + assert popped is not None + assert popped["role"] == "assistant" # Should be the last item + + # Verify only 1 item remains + items = cast(list[Any], await session.get_items()) + assert len(items) == 1, f"Expected 1 item after pop, got {len(items)}" + assert items[0]["role"] == "user" + + # Test clear_session + await session.clear_session() + items = cast(list[Any], await session.get_items()) + assert items == [], f"Expected empty after clear, got {items}" + + +@pytest.mark.asyncio +async def test_memory_session_multiple_sessions_isolated(cleanup_tables): + """Test that different session_ids have isolated data.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id_1 = str(uuid.uuid4()) + session_id_2 = str(uuid.uuid4()) + + session_1 = MemorySession( + session_id=session_id_1, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + session_2 = MemorySession( + session_id=session_id_2, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + create_tables=False, # Tables already created by session_1 + ) + + # Add different items to each session + items_1_data: list[Any] = [{"role": "user", "content": "Session 1 message"}] + items_2_data: list[Any] = [{"role": "user", "content": "Session 2 message"}] + await session_1.add_items(items_1_data) + await session_2.add_items(items_2_data) + + # Verify isolation + items_1 = cast(list[Any], await session_1.get_items()) + items_2 = cast(list[Any], await session_2.get_items()) + + assert len(items_1) == 1 + assert len(items_2) == 1 + assert items_1[0]["content"] == "Session 1 message" + assert items_2[0]["content"] == "Session 2 message" + + # Clear one session shouldn't affect the other + await session_1.clear_session() + items_1 = cast(list[Any], await session_1.get_items()) + items_2 = cast(list[Any], await session_2.get_items()) + assert len(items_1) == 0 + assert len(items_2) == 1 + + # Cleanup + await session_2.clear_session() + + +@pytest.mark.asyncio +async def test_memory_session_pop_empty_returns_none(cleanup_tables): + """Test that pop_item returns None on empty session.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Pop on empty session should return None + popped = await session.pop_item() + assert popped is None + + +@pytest.mark.asyncio +async def test_memory_session_add_empty_items_noop(cleanup_tables): + """Test that add_items with empty list is a no-op.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add empty list - should not raise + await session.add_items([]) + + # Session should still be empty + items = cast(list[Any], await session.get_items()) + assert items == [] + + +@pytest.mark.asyncio +async def test_memory_session_complex_message_data(cleanup_tables): + """Test storing complex message data with nested structures.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add complex item with nested data + complex_item: dict[str, Any] = { + "role": "assistant", + "content": "Here's your result", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + } + ], + "metadata": { + "model": "gpt-4", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + }, + } + await session.add_items([complex_item]) + + # Retrieve and verify + items = cast(list[Any], await session.get_items()) + assert len(items) == 1 + assert items[0]["role"] == "assistant" + assert items[0]["tool_calls"][0]["function"]["name"] == "get_weather" + assert items[0]["metadata"]["usage"]["prompt_tokens"] == 10 + + # Cleanup + await session.clear_session() + + +@pytest.mark.asyncio +async def test_memory_session_properties(cleanup_tables): + """Test that session properties return correct values.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = str(uuid.uuid4()) + session = MemorySession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Verify properties + assert session.instance_name == get_instance_name() + assert session.host is not None + assert len(session.host) > 0 + assert session.username is not None + assert len(session.username) > 0 + assert session.session_id == session_id + + +@pytest.mark.asyncio +async def test_memory_session_get_items_ordering(cleanup_tables): + """Test that get_items returns items in correct chronological order.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add multiple items + test_items: list[Any] = [{"role": "user", "content": f"Message {i}"} for i in range(5)] + await session.add_items(test_items) + + # Get all items - should be in chronological order + items = cast(list[Any], await session.get_items()) + assert len(items) == 5 + for i, item in enumerate(items): + assert item["content"] == f"Message {i}" + + # Get with limit - should return LATEST N items in chronological order + items = cast(list[Any], await session.get_items(limit=3)) + assert len(items) == 3 + # Should be messages 2, 3, 4 (the latest 3) in order + assert items[0]["content"] == "Message 2" + assert items[1]["content"] == "Message 3" + assert items[2]["content"] == "Message 4" + + # Cleanup + await session.clear_session() diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py new file mode 100644 index 000000000..0bccc1b59 --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -0,0 +1,491 @@ +"""Unit tests for MemorySession.""" + +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def mock_workspace_client(): + """Create a mock WorkspaceClient.""" + mock_client = MagicMock() + mock_client.config.host = "https://test.databricks.com" + + # Mock current_user.me() for username inference + mock_user = MagicMock() + mock_user.user_name = "test_user@databricks.com" + mock_client.current_user.me.return_value = mock_user + + # Mock database.get_database_instance() for host resolution + mock_instance = MagicMock() + mock_instance.read_write_dns = "test-instance.lakebase.databricks.com" + mock_client.database.get_database_instance.return_value = mock_instance + + # Mock database.generate_database_credential() for token minting + mock_credential = MagicMock() + mock_credential.token = "test-oauth-token" + mock_client.database.generate_database_credential.return_value = mock_credential + + return mock_client + + +@pytest.fixture +def mock_engine(): + """Create a mock SQLAlchemy AsyncEngine.""" + mock_eng = MagicMock() + mock_eng.sync_engine = MagicMock() + return mock_eng + + +@pytest.fixture +def mock_event_listens_for(): + """Create a mock for event.listens_for that captures the handler.""" + + def create_decorator(engine, event_name): + def decorator(fn): + return fn + + return decorator + + return create_decorator + + +class TestLakebaseCredentials: + """Tests for _LakebaseCredentials class.""" + + def test_init_resolves_host(self, mock_workspace_client): + """Test that initialization resolves the Lakebase host.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + from databricks_openai.agents.session import _LakebaseCredentials + + creds = _LakebaseCredentials( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert creds.host == "test-instance.lakebase.databricks.com" + mock_workspace_client.database.get_database_instance.assert_called_once_with( + "test-instance" + ) + + def test_init_infers_username(self, mock_workspace_client): + """Test that initialization infers the username.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + from databricks_openai.agents.session import _LakebaseCredentials + + creds = _LakebaseCredentials( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert creds.username == "test_user@databricks.com" + mock_workspace_client.current_user.me.assert_called() + + def test_get_token_mints_new_token(self, mock_workspace_client): + """Test that get_token mints a new token when cache is empty.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + from databricks_openai.agents.session import _LakebaseCredentials + + creds = _LakebaseCredentials( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + token = creds.get_token() + + assert token == "test-oauth-token" + mock_workspace_client.database.generate_database_credential.assert_called_once() + + def test_get_token_returns_cached_token(self, mock_workspace_client): + """Test that get_token returns cached token when valid.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + from databricks_openai.agents.session import _LakebaseCredentials + + creds = _LakebaseCredentials( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # First call - mints token + token1 = creds.get_token() + # Second call - should return cached + token2 = creds.get_token() + + assert token1 == token2 == "test-oauth-token" + # Should only mint once + assert mock_workspace_client.database.generate_database_credential.call_count == 1 + + def test_get_token_refreshes_expired_token(self, mock_workspace_client): + """Test that get_token refreshes token when expired.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + import time + + from databricks_openai.agents.session import _LakebaseCredentials + + creds = _LakebaseCredentials( + instance_name="test-instance", + workspace_client=mock_workspace_client, + token_cache_duration_seconds=0, # Immediate expiration + ) + + # First call + creds.get_token() + # Wait a tiny bit + time.sleep(0.01) + # Second call - should mint new token + creds.get_token() + + # Should mint twice due to expiration + assert mock_workspace_client.database.generate_database_credential.call_count == 2 + + def test_init_raises_on_invalid_instance(self, mock_workspace_client): + """Test that initialization raises ValueError for invalid instance.""" + mock_workspace_client.database.get_database_instance.side_effect = Exception( + "Instance not found" + ) + + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + from databricks_openai.agents.session import _LakebaseCredentials + + with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): + _LakebaseCredentials( + instance_name="invalid-instance", + workspace_client=mock_workspace_client, + ) + + +class TestMemorySessionInit: + """Tests for MemorySession initialization.""" + + def test_init_creates_engine_with_correct_url( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization creates engine with correct connection URL.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Verify engine was created with correct URL object + call_args = mock_create_engine.call_args + url = call_args[0][0] + # URL.create() returns a URL object, check its properties + assert url.drivername == "postgresql+psycopg" + assert url.username == "test_user@databricks.com" + assert url.host == "test-instance.lakebase.databricks.com" + assert url.port == 5432 + assert url.database == "databricks_postgres" + + def test_init_uses_null_pool( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization uses NullPool to avoid connection caching.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession, NullPool + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + call_kwargs = mock_create_engine.call_args[1] + assert call_kwargs["poolclass"] == NullPool + + def test_init_sets_ssl_mode( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization sets SSL mode to require.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + call_kwargs = mock_create_engine.call_args[1] + assert call_kwargs["connect_args"]["sslmode"] == "require" + + def test_init_registers_do_connect_event(self, mock_workspace_client, mock_engine): + """Test that initialization registers do_connect event for token injection.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for" + ) as mock_listens_for: + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Verify event listener was registered on sync_engine for "do_connect" + mock_listens_for.assert_called_once_with(mock_engine.sync_engine, "do_connect") + + def test_init_with_custom_table_names( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test initialization with custom table names.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), patch( + "agents.extensions.memory.SQLAlchemySession.__init__", + return_value=None, + ) as mock_parent_init: + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + # Verify parent was called with custom table names + call_kwargs = mock_parent_init.call_args[1] + assert call_kwargs["sessions_table"] == "custom_sessions" + assert call_kwargs["messages_table"] == "custom_messages" + + def test_init_with_create_tables_false( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test initialization with create_tables=False.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), patch( + "agents.extensions.memory.SQLAlchemySession.__init__", + return_value=None, + ) as mock_parent_init: + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + create_tables=False, + ) + + call_kwargs = mock_parent_init.call_args[1] + assert call_kwargs["create_tables"] is False + + +class TestMemorySessionProperties: + """Tests for MemorySession properties.""" + + def test_instance_name_property( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test instance_name property returns correct value.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession + + session = MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert session.instance_name == "test-instance" + + def test_host_property( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test host property returns resolved host.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession + + session = MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert session.host == "test-instance.lakebase.databricks.com" + + def test_username_property( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test username property returns inferred username.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession + + session = MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert session.username == "test_user@databricks.com" + + +class TestMemorySessionTokenInjection: + """Tests for token injection via do_connect event.""" + + def test_do_connect_injects_token(self, mock_workspace_client, mock_engine): + """Test that do_connect event handler injects token into connection params.""" + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=capture_handler, + ): + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Simulate do_connect event + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + + # Verify token was injected + assert cparams["password"] == "test-oauth-token" + + +class TestMemorySessionEngineKwargs: + """Tests for passing additional engine kwargs.""" + + def test_extra_engine_kwargs_passed( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that extra kwargs are passed to create_async_engine.""" + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ): + from databricks_openai.agents.session import MemorySession + + MemorySession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, # Extra kwarg for SQLAlchemy + pool_pre_ping=True, # Another extra kwarg + ) + + call_kwargs = mock_create_engine.call_args[1] + assert call_kwargs.get("echo") is True + assert call_kwargs.get("pool_pre_ping") is True From 1e7b16559f1864489cfaf4505dde9a11f70a2f56 Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Feb 2026 13:49:11 -0800 Subject: [PATCH 02/17] fix type check issues --- .../src/databricks_openai/agents/session.py | 1 + .../integration_tests/test_memory_session.py | 1 - .../openai/tests/unit_tests/test_session.py | 274 ++++++++++-------- 3 files changed, 156 insertions(+), 120 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 36a0ff604..63bbe0cc0 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -103,6 +103,7 @@ class MemorySession(SQLAlchemySession): from databricks_openai.agents.session import MemorySession from agents import Agent, Runner + async def run_agent(session_id: str, message: str): session = MemorySession( session_id=session_id, diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py index 1867c621b..078080131 100644 --- a/integrations/openai/tests/integration_tests/test_memory_session.py +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -14,7 +14,6 @@ from __future__ import annotations -import asyncio import os import uuid from typing import Any, cast diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 0bccc1b59..d94139bd8 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -179,15 +179,19 @@ def test_init_creates_engine_with_correct_url( self, mock_workspace_client, mock_engine, mock_event_listens_for ): """Test that initialization creates engine with correct connection URL.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession @@ -207,19 +211,21 @@ def test_init_creates_engine_with_correct_url( assert url.port == 5432 assert url.database == "databricks_postgres" - def test_init_uses_null_pool( - self, mock_workspace_client, mock_engine, mock_event_listens_for - ): + def test_init_uses_null_pool(self, mock_workspace_client, mock_engine, mock_event_listens_for): """Test that initialization uses NullPool to avoid connection caching.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession, NullPool @@ -232,19 +238,21 @@ def test_init_uses_null_pool( call_kwargs = mock_create_engine.call_args[1] assert call_kwargs["poolclass"] == NullPool - def test_init_sets_ssl_mode( - self, mock_workspace_client, mock_engine, mock_event_listens_for - ): + def test_init_sets_ssl_mode(self, mock_workspace_client, mock_engine, mock_event_listens_for): """Test that initialization sets SSL mode to require.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession @@ -259,15 +267,17 @@ def test_init_sets_ssl_mode( def test_init_registers_do_connect_event(self, mock_workspace_client, mock_engine): """Test that initialization registers do_connect event for token injection.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for" - ) as mock_listens_for: + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch("databricks_openai.agents.session.event.listens_for") as mock_listens_for, + ): from databricks_openai.agents.session import MemorySession MemorySession( @@ -283,19 +293,24 @@ def test_init_with_custom_table_names( self, mock_workspace_client, mock_engine, mock_event_listens_for ): """Test initialization with custom table names.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, - ), patch( - "agents.extensions.memory.SQLAlchemySession.__init__", - return_value=None, - ) as mock_parent_init: + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + patch( + "agents.extensions.memory.SQLAlchemySession.__init__", + return_value=None, + ) as mock_parent_init, + ): from databricks_openai.agents.session import MemorySession MemorySession( @@ -315,19 +330,24 @@ def test_init_with_create_tables_false( self, mock_workspace_client, mock_engine, mock_event_listens_for ): """Test initialization with create_tables=False.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, - ), patch( - "agents.extensions.memory.SQLAlchemySession.__init__", - return_value=None, - ) as mock_parent_init: + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + patch( + "agents.extensions.memory.SQLAlchemySession.__init__", + return_value=None, + ) as mock_parent_init, + ): from databricks_openai.agents.session import MemorySession MemorySession( @@ -348,15 +368,19 @@ def test_instance_name_property( self, mock_workspace_client, mock_engine, mock_event_listens_for ): """Test instance_name property returns correct value.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession @@ -368,19 +392,21 @@ def test_instance_name_property( assert session.instance_name == "test-instance" - def test_host_property( - self, mock_workspace_client, mock_engine, mock_event_listens_for - ): + def test_host_property(self, mock_workspace_client, mock_engine, mock_event_listens_for): """Test host property returns resolved host.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession @@ -392,19 +418,21 @@ def test_host_property( assert session.host == "test-instance.lakebase.databricks.com" - def test_username_property( - self, mock_workspace_client, mock_engine, mock_event_listens_for - ): + def test_username_property(self, mock_workspace_client, mock_engine, mock_event_listens_for): """Test username property returns inferred username.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession @@ -432,15 +460,19 @@ def decorator(fn): return decorator - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=capture_handler, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=capture_handler, + ), ): from databricks_openai.agents.session import MemorySession @@ -466,15 +498,19 @@ def test_extra_engine_kwargs_passed( self, mock_workspace_client, mock_engine, mock_event_listens_for ): """Test that extra kwargs are passed to create_async_engine.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), ): from databricks_openai.agents.session import MemorySession From 9ac71d9b5bc4ea80f90de3a357a32aeeadfa5005 Mon Sep 17 00:00:00 2001 From: Jenny Date: Tue, 3 Feb 2026 14:53:07 -0800 Subject: [PATCH 03/17] update tests --- integrations/openai/pyproject.toml | 3 +-- .../tests/integration_tests/test_memory_session.py | 4 ++-- integrations/openai/tests/unit_tests/test_session.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index b14e2dde7..36887a7b1 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -22,8 +22,6 @@ dependencies = [ [project.optional-dependencies] memory = [ - "sqlalchemy>=2.0.0", - "greenlet>=3.0.0", "databricks-ai-bridge[memory]>=0.13.0", ] @@ -41,6 +39,7 @@ tests = [ "pytest-asyncio>=1.3.0", "pytest-timeout>=2.3.1", "pytest-cov>=4.1.0", + "databricks-ai-bridge[memory]>=0.13.0", ] [build-system] diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py index 078080131..b7de1912a 100644 --- a/integrations/openai/tests/integration_tests/test_memory_session.py +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -244,7 +244,7 @@ async def test_memory_session_complex_message_data(cleanup_tables): ) # Add complex item with nested data - complex_item: dict[str, Any] = { + complex_item = { "role": "assistant", "content": "Here's your result", "tool_calls": [ @@ -262,7 +262,7 @@ async def test_memory_session_complex_message_data(cleanup_tables): "usage": {"prompt_tokens": 10, "completion_tokens": 20}, }, } - await session.add_items([complex_item]) + await session.add_items([complex_item]) # type: ignore[list-item] # Retrieve and verify items = cast(list[Any], await session.get_items()) diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index d94139bd8..f2e6e9cb6 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -4,6 +4,16 @@ import pytest +try: + from databricks_ai_bridge.lakebase import _LakebasePoolBase # noqa: F401 + from psycopg.rows import DictRow # noqa: F401 + from sqlalchemy.ext.asyncio import AsyncEngine # noqa: F401 +except ImportError as e: + raise ImportError( + "MemorySession tests require databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) from e + @pytest.fixture def mock_workspace_client(): From 1433da04e6c2872b917cb4334640205ace2d5caf Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 4 Feb 2026 08:54:32 -0800 Subject: [PATCH 04/17] enable connection pooling using SQLAlchemy's default QueuePool --- .../src/databricks_openai/agents/session.py | 9 +++++---- .../openai/tests/unit_tests/test_session.py | 16 ++++++++++++---- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 63bbe0cc0..4828d53dc 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -34,7 +34,6 @@ from databricks_ai_bridge.lakebase import _LakebasePoolBase from sqlalchemy import URL, event from sqlalchemy.ext.asyncio import create_async_engine - from sqlalchemy.pool import NullPool except ImportError as e: raise ImportError( "MemorySession requires databricks-openai[memory]. " @@ -45,6 +44,7 @@ # Constants for Lakebase connection DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 50 * 60 # 50 minutes +DEFAULT_POOL_RECYCLE_SECONDS = 45 * 60 # 45 minutes (before token cache expires) DEFAULT_SSLMODE = "require" DEFAULT_PORT = 5432 DEFAULT_DATABASE = "databricks_postgres" @@ -184,11 +184,12 @@ def _create_engine(self, **engine_kwargs) -> "AsyncEngine": database=DEFAULT_DATABASE, ) - # Create engine with NullPool to avoid connection caching issues - # Each operation gets a fresh connection with a fresh token check + # Use default QueuePool with connection recycling. + # Connections are recycled before token cache expires (50 min), + # ensuring fresh tokens are injected via do_connect event. engine = create_async_engine( url, - poolclass=NullPool, + pool_recycle=DEFAULT_POOL_RECYCLE_SECONDS, connect_args={"sslmode": DEFAULT_SSLMODE}, **engine_kwargs, ) diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index f2e6e9cb6..25f93569e 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -221,8 +221,10 @@ def test_init_creates_engine_with_correct_url( assert url.port == 5432 assert url.database == "databricks_postgres" - def test_init_uses_null_pool(self, mock_workspace_client, mock_engine, mock_event_listens_for): - """Test that initialization uses NullPool to avoid connection caching.""" + def test_init_uses_pool_recycle( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization uses pool_recycle for connection recycling.""" with ( patch( "databricks_ai_bridge.lakebase.WorkspaceClient", @@ -237,7 +239,10 @@ def test_init_uses_null_pool(self, mock_workspace_client, mock_engine, mock_even side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession, NullPool + from databricks_openai.agents.session import ( + DEFAULT_POOL_RECYCLE_SECONDS, + MemorySession, + ) MemorySession( session_id="test-session-123", @@ -246,7 +251,10 @@ def test_init_uses_null_pool(self, mock_workspace_client, mock_engine, mock_even ) call_kwargs = mock_create_engine.call_args[1] - assert call_kwargs["poolclass"] == NullPool + # Should use pool_recycle for connection recycling (default QueuePool) + assert call_kwargs["pool_recycle"] == DEFAULT_POOL_RECYCLE_SECONDS + # Should NOT use NullPool (uses default QueuePool instead) + assert "poolclass" not in call_kwargs def test_init_sets_ssl_mode(self, mock_workspace_client, mock_engine, mock_event_listens_for): """Test that initialization sets SSL mode to require.""" From d1afb16486bae64ff956bf85cb9c940544ace0d1 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 4 Feb 2026 16:33:10 -0800 Subject: [PATCH 05/17] update to asyncdatabrickssession, class-level cache for engine, rename LakebasePoolBase to LakebaseBase --- .../src/databricks_openai/agents/__init__.py | 3 +- .../src/databricks_openai/agents/session.py | 180 +++++++--- .../integration_tests/test_memory_session.py | 36 +- .../openai/tests/unit_tests/test_session.py | 315 ++++++++++++++++-- src/databricks_ai_bridge/lakebase.py | 10 +- 5 files changed, 438 insertions(+), 106 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/__init__.py b/integrations/openai/src/databricks_openai/agents/__init__.py index f78fbd25b..2301fe320 100644 --- a/integrations/openai/src/databricks_openai/agents/__init__.py +++ b/integrations/openai/src/databricks_openai/agents/__init__.py @@ -1,3 +1,4 @@ from databricks_openai.agents.mcp_server import McpServer +from databricks_openai.agents.session import AsyncDatabricksSession -__all__ = ["McpServer"] +__all__ = ["McpServer", "AsyncDatabricksSession"] diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 4828d53dc..3c9c76b4b 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -1,21 +1,31 @@ """ -MemorySession - SQLAlchemy-based session storage for Databricks Lakebase. +AsyncDatabricksSession - Async SQLAlchemy-based session storage for Databricks Lakebase. -This module provides a MemorySession class that subclasses OpenAI's SQLAlchemySession +This module provides a AsyncDatabricksSession class that subclasses OpenAI's SQLAlchemySession to provide persistent conversation history storage in Databricks Lakebase. +Note: + This class is **async-only** as it follows the Session Protocol. Use within async context + https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.Session + Usage:: - from databricks_openai.agents.session import MemorySession + import asyncio + from databricks_openai.agents import AsyncDatabricksSession from agents import Agent, Runner - session = MemorySession( - session_id="user-123", - instance_name="my-lakebase-instance", - ) - agent = Agent(name="Assistant") - result = await Runner.run(agent, "Hello!", session=session) + async def main(): + session = AsyncDatabricksSession( + session_id="user-123", + instance_name="my-lakebase-instance", + ) + + agent = Agent(name="Assistant") + result = await Runner.run(agent, "Hello!", session=session) + + + asyncio.run(main()) """ from __future__ import annotations @@ -31,12 +41,12 @@ try: from agents.extensions.memory import SQLAlchemySession from databricks.sdk import WorkspaceClient - from databricks_ai_bridge.lakebase import _LakebasePoolBase + from databricks_ai_bridge.lakebase import _LakebaseBase from sqlalchemy import URL, event from sqlalchemy.ext.asyncio import create_async_engine except ImportError as e: raise ImportError( - "MemorySession requires databricks-openai[memory]. " + "AsyncDatabricksSession requires databricks-openai[memory]. " "Please install with: pip install databricks-openai[memory]" ) from e @@ -50,9 +60,9 @@ DEFAULT_DATABASE = "databricks_postgres" -class _LakebaseCredentials(_LakebasePoolBase): +class _LakebaseCredentials(_LakebaseBase): """ - Lightweight credential provider that reuses _LakebasePoolBase for: + Lightweight credential provider that reuses _LakebaseBase for: - Instance name → host resolution - Username inference from workspace client - Token minting and caching @@ -85,38 +95,57 @@ def get_token(self) -> str: return token -class MemorySession(SQLAlchemySession): +class AsyncDatabricksSession(SQLAlchemySession): """ - OpenAI Agents SDK Session implementation for Databricks Lakebase. + Async OpenAI Agents SDK Session implementation for Databricks Lakebase. This class subclasses SQLAlchemySession to provide: - Lakebase instance resolution - OAuth token rotation for authentication + - Connection pooling with automatic token refresh - SQL logic inherited from SQLAlchemySession + Note: + This class is **async-only**. All session methods (get_items, add_items, + clear_session, etc.) are coroutines and must be awaited. + + Note: + Engines are cached and reused across sessions with the same instance_name. + This means multiple AsyncDatabricksSession instances share a single connection pool, + rather than creating a new pool per session + The session stores conversation history in two tables: - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) - agent_messages: Stores conversation items (id, session_id, message_data, created_at) Example: ```python - from databricks_openai.agents.session import MemorySession + import asyncio + from databricks_openai.agents import AsyncDatabricksSession from agents import Agent, Runner - async def run_agent(session_id: str, message: str): - session = MemorySession( - session_id=session_id, + async def main(): + session = AsyncDatabricksSession( + session_id="user-123", instance_name="my-lakebase-instance", ) agent = Agent(name="Assistant") - return await Runner.run(agent, message, session=session) + result = await Runner.run(agent, "Hello!", session=session) + + + asyncio.run(main()) ``` For more information on the Session protocol, see: https://openai.github.io/openai-agents-python/ref/memory/session/ """ + # Class-level cache for engines and credentials, keyed by instance_name. + # This allows multiple AsyncDatabricksSession instances to share a single engine/pool. + _engine_cache: "dict[str, tuple[AsyncEngine, _LakebaseCredentials]]" = {} + _engine_cache_lock = Lock() + def __init__( self, session_id: str, @@ -130,7 +159,7 @@ def __init__( **engine_kwargs, ) -> None: """ - Initialize a MemorySession for Databricks Lakebase. + Initialize a AsyncDatabricksSession for Databricks Lakebase. Args: session_id: Unique identifier for the conversation session. @@ -148,14 +177,13 @@ def __init__( **engine_kwargs: Additional keyword arguments passed to SQLAlchemy's create_async_engine(). """ - # Create credential provider - self._credentials = _LakebaseCredentials( + engine, credentials = self._get_or_create_engine( instance_name=instance_name, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, + **engine_kwargs, ) - - engine = self._create_engine(**engine_kwargs) + self._credentials = credentials # Initialize parent SQLAlchemySession - inherits all SQL logic super().__init__( @@ -167,26 +195,88 @@ def __init__( ) logger.info( - "MemorySession initialized: instance=%s host=%s session_id=%s", + "AsyncDatabricksSession initialized: instance=%s host=%s session_id=%s", instance_name, self._credentials.host, session_id, ) - def _create_engine(self, **engine_kwargs) -> "AsyncEngine": - """Create an AsyncEngine with do_connect event for token injection.""" - # https://docs.sqlalchemy.org/en/21/core/engines.html#creating-urls-programmatically - url = URL.create( + @classmethod + def _get_or_create_engine( + cls, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient], + token_cache_duration_seconds: int, + **engine_kwargs, + ) -> "tuple[AsyncEngine, _LakebaseCredentials]": + """Get cached engine or create a new one (thread-safe). + + Engines are cached by instance_name so multiple sessions can share + the same connection pool. + """ + with cls._engine_cache_lock: + if instance_name in cls._engine_cache: + logger.debug("Reusing cached engine for instance=%s", instance_name) + return cls._engine_cache[instance_name] + + credentials = _LakebaseCredentials( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + ) + + engine = cls._create_engine(credentials, **engine_kwargs) + cls._engine_cache[instance_name] = (engine, credentials) + logger.info( + "Created new engine for instance=%s host=%s", + instance_name, + credentials.host, + ) + + return engine, credentials + + @classmethod + def clear_engine_cache(cls, instance_name: Optional[str] = None) -> None: + """Clear cached engines. + + Args: + instance_name: If provided, only clear the engine for this instance. + If None, clear all cached engines. + + Note: + This does not close the engines. Use this when you need to force + creation of a new engine with different settings. + """ + with cls._engine_cache_lock: + if instance_name is not None: + cls._engine_cache.pop(instance_name, None) + logger.info("Cleared engine cache for instance=%s", instance_name) + else: + cls._engine_cache.clear() + logger.info("Cleared all engine caches") + + @staticmethod + def _create_url(credentials: _LakebaseCredentials): + """Create a SQLAlchemy URL for Lakebase connection. + + https://docs.sqlalchemy.org/en/21/core/engines.html#creating-urls-programmatically + """ + return URL.create( drivername="postgresql+psycopg", - username=self._credentials.username, - host=self._credentials.host, + username=credentials.username, + host=credentials.host, port=DEFAULT_PORT, database=DEFAULT_DATABASE, ) - # Use default QueuePool with connection recycling. - # Connections are recycled before token cache expires (50 min), - # ensuring fresh tokens are injected via do_connect event. + @staticmethod + def _create_engine( + credentials: _LakebaseCredentials, **engine_kwargs + ) -> "AsyncEngine": + """Create an AsyncEngine with do_connect event for token injection.""" + url = AsyncDatabricksSession._create_url(credentials) + engine = create_async_engine( url, pool_recycle=DEFAULT_POOL_RECYCLE_SECONDS, @@ -194,10 +284,9 @@ def _create_engine(self, **engine_kwargs) -> "AsyncEngine": **engine_kwargs, ) - # Attach event to inject Lakebase token before each connection - # Note: do_connect fires on sync_engine even for async operations - credentials = self._credentials - + # AsyncEngine wraps a sync Engine internally - connection events like + # do_connect must be registered on sync_engine, not the async wrapper. + # https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#using-events-with-the-asyncio-extension @event.listens_for(engine.sync_engine, "do_connect") def inject_lakebase_token(dialect, conn_rec, cargs, cparams): cparams["password"] = credentials.get_token() @@ -220,15 +309,4 @@ def username(self) -> str: """The database username.""" return self._credentials.username - @property - def connection_url(self) -> str: - """The SQLAlchemy connection URL (without password, for debugging).""" - url = URL.create( - drivername="postgresql+psycopg", - username=self._credentials.username, - host=self._credentials.host, - port=DEFAULT_PORT, - database=DEFAULT_DATABASE, - query={"sslmode": DEFAULT_SSLMODE}, - ) - return url.render_as_string(hide_password=True) + diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py index b7de1912a..273bf8843 100644 --- a/integrations/openai/tests/integration_tests/test_memory_session.py +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -1,5 +1,5 @@ """ -Integration tests for MemorySession. +Integration tests for AsyncDatabricksSession. These tests require: 1. A Lakebase instance to be available @@ -59,14 +59,14 @@ def cleanup_tables(): # ============================================================================= -# MemorySession Tests +# AsyncDatabricksSession Tests # ============================================================================= @pytest.mark.asyncio async def test_memory_session_crud_operations(cleanup_tables): """ - Comprehensive CRUD test for MemorySession. + Comprehensive CRUD test for AsyncDatabricksSession. Tests the full lifecycle: - clear_session() on fresh session @@ -77,13 +77,13 @@ async def test_memory_session_crud_operations(cleanup_tables): - pop_item() removes and returns most recent item - clear_session() removes all items """ - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) session_id = str(uuid.uuid4()) - session = MemorySession( + session = AsyncDatabricksSession( session_id=session_id, instance_name=get_instance_name(), sessions_table=sessions_table, @@ -136,7 +136,7 @@ async def test_memory_session_crud_operations(cleanup_tables): @pytest.mark.asyncio async def test_memory_session_multiple_sessions_isolated(cleanup_tables): """Test that different session_ids have isolated data.""" - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) @@ -144,14 +144,14 @@ async def test_memory_session_multiple_sessions_isolated(cleanup_tables): session_id_1 = str(uuid.uuid4()) session_id_2 = str(uuid.uuid4()) - session_1 = MemorySession( + session_1 = AsyncDatabricksSession( session_id=session_id_1, instance_name=get_instance_name(), sessions_table=sessions_table, messages_table=messages_table, ) - session_2 = MemorySession( + session_2 = AsyncDatabricksSession( session_id=session_id_2, instance_name=get_instance_name(), sessions_table=sessions_table, @@ -188,12 +188,12 @@ async def test_memory_session_multiple_sessions_isolated(cleanup_tables): @pytest.mark.asyncio async def test_memory_session_pop_empty_returns_none(cleanup_tables): """Test that pop_item returns None on empty session.""" - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) - session = MemorySession( + session = AsyncDatabricksSession( session_id=str(uuid.uuid4()), instance_name=get_instance_name(), sessions_table=sessions_table, @@ -208,12 +208,12 @@ async def test_memory_session_pop_empty_returns_none(cleanup_tables): @pytest.mark.asyncio async def test_memory_session_add_empty_items_noop(cleanup_tables): """Test that add_items with empty list is a no-op.""" - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) - session = MemorySession( + session = AsyncDatabricksSession( session_id=str(uuid.uuid4()), instance_name=get_instance_name(), sessions_table=sessions_table, @@ -231,12 +231,12 @@ async def test_memory_session_add_empty_items_noop(cleanup_tables): @pytest.mark.asyncio async def test_memory_session_complex_message_data(cleanup_tables): """Test storing complex message data with nested structures.""" - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) - session = MemorySession( + session = AsyncDatabricksSession( session_id=str(uuid.uuid4()), instance_name=get_instance_name(), sessions_table=sessions_table, @@ -278,13 +278,13 @@ async def test_memory_session_complex_message_data(cleanup_tables): @pytest.mark.asyncio async def test_memory_session_properties(cleanup_tables): """Test that session properties return correct values.""" - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) session_id = str(uuid.uuid4()) - session = MemorySession( + session = AsyncDatabricksSession( session_id=session_id, instance_name=get_instance_name(), sessions_table=sessions_table, @@ -303,12 +303,12 @@ async def test_memory_session_properties(cleanup_tables): @pytest.mark.asyncio async def test_memory_session_get_items_ordering(cleanup_tables): """Test that get_items returns items in correct chronological order.""" - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents import AsyncDatabricksSession sessions_table, messages_table = get_unique_table_names() cleanup_tables.append((sessions_table, messages_table)) - session = MemorySession( + session = AsyncDatabricksSession( session_id=str(uuid.uuid4()), instance_name=get_instance_name(), sessions_table=sessions_table, diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 25f93569e..72053d111 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1,20 +1,38 @@ -"""Unit tests for MemorySession.""" +"""Unit tests for AsyncDatabricksSession.""" from unittest.mock import MagicMock, patch import pytest try: - from databricks_ai_bridge.lakebase import _LakebasePoolBase # noqa: F401 + from databricks_ai_bridge.lakebase import _LakebaseBase # noqa: F401 from psycopg.rows import DictRow # noqa: F401 from sqlalchemy.ext.asyncio import AsyncEngine # noqa: F401 except ImportError as e: raise ImportError( - "MemorySession tests require databricks-openai[memory]. " + "AsyncDatabricksSession tests require databricks-openai[memory]. " "Please install with: pip install databricks-openai[memory]" ) from e +@pytest.fixture(autouse=True) +def clear_engine_cache(): + """Clear AsyncDatabricksSession engine cache before each test.""" + try: + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession.clear_engine_cache() + except ImportError: + pass + yield + try: + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession.clear_engine_cache() + except ImportError: + pass + + @pytest.fixture def mock_workspace_client(): """Create a mock WorkspaceClient.""" @@ -182,8 +200,8 @@ def test_init_raises_on_invalid_instance(self, mock_workspace_client): ) -class TestMemorySessionInit: - """Tests for MemorySession initialization.""" +class TestAsyncDatabricksSessionInit: + """Tests for AsyncDatabricksSession initialization.""" def test_init_creates_engine_with_correct_url( self, mock_workspace_client, mock_engine, mock_event_listens_for @@ -203,9 +221,9 @@ def test_init_creates_engine_with_correct_url( side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -241,10 +259,10 @@ def test_init_uses_pool_recycle( ): from databricks_openai.agents.session import ( DEFAULT_POOL_RECYCLE_SECONDS, - MemorySession, + AsyncDatabricksSession, ) - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -272,9 +290,9 @@ def test_init_sets_ssl_mode(self, mock_workspace_client, mock_engine, mock_event side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -296,9 +314,9 @@ def test_init_registers_do_connect_event(self, mock_workspace_client, mock_engin ), patch("databricks_openai.agents.session.event.listens_for") as mock_listens_for, ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -329,9 +347,9 @@ def test_init_with_custom_table_names( return_value=None, ) as mock_parent_init, ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -366,9 +384,9 @@ def test_init_with_create_tables_false( return_value=None, ) as mock_parent_init, ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -379,8 +397,8 @@ def test_init_with_create_tables_false( assert call_kwargs["create_tables"] is False -class TestMemorySessionProperties: - """Tests for MemorySession properties.""" +class TestAsyncDatabricksSessionProperties: + """Tests for AsyncDatabricksSession properties.""" def test_instance_name_property( self, mock_workspace_client, mock_engine, mock_event_listens_for @@ -400,9 +418,9 @@ def test_instance_name_property( side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - session = MemorySession( + session = AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -426,9 +444,9 @@ def test_host_property(self, mock_workspace_client, mock_engine, mock_event_list side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - session = MemorySession( + session = AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -452,9 +470,9 @@ def test_username_property(self, mock_workspace_client, mock_engine, mock_event_ side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - session = MemorySession( + session = AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -463,7 +481,7 @@ def test_username_property(self, mock_workspace_client, mock_engine, mock_event_ assert session.username == "test_user@databricks.com" -class TestMemorySessionTokenInjection: +class TestAsyncDatabricksSessionTokenInjection: """Tests for token injection via do_connect event.""" def test_do_connect_injects_token(self, mock_workspace_client, mock_engine): @@ -492,9 +510,9 @@ def decorator(fn): side_effect=capture_handler, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -509,7 +527,7 @@ def decorator(fn): assert cparams["password"] == "test-oauth-token" -class TestMemorySessionEngineKwargs: +class TestAsyncDatabricksSessionEngineKwargs: """Tests for passing additional engine kwargs.""" def test_extra_engine_kwargs_passed( @@ -530,9 +548,9 @@ def test_extra_engine_kwargs_passed( side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import MemorySession + from databricks_openai.agents.session import AsyncDatabricksSession - MemorySession( + AsyncDatabricksSession( session_id="test-session-123", instance_name="test-instance", workspace_client=mock_workspace_client, @@ -543,3 +561,238 @@ def test_extra_engine_kwargs_passed( call_kwargs = mock_create_engine.call_args[1] assert call_kwargs.get("echo") is True assert call_kwargs.get("pool_pre_ping") is True + + +class TestAsyncDatabricksSessionEngineCaching: + """Tests for engine caching behavior.""" + + def test_sessions_share_engine_for_same_instance( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that multiple sessions with the same instance_name share an engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + # Create two sessions with the same instance_name + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + + # Both sessions should reference the same engine + assert session1._engine is session2._engine + + def test_different_instances_get_different_engines( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that sessions with different instance_names get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + # Create sessions with different instance_names + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="instance-a", + workspace_client=mock_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="instance-b", + workspace_client=mock_workspace_client, + ) + + # Engine should be created twice (once per instance) + assert mock_create_engine.call_count == 2 + + # Sessions should have different engines + assert session1._engine is not session2._engine + + def test_clear_engine_cache( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that clear_engine_cache forces new engine creation.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + # Create first session + AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + assert mock_create_engine.call_count == 1 + + # Clear cache + AsyncDatabricksSession.clear_engine_cache() + + # Create second session - should create new engine + AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + assert mock_create_engine.call_count == 2 + + +class TestAsyncDatabricksSessionAsyncOnly: + """Tests verifying AsyncDatabricksSession is async-only.""" + + def test_get_items_returns_coroutine_without_await( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that calling get_items() without await returns a coroutine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + import inspect + + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Calling without await returns a coroutine, not actual data + result = session.get_items() + assert inspect.iscoroutine(result) + + # Clean up the coroutine to avoid RuntimeWarning + result.close() + + def test_add_items_returns_coroutine_without_await( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that calling add_items() without await returns a coroutine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + import inspect + + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Calling without await returns a coroutine, not actual data + result = session.add_items([{"role": "user", "content": "test"}]) + assert inspect.iscoroutine(result) + + # Clean up the coroutine to avoid RuntimeWarning + result.close() + + def test_methods_are_coroutine_functions( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that all session methods are async (coroutine functions).""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "databricks_openai.agents.session.create_async_engine", + return_value=mock_engine, + ), + patch( + "databricks_openai.agents.session.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + import inspect + + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # All data methods should be async + assert inspect.iscoroutinefunction(session.get_items) + assert inspect.iscoroutinefunction(session.add_items) + assert inspect.iscoroutinefunction(session.pop_item) + assert inspect.iscoroutinefunction(session.clear_session) diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 1cb0e83fa..6969d40e2 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -81,12 +81,12 @@ class SequencePrivilege(str, Enum): ALL = "ALL" # Renders as ALL PRIVILEGES -class _LakebasePoolBase: +class _LakebaseBase: """ - Base logic for Lakebase connection pools: resolve host, infer username, + Base class for Lakebase connections: resolve host, infer username, token cache + minting, and conninfo building. - Subclasses implement pool-specific initialization and lifecycle methods. + Subclasses implement specific initialization and lifecycle methods. """ def __init__( @@ -168,7 +168,7 @@ def _infer_username(self) -> str: raise ValueError("Unable to infer username for Lakebase connection.") -class LakebasePool(_LakebasePoolBase): +class LakebasePool(_LakebaseBase): """Sync Lakebase connection pool built on psycopg with rotating credentials. instance_name: Name of Lakebase Instance @@ -262,7 +262,7 @@ def close(self) -> None: self._pool.close() -class AsyncLakebasePool(_LakebasePoolBase): +class AsyncLakebasePool(_LakebaseBase): """Async Lakebase connection pool built on psycopg with rotating credentials. instance_name: Name of Lakebase Instance From d7c4739c391754e9438ac59d5beebb4d87aaaaf3 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 4 Feb 2026 17:05:15 -0800 Subject: [PATCH 06/17] update dependnecies needed in openai package --- integrations/openai/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 36887a7b1..b7014cf7e 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ [project.optional-dependencies] memory = [ "databricks-ai-bridge[memory]>=0.13.0", + "sqlalchemy[asyncio]>=2.0.0", ] [dependency-groups] From 5aa6d1a210c1c26cc3d686364a22f31034148706 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 4 Feb 2026 17:11:15 -0800 Subject: [PATCH 07/17] format --- integrations/openai/src/databricks_openai/agents/session.py | 6 +----- integrations/openai/tests/unit_tests/test_session.py | 4 +--- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 3c9c76b4b..9490867e1 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -271,9 +271,7 @@ def _create_url(credentials: _LakebaseCredentials): ) @staticmethod - def _create_engine( - credentials: _LakebaseCredentials, **engine_kwargs - ) -> "AsyncEngine": + def _create_engine(credentials: _LakebaseCredentials, **engine_kwargs) -> "AsyncEngine": """Create an AsyncEngine with do_connect event for token injection.""" url = AsyncDatabricksSession._create_url(credentials) @@ -308,5 +306,3 @@ def host(self) -> str: def username(self) -> str: """The database username.""" return self._credentials.username - - diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 72053d111..262390139 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -650,9 +650,7 @@ def test_different_instances_get_different_engines( # Sessions should have different engines assert session1._engine is not session2._engine - def test_clear_engine_cache( - self, mock_workspace_client, mock_engine, mock_event_listens_for - ): + def test_clear_engine_cache(self, mock_workspace_client, mock_engine, mock_event_listens_for): """Test that clear_engine_cache forces new engine creation.""" with ( patch( From bbb0b0e38c5bfaae35e715c87015a7f4770a694b Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 4 Feb 2026 17:17:43 -0800 Subject: [PATCH 08/17] integrations/openai typecheck --- integrations/openai/tests/unit_tests/test_clients.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index df44f276e..07bc689d5 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -54,6 +54,7 @@ def test_bearer_auth_flow(self, mock_workspace_client): request = Request("GET", "https://test.databricks.com/api/test") # Authenticate the request + assert http_client.auth is not None auth_flow = http_client.auth.auth_flow(request) authenticated_request = next(auth_flow) @@ -98,6 +99,7 @@ def test_bearer_auth_flow(self, mock_workspace_client): request = Request("GET", "https://test.databricks.com/api/test") # Authenticate the request + assert http_client.auth is not None auth_flow = http_client.auth.auth_flow(request) authenticated_request = next(auth_flow) @@ -200,7 +202,7 @@ def test_chat_completions_strips_strict_for_claude(self): with patch.object(Completions, "create") as mock_create: mock_create.return_value = MagicMock() - tools = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], @@ -225,7 +227,7 @@ def test_chat_completions_preserves_strict_for_gpt(self): with patch.object(Completions, "create") as mock_create: mock_create.return_value = MagicMock() - tools = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], @@ -275,7 +277,7 @@ async def test_chat_completions_strips_strict_for_claude(self): client = AsyncDatabricksOpenAI() with patch.object(AsyncCompletions, "create", new_callable=AsyncMock) as mock_create: - tools = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] await client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], @@ -300,7 +302,7 @@ async def test_chat_completions_preserves_strict_for_gpt(self): client = AsyncDatabricksOpenAI() with patch.object(AsyncCompletions, "create", new_callable=AsyncMock) as mock_create: - tools = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] await client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], From e1ff83c3050a75f800a9c88681aa24b81e940f21 Mon Sep 17 00:00:00 2001 From: Jenny Date: Wed, 4 Feb 2026 17:20:59 -0800 Subject: [PATCH 09/17] Linting / typechecking for integrations/openai --- .../openai/tests/unit_tests/test_clients.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index 07bc689d5..a112c3552 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -202,11 +202,11 @@ def test_chat_completions_strips_strict_for_claude(self): with patch.object(Completions, "create") as mock_create: mock_create.return_value = MagicMock() - tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools = [{"type": "function", "function": {"name": "test", "strict": True}}] client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -227,11 +227,11 @@ def test_chat_completions_preserves_strict_for_gpt(self): with patch.object(Completions, "create") as mock_create: mock_create.return_value = MagicMock() - tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools = [{"type": "function", "function": {"name": "test", "strict": True}}] client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -277,11 +277,11 @@ async def test_chat_completions_strips_strict_for_claude(self): client = AsyncDatabricksOpenAI() with patch.object(AsyncCompletions, "create", new_callable=AsyncMock) as mock_create: - tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools = [{"type": "function", "function": {"name": "test", "strict": True}}] await client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -302,11 +302,11 @@ async def test_chat_completions_preserves_strict_for_gpt(self): client = AsyncDatabricksOpenAI() with patch.object(AsyncCompletions, "create", new_callable=AsyncMock) as mock_create: - tools: Any = [{"type": "function", "function": {"name": "test", "strict": True}}] + tools = [{"type": "function", "function": {"name": "test", "strict": True}}] await client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs From b9b2cf8f128db5eddbb0deeed9498ef4178cc8c0 Mon Sep 17 00:00:00 2001 From: Jenny Date: Thu, 5 Feb 2026 11:44:10 -0800 Subject: [PATCH 10/17] add dependency to test dev --- integrations/openai/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index b7014cf7e..262546b9f 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -41,6 +41,7 @@ tests = [ "pytest-timeout>=2.3.1", "pytest-cov>=4.1.0", "databricks-ai-bridge[memory]>=0.13.0", + "sqlalchemy[asyncio]>=2.0.0", ] [build-system] From 04b7e4639913a31d852e565d1d4986969247fb43 Mon Sep 17 00:00:00 2001 From: Jenny Date: Thu, 5 Feb 2026 17:51:27 -0800 Subject: [PATCH 11/17] refactor AsyncDatabricksSession to use AsyncLakebaseSqlAlchemy session that returns engines to manage connections to db --- .../src/databricks_openai/agents/session.py | 192 ++--------- .../integration_tests/test_memory_session.py | 25 -- .../openai/tests/unit_tests/test_session.py | 314 +++++++----------- pyproject.toml | 1 + src/databricks_ai_bridge/lakebase.py | 138 +++++++- tests/databricks_ai_bridge/test_lakebase.py | 4 +- 6 files changed, 292 insertions(+), 382 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 9490867e1..6d170f5ff 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -31,19 +31,17 @@ async def main(): from __future__ import annotations import logging -import time from threading import Lock -from typing import TYPE_CHECKING, Optional - -if TYPE_CHECKING: - from sqlalchemy.ext.asyncio import AsyncEngine +from typing import Optional try: from agents.extensions.memory import SQLAlchemySession from databricks.sdk import WorkspaceClient - from databricks_ai_bridge.lakebase import _LakebaseBase - from sqlalchemy import URL, event - from sqlalchemy.ext.asyncio import create_async_engine + from databricks_ai_bridge.lakebase import ( + DEFAULT_POOL_RECYCLE_SECONDS, + DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + AsyncLakebaseSQLAlchemy, + ) except ImportError as e: raise ImportError( "AsyncDatabricksSession requires databricks-openai[memory]. " @@ -52,68 +50,17 @@ async def main(): logger = logging.getLogger(__name__) -# Constants for Lakebase connection -DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 50 * 60 # 50 minutes -DEFAULT_POOL_RECYCLE_SECONDS = 45 * 60 # 45 minutes (before token cache expires) -DEFAULT_SSLMODE = "require" -DEFAULT_PORT = 5432 -DEFAULT_DATABASE = "databricks_postgres" - - -class _LakebaseCredentials(_LakebaseBase): - """ - Lightweight credential provider that reuses _LakebaseBase for: - - Instance name → host resolution - - Username inference from workspace client - - Token minting and caching - - Does NOT create a connection pool - just provides credentials for SQLAlchemy. - """ - - def __init__( - self, - *, - instance_name: str, - workspace_client: Optional[WorkspaceClient] = None, - token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, - ) -> None: - super().__init__( - instance_name=instance_name, - workspace_client=workspace_client, - token_cache_duration_seconds=token_cache_duration_seconds, - ) - self._cache_lock = Lock() - - def get_token(self) -> str: - """Get cached token or mint a new one (thread-safe).""" - with self._cache_lock: - if cached := self._get_cached_token(): - return cached - token = self._mint_token() - self._cached_token = token - self._cache_ts = time.time() - return token - class AsyncDatabricksSession(SQLAlchemySession): """ Async OpenAI Agents SDK Session implementation for Databricks Lakebase. - - This class subclasses SQLAlchemySession to provide: - - Lakebase instance resolution - - OAuth token rotation for authentication - - Connection pooling with automatic token refresh - - SQL logic inherited from SQLAlchemySession + For more information on the Session protocol, see: + https://openai.github.io/openai-agents-python/ref/memory/session/ Note: This class is **async-only**. All session methods (get_items, add_items, clear_session, etc.) are coroutines and must be awaited. - Note: - Engines are cached and reused across sessions with the same instance_name. - This means multiple AsyncDatabricksSession instances share a single connection pool, - rather than creating a new pool per session - The session stores conversation history in two tables: - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) - agent_messages: Stores conversation items (id, session_id, message_data, created_at) @@ -136,15 +83,12 @@ async def main(): asyncio.run(main()) ``` - - For more information on the Session protocol, see: - https://openai.github.io/openai-agents-python/ref/memory/session/ """ - # Class-level cache for engines and credentials, keyed by instance_name. + # Class-level cache for AsyncLakebaseSQLAlchemy instances, keyed by instance_name. # This allows multiple AsyncDatabricksSession instances to share a single engine/pool. - _engine_cache: "dict[str, tuple[AsyncEngine, _LakebaseCredentials]]" = {} - _engine_cache_lock = Lock() + _lakebase_sql_alchemy_cache: dict[str, AsyncLakebaseSQLAlchemy] = {} + _lakebase_sql_alchemy_cache_lock = Lock() def __init__( self, @@ -167,7 +111,7 @@ def __init__( workspace_client: Optional WorkspaceClient for authentication. If not provided, a default client will be created. token_cache_duration_seconds: How long to cache OAuth tokens. - Defaults to 50 minutes. + Defaults to 15 minutes. create_tables: Whether to auto-create tables on first use. Defaults to True. sessions_table: Name of the sessions table. @@ -177,132 +121,52 @@ def __init__( **engine_kwargs: Additional keyword arguments passed to SQLAlchemy's create_async_engine(). """ - engine, credentials = self._get_or_create_engine( + self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, + pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), **engine_kwargs, ) - self._credentials = credentials # Initialize parent SQLAlchemySession - inherits all SQL logic super().__init__( session_id=session_id, - engine=engine, + engine=self._lakebase.engine, create_tables=create_tables, sessions_table=sessions_table, messages_table=messages_table, ) logger.info( - "AsyncDatabricksSession initialized: instance=%s host=%s session_id=%s", + "AsyncDatabricksSession initialized: instance=%s session_id=%s", instance_name, - self._credentials.host, session_id, ) @classmethod - def _get_or_create_engine( + def _get_or_create_lakebase( cls, *, instance_name: str, workspace_client: Optional[WorkspaceClient], token_cache_duration_seconds: int, + pool_recycle: int, **engine_kwargs, - ) -> "tuple[AsyncEngine, _LakebaseCredentials]": - """Get cached engine or create a new one (thread-safe). - - Engines are cached by instance_name so multiple sessions can share - the same connection pool. - """ - with cls._engine_cache_lock: - if instance_name in cls._engine_cache: + ) -> AsyncLakebaseSQLAlchemy: + """Get cached AsyncLakebaseSQLAlchemy or create a new one (thread-safe).""" + with cls._lakebase_sql_alchemy_cache_lock: + if instance_name in cls._lakebase_sql_alchemy_cache: logger.debug("Reusing cached engine for instance=%s", instance_name) - return cls._engine_cache[instance_name] + return cls._lakebase_sql_alchemy_cache[instance_name] - credentials = _LakebaseCredentials( + lakebase = AsyncLakebaseSQLAlchemy( instance_name=instance_name, workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, + pool_recycle=pool_recycle, + **engine_kwargs, ) + cls._lakebase_sql_alchemy_cache[instance_name] = lakebase + return lakebase - engine = cls._create_engine(credentials, **engine_kwargs) - cls._engine_cache[instance_name] = (engine, credentials) - logger.info( - "Created new engine for instance=%s host=%s", - instance_name, - credentials.host, - ) - - return engine, credentials - - @classmethod - def clear_engine_cache(cls, instance_name: Optional[str] = None) -> None: - """Clear cached engines. - - Args: - instance_name: If provided, only clear the engine for this instance. - If None, clear all cached engines. - - Note: - This does not close the engines. Use this when you need to force - creation of a new engine with different settings. - """ - with cls._engine_cache_lock: - if instance_name is not None: - cls._engine_cache.pop(instance_name, None) - logger.info("Cleared engine cache for instance=%s", instance_name) - else: - cls._engine_cache.clear() - logger.info("Cleared all engine caches") - - @staticmethod - def _create_url(credentials: _LakebaseCredentials): - """Create a SQLAlchemy URL for Lakebase connection. - - https://docs.sqlalchemy.org/en/21/core/engines.html#creating-urls-programmatically - """ - return URL.create( - drivername="postgresql+psycopg", - username=credentials.username, - host=credentials.host, - port=DEFAULT_PORT, - database=DEFAULT_DATABASE, - ) - - @staticmethod - def _create_engine(credentials: _LakebaseCredentials, **engine_kwargs) -> "AsyncEngine": - """Create an AsyncEngine with do_connect event for token injection.""" - url = AsyncDatabricksSession._create_url(credentials) - - engine = create_async_engine( - url, - pool_recycle=DEFAULT_POOL_RECYCLE_SECONDS, - connect_args={"sslmode": DEFAULT_SSLMODE}, - **engine_kwargs, - ) - - # AsyncEngine wraps a sync Engine internally - connection events like - # do_connect must be registered on sync_engine, not the async wrapper. - # https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#using-events-with-the-asyncio-extension - @event.listens_for(engine.sync_engine, "do_connect") - def inject_lakebase_token(dialect, conn_rec, cargs, cparams): - cparams["password"] = credentials.get_token() - logger.debug("Injected Lakebase token for connection") - - return engine - - @property - def instance_name(self) -> str: - """The Lakebase instance name.""" - return self._credentials.instance_name - - @property - def host(self) -> str: - """The resolved Lakebase host.""" - return self._credentials.host - - @property - def username(self) -> str: - """The database username.""" - return self._credentials.username diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py index 273bf8843..bc6307422 100644 --- a/integrations/openai/tests/integration_tests/test_memory_session.py +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -275,31 +275,6 @@ async def test_memory_session_complex_message_data(cleanup_tables): await session.clear_session() -@pytest.mark.asyncio -async def test_memory_session_properties(cleanup_tables): - """Test that session properties return correct values.""" - from databricks_openai.agents import AsyncDatabricksSession - - sessions_table, messages_table = get_unique_table_names() - cleanup_tables.append((sessions_table, messages_table)) - - session_id = str(uuid.uuid4()) - session = AsyncDatabricksSession( - session_id=session_id, - instance_name=get_instance_name(), - sessions_table=sessions_table, - messages_table=messages_table, - ) - - # Verify properties - assert session.instance_name == get_instance_name() - assert session.host is not None - assert len(session.host) > 0 - assert session.username is not None - assert len(session.username) > 0 - assert session.session_id == session_id - - @pytest.mark.asyncio async def test_memory_session_get_items_ordering(cleanup_tables): """Test that get_items returns items in correct chronological order.""" diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 262390139..5f8d880e6 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -21,14 +21,14 @@ def clear_engine_cache(): try: from databricks_openai.agents.session import AsyncDatabricksSession - AsyncDatabricksSession.clear_engine_cache() + AsyncDatabricksSession._lakebase_sql_alchemy_cache.clear() except ImportError: pass yield try: from databricks_openai.agents.session import AsyncDatabricksSession - AsyncDatabricksSession.clear_engine_cache() + AsyncDatabricksSession._lakebase_sql_alchemy_cache.clear() except ImportError: pass @@ -78,78 +78,106 @@ def decorator(fn): return create_decorator -class TestLakebaseCredentials: - """Tests for _LakebaseCredentials class.""" +class TestAsyncLakebaseSQLAlchemy: + """Tests for AsyncLakebaseSQLAlchemy class.""" def test_init_resolves_host(self, mock_workspace_client): """Test that initialization resolves the Lakebase host.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), ): - from databricks_openai.agents.session import _LakebaseCredentials + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy - creds = _LakebaseCredentials( + lakebase = AsyncLakebaseSQLAlchemy( instance_name="test-instance", workspace_client=mock_workspace_client, ) - assert creds.host == "test-instance.lakebase.databricks.com" + assert lakebase.host == "test-instance.lakebase.databricks.com" mock_workspace_client.database.get_database_instance.assert_called_once_with( "test-instance" ) def test_init_infers_username(self, mock_workspace_client): """Test that initialization infers the username.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), ): - from databricks_openai.agents.session import _LakebaseCredentials + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy - creds = _LakebaseCredentials( + lakebase = AsyncLakebaseSQLAlchemy( instance_name="test-instance", workspace_client=mock_workspace_client, ) - assert creds.username == "test_user@databricks.com" + assert lakebase.username == "test_user@databricks.com" mock_workspace_client.current_user.me.assert_called() def test_get_token_mints_new_token(self, mock_workspace_client): """Test that get_token mints a new token when cache is empty.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), ): - from databricks_openai.agents.session import _LakebaseCredentials + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy - creds = _LakebaseCredentials( + lakebase = AsyncLakebaseSQLAlchemy( instance_name="test-instance", workspace_client=mock_workspace_client, ) - token = creds.get_token() + token = lakebase.get_token() assert token == "test-oauth-token" mock_workspace_client.database.generate_database_credential.assert_called_once() def test_get_token_returns_cached_token(self, mock_workspace_client): """Test that get_token returns cached token when valid.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), ): - from databricks_openai.agents.session import _LakebaseCredentials + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy - creds = _LakebaseCredentials( + lakebase = AsyncLakebaseSQLAlchemy( instance_name="test-instance", workspace_client=mock_workspace_client, ) # First call - mints token - token1 = creds.get_token() + token1 = lakebase.get_token() # Second call - should return cached - token2 = creds.get_token() + token2 = lakebase.get_token() assert token1 == token2 == "test-oauth-token" # Should only mint once @@ -157,26 +185,33 @@ def test_get_token_returns_cached_token(self, mock_workspace_client): def test_get_token_refreshes_expired_token(self, mock_workspace_client): """Test that get_token refreshes token when expired.""" - with patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), ): import time - from databricks_openai.agents.session import _LakebaseCredentials + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy - creds = _LakebaseCredentials( + lakebase = AsyncLakebaseSQLAlchemy( instance_name="test-instance", workspace_client=mock_workspace_client, token_cache_duration_seconds=0, # Immediate expiration ) # First call - creds.get_token() + lakebase.get_token() # Wait a tiny bit time.sleep(0.01) # Second call - should mint new token - creds.get_token() + lakebase.get_token() # Should mint twice due to expiration assert mock_workspace_client.database.generate_database_credential.call_count == 2 @@ -191,14 +226,37 @@ def test_init_raises_on_invalid_instance(self, mock_workspace_client): "databricks_ai_bridge.lakebase.WorkspaceClient", return_value=mock_workspace_client, ): - from databricks_openai.agents.session import _LakebaseCredentials + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): - _LakebaseCredentials( + AsyncLakebaseSQLAlchemy( instance_name="invalid-instance", workspace_client=mock_workspace_client, ) + def test_engine_property_returns_engine(self, mock_workspace_client): + """Test that engine property returns the created engine.""" + mock_eng = MagicMock(sync_engine=MagicMock()) + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_eng, + ), + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert lakebase.engine is mock_eng + class TestAsyncDatabricksSessionInit: """Tests for AsyncDatabricksSession initialization.""" @@ -213,11 +271,11 @@ def test_init_creates_engine_with_correct_url( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -249,18 +307,17 @@ def test_init_uses_pool_recycle( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): - from databricks_openai.agents.session import ( - DEFAULT_POOL_RECYCLE_SECONDS, - AsyncDatabricksSession, - ) + from databricks_ai_bridge.lakebase import DEFAULT_POOL_RECYCLE_SECONDS + + from databricks_openai.agents.session import AsyncDatabricksSession AsyncDatabricksSession( session_id="test-session-123", @@ -282,11 +339,11 @@ def test_init_sets_ssl_mode(self, mock_workspace_client, mock_engine, mock_event return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -309,10 +366,10 @@ def test_init_registers_do_connect_event(self, mock_workspace_client, mock_engin return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), - patch("databricks_openai.agents.session.event.listens_for") as mock_listens_for, + patch("sqlalchemy.event.listens_for") as mock_listens_for, ): from databricks_openai.agents.session import AsyncDatabricksSession @@ -335,11 +392,11 @@ def test_init_with_custom_table_names( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), patch( @@ -372,11 +429,11 @@ def test_init_with_create_tables_false( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), patch( @@ -397,90 +454,6 @@ def test_init_with_create_tables_false( assert call_kwargs["create_tables"] is False -class TestAsyncDatabricksSessionProperties: - """Tests for AsyncDatabricksSession properties.""" - - def test_instance_name_property( - self, mock_workspace_client, mock_engine, mock_event_listens_for - ): - """Test instance_name property returns correct value.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), - patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), - patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - session = AsyncDatabricksSession( - session_id="test-session-123", - instance_name="test-instance", - workspace_client=mock_workspace_client, - ) - - assert session.instance_name == "test-instance" - - def test_host_property(self, mock_workspace_client, mock_engine, mock_event_listens_for): - """Test host property returns resolved host.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), - patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), - patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - session = AsyncDatabricksSession( - session_id="test-session-123", - instance_name="test-instance", - workspace_client=mock_workspace_client, - ) - - assert session.host == "test-instance.lakebase.databricks.com" - - def test_username_property(self, mock_workspace_client, mock_engine, mock_event_listens_for): - """Test username property returns inferred username.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), - patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ), - patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - session = AsyncDatabricksSession( - session_id="test-session-123", - instance_name="test-instance", - workspace_client=mock_workspace_client, - ) - - assert session.username == "test_user@databricks.com" - - class TestAsyncDatabricksSessionTokenInjection: """Tests for token injection via do_connect event.""" @@ -502,11 +475,11 @@ def decorator(fn): return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=capture_handler, ), ): @@ -540,11 +513,11 @@ def test_extra_engine_kwargs_passed( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -576,11 +549,11 @@ def test_sessions_share_engine_for_same_instance( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -622,11 +595,11 @@ def test_different_instances_get_different_engines( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", side_effect=lambda *args, **kwargs: next(engine_iter), ) as mock_create_engine, patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -650,43 +623,6 @@ def test_different_instances_get_different_engines( # Sessions should have different engines assert session1._engine is not session2._engine - def test_clear_engine_cache(self, mock_workspace_client, mock_engine, mock_event_listens_for): - """Test that clear_engine_cache forces new engine creation.""" - with ( - patch( - "databricks_ai_bridge.lakebase.WorkspaceClient", - return_value=mock_workspace_client, - ), - patch( - "databricks_openai.agents.session.create_async_engine", - return_value=mock_engine, - ) as mock_create_engine, - patch( - "databricks_openai.agents.session.event.listens_for", - side_effect=mock_event_listens_for, - ), - ): - from databricks_openai.agents.session import AsyncDatabricksSession - - # Create first session - AsyncDatabricksSession( - session_id="session-1", - instance_name="test-instance", - workspace_client=mock_workspace_client, - ) - assert mock_create_engine.call_count == 1 - - # Clear cache - AsyncDatabricksSession.clear_engine_cache() - - # Create second session - should create new engine - AsyncDatabricksSession( - session_id="session-2", - instance_name="test-instance", - workspace_client=mock_workspace_client, - ) - assert mock_create_engine.call_count == 2 - class TestAsyncDatabricksSessionAsyncOnly: """Tests verifying AsyncDatabricksSession is async-only.""" @@ -701,11 +637,11 @@ def test_get_items_returns_coroutine_without_await( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -736,11 +672,11 @@ def test_add_items_returns_coroutine_without_await( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): @@ -771,11 +707,11 @@ def test_methods_are_coroutine_functions( return_value=mock_workspace_client, ), patch( - "databricks_openai.agents.session.create_async_engine", + "sqlalchemy.ext.asyncio.create_async_engine", return_value=mock_engine, ), patch( - "databricks_openai.agents.session.event.listens_for", + "sqlalchemy.event.listens_for", side_effect=mock_event_listens_for, ), ): diff --git a/pyproject.toml b/pyproject.toml index 32ed1e2f8..044f436c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ file = "LICENSE.txt" [project.optional-dependencies] memory = [ "psycopg[binary,pool]>=3.2.10", + "sqlalchemy[asyncio]>=2.0.0", ] [dependency-groups] diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 6969d40e2..f582044bd 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -6,11 +6,15 @@ import uuid from enum import Enum from threading import Lock -from typing import Any, List, Literal, Optional, Sequence +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence from databricks.sdk import WorkspaceClient from psycopg.rows import DictRow +if TYPE_CHECKING: + from sqlalchemy import URL + from sqlalchemy.ext.asyncio import AsyncEngine + try: import psycopg from psycopg import sql @@ -24,12 +28,16 @@ __all__ = [ "AsyncLakebasePool", + "AsyncLakebaseSQLAlchemy", "LakebasePool", ] logger = logging.getLogger(__name__) -DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 50 * 60 # Cache token for 50 minutes +# Token cache duration based on Databricks Lakebase docs (15 minutes) +# https://docs.databricks.com/aws/en/oltp/projects/authentication?language=Python%3A+SQLAlchemy +DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 15 * 60 # 15 minutes (900 seconds) +DEFAULT_POOL_RECYCLE_SECONDS = 14 * 60 # 14 minutes (before token cache expires) DEFAULT_MIN_SIZE = 1 DEFAULT_MAX_SIZE = 10 DEFAULT_TIMEOUT = 30.0 @@ -869,3 +877,129 @@ def grant_all_sequences_in_schema( schema, grantee, ) + + +# ============================================================================= +# AsyncLakebaseSQLAlchemy - SQLAlchemy async engine factory +# ============================================================================= + + +class AsyncLakebaseSQLAlchemy(_LakebaseBase): + """Async SQLAlchemy engine factory for Databricks Lakebase. + + Provides an AsyncEngine with automatic OAuth token injection via + SQLAlchemy's do_connect event. Tokens are cached and refreshed + every 15 minutes. + + Note: + This class is **async-only**. The engine uses SQLAlchemy's + async extension with the psycopg driver. + + Reference: + https://docs.databricks.com/aws/en/oltp/instances/authentication + + Example: + ```python + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + # Create once and reuse the engine + lakebase = AsyncLakebaseSQLAlchemy(instance_name="my-lakebase") + + async with lakebase.engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + ``` + """ + + def __init__( + self, + *, + instance_name: str, + workspace_client: WorkspaceClient | None = None, + token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + pool_recycle: int = DEFAULT_POOL_RECYCLE_SECONDS, + **engine_kwargs, + ) -> None: + """ + Initialize AsyncLakebaseSQLAlchemy for Databricks Lakebase. + + Args: + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + If not provided, a default client will be created. + token_cache_duration_seconds: How long to cache OAuth tokens. + Defaults to 15 minutes. + pool_recycle: Connection pool recycle time in seconds. + Defaults to 14 minutes (before token cache expires). + **engine_kwargs: Additional keyword arguments passed to + SQLAlchemy's create_async_engine(). + """ + super().__init__( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + ) + + # Thread-safe lock for token caching (do_connect is sync context) + self._cache_lock = Lock() + self._pool_recycle = pool_recycle + self._engine = self._create_engine(**engine_kwargs) + + logger.info( + "AsyncLakebaseSQLAlchemy initialized: instance=%s host=%s", + instance_name, + self.host, + ) + + @property + def engine(self) -> "AsyncEngine": + """The SQLAlchemy AsyncEngine.""" + return self._engine + + def get_token(self) -> str: + """Get cached token or mint a new one (thread-safe).""" + with self._cache_lock: + if cached := self._get_cached_token(): + return cached + token = self._mint_token() + self._cached_token = token + self._cache_ts = time.time() + return token + + def _create_url(self) -> "URL": + """Create SQLAlchemy URL for Lakebase connection.""" + from sqlalchemy import URL + + # Create engine without password - token injected via event listener + # Note: empty password in URL, actual token provided on connect + return URL.create( + drivername="postgresql+psycopg", + username=self.username, + host=self.host, + port=DEFAULT_PORT, + database=DEFAULT_DATABASE, + ) + + def _create_engine(self, **engine_kwargs) -> "AsyncEngine": + """Create AsyncEngine with do_connect event for token injection.""" + from sqlalchemy import event + from sqlalchemy.ext.asyncio import create_async_engine + + url = self._create_url() + + engine: AsyncEngine = create_async_engine( + url, + pool_recycle=self._pool_recycle, + connect_args={"sslmode": DEFAULT_SSLMODE}, + **engine_kwargs, + ) + + # AsyncEngine wraps a sync Engine internally - connection events like + # do_connect must be registered on sync_engine, not the async wrapper. + # https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#using-events-with-the-asyncio-extension + # Lakebase docs https://docs.databricks.com/aws/en/oltp/projects/authentication?language=Python%3A+SQLAlchemy + @event.listens_for(engine.sync_engine, "do_connect") + def inject_token(dialect, conn_rec, cargs, cparams): + cparams["password"] = self.get_token() + logger.debug("Injected Lakebase token for connection") + + return engine diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index b914f7217..d17be23b2 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -83,7 +83,7 @@ def test_lakebase_pool_logs_cache_seconds(monkeypatch, caplog): ) assert any( - record.levelno == logging.INFO and re.search(r"cache=3000s$", record.getMessage()) + record.levelno == logging.INFO and re.search(r"cache=900s$", record.getMessage()) for record in caplog.records ) @@ -255,7 +255,7 @@ async def test_async_lakebase_pool_logs_cache_seconds(monkeypatch, caplog): ) assert any( - record.levelno == logging.INFO and re.search(r"cache=3000s$", record.getMessage()) + record.levelno == logging.INFO and re.search(r"cache=900s$", record.getMessage()) for record in caplog.records ) From 00280db650619b17a1d112a9acbb5e80119af02a Mon Sep 17 00:00:00 2001 From: Jenny Date: Thu, 5 Feb 2026 17:55:13 -0800 Subject: [PATCH 12/17] ruff lint --- integrations/openai/src/databricks_openai/agents/session.py | 1 - 1 file changed, 1 deletion(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 6d170f5ff..d8bc45f0a 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -169,4 +169,3 @@ def _get_or_create_lakebase( ) cls._lakebase_sql_alchemy_cache[instance_name] = lakebase return lakebase - From ec0c1c9da536599c00cb7ab6a5b28a4ecc6354db Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 9 Feb 2026 11:56:49 -0800 Subject: [PATCH 13/17] openai integration cleanup - ensure imports available, AsyncLakebaseSQLAlchemy Tests --- .github/workflows/main.yml | 6 + .../src/databricks_openai/agents/__init__.py | 2 +- .../src/databricks_openai/agents/session.py | 16 +- .../src/databricks_openai/utils/clients.py | 6 +- .../openai/tests/unit_tests/test_clients.py | 29 ++- tests/databricks_ai_bridge/test_lakebase.py | 241 ++++++++++++++++++ 6 files changed, 288 insertions(+), 12 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 603eac1da..e708e5316 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -89,6 +89,12 @@ jobs: --exact --group tests --extra memory \ pytest tests/unit_tests/test_checkpoint.py \ tests/unit_tests/test_store.py + - name: Run databricks-openai Tests + working-directory: integrations/openai + run: | + uv run --resolution ${{ matrix.uv-resolution }} \ + --exact --group tests --extra memory \ + pytest tests/unit_tests/test_session.py langchain_cross_version_test: runs-on: ubuntu-latest diff --git a/integrations/openai/src/databricks_openai/agents/__init__.py b/integrations/openai/src/databricks_openai/agents/__init__.py index 2301fe320..1b67279c9 100644 --- a/integrations/openai/src/databricks_openai/agents/__init__.py +++ b/integrations/openai/src/databricks_openai/agents/__init__.py @@ -1,4 +1,4 @@ from databricks_openai.agents.mcp_server import McpServer from databricks_openai.agents.session import AsyncDatabricksSession -__all__ = ["McpServer", "AsyncDatabricksSession"] +__all__ = ["AsyncDatabricksSession", "McpServer"] diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index d8bc45f0a..0506d314b 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -42,11 +42,11 @@ async def main(): DEFAULT_TOKEN_CACHE_DURATION_SECONDS, AsyncLakebaseSQLAlchemy, ) -except ImportError as e: - raise ImportError( - "AsyncDatabricksSession requires databricks-openai[memory]. " - "Please install with: pip install databricks-openai[memory]" - ) from e + + _session_imports_available = True +except ImportError: + SQLAlchemySession = object # type: ignore + _session_imports_available = False logger = logging.getLogger(__name__) @@ -121,6 +121,12 @@ def __init__( **engine_kwargs: Additional keyword arguments passed to SQLAlchemy's create_async_engine(). """ + if not _session_imports_available: + raise ImportError( + "AsyncDatabricksSession requires databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) + self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, workspace_client=workspace_client, diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 99bb9d08c..d7b5f95b5 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -24,8 +24,10 @@ def _strip_strict_from_tools(tools: list | None) -> list | None: Databricks model endpoints (except GPT) don't support the 'strict' field in tool schemas, but openai-agents SDK v0.6.4+ includes it. """ - if tools is None: - return None + # Handle None or OpenAI's NOT_GIVEN/Omit sentinel types (non-iterable placeholders). + # See https://deepwiki.com/openai/openai-python/5-data-types-and-models#special-types-and-sentinels + if not tools: + return tools for tool in tools: if isinstance(tool, dict) and "function" in tool: tool.get("function", {}).pop("strict", None) diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index a112c3552..b3e8191b1 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -128,6 +128,27 @@ def test_strip_strict_from_tools_handles_none(self): assert _strip_strict_from_tools(None) is None + def test_strip_strict_from_tools_handles_openai_not_given_sentinel(self): + """OpenAI Agents SDK may pass NOT_GIVEN instead of None or a list.""" + from openai._types import NOT_GIVEN + + from databricks_openai.utils.clients import _strip_strict_from_tools + + # Should not raise TypeError: 'NotGiven' object is not iterable + result = _strip_strict_from_tools(NOT_GIVEN) + assert result is NOT_GIVEN + + def test_strip_strict_from_tools_handles_openai_omit_sentinel(self): + """OpenAI Agents SDK may pass Omit() instead of None or a list.""" + from openai._types import Omit + + from databricks_openai.utils.clients import _strip_strict_from_tools + + omit = Omit() + # Should not raise TypeError: 'Omit' object is not iterable + result = _strip_strict_from_tools(omit) + assert result is omit + def test_strip_strict_from_tools_handles_empty_list(self): from databricks_openai.utils.clients import _strip_strict_from_tools @@ -206,7 +227,7 @@ def test_chat_completions_strips_strict_for_claude(self): client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=cast(Any, tools), + tools=tools, ) call_kwargs = mock_create.call_args.kwargs @@ -231,7 +252,7 @@ def test_chat_completions_preserves_strict_for_gpt(self): client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=cast(Any, tools), + tools=tools, ) call_kwargs = mock_create.call_args.kwargs @@ -281,7 +302,7 @@ async def test_chat_completions_strips_strict_for_claude(self): await client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=cast(Any, tools), + tools=tools, ) call_kwargs = mock_create.call_args.kwargs @@ -306,7 +327,7 @@ async def test_chat_completions_preserves_strict_for_gpt(self): await client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=cast(Any, tools), + tools=tools, ) call_kwargs = mock_create.call_args.kwargs diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index d17be23b2..335e1ac5b 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -12,6 +12,7 @@ import databricks_ai_bridge.lakebase as lakebase from databricks_ai_bridge.lakebase import ( AsyncLakebasePool, + AsyncLakebaseSQLAlchemy, LakebaseClient, LakebasePool, SchemaPrivilege, @@ -1056,3 +1057,243 @@ def test_execute_grant_handles_insufficient_privilege(self): error_msg = str(exc_info.value) assert "Insufficient privileges" in error_msg assert "CAN MANAGE" in error_msg + + +# ============================================================================= +# AsyncLakebaseSQLAlchemy Tests +# ============================================================================= + +pytest.importorskip("sqlalchemy") + + +def _make_sqlalchemy_patches(workspace): + """Return a context manager that patches SQLAlchemy internals for AsyncLakebaseSQLAlchemy.""" + from unittest.mock import patch + + mock_engine = MagicMock(sync_engine=MagicMock()) + + return ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + return_value=lambda f: f, + ), + mock_engine, + ) + + +def test_async_lakebase_sqlalchemy_resolves_host(): + """Test that AsyncLakebaseSQLAlchemy resolves the Lakebase host from instance name.""" + workspace = _make_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert sa.host == "db.host" + workspace.database.get_database_instance.assert_called_once_with("lake-instance") + + +def test_async_lakebase_sqlalchemy_infers_username(): + """Test that AsyncLakebaseSQLAlchemy infers the username from workspace client.""" + workspace = _make_workspace(user_name="alice@databricks.com") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert sa.username == "alice@databricks.com" + + +def test_async_lakebase_sqlalchemy_engine_property(): + """Test that engine property returns the created AsyncEngine.""" + workspace = _make_workspace() + patch_engine, patch_event, mock_engine = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert sa.engine is mock_engine + + +def test_async_lakebase_sqlalchemy_creates_engine_with_correct_url(): + """Test that the engine is created with the correct SQLAlchemy URL.""" + from unittest.mock import patch + + workspace = _make_workspace() + mock_engine = MagicMock(sync_engine=MagicMock()) + + with ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, + patch( + "sqlalchemy.event.listens_for", + return_value=lambda f: f, + ), + ): + AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + url = mock_create.call_args[0][0] + assert url.drivername == "postgresql+psycopg" + assert url.username == "test@databricks.com" + assert url.host == "db.host" + assert url.port == 5432 + assert url.database == "databricks_postgres" + + +def test_async_lakebase_sqlalchemy_passes_extra_engine_kwargs(): + """Test that additional kwargs are forwarded to create_async_engine.""" + from unittest.mock import patch + + workspace = _make_workspace() + mock_engine = MagicMock(sync_engine=MagicMock()) + + with ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, + patch( + "sqlalchemy.event.listens_for", + return_value=lambda f: f, + ), + ): + AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + echo=True, + pool_pre_ping=True, + ) + + call_kwargs = mock_create.call_args[1] + assert call_kwargs["echo"] is True + assert call_kwargs["pool_pre_ping"] is True + + +def test_async_lakebase_sqlalchemy_do_connect_injects_token(): + """Test that the do_connect handler injects the OAuth token into cparams.""" + from unittest.mock import patch + + workspace = _make_workspace(credential_token="my-secret-token") + mock_engine = MagicMock(sync_engine=MagicMock()) + + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + assert cparams["password"] == "my-secret-token" + + +def test_async_lakebase_sqlalchemy_get_token_caches(): + """Test that get_token returns cached token on repeated calls.""" + workspace = _make_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + token1 = sa.get_token() + token2 = sa.get_token() + + assert token1 == token2 == "token-1" + assert workspace.database.generate_database_credential.call_count == 1 + + +def test_async_lakebase_sqlalchemy_get_token_refreshes_after_expiry(monkeypatch): + """Test that get_token mints a new token after cache expiry.""" + import time + + call_count = [] + + def mock_generate_credential(**kwargs): + call_count.append(1) + return MagicMock(token=f"token-{len(call_count)}") + + workspace = _make_workspace() + workspace.database.generate_database_credential = mock_generate_credential + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + token_cache_duration_seconds=1, + ) + + test_time = [100.0] + monkeypatch.setattr(time, "time", lambda: test_time[0]) + + token1 = sa.get_token() + assert token1 == "token-1" + + # Within cache window + test_time[0] = 100.5 + token2 = sa.get_token() + assert token2 == "token-1" + assert len(call_count) == 1 + + # After cache expiry + test_time[0] = 101.5 + token3 = sa.get_token() + assert token3 == "token-2" + assert len(call_count) == 2 + + +def test_async_lakebase_sqlalchemy_invalid_instance_raises(): + """Test that an invalid instance name raises ValueError.""" + workspace = _make_workspace() + workspace.database.get_database_instance.side_effect = Exception("Not found") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): + AsyncLakebaseSQLAlchemy( + instance_name="bad-instance", + workspace_client=workspace, + ) + + From be3a9cb6e93ec7815a69e11c0265de91b26bac55 Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 9 Feb 2026 13:11:23 -0800 Subject: [PATCH 14/17] lint --- .../openai/src/databricks_openai/utils/clients.py | 4 ++-- integrations/openai/tests/unit_tests/test_clients.py | 8 ++++---- tests/databricks_ai_bridge/test_lakebase.py | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index d7b5f95b5..21152090e 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Any, Generator from databricks.sdk import WorkspaceClient from httpx import AsyncClient, Auth, Client, Request, Response @@ -18,7 +18,7 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield request -def _strip_strict_from_tools(tools: list | None) -> list | None: +def _strip_strict_from_tools(tools: Any) -> Any: """Remove 'strict' field from tool function definitions. Databricks model endpoints (except GPT) don't support the 'strict' field diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index b3e8191b1..588051f2f 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -227,7 +227,7 @@ def test_chat_completions_strips_strict_for_claude(self): client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -252,7 +252,7 @@ def test_chat_completions_preserves_strict_for_gpt(self): client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -302,7 +302,7 @@ async def test_chat_completions_strips_strict_for_claude(self): await client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -327,7 +327,7 @@ async def test_chat_completions_preserves_strict_for_gpt(self): await client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index 335e1ac5b..bdc7ddcbd 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -1295,5 +1295,3 @@ def test_async_lakebase_sqlalchemy_invalid_instance_raises(): instance_name="bad-instance", workspace_client=workspace, ) - - From 6135b57798825d4cd4634c1be800ab0f5dc83f2d Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 9 Feb 2026 13:44:54 -0800 Subject: [PATCH 15/17] remove redundant sqlalchemy import since ai-bridge[memory] dependency includes it --- integrations/openai/pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 262546b9f..36887a7b1 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -23,7 +23,6 @@ dependencies = [ [project.optional-dependencies] memory = [ "databricks-ai-bridge[memory]>=0.13.0", - "sqlalchemy[asyncio]>=2.0.0", ] [dependency-groups] @@ -41,7 +40,6 @@ tests = [ "pytest-timeout>=2.3.1", "pytest-cov>=4.1.0", "databricks-ai-bridge[memory]>=0.13.0", - "sqlalchemy[asyncio]>=2.0.0", ] [build-system] From 0ce9d423607fc479804e1eb09ca75a21a3cc908f Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 9 Feb 2026 16:27:24 -0800 Subject: [PATCH 16/17] cache using both instances name + engine kwargs, can tune with use_cached_engine --- .../src/databricks_openai/agents/session.py | 66 ++++++--- .../openai/tests/unit_tests/test_session.py | 133 ++++++++++++++++++ 2 files changed, 179 insertions(+), 20 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 0506d314b..ba992e77d 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -1,7 +1,7 @@ """ AsyncDatabricksSession - Async SQLAlchemy-based session storage for Databricks Lakebase. -This module provides a AsyncDatabricksSession class that subclasses OpenAI's SQLAlchemySession +This module provides an AsyncDatabricksSession class that subclasses OpenAI's SQLAlchemySession to provide persistent conversation history storage in Databricks Lakebase. Note: @@ -30,9 +30,10 @@ async def main(): from __future__ import annotations +import json import logging from threading import Lock -from typing import Optional +from typing import Any, Optional try: from agents.extensions.memory import SQLAlchemySession @@ -85,8 +86,9 @@ async def main(): ``` """ - # Class-level cache for AsyncLakebaseSQLAlchemy instances, keyed by instance_name. - # This allows multiple AsyncDatabricksSession instances to share a single engine/pool. + # Class-level cache for AsyncLakebaseSQLAlchemy instances keyed by + # (instance_name, engine_kwargs). This allows multiple sessions to share + # a single engine/connection pool when the configuration is identical. _lakebase_sql_alchemy_cache: dict[str, AsyncLakebaseSQLAlchemy] = {} _lakebase_sql_alchemy_cache_lock = Lock() @@ -100,10 +102,11 @@ def __init__( create_tables: bool = True, sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", + use_cached_engine: bool = True, **engine_kwargs, ) -> None: """ - Initialize a AsyncDatabricksSession for Databricks Lakebase. + Initialize an AsyncDatabricksSession for Databricks Lakebase. Args: session_id: Unique identifier for the conversation session. @@ -118,6 +121,9 @@ def __init__( Defaults to "agent_sessions". messages_table: Name of the messages table. Defaults to "agent_messages". + use_cached_engine: Whether to reuse a cached engine for the same + instance_name and engine_kwargs combination. Set to False to + always create a new engine. Defaults to True. **engine_kwargs: Additional keyword arguments passed to SQLAlchemy's create_async_engine(). """ @@ -132,6 +138,7 @@ def __init__( workspace_client=workspace_client, token_cache_duration_seconds=token_cache_duration_seconds, pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), + use_cached_engine=use_cached_engine, **engine_kwargs, ) @@ -150,6 +157,13 @@ def __init__( session_id, ) + @classmethod + def _build_cache_key(cls, instance_name: str, **engine_kwargs: Any) -> str: + """Build a cache key from instance_name and engine_kwargs.""" + # Sort kwargs for deterministic key; use JSON for serializable values + kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str) + return f"{instance_name}::{kwargs_key}" + @classmethod def _get_or_create_lakebase( cls, @@ -158,20 +172,32 @@ def _get_or_create_lakebase( workspace_client: Optional[WorkspaceClient], token_cache_duration_seconds: int, pool_recycle: int, + use_cached_engine: bool = True, **engine_kwargs, ) -> AsyncLakebaseSQLAlchemy: - """Get cached AsyncLakebaseSQLAlchemy or create a new one (thread-safe).""" - with cls._lakebase_sql_alchemy_cache_lock: - if instance_name in cls._lakebase_sql_alchemy_cache: - logger.debug("Reusing cached engine for instance=%s", instance_name) - return cls._lakebase_sql_alchemy_cache[instance_name] - - lakebase = AsyncLakebaseSQLAlchemy( - instance_name=instance_name, - workspace_client=workspace_client, - token_cache_duration_seconds=token_cache_duration_seconds, - pool_recycle=pool_recycle, - **engine_kwargs, - ) - cls._lakebase_sql_alchemy_cache[instance_name] = lakebase - return lakebase + """Get cached AsyncLakebaseSQLAlchemy or create a new one. + The cache key uses both instance_name and engine_kwargs + """ + cache_key = cls._build_cache_key( + instance_name, pool_recycle=pool_recycle, **engine_kwargs + ) + + if use_cached_engine: + with cls._lakebase_sql_alchemy_cache_lock: + if cache_key in cls._lakebase_sql_alchemy_cache: + logger.debug("Reusing cached engine for key=%s", cache_key) + return cls._lakebase_sql_alchemy_cache[cache_key] + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + pool_recycle=pool_recycle, + **engine_kwargs, + ) + + if use_cached_engine: + with cls._lakebase_sql_alchemy_cache_lock: + cls._lakebase_sql_alchemy_cache[cache_key] = lakebase + + return lakebase diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 5f8d880e6..043726cf9 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -623,6 +623,139 @@ def test_different_instances_get_different_engines( # Sessions should have different engines assert session1._engine is not session2._engine + def test_different_engine_kwargs_get_different_engines( + self, mock_workspace_client, mock_event_listens_for + ): + """Test that same instance_name with different engine_kwargs get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=False, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, + ) + + # Engine should be created twice (different engine_kwargs) + assert mock_create_engine.call_count == 2 + + # Sessions should have different engines + assert session1._engine is not session2._engine + + def test_use_cached_engine_false_creates_new_engine( + self, mock_workspace_client, mock_event_listens_for + ): + """Test that use_cached_engine=False always creates a new engine.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + use_cached_engine=False, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + use_cached_engine=False, + ) + + # Engine should be created twice despite same instance_name + assert mock_create_engine.call_count == 2 + + # Sessions should have different engines + assert session1._engine is not session2._engine + + def test_same_instance_and_kwargs_share_engine( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that same instance_name with same engine_kwargs reuse the cached engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + + # Both sessions should share the same engine + assert session1._engine is session2._engine + class TestAsyncDatabricksSessionAsyncOnly: """Tests verifying AsyncDatabricksSession is async-only.""" From 62bd09f3c23e61dca7d249459dacff317ddaea5b Mon Sep 17 00:00:00 2001 From: Jenny Date: Mon, 9 Feb 2026 16:30:46 -0800 Subject: [PATCH 17/17] lint fix --- integrations/openai/src/databricks_openai/agents/session.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index ba992e77d..5d96826b1 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -178,9 +178,7 @@ def _get_or_create_lakebase( """Get cached AsyncLakebaseSQLAlchemy or create a new one. The cache key uses both instance_name and engine_kwargs """ - cache_key = cls._build_cache_key( - instance_name, pool_recycle=pool_recycle, **engine_kwargs - ) + cache_key = cls._build_cache_key(instance_name, pool_recycle=pool_recycle, **engine_kwargs) if use_cached_engine: with cls._lakebase_sql_alchemy_cache_lock: