Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions integrations/openai/src/databricks_openai/utils/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ class DatabricksOpenAI(OpenAI):
with ``base_url`` or ``use_ai_gateway``. Defaults to False.
use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route
requests through it using the MLflow API. Defaults to False.
**kwargs: Additional keyword arguments forwarded to the underlying ``openai.OpenAI``
client (e.g. ``timeout``, ``max_retries``, ``default_headers``).

Example - Query a serving or AI gateway endpoint:
>>> client = DatabricksOpenAI()
Expand Down Expand Up @@ -385,6 +387,7 @@ def __init__(
base_url: str | None = None,
use_ai_gateway_native_api: bool = False,
use_ai_gateway: bool = False,
**kwargs: Any,
):
if workspace_client is None:
workspace_client = WorkspaceClient()
Expand All @@ -400,6 +403,7 @@ def __init__(
base_url=target_base_url,
api_key=_get_openai_api_key(),
http_client=_get_authorized_http_client(workspace_client),
**kwargs,
)

@override
Expand Down Expand Up @@ -504,6 +508,8 @@ class AsyncDatabricksOpenAI(AsyncOpenAI):
with ``base_url`` or ``use_ai_gateway``. Defaults to False.
use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route
requests through it using the MLflow API. Defaults to False.
**kwargs: Additional keyword arguments forwarded to the underlying ``openai.AsyncOpenAI``
client (e.g. ``timeout``, ``max_retries``, ``default_headers``).

Example - Query a serving or AI gateway endpoint:
>>> client = AsyncDatabricksOpenAI()
Expand Down Expand Up @@ -546,6 +552,7 @@ def __init__(
base_url: str | None = None,
use_ai_gateway_native_api: bool = False,
use_ai_gateway: bool = False,
**kwargs: Any,
):
if workspace_client is None:
workspace_client = WorkspaceClient()
Expand All @@ -561,6 +568,7 @@ def __init__(
base_url=target_base_url,
api_key=_get_openai_api_key(),
http_client=_get_authorized_async_http_client(workspace_client),
**kwargs,
)

@property
Expand Down
24 changes: 24 additions & 0 deletions integrations/openai/tests/unit_tests/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,30 @@ def test_bearer_auth_flow(self, mock_workspace_client):
mock_workspace_client.config.authenticate.assert_called()


class TestDatabricksOpenAIKwargForwarding:
"""Forwarding of OpenAI kwargs through the wrappers (issue #423)."""

@pytest.mark.parametrize(
"client_cls", [DatabricksOpenAI, AsyncDatabricksOpenAI], ids=["sync", "async"]
)
def test_forwards_timeout_and_max_retries(self, client_cls, mock_workspace_client):
client = client_cls(workspace_client=mock_workspace_client, timeout=42.0, max_retries=7)
assert client.timeout == 42.0
assert client.max_retries == 7

Comment thread
kevinyan122 marked this conversation as resolved.
@pytest.mark.parametrize(
"client_cls", [DatabricksOpenAI, AsyncDatabricksOpenAI], ids=["sync", "async"]
)
@pytest.mark.parametrize("managed_kwarg", ["api_key", "http_client"])
def test_rejects_databricks_managed_auth_kwargs(
self, client_cls, managed_kwarg, mock_workspace_client
):
# api_key and http_client are hardcoded with Databricks auth; passing them through
# **kwargs would collide with the super().__init__() call and raise TypeError.
with pytest.raises(TypeError, match="got multiple values for keyword argument"):
client_cls(workspace_client=mock_workspace_client, **{managed_kwarg: "anything"})


class TestStrictFieldStripping:
"""Tests for strict field stripping helper functions."""

Expand Down
Loading