diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 603eac1da..e708e5316 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -89,6 +89,12 @@ jobs: --exact --group tests --extra memory \ pytest tests/unit_tests/test_checkpoint.py \ tests/unit_tests/test_store.py + - name: Run databricks-openai Tests + working-directory: integrations/openai + run: | + uv run --resolution ${{ matrix.uv-resolution }} \ + --exact --group tests --extra memory \ + pytest tests/unit_tests/test_session.py langchain_cross_version_test: runs-on: ubuntu-latest diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index 945c78a75..36887a7b1 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -20,6 +20,11 @@ dependencies = [ "openai-agents>=0.5.0" ] +[project.optional-dependencies] +memory = [ + "databricks-ai-bridge[memory]>=0.13.0", +] + [dependency-groups] dev = [ "typing_extensions>=4.15.0", @@ -34,6 +39,7 @@ tests = [ "pytest-asyncio>=1.3.0", "pytest-timeout>=2.3.1", "pytest-cov>=4.1.0", + "databricks-ai-bridge[memory]>=0.13.0", ] [build-system] diff --git a/integrations/openai/src/databricks_openai/agents/__init__.py b/integrations/openai/src/databricks_openai/agents/__init__.py index f78fbd25b..1b67279c9 100644 --- a/integrations/openai/src/databricks_openai/agents/__init__.py +++ b/integrations/openai/src/databricks_openai/agents/__init__.py @@ -1,3 +1,4 @@ from databricks_openai.agents.mcp_server import McpServer +from databricks_openai.agents.session import AsyncDatabricksSession -__all__ = ["McpServer"] +__all__ = ["AsyncDatabricksSession", "McpServer"] diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py new file mode 100644 index 000000000..5d96826b1 --- /dev/null +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -0,0 +1,201 @@ +""" +AsyncDatabricksSession - Async SQLAlchemy-based session storage for Databricks Lakebase. + +This module provides an AsyncDatabricksSession class that subclasses OpenAI's SQLAlchemySession +to provide persistent conversation history storage in Databricks Lakebase. + +Note: + This class is **async-only** as it follows the Session Protocol. Use within async context + https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.Session + +Usage:: + + import asyncio + from databricks_openai.agents import AsyncDatabricksSession + from agents import Agent, Runner + + + async def main(): + session = AsyncDatabricksSession( + session_id="user-123", + instance_name="my-lakebase-instance", + ) + + agent = Agent(name="Assistant") + result = await Runner.run(agent, "Hello!", session=session) + + + asyncio.run(main()) +""" + +from __future__ import annotations + +import json +import logging +from threading import Lock +from typing import Any, Optional + +try: + from agents.extensions.memory import SQLAlchemySession + from databricks.sdk import WorkspaceClient + from databricks_ai_bridge.lakebase import ( + DEFAULT_POOL_RECYCLE_SECONDS, + DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + AsyncLakebaseSQLAlchemy, + ) + + _session_imports_available = True +except ImportError: + SQLAlchemySession = object # type: ignore + _session_imports_available = False + +logger = logging.getLogger(__name__) + + +class AsyncDatabricksSession(SQLAlchemySession): + """ + Async OpenAI Agents SDK Session implementation for Databricks Lakebase. + For more information on the Session protocol, see: + https://openai.github.io/openai-agents-python/ref/memory/session/ + + Note: + This class is **async-only**. All session methods (get_items, add_items, + clear_session, etc.) are coroutines and must be awaited. + + The session stores conversation history in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + Example: + ```python + import asyncio + from databricks_openai.agents import AsyncDatabricksSession + from agents import Agent, Runner + + + async def main(): + session = AsyncDatabricksSession( + session_id="user-123", + instance_name="my-lakebase-instance", + ) + agent = Agent(name="Assistant") + result = await Runner.run(agent, "Hello!", session=session) + + + asyncio.run(main()) + ``` + """ + + # Class-level cache for AsyncLakebaseSQLAlchemy instances keyed by + # (instance_name, engine_kwargs). This allows multiple sessions to share + # a single engine/connection pool when the configuration is identical. + _lakebase_sql_alchemy_cache: dict[str, AsyncLakebaseSQLAlchemy] = {} + _lakebase_sql_alchemy_cache_lock = Lock() + + def __init__( + self, + session_id: str, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + create_tables: bool = True, + sessions_table: str = "agent_sessions", + messages_table: str = "agent_messages", + use_cached_engine: bool = True, + **engine_kwargs, + ) -> None: + """ + Initialize an AsyncDatabricksSession for Databricks Lakebase. + + Args: + session_id: Unique identifier for the conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + If not provided, a default client will be created. + token_cache_duration_seconds: How long to cache OAuth tokens. + Defaults to 15 minutes. + create_tables: Whether to auto-create tables on first use. + Defaults to True. + sessions_table: Name of the sessions table. + Defaults to "agent_sessions". + messages_table: Name of the messages table. + Defaults to "agent_messages". + use_cached_engine: Whether to reuse a cached engine for the same + instance_name and engine_kwargs combination. Set to False to + always create a new engine. Defaults to True. + **engine_kwargs: Additional keyword arguments passed to + SQLAlchemy's create_async_engine(). + """ + if not _session_imports_available: + raise ImportError( + "AsyncDatabricksSession requires databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) + + self._lakebase = self._get_or_create_lakebase( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS), + use_cached_engine=use_cached_engine, + **engine_kwargs, + ) + + # Initialize parent SQLAlchemySession - inherits all SQL logic + super().__init__( + session_id=session_id, + engine=self._lakebase.engine, + create_tables=create_tables, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + logger.info( + "AsyncDatabricksSession initialized: instance=%s session_id=%s", + instance_name, + session_id, + ) + + @classmethod + def _build_cache_key(cls, instance_name: str, **engine_kwargs: Any) -> str: + """Build a cache key from instance_name and engine_kwargs.""" + # Sort kwargs for deterministic key; use JSON for serializable values + kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str) + return f"{instance_name}::{kwargs_key}" + + @classmethod + def _get_or_create_lakebase( + cls, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient], + token_cache_duration_seconds: int, + pool_recycle: int, + use_cached_engine: bool = True, + **engine_kwargs, + ) -> AsyncLakebaseSQLAlchemy: + """Get cached AsyncLakebaseSQLAlchemy or create a new one. + The cache key uses both instance_name and engine_kwargs + """ + cache_key = cls._build_cache_key(instance_name, pool_recycle=pool_recycle, **engine_kwargs) + + if use_cached_engine: + with cls._lakebase_sql_alchemy_cache_lock: + if cache_key in cls._lakebase_sql_alchemy_cache: + logger.debug("Reusing cached engine for key=%s", cache_key) + return cls._lakebase_sql_alchemy_cache[cache_key] + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + pool_recycle=pool_recycle, + **engine_kwargs, + ) + + if use_cached_engine: + with cls._lakebase_sql_alchemy_cache_lock: + cls._lakebase_sql_alchemy_cache[cache_key] = lakebase + + return lakebase diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 99bb9d08c..21152090e 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -1,4 +1,4 @@ -from typing import Generator +from typing import Any, Generator from databricks.sdk import WorkspaceClient from httpx import AsyncClient, Auth, Client, Request, Response @@ -18,14 +18,16 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield request -def _strip_strict_from_tools(tools: list | None) -> list | None: +def _strip_strict_from_tools(tools: Any) -> Any: """Remove 'strict' field from tool function definitions. Databricks model endpoints (except GPT) don't support the 'strict' field in tool schemas, but openai-agents SDK v0.6.4+ includes it. """ - if tools is None: - return None + # Handle None or OpenAI's NOT_GIVEN/Omit sentinel types (non-iterable placeholders). + # See https://deepwiki.com/openai/openai-python/5-data-types-and-models#special-types-and-sentinels + if not tools: + return tools for tool in tools: if isinstance(tool, dict) and "function" in tool: tool.get("function", {}).pop("strict", None) diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py new file mode 100644 index 000000000..bc6307422 --- /dev/null +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -0,0 +1,312 @@ +""" +Integration tests for AsyncDatabricksSession. + +These tests require: +1. A Lakebase instance to be available +2. Valid Databricks authentication (DATABRICKS_HOST + DATABRICKS_TOKEN or profile) + +Set the environment variable: + LAKEBASE_INSTANCE_NAME: Name of the Lakebase instance + +Example: + LAKEBASE_INSTANCE_NAME=lakebase pytest tests/integration_tests/test_memory_session.py -v +""" + +from __future__ import annotations + +import os +import uuid +from typing import Any, cast + +import pytest + +# Skip all tests if LAKEBASE_INSTANCE_NAME is not set +pytestmark = pytest.mark.skipif( + not os.environ.get("LAKEBASE_INSTANCE_NAME"), + reason="LAKEBASE_INSTANCE_NAME environment variable not set", +) + + +def get_instance_name() -> str: + """Get the Lakebase instance name from environment.""" + return os.environ["LAKEBASE_INSTANCE_NAME"] + + +def get_unique_table_names() -> tuple[str, str]: + """Generate unique table names for test isolation.""" + suffix = uuid.uuid4().hex[:8] + return f"test_sessions_{suffix}", f"test_messages_{suffix}" + + +@pytest.fixture +def cleanup_tables(): + """Fixture to track and clean up test tables after tests.""" + tables_to_cleanup: list[tuple[str, str]] = [] + + yield tables_to_cleanup + + # Cleanup after test + if tables_to_cleanup: + from databricks_ai_bridge.lakebase import LakebasePool + + pool = LakebasePool(instance_name=get_instance_name()) + with pool.connection() as conn: + for sessions_table, messages_table in tables_to_cleanup: + # Drop messages first (foreign key constraint) + conn.execute(f"DROP TABLE IF EXISTS {messages_table}") + conn.execute(f"DROP TABLE IF EXISTS {sessions_table}") + pool.close() + + +# ============================================================================= +# AsyncDatabricksSession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_memory_session_crud_operations(cleanup_tables): + """ + Comprehensive CRUD test for AsyncDatabricksSession. + + Tests the full lifecycle: + - clear_session() on fresh session + - get_items() returns empty list for new session + - add_items() stores messages + - get_items() retrieves stored messages + - get_items(limit=N) returns latest N items in order + - pop_item() removes and returns most recent item + - clear_session() removes all items + """ + from databricks_openai.agents import AsyncDatabricksSession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = str(uuid.uuid4()) + session = AsyncDatabricksSession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Clear any existing data (should be no-op for new session) + await session.clear_session() + + # Test get_items on empty session + items = cast(list[Any], await session.get_items()) + assert items == [], f"Expected empty list, got {items}" + + # Test add_items + test_items: list[Any] = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + await session.add_items(test_items) + + # Test get_items returns what we added + items = cast(list[Any], await session.get_items()) + assert len(items) == 2, f"Expected 2 items, got {len(items)}" + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, how are you?" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "I'm doing well, thank you!" + + # Test get_items with limit - should return latest N items in chronological order + items = cast(list[Any], await session.get_items(limit=1)) + assert len(items) == 1, f"Expected 1 item with limit, got {len(items)}" + assert items[0]["role"] == "assistant" # Latest item + + # Test pop_item - removes and returns the last item + popped = cast(Any, await session.pop_item()) + assert popped is not None + assert popped["role"] == "assistant" # Should be the last item + + # Verify only 1 item remains + items = cast(list[Any], await session.get_items()) + assert len(items) == 1, f"Expected 1 item after pop, got {len(items)}" + assert items[0]["role"] == "user" + + # Test clear_session + await session.clear_session() + items = cast(list[Any], await session.get_items()) + assert items == [], f"Expected empty after clear, got {items}" + + +@pytest.mark.asyncio +async def test_memory_session_multiple_sessions_isolated(cleanup_tables): + """Test that different session_ids have isolated data.""" + from databricks_openai.agents import AsyncDatabricksSession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id_1 = str(uuid.uuid4()) + session_id_2 = str(uuid.uuid4()) + + session_1 = AsyncDatabricksSession( + session_id=session_id_1, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + session_2 = AsyncDatabricksSession( + session_id=session_id_2, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + create_tables=False, # Tables already created by session_1 + ) + + # Add different items to each session + items_1_data: list[Any] = [{"role": "user", "content": "Session 1 message"}] + items_2_data: list[Any] = [{"role": "user", "content": "Session 2 message"}] + await session_1.add_items(items_1_data) + await session_2.add_items(items_2_data) + + # Verify isolation + items_1 = cast(list[Any], await session_1.get_items()) + items_2 = cast(list[Any], await session_2.get_items()) + + assert len(items_1) == 1 + assert len(items_2) == 1 + assert items_1[0]["content"] == "Session 1 message" + assert items_2[0]["content"] == "Session 2 message" + + # Clear one session shouldn't affect the other + await session_1.clear_session() + items_1 = cast(list[Any], await session_1.get_items()) + items_2 = cast(list[Any], await session_2.get_items()) + assert len(items_1) == 0 + assert len(items_2) == 1 + + # Cleanup + await session_2.clear_session() + + +@pytest.mark.asyncio +async def test_memory_session_pop_empty_returns_none(cleanup_tables): + """Test that pop_item returns None on empty session.""" + from databricks_openai.agents import AsyncDatabricksSession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncDatabricksSession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Pop on empty session should return None + popped = await session.pop_item() + assert popped is None + + +@pytest.mark.asyncio +async def test_memory_session_add_empty_items_noop(cleanup_tables): + """Test that add_items with empty list is a no-op.""" + from databricks_openai.agents import AsyncDatabricksSession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncDatabricksSession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add empty list - should not raise + await session.add_items([]) + + # Session should still be empty + items = cast(list[Any], await session.get_items()) + assert items == [] + + +@pytest.mark.asyncio +async def test_memory_session_complex_message_data(cleanup_tables): + """Test storing complex message data with nested structures.""" + from databricks_openai.agents import AsyncDatabricksSession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncDatabricksSession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add complex item with nested data + complex_item = { + "role": "assistant", + "content": "Here's your result", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + }, + } + ], + "metadata": { + "model": "gpt-4", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + }, + } + await session.add_items([complex_item]) # type: ignore[list-item] + + # Retrieve and verify + items = cast(list[Any], await session.get_items()) + assert len(items) == 1 + assert items[0]["role"] == "assistant" + assert items[0]["tool_calls"][0]["function"]["name"] == "get_weather" + assert items[0]["metadata"]["usage"]["prompt_tokens"] == 10 + + # Cleanup + await session.clear_session() + + +@pytest.mark.asyncio +async def test_memory_session_get_items_ordering(cleanup_tables): + """Test that get_items returns items in correct chronological order.""" + from databricks_openai.agents import AsyncDatabricksSession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncDatabricksSession( + session_id=str(uuid.uuid4()), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add multiple items + test_items: list[Any] = [{"role": "user", "content": f"Message {i}"} for i in range(5)] + await session.add_items(test_items) + + # Get all items - should be in chronological order + items = cast(list[Any], await session.get_items()) + assert len(items) == 5 + for i, item in enumerate(items): + assert item["content"] == f"Message {i}" + + # Get with limit - should return LATEST N items in chronological order + items = cast(list[Any], await session.get_items(limit=3)) + assert len(items) == 3 + # Should be messages 2, 3, 4 (the latest 3) in order + assert items[0]["content"] == "Message 2" + assert items[1]["content"] == "Message 3" + assert items[2]["content"] == "Message 4" + + # Cleanup + await session.clear_session() diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index df44f276e..588051f2f 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -54,6 +54,7 @@ def test_bearer_auth_flow(self, mock_workspace_client): request = Request("GET", "https://test.databricks.com/api/test") # Authenticate the request + assert http_client.auth is not None auth_flow = http_client.auth.auth_flow(request) authenticated_request = next(auth_flow) @@ -98,6 +99,7 @@ def test_bearer_auth_flow(self, mock_workspace_client): request = Request("GET", "https://test.databricks.com/api/test") # Authenticate the request + assert http_client.auth is not None auth_flow = http_client.auth.auth_flow(request) authenticated_request = next(auth_flow) @@ -126,6 +128,27 @@ def test_strip_strict_from_tools_handles_none(self): assert _strip_strict_from_tools(None) is None + def test_strip_strict_from_tools_handles_openai_not_given_sentinel(self): + """OpenAI Agents SDK may pass NOT_GIVEN instead of None or a list.""" + from openai._types import NOT_GIVEN + + from databricks_openai.utils.clients import _strip_strict_from_tools + + # Should not raise TypeError: 'NotGiven' object is not iterable + result = _strip_strict_from_tools(NOT_GIVEN) + assert result is NOT_GIVEN + + def test_strip_strict_from_tools_handles_openai_omit_sentinel(self): + """OpenAI Agents SDK may pass Omit() instead of None or a list.""" + from openai._types import Omit + + from databricks_openai.utils.clients import _strip_strict_from_tools + + omit = Omit() + # Should not raise TypeError: 'Omit' object is not iterable + result = _strip_strict_from_tools(omit) + assert result is omit + def test_strip_strict_from_tools_handles_empty_list(self): from databricks_openai.utils.clients import _strip_strict_from_tools @@ -204,7 +227,7 @@ def test_chat_completions_strips_strict_for_claude(self): client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -229,7 +252,7 @@ def test_chat_completions_preserves_strict_for_gpt(self): client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -279,7 +302,7 @@ async def test_chat_completions_strips_strict_for_claude(self): await client.chat.completions.create( model="databricks-claude-3-7-sonnet", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs @@ -304,7 +327,7 @@ async def test_chat_completions_preserves_strict_for_gpt(self): await client.chat.completions.create( model="databricks-gpt-4o", messages=[{"role": "user", "content": "hi"}], - tools=tools, + tools=cast(Any, tools), ) call_kwargs = mock_create.call_args.kwargs diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py new file mode 100644 index 000000000..043726cf9 --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -0,0 +1,865 @@ +"""Unit tests for AsyncDatabricksSession.""" + +from unittest.mock import MagicMock, patch + +import pytest + +try: + from databricks_ai_bridge.lakebase import _LakebaseBase # noqa: F401 + from psycopg.rows import DictRow # noqa: F401 + from sqlalchemy.ext.asyncio import AsyncEngine # noqa: F401 +except ImportError as e: + raise ImportError( + "AsyncDatabricksSession tests require databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) from e + + +@pytest.fixture(autouse=True) +def clear_engine_cache(): + """Clear AsyncDatabricksSession engine cache before each test.""" + try: + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession._lakebase_sql_alchemy_cache.clear() + except ImportError: + pass + yield + try: + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession._lakebase_sql_alchemy_cache.clear() + except ImportError: + pass + + +@pytest.fixture +def mock_workspace_client(): + """Create a mock WorkspaceClient.""" + mock_client = MagicMock() + mock_client.config.host = "https://test.databricks.com" + + # Mock current_user.me() for username inference + mock_user = MagicMock() + mock_user.user_name = "test_user@databricks.com" + mock_client.current_user.me.return_value = mock_user + + # Mock database.get_database_instance() for host resolution + mock_instance = MagicMock() + mock_instance.read_write_dns = "test-instance.lakebase.databricks.com" + mock_client.database.get_database_instance.return_value = mock_instance + + # Mock database.generate_database_credential() for token minting + mock_credential = MagicMock() + mock_credential.token = "test-oauth-token" + mock_client.database.generate_database_credential.return_value = mock_credential + + return mock_client + + +@pytest.fixture +def mock_engine(): + """Create a mock SQLAlchemy AsyncEngine.""" + mock_eng = MagicMock() + mock_eng.sync_engine = MagicMock() + return mock_eng + + +@pytest.fixture +def mock_event_listens_for(): + """Create a mock for event.listens_for that captures the handler.""" + + def create_decorator(engine, event_name): + def decorator(fn): + return fn + + return decorator + + return create_decorator + + +class TestAsyncLakebaseSQLAlchemy: + """Tests for AsyncLakebaseSQLAlchemy class.""" + + def test_init_resolves_host(self, mock_workspace_client): + """Test that initialization resolves the Lakebase host.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert lakebase.host == "test-instance.lakebase.databricks.com" + mock_workspace_client.database.get_database_instance.assert_called_once_with( + "test-instance" + ) + + def test_init_infers_username(self, mock_workspace_client): + """Test that initialization infers the username.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert lakebase.username == "test_user@databricks.com" + mock_workspace_client.current_user.me.assert_called() + + def test_get_token_mints_new_token(self, mock_workspace_client): + """Test that get_token mints a new token when cache is empty.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + token = lakebase.get_token() + + assert token == "test-oauth-token" + mock_workspace_client.database.generate_database_credential.assert_called_once() + + def test_get_token_returns_cached_token(self, mock_workspace_client): + """Test that get_token returns cached token when valid.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # First call - mints token + token1 = lakebase.get_token() + # Second call - should return cached + token2 = lakebase.get_token() + + assert token1 == token2 == "test-oauth-token" + # Should only mint once + assert mock_workspace_client.database.generate_database_credential.call_count == 1 + + def test_get_token_refreshes_expired_token(self, mock_workspace_client): + """Test that get_token refreshes token when expired.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=MagicMock(sync_engine=MagicMock()), + ), + ): + import time + + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + token_cache_duration_seconds=0, # Immediate expiration + ) + + # First call + lakebase.get_token() + # Wait a tiny bit + time.sleep(0.01) + # Second call - should mint new token + lakebase.get_token() + + # Should mint twice due to expiration + assert mock_workspace_client.database.generate_database_credential.call_count == 2 + + def test_init_raises_on_invalid_instance(self, mock_workspace_client): + """Test that initialization raises ValueError for invalid instance.""" + mock_workspace_client.database.get_database_instance.side_effect = Exception( + "Instance not found" + ) + + with patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): + AsyncLakebaseSQLAlchemy( + instance_name="invalid-instance", + workspace_client=mock_workspace_client, + ) + + def test_engine_property_returns_engine(self, mock_workspace_client): + """Test that engine property returns the created engine.""" + mock_eng = MagicMock(sync_engine=MagicMock()) + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch("sqlalchemy.event.listens_for", return_value=lambda f: f), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_eng, + ), + ): + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + lakebase = AsyncLakebaseSQLAlchemy( + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + assert lakebase.engine is mock_eng + + +class TestAsyncDatabricksSessionInit: + """Tests for AsyncDatabricksSession initialization.""" + + def test_init_creates_engine_with_correct_url( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization creates engine with correct connection URL.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Verify engine was created with correct URL object + call_args = mock_create_engine.call_args + url = call_args[0][0] + # URL.create() returns a URL object, check its properties + assert url.drivername == "postgresql+psycopg" + assert url.username == "test_user@databricks.com" + assert url.host == "test-instance.lakebase.databricks.com" + assert url.port == 5432 + assert url.database == "databricks_postgres" + + def test_init_uses_pool_recycle( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that initialization uses pool_recycle for connection recycling.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_ai_bridge.lakebase import DEFAULT_POOL_RECYCLE_SECONDS + + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + call_kwargs = mock_create_engine.call_args[1] + # Should use pool_recycle for connection recycling (default QueuePool) + assert call_kwargs["pool_recycle"] == DEFAULT_POOL_RECYCLE_SECONDS + # Should NOT use NullPool (uses default QueuePool instead) + assert "poolclass" not in call_kwargs + + def test_init_sets_ssl_mode(self, mock_workspace_client, mock_engine, mock_event_listens_for): + """Test that initialization sets SSL mode to require.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + call_kwargs = mock_create_engine.call_args[1] + assert call_kwargs["connect_args"]["sslmode"] == "require" + + def test_init_registers_do_connect_event(self, mock_workspace_client, mock_engine): + """Test that initialization registers do_connect event for token injection.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch("sqlalchemy.event.listens_for") as mock_listens_for, + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Verify event listener was registered on sync_engine for "do_connect" + mock_listens_for.assert_called_once_with(mock_engine.sync_engine, "do_connect") + + def test_init_with_custom_table_names( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test initialization with custom table names.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + patch( + "agents.extensions.memory.SQLAlchemySession.__init__", + return_value=None, + ) as mock_parent_init, + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + # Verify parent was called with custom table names + call_kwargs = mock_parent_init.call_args[1] + assert call_kwargs["sessions_table"] == "custom_sessions" + assert call_kwargs["messages_table"] == "custom_messages" + + def test_init_with_create_tables_false( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test initialization with create_tables=False.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + patch( + "agents.extensions.memory.SQLAlchemySession.__init__", + return_value=None, + ) as mock_parent_init, + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + create_tables=False, + ) + + call_kwargs = mock_parent_init.call_args[1] + assert call_kwargs["create_tables"] is False + + +class TestAsyncDatabricksSessionTokenInjection: + """Tests for token injection via do_connect event.""" + + def test_do_connect_injects_token(self, mock_workspace_client, mock_engine): + """Test that do_connect event handler injects token into connection params.""" + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Simulate do_connect event + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + + # Verify token was injected + assert cparams["password"] == "test-oauth-token" + + +class TestAsyncDatabricksSessionEngineKwargs: + """Tests for passing additional engine kwargs.""" + + def test_extra_engine_kwargs_passed( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that extra kwargs are passed to create_async_engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, # Extra kwarg for SQLAlchemy + pool_pre_ping=True, # Another extra kwarg + ) + + call_kwargs = mock_create_engine.call_args[1] + assert call_kwargs.get("echo") is True + assert call_kwargs.get("pool_pre_ping") is True + + +class TestAsyncDatabricksSessionEngineCaching: + """Tests for engine caching behavior.""" + + def test_sessions_share_engine_for_same_instance( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that multiple sessions with the same instance_name share an engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + # Create two sessions with the same instance_name + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + + # Both sessions should reference the same engine + assert session1._engine is session2._engine + + def test_different_instances_get_different_engines( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that sessions with different instance_names get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + # Create sessions with different instance_names + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="instance-a", + workspace_client=mock_workspace_client, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="instance-b", + workspace_client=mock_workspace_client, + ) + + # Engine should be created twice (once per instance) + assert mock_create_engine.call_count == 2 + + # Sessions should have different engines + assert session1._engine is not session2._engine + + def test_different_engine_kwargs_get_different_engines( + self, mock_workspace_client, mock_event_listens_for + ): + """Test that same instance_name with different engine_kwargs get different engines.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=False, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, + ) + + # Engine should be created twice (different engine_kwargs) + assert mock_create_engine.call_count == 2 + + # Sessions should have different engines + assert session1._engine is not session2._engine + + def test_use_cached_engine_false_creates_new_engine( + self, mock_workspace_client, mock_event_listens_for + ): + """Test that use_cached_engine=False always creates a new engine.""" + engine1 = MagicMock() + engine1.sync_engine = MagicMock() + engine2 = MagicMock() + engine2.sync_engine = MagicMock() + + engines = [engine1, engine2] + engine_iter = iter(engines) + + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + side_effect=lambda *args, **kwargs: next(engine_iter), + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + use_cached_engine=False, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + use_cached_engine=False, + ) + + # Engine should be created twice despite same instance_name + assert mock_create_engine.call_count == 2 + + # Sessions should have different engines + assert session1._engine is not session2._engine + + def test_same_instance_and_kwargs_share_engine( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that same instance_name with same engine_kwargs reuse the cached engine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create_engine, + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + from databricks_openai.agents.session import AsyncDatabricksSession + + session1 = AsyncDatabricksSession( + session_id="session-1", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, + ) + session2 = AsyncDatabricksSession( + session_id="session-2", + instance_name="test-instance", + workspace_client=mock_workspace_client, + echo=True, + ) + + # Engine should only be created once + assert mock_create_engine.call_count == 1 + + # Both sessions should share the same engine + assert session1._engine is session2._engine + + +class TestAsyncDatabricksSessionAsyncOnly: + """Tests verifying AsyncDatabricksSession is async-only.""" + + def test_get_items_returns_coroutine_without_await( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that calling get_items() without await returns a coroutine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + import inspect + + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Calling without await returns a coroutine, not actual data + result = session.get_items() + assert inspect.iscoroutine(result) + + # Clean up the coroutine to avoid RuntimeWarning + result.close() + + def test_add_items_returns_coroutine_without_await( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that calling add_items() without await returns a coroutine.""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + import inspect + + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # Calling without await returns a coroutine, not actual data + result = session.add_items([{"role": "user", "content": "test"}]) + assert inspect.iscoroutine(result) + + # Clean up the coroutine to avoid RuntimeWarning + result.close() + + def test_methods_are_coroutine_functions( + self, mock_workspace_client, mock_engine, mock_event_listens_for + ): + """Test that all session methods are async (coroutine functions).""" + with ( + patch( + "databricks_ai_bridge.lakebase.WorkspaceClient", + return_value=mock_workspace_client, + ), + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=mock_event_listens_for, + ), + ): + import inspect + + from databricks_openai.agents.session import AsyncDatabricksSession + + session = AsyncDatabricksSession( + session_id="test-session-123", + instance_name="test-instance", + workspace_client=mock_workspace_client, + ) + + # All data methods should be async + assert inspect.iscoroutinefunction(session.get_items) + assert inspect.iscoroutinefunction(session.add_items) + assert inspect.iscoroutinefunction(session.pop_item) + assert inspect.iscoroutinefunction(session.clear_session) diff --git a/pyproject.toml b/pyproject.toml index 32ed1e2f8..044f436c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ file = "LICENSE.txt" [project.optional-dependencies] memory = [ "psycopg[binary,pool]>=3.2.10", + "sqlalchemy[asyncio]>=2.0.0", ] [dependency-groups] diff --git a/src/databricks_ai_bridge/lakebase.py b/src/databricks_ai_bridge/lakebase.py index 1cb0e83fa..f582044bd 100644 --- a/src/databricks_ai_bridge/lakebase.py +++ b/src/databricks_ai_bridge/lakebase.py @@ -6,11 +6,15 @@ import uuid from enum import Enum from threading import Lock -from typing import Any, List, Literal, Optional, Sequence +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence from databricks.sdk import WorkspaceClient from psycopg.rows import DictRow +if TYPE_CHECKING: + from sqlalchemy import URL + from sqlalchemy.ext.asyncio import AsyncEngine + try: import psycopg from psycopg import sql @@ -24,12 +28,16 @@ __all__ = [ "AsyncLakebasePool", + "AsyncLakebaseSQLAlchemy", "LakebasePool", ] logger = logging.getLogger(__name__) -DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 50 * 60 # Cache token for 50 minutes +# Token cache duration based on Databricks Lakebase docs (15 minutes) +# https://docs.databricks.com/aws/en/oltp/projects/authentication?language=Python%3A+SQLAlchemy +DEFAULT_TOKEN_CACHE_DURATION_SECONDS = 15 * 60 # 15 minutes (900 seconds) +DEFAULT_POOL_RECYCLE_SECONDS = 14 * 60 # 14 minutes (before token cache expires) DEFAULT_MIN_SIZE = 1 DEFAULT_MAX_SIZE = 10 DEFAULT_TIMEOUT = 30.0 @@ -81,12 +89,12 @@ class SequencePrivilege(str, Enum): ALL = "ALL" # Renders as ALL PRIVILEGES -class _LakebasePoolBase: +class _LakebaseBase: """ - Base logic for Lakebase connection pools: resolve host, infer username, + Base class for Lakebase connections: resolve host, infer username, token cache + minting, and conninfo building. - Subclasses implement pool-specific initialization and lifecycle methods. + Subclasses implement specific initialization and lifecycle methods. """ def __init__( @@ -168,7 +176,7 @@ def _infer_username(self) -> str: raise ValueError("Unable to infer username for Lakebase connection.") -class LakebasePool(_LakebasePoolBase): +class LakebasePool(_LakebaseBase): """Sync Lakebase connection pool built on psycopg with rotating credentials. instance_name: Name of Lakebase Instance @@ -262,7 +270,7 @@ def close(self) -> None: self._pool.close() -class AsyncLakebasePool(_LakebasePoolBase): +class AsyncLakebasePool(_LakebaseBase): """Async Lakebase connection pool built on psycopg with rotating credentials. instance_name: Name of Lakebase Instance @@ -869,3 +877,129 @@ def grant_all_sequences_in_schema( schema, grantee, ) + + +# ============================================================================= +# AsyncLakebaseSQLAlchemy - SQLAlchemy async engine factory +# ============================================================================= + + +class AsyncLakebaseSQLAlchemy(_LakebaseBase): + """Async SQLAlchemy engine factory for Databricks Lakebase. + + Provides an AsyncEngine with automatic OAuth token injection via + SQLAlchemy's do_connect event. Tokens are cached and refreshed + every 15 minutes. + + Note: + This class is **async-only**. The engine uses SQLAlchemy's + async extension with the psycopg driver. + + Reference: + https://docs.databricks.com/aws/en/oltp/instances/authentication + + Example: + ```python + from databricks_ai_bridge.lakebase import AsyncLakebaseSQLAlchemy + + # Create once and reuse the engine + lakebase = AsyncLakebaseSQLAlchemy(instance_name="my-lakebase") + + async with lakebase.engine.connect() as conn: + result = await conn.execute(text("SELECT 1")) + ``` + """ + + def __init__( + self, + *, + instance_name: str, + workspace_client: WorkspaceClient | None = None, + token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS, + pool_recycle: int = DEFAULT_POOL_RECYCLE_SECONDS, + **engine_kwargs, + ) -> None: + """ + Initialize AsyncLakebaseSQLAlchemy for Databricks Lakebase. + + Args: + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + If not provided, a default client will be created. + token_cache_duration_seconds: How long to cache OAuth tokens. + Defaults to 15 minutes. + pool_recycle: Connection pool recycle time in seconds. + Defaults to 14 minutes (before token cache expires). + **engine_kwargs: Additional keyword arguments passed to + SQLAlchemy's create_async_engine(). + """ + super().__init__( + instance_name=instance_name, + workspace_client=workspace_client, + token_cache_duration_seconds=token_cache_duration_seconds, + ) + + # Thread-safe lock for token caching (do_connect is sync context) + self._cache_lock = Lock() + self._pool_recycle = pool_recycle + self._engine = self._create_engine(**engine_kwargs) + + logger.info( + "AsyncLakebaseSQLAlchemy initialized: instance=%s host=%s", + instance_name, + self.host, + ) + + @property + def engine(self) -> "AsyncEngine": + """The SQLAlchemy AsyncEngine.""" + return self._engine + + def get_token(self) -> str: + """Get cached token or mint a new one (thread-safe).""" + with self._cache_lock: + if cached := self._get_cached_token(): + return cached + token = self._mint_token() + self._cached_token = token + self._cache_ts = time.time() + return token + + def _create_url(self) -> "URL": + """Create SQLAlchemy URL for Lakebase connection.""" + from sqlalchemy import URL + + # Create engine without password - token injected via event listener + # Note: empty password in URL, actual token provided on connect + return URL.create( + drivername="postgresql+psycopg", + username=self.username, + host=self.host, + port=DEFAULT_PORT, + database=DEFAULT_DATABASE, + ) + + def _create_engine(self, **engine_kwargs) -> "AsyncEngine": + """Create AsyncEngine with do_connect event for token injection.""" + from sqlalchemy import event + from sqlalchemy.ext.asyncio import create_async_engine + + url = self._create_url() + + engine: AsyncEngine = create_async_engine( + url, + pool_recycle=self._pool_recycle, + connect_args={"sslmode": DEFAULT_SSLMODE}, + **engine_kwargs, + ) + + # AsyncEngine wraps a sync Engine internally - connection events like + # do_connect must be registered on sync_engine, not the async wrapper. + # https://docs.sqlalchemy.org/en/20/orm/extensions/asyncio.html#using-events-with-the-asyncio-extension + # Lakebase docs https://docs.databricks.com/aws/en/oltp/projects/authentication?language=Python%3A+SQLAlchemy + @event.listens_for(engine.sync_engine, "do_connect") + def inject_token(dialect, conn_rec, cargs, cparams): + cparams["password"] = self.get_token() + logger.debug("Injected Lakebase token for connection") + + return engine diff --git a/tests/databricks_ai_bridge/test_lakebase.py b/tests/databricks_ai_bridge/test_lakebase.py index b914f7217..bdc7ddcbd 100644 --- a/tests/databricks_ai_bridge/test_lakebase.py +++ b/tests/databricks_ai_bridge/test_lakebase.py @@ -12,6 +12,7 @@ import databricks_ai_bridge.lakebase as lakebase from databricks_ai_bridge.lakebase import ( AsyncLakebasePool, + AsyncLakebaseSQLAlchemy, LakebaseClient, LakebasePool, SchemaPrivilege, @@ -83,7 +84,7 @@ def test_lakebase_pool_logs_cache_seconds(monkeypatch, caplog): ) assert any( - record.levelno == logging.INFO and re.search(r"cache=3000s$", record.getMessage()) + record.levelno == logging.INFO and re.search(r"cache=900s$", record.getMessage()) for record in caplog.records ) @@ -255,7 +256,7 @@ async def test_async_lakebase_pool_logs_cache_seconds(monkeypatch, caplog): ) assert any( - record.levelno == logging.INFO and re.search(r"cache=3000s$", record.getMessage()) + record.levelno == logging.INFO and re.search(r"cache=900s$", record.getMessage()) for record in caplog.records ) @@ -1056,3 +1057,241 @@ def test_execute_grant_handles_insufficient_privilege(self): error_msg = str(exc_info.value) assert "Insufficient privileges" in error_msg assert "CAN MANAGE" in error_msg + + +# ============================================================================= +# AsyncLakebaseSQLAlchemy Tests +# ============================================================================= + +pytest.importorskip("sqlalchemy") + + +def _make_sqlalchemy_patches(workspace): + """Return a context manager that patches SQLAlchemy internals for AsyncLakebaseSQLAlchemy.""" + from unittest.mock import patch + + mock_engine = MagicMock(sync_engine=MagicMock()) + + return ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + return_value=lambda f: f, + ), + mock_engine, + ) + + +def test_async_lakebase_sqlalchemy_resolves_host(): + """Test that AsyncLakebaseSQLAlchemy resolves the Lakebase host from instance name.""" + workspace = _make_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert sa.host == "db.host" + workspace.database.get_database_instance.assert_called_once_with("lake-instance") + + +def test_async_lakebase_sqlalchemy_infers_username(): + """Test that AsyncLakebaseSQLAlchemy infers the username from workspace client.""" + workspace = _make_workspace(user_name="alice@databricks.com") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert sa.username == "alice@databricks.com" + + +def test_async_lakebase_sqlalchemy_engine_property(): + """Test that engine property returns the created AsyncEngine.""" + workspace = _make_workspace() + patch_engine, patch_event, mock_engine = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert sa.engine is mock_engine + + +def test_async_lakebase_sqlalchemy_creates_engine_with_correct_url(): + """Test that the engine is created with the correct SQLAlchemy URL.""" + from unittest.mock import patch + + workspace = _make_workspace() + mock_engine = MagicMock(sync_engine=MagicMock()) + + with ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, + patch( + "sqlalchemy.event.listens_for", + return_value=lambda f: f, + ), + ): + AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + url = mock_create.call_args[0][0] + assert url.drivername == "postgresql+psycopg" + assert url.username == "test@databricks.com" + assert url.host == "db.host" + assert url.port == 5432 + assert url.database == "databricks_postgres" + + +def test_async_lakebase_sqlalchemy_passes_extra_engine_kwargs(): + """Test that additional kwargs are forwarded to create_async_engine.""" + from unittest.mock import patch + + workspace = _make_workspace() + mock_engine = MagicMock(sync_engine=MagicMock()) + + with ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ) as mock_create, + patch( + "sqlalchemy.event.listens_for", + return_value=lambda f: f, + ), + ): + AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + echo=True, + pool_pre_ping=True, + ) + + call_kwargs = mock_create.call_args[1] + assert call_kwargs["echo"] is True + assert call_kwargs["pool_pre_ping"] is True + + +def test_async_lakebase_sqlalchemy_do_connect_injects_token(): + """Test that the do_connect handler injects the OAuth token into cparams.""" + from unittest.mock import patch + + workspace = _make_workspace(credential_token="my-secret-token") + mock_engine = MagicMock(sync_engine=MagicMock()) + + captured_handler = None + + def capture_handler(engine, event_name): + def decorator(fn): + nonlocal captured_handler + captured_handler = fn + return fn + + return decorator + + with ( + patch( + "sqlalchemy.ext.asyncio.create_async_engine", + return_value=mock_engine, + ), + patch( + "sqlalchemy.event.listens_for", + side_effect=capture_handler, + ), + ): + AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + assert captured_handler is not None + cparams = {} + captured_handler(None, None, None, cparams) + assert cparams["password"] == "my-secret-token" + + +def test_async_lakebase_sqlalchemy_get_token_caches(): + """Test that get_token returns cached token on repeated calls.""" + workspace = _make_workspace() + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + ) + + token1 = sa.get_token() + token2 = sa.get_token() + + assert token1 == token2 == "token-1" + assert workspace.database.generate_database_credential.call_count == 1 + + +def test_async_lakebase_sqlalchemy_get_token_refreshes_after_expiry(monkeypatch): + """Test that get_token mints a new token after cache expiry.""" + import time + + call_count = [] + + def mock_generate_credential(**kwargs): + call_count.append(1) + return MagicMock(token=f"token-{len(call_count)}") + + workspace = _make_workspace() + workspace.database.generate_database_credential = mock_generate_credential + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + sa = AsyncLakebaseSQLAlchemy( + instance_name="lake-instance", + workspace_client=workspace, + token_cache_duration_seconds=1, + ) + + test_time = [100.0] + monkeypatch.setattr(time, "time", lambda: test_time[0]) + + token1 = sa.get_token() + assert token1 == "token-1" + + # Within cache window + test_time[0] = 100.5 + token2 = sa.get_token() + assert token2 == "token-1" + assert len(call_count) == 1 + + # After cache expiry + test_time[0] = 101.5 + token3 = sa.get_token() + assert token3 == "token-2" + assert len(call_count) == 2 + + +def test_async_lakebase_sqlalchemy_invalid_instance_raises(): + """Test that an invalid instance name raises ValueError.""" + workspace = _make_workspace() + workspace.database.get_database_instance.side_effect = Exception("Not found") + patch_engine, patch_event, _ = _make_sqlalchemy_patches(workspace) + + with patch_engine, patch_event: + with pytest.raises(ValueError, match="Unable to resolve Lakebase instance"): + AsyncLakebaseSQLAlchemy( + instance_name="bad-instance", + workspace_client=workspace, + )