Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ optional-dependencies.all = [
"google-cloud-aiplatform[agent-engines]>=1.148.1,<2",
"google-cloud-bigquery>=2.2",
"google-cloud-bigquery-storage>=2",
"google-cloud-bigtable>=2.32",
"google-cloud-bigtable>=2.38.0",
"google-cloud-dataplex>=1.7,<3",
"google-cloud-discoveryengine>=0.13.12,<0.14",
"google-cloud-parametermanager>=0.4,<1",
Expand Down Expand Up @@ -161,7 +161,7 @@ optional-dependencies.gcp = [
"google-cloud-aiplatform[agent-engines]>=1.148.1,<2",
"google-cloud-bigquery>=2.2",
"google-cloud-bigquery-storage>=2",
"google-cloud-bigtable>=2.32",
"google-cloud-bigtable>=2.38.0",
"google-cloud-dataplex>=1.7,<3",
"google-cloud-discoveryengine>=0.13.12,<0.14",
"google-cloud-parametermanager>=0.4,<1",
Expand Down Expand Up @@ -202,7 +202,7 @@ optional-dependencies.test = [
"google-cloud-aiplatform[agent-engines,evaluation]>=1.148.1,<2",
"google-cloud-bigquery>=2.2",
"google-cloud-bigquery-storage>=2",
"google-cloud-bigtable>=2.32",
"google-cloud-bigtable>=2.38.0",
"google-cloud-dataplex>=1.7,<3",
"google-cloud-discoveryengine>=0.13.12,<0.14",
"google-cloud-firestore>=2.11,<3",
Expand Down
98 changes: 98 additions & 0 deletions src/google/adk/tools/bigtable/bigtable_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

from __future__ import annotations

import inspect
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Union

from google.adk.agents.readonly_context import ReadonlyContext
from google.auth.credentials import Credentials
from pydantic import BaseModel
from typing_extensions import override

from . import metadata_tool
Expand All @@ -29,9 +34,93 @@
from ...tools.base_toolset import BaseToolset
from ...tools.base_toolset import ToolPredicate
from ...tools.google_tool import GoogleTool
from ..tool_context import ToolContext
from .bigtable_credentials import BigtableCredentialsConfig
from .settings import BigtableToolSettings


class BigtableParameterizedViewTool(GoogleTool):
"""Wrapper FunctionTool for Bigtable execute_sql query tool that passes view parameters.

This tool wraps the Bigtable query tool to automatically resolve and inject a
parameter from the ToolContext (e.g. user_id) into the query's
view_parameters. The parameter name to resolve is configured via
view_parameter_name.

Example:
If a parameterized view `purchase_history_pv` was created with the query:
`SELECT * FROM purchases WHERE user_id = VIEW_PARAMETERS('user_id')`

By configuring `view_parameter_name="user_id"`, the wrapper will resolve the
`user_id` value from the `tool_context.user_id` at runtime and pass it as
`view_parameters={"user_id": user_id}`. This securely restricts query execution to the
logged-in user's data without exposing the `user_id` parameter to the LLM.
"""

def __init__(
self,
func: Callable[..., Any],
*,
credentials_config: Optional[BigtableCredentialsConfig] = None,
tool_settings: Optional[BigtableToolSettings] = None,
view_parameter_name: Optional[str] = None,
):
"""Initializes the BigtableParameterizedViewTool.

Args:
func: The Bigtable query function to wrap.
credentials_config: The credentials configuration.
tool_settings: The tool settings.
view_parameter_name: The name of the parameter to resolve from
tool_context and pass into view_parameters. This is typically
configured on the toolset (BigtableToolset) and forwarded here.
"""
super().__init__(
func=func,
credentials_config=credentials_config,
tool_settings=tool_settings,
)
self.name = "execute_sql_parameterized"
self.description = (
"Execute a GoogleSQL query from a Bigtable table using parameterized views "
"to securely check permissions."
)
self._view_parameter_name = view_parameter_name
# Exclude from being parsed and exposed to the LLM when generating tool schemas
self._ignore_params.append("view_parameters")

@override
async def _run_async_with_credential(
self,
credentials: Credentials,
tool_settings: BaseModel,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
args_to_call = args.copy()
signature = inspect.signature(self.func)
if "view_parameters" in signature.parameters and self._view_parameter_name:
view_params = {}
# 1. Check if it's a strongly-typed top-level property (like 'user_id')
if hasattr(tool_context, self._view_parameter_name):
view_params[self._view_parameter_name] = getattr(
tool_context, self._view_parameter_name
)
# 2. Fallback to checking application-level session state
elif (
tool_context.state
and self._view_parameter_name in tool_context.state
):
view_params[self._view_parameter_name] = tool_context.state[
self._view_parameter_name
]

args_to_call["view_parameters"] = view_params
return await super()._run_async_with_credential(
credentials, tool_settings, args_to_call, tool_context
)


DEFAULT_BIGTABLE_TOOL_NAME_PREFIX = "bigtable"


Expand All @@ -55,6 +144,7 @@ def __init__(
tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
credentials_config: Optional[BigtableCredentialsConfig] = None,
bigtable_tool_settings: Optional[BigtableToolSettings] = None,
view_parameter_name: Optional[str] = None,
):
super().__init__(
tool_filter=tool_filter,
Expand All @@ -66,6 +156,7 @@ def __init__(
if bigtable_tool_settings
else BigtableToolSettings()
)
self._view_parameter_name = view_parameter_name

def _is_tool_selected(
self, tool: BaseTool, readonly_context: ReadonlyContext
Expand Down Expand Up @@ -101,6 +192,13 @@ async def get_tools(
metadata_tool.get_cluster_info,
query_tool.execute_sql,
]
] + [
BigtableParameterizedViewTool(
func=query_tool.execute_sql,
credentials_config=self._credentials_config,
tool_settings=self._tool_settings,
view_parameter_name=self._view_parameter_name,
)
]
return [
tool
Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/tools/bigtable/query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ async def execute_sql(
tool_context: ToolContext,
parameters: Dict[str, Any] | None = None,
parameter_types: Dict[str, Any] | None = None,
view_parameters: Dict[str, Any] | None = None,
) -> dict:
"""Execute a GoogleSQL query from a Bigtable table.

Expand All @@ -56,6 +57,7 @@ async def execute_sql(
parameters (dict): properties for parameter replacement. Keys must match
the names used in ``query``.
parameter_types (dict): maps explicit types for one or more param values.
view_parameters (dict): maps properties for parameterized authorized views.

Returns:
dict: Dictionary containing the status and the rows read.
Expand Down Expand Up @@ -91,6 +93,7 @@ def _execute_sql():
instance_id=instance_id,
parameters=parameters,
parameter_types=parameter_types,
view_parameters=view_parameters,
)

rows: List[Dict[str, Any]] = []
Expand Down
38 changes: 38 additions & 0 deletions tests/unittests/tools/bigtable/test_bigtable_query_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def raise_error():
instance_id=instance_id,
parameters=parameters,
parameter_types=parameter_types,
view_parameters=None,
)
mock_iterator.close.assert_called_once()

Expand Down Expand Up @@ -228,3 +229,40 @@ async def test_execute_sql_row_value_circular_reference_fallback():

assert result["status"] == "SUCCESS"
assert result["rows"][0]["col1"] == str(circular_value)


@pytest.mark.asyncio
async def test_execute_sql_with_view_parameters():
"""Test execute_sql with view_parameters passed."""
project = "my_project"
instance_id = "my_instance"
query = "SELECT * FROM my_table"
credentials = mock.create_autospec(Credentials, instance=True)
tool_context = mock.create_autospec(ToolContext, instance=True)
view_parameters = {"user_id": "test-user-123"}

with mock.patch.object(client, "get_bigtable_data_client") as mock_get_client:
mock_client = mock.MagicMock()
mock_get_client.return_value = mock_client
mock_iterator = mock.create_autospec(ExecuteQueryIterator, instance=True)
mock_client.execute_query.return_value = mock_iterator
mock_iterator.__iter__.return_value = []

result = await execute_sql(
project_id=project,
instance_id=instance_id,
credentials=credentials,
query=query,
settings=BigtableToolSettings(),
tool_context=tool_context,
view_parameters=view_parameters,
)

assert result["status"] == "SUCCESS"
mock_client.execute_query.assert_called_once_with(
query=query,
instance_id=instance_id,
parameters=None,
parameter_types=None,
view_parameters=view_parameters,
)
Loading
Loading