diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 6711b87f..28138bdc 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -343,6 +343,8 @@ class DatabricksOpenAI(OpenAI): with ``base_url`` or ``use_ai_gateway``. Defaults to False. use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route requests through it using the MLflow API. Defaults to False. + **kwargs: Additional keyword arguments forwarded to the underlying ``openai.OpenAI`` + client (e.g. ``timeout``, ``max_retries``, ``default_headers``). Example - Query a serving or AI gateway endpoint: >>> client = DatabricksOpenAI() @@ -385,6 +387,7 @@ def __init__( base_url: str | None = None, use_ai_gateway_native_api: bool = False, use_ai_gateway: bool = False, + **kwargs: Any, ): if workspace_client is None: workspace_client = WorkspaceClient() @@ -400,6 +403,7 @@ def __init__( base_url=target_base_url, api_key=_get_openai_api_key(), http_client=_get_authorized_http_client(workspace_client), + **kwargs, ) @override @@ -504,6 +508,8 @@ class AsyncDatabricksOpenAI(AsyncOpenAI): with ``base_url`` or ``use_ai_gateway``. Defaults to False. use_ai_gateway: If True, auto-detect AI Gateway V2 availability and route requests through it using the MLflow API. Defaults to False. + **kwargs: Additional keyword arguments forwarded to the underlying ``openai.AsyncOpenAI`` + client (e.g. ``timeout``, ``max_retries``, ``default_headers``). Example - Query a serving or AI gateway endpoint: >>> client = AsyncDatabricksOpenAI() @@ -546,6 +552,7 @@ def __init__( base_url: str | None = None, use_ai_gateway_native_api: bool = False, use_ai_gateway: bool = False, + **kwargs: Any, ): if workspace_client is None: workspace_client = WorkspaceClient() @@ -561,6 +568,7 @@ def __init__( base_url=target_base_url, api_key=_get_openai_api_key(), http_client=_get_authorized_async_http_client(workspace_client), + **kwargs, ) @property diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index bd3fbe53..bd23211a 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -168,6 +168,30 @@ def test_bearer_auth_flow(self, mock_workspace_client): mock_workspace_client.config.authenticate.assert_called() +class TestDatabricksOpenAIKwargForwarding: + """Forwarding of OpenAI kwargs through the wrappers (issue #423).""" + + @pytest.mark.parametrize( + "client_cls", [DatabricksOpenAI, AsyncDatabricksOpenAI], ids=["sync", "async"] + ) + def test_forwards_timeout_and_max_retries(self, client_cls, mock_workspace_client): + client = client_cls(workspace_client=mock_workspace_client, timeout=42.0, max_retries=7) + assert client.timeout == 42.0 + assert client.max_retries == 7 + + @pytest.mark.parametrize( + "client_cls", [DatabricksOpenAI, AsyncDatabricksOpenAI], ids=["sync", "async"] + ) + @pytest.mark.parametrize("managed_kwarg", ["api_key", "http_client"]) + def test_rejects_databricks_managed_auth_kwargs( + self, client_cls, managed_kwarg, mock_workspace_client + ): + # api_key and http_client are hardcoded with Databricks auth; passing them through + # **kwargs would collide with the super().__init__() call and raise TypeError. + with pytest.raises(TypeError, match="got multiple values for keyword argument"): + client_cls(workspace_client=mock_workspace_client, **{managed_kwarg: "anything"}) + + class TestStrictFieldStripping: """Tests for strict field stripping helper functions."""