Skip to content
271 changes: 265 additions & 6 deletions src/rockfish_mcp/sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import io
import math
import re
from typing import Any, Dict, Optional, Tuple

# Set matplotlib to non-interactive backend BEFORE importing pyplot
Expand Down Expand Up @@ -218,11 +219,14 @@ async def collect_logs():
"generation_workflow_id": generation_workflow_id,
"status": status,
}
generated_dataset = await generation_workflow.datasets().last()
# TODO
generated_datasets = await generation_workflow.datasets().collect()
return {
"success": True,
"generation_workflow_id": generation_workflow_id,
"generated_dataset_id": generated_dataset.id,
"generated_dataset_id(s)": [
generated_dataset.id for generated_dataset in generated_datasets
],
}
elif tool_name == "plot_distribution":
dataset_ids = arguments["dataset_ids"]
Expand Down Expand Up @@ -388,6 +392,107 @@ async def collect_logs():
"changes_applied": changes_applied,
"train_config": config_dict,
}
elif tool_name == "validate_data_schema_config":
data_schema_config = arguments["data_schema_config"]
entity_labels = arguments.get("entity_labels")

# Validate schema (all validation handled by _validate_data_schema)
validation_result = _validate_data_schema(data_schema_config)

if not validation_result["valid"]:
# Return validation errors with enhanced details
error_response = {
"success": False,
"message": validation_result["summary"],
"error": validation_result["error_message"],
}

return error_response

# Validation passed - cache the config
config_id = f"schema_config_{uuid.uuid4()}"
cache_entry = {
"data_schema_config": data_schema_config,
}

if entity_labels:
cache_entry["entity_labels"] = entity_labels

self._cache[config_id] = cache_entry

# Generate summary
entities = data_schema_config.get("entities", [])
entity_names = [e.get("name") for e in entities]
total_columns = sum(len(e.get("columns", [])) for e in entities)
relationships_count = len(
data_schema_config.get("entity_relationships", [])
)

response = {
"success": True,
"schema_config_id": config_id,
"summary": {
"entities_count": len(entities),
"entities": entity_names,
"total_columns": total_columns,
"relationships_count": relationships_count,
},
"message": validation_result["summary"],
}

return response
elif tool_name == "start_data_schema_generation_workflow":
schema_config_id = arguments["schema_config_id"]

# Check cache
if schema_config_id not in self._cache:
return {
"success": False,
"message": f"Config ID '{schema_config_id}' not found in cache. It may have expired or already been used. Please call validate_data_schema_config again.",
"schema_config_id": schema_config_id,
}

# Retrieve and remove from cache
cache_entry = self._cache.pop(schema_config_id)
data_schema_dict = cache_entry["data_schema_config"]
entity_labels = cache_entry.get("entity_labels")

# Convert dict to DataSchema (should not fail since we already validated)
try:
data_schema = rf.converter.structure(
data_schema_dict, ra.ent.DataSchema
)
except Exception as e:
return {
"success": False,
"message": f"Failed to convert schema config to DataSchema: {str(e)}",
"schema_config_id": schema_config_id,
}

# Create GenerateFromDataSchema action
try:
if entity_labels:
generate_action = ra.ent.GenerateFromDataSchema(
schema=data_schema, entity_labels=entity_labels
)
else:
generate_action = ra.ent.GenerateFromDataSchema(schema=data_schema)

# Build and start workflow
builder = create_workflow([generate_action])
workflow = await builder.start(self._conn)

return {
"success": True,
"workflow_id": workflow.id(),
"message": f"Started data schema generation workflow: {workflow.id()}. Use get_workflow_logs to monitor progress.",
}
except Exception as e:
return {
"success": False,
"message": f"Failed to start workflow: {str(e)}",
"schema_config_id": schema_config_id,
}
else:
return {
"success": False,
Expand Down Expand Up @@ -427,11 +532,11 @@ def _fig_to_base64(fig):
plt.close(fig.fig) # Close the underlying matplotlib figure to free memory
return img_str

if len(dataset_ids) != 2:
raise ValueError("current only support 2 datasets for comparison plotting")
# Load dataset and convert to LocalDataset
dataset = await conn.get_dataset(dataset_ids[0])
dataset = await dataset.to_local(conn)
synthetic = await conn.get_dataset(dataset_ids[1])
synthetic = await synthetic.to_local(conn)
dataset = await get_local_dataset(conn, dataset_ids[0])
synthetic = await get_local_dataset(conn, dataset_ids[1])

table = dataset.table
field_type = table[column_name].type
Expand Down Expand Up @@ -497,3 +602,157 @@ def guess_tab_gan_train_config(dataset) -> Tuple[ra.TrainTabGAN.Config, dict]:
"high_cardinality_columns": high_cardinality_columns,
}
return train_config, column_metadata


# Entity Data Generator helpers
def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate DataSchema with enhanced error messages.

Validation is performed in two layers:
1. SDK validation via rf.converter.structure() - validates types, required fields, enums
2. validators.py validation - validates business logic (PARAM_TS_07, TS_08, R11, R12)

All errors are always shown (no limiting).

Returns:
Dict with validation results:
- If valid:
{
"valid": True,
"data_schema": DataSchema,
"summary": str
}
- If invalid:
{
"valid": False,
"error_count": int,
"error_message": str (numbered list of errors),
"summary": str
}
"""
try:
# SDK validates all levels automatically via __attrs_post_init__ methods
data_schema = rf.converter.structure(schema_dict, ra.ent.DataSchema)

return {
"valid": True,
"data_schema": data_schema,
"summary": f"Validation successful: {len(schema_dict.get('entities', []))} entities",
}
except Exception as e:
error_details = extract_structure_error_details(e)

# Format errors as numbered list
error_list = []
for idx, err in enumerate(error_details["errors"], 1):
error_list.append(f"{idx}. {err['location']}: {err['error_message']}")

return {
"valid": False,
"error_count": error_details["error_count"],
"error_message": "\n\n".join(error_list),
"summary": f"Validation failed: {error_details['summary']}",
}


def extract_structure_error_details(exc: Exception) -> Dict[str, Any]:
"""
Extract error message and location from StructureError exception chain.

Simplified version that extracts only essential information:
- error_message: The actual error text from the exception
- location: The JSON path where the error occurred

Args:
exc: The exception to parse (typically StructureError)

Returns:
{
"error_count": N,
"summary": "Found N validation error(s)",
"errors": [
{
"error_message": "spike_magnitude (5.0) must be in [0, 1]",
"location": "$.entities[0].columns[1]"
},
...
]
}

Example:
>>> try:
... rf.converter.structure(ts_dict, TimeseriesParams)
... except Exception as e:
... details = extract_structure_error_details(e)
... for err in details['errors']:
... print(f"{err['location']}: {err['error_message']}")
"""

errors = []

def collect_errors(current_exc, location="$"):
"""Recursively collect error messages and locations."""
# Check for sub-exceptions (IterableValidationError)
if hasattr(current_exc, "exceptions") and current_exc.exceptions:
# Extract location info from wrapper message
msg = str(current_exc)
index_match = re.search(r"@ index (\d+)", msg)
if index_match:
index = index_match.group(1)
if "list[Column]" in msg:
new_location = f"{location}.columns[{index}]"
elif "list[Entity]" in msg:
new_location = f"{location}.entities[{index}]"
else:
new_location = f"{location}[{index}]"
else:
new_location = location

for sub_exc in current_exc.exceptions:
collect_errors(sub_exc, new_location)

# Check for __cause__ chain
elif hasattr(current_exc, "__cause__") and current_exc.__cause__:
# Extract location from StructureError if present
if type(current_exc).__name__ == "StructureError":
msg = str(current_exc)
loc_match = re.search(r"@ (\$\.[^\s]+)", msg)
if loc_match:
location = loc_match.group(1)

collect_errors(current_exc.__cause__, location)

# Leaf error - add to list
else:
# Skip wrapper exceptions
if type(current_exc).__name__ not in (
"StructureError",
"ClassValidationError",
"IterableValidationError",
"ExceptionGroup",
):
errors.append({"error_message": str(current_exc), "location": location})

# Extract top-level location from StructureError
top_location = "$"
if type(exc).__name__ == "StructureError":
top_msg = str(exc)
loc_match = re.search(r"@ (\$\.[^\s]+)", top_msg)
if loc_match:
top_location = loc_match.group(1)

# Collect all errors
collect_errors(exc, top_location)

# If no errors found, use original exception
if not errors:
errors = [{"error_message": str(exc), "location": top_location}]

# Return consistent format
error_word = "error" if len(errors) == 1 else "errors"
return {
"error_count": len(errors),
"summary": f"Found {len(errors)} validation {error_word}",
"errors": errors,
}
74 changes: 74 additions & 0 deletions src/rockfish_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,78 @@ async def handle_list_tools() -> List[types.Tool]:
"required": ["dataset_ids"],
},
),
types.Tool(
name="validate_data_schema_config",
description="Validate DataSchema configuration for entity data generation and cache it. Validates column, entity, and schema levels. Returns validation errors or success with cached config_id.",
inputSchema={
"type": "object",
"properties": {
"data_schema_config": {
"type": "object",
"description": "DataSchema configuration with entities, columns, and relationships",
"properties": {
"entities": {
"type": "array",
"description": "List of entity specifications",
"items": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Entity name",
},
"cardinality": {
"type": "integer",
"minimum": 1,
"description": "Number of rows to generate for this entity",
},
"columns": {
"type": "array",
"description": "List of column specifications",
"items": {"type": "object"},
"minItems": 1,
},
},
"required": ["name", "cardinality", "columns"],
},
"minItems": 1,
},
"entity_relationships": {
"type": "array",
"description": "List of relationships between entities",
"items": {"type": "object"},
},
"global_timestamp": {
"type": "object",
"description": "Optional global timestamp configuration",
},
},
"required": ["entities"],
"additionalProperties": True,
},
"entity_labels": {
"type": "object",
"description": "Optional entity label mappings for generated datasets",
"additionalProperties": True,
},
},
"required": ["data_schema_config"],
},
),
types.Tool(
name="start_data_schema_generation_workflow",
description="Start entity data generation workflow using cached schema config. Converts config to rockfish.actions.ent dataclasses and starts workflow. Use after validate_data_schema_config.",
inputSchema={
"type": "object",
"properties": {
"schema_config_id": {
"type": "string",
"description": "Config ID from validate_data_schema_config (retrieves cached config)",
}
},
"required": ["schema_config_id"],
},
),
]
)

Expand All @@ -1011,6 +1083,8 @@ async def handle_call_tool(
"obtain_synthetic_dataset_id",
"plot_distribution",
"get_marginal_distribution_score",
"validate_data_schema_config",
"start_data_schema_generation_workflow",
]

if name in sdk_tools:
Expand Down