Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ jobs:
--exact --group tests --extra memory \
pytest tests/unit_tests/test_checkpoint.py \
tests/unit_tests/test_store.py
- name: Run databricks-openai Tests
working-directory: integrations/openai
run: |
uv run --resolution ${{ matrix.uv-resolution }} \
--exact --group tests --extra memory \
pytest tests/unit_tests/test_session.py

langchain_cross_version_test:
runs-on: ubuntu-latest
Expand Down
6 changes: 6 additions & 0 deletions integrations/openai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ dependencies = [
"openai-agents>=0.5.0"
]

[project.optional-dependencies]
memory = [
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update the CI job for this memory extra too

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

"databricks-ai-bridge[memory]>=0.13.0",
]

[dependency-groups]
dev = [
"typing_extensions>=4.15.0",
Expand All @@ -34,6 +39,7 @@ tests = [
"pytest-asyncio>=1.3.0",
"pytest-timeout>=2.3.1",
"pytest-cov>=4.1.0",
"databricks-ai-bridge[memory]>=0.13.0",
]

[build-system]
Expand Down
3 changes: 2 additions & 1 deletion integrations/openai/src/databricks_openai/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from databricks_openai.agents.mcp_server import McpServer
from databricks_openai.agents.session import AsyncDatabricksSession

__all__ = ["McpServer"]
__all__ = ["AsyncDatabricksSession", "McpServer"]
201 changes: 201 additions & 0 deletions integrations/openai/src/databricks_openai/agents/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want this to be importable from databricks_openai.agents?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'll make this importable so instead of:

from databricks_openai.agents.session import MemorySession

import path will look like:

from databricks_openai.agents import MemorySession

AsyncDatabricksSession - Async SQLAlchemy-based session storage for Databricks Lakebase.

This module provides an AsyncDatabricksSession class that subclasses OpenAI's SQLAlchemySession
to provide persistent conversation history storage in Databricks Lakebase.

Note:
This class is **async-only** as it follows the Session Protocol. Use within async context
https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.Session

Usage::

import asyncio
from databricks_openai.agents import AsyncDatabricksSession
from agents import Agent, Runner


async def main():
session = AsyncDatabricksSession(
session_id="user-123",
instance_name="my-lakebase-instance",
)

agent = Agent(name="Assistant")
result = await Runner.run(agent, "Hello!", session=session)


asyncio.run(main())
"""

from __future__ import annotations

import json
import logging
from threading import Lock
from typing import Any, Optional

try:
from agents.extensions.memory import SQLAlchemySession
from databricks.sdk import WorkspaceClient
from databricks_ai_bridge.lakebase import (
DEFAULT_POOL_RECYCLE_SECONDS,
DEFAULT_TOKEN_CACHE_DURATION_SECONDS,
AsyncLakebaseSQLAlchemy,
)

_session_imports_available = True
except ImportError:
SQLAlchemySession = object # type: ignore
_session_imports_available = False

logger = logging.getLogger(__name__)


class AsyncDatabricksSession(SQLAlchemySession):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from talking to the research team working on DBRA, they actually have a very similar snippet as us to manage a SQLAlchemy connection to lakebase: https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/blob/research/aroll/app/aroll_app/db/connection.py?L162-182

would it make sense for us to further abstract this by providing a similar AsyncLakebaseSQLAlchemy / LakebaseSQLAlchemy class?

Copy link
Copy Markdown
Contributor Author

@jennsun jennsun Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline but I'll refactor such that:

this will create much cleaner separation of concerns for future frameworks to reuse any sqlalchemy engines etc!

"""
Async OpenAI Agents SDK Session implementation for Databricks Lakebase.
For more information on the Session protocol, see:
https://openai.github.io/openai-agents-python/ref/memory/session/

Note:
This class is **async-only**. All session methods (get_items, add_items,
clear_session, etc.) are coroutines and must be awaited.

The session stores conversation history in two tables:
- agent_sessions: Tracks session metadata (session_id, created_at, updated_at)
- agent_messages: Stores conversation items (id, session_id, message_data, created_at)

Example:
```python
import asyncio
from databricks_openai.agents import AsyncDatabricksSession
from agents import Agent, Runner


async def main():
session = AsyncDatabricksSession(
session_id="user-123",
instance_name="my-lakebase-instance",
)
agent = Agent(name="Assistant")
result = await Runner.run(agent, "Hello!", session=session)


asyncio.run(main())
```
"""

# Class-level cache for AsyncLakebaseSQLAlchemy instances keyed by
# (instance_name, engine_kwargs). This allows multiple sessions to share
# a single engine/connection pool when the configuration is identical.
_lakebase_sql_alchemy_cache: dict[str, AsyncLakebaseSQLAlchemy] = {}
_lakebase_sql_alchemy_cache_lock = Lock()
Copy link
Copy Markdown
Contributor Author

@jennsun jennsun Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts on the class-level cache for AsyncLakebaseSQLAlchemy engines keyed by instance_name?

this is so we reuse a single SQLAlchemy engine / pool per Lakebase instance, avoiding repeated pool creation, TCP handshakes, and auth setup.

sessions are still created per Runner.run(), but engines are shared

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach looks good to me to minimize IO. two comments:

  • we may want to include a param for a func for customers to customize the cache key. currently, diff engine kwargs for the same instance name will be ignored
  • let's also call this out in the docstring and add a param to optionally disable this engine caching

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the best case would be include engine kwargs + instance name in the cache key

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good - going to create cache key that takes into consideration both instance name + engine kwards, as well as ability to not cache the engines (but defaults to caching)


def __init__(
self,
session_id: str,
*,
instance_name: str,
workspace_client: Optional[WorkspaceClient] = None,
token_cache_duration_seconds: int = DEFAULT_TOKEN_CACHE_DURATION_SECONDS,
create_tables: bool = True,
sessions_table: str = "agent_sessions",
messages_table: str = "agent_messages",
use_cached_engine: bool = True,
**engine_kwargs,
) -> None:
"""
Initialize an AsyncDatabricksSession for Databricks Lakebase.

Args:
session_id: Unique identifier for the conversation session.
instance_name: Name of the Lakebase instance.
workspace_client: Optional WorkspaceClient for authentication.
If not provided, a default client will be created.
token_cache_duration_seconds: How long to cache OAuth tokens.
Defaults to 15 minutes.
create_tables: Whether to auto-create tables on first use.
Defaults to True.
sessions_table: Name of the sessions table.
Defaults to "agent_sessions".
messages_table: Name of the messages table.
Defaults to "agent_messages".
use_cached_engine: Whether to reuse a cached engine for the same
instance_name and engine_kwargs combination. Set to False to
always create a new engine. Defaults to True.
**engine_kwargs: Additional keyword arguments passed to
SQLAlchemy's create_async_engine().
"""
if not _session_imports_available:
raise ImportError(
"AsyncDatabricksSession requires databricks-openai[memory]. "
"Please install with: pip install databricks-openai[memory]"
)

self._lakebase = self._get_or_create_lakebase(
instance_name=instance_name,
workspace_client=workspace_client,
token_cache_duration_seconds=token_cache_duration_seconds,
pool_recycle=engine_kwargs.pop("pool_recycle", DEFAULT_POOL_RECYCLE_SECONDS),
use_cached_engine=use_cached_engine,
**engine_kwargs,
)

# Initialize parent SQLAlchemySession - inherits all SQL logic
super().__init__(
session_id=session_id,
engine=self._lakebase.engine,
create_tables=create_tables,
sessions_table=sessions_table,
messages_table=messages_table,
)

logger.info(
"AsyncDatabricksSession initialized: instance=%s session_id=%s",
instance_name,
session_id,
)

@classmethod
def _build_cache_key(cls, instance_name: str, **engine_kwargs: Any) -> str:
"""Build a cache key from instance_name and engine_kwargs."""
# Sort kwargs for deterministic key; use JSON for serializable values
kwargs_key = json.dumps(engine_kwargs, sort_keys=True, default=str)
return f"{instance_name}::{kwargs_key}"

@classmethod
def _get_or_create_lakebase(
cls,
*,
instance_name: str,
workspace_client: Optional[WorkspaceClient],
token_cache_duration_seconds: int,
pool_recycle: int,
use_cached_engine: bool = True,
**engine_kwargs,
) -> AsyncLakebaseSQLAlchemy:
"""Get cached AsyncLakebaseSQLAlchemy or create a new one.
The cache key uses both instance_name and engine_kwargs
"""
cache_key = cls._build_cache_key(instance_name, pool_recycle=pool_recycle, **engine_kwargs)

if use_cached_engine:
with cls._lakebase_sql_alchemy_cache_lock:
if cache_key in cls._lakebase_sql_alchemy_cache:
logger.debug("Reusing cached engine for key=%s", cache_key)
return cls._lakebase_sql_alchemy_cache[cache_key]

lakebase = AsyncLakebaseSQLAlchemy(
instance_name=instance_name,
workspace_client=workspace_client,
token_cache_duration_seconds=token_cache_duration_seconds,
pool_recycle=pool_recycle,
**engine_kwargs,
)

if use_cached_engine:
with cls._lakebase_sql_alchemy_cache_lock:
cls._lakebase_sql_alchemy_cache[cache_key] = lakebase

return lakebase
10 changes: 6 additions & 4 deletions integrations/openai/src/databricks_openai/utils/clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator
from typing import Any, Generator

from databricks.sdk import WorkspaceClient
from httpx import AsyncClient, Auth, Client, Request, Response
Expand All @@ -18,14 +18,16 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]:
yield request


def _strip_strict_from_tools(tools: list | None) -> list | None:
def _strip_strict_from_tools(tools: Any) -> Any:
"""Remove 'strict' field from tool function definitions.

Databricks model endpoints (except GPT) don't support the 'strict' field
in tool schemas, but openai-agents SDK v0.6.4+ includes it.
"""
if tools is None:
return None
# Handle None or OpenAI's NOT_GIVEN/Omit sentinel types (non-iterable placeholders).
# See https://deepwiki.com/openai/openai-python/5-data-types-and-models#special-types-and-sentinels
if not tools:
return tools
for tool in tools:
if isinstance(tool, dict) and "function" in tool:
tool.get("function", {}).pop("strict", None)
Expand Down
Loading
Loading