diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index 9d706c84152f..25c9e72c6005 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -93,6 +93,29 @@ class TableColumnInfo(BaseModel): filterable: bool | None = Field(None, description="Is filterable") description: str | None = Field(None, description="Column description") + @model_serializer(mode="wrap") + def _filter_column_fields_by_context( + self, serializer: Any, info: Any + ) -> Dict[str, Any]: + """Filter column fields based on serialization context. + + If context contains 'column_fields', only include those fields. + Otherwise, include all fields. This trims wide datasets so a + 50-column dataset doesn't ship 50 long descriptions when the + caller only needs column_name + type. + """ + data = serializer(self) + + if info.context and isinstance(info.context, dict): + column_fields = info.context.get("column_fields") + if column_fields: + requested = set(column_fields) + # Always preserve column_name as the only required field + requested.add("column_name") + return {k: v for k, v in data.items() if k in requested} + + return data + class SqlMetricInfo(BaseModel): metric_name: str = Field( @@ -311,6 +334,29 @@ def create(cls, error: str, error_type: str) -> "DatasetError": ) +DEFAULT_GET_DATASET_INFO_COLUMNS: List[str] = [ + "id", + "table_name", + "schema", + "database_name", + "database_id", + "uuid", + "is_virtual", + "description", + "main_dttm_col", + "sql", + "url", + "columns", + "metrics", +] + +DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS: List[str] = [ + "column_name", + "type", + "is_dttm", +] + + class GetDatasetInfoRequest(MetadataCacheControl): """Request schema for get_dataset_info with support for ID or UUID.""" @@ -318,6 +364,59 @@ class GetDatasetInfoRequest(MetadataCacheControl): int | str, Field(description="Dataset identifier - can be numeric ID or UUID string"), ] + select_columns: Annotated[ + List[str], + Field( + default_factory=lambda: list(DEFAULT_GET_DATASET_INFO_COLUMNS), + description=( + "Top-level fields to include in the response. Defaults to a lean " + "set that excludes verbose fields like params, template_params, " + "extra, tags, certification_details. Pass an explicit list to " + "override (e.g. ['id','table_name','columns'] for minimal output)." + ), + ), + ] + column_fields: Annotated[ + List[str], + Field( + default_factory=lambda: list(DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS), + description=( + "Per-column fields to include for entries in 'columns'. Defaults " + "to ['column_name','type','is_dttm']. Pass a wider list to " + "include 'verbose_name','groupby','filterable','description' " + "when needed. Trimming per-column fields keeps responses small " + "for wide datasets." + ), + ), + ] + + @field_validator("select_columns", mode="before") + @classmethod + def _parse_select_columns(cls, value: Any) -> Any: + from superset.mcp_service.utils.schema_utils import parse_json_or_list + + if value is None: + return list(DEFAULT_GET_DATASET_INFO_COLUMNS) + parsed = parse_json_or_list(value, "select_columns") + # Treat empty list as "use defaults" so callers cannot accidentally + # opt out of size reduction by passing []. Without this, an empty + # list disables filtering downstream and reintroduces oversized + # responses. + if not parsed: + return list(DEFAULT_GET_DATASET_INFO_COLUMNS) + return parsed + + @field_validator("column_fields", mode="before") + @classmethod + def _parse_column_fields(cls, value: Any) -> Any: + from superset.mcp_service.utils.schema_utils import parse_json_or_list + + if value is None: + return list(DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS) + parsed = parse_json_or_list(value, "column_fields") + if not parsed: + return list(DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS) + return parsed class CreateVirtualDatasetRequest(BaseModel): diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py b/superset/mcp_service/dataset/tool/get_dataset_info.py index 93d5f21ddf05..bd546371cfd4 100644 --- a/superset/mcp_service/dataset/tool/get_dataset_info.py +++ b/superset/mcp_service/dataset/tool/get_dataset_info.py @@ -24,6 +24,7 @@ import logging from datetime import datetime, timezone +from typing import Any from fastmcp import Context from sqlalchemy.orm import joinedload, subqueryload @@ -58,7 +59,7 @@ @requires_data_model_metadata_access async def get_dataset_info( request: GetDatasetInfoRequest, ctx: Context -) -> DatasetInfo | DatasetError: +) -> dict[str, Any] | DatasetError: """Get dataset metadata by ID or UUID. Returns columns, metrics, and schema details. @@ -68,6 +69,12 @@ async def get_dataset_info( - DO NOT use schema.table_name format (e.g., "public.customers") - To find a dataset ID, use the list_datasets tool first + Response size control (use these to keep responses small): + - select_columns: top-level fields to include (default: lean set) + - column_fields: per-column fields for entries in 'columns' (default: + column_name, type, is_dttm). Pass a wider list to opt in to + verbose_name, groupby, filterable, description. + IMPORTANT - Saved Metrics vs Columns: The response includes both 'columns' (raw database columns) and 'metrics' (pre-defined saved metrics). When building chart configs, use saved_metric=true @@ -144,12 +151,24 @@ async def get_dataset_info( len(result.metrics) if result.metrics else 0, ) ) - else: - await ctx.warning( - "Dataset retrieval failed: error_type=%s, error=%s" - % (result.error_type, result.error) + await ctx.debug( + "Filtering response: select_columns=%s, column_fields=%s" + % (request.select_columns, request.column_fields) ) + with event_logger.log_context(action="mcp.get_dataset_info.serialization"): + return result.model_dump( + mode="json", + by_alias=True, + context={ + "select_columns": request.select_columns, + "column_fields": request.column_fields, + }, + ) + await ctx.warning( + "Dataset retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) return result except Exception as e: diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py index bfe8410dde9d..20bfa3f3ac5d 100644 --- a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py +++ b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py @@ -1200,6 +1200,193 @@ async def test_get_dataset_info_includes_columns_and_metrics(mock_info, mcp_serv assert data["metrics"][1]["metric_name"] == "count_orders" +def _build_full_dataset_mock(dataset_id: int = 1) -> MagicMock: + """Build a richly-populated dataset mock for response-size tests.""" + dataset = MagicMock() + dataset.id = dataset_id + dataset.table_name = "wide_table" + dataset.schema = "main" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.description = "long description " * 20 + dataset.certified_by = "team-data" + dataset.certification_details = "certified via review " * 10 + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.uuid = "00000000-0000-0000-0000-000000000001" + dataset.schema_perm = "[examples].[main]" + dataset.url = "/explore/?datasource_type=table&datasource_id=1" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {"key_" + str(i): "value_" * 20 for i in range(20)} + dataset.template_params = {"tparam": "x" * 200} + dataset.extra = {"certification": {"details": "y" * 500}} + dataset.columns = [ + MagicMock( + column_name=f"col_{i}", + verbose_name=f"Verbose Column {i}", + type="VARCHAR", + is_dttm=False, + groupby=True, + filterable=True, + description="long column description " * 30, + ) + for i in range(50) + ] + dataset.metrics = [ + MagicMock( + metric_name="count", + verbose_name="Count", + expression="COUNT(*)", + description="row count", + d3format=None, + ), + ] + return dataset + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_default_omits_verbose_fields( + mock_info, mcp_server +) -> None: + """Default response excludes verbose top-level fields and per-column + fields to keep payload size small for wide datasets. + + Regression test for SC-105681: 80KB+ responses caused eval timeouts. + """ + mock_info.return_value = _build_full_dataset_mock(dataset_id=1) + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", {"request": {"identifier": 1}} + ) + data = json.loads(result.content[0].text) + + # Top-level: lean defaults are present + assert data["id"] == 1 + assert data["table_name"] == "wide_table" + assert "columns" in data + assert "metrics" in data + # Top-level: verbose fields are excluded by default + assert "params" not in data + assert "template_params" not in data + assert "extra" not in data + assert "tags" not in data + assert "certification_details" not in data + assert "certified_by" not in data + assert "schema_perm" not in data + + # Per-column: lean defaults only + assert len(data["columns"]) == 50 + first = data["columns"][0] + assert first["column_name"] == "col_0" + assert first["type"] == "VARCHAR" + assert "is_dttm" in first + # description, verbose_name, groupby, filterable excluded by default + assert "description" not in first + assert "verbose_name" not in first + assert "groupby" not in first + assert "filterable" not in first + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_respects_select_columns(mock_info, mcp_server) -> None: + """Explicit select_columns trims the response to requested fields.""" + mock_info.return_value = _build_full_dataset_mock(dataset_id=2) + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", + { + "request": { + "identifier": 2, + "select_columns": ["id", "table_name", "columns"], + } + }, + ) + data = json.loads(result.content[0].text) + + assert set(data.keys()) == {"id", "table_name", "columns"} + assert data["id"] == 2 + assert "metrics" not in data + assert "description" not in data + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_respects_column_fields(mock_info, mcp_server) -> None: + """Explicit column_fields opts in to verbose per-column fields.""" + mock_info.return_value = _build_full_dataset_mock(dataset_id=3) + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", + { + "request": { + "identifier": 3, + "select_columns": ["id", "columns"], + "column_fields": ["column_name", "description"], + } + }, + ) + data = json.loads(result.content[0].text) + + assert set(data.keys()) == {"id", "columns"} + first = data["columns"][0] + # column_name is always preserved; description was opted in + assert first["column_name"] == "col_0" + assert "description" in first + # type was not in column_fields, so excluded + assert "type" not in first + assert "verbose_name" not in first + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_empty_lists_fall_back_to_defaults( + mock_info, mcp_server +) -> None: + """Empty select_columns / column_fields fall back to lean defaults. + + Without this, a caller passing column_fields=[] would silently + re-enable verbose per-column fields and reintroduce oversized + responses. Regression test for the review on apache/superset#39898. + """ + mock_info.return_value = _build_full_dataset_mock(dataset_id=4) + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", + { + "request": { + "identifier": 4, + "select_columns": [], + "column_fields": [], + } + }, + ) + data = json.loads(result.content[0].text) + + # Defaults applied for select_columns: verbose top-level fields excluded + assert "params" not in data + assert "extra" not in data + assert "tags" not in data + # Defaults applied for column_fields: only lean per-column fields + first = data["columns"][0] + assert first["column_name"] == "col_0" + assert "description" not in first + assert "verbose_name" not in first + assert "groupby" not in first + + @patch("superset.daos.dataset.DatasetDAO.list") @pytest.mark.asyncio async def test_list_datasets_includes_columns_and_metrics(mock_list, mcp_server): @@ -1759,6 +1946,106 @@ def _make_mock_virtual_dataset( # --- Schema tests --- +def test_table_column_info_no_context_returns_all_fields() -> None: + """TableColumnInfo serializer returns all fields when no context is set.""" + from superset.mcp_service.dataset.schemas import TableColumnInfo + + col = TableColumnInfo( + column_name="id", + verbose_name="ID", + type="INT", + is_dttm=False, + groupby=True, + filterable=True, + description="Primary key", + ) + data = col.model_dump() + assert data["column_name"] == "id" + assert data["verbose_name"] == "ID" + assert data["description"] == "Primary key" + assert data["groupby"] is True + + +def test_table_column_info_empty_context_returns_all_fields() -> None: + """TableColumnInfo serializer returns all fields when context lacks + column_fields.""" + from superset.mcp_service.dataset.schemas import TableColumnInfo + + col = TableColumnInfo(column_name="id", type="INT", description="x") + data = col.model_dump(context={"unrelated_key": "value"}) + assert data["description"] == "x" + assert data["type"] == "INT" + + +def test_table_column_info_filters_by_context() -> None: + """TableColumnInfo serializer keeps only requested fields, plus + column_name.""" + from superset.mcp_service.dataset.schemas import TableColumnInfo + + col = TableColumnInfo( + column_name="id", + verbose_name="ID", + type="INT", + description="Primary key", + ) + data = col.model_dump(context={"column_fields": ["type"]}) + assert set(data.keys()) == {"column_name", "type"} + assert data["column_name"] == "id" + + +def test_get_dataset_info_request_default_select_columns() -> None: + """Default select_columns and column_fields are populated.""" + from superset.mcp_service.dataset.schemas import ( + DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS, + DEFAULT_GET_DATASET_INFO_COLUMNS, + GetDatasetInfoRequest, + ) + + req = GetDatasetInfoRequest(identifier=1) + assert req.select_columns == DEFAULT_GET_DATASET_INFO_COLUMNS + assert req.column_fields == DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS + + +def test_get_dataset_info_request_explicit_columns_override_defaults() -> None: + """Explicit lists override the defaults.""" + from superset.mcp_service.dataset.schemas import GetDatasetInfoRequest + + req = GetDatasetInfoRequest( + identifier="abc-uuid", + select_columns=["id", "table_name"], + column_fields=["column_name", "type"], + ) + assert req.select_columns == ["id", "table_name"] + assert req.column_fields == ["column_name", "type"] + + +def test_get_dataset_info_request_empty_lists_use_defaults() -> None: + """Empty lists coerce to lean defaults so callers cannot accidentally + disable size reduction.""" + from superset.mcp_service.dataset.schemas import ( + DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS, + DEFAULT_GET_DATASET_INFO_COLUMNS, + GetDatasetInfoRequest, + ) + + req = GetDatasetInfoRequest(identifier=1, select_columns=[], column_fields=[]) + assert req.select_columns == DEFAULT_GET_DATASET_INFO_COLUMNS + assert req.column_fields == DEFAULT_GET_DATASET_INFO_COLUMN_FIELDS + + +def test_get_dataset_info_request_accepts_json_string_lists() -> None: + """Field validators parse JSON-encoded lists for both fields.""" + from superset.mcp_service.dataset.schemas import GetDatasetInfoRequest + + req = GetDatasetInfoRequest( + identifier=1, + select_columns='["id", "columns"]', + column_fields='["column_name", "is_dttm"]', + ) + assert req.select_columns == ["id", "columns"] + assert req.column_fields == ["column_name", "is_dttm"] + + def test_create_virtual_dataset_request_valid() -> None: req = CreateVirtualDatasetRequest( database_id=1,