diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index 2d0766a..92a9b2c 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -2,6 +2,7 @@ import base64 import io import math +import re from typing import Any, Dict, Optional, Tuple # Set matplotlib to non-interactive backend BEFORE importing pyplot @@ -218,11 +219,14 @@ async def collect_logs(): "generation_workflow_id": generation_workflow_id, "status": status, } - generated_dataset = await generation_workflow.datasets().last() + # TODO + generated_datasets = await generation_workflow.datasets().collect() return { "success": True, "generation_workflow_id": generation_workflow_id, - "generated_dataset_id": generated_dataset.id, + "generated_dataset_id(s)": [ + generated_dataset.id for generated_dataset in generated_datasets + ], } elif tool_name == "plot_distribution": dataset_ids = arguments["dataset_ids"] @@ -388,6 +392,107 @@ async def collect_logs(): "changes_applied": changes_applied, "train_config": config_dict, } + elif tool_name == "validate_data_schema_config": + data_schema_config = arguments["data_schema_config"] + entity_labels = arguments.get("entity_labels") + + # Validate schema (all validation handled by _validate_data_schema) + validation_result = _validate_data_schema(data_schema_config) + + if not validation_result["valid"]: + # Return validation errors with enhanced details + error_response = { + "success": False, + "message": validation_result["summary"], + "error": validation_result["error_message"], + } + + return error_response + + # Validation passed - cache the config + config_id = f"schema_config_{uuid.uuid4()}" + cache_entry = { + "data_schema_config": data_schema_config, + } + + if entity_labels: + cache_entry["entity_labels"] = entity_labels + + self._cache[config_id] = cache_entry + + # Generate summary + entities = data_schema_config.get("entities", []) + entity_names = [e.get("name") for e in entities] + total_columns = sum(len(e.get("columns", [])) for e in entities) + relationships_count = len( + data_schema_config.get("entity_relationships", []) + ) + + response = { + "success": True, + "schema_config_id": config_id, + "summary": { + "entities_count": len(entities), + "entities": entity_names, + "total_columns": total_columns, + "relationships_count": relationships_count, + }, + "message": validation_result["summary"], + } + + return response + elif tool_name == "start_data_schema_generation_workflow": + schema_config_id = arguments["schema_config_id"] + + # Check cache + if schema_config_id not in self._cache: + return { + "success": False, + "message": f"Config ID '{schema_config_id}' not found in cache. It may have expired or already been used. Please call validate_data_schema_config again.", + "schema_config_id": schema_config_id, + } + + # Retrieve and remove from cache + cache_entry = self._cache.pop(schema_config_id) + data_schema_dict = cache_entry["data_schema_config"] + entity_labels = cache_entry.get("entity_labels") + + # Convert dict to DataSchema (should not fail since we already validated) + try: + data_schema = rf.converter.structure( + data_schema_dict, ra.ent.DataSchema + ) + except Exception as e: + return { + "success": False, + "message": f"Failed to convert schema config to DataSchema: {str(e)}", + "schema_config_id": schema_config_id, + } + + # Create GenerateFromDataSchema action + try: + if entity_labels: + generate_action = ra.ent.GenerateFromDataSchema( + schema=data_schema, entity_labels=entity_labels + ) + else: + generate_action = ra.ent.GenerateFromDataSchema(schema=data_schema) + + # Build and start workflow + builder = create_workflow([generate_action]) + workflow = await builder.start(self._conn) + + return { + "success": True, + "workflow_id": workflow.id(), + "message": f"Started data schema generation workflow: {workflow.id()}. Use get_workflow_logs to monitor progress.", + } + except Exception as e: + return { + "success": False, + "message": f"Failed to start workflow: {str(e)}", + "schema_config_id": schema_config_id, + } else: return { "success": False, @@ -427,11 +532,11 @@ def _fig_to_base64(fig): plt.close(fig.fig) # Close the underlying matplotlib figure to free memory return img_str + if len(dataset_ids) != 2: + raise ValueError("current only support 2 datasets for comparison plotting") # Load dataset and convert to LocalDataset - dataset = await conn.get_dataset(dataset_ids[0]) - dataset = await dataset.to_local(conn) - synthetic = await conn.get_dataset(dataset_ids[1]) - synthetic = await synthetic.to_local(conn) + dataset = await get_local_dataset(conn, dataset_ids[0]) + synthetic = await get_local_dataset(conn, dataset_ids[1]) table = dataset.table field_type = table[column_name].type @@ -497,3 +602,157 @@ def guess_tab_gan_train_config(dataset) -> Tuple[ra.TrainTabGAN.Config, dict]: "high_cardinality_columns": high_cardinality_columns, } return train_config, column_metadata + + +# Entity Data Generator helpers +def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate DataSchema with enhanced error messages. + + Validation is performed in two layers: + 1. SDK validation via rf.converter.structure() - validates types, required fields, enums + 2. validators.py validation - validates business logic (PARAM_TS_07, TS_08, R11, R12) + + All errors are always shown (no limiting). + + Returns: + Dict with validation results: + - If valid: + { + "valid": True, + "data_schema": DataSchema, + "summary": str + } + - If invalid: + { + "valid": False, + "error_count": int, + "error_message": str (numbered list of errors), + "summary": str + } + """ + try: + # SDK validates all levels automatically via __attrs_post_init__ methods + data_schema = rf.converter.structure(schema_dict, ra.ent.DataSchema) + + return { + "valid": True, + "data_schema": data_schema, + "summary": f"Validation successful: {len(schema_dict.get('entities', []))} entities", + } + except Exception as e: + error_details = extract_structure_error_details(e) + + # Format errors as numbered list + error_list = [] + for idx, err in enumerate(error_details["errors"], 1): + error_list.append(f"{idx}. {err['location']}: {err['error_message']}") + + return { + "valid": False, + "error_count": error_details["error_count"], + "error_message": "\n\n".join(error_list), + "summary": f"Validation failed: {error_details['summary']}", + } + + +def extract_structure_error_details(exc: Exception) -> Dict[str, Any]: + """ + Extract error message and location from StructureError exception chain. + + Simplified version that extracts only essential information: + - error_message: The actual error text from the exception + - location: The JSON path where the error occurred + + Args: + exc: The exception to parse (typically StructureError) + + Returns: + { + "error_count": N, + "summary": "Found N validation error(s)", + "errors": [ + { + "error_message": "spike_magnitude (5.0) must be in [0, 1]", + "location": "$.entities[0].columns[1]" + }, + ... + ] + } + + Example: + >>> try: + ... rf.converter.structure(ts_dict, TimeseriesParams) + ... except Exception as e: + ... details = extract_structure_error_details(e) + ... for err in details['errors']: + ... print(f"{err['location']}: {err['error_message']}") + """ + + errors = [] + + def collect_errors(current_exc, location="$"): + """Recursively collect error messages and locations.""" + # Check for sub-exceptions (IterableValidationError) + if hasattr(current_exc, "exceptions") and current_exc.exceptions: + # Extract location info from wrapper message + msg = str(current_exc) + index_match = re.search(r"@ index (\d+)", msg) + if index_match: + index = index_match.group(1) + if "list[Column]" in msg: + new_location = f"{location}.columns[{index}]" + elif "list[Entity]" in msg: + new_location = f"{location}.entities[{index}]" + else: + new_location = f"{location}[{index}]" + else: + new_location = location + + for sub_exc in current_exc.exceptions: + collect_errors(sub_exc, new_location) + + # Check for __cause__ chain + elif hasattr(current_exc, "__cause__") and current_exc.__cause__: + # Extract location from StructureError if present + if type(current_exc).__name__ == "StructureError": + msg = str(current_exc) + loc_match = re.search(r"@ (\$\.[^\s]+)", msg) + if loc_match: + location = loc_match.group(1) + + collect_errors(current_exc.__cause__, location) + + # Leaf error - add to list + else: + # Skip wrapper exceptions + if type(current_exc).__name__ not in ( + "StructureError", + "ClassValidationError", + "IterableValidationError", + "ExceptionGroup", + ): + errors.append({"error_message": str(current_exc), "location": location}) + + # Extract top-level location from StructureError + top_location = "$" + if type(exc).__name__ == "StructureError": + top_msg = str(exc) + loc_match = re.search(r"@ (\$\.[^\s]+)", top_msg) + if loc_match: + top_location = loc_match.group(1) + + # Collect all errors + collect_errors(exc, top_location) + + # If no errors found, use original exception + if not errors: + errors = [{"error_message": str(exc), "location": top_location}] + + # Return consistent format + error_word = "error" if len(errors) == 1 else "errors" + return { + "error_count": len(errors), + "summary": f"Found {len(errors)} validation {error_word}", + "errors": errors, + } diff --git a/src/rockfish_mcp/server.py b/src/rockfish_mcp/server.py index 34d0840..775e306 100644 --- a/src/rockfish_mcp/server.py +++ b/src/rockfish_mcp/server.py @@ -988,6 +988,78 @@ async def handle_list_tools() -> List[types.Tool]: "required": ["dataset_ids"], }, ), + types.Tool( + name="validate_data_schema_config", + description="Validate DataSchema configuration for entity data generation and cache it. Validates column, entity, and schema levels. Returns validation errors or success with cached config_id.", + inputSchema={ + "type": "object", + "properties": { + "data_schema_config": { + "type": "object", + "description": "DataSchema configuration with entities, columns, and relationships", + "properties": { + "entities": { + "type": "array", + "description": "List of entity specifications", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Entity name", + }, + "cardinality": { + "type": "integer", + "minimum": 1, + "description": "Number of rows to generate for this entity", + }, + "columns": { + "type": "array", + "description": "List of column specifications", + "items": {"type": "object"}, + "minItems": 1, + }, + }, + "required": ["name", "cardinality", "columns"], + }, + "minItems": 1, + }, + "entity_relationships": { + "type": "array", + "description": "List of relationships between entities", + "items": {"type": "object"}, + }, + "global_timestamp": { + "type": "object", + "description": "Optional global timestamp configuration", + }, + }, + "required": ["entities"], + "additionalProperties": True, + }, + "entity_labels": { + "type": "object", + "description": "Optional entity label mappings for generated datasets", + "additionalProperties": True, + }, + }, + "required": ["data_schema_config"], + }, + ), + types.Tool( + name="start_data_schema_generation_workflow", + description="Start entity data generation workflow using cached schema config. Converts config to rockfish.actions.ent dataclasses and starts workflow. Use after validate_data_schema_config.", + inputSchema={ + "type": "object", + "properties": { + "schema_config_id": { + "type": "string", + "description": "Config ID from validate_data_schema_config (retrieves cached config)", + } + }, + "required": ["schema_config_id"], + }, + ), ] ) @@ -1011,6 +1083,8 @@ async def handle_call_tool( "obtain_synthetic_dataset_id", "plot_distribution", "get_marginal_distribution_score", + "validate_data_schema_config", + "start_data_schema_generation_workflow", ] if name in sdk_tools: