diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 3f79d8fadc..7c75f7add3 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -19,6 +19,7 @@ import copy from functools import cached_property import logging +import os import re from typing import Any from typing import AsyncGenerator @@ -54,6 +55,7 @@ _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} _GOOGLE_API_VERSION_SUFFIX_PATTERN = re.compile(r'/?(v[0-9][a-z0-9.-]*)/?') +_API_VERSION_ENV_VARIABLE_NAME = 'GOOGLE_GENAI_API_VERSION' _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ @@ -123,6 +125,27 @@ def api_client(self) -> Client: base_url: Optional[str] = None """The base URL for the AI platform service endpoint.""" + api_version: Optional[str] = None + """The API version to use for the AI platform service endpoint. + + For the Vertex AI backend the google-genai SDK defaults to ``v1beta1``, which + exposes the latest preview features. Production deployments that require a + stable, SLA-eligible endpoint can set this to ``v1`` to use the GA Vertex AI + API. When unset, the ``GOOGLE_GENAI_API_VERSION`` environment variable is + consulted, and finally the SDK's own default is used so existing behavior is + unchanged. + + An API version embedded in ``base_url`` (e.g. + ``https://...googleapis.com/v1``) takes precedence over this field. + + Sample: + ```python + from google.adk.models import Gemini + + agent = Agent(model=Gemini(model="gemini-2.5-pro", api_version="v1")) + ``` + """ + speech_config: Optional[types.SpeechConfig] = None use_interactions_api: bool = False @@ -371,9 +394,29 @@ def _api_backend(self) -> GoogleLLMVariant: def _tracking_headers(self) -> dict[str, str]: return get_tracking_headers() + def _configured_api_version(self) -> Optional[str]: + """Returns the explicitly configured API version, if any. + + Resolution order: + 1. The ``api_version`` field set on this instance. + 2. The ``GOOGLE_GENAI_API_VERSION`` environment variable. + + Returns ``None`` when neither is set, in which case the google-genai SDK's + own default (``v1beta1`` for Vertex AI) applies, preserving existing + behavior. + """ + if self.api_version: + return self.api_version + return os.environ.get(_API_VERSION_ENV_VARIABLE_NAME) or None + @cached_property def _base_url_and_api_version(self) -> tuple[Optional[str], Optional[str]]: - return _normalize_base_url_and_api_version(self.base_url) + base_url, api_version = _normalize_base_url_and_api_version(self.base_url) + # A version embedded in the base URL wins; otherwise fall back to the + # explicitly configured api_version (field or environment variable). + if api_version is None: + api_version = self._configured_api_version() + return base_url, api_version @cached_property def _live_api_version(self) -> str: diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index caad24ed3f..413d1eb558 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -341,6 +341,61 @@ def test_api_client_preserves_custom_base_url_path(): assert client._api_client._http_options.api_version == "v1beta" +def test_api_client_default_api_version_unchanged(monkeypatch): + """Without configuration, ADK does not force an api_version (SDK default).""" + monkeypatch.delenv("GOOGLE_GENAI_API_VERSION", raising=False) + model = Gemini(model="gemini-2.5-flash") + + # ADK leaves api_version unset so the google-genai SDK applies its own + # default (v1beta1 for Vertex AI), preserving existing behavior. + assert model._base_url_and_api_version == (None, None) + + +def test_api_client_uses_api_version_field(): + """The api_version field flows into the constructed client's http_options.""" + model = Gemini(model="gemini-2.5-flash", api_version="v1") + + client = model.api_client + + assert client._api_client._http_options.api_version == "v1" + + +def test_api_client_uses_api_version_env_var(monkeypatch): + """The GOOGLE_GENAI_API_VERSION env var flows into http_options.""" + monkeypatch.setenv("GOOGLE_GENAI_API_VERSION", "v1") + model = Gemini(model="gemini-2.5-flash") + + client = model.api_client + + assert client._api_client._http_options.api_version == "v1" + + +def test_api_version_field_overrides_env_var(monkeypatch): + """The explicit api_version field takes precedence over the env var.""" + monkeypatch.setenv("GOOGLE_GENAI_API_VERSION", "v1beta1") + model = Gemini(model="gemini-2.5-flash", api_version="v1") + + client = model.api_client + + assert client._api_client._http_options.api_version == "v1" + + +def test_base_url_api_version_overrides_field(): + """A version embedded in base_url wins over the api_version field.""" + model = Gemini( + model="gemini-2.5-flash", + base_url="https://generativelanguage.googleapis.com/v1alpha", + api_version="v1", + ) + + client = model.api_client + + assert client._api_client._http_options.base_url == ( + "https://generativelanguage.googleapis.com/" + ) + assert client._api_client._http_options.api_version == "v1alpha" + + def test_maybe_append_user_content(gemini_llm, llm_request): # Test with user content already present gemini_llm._maybe_append_user_content(llm_request) @@ -766,6 +821,35 @@ async def mock_coro(): assert len(responses) == 2 if stream else 1 +@pytest.mark.asyncio +async def test_generate_content_async_patches_api_version_from_field( + llm_request, generate_content_response +): + """The configured api_version field is patched onto the request config.""" + gemini_llm = Gemini(model="gemini-2.5-flash", api_version="v1") + llm_request.config.http_options = types.HttpOptions( + headers={"custom-header": "custom-value"} + ) + + with mock.patch.object(gemini_llm, "api_client") as mock_client: + + async def mock_coro(): + return generate_content_response + + mock_client.aio.models.generate_content.return_value = mock_coro() + + _ = [ + resp + async for resp in gemini_llm.generate_content_async( + llm_request, stream=False + ) + ] + + call_args = mock_client.aio.models.generate_content.call_args + final_config = call_args.kwargs["config"] + assert final_config.http_options.api_version == "v1" + + def test_live_api_version_vertex_ai(gemini_llm): """Test that _live_api_version returns 'v1beta1' for Vertex AI backend.""" with mock.patch.object( @@ -774,6 +858,28 @@ def test_live_api_version_vertex_ai(gemini_llm): assert gemini_llm._live_api_version == "v1beta1" +def test_live_api_version_uses_configured_field(): + """Test that _live_api_version honors the configured api_version field.""" + gemini_llm = Gemini(model="gemini-2.5-flash", api_version="v1") + + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + assert gemini_llm._live_api_version == "v1" + + +def test_live_api_client_uses_configured_field(): + """Test that _live_api_client http_options honors the api_version field.""" + gemini_llm = Gemini(model="gemini-2.5-flash", api_version="v1") + + with mock.patch.object( + gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + ): + client = gemini_llm._live_api_client + + assert client._api_client._http_options.api_version == "v1" + + def test_live_api_version_uses_google_base_url_version(): gemini_llm = Gemini( model="gemini-2.5-flash",