From 7c3665784e681e1161627cace963cd9e20a70269 Mon Sep 17 00:00:00 2001 From: Fan Date: Wed, 26 Nov 2025 17:39:07 -0800 Subject: [PATCH 01/10] add for data schema generation --- src/rockfish_mcp/sdk_client.py | 531 ++++++++++++++++++++ src/rockfish_mcp/server.py | 79 +++ src/rockfish_mcp/validators.py | 878 +++++++++++++++++++++++++++++++++ 3 files changed, 1488 insertions(+) create mode 100644 src/rockfish_mcp/validators.py diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index 2d0766a..b9e5466 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -12,6 +12,7 @@ import rockfish as rf import rockfish.actions as ra import rockfish.labs as rl +from rockfish.converter import StructureError from rockfish.remote import glue matplotlib.use("Agg") @@ -22,6 +23,13 @@ logger = logging.getLogger(__name__) +# Import validators for comprehensive DataSchema validation +from rockfish_mcp.validators import ( + validate_dataschema_comprehensive, + ValidationError, + ValidationLevel, +) + class RockfishSDKClient: def __init__( @@ -388,6 +396,155 @@ async def collect_logs(): "changes_applied": changes_applied, "train_config": config_dict, } + elif tool_name == "validate_data_schema_config": + data_schema_config = arguments["data_schema_config"] + entity_labels = arguments.get("entity_labels") + show_all_errors = arguments.get("show_all_errors", False) + + # Python type checking (Layer 2 validation) + if not isinstance(data_schema_config, dict): + return { + "success": False, + "message": "data_schema_config must be a dictionary", + "error": f"Expected dict, got {type(data_schema_config).__name__}", + } + + if "entities" not in data_schema_config: + return { + "success": False, + "message": "data_schema_config must contain 'entities' field", + "error": "Missing required field: 'entities'", + } + + if not isinstance(data_schema_config["entities"], list): + return { + "success": False, + "message": "'entities' must be a list", + "error": f"Expected list, got {type(data_schema_config['entities']).__name__}", + } + + if len(data_schema_config["entities"]) == 0: + return { + "success": False, + "message": "'entities' list cannot be empty", + "error": "At least one entity is required", + } + + if entity_labels is not None and not isinstance(entity_labels, dict): + return { + "success": False, + "message": "entity_labels must be a dictionary", + "error": f"Expected dict, got {type(entity_labels).__name__}", + } + + # Validate schema using helper function (Layer 3 validation) + # Pass show_all_errors via schema_dict for _validate_data_schema + schema_dict_with_options = { + **data_schema_config, + "_show_all_errors": show_all_errors + } + validation_result = _validate_data_schema(schema_dict_with_options) + + if not validation_result["valid"]: + # Return validation errors with enhanced details + error_response = { + "success": False, + "message": validation_result["summary"], + "error": validation_result["error_message"], + } + # Add optional fields if present + if "suggestion" in validation_result and validation_result["suggestion"]: + error_response["suggestion"] = validation_result["suggestion"] + if "reference" in validation_result and validation_result["reference"]: + error_response["reference"] = validation_result["reference"] + + return error_response + + # Validation passed - cache the config + config_id = f"schema_config_{uuid.uuid4()}" + cache_entry = { + "data_schema_config": data_schema_config, + } + # TODO: entity_labels? + if entity_labels: + cache_entry["entity_labels"] = entity_labels + + self._cache[config_id] = cache_entry + + # Generate summary + entities = data_schema_config.get("entities", []) + entity_names = [e.get("name") for e in entities] + total_columns = sum(len(e.get("columns", [])) for e in entities) + relationships_count = len(data_schema_config.get("entity_relationships", [])) + + response = { + "success": True, + "schema_config_id": config_id, + "summary": { + "entities_count": len(entities), + "entities": entity_names, + "total_columns": total_columns, + "relationships_count": relationships_count, + }, + "message": validation_result["summary"], + } + + # Add warnings if present (from comprehensive validation) + if "warnings" in validation_result and validation_result["warnings"]: + response["warnings"] = validation_result["warnings"] + + return response + elif tool_name == "start_data_schema_generation_workflow": + schema_config_id = arguments["schema_config_id"] + + # Check cache + if schema_config_id not in self._cache: + return { + "success": False, + "message": f"Config ID '{schema_config_id}' not found in cache. It may have expired or already been used. Please call validate_data_schema_config again.", + "schema_config_id": schema_config_id, + } + + # Retrieve and remove from cache + cache_entry = self._cache.pop(schema_config_id) + data_schema_dict = cache_entry["data_schema_config"] + entity_labels = cache_entry.get("entity_labels") # TODO: make use of entity_labels in GenerateFromDataSchema + + # Convert dict to DataSchema (should not fail since we already validated) + try: + data_schema = rf.converter.structure(data_schema_dict, ra.ent.DataSchema) + except Exception as e: + return { + "success": False, + "message": f"Failed to convert schema config to DataSchema: {str(e)}", + "schema_config_id": schema_config_id, + } + + # Create GenerateFromDataSchema action + try: + if entity_labels: + generate_action = ra.ent.GenerateFromDataSchema( + schema=data_schema, + entity_labels=entity_labels + ) + else: + generate_action = ra.ent.GenerateFromDataSchema(schema=data_schema) + + # Build and start workflow + builder = create_workflow([generate_action]) + workflow = await builder.start(self._conn) + + return { + "success": True, + "workflow_id": workflow.id(), + "message": f"Started data schema generation workflow: {workflow.id()}. Use get_workflow_logs to monitor progress.", + } + except Exception as e: + return { + "success": False, + "message": f"Failed to start workflow: {str(e)}", + "schema_config_id": schema_config_id, + } else: return { "success": False, @@ -497,3 +654,377 @@ def guess_tab_gan_train_config(dataset) -> Tuple[ra.TrainTabGAN.Config, dict]: "high_cardinality_columns": high_cardinality_columns, } return train_config, column_metadata + + +# Entity Data Generator helpers +def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate DataSchema with enhanced error messages for common mistakes. + + Returns: + Dict with validation results: + { + "valid": bool, + "data_schema": DataSchema (if valid), + "error_message": str (if invalid), + "suggestion": str (if invalid, optional), + "reference": str (if invalid, optional), + "summary": str + } + """ + try: + # SDK validates all levels automatically via __attrs_post_init__ methods + data_schema = rf.converter.structure(schema_dict, ra.ent.DataSchema) + + # Run comprehensive validation (Layer 3) + comprehensive_errors = validate_dataschema_comprehensive(data_schema) + + if comprehensive_errors: + # Convert ValidationError list to response format + # show_all parameter passed from tool arguments + validation_result = _format_comprehensive_errors( + comprehensive_errors, + schema_dict, + show_all=schema_dict.get('_show_all_errors', False) + ) + + # If valid=True (warnings only), add data_schema for caching + if validation_result["valid"]: + validation_result["data_schema"] = data_schema + + return validation_result + + return { + "valid": True, + "data_schema": data_schema, + "summary": f"Validation successful: {len(schema_dict.get('entities', []))} entities" + } + except StructureError as e: + # Parse StructureError message with intelligent error detection + error_msg = str(e) + helpful_error = _detect_common_errors(error_msg, schema_dict) + + return { + "valid": False, + "error_message": helpful_error["detailed_message"], + "suggestion": helpful_error.get("fix_suggestion"), + "reference": helpful_error.get("docs_link"), + "summary": f"Validation failed: {helpful_error['summary']}" + } + except Exception as e: + # Catch other validation errors (from __attrs_post_init__) + error_msg = str(e) + helpful_error = _detect_common_errors(error_msg, schema_dict) + + return { + "valid": False, + "error_message": helpful_error["detailed_message"], + "suggestion": helpful_error.get("fix_suggestion"), + "reference": helpful_error.get("docs_link"), + "summary": f"Validation failed: {helpful_error['summary']}" + } + + +def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Detect common error patterns and provide actionable fix suggestions. + + Common patterns: + 1. Domain structure errors (missing "params" wrapper) + 2. Derivation structure errors (missing "params" wrapper) + 3. Column type/category mismatches + 4. Domain/Derivation mutual exclusivity violations + 5. Missing required fields + 6. Invalid enum values + + Args: + error_msg: The error message from SDK validation + schema_dict: The schema dictionary being validated + + Returns: + Dict with summary, detailed_message, fix_suggestion, and docs_link + """ + import re + + # Pattern 1: Domain validation error - "expected Optional @ $.entities[X].columns[Y].domain" + domain_error_pattern = r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.domain" + match = re.search(domain_error_pattern, error_msg) + + if match: + entity_idx = int(match.group(1)) + col_idx = int(match.group(2)) + + # Extract the problematic domain + try: + entities = schema_dict.get("entities", []) + if entity_idx < len(entities): + entity = entities[entity_idx] + columns = entity.get("columns", []) + if col_idx < len(columns): + column = columns[col_idx] + domain = column.get("domain", {}) + + # Check if domain has wrong structure (missing "params" wrapper) + if isinstance(domain, dict) and "type" in domain: + domain_type = domain.get("type") + # Check if there are extra keys that should be in "params" + extra_keys = [k for k in domain.keys() if k not in ["type", "params"]] + + if extra_keys and "params" not in domain: + params_str = ', '.join([f'"{k}"' for k in extra_keys]) + return { + "summary": f"Domain structure error in column '{column.get('name', 'unknown')}'", + "detailed_message": ( + f"Domain validation failed for entity '{entity.get('name', 'unknown')}', " + f"column '{column.get('name', 'unknown')}' (entities[{entity_idx}].columns[{col_idx}]).\n\n" + f"ERROR: Domain parameters must be nested under 'params' key.\n\n" + f"The parameters {params_str} should be inside a 'params' dictionary.\n\n" + f"Correct structure:\n" + f' "domain": {{\n' + f' "type": "{domain_type}",\n' + f' "params": {{ {params_str}: ... }}\n' + f' }}' + ), + "fix_suggestion": f"Move {params_str} inside 'params': {{\"type\": \"{domain_type}\", \"params\": {{...}}}}", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Domain" + } + except (IndexError, KeyError, TypeError): + pass # Fall through to generic error + + # Pattern 2: Derivation validation error - similar structure + derivation_error_pattern = r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.derivation" + match = re.search(derivation_error_pattern, error_msg) + + if match: + entity_idx = int(match.group(1)) + col_idx = int(match.group(2)) + + try: + entities = schema_dict.get("entities", []) + if entity_idx < len(entities): + entity = entities[entity_idx] + columns = entity.get("columns", []) + if col_idx < len(columns): + column = columns[col_idx] + derivation = column.get("derivation", {}) + + if isinstance(derivation, dict) and "function_type" in derivation: + func_type = derivation.get("function_type") + extra_keys = [k for k in derivation.keys() if k not in ["function_type", "params", "dependent_columns"]] + + if extra_keys and "params" not in derivation: + params_str = ', '.join([f'"{k}"' for k in extra_keys]) + return { + "summary": f"Derivation structure error in column '{column.get('name', 'unknown')}'", + "detailed_message": ( + f"Derivation validation failed for entity '{entity.get('name', 'unknown')}', " + f"column '{column.get('name', 'unknown')}' (entities[{entity_idx}].columns[{col_idx}]).\n\n" + f"ERROR: Derivation parameters must be nested under 'params' key.\n\n" + f"Correct structure:\n" + f' "derivation": {{\n' + f' "function_type": "{func_type}",\n' + f' "dependent_columns": [...],\n' + f' "params": {{ {params_str}: ... }}\n' + f' }}' + ), + "fix_suggestion": f"Move {params_str} inside 'params'", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Derivation" + } + except (IndexError, KeyError, TypeError): + pass + + # Pattern 3: Enum validation errors + if "expected ColumnType" in error_msg: + return { + "summary": "Invalid column_type value", + "detailed_message": ( + f"{error_msg}\n\n" + "Valid values: 'independent', 'stateful', 'derived', 'foreign_key'\n\n" + "Usage guide:\n" + "- 'independent': Columns with non-temporal domains (ID, categorical, distributions)\n" + "- 'stateful': Columns with temporal domains (timeseries, state_machine)\n" + "- 'derived': Columns computed from other columns\n" + "- 'foreign_key': Auto-generated foreign key references" + ), + "fix_suggestion": "Use one of: 'independent', 'stateful', 'derived', 'foreign_key'", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + } + + if "expected ColumnCategoryType" in error_msg or "expected CategoryType" in error_msg: + return { + "summary": "Invalid column_category_type value", + "detailed_message": ( + f"{error_msg}\n\n" + "Valid values: 'metadata', 'measurement'\n\n" + "IMPORTANT constraints:\n" + "- STATEFUL columns MUST use 'measurement' category\n" + "- FOREIGN_KEY columns MUST use 'metadata' category" + ), + "fix_suggestion": "Use 'metadata' or 'measurement' (stateful columns require 'measurement')", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + } + + # Pattern 4: Domain AND derivation error + if "cannot have both domain and derivation" in error_msg.lower(): + return { + "summary": "Column has both domain and derivation", + "detailed_message": ( + f"{error_msg}\n\n" + "RULE: Columns must have EITHER domain OR derivation, not both.\n\n" + "- INDEPENDENT/STATEFUL columns: Use 'domain' (no 'derivation')\n" + "- DERIVED columns: Use 'derivation' (no 'domain')\n" + "- FOREIGN_KEY columns: Neither (auto-generated)" + ), + "fix_suggestion": "Remove either 'domain' or 'derivation' based on your column_type", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + } + + # Pattern 5: Missing domain or derivation + if "must have domain" in error_msg.lower() or "must have derivation" in error_msg.lower(): + return { + "summary": "Missing required domain or derivation", + "detailed_message": ( + f"{error_msg}\n\n" + "Column requirements:\n" + "- INDEPENDENT columns: Must have 'domain'\n" + "- STATEFUL columns: Must have 'domain' (timeseries or state_machine)\n" + "- DERIVED columns: Must have 'derivation'\n" + "- FOREIGN_KEY columns: No domain/derivation needed (auto-generated)" + ), + "fix_suggestion": "Add appropriate 'domain' or 'derivation' based on column_type", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + } + + # Pattern 6: Missing global_timestamp + if "global_timestamp" in error_msg.lower() and "required" in error_msg.lower(): + return { + "summary": "Missing global_timestamp configuration", + "detailed_message": ( + f"{error_msg}\n\n" + "RULE: If any entity has a timestamp field, the DataSchema must include 'global_timestamp'.\n\n" + "Example:\n" + f' "global_timestamp": {{\n' + f' "start_timestamp": "2024-01-01T00:00:00",\n' + f' "cadence": {{"num_steps": 1000}}\n' + f' }}' + ), + "fix_suggestion": "Add 'global_timestamp' to your DataSchema with start_timestamp and cadence", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.GlobalTimestamp" + } + + # Pattern 7: Duplicate names + if "duplicate" in error_msg.lower() and "name" in error_msg.lower(): + return { + "summary": "Duplicate names detected", + "detailed_message": ( + f"{error_msg}\n\n" + "RULES:\n" + "- Entity names must be unique across the schema\n" + "- Column names must be unique within each entity" + ), + "fix_suggestion": "Ensure all entity names and column names are unique", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html" + } + + # Pattern 8: Missing required fields + if "missing" in error_msg.lower() or "required" in error_msg.lower(): + return { + "summary": "Missing required field", + "detailed_message": ( + f"{error_msg}\n\n" + "Required fields by type:\n" + "- Column: name, data_type, column_type, column_category_type\n" + "- Entity: name, cardinality, columns (at least one)\n" + "- DataSchema: entities (at least one)\n" + "- Domain: type, params\n" + "- Derivation: function_type, dependent_columns, params" + ), + "fix_suggestion": "Add the missing required field indicated in the error", + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html" + } + + # Generic error (no pattern matched) + return { + "summary": "Validation error", + "detailed_message": error_msg, + "reference": "See https://docs.rockfish.ai/sdk/actions-ent.html for complete schema documentation" + } + + +def _format_comprehensive_errors( + errors: list, schema_dict: Dict[str, Any], show_all: bool = False +) -> Dict[str, Any]: + """ + Convert ValidationError objects from validators.py to MCP tool response format. + + Only ERROR level blocks validation. WARNING/INFO are advisory. + + Args: + errors: List of ValidationError objects from validators.py + schema_dict: Original schema dictionary for context + show_all: If True, show all errors. If False, show first 5 (default) + + Returns: + Dict with keys: valid, error_message, suggestion, reference, summary + OR (if no ERROR level): valid, warnings, data_schema, summary + """ + # Separate by severity + error_level = [e for e in errors if e.level == ValidationLevel.ERROR] + warning_level = [e for e in errors if e.level == ValidationLevel.WARNING] + info_level = [e for e in errors if e.level == ValidationLevel.INFO] + + # Only ERROR level blocks validation + if not error_level: + # No errors - return success with warnings/info as advisory + advisory = [] + for w in warning_level + info_level: + advisory.append({ + "level": str(w.level), + "rule": w.rule, + "message": w.message, + "location": w.location, + "suggestion": w.suggestion + }) + + return { + "valid": True, + "warnings": advisory if advisory else [], + "summary": f"Validation passed with {len(warning_level)} warning(s) and {len(info_level)} info message(s)" + } + + # Has ERROR level - return error response + primary_error = error_level[0] + + # Determine how many errors to show + max_display = len(error_level) if show_all else min(5, len(error_level)) + + # Build error details + error_details = [] + for idx, err in enumerate(error_level[:max_display], 1): + error_details.append( + f"{idx}. [{err.level}] {err.rule}: {err.message}\n" + f" Location: {err.location}" + ) + + if len(error_level) > max_display: + error_details.append( + f"\n... and {len(error_level) - max_display} more error(s)" + ) + if not show_all: + error_details.append( + "Tip: Set show_all_errors=true in the tool call to see all errors" + ) + + detailed_message = "\n\n".join(error_details) + + summary = f"Validation failed: {len(error_level)} error(s)" + if warning_level or info_level: + summary += f" (+ {len(warning_level)} warning(s), {len(info_level)} info)" + + return { + "valid": False, + "error_message": detailed_message, + "suggestion": primary_error.suggestion if primary_error.suggestion else "", + "reference": "https://docs.rockfish.ai/sdk/actions-ent.html", + "summary": summary + } diff --git a/src/rockfish_mcp/server.py b/src/rockfish_mcp/server.py index 79ad5eb..4e8f2a3 100644 --- a/src/rockfish_mcp/server.py +++ b/src/rockfish_mcp/server.py @@ -1007,6 +1007,83 @@ async def handle_list_tools() -> List[types.Tool]: "required": ["dataset_ids"], }, ), + types.Tool( + name="validate_data_schema_config", + description="Validate DataSchema configuration for entity data generation and cache it. Validates column, entity, and schema levels. Returns validation errors or success with cached config_id.", + inputSchema={ + "type": "object", + "properties": { + "data_schema_config": { + "type": "object", + "description": "DataSchema configuration with entities, columns, and relationships", + "properties": { + "entities": { + "type": "array", + "description": "List of entity specifications", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Entity name" + }, + "cardinality": { + "type": "integer", + "minimum": 1, + "description": "Number of rows to generate for this entity" + }, + "columns": { + "type": "array", + "description": "List of column specifications", + "items": {"type": "object"}, + "minItems": 1 + } + }, + "required": ["name", "cardinality", "columns"] + }, + "minItems": 1 + }, + "entity_relationships": { + "type": "array", + "description": "List of relationships between entities", + "items": {"type": "object"} + }, + "global_timestamp": { + "type": "object", + "description": "Optional global timestamp configuration" + } + }, + "required": ["entities"], + "additionalProperties": True + }, + "entity_labels": { + "type": "object", + "description": "Optional entity label mappings for generated datasets", + "additionalProperties": True + }, + "show_all_errors": { + "type": "boolean", + "description": "If true, show all validation errors. If false, show first 5 (default: false)", + "default": False + } + }, + "required": ["data_schema_config"], + }, + ), + types.Tool( + name="start_data_schema_generation_workflow", + description="Start entity data generation workflow using cached schema config. Converts config to rockfish.actions.ent dataclasses and starts workflow. Use after validate_data_schema_config.", + inputSchema={ + "type": "object", + "properties": { + "schema_config_id": { + "type": "string", + "description": "Config ID from validate_data_schema_config (retrieves cached config)", + } + }, + "required": ["schema_config_id"], + }, + ), ] ) @@ -1030,6 +1107,8 @@ async def handle_call_tool( "obtain_synthetic_dataset_id", "plot_distribution", "get_marginal_distribution_score", + "validate_data_schema_config", + "start_data_schema_generation_workflow", ] if name in sdk_tools: diff --git a/src/rockfish_mcp/validators.py b/src/rockfish_mcp/validators.py new file mode 100644 index 0000000..3913500 --- /dev/null +++ b/src/rockfish_mcp/validators.py @@ -0,0 +1,878 @@ +""" +Comprehensive validation utilities for DataSchema configurations. + +This module provides detailed validation beyond what rf.converter.structure() performs, +checking business logic rules, parameter constraints, and semantic relationships. + +Validation Layers: +1. Structure validation (rf.converter.structure) - type checking, required fields, enums +2. Parameter validation (this module) - ranges, constraints, logical consistency +3. Business logic validation (this module) - R1-R10 rules +4. Graph validation (planner) - circular dependencies + +Usage: + from ent.validators import validate_dataschema_comprehensive + + errors = validate_dataschema_comprehensive(schema) + if errors: + for error in errors: + print(f"[{error.level}] {error.rule}: {error.message}") +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from rockfish.actions.ent import ( + CategoricalParams, + Column, + ColumnCategoryType, + ColumnType, + DataSchema, + Derivation, + DerivationFunctionType, + Domain, + DomainType, + Entity, + EntityRelationship, + EntityRelationshipType, + ExponentialDistParams, + GlobalTimestamp, + IDParams, + MapValuesParams, + NormalDistParams, + StateMachineParams, + TimeseriesParams, + Transition, + UniformDistParams, +) + + +class ValidationLevel(str, Enum): + """Severity level of validation error.""" + + ERROR = "ERROR" # Must fix - will cause generation to fail + WARNING = "WARNING" # Should fix - may cause unexpected behavior + INFO = "INFO" # Informational - best practice suggestion + + +@dataclass +class ValidationError: + """Structured validation error.""" + + level: ValidationLevel + rule: str # e.g., "R1", "PARAM_UNIFORM_01", "COL_INDEPENDENT_01" + message: str + location: str # e.g., "entity 'users' > column 'age'" + suggestion: str = "" # Optional fix suggestion + + +class DataSchemaValidator: + """Comprehensive DataSchema validator.""" + + def __init__(self, schema: DataSchema): + self.schema = schema + self.errors: list[ValidationError] = [] + self.entity_map = {entity.name: entity for entity in schema.entities} + + def validate_all(self) -> list[ValidationError]: + """Run all validation checks and return errors.""" + self.errors = [] + + # Layer 2: Parameter validation + self._validate_domain_params() + self._validate_derivation_params() + + # Layer 3: Business logic (R1-R10) + self._validate_business_rules() + + # Additional checks + self._validate_entities() + self._validate_relationships() + self._validate_global_timestamp() + + return self.errors + + # ========================================================================= + # Domain Parameter Validation + # ========================================================================= + + def _validate_domain_params(self): + """Validate all domain parameters.""" + for entity in self.schema.entities: + for column in entity.columns: + if column.domain is None: + continue + + loc = f"entity '{entity.name}' > column '{column.name}'" + + if column.domain.type == DomainType.ID: + self._validate_id_params(column.domain.params, loc) + elif column.domain.type == DomainType.CATEGORICAL: + self._validate_categorical_params(column.domain.params, loc) + elif column.domain.type == DomainType.UNIFORM_DIST: + self._validate_uniform_params(column.domain.params, loc) + elif column.domain.type == DomainType.NORMAL_DIST: + self._validate_normal_params(column.domain.params, loc) + elif column.domain.type == DomainType.EXPONENTIAL_DIST: + self._validate_exponential_params(column.domain.params, loc) + elif column.domain.type == DomainType.TIMESERIES: + self._validate_timeseries_params(column.domain.params, loc) + elif column.domain.type == DomainType.STATE_MACHINE: + self._validate_state_machine_params(column.domain.params, loc) + + def _validate_id_params(self, params: IDParams, location: str): + """Validate IDParams: template must contain {id}.""" + if "{id}" not in params.template_str: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_ID_01", + message=f"IDParams template_str must contain '{{id}}' placeholder, got: '{params.template_str}'", + location=location, + suggestion="Use a template like 'USER_{{id}}' or 'item-{{id}}'", + ) + ) + + def _validate_categorical_params( + self, params: CategoricalParams, location: str + ): + """Validate CategoricalParams: values not empty, weights match length.""" + if not params.values: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_CAT_01", + message="CategoricalParams values list cannot be empty", + location=location, + suggestion="Provide at least one value in the values list", + ) + ) + + if params.weights is not None: + if len(params.weights) != len(params.values): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_CAT_02", + message=f"CategoricalParams weights length ({len(params.weights)}) must match values length ({len(params.values)})", + location=location, + suggestion="Either remove weights or ensure it has the same length as values", + ) + ) + + def _validate_uniform_params(self, params: UniformDistParams, location: str): + """Validate UniformDistParams: lower < upper.""" + if params.lower >= params.upper: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_UNIFORM_01", + message=f"UniformDistParams lower ({params.lower}) must be less than upper ({params.upper})", + location=location, + suggestion=f"Swap the values or use lower={params.upper}, upper={params.lower + (params.upper - params.lower) * 2}", + ) + ) + + def _validate_normal_params(self, params: NormalDistParams, location: str): + """Validate NormalDistParams: std > 0.""" + if params.std <= 0: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_NORMAL_01", + message=f"NormalDistParams std (standard deviation) must be positive, got: {params.std}", + location=location, + suggestion="Use a positive value like std=10.0 or std=1.5", + ) + ) + + def _validate_exponential_params( + self, params: ExponentialDistParams, location: str + ): + """Validate ExponentialDistParams: scale > 0.""" + if params.scale <= 0: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_EXP_01", + message=f"ExponentialDistParams scale must be positive, got: {params.scale}", + location=location, + suggestion="Use a positive value like scale=2.0", + ) + ) + + def _validate_timeseries_params( + self, params: TimeseriesParams, location: str + ): + """Validate TimeseriesParams: 6 range and probability checks.""" + # min_value < max_value + if params.min_value >= params.max_value: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_01", + message=f"TimeseriesParams min_value ({params.min_value}) must be less than max_value ({params.max_value})", + location=location, + suggestion="Ensure min_value < base_value < max_value", + ) + ) + + # peak_start_hour < peak_end_hour (only relevant for peak_offpeak) + if params.seasonality_type == "peak_offpeak": + if params.peak_start_hour >= params.peak_end_hour: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_02", + message=f"TimeseriesParams peak_start_hour ({params.peak_start_hour}) must be less than peak_end_hour ({params.peak_end_hour})", + location=location, + suggestion="Use values like peak_start_hour=8, peak_end_hour=22 for business hours", + ) + ) + + # seasonality_strength in [0, 1] + if not (0.0 <= params.seasonality_strength <= 1.0): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_03", + message=f"TimeseriesParams seasonality_strength must be in [0, 1], got: {params.seasonality_strength}", + location=location, + suggestion="Use a value between 0.0 (no seasonality) and 1.0 (strong seasonality)", + ) + ) + + # noise_level in [0, 1] + if not (0.0 <= params.noise_level <= 1.0): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_04", + message=f"TimeseriesParams noise_level must be in [0, 1], got: {params.noise_level}", + location=location, + suggestion="Use a value between 0.0 (no noise) and 1.0 (high noise)", + ) + ) + + # spike_probability in [0, 1] + if not (0.0 <= params.spike_probability <= 1.0): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_05", + message=f"TimeseriesParams spike_probability must be in [0, 1], got: {params.spike_probability}", + location=location, + suggestion="Use a value between 0.0 (no spikes) and 1.0 (frequent spikes)", + ) + ) + + # spike_magnitude in [0, 1] + if not (0.0 <= params.spike_magnitude <= 1.0): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_06", + message=f"TimeseriesParams spike_magnitude must be in [0, 1], got: {params.spike_magnitude}", + location=location, + suggestion="Use a value between 0.0 (small spikes) and 1.0 (large spikes)", + ) + ) + + def _validate_state_machine_params( + self, params: StateMachineParams, location: str + ): + """Validate StateMachineParams: states, transitions, context variables.""" + # initial_state in states + if params.initial_state not in params.states: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_01", + message=f"StateMachineParams initial_state '{params.initial_state}' not in states list: {params.states}", + location=location, + suggestion=f"Add '{params.initial_state}' to states or use one of: {params.states}", + ) + ) + + # terminal_states all in states + for terminal in params.terminal_states: + if terminal not in params.states: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_02", + message=f"StateMachineParams terminal_state '{terminal}' not in states list: {params.states}", + location=location, + suggestion=f"Add '{terminal}' to states or remove from terminal_states", + ) + ) + + # Validate each transition + for idx, trans in enumerate(params.transitions): + trans_loc = f"{location} > transition {idx}" + self._validate_transition(trans, params.states, params.context_variables, trans_loc) + + def _validate_transition( + self, + trans: Transition, + states: list[str], + context_vars: dict[str, bool], + location: str, + ): + """Validate a single transition.""" + # source in states + if trans.source not in states: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_03", + message=f"Transition source '{trans.source}' not in states list: {states}", + location=location, + suggestion=f"Add '{trans.source}' to states or change transition source", + ) + ) + + # dest in states + if trans.dest not in states: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_04", + message=f"Transition dest '{trans.dest}' not in states list: {states}", + location=location, + suggestion=f"Add '{trans.dest}' to states or change transition dest", + ) + ) + + # probability in (0, 1] + if not (0.0 < trans.probability <= 1.0): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_05", + message=f"Transition probability must be > 0 and <= 1, got: {trans.probability}", + location=location, + suggestion="Use a probability value like 0.7 or 0.3", + ) + ) + + # conditions reference valid context vars + for cond in trans.conditions: + if cond not in context_vars: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_06", + message=f"Transition condition '{cond}' not defined in context_variables: {list(context_vars.keys())}", + location=location, + suggestion=f"Add '{cond}' to context_variables or remove from conditions", + ) + ) + + # context_updates reference valid context vars + for key in trans.context_updates: + if key not in context_vars: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_SM_07", + message=f"Transition context_update key '{key}' not defined in context_variables: {list(context_vars.keys())}", + location=location, + suggestion=f"Add '{key}' to context_variables or remove from context_updates", + ) + ) + + # ========================================================================= + # Derivation Parameter Validation + # ========================================================================= + + def _validate_derivation_params(self): + """Validate all derivation parameters.""" + for entity in self.schema.entities: + for column in entity.columns: + if column.derivation is None: + continue + + loc = f"entity '{entity.name}' > column '{column.name}'" + + if column.derivation.function_type == DerivationFunctionType.MAP_VALUES: + self._validate_map_values_params( + column.derivation.params, loc + ) + + # Check for unsupported cross-category MEASUREMENT dependencies + self._validate_measurement_dependencies(entity, column, loc) + + def _validate_map_values_params( + self, params: MapValuesParams, location: str + ): + """Validate MapValuesParams: mapping not empty, rules have from/to.""" + if not params.mapping: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_MAP_01", + message="MapValuesParams mapping list cannot be empty", + location=location, + suggestion='Provide mapping rules like [{"from": "active", "to": "high"}]', + ) + ) + + for idx, rule in enumerate(params.mapping): + if "from" not in rule: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_MAP_02", + message=f"MapValuesParams mapping rule {idx} missing 'from' key", + location=location, + suggestion=f"Add 'from' key to rule: {rule}", + ) + ) + if "to" not in rule: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_MAP_03", + message=f"MapValuesParams mapping rule {idx} missing 'to' key", + location=location, + suggestion=f"Add 'to' key to rule: {rule}", + ) + ) + + def _validate_measurement_dependencies( + self, entity: Entity, column: Column, location: str + ): + """ + Validate that MEASUREMENT derived columns don't have same-entity MEASUREMENT dependencies. + + This is currently unsupported and will cause a KeyError at runtime because MEASUREMENT + columns are generated in an arbitrary order when they don't have explicit dependencies + tracked in the column graph. + """ + # Only check MEASUREMENT DERIVED columns + if column.column_category_type != ColumnCategoryType.MEASUREMENT: + return + if column.column_type != ColumnType.DERIVED: + return + if column.derivation is None: + return + + # Build a map of column names to their categories in this entity + entity_columns = {col.name: col for col in entity.columns} + + # Check each dependency + for dep_col_name in column.derivation.dependent_columns: + # Skip cross-entity dependencies (they're fine because dependent entity is generated first) + if "." in dep_col_name: + continue + + # Check if this is a same-entity dependency + dep_col = entity_columns.get(dep_col_name) + if dep_col is None: + # Dependency doesn't exist in this entity - will be caught by other validation + continue + + # Check if the dependency is also a MEASUREMENT column + if dep_col.column_category_type == ColumnCategoryType.MEASUREMENT: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="COL_DERIVED_01", + message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", + location=location, + suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", + ) + ) + + # ========================================================================= + # Business Logic Validation (R1-R10) + # ========================================================================= + + def _validate_business_rules(self): + """Validate R1-R10 business logic rules.""" + # R1: If any entity has Timestamp → GlobalTimestamp required + entities_with_timestamps = [ + e.name for e in self.schema.entities if e.timestamp is not None + ] + if entities_with_timestamps and self.schema.global_timestamp is None: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R1", + message=f"Entities {entities_with_timestamps} have timestamp, but global_timestamp is not defined", + location="schema root", + suggestion="Add global_timestamp configuration with t_start, t_end, and time_interval", + ) + ) + + # R2-R6: Column-level rules (validated per column) + for entity in self.schema.entities: + for column in entity.columns: + loc = f"entity '{entity.name}' > column '{column.name}'" + self._validate_column_business_rules(column, loc) + + # R6: Entity with Timestamp → Must have ≥1 measurement column + if entity.timestamp is not None: + has_measurement = any( + col.column_category_type == ColumnCategoryType.MEASUREMENT + for col in entity.columns + ) + if not has_measurement: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R6", + message=f"Entity '{entity.name}' has timestamp but no measurement columns", + location=f"entity '{entity.name}'", + suggestion="Add at least one column with column_category_type='measurement'", + ) + ) + + # R7: ONE_TO_ONE relationship → from.cardinality ≤ to.cardinality + if self.schema.entity_relationships: + for rel in self.schema.entity_relationships: + if rel.relationship_type == EntityRelationshipType.ONE_TO_ONE: + from_entity = self.entity_map.get(rel.from_entity) + to_entity = self.entity_map.get(rel.to_entity) + if from_entity and to_entity: + if from_entity.cardinality > to_entity.cardinality: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R7", + message=f"ONE_TO_ONE relationship from '{rel.from_entity}' to '{rel.to_entity}': child cardinality ({from_entity.cardinality}) cannot exceed parent cardinality ({to_entity.cardinality})", + location=f"relationship {rel.from_entity} -> {rel.to_entity}", + suggestion=f"Either increase '{rel.to_entity}' cardinality or reduce '{rel.from_entity}' cardinality", + ) + ) + + # R8: Entity names must be unique + entity_names = [e.name for e in self.schema.entities] + duplicates = [ + name for name in set(entity_names) if entity_names.count(name) > 1 + ] + if duplicates: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R8", + message=f"Duplicate entity names found: {duplicates}", + location="schema root", + suggestion="Rename entities to have unique names", + ) + ) + + # R9: Column names unique within entity + for entity in self.schema.entities: + column_names = [c.name for c in entity.columns] + duplicates = [ + name + for name in set(column_names) + if column_names.count(name) > 1 + ] + if duplicates: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R9", + message=f"Duplicate column names in entity '{entity.name}': {duplicates}", + location=f"entity '{entity.name}'", + suggestion="Rename columns to have unique names within the entity", + ) + ) + + def _validate_column_business_rules(self, column: Column, location: str): + """Validate business rules R2-R5, R10 for a single column.""" + # R2: STATEFUL → must be MEASUREMENT + if column.column_type == ColumnType.STATEFUL: + if column.column_category_type != ColumnCategoryType.MEASUREMENT: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R2", + message=f"STATEFUL column must be MEASUREMENT category, got: {column.column_category_type}", + location=location, + suggestion="Change column_category_type to 'measurement'", + ) + ) + + # R3: STATEFUL → domain must be STATE_MACHINE or TIMESERIES + if column.column_type == ColumnType.STATEFUL: + if column.domain is None or column.domain.type not in ( + DomainType.STATE_MACHINE, + DomainType.TIMESERIES, + ): + domain_type = ( + column.domain.type if column.domain else "None" + ) + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R3", + message=f"STATEFUL column must have STATE_MACHINE or TIMESERIES domain, got: {domain_type}", + location=location, + suggestion="Use domain.type='timeseries' or domain.type='state_machine'", + ) + ) + + # R4: INDEPENDENT → domain CANNOT be STATE_MACHINE or TIMESERIES + if column.column_type == ColumnType.INDEPENDENT: + if column.domain and column.domain.type in ( + DomainType.STATE_MACHINE, + DomainType.TIMESERIES, + ): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R4", + message=f"INDEPENDENT column cannot have temporal domain ({column.domain.type})", + location=location, + suggestion="Use a non-temporal domain like 'categorical', 'uniform_dist', or 'id'", + ) + ) + + # R5: FOREIGN_KEY → must be METADATA + if column.column_type == ColumnType.FOREIGN_KEY: + if column.column_category_type != ColumnCategoryType.METADATA: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R5", + message=f"FOREIGN_KEY column must be METADATA category, got: {column.column_category_type}", + location=location, + suggestion="Change column_category_type to 'metadata'", + ) + ) + + # R10: Column has EXACTLY ONE: domain OR derivation OR neither (FK only) + has_domain = column.domain is not None + has_derivation = column.derivation is not None + + if column.column_type in (ColumnType.INDEPENDENT, ColumnType.STATEFUL): + if not has_domain: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R10", + message=f"{column.column_type} column must have domain", + location=location, + suggestion="Add a domain configuration for this column", + ) + ) + if has_derivation: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R10", + message=f"{column.column_type} column cannot have derivation", + location=location, + suggestion="Remove the derivation field", + ) + ) + + elif column.column_type == ColumnType.DERIVED: + if has_domain: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R10", + message="DERIVED column cannot have domain", + location=location, + suggestion="Remove the domain field", + ) + ) + if not has_derivation: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R10", + message="DERIVED column must have derivation", + location=location, + suggestion="Add a derivation configuration for this column", + ) + ) + + elif column.column_type == ColumnType.FOREIGN_KEY: + if has_domain or has_derivation: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R10", + message="FOREIGN_KEY column cannot have domain or derivation", + location=location, + suggestion="Remove domain and derivation fields (they are auto-populated)", + ) + ) + + # ========================================================================= + # Entity Validation + # ========================================================================= + + def _validate_entities(self): + """Validate entity-level constraints.""" + for entity in self.schema.entities: + loc = f"entity '{entity.name}'" + + # Cardinality must be positive + if entity.cardinality <= 0: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="ENT_01", + message=f"Entity cardinality must be positive, got: {entity.cardinality}", + location=loc, + suggestion="Use a positive integer like cardinality=100", + ) + ) + + # Must have at least one column + if not entity.columns: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="ENT_02", + message="Entity must have at least one column", + location=loc, + suggestion="Add column definitions to this entity", + ) + ) + + # ========================================================================= + # Relationship Validation + # ========================================================================= + + def _validate_relationships(self): + """Validate entity relationship constraints.""" + if not self.schema.entity_relationships: + return + + for rel in self.schema.entity_relationships: + loc = f"relationship {rel.from_entity} -> {rel.to_entity}" + + # join_columns cannot be empty + if not rel.join_columns: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="REL_01", + message="Relationship join_columns cannot be empty", + location=loc, + suggestion='Add join_columns like {"user_id": "user_id"}', + ) + ) + + # from_entity ≠ to_entity + if rel.from_entity == rel.to_entity: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="REL_02", + message=f"Relationship cannot be self-referential: '{rel.from_entity}'", + location=loc, + suggestion="Create relationships between different entities", + ) + ) + + # from_entity and to_entity must exist + if rel.from_entity not in self.entity_map: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="REL_03", + message=f"Relationship references unknown from_entity: '{rel.from_entity}'", + location=loc, + suggestion=f"Use one of: {list(self.entity_map.keys())}", + ) + ) + + if rel.to_entity not in self.entity_map: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="REL_04", + message=f"Relationship references unknown to_entity: '{rel.to_entity}'", + location=loc, + suggestion=f"Use one of: {list(self.entity_map.keys())}", + ) + ) + + # ========================================================================= + # GlobalTimestamp Validation + # ========================================================================= + + def _validate_global_timestamp(self): + """Validate GlobalTimestamp constraints.""" + if self.schema.global_timestamp is None: + return + + gt = self.schema.global_timestamp + loc = "global_timestamp" + + # time_interval format validation (already done in config.py __attrs_post_init__) + # Just add a reminder check + import re + + pattern = r"^\d+(min|hour|day|month)$" + if not re.match(pattern, gt.time_interval): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="GT_01", + message=f"GlobalTimestamp time_interval format invalid: '{gt.time_interval}'", + location=loc, + suggestion="Use format like '15min', '1hour', '1day', or '3months'", + ) + ) + + +# ============================================================================= +# Public API +# ============================================================================= + + +def validate_dataschema_comprehensive( + schema: DataSchema, +) -> list[ValidationError]: + """ + Run comprehensive validation on a DataSchema. + + This performs validation beyond rf.converter.structure(), checking: + - Parameter ranges and constraints (54 rules) + - Business logic rules (R1-R10) + - Semantic relationships + + Args: + schema: DataSchema object to validate + + Returns: + List of ValidationError objects (empty if valid) + + Example: + >>> errors = validate_dataschema_comprehensive(schema) + >>> if errors: + ... for err in errors: + ... print(f"[{err.rule}] {err.message}") + ... else: + ... print("Schema is valid!") + """ + validator = DataSchemaValidator(schema) + return validator.validate_all() + + +def format_validation_errors(errors: list[ValidationError]) -> str: + """Format validation errors as a readable report.""" + if not errors: + return "✅ Schema validation passed!" + + report = [f"❌ Found {len(errors)} validation error(s):\n"] + + for idx, err in enumerate(errors, 1): + report.append(f"{idx}. [{err.level}] {err.rule}: {err.message}") + report.append(f" Location: {err.location}") + if err.suggestion: + report.append(f" Suggestion: {err.suggestion}") + report.append("") + + return "\n".join(report) From 431a92bd417c9b3ea8e35fef4f0d8b6a8bbc6bb9 Mon Sep 17 00:00:00 2001 From: Fan Date: Wed, 26 Nov 2025 17:39:54 -0800 Subject: [PATCH 02/10] lint --- src/rockfish_mcp/sdk_client.py | 116 ++++++++++++++++++++------------- src/rockfish_mcp/server.py | 30 +++++---- src/rockfish_mcp/validators.py | 32 +++------ 3 files changed, 98 insertions(+), 80 deletions(-) diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index b9e5466..a0fefbb 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -25,9 +25,9 @@ # Import validators for comprehensive DataSchema validation from rockfish_mcp.validators import ( - validate_dataschema_comprehensive, ValidationError, ValidationLevel, + validate_dataschema_comprehensive, ) @@ -441,7 +441,7 @@ async def collect_logs(): # Pass show_all_errors via schema_dict for _validate_data_schema schema_dict_with_options = { **data_schema_config, - "_show_all_errors": show_all_errors + "_show_all_errors": show_all_errors, } validation_result = _validate_data_schema(schema_dict_with_options) @@ -453,7 +453,10 @@ async def collect_logs(): "error": validation_result["error_message"], } # Add optional fields if present - if "suggestion" in validation_result and validation_result["suggestion"]: + if ( + "suggestion" in validation_result + and validation_result["suggestion"] + ): error_response["suggestion"] = validation_result["suggestion"] if "reference" in validation_result and validation_result["reference"]: error_response["reference"] = validation_result["reference"] @@ -475,7 +478,9 @@ async def collect_logs(): entities = data_schema_config.get("entities", []) entity_names = [e.get("name") for e in entities] total_columns = sum(len(e.get("columns", [])) for e in entities) - relationships_count = len(data_schema_config.get("entity_relationships", [])) + relationships_count = len( + data_schema_config.get("entity_relationships", []) + ) response = { "success": True, @@ -508,11 +513,15 @@ async def collect_logs(): # Retrieve and remove from cache cache_entry = self._cache.pop(schema_config_id) data_schema_dict = cache_entry["data_schema_config"] - entity_labels = cache_entry.get("entity_labels") # TODO: make use of entity_labels in GenerateFromDataSchema + entity_labels = cache_entry.get( + "entity_labels" + ) # TODO: make use of entity_labels in GenerateFromDataSchema # Convert dict to DataSchema (should not fail since we already validated) try: - data_schema = rf.converter.structure(data_schema_dict, ra.ent.DataSchema) + data_schema = rf.converter.structure( + data_schema_dict, ra.ent.DataSchema + ) except Exception as e: return { "success": False, @@ -524,8 +533,7 @@ async def collect_logs(): try: if entity_labels: generate_action = ra.ent.GenerateFromDataSchema( - schema=data_schema, - entity_labels=entity_labels + schema=data_schema, entity_labels=entity_labels ) else: generate_action = ra.ent.GenerateFromDataSchema(schema=data_schema) @@ -685,7 +693,7 @@ def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: validation_result = _format_comprehensive_errors( comprehensive_errors, schema_dict, - show_all=schema_dict.get('_show_all_errors', False) + show_all=schema_dict.get("_show_all_errors", False), ) # If valid=True (warnings only), add data_schema for caching @@ -697,7 +705,7 @@ def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: return { "valid": True, "data_schema": data_schema, - "summary": f"Validation successful: {len(schema_dict.get('entities', []))} entities" + "summary": f"Validation successful: {len(schema_dict.get('entities', []))} entities", } except StructureError as e: # Parse StructureError message with intelligent error detection @@ -709,7 +717,7 @@ def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: "error_message": helpful_error["detailed_message"], "suggestion": helpful_error.get("fix_suggestion"), "reference": helpful_error.get("docs_link"), - "summary": f"Validation failed: {helpful_error['summary']}" + "summary": f"Validation failed: {helpful_error['summary']}", } except Exception as e: # Catch other validation errors (from __attrs_post_init__) @@ -721,11 +729,13 @@ def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: "error_message": helpful_error["detailed_message"], "suggestion": helpful_error.get("fix_suggestion"), "reference": helpful_error.get("docs_link"), - "summary": f"Validation failed: {helpful_error['summary']}" + "summary": f"Validation failed: {helpful_error['summary']}", } -def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[str, Any]: +def _detect_common_errors( + error_msg: str, schema_dict: Dict[str, Any] +) -> Dict[str, Any]: """ Detect common error patterns and provide actionable fix suggestions. @@ -747,7 +757,9 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s import re # Pattern 1: Domain validation error - "expected Optional @ $.entities[X].columns[Y].domain" - domain_error_pattern = r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.domain" + domain_error_pattern = ( + r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.domain" + ) match = re.search(domain_error_pattern, error_msg) if match: @@ -768,10 +780,12 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s if isinstance(domain, dict) and "type" in domain: domain_type = domain.get("type") # Check if there are extra keys that should be in "params" - extra_keys = [k for k in domain.keys() if k not in ["type", "params"]] + extra_keys = [ + k for k in domain.keys() if k not in ["type", "params"] + ] if extra_keys and "params" not in domain: - params_str = ', '.join([f'"{k}"' for k in extra_keys]) + params_str = ", ".join([f'"{k}"' for k in extra_keys]) return { "summary": f"Domain structure error in column '{column.get('name', 'unknown')}'", "detailed_message": ( @@ -783,16 +797,18 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s f' "domain": {{\n' f' "type": "{domain_type}",\n' f' "params": {{ {params_str}: ... }}\n' - f' }}' + f" }}" ), - "fix_suggestion": f"Move {params_str} inside 'params': {{\"type\": \"{domain_type}\", \"params\": {{...}}}}", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Domain" + "fix_suggestion": f'Move {params_str} inside \'params\': {{"type": "{domain_type}", "params": {{...}}}}', + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Domain", } except (IndexError, KeyError, TypeError): pass # Fall through to generic error # Pattern 2: Derivation validation error - similar structure - derivation_error_pattern = r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.derivation" + derivation_error_pattern = ( + r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.derivation" + ) match = re.search(derivation_error_pattern, error_msg) if match: @@ -810,10 +826,14 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s if isinstance(derivation, dict) and "function_type" in derivation: func_type = derivation.get("function_type") - extra_keys = [k for k in derivation.keys() if k not in ["function_type", "params", "dependent_columns"]] + extra_keys = [ + k + for k in derivation.keys() + if k not in ["function_type", "params", "dependent_columns"] + ] if extra_keys and "params" not in derivation: - params_str = ', '.join([f'"{k}"' for k in extra_keys]) + params_str = ", ".join([f'"{k}"' for k in extra_keys]) return { "summary": f"Derivation structure error in column '{column.get('name', 'unknown')}'", "detailed_message": ( @@ -825,10 +845,10 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s f' "function_type": "{func_type}",\n' f' "dependent_columns": [...],\n' f' "params": {{ {params_str}: ... }}\n' - f' }}' + f" }}" ), "fix_suggestion": f"Move {params_str} inside 'params'", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Derivation" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Derivation", } except (IndexError, KeyError, TypeError): pass @@ -847,10 +867,13 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s "- 'foreign_key': Auto-generated foreign key references" ), "fix_suggestion": "Use one of: 'independent', 'stateful', 'derived', 'foreign_key'", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", } - if "expected ColumnCategoryType" in error_msg or "expected CategoryType" in error_msg: + if ( + "expected ColumnCategoryType" in error_msg + or "expected CategoryType" in error_msg + ): return { "summary": "Invalid column_category_type value", "detailed_message": ( @@ -861,7 +884,7 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s "- FOREIGN_KEY columns MUST use 'metadata' category" ), "fix_suggestion": "Use 'metadata' or 'measurement' (stateful columns require 'measurement')", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", } # Pattern 4: Domain AND derivation error @@ -876,11 +899,14 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s "- FOREIGN_KEY columns: Neither (auto-generated)" ), "fix_suggestion": "Remove either 'domain' or 'derivation' based on your column_type", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", } # Pattern 5: Missing domain or derivation - if "must have domain" in error_msg.lower() or "must have derivation" in error_msg.lower(): + if ( + "must have domain" in error_msg.lower() + or "must have derivation" in error_msg.lower() + ): return { "summary": "Missing required domain or derivation", "detailed_message": ( @@ -892,7 +918,7 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s "- FOREIGN_KEY columns: No domain/derivation needed (auto-generated)" ), "fix_suggestion": "Add appropriate 'domain' or 'derivation' based on column_type", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", } # Pattern 6: Missing global_timestamp @@ -906,10 +932,10 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s f' "global_timestamp": {{\n' f' "start_timestamp": "2024-01-01T00:00:00",\n' f' "cadence": {{"num_steps": 1000}}\n' - f' }}' + f" }}" ), "fix_suggestion": "Add 'global_timestamp' to your DataSchema with start_timestamp and cadence", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.GlobalTimestamp" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.GlobalTimestamp", } # Pattern 7: Duplicate names @@ -923,7 +949,7 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s "- Column names must be unique within each entity" ), "fix_suggestion": "Ensure all entity names and column names are unique", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html", } # Pattern 8: Missing required fields @@ -940,14 +966,14 @@ def _detect_common_errors(error_msg: str, schema_dict: Dict[str, Any]) -> Dict[s "- Derivation: function_type, dependent_columns, params" ), "fix_suggestion": "Add the missing required field indicated in the error", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html" + "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html", } # Generic error (no pattern matched) return { "summary": "Validation error", "detailed_message": error_msg, - "reference": "See https://docs.rockfish.ai/sdk/actions-ent.html for complete schema documentation" + "reference": "See https://docs.rockfish.ai/sdk/actions-ent.html for complete schema documentation", } @@ -978,18 +1004,20 @@ def _format_comprehensive_errors( # No errors - return success with warnings/info as advisory advisory = [] for w in warning_level + info_level: - advisory.append({ - "level": str(w.level), - "rule": w.rule, - "message": w.message, - "location": w.location, - "suggestion": w.suggestion - }) + advisory.append( + { + "level": str(w.level), + "rule": w.rule, + "message": w.message, + "location": w.location, + "suggestion": w.suggestion, + } + ) return { "valid": True, "warnings": advisory if advisory else [], - "summary": f"Validation passed with {len(warning_level)} warning(s) and {len(info_level)} info message(s)" + "summary": f"Validation passed with {len(warning_level)} warning(s) and {len(info_level)} info message(s)", } # Has ERROR level - return error response @@ -1026,5 +1054,5 @@ def _format_comprehensive_errors( "error_message": detailed_message, "suggestion": primary_error.suggestion if primary_error.suggestion else "", "reference": "https://docs.rockfish.ai/sdk/actions-ent.html", - "summary": summary + "summary": summary, } diff --git a/src/rockfish_mcp/server.py b/src/rockfish_mcp/server.py index 4e8f2a3..63c632d 100644 --- a/src/rockfish_mcp/server.py +++ b/src/rockfish_mcp/server.py @@ -1025,47 +1025,47 @@ async def handle_list_tools() -> List[types.Tool]: "properties": { "name": { "type": "string", - "description": "Entity name" + "description": "Entity name", }, "cardinality": { "type": "integer", "minimum": 1, - "description": "Number of rows to generate for this entity" + "description": "Number of rows to generate for this entity", }, "columns": { "type": "array", "description": "List of column specifications", "items": {"type": "object"}, - "minItems": 1 - } + "minItems": 1, + }, }, - "required": ["name", "cardinality", "columns"] + "required": ["name", "cardinality", "columns"], }, - "minItems": 1 + "minItems": 1, }, "entity_relationships": { "type": "array", "description": "List of relationships between entities", - "items": {"type": "object"} + "items": {"type": "object"}, }, "global_timestamp": { "type": "object", - "description": "Optional global timestamp configuration" - } + "description": "Optional global timestamp configuration", + }, }, "required": ["entities"], - "additionalProperties": True + "additionalProperties": True, }, "entity_labels": { "type": "object", "description": "Optional entity label mappings for generated datasets", - "additionalProperties": True + "additionalProperties": True, }, "show_all_errors": { "type": "boolean", "description": "If true, show all validation errors. If false, show first 5 (default: false)", - "default": False - } + "default": False, + }, }, "required": ["data_schema_config"], }, @@ -1196,7 +1196,9 @@ async def main(): # Initialize Rockfish client api_key = os.getenv("ROCKFISH_API_KEY") # Support both new API_URL and legacy BASE_URL variable names for backwards compatibility - api_url = os.getenv("ROCKFISH_API_URL") or os.getenv("ROCKFISH_BASE_URL", "https://api.rockfish.ai") + api_url = os.getenv("ROCKFISH_API_URL") or os.getenv( + "ROCKFISH_BASE_URL", "https://api.rockfish.ai" + ) organization_id = os.getenv("ROCKFISH_ORGANIZATION_ID", None) project_id = os.getenv("ROCKFISH_PROJECT_ID", None) diff --git a/src/rockfish_mcp/validators.py b/src/rockfish_mcp/validators.py index 3913500..0f79ff6 100644 --- a/src/rockfish_mcp/validators.py +++ b/src/rockfish_mcp/validators.py @@ -134,9 +134,7 @@ def _validate_id_params(self, params: IDParams, location: str): ) ) - def _validate_categorical_params( - self, params: CategoricalParams, location: str - ): + def _validate_categorical_params(self, params: CategoricalParams, location: str): """Validate CategoricalParams: values not empty, weights match length.""" if not params.values: self.errors.append( @@ -202,9 +200,7 @@ def _validate_exponential_params( ) ) - def _validate_timeseries_params( - self, params: TimeseriesParams, location: str - ): + def _validate_timeseries_params(self, params: TimeseriesParams, location: str): """Validate TimeseriesParams: 6 range and probability checks.""" # min_value < max_value if params.min_value >= params.max_value: @@ -279,9 +275,7 @@ def _validate_timeseries_params( ) ) - def _validate_state_machine_params( - self, params: StateMachineParams, location: str - ): + def _validate_state_machine_params(self, params: StateMachineParams, location: str): """Validate StateMachineParams: states, transitions, context variables.""" # initial_state in states if params.initial_state not in params.states: @@ -311,7 +305,9 @@ def _validate_state_machine_params( # Validate each transition for idx, trans in enumerate(params.transitions): trans_loc = f"{location} > transition {idx}" - self._validate_transition(trans, params.states, params.context_variables, trans_loc) + self._validate_transition( + trans, params.states, params.context_variables, trans_loc + ) def _validate_transition( self, @@ -397,16 +393,12 @@ def _validate_derivation_params(self): loc = f"entity '{entity.name}' > column '{column.name}'" if column.derivation.function_type == DerivationFunctionType.MAP_VALUES: - self._validate_map_values_params( - column.derivation.params, loc - ) + self._validate_map_values_params(column.derivation.params, loc) # Check for unsupported cross-category MEASUREMENT dependencies self._validate_measurement_dependencies(entity, column, loc) - def _validate_map_values_params( - self, params: MapValuesParams, location: str - ): + def _validate_map_values_params(self, params: MapValuesParams, location: str): """Validate MapValuesParams: mapping not empty, rules have from/to.""" if not params.mapping: self.errors.append( @@ -568,9 +560,7 @@ def _validate_business_rules(self): for entity in self.schema.entities: column_names = [c.name for c in entity.columns] duplicates = [ - name - for name in set(column_names) - if column_names.count(name) > 1 + name for name in set(column_names) if column_names.count(name) > 1 ] if duplicates: self.errors.append( @@ -604,9 +594,7 @@ def _validate_column_business_rules(self, column: Column, location: str): DomainType.STATE_MACHINE, DomainType.TIMESERIES, ): - domain_type = ( - column.domain.type if column.domain else "None" - ) + domain_type = column.domain.type if column.domain else "None" self.errors.append( ValidationError( level=ValidationLevel.ERROR, From 168507927ad4e495016322fb4dc7e100d12cc9b9 Mon Sep 17 00:00:00 2001 From: Fan Date: Mon, 1 Dec 2025 14:55:58 -0800 Subject: [PATCH 03/10] update validators for independent measurement, global timestamp restrictions --- src/rockfish_mcp/validators.py | 69 +++++++++++++++++++++++++++++----- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/src/rockfish_mcp/validators.py b/src/rockfish_mcp/validators.py index 0f79ff6..41e34f3 100644 --- a/src/rockfish_mcp/validators.py +++ b/src/rockfish_mcp/validators.py @@ -7,7 +7,7 @@ Validation Layers: 1. Structure validation (rf.converter.structure) - type checking, required fields, enums 2. Parameter validation (this module) - ranges, constraints, logical consistency -3. Business logic validation (this module) - R1-R10 rules +3. Business logic validation (this module) - R1-R12 rules 4. Graph validation (planner) - circular dependencies Usage: @@ -130,7 +130,7 @@ def _validate_id_params(self, params: IDParams, location: str): rule="PARAM_ID_01", message=f"IDParams template_str must contain '{{id}}' placeholder, got: '{params.template_str}'", location=location, - suggestion="Use a template like 'USER_{{id}}' or 'item-{{id}}'", + suggestion="Use a template like 'USER_{id}' or 'item-{id}'", ) ) @@ -227,6 +227,30 @@ def _validate_timeseries_params(self, params: TimeseriesParams, location: str): ) ) + # PARAM_TS_07: peak_start_hour must be in [0, 23] + if not (0 <= params.peak_start_hour <= 23): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_07", + message=f"TimeseriesParams peak_start_hour must be in [0, 23], got: {params.peak_start_hour}", + location=location, + suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", + ) + ) + + # PARAM_TS_08: peak_end_hour must be in [0, 23] + if not (0 <= params.peak_end_hour <= 23): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="PARAM_TS_08", + message=f"TimeseriesParams peak_end_hour must be in [0, 23], got: {params.peak_end_hour}", + location=location, + suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", + ) + ) + # seasonality_strength in [0, 1] if not (0.0 <= params.seasonality_strength <= 1.0): self.errors.append( @@ -471,7 +495,7 @@ def _validate_measurement_dependencies( self.errors.append( ValidationError( level=ValidationLevel.ERROR, - rule="COL_DERIVED_01", + rule="R11", message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", location=location, suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", @@ -479,11 +503,11 @@ def _validate_measurement_dependencies( ) # ========================================================================= - # Business Logic Validation (R1-R10) + # Business Logic Validation (R1-R12) # ========================================================================= def _validate_business_rules(self): - """Validate R1-R10 business logic rules.""" + """Validate R1-R12 business logic rules.""" # R1: If any entity has Timestamp → GlobalTimestamp required entities_with_timestamps = [ e.name for e in self.schema.entities if e.timestamp is not None @@ -499,7 +523,7 @@ def _validate_business_rules(self): ) ) - # R2-R6: Column-level rules (validated per column) + # R2-R5, R10-R12: Column-level rules (validated per column) for entity in self.schema.entities: for column in entity.columns: loc = f"entity '{entity.name}' > column '{column.name}'" @@ -574,7 +598,7 @@ def _validate_business_rules(self): ) def _validate_column_business_rules(self, column: Column, location: str): - """Validate business rules R2-R5, R10 for a single column.""" + """Validate business rules R2-R5, R10-R12 for a single column.""" # R2: STATEFUL → must be MEASUREMENT if column.column_type == ColumnType.STATEFUL: if column.column_category_type != ColumnCategoryType.MEASUREMENT: @@ -634,6 +658,21 @@ def _validate_column_business_rules(self, column: Column, location: str): ) ) + # R12: MEASUREMENT columns cannot be INDEPENDENT (currently unsupported) + if ( + column.column_category_type == ColumnCategoryType.MEASUREMENT + and column.column_type == ColumnType.INDEPENDENT + ): + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="R12", + message=f"MEASUREMENT column cannot be INDEPENDENT (currently unsupported)", + location=location, + suggestion="Change column_type to 'stateful' for time-varying measurements, or change column_category_type to 'metadata' for static values", + ) + ) + # R10: Column has EXACTLY ONE: domain OR derivation OR neither (FK only) has_domain = column.domain is not None has_derivation = column.derivation is not None @@ -814,6 +853,18 @@ def _validate_global_timestamp(self): ) ) + # GT_02: t_start < t_end + if gt.t_start >= gt.t_end: + self.errors.append( + ValidationError( + level=ValidationLevel.ERROR, + rule="GT_02", + message=f"GlobalTimestamp t_start ({gt.t_start}) must be less than t_end ({gt.t_end})", + location=loc, + suggestion="Ensure start time is before end time", + ) + ) + # ============================================================================= # Public API @@ -827,8 +878,8 @@ def validate_dataschema_comprehensive( Run comprehensive validation on a DataSchema. This performs validation beyond rf.converter.structure(), checking: - - Parameter ranges and constraints (54 rules) - - Business logic rules (R1-R10) + - Parameter ranges and constraints (30 rules) + - Business logic rules (R1-R12) - Semantic relationships Args: From 56f20418505e315cc3b1c91871921388edf2043b Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 2 Dec 2025 10:57:13 -0800 Subject: [PATCH 04/10] update MCP toop --- src/rockfish_mcp/sdk_client.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index a0fefbb..db1293c 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -226,11 +226,12 @@ async def collect_logs(): "generation_workflow_id": generation_workflow_id, "status": status, } - generated_dataset = await generation_workflow.datasets().last() + # TODO + generated_datasets = await generation_workflow.datasets().collect() return { "success": True, "generation_workflow_id": generation_workflow_id, - "generated_dataset_id": generated_dataset.id, + "generated_dataset_id(s)": [generated_dataset.id for generated_dataset in generated_datasets], } elif tool_name == "plot_distribution": dataset_ids = arguments["dataset_ids"] @@ -591,12 +592,11 @@ def _fig_to_base64(fig): buf.close() plt.close(fig.fig) # Close the underlying matplotlib figure to free memory return img_str - + if len(dataset_ids) != 2: + raise ValueError("current only support 2 datasets for comparison plotting") # Load dataset and convert to LocalDataset - dataset = await conn.get_dataset(dataset_ids[0]) - dataset = await dataset.to_local(conn) - synthetic = await conn.get_dataset(dataset_ids[1]) - synthetic = await synthetic.to_local(conn) + dataset = await get_local_dataset(conn, dataset_ids[0]) + synthetic = await get_local_dataset(conn, dataset_ids[1]) table = dataset.table field_type = table[column_name].type From 95c4fa20c095323b45cca856726a9b1b3bf6cef5 Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 2 Dec 2025 11:02:56 -0800 Subject: [PATCH 05/10] formatting --- src/rockfish_mcp/sdk_client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index db1293c..5b04ac2 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -231,7 +231,9 @@ async def collect_logs(): return { "success": True, "generation_workflow_id": generation_workflow_id, - "generated_dataset_id(s)": [generated_dataset.id for generated_dataset in generated_datasets], + "generated_dataset_id(s)": [ + generated_dataset.id for generated_dataset in generated_datasets + ], } elif tool_name == "plot_distribution": dataset_ids = arguments["dataset_ids"] @@ -592,6 +594,7 @@ def _fig_to_base64(fig): buf.close() plt.close(fig.fig) # Close the underlying matplotlib figure to free memory return img_str + if len(dataset_ids) != 2: raise ValueError("current only support 2 datasets for comparison plotting") # Load dataset and convert to LocalDataset From 04c6454e3106d469e30885094febd3af39be9779 Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 2 Dec 2025 11:57:15 -0800 Subject: [PATCH 06/10] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 3196018e1a0f605241175b9ea9a556bf8db4806c Author: Shane Duan <18065+wolfdancer@users.noreply.github.com> Date: Thu Nov 27 19:18:56 2025 -0800 Fix requirements.txt to use --find-links for Rockfish SDK installation (#14) Changed from --extra-index-url to --find-links in requirements.txt to properly install the rockfish SDK from the custom package repository at https://packages.rockfish.ai. Also updated CLAUDE.md to document this requirement and added .claude/ to .gitignore. Fixes #9 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Shane Co-authored-by: Claude commit 2c686cc0276003cb92b9ffd25b6605eebc676318 Author: Shane Duan <18065+wolfdancer@users.noreply.github.com> Date: Wed Nov 26 13:32:52 2025 -0800 Remove query_dataset tool from MCP server as the execute_query tool can do the same job and work with LLM much better.More specifically, the query_dataset requires the table name to be "my_table", which LLM is just not used to. Instead, execute_query uses the dataset id as the table name, and that works out of box with LLM. (#11) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Claude commit b22e44ba0402f47de24b1599750a15017b0968cd Merge: 0260809 b129acf Author: Shane Duan <18065+wolfdancer@users.noreply.github.com> Date: Tue Nov 25 19:34:01 2025 -0800 Merge pull request #7 from FanG-817/feature/workflow_rf_tab_gan Add SDK synthetic data generation workflow commit 026080922ca5d149a9f88926b88d9f5c2e41806f Merge: c5ba902 81d51be Author: Shane Duan <18065+wolfdancer@users.noreply.github.com> Date: Mon Nov 24 20:05:50 2025 -0800 Merge pull request #6 from FanG-817/refactor/standardize-api-url-naming update all "base_url" to "api_url" with backward compatibility as part of the deprecation process --- .gitignore | 1 + CLAUDE.md | 27 ++++++++++++++++++++++++++- requirements.txt | 2 +- src/rockfish_mcp/client.py | 22 ---------------------- src/rockfish_mcp/server.py | 19 ------------------- 5 files changed, 28 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index c15ad80..77f1956 100644 --- a/.gitignore +++ b/.gitignore @@ -39,6 +39,7 @@ venv.bak/ # IDE .vscode/ .idea/ +.claude/ *.swp *.swo *~ diff --git a/CLAUDE.md b/CLAUDE.md index b380be5..227572f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,9 +10,12 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co pip install -e . # Install from requirements.txt (for development) +# Note: The requirements.txt uses --find-links to access the Rockfish package repository pip install -r requirements.txt ``` +**Important**: The Rockfish SDK package is hosted on a custom package repository at `https://packages.rockfish.ai`. The [requirements.txt](requirements.txt) file uses `--find-links` directive to enable pip to discover and install packages from this repository. + ### Running the Server ```bash # Run the MCP server @@ -26,6 +29,8 @@ rockfish-mcp The application requires these environment variables: - `ROCKFISH_API_KEY`: Your Rockfish API key (required) - `ROCKFISH_API_URL`: API URL for Rockfish API (defaults to https://api.rockfish.ai) +- `ROCKFISH_ORGANIZATION_ID`: Organization ID (optional - uses default if not set) +- `ROCKFISH_PROJECT_ID`: Project ID (optional - uses default if not set) - `MANTA_API_URL`: API URL for Manta service (optional - Manta tools only appear if this is set) Create a `.env` file with these variables for local development: @@ -34,6 +39,26 @@ Create a `.env` file with these variables for local development: cp .env.example .env ``` +### Code Formatting +This project uses black and isort for code formatting: +```bash +# Format code before committing +isort src/rockfish_mcp/ +black src/rockfish_mcp/ + +# Check formatting without modifying files +isort --check-only src/rockfish_mcp/ +black --check src/rockfish_mcp/ +``` + +### Testing with MCP Inspector +Use the MCP Inspector to test the server before connecting to Claude Desktop: +```bash +# Start the inspector (replace with your actual Python path) +npx @modelcontextprotocol/inspector /path/to/.venv/bin/python -m rockfish_mcp.server +``` +The Inspector provides an interactive web interface to test all available tools. + ## Architecture Overview This is an MCP (Model Context Protocol) server that provides AI assistants access to the Rockfish machine learning platform API, the Manta dataset testing service, and the Rockfish SDK for synthetic data generation. The architecture consists of four main components in a simple, focused structure. @@ -52,7 +77,7 @@ src/rockfish_mcp/ **Server (`server.py`)**: The main MCP server that: - Defines tools across multiple resource categories - - Rockfish API: Databases, Worker Sets, Workflows, Models, Projects, Datasets (22 tools, always available) + - Rockfish API: Databases, Worker Sets, Workflows, Models, Projects, Datasets (21 tools, always available) - Manta Service: Prompt Management, Data Manipulation, LLM Processing (10 tools, conditional) - SDK Tools: Synthetic Data Generation workflow tools (9 tools, always available) - Conditionally loads Manta tools only when `MANTA_API_URL` environment variable is set diff --git a/requirements.txt b/requirements.txt index e4fe608..0456a78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ ---extra-index-url https://packages.rockfish.ai +--find-links https://packages.rockfish.ai black==24.10.0 httpx==0.27.0 isort==5.13.2 diff --git a/src/rockfish_mcp/client.py b/src/rockfish_mcp/client.py index 1393128..c9c9d04 100644 --- a/src/rockfish_mcp/client.py +++ b/src/rockfish_mcp/client.py @@ -201,27 +201,5 @@ async def call_endpoint( response.raise_for_status() return {"result": response.text} - elif tool_name == "query_dataset": - dataset_id = arguments["id"] - query = arguments["query"] - project_id = arguments.get("project_id") - - # Prepare headers for dataset query request - query_headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "text/plain", - } - if project_id: - query_headers["X-Project-ID"] = project_id - - # Make dataset query request with text/plain content - url = f"{self.api_url}/dataset/{dataset_id}/query" - async with httpx.AsyncClient() as client: - response = await client.request( - method="POST", url=url, headers=query_headers, content=query - ) - response.raise_for_status() - return {"result": response.text} - else: raise ValueError(f"Unknown tool: {tool_name}") diff --git a/src/rockfish_mcp/server.py b/src/rockfish_mcp/server.py index 63c632d..fb884a2 100644 --- a/src/rockfish_mcp/server.py +++ b/src/rockfish_mcp/server.py @@ -432,25 +432,6 @@ async def handle_list_tools() -> List[types.Tool]: "required": ["query"], }, ), - types.Tool( - name="query_dataset", - description="Execute a query against a specific dataset and return results in CSV format", - inputSchema={ - "type": "object", - "properties": { - "id": { - "type": "string", - "description": "Dataset ID to query against", - }, - "query": {"type": "string", "description": "The query to execute"}, - "project_id": { - "type": "string", - "description": "Optional project ID to execute the query in", - }, - }, - "required": ["id", "query"], - }, - ), ] # Add Manta tools only if Manta client is initialized From 631785adefcafae3e1e0e961449b4dc404e19828 Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 2 Dec 2025 15:57:38 -0800 Subject: [PATCH 07/10] extract and parse structure errors with detailed info; remove redundant checks --- src/rockfish_mcp/sdk_client.py | 531 +++++------------- src/rockfish_mcp/server.py | 5 - src/rockfish_mcp/validators.py | 968 ++++++--------------------------- 3 files changed, 294 insertions(+), 1210 deletions(-) diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index 5b04ac2..cd5cee2 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -2,6 +2,7 @@ import base64 import io import math +import re from typing import Any, Dict, Optional, Tuple # Set matplotlib to non-interactive backend BEFORE importing pyplot @@ -12,7 +13,6 @@ import rockfish as rf import rockfish.actions as ra import rockfish.labs as rl -from rockfish.converter import StructureError from rockfish.remote import glue matplotlib.use("Agg") @@ -24,11 +24,7 @@ logger = logging.getLogger(__name__) # Import validators for comprehensive DataSchema validation -from rockfish_mcp.validators import ( - ValidationError, - ValidationLevel, - validate_dataschema_comprehensive, -) +from rockfish_mcp.validators import ValidationError, validate_dataschema_comprehensive class RockfishSDKClient: @@ -402,51 +398,9 @@ async def collect_logs(): elif tool_name == "validate_data_schema_config": data_schema_config = arguments["data_schema_config"] entity_labels = arguments.get("entity_labels") - show_all_errors = arguments.get("show_all_errors", False) - # Python type checking (Layer 2 validation) - if not isinstance(data_schema_config, dict): - return { - "success": False, - "message": "data_schema_config must be a dictionary", - "error": f"Expected dict, got {type(data_schema_config).__name__}", - } - - if "entities" not in data_schema_config: - return { - "success": False, - "message": "data_schema_config must contain 'entities' field", - "error": "Missing required field: 'entities'", - } - - if not isinstance(data_schema_config["entities"], list): - return { - "success": False, - "message": "'entities' must be a list", - "error": f"Expected list, got {type(data_schema_config['entities']).__name__}", - } - - if len(data_schema_config["entities"]) == 0: - return { - "success": False, - "message": "'entities' list cannot be empty", - "error": "At least one entity is required", - } - - if entity_labels is not None and not isinstance(entity_labels, dict): - return { - "success": False, - "message": "entity_labels must be a dictionary", - "error": f"Expected dict, got {type(entity_labels).__name__}", - } - - # Validate schema using helper function (Layer 3 validation) - # Pass show_all_errors via schema_dict for _validate_data_schema - schema_dict_with_options = { - **data_schema_config, - "_show_all_errors": show_all_errors, - } - validation_result = _validate_data_schema(schema_dict_with_options) + # Validate schema (all validation handled by _validate_data_schema) + validation_result = _validate_data_schema(data_schema_config) if not validation_result["valid"]: # Return validation errors with enhanced details @@ -497,10 +451,6 @@ async def collect_logs(): "message": validation_result["summary"], } - # Add warnings if present (from comprehensive validation) - if "warnings" in validation_result and validation_result["warnings"]: - response["warnings"] = validation_result["warnings"] - return response elif tool_name == "start_data_schema_generation_workflow": schema_config_id = arguments["schema_config_id"] @@ -670,392 +620,169 @@ def guess_tab_gan_train_config(dataset) -> Tuple[ra.TrainTabGAN.Config, dict]: # Entity Data Generator helpers def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: """ - Validate DataSchema with enhanced error messages for common mistakes. + Validate DataSchema with enhanced error messages. + + Validation is performed in two layers: + 1. SDK validation via rf.converter.structure() - validates types, required fields, enums + 2. validators.py validation - validates business logic (PARAM_TS_07, TS_08, R11, R12) + + All errors are always shown (no limiting). Returns: Dict with validation results: - { - "valid": bool, - "data_schema": DataSchema (if valid), - "error_message": str (if invalid), - "suggestion": str (if invalid, optional), - "reference": str (if invalid, optional), - "summary": str - } + - If valid: + { + "valid": True, + "data_schema": DataSchema, + "summary": str + } + - If invalid: + { + "valid": False, + "error_count": int, + "error_message": str (numbered list of errors), + "summary": str + } """ try: # SDK validates all levels automatically via __attrs_post_init__ methods data_schema = rf.converter.structure(schema_dict, ra.ent.DataSchema) - # Run comprehensive validation (Layer 3) + # Run comprehensive validation via validators.py comprehensive_errors = validate_dataschema_comprehensive(data_schema) if comprehensive_errors: - # Convert ValidationError list to response format - # show_all parameter passed from tool arguments - validation_result = _format_comprehensive_errors( - comprehensive_errors, - schema_dict, - show_all=schema_dict.get("_show_all_errors", False), - ) - - # If valid=True (warnings only), add data_schema for caching - if validation_result["valid"]: - validation_result["data_schema"] = data_schema + # Format errors same way as structure error path + # Show ALL validator errors (no limiting) + error_list = [] + for idx, err in enumerate(comprehensive_errors, 1): + error_list.append(f"{idx}. {err.location}: {err.message}") - return validation_result + return { + "valid": False, + "error_count": len(comprehensive_errors), + "error_message": "\n\n".join(error_list), + "summary": f"Validation failed: {len(comprehensive_errors)} error(s)", + } return { "valid": True, "data_schema": data_schema, "summary": f"Validation successful: {len(schema_dict.get('entities', []))} entities", } - except StructureError as e: - # Parse StructureError message with intelligent error detection - error_msg = str(e) - helpful_error = _detect_common_errors(error_msg, schema_dict) - - return { - "valid": False, - "error_message": helpful_error["detailed_message"], - "suggestion": helpful_error.get("fix_suggestion"), - "reference": helpful_error.get("docs_link"), - "summary": f"Validation failed: {helpful_error['summary']}", - } except Exception as e: - # Catch other validation errors (from __attrs_post_init__) - error_msg = str(e) - helpful_error = _detect_common_errors(error_msg, schema_dict) - - return { - "valid": False, - "error_message": helpful_error["detailed_message"], - "suggestion": helpful_error.get("fix_suggestion"), - "reference": helpful_error.get("docs_link"), - "summary": f"Validation failed: {helpful_error['summary']}", - } - - -def _detect_common_errors( - error_msg: str, schema_dict: Dict[str, Any] -) -> Dict[str, Any]: - """ - Detect common error patterns and provide actionable fix suggestions. - - Common patterns: - 1. Domain structure errors (missing "params" wrapper) - 2. Derivation structure errors (missing "params" wrapper) - 3. Column type/category mismatches - 4. Domain/Derivation mutual exclusivity violations - 5. Missing required fields - 6. Invalid enum values - - Args: - error_msg: The error message from SDK validation - schema_dict: The schema dictionary being validated - - Returns: - Dict with summary, detailed_message, fix_suggestion, and docs_link - """ - import re - - # Pattern 1: Domain validation error - "expected Optional @ $.entities[X].columns[Y].domain" - domain_error_pattern = ( - r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.domain" - ) - match = re.search(domain_error_pattern, error_msg) - - if match: - entity_idx = int(match.group(1)) - col_idx = int(match.group(2)) - - # Extract the problematic domain - try: - entities = schema_dict.get("entities", []) - if entity_idx < len(entities): - entity = entities[entity_idx] - columns = entity.get("columns", []) - if col_idx < len(columns): - column = columns[col_idx] - domain = column.get("domain", {}) - - # Check if domain has wrong structure (missing "params" wrapper) - if isinstance(domain, dict) and "type" in domain: - domain_type = domain.get("type") - # Check if there are extra keys that should be in "params" - extra_keys = [ - k for k in domain.keys() if k not in ["type", "params"] - ] - - if extra_keys and "params" not in domain: - params_str = ", ".join([f'"{k}"' for k in extra_keys]) - return { - "summary": f"Domain structure error in column '{column.get('name', 'unknown')}'", - "detailed_message": ( - f"Domain validation failed for entity '{entity.get('name', 'unknown')}', " - f"column '{column.get('name', 'unknown')}' (entities[{entity_idx}].columns[{col_idx}]).\n\n" - f"ERROR: Domain parameters must be nested under 'params' key.\n\n" - f"The parameters {params_str} should be inside a 'params' dictionary.\n\n" - f"Correct structure:\n" - f' "domain": {{\n' - f' "type": "{domain_type}",\n' - f' "params": {{ {params_str}: ... }}\n' - f" }}" - ), - "fix_suggestion": f'Move {params_str} inside \'params\': {{"type": "{domain_type}", "params": {{...}}}}', - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Domain", - } - except (IndexError, KeyError, TypeError): - pass # Fall through to generic error - - # Pattern 2: Derivation validation error - similar structure - derivation_error_pattern = ( - r"expected Optional @ \$\.entities\[(\d+)\]\.columns\[(\d+)\]\.derivation" - ) - match = re.search(derivation_error_pattern, error_msg) - - if match: - entity_idx = int(match.group(1)) - col_idx = int(match.group(2)) - - try: - entities = schema_dict.get("entities", []) - if entity_idx < len(entities): - entity = entities[entity_idx] - columns = entity.get("columns", []) - if col_idx < len(columns): - column = columns[col_idx] - derivation = column.get("derivation", {}) - - if isinstance(derivation, dict) and "function_type" in derivation: - func_type = derivation.get("function_type") - extra_keys = [ - k - for k in derivation.keys() - if k not in ["function_type", "params", "dependent_columns"] - ] - - if extra_keys and "params" not in derivation: - params_str = ", ".join([f'"{k}"' for k in extra_keys]) - return { - "summary": f"Derivation structure error in column '{column.get('name', 'unknown')}'", - "detailed_message": ( - f"Derivation validation failed for entity '{entity.get('name', 'unknown')}', " - f"column '{column.get('name', 'unknown')}' (entities[{entity_idx}].columns[{col_idx}]).\n\n" - f"ERROR: Derivation parameters must be nested under 'params' key.\n\n" - f"Correct structure:\n" - f' "derivation": {{\n' - f' "function_type": "{func_type}",\n' - f' "dependent_columns": [...],\n' - f' "params": {{ {params_str}: ... }}\n' - f" }}" - ), - "fix_suggestion": f"Move {params_str} inside 'params'", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Derivation", - } - except (IndexError, KeyError, TypeError): - pass - - # Pattern 3: Enum validation errors - if "expected ColumnType" in error_msg: - return { - "summary": "Invalid column_type value", - "detailed_message": ( - f"{error_msg}\n\n" - "Valid values: 'independent', 'stateful', 'derived', 'foreign_key'\n\n" - "Usage guide:\n" - "- 'independent': Columns with non-temporal domains (ID, categorical, distributions)\n" - "- 'stateful': Columns with temporal domains (timeseries, state_machine)\n" - "- 'derived': Columns computed from other columns\n" - "- 'foreign_key': Auto-generated foreign key references" - ), - "fix_suggestion": "Use one of: 'independent', 'stateful', 'derived', 'foreign_key'", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", - } - - if ( - "expected ColumnCategoryType" in error_msg - or "expected CategoryType" in error_msg - ): - return { - "summary": "Invalid column_category_type value", - "detailed_message": ( - f"{error_msg}\n\n" - "Valid values: 'metadata', 'measurement'\n\n" - "IMPORTANT constraints:\n" - "- STATEFUL columns MUST use 'measurement' category\n" - "- FOREIGN_KEY columns MUST use 'metadata' category" - ), - "fix_suggestion": "Use 'metadata' or 'measurement' (stateful columns require 'measurement')", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", - } - - # Pattern 4: Domain AND derivation error - if "cannot have both domain and derivation" in error_msg.lower(): - return { - "summary": "Column has both domain and derivation", - "detailed_message": ( - f"{error_msg}\n\n" - "RULE: Columns must have EITHER domain OR derivation, not both.\n\n" - "- INDEPENDENT/STATEFUL columns: Use 'domain' (no 'derivation')\n" - "- DERIVED columns: Use 'derivation' (no 'domain')\n" - "- FOREIGN_KEY columns: Neither (auto-generated)" - ), - "fix_suggestion": "Remove either 'domain' or 'derivation' based on your column_type", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", - } - - # Pattern 5: Missing domain or derivation - if ( - "must have domain" in error_msg.lower() - or "must have derivation" in error_msg.lower() - ): - return { - "summary": "Missing required domain or derivation", - "detailed_message": ( - f"{error_msg}\n\n" - "Column requirements:\n" - "- INDEPENDENT columns: Must have 'domain'\n" - "- STATEFUL columns: Must have 'domain' (timeseries or state_machine)\n" - "- DERIVED columns: Must have 'derivation'\n" - "- FOREIGN_KEY columns: No domain/derivation needed (auto-generated)" - ), - "fix_suggestion": "Add appropriate 'domain' or 'derivation' based on column_type", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.Column", - } - - # Pattern 6: Missing global_timestamp - if "global_timestamp" in error_msg.lower() and "required" in error_msg.lower(): - return { - "summary": "Missing global_timestamp configuration", - "detailed_message": ( - f"{error_msg}\n\n" - "RULE: If any entity has a timestamp field, the DataSchema must include 'global_timestamp'.\n\n" - "Example:\n" - f' "global_timestamp": {{\n' - f' "start_timestamp": "2024-01-01T00:00:00",\n' - f' "cadence": {{"num_steps": 1000}}\n' - f" }}" - ), - "fix_suggestion": "Add 'global_timestamp' to your DataSchema with start_timestamp and cadence", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html#rockfish.actions.ent.GlobalTimestamp", - } + error_details = extract_structure_error_details(e) - # Pattern 7: Duplicate names - if "duplicate" in error_msg.lower() and "name" in error_msg.lower(): - return { - "summary": "Duplicate names detected", - "detailed_message": ( - f"{error_msg}\n\n" - "RULES:\n" - "- Entity names must be unique across the schema\n" - "- Column names must be unique within each entity" - ), - "fix_suggestion": "Ensure all entity names and column names are unique", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html", - } + # Format errors as numbered list + error_list = [] + for idx, err in enumerate(error_details["errors"], 1): + error_list.append(f"{idx}. {err['location']}: {err['error_message']}") - # Pattern 8: Missing required fields - if "missing" in error_msg.lower() or "required" in error_msg.lower(): return { - "summary": "Missing required field", - "detailed_message": ( - f"{error_msg}\n\n" - "Required fields by type:\n" - "- Column: name, data_type, column_type, column_category_type\n" - "- Entity: name, cardinality, columns (at least one)\n" - "- DataSchema: entities (at least one)\n" - "- Domain: type, params\n" - "- Derivation: function_type, dependent_columns, params" - ), - "fix_suggestion": "Add the missing required field indicated in the error", - "docs_link": "https://docs.rockfish.ai/sdk/actions-ent.html", + "valid": False, + "error_count": error_details["error_count"], + "error_message": "\n\n".join(error_list), + "summary": f"Validation failed: {error_details['summary']}", } - # Generic error (no pattern matched) - return { - "summary": "Validation error", - "detailed_message": error_msg, - "reference": "See https://docs.rockfish.ai/sdk/actions-ent.html for complete schema documentation", - } - -def _format_comprehensive_errors( - errors: list, schema_dict: Dict[str, Any], show_all: bool = False -) -> Dict[str, Any]: +def extract_structure_error_details(exc: Exception) -> Dict[str, Any]: """ - Convert ValidationError objects from validators.py to MCP tool response format. + Extract error message and location from StructureError exception chain. - Only ERROR level blocks validation. WARNING/INFO are advisory. + Simplified version that extracts only essential information: + - error_message: The actual error text from the exception + - location: The JSON path where the error occurred Args: - errors: List of ValidationError objects from validators.py - schema_dict: Original schema dictionary for context - show_all: If True, show all errors. If False, show first 5 (default) + exc: The exception to parse (typically StructureError) Returns: - Dict with keys: valid, error_message, suggestion, reference, summary - OR (if no ERROR level): valid, warnings, data_schema, summary - """ - # Separate by severity - error_level = [e for e in errors if e.level == ValidationLevel.ERROR] - warning_level = [e for e in errors if e.level == ValidationLevel.WARNING] - info_level = [e for e in errors if e.level == ValidationLevel.INFO] - - # Only ERROR level blocks validation - if not error_level: - # No errors - return success with warnings/info as advisory - advisory = [] - for w in warning_level + info_level: - advisory.append( + { + "error_count": N, + "summary": "Found N validation error(s)", + "errors": [ { - "level": str(w.level), - "rule": w.rule, - "message": w.message, - "location": w.location, - "suggestion": w.suggestion, - } - ) - - return { - "valid": True, - "warnings": advisory if advisory else [], - "summary": f"Validation passed with {len(warning_level)} warning(s) and {len(info_level)} info message(s)", + "error_message": "spike_magnitude (5.0) must be in [0, 1]", + "location": "$.entities[0].columns[1]" + }, + ... + ] } - # Has ERROR level - return error response - primary_error = error_level[0] - - # Determine how many errors to show - max_display = len(error_level) if show_all else min(5, len(error_level)) + Example: + >>> try: + ... rf.converter.structure(ts_dict, TimeseriesParams) + ... except Exception as e: + ... details = extract_structure_error_details(e) + ... for err in details['errors']: + ... print(f"{err['location']}: {err['error_message']}") + """ - # Build error details - error_details = [] - for idx, err in enumerate(error_level[:max_display], 1): - error_details.append( - f"{idx}. [{err.level}] {err.rule}: {err.message}\n" - f" Location: {err.location}" - ) + errors = [] + + def collect_errors(current_exc, location="$"): + """Recursively collect error messages and locations.""" + # Check for sub-exceptions (IterableValidationError) + if hasattr(current_exc, "exceptions") and current_exc.exceptions: + # Extract location info from wrapper message + msg = str(current_exc) + index_match = re.search(r"@ index (\d+)", msg) + if index_match: + index = index_match.group(1) + if "list[Column]" in msg: + new_location = f"{location}.columns[{index}]" + elif "list[Entity]" in msg: + new_location = f"{location}.entities[{index}]" + else: + new_location = f"{location}[{index}]" + else: + new_location = location - if len(error_level) > max_display: - error_details.append( - f"\n... and {len(error_level) - max_display} more error(s)" - ) - if not show_all: - error_details.append( - "Tip: Set show_all_errors=true in the tool call to see all errors" - ) + for sub_exc in current_exc.exceptions: + collect_errors(sub_exc, new_location) - detailed_message = "\n\n".join(error_details) + # Check for __cause__ chain + elif hasattr(current_exc, "__cause__") and current_exc.__cause__: + # Extract location from StructureError if present + if type(current_exc).__name__ == "StructureError": + msg = str(current_exc) + loc_match = re.search(r"@ (\$\.[^\s]+)", msg) + if loc_match: + location = loc_match.group(1) - summary = f"Validation failed: {len(error_level)} error(s)" - if warning_level or info_level: - summary += f" (+ {len(warning_level)} warning(s), {len(info_level)} info)" + collect_errors(current_exc.__cause__, location) + # Leaf error - add to list + else: + # Skip wrapper exceptions + if type(current_exc).__name__ not in ( + "StructureError", + "ClassValidationError", + "IterableValidationError", + "ExceptionGroup", + ): + errors.append({"error_message": str(current_exc), "location": location}) + + # Extract top-level location from StructureError + top_location = "$" + if type(exc).__name__ == "StructureError": + top_msg = str(exc) + loc_match = re.search(r"@ (\$\.[^\s]+)", top_msg) + if loc_match: + top_location = loc_match.group(1) + + # Collect all errors + collect_errors(exc, top_location) + + # If no errors found, use original exception + if not errors: + errors = [{"error_message": str(exc), "location": top_location}] + + # Return consistent format + error_word = "error" if len(errors) == 1 else "errors" return { - "valid": False, - "error_message": detailed_message, - "suggestion": primary_error.suggestion if primary_error.suggestion else "", - "reference": "https://docs.rockfish.ai/sdk/actions-ent.html", - "summary": summary, + "error_count": len(errors), + "summary": f"Found {len(errors)} validation {error_word}", + "errors": errors, } diff --git a/src/rockfish_mcp/server.py b/src/rockfish_mcp/server.py index fb884a2..775e306 100644 --- a/src/rockfish_mcp/server.py +++ b/src/rockfish_mcp/server.py @@ -1042,11 +1042,6 @@ async def handle_list_tools() -> List[types.Tool]: "description": "Optional entity label mappings for generated datasets", "additionalProperties": True, }, - "show_all_errors": { - "type": "boolean", - "description": "If true, show all validation errors. If false, show first 5 (default: false)", - "default": False, - }, }, "required": ["data_schema_config"], }, diff --git a/src/rockfish_mcp/validators.py b/src/rockfish_mcp/validators.py index 41e34f3..2ebd5d3 100644 --- a/src/rockfish_mcp/validators.py +++ b/src/rockfish_mcp/validators.py @@ -1,671 +1,212 @@ """ -Comprehensive validation utilities for DataSchema configurations. - -This module provides detailed validation beyond what rf.converter.structure() performs, -checking business logic rules, parameter constraints, and semantic relationships. - -Validation Layers: -1. Structure validation (rf.converter.structure) - type checking, required fields, enums -2. Parameter validation (this module) - ranges, constraints, logical consistency -3. Business logic validation (this module) - R1-R12 rules -4. Graph validation (planner) - circular dependencies - -Usage: - from ent.validators import validate_dataschema_comprehensive - - errors = validate_dataschema_comprehensive(schema) - if errors: - for error in errors: - print(f"[{error.level}] {error.rule}: {error.message}") +Validation layer for DataSchema configurations. + +This module provides business logic validation rules that are NOT covered by the +Rockfish SDK's built-in __attrs_post_init__ methods. These 4 validators catch +errors that would otherwise cause runtime failures during generation. + +The SDK already validates: +- Type checking (Layer 1) +- Required fields +- Enum values +- Basic parameter constraints (min < max, ranges [0,1], etc.) + +This module adds: +- PARAM_TS_07, TS_08: Peak hour range validation [0,23] +- R11: MEASUREMENT→MEASUREMENT dependency prevention +- R12: MEASUREMENT INDEPENDENT restriction """ from dataclasses import dataclass -from enum import Enum from typing import Any from rockfish.actions.ent import ( - CategoricalParams, Column, ColumnCategoryType, ColumnType, DataSchema, - Derivation, - DerivationFunctionType, - Domain, - DomainType, Entity, - EntityRelationship, - EntityRelationshipType, - ExponentialDistParams, - GlobalTimestamp, - IDParams, - MapValuesParams, - NormalDistParams, - StateMachineParams, - TimeseriesParams, - Transition, - UniformDistParams, ) +from rockfish.actions.ent.generate import Domain, DomainType, TimeseriesParams - -class ValidationLevel(str, Enum): - """Severity level of validation error.""" - - ERROR = "ERROR" # Must fix - will cause generation to fail - WARNING = "WARNING" # Should fix - may cause unexpected behavior - INFO = "INFO" # Informational - best practice suggestion +############################################################################### +# Core Validation Types +############################################################################### @dataclass class ValidationError: - """Structured validation error.""" + """Represents a validation error with context. + + All validation errors are treated as blocking errors (ERROR level). + """ - level: ValidationLevel - rule: str # e.g., "R1", "PARAM_UNIFORM_01", "COL_INDEPENDENT_01" + rule: str message: str - location: str # e.g., "entity 'users' > column 'age'" - suggestion: str = "" # Optional fix suggestion + location: str + suggestion: str = "" + + +############################################################################### +# Validator Class +############################################################################### class DataSchemaValidator: - """Comprehensive DataSchema validator.""" + """ + Validates DataSchema configurations with business logic rules. + + Only validates rules that are NOT already covered by the Rockfish SDK: + - PARAM_TS_07: peak_start_hour in [0,23] + - PARAM_TS_08: peak_end_hour in [0,23] + - R11: MEASUREMENT→MEASUREMENT dependency prevention + - R12: MEASUREMENT cannot be INDEPENDENT + """ def __init__(self, schema: DataSchema): self.schema = schema self.errors: list[ValidationError] = [] - self.entity_map = {entity.name: entity for entity in schema.entities} def validate_all(self) -> list[ValidationError]: - """Run all validation checks and return errors.""" - self.errors = [] + """ + Run all validation rules. - # Layer 2: Parameter validation + Returns: + List of ValidationError objects (empty if valid) + """ + # Validate TimeseriesParams peak hours (PARAM_TS_07, PARAM_TS_08) self._validate_domain_params() + + # Validate MEASUREMENT dependencies (R11) self._validate_derivation_params() - # Layer 3: Business logic (R1-R10) + # Validate column business rules (R12) self._validate_business_rules() - # Additional checks - self._validate_entities() - self._validate_relationships() - self._validate_global_timestamp() - return self.errors - # ========================================================================= - # Domain Parameter Validation - # ========================================================================= - def _validate_domain_params(self): - """Validate all domain parameters.""" - for entity in self.schema.entities: - for column in entity.columns: - if column.domain is None: - continue - - loc = f"entity '{entity.name}' > column '{column.name}'" - - if column.domain.type == DomainType.ID: - self._validate_id_params(column.domain.params, loc) - elif column.domain.type == DomainType.CATEGORICAL: - self._validate_categorical_params(column.domain.params, loc) - elif column.domain.type == DomainType.UNIFORM_DIST: - self._validate_uniform_params(column.domain.params, loc) - elif column.domain.type == DomainType.NORMAL_DIST: - self._validate_normal_params(column.domain.params, loc) - elif column.domain.type == DomainType.EXPONENTIAL_DIST: - self._validate_exponential_params(column.domain.params, loc) - elif column.domain.type == DomainType.TIMESERIES: - self._validate_timeseries_params(column.domain.params, loc) - elif column.domain.type == DomainType.STATE_MACHINE: - self._validate_state_machine_params(column.domain.params, loc) - - def _validate_id_params(self, params: IDParams, location: str): - """Validate IDParams: template must contain {id}.""" - if "{id}" not in params.template_str: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_ID_01", - message=f"IDParams template_str must contain '{{id}}' placeholder, got: '{params.template_str}'", - location=location, - suggestion="Use a template like 'USER_{id}' or 'item-{id}'", - ) - ) + """Validate Domain parameters for TimeseriesParams peak hours.""" + for entity_idx, entity in enumerate(self.schema.entities): + for column_idx, column in enumerate(entity.columns): + if column.domain: + location = f"$.entities[{entity_idx}] ({entity.name}) > columns[{column_idx}] ({column.name}) > domain" - def _validate_categorical_params(self, params: CategoricalParams, location: str): - """Validate CategoricalParams: values not empty, weights match length.""" - if not params.values: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_CAT_01", - message="CategoricalParams values list cannot be empty", - location=location, - suggestion="Provide at least one value in the values list", - ) - ) - - if params.weights is not None: - if len(params.weights) != len(params.values): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_CAT_02", - message=f"CategoricalParams weights length ({len(params.weights)}) must match values length ({len(params.values)})", - location=location, - suggestion="Either remove weights or ensure it has the same length as values", - ) - ) - - def _validate_uniform_params(self, params: UniformDistParams, location: str): - """Validate UniformDistParams: lower < upper.""" - if params.lower >= params.upper: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_UNIFORM_01", - message=f"UniformDistParams lower ({params.lower}) must be less than upper ({params.upper})", - location=location, - suggestion=f"Swap the values or use lower={params.upper}, upper={params.lower + (params.upper - params.lower) * 2}", - ) - ) - - def _validate_normal_params(self, params: NormalDistParams, location: str): - """Validate NormalDistParams: std > 0.""" - if params.std <= 0: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_NORMAL_01", - message=f"NormalDistParams std (standard deviation) must be positive, got: {params.std}", - location=location, - suggestion="Use a positive value like std=10.0 or std=1.5", - ) - ) - - def _validate_exponential_params( - self, params: ExponentialDistParams, location: str - ): - """Validate ExponentialDistParams: scale > 0.""" - if params.scale <= 0: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_EXP_01", - message=f"ExponentialDistParams scale must be positive, got: {params.scale}", - location=location, - suggestion="Use a positive value like scale=2.0", - ) - ) + # Only validate TimeseriesParams + if column.domain.type == DomainType.TIMESERIES: + self._validate_timeseries_params(column.domain.params, location) def _validate_timeseries_params(self, params: TimeseriesParams, location: str): - """Validate TimeseriesParams: 6 range and probability checks.""" - # min_value < max_value - if params.min_value >= params.max_value: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_01", - message=f"TimeseriesParams min_value ({params.min_value}) must be less than max_value ({params.max_value})", - location=location, - suggestion="Ensure min_value < base_value < max_value", - ) - ) - - # peak_start_hour < peak_end_hour (only relevant for peak_offpeak) - if params.seasonality_type == "peak_offpeak": - if params.peak_start_hour >= params.peak_end_hour: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_02", - message=f"TimeseriesParams peak_start_hour ({params.peak_start_hour}) must be less than peak_end_hour ({params.peak_end_hour})", - location=location, - suggestion="Use values like peak_start_hour=8, peak_end_hour=22 for business hours", - ) - ) - - # PARAM_TS_07: peak_start_hour must be in [0, 23] - if not (0 <= params.peak_start_hour <= 23): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_07", - message=f"TimeseriesParams peak_start_hour must be in [0, 23], got: {params.peak_start_hour}", - location=location, - suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", - ) - ) - - # PARAM_TS_08: peak_end_hour must be in [0, 23] - if not (0 <= params.peak_end_hour <= 23): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_08", - message=f"TimeseriesParams peak_end_hour must be in [0, 23], got: {params.peak_end_hour}", - location=location, - suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", - ) - ) - - # seasonality_strength in [0, 1] - if not (0.0 <= params.seasonality_strength <= 1.0): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_03", - message=f"TimeseriesParams seasonality_strength must be in [0, 1], got: {params.seasonality_strength}", - location=location, - suggestion="Use a value between 0.0 (no seasonality) and 1.0 (strong seasonality)", - ) - ) - - # noise_level in [0, 1] - if not (0.0 <= params.noise_level <= 1.0): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_04", - message=f"TimeseriesParams noise_level must be in [0, 1], got: {params.noise_level}", - location=location, - suggestion="Use a value between 0.0 (no noise) and 1.0 (high noise)", - ) - ) - - # spike_probability in [0, 1] - if not (0.0 <= params.spike_probability <= 1.0): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_05", - message=f"TimeseriesParams spike_probability must be in [0, 1], got: {params.spike_probability}", - location=location, - suggestion="Use a value between 0.0 (no spikes) and 1.0 (frequent spikes)", - ) - ) - - # spike_magnitude in [0, 1] - if not (0.0 <= params.spike_magnitude <= 1.0): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_TS_06", - message=f"TimeseriesParams spike_magnitude must be in [0, 1], got: {params.spike_magnitude}", - location=location, - suggestion="Use a value between 0.0 (small spikes) and 1.0 (large spikes)", - ) - ) - - def _validate_state_machine_params(self, params: StateMachineParams, location: str): - """Validate StateMachineParams: states, transitions, context variables.""" - # initial_state in states - if params.initial_state not in params.states: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_01", - message=f"StateMachineParams initial_state '{params.initial_state}' not in states list: {params.states}", - location=location, - suggestion=f"Add '{params.initial_state}' to states or use one of: {params.states}", - ) - ) - - # terminal_states all in states - for terminal in params.terminal_states: - if terminal not in params.states: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_02", - message=f"StateMachineParams terminal_state '{terminal}' not in states list: {params.states}", - location=location, - suggestion=f"Add '{terminal}' to states or remove from terminal_states", - ) - ) - - # Validate each transition - for idx, trans in enumerate(params.transitions): - trans_loc = f"{location} > transition {idx}" - self._validate_transition( - trans, params.states, params.context_variables, trans_loc - ) - - def _validate_transition( - self, - trans: Transition, - states: list[str], - context_vars: dict[str, bool], - location: str, - ): - """Validate a single transition.""" - # source in states - if trans.source not in states: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_03", - message=f"Transition source '{trans.source}' not in states list: {states}", - location=location, - suggestion=f"Add '{trans.source}' to states or change transition source", - ) - ) - - # dest in states - if trans.dest not in states: + """ + Validate TimeseriesParams peak hour ranges (SDK doesn't validate these). + + The SDK validates: + - min_value < max_value + - peak_start_hour < peak_end_hour + - seasonality_strength in [0,1] + - noise_level in [0,1] + - spike_probability in [0,1] + - spike_magnitude in [0,1] + + This validator adds: + - PARAM_TS_07: peak_start_hour in [0,23] + - PARAM_TS_08: peak_end_hour in [0,23] + """ + # PARAM_TS_07: peak_start_hour must be in [0, 23] + if params.peak_start_hour is not None and not ( + 0 <= params.peak_start_hour <= 23 + ): self.errors.append( ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_04", - message=f"Transition dest '{trans.dest}' not in states list: {states}", + rule="PARAM_TS_07", + message=f"TimeseriesParams peak_start_hour must be in [0, 23], got: {params.peak_start_hour}", location=location, - suggestion=f"Add '{trans.dest}' to states or change transition dest", + suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", ) ) - # probability in (0, 1] - if not (0.0 < trans.probability <= 1.0): + # PARAM_TS_08: peak_end_hour must be in [0, 23] + if params.peak_end_hour is not None and not (0 <= params.peak_end_hour <= 23): self.errors.append( ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_05", - message=f"Transition probability must be > 0 and <= 1, got: {trans.probability}", + rule="PARAM_TS_08", + message=f"TimeseriesParams peak_end_hour must be in [0, 23], got: {params.peak_end_hour}", location=location, - suggestion="Use a probability value like 0.7 or 0.3", + suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", ) ) - # conditions reference valid context vars - for cond in trans.conditions: - if cond not in context_vars: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_06", - message=f"Transition condition '{cond}' not defined in context_variables: {list(context_vars.keys())}", - location=location, - suggestion=f"Add '{cond}' to context_variables or remove from conditions", - ) - ) - - # context_updates reference valid context vars - for key in trans.context_updates: - if key not in context_vars: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_SM_07", - message=f"Transition context_update key '{key}' not defined in context_variables: {list(context_vars.keys())}", - location=location, - suggestion=f"Add '{key}' to context_variables or remove from context_updates", - ) - ) - - # ========================================================================= - # Derivation Parameter Validation - # ========================================================================= - def _validate_derivation_params(self): - """Validate all derivation parameters.""" - for entity in self.schema.entities: - for column in entity.columns: - if column.derivation is None: - continue - - loc = f"entity '{entity.name}' > column '{column.name}'" + """Validate Derivation parameters and dependencies.""" + for entity_idx, entity in enumerate(self.schema.entities): + for column_idx, column in enumerate(entity.columns): + if column.derivation: + location = f"$.entities[{entity_idx}] ({entity.name}) > columns[{column_idx}] ({column.name}) > derivation" - if column.derivation.function_type == DerivationFunctionType.MAP_VALUES: - self._validate_map_values_params(column.derivation.params, loc) - - # Check for unsupported cross-category MEASUREMENT dependencies - self._validate_measurement_dependencies(entity, column, loc) - - def _validate_map_values_params(self, params: MapValuesParams, location: str): - """Validate MapValuesParams: mapping not empty, rules have from/to.""" - if not params.mapping: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_MAP_01", - message="MapValuesParams mapping list cannot be empty", - location=location, - suggestion='Provide mapping rules like [{"from": "active", "to": "high"}]', - ) - ) - - for idx, rule in enumerate(params.mapping): - if "from" not in rule: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_MAP_02", - message=f"MapValuesParams mapping rule {idx} missing 'from' key", - location=location, - suggestion=f"Add 'from' key to rule: {rule}", - ) - ) - if "to" not in rule: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="PARAM_MAP_03", - message=f"MapValuesParams mapping rule {idx} missing 'to' key", - location=location, - suggestion=f"Add 'to' key to rule: {rule}", - ) - ) + # R11: Check for MEASUREMENT→MEASUREMENT dependencies - unsolved issue? + self._validate_measurement_dependencies(entity, column, location) def _validate_measurement_dependencies( self, entity: Entity, column: Column, location: str ): """ - Validate that MEASUREMENT derived columns don't have same-entity MEASUREMENT dependencies. + R11: MEASUREMENT derived columns cannot depend on same-entity MEASUREMENT columns. - This is currently unsupported and will cause a KeyError at runtime because MEASUREMENT - columns are generated in an arbitrary order when they don't have explicit dependencies - tracked in the column graph. - """ - # Only check MEASUREMENT DERIVED columns - if column.column_category_type != ColumnCategoryType.MEASUREMENT: - return - if column.column_type != ColumnType.DERIVED: - return - if column.derivation is None: - return - - # Build a map of column names to their categories in this entity - entity_columns = {col.name: col for col in entity.columns} - - # Check each dependency - for dep_col_name in column.derivation.dependent_columns: - # Skip cross-entity dependencies (they're fine because dependent entity is generated first) - if "." in dep_col_name: - continue - - # Check if this is a same-entity dependency - dep_col = entity_columns.get(dep_col_name) - if dep_col is None: - # Dependency doesn't exist in this entity - will be caught by other validation - continue - - # Check if the dependency is also a MEASUREMENT column - if dep_col.column_category_type == ColumnCategoryType.MEASUREMENT: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R11", - message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", - location=location, - suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", - ) - ) + This is a critical validator that prevents runtime KeyError during generation. + MEASUREMENT columns are generated in arbitrary order, so dependencies between + them in the same entity will fail. - # ========================================================================= - # Business Logic Validation (R1-R12) - # ========================================================================= - - def _validate_business_rules(self): - """Validate R1-R12 business logic rules.""" - # R1: If any entity has Timestamp → GlobalTimestamp required - entities_with_timestamps = [ - e.name for e in self.schema.entities if e.timestamp is not None - ] - if entities_with_timestamps and self.schema.global_timestamp is None: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R1", - message=f"Entities {entities_with_timestamps} have timestamp, but global_timestamp is not defined", - location="schema root", - suggestion="Add global_timestamp configuration with t_start, t_end, and time_interval", + The SDK does NOT validate this. + """ + if ( + column.column_category_type == ColumnCategoryType.MEASUREMENT + and column.derivation + ): + # Check each dependency + for dep_col_name in column.derivation.dependencies: + # Find dependency column in same entity + dep_col = next( + (col for col in entity.columns if col.name == dep_col_name), None ) - ) - # R2-R5, R10-R12: Column-level rules (validated per column) - for entity in self.schema.entities: - for column in entity.columns: - loc = f"entity '{entity.name}' > column '{column.name}'" - self._validate_column_business_rules(column, loc) - - # R6: Entity with Timestamp → Must have ≥1 measurement column - if entity.timestamp is not None: - has_measurement = any( - col.column_category_type == ColumnCategoryType.MEASUREMENT - for col in entity.columns - ) - if not has_measurement: + if ( + dep_col + and dep_col.column_category_type == ColumnCategoryType.MEASUREMENT + ): self.errors.append( ValidationError( - level=ValidationLevel.ERROR, - rule="R6", - message=f"Entity '{entity.name}' has timestamp but no measurement columns", - location=f"entity '{entity.name}'", - suggestion="Add at least one column with column_category_type='measurement'", + rule="R11", + message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", + location=location, + suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", ) ) - # R7: ONE_TO_ONE relationship → from.cardinality ≤ to.cardinality - if self.schema.entity_relationships: - for rel in self.schema.entity_relationships: - if rel.relationship_type == EntityRelationshipType.ONE_TO_ONE: - from_entity = self.entity_map.get(rel.from_entity) - to_entity = self.entity_map.get(rel.to_entity) - if from_entity and to_entity: - if from_entity.cardinality > to_entity.cardinality: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R7", - message=f"ONE_TO_ONE relationship from '{rel.from_entity}' to '{rel.to_entity}': child cardinality ({from_entity.cardinality}) cannot exceed parent cardinality ({to_entity.cardinality})", - location=f"relationship {rel.from_entity} -> {rel.to_entity}", - suggestion=f"Either increase '{rel.to_entity}' cardinality or reduce '{rel.from_entity}' cardinality", - ) - ) - - # R8: Entity names must be unique - entity_names = [e.name for e in self.schema.entities] - duplicates = [ - name for name in set(entity_names) if entity_names.count(name) > 1 - ] - if duplicates: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R8", - message=f"Duplicate entity names found: {duplicates}", - location="schema root", - suggestion="Rename entities to have unique names", - ) - ) + def _validate_business_rules(self): + """Validate business logic rules.""" + for entity_idx, entity in enumerate(self.schema.entities): + for column_idx, column in enumerate(entity.columns): + location = f"$.entities[{entity_idx}] ({entity.name}) > columns[{column_idx}] ({column.name})" - # R9: Column names unique within entity - for entity in self.schema.entities: - column_names = [c.name for c in entity.columns] - duplicates = [ - name for name in set(column_names) if column_names.count(name) > 1 - ] - if duplicates: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R9", - message=f"Duplicate column names in entity '{entity.name}': {duplicates}", - location=f"entity '{entity.name}'", - suggestion="Rename columns to have unique names within the entity", - ) - ) + # R12: MEASUREMENT cannot be INDEPENDENT - will added in the next release + self._validate_column_business_rules(column, location) def _validate_column_business_rules(self, column: Column, location: str): - """Validate business rules R2-R5, R10-R12 for a single column.""" - # R2: STATEFUL → must be MEASUREMENT - if column.column_type == ColumnType.STATEFUL: - if column.column_category_type != ColumnCategoryType.MEASUREMENT: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R2", - message=f"STATEFUL column must be MEASUREMENT category, got: {column.column_category_type}", - location=location, - suggestion="Change column_category_type to 'measurement'", - ) - ) - - # R3: STATEFUL → domain must be STATE_MACHINE or TIMESERIES - if column.column_type == ColumnType.STATEFUL: - if column.domain is None or column.domain.type not in ( - DomainType.STATE_MACHINE, - DomainType.TIMESERIES, - ): - domain_type = column.domain.type if column.domain else "None" - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R3", - message=f"STATEFUL column must have STATE_MACHINE or TIMESERIES domain, got: {domain_type}", - location=location, - suggestion="Use domain.type='timeseries' or domain.type='state_machine'", - ) - ) - - # R4: INDEPENDENT → domain CANNOT be STATE_MACHINE or TIMESERIES - if column.column_type == ColumnType.INDEPENDENT: - if column.domain and column.domain.type in ( - DomainType.STATE_MACHINE, - DomainType.TIMESERIES, - ): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R4", - message=f"INDEPENDENT column cannot have temporal domain ({column.domain.type})", - location=location, - suggestion="Use a non-temporal domain like 'categorical', 'uniform_dist', or 'id'", - ) - ) + """ + Validate column-level business rules. - # R5: FOREIGN_KEY → must be METADATA - if column.column_type == ColumnType.FOREIGN_KEY: - if column.column_category_type != ColumnCategoryType.METADATA: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R5", - message=f"FOREIGN_KEY column must be METADATA category, got: {column.column_category_type}", - location=location, - suggestion="Change column_category_type to 'metadata'", - ) - ) + R12: MEASUREMENT columns cannot be INDEPENDENT type. - # R12: MEASUREMENT columns cannot be INDEPENDENT (currently unsupported) + This combination is currently unsupported by generation but the SDK allows it. + This validator prevents invalid configurations. + """ + # R12: MEASUREMENT column cannot be INDEPENDENT if ( column.column_category_type == ColumnCategoryType.MEASUREMENT and column.column_type == ColumnType.INDEPENDENT ): self.errors.append( ValidationError( - level=ValidationLevel.ERROR, rule="R12", message=f"MEASUREMENT column cannot be INDEPENDENT (currently unsupported)", location=location, @@ -673,245 +214,66 @@ def _validate_column_business_rules(self, column: Column, location: str): ) ) - # R10: Column has EXACTLY ONE: domain OR derivation OR neither (FK only) - has_domain = column.domain is not None - has_derivation = column.derivation is not None - - if column.column_type in (ColumnType.INDEPENDENT, ColumnType.STATEFUL): - if not has_domain: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R10", - message=f"{column.column_type} column must have domain", - location=location, - suggestion="Add a domain configuration for this column", - ) - ) - if has_derivation: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R10", - message=f"{column.column_type} column cannot have derivation", - location=location, - suggestion="Remove the derivation field", - ) - ) - - elif column.column_type == ColumnType.DERIVED: - if has_domain: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R10", - message="DERIVED column cannot have domain", - location=location, - suggestion="Remove the domain field", - ) - ) - if not has_derivation: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R10", - message="DERIVED column must have derivation", - location=location, - suggestion="Add a derivation configuration for this column", - ) - ) - - elif column.column_type == ColumnType.FOREIGN_KEY: - if has_domain or has_derivation: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="R10", - message="FOREIGN_KEY column cannot have domain or derivation", - location=location, - suggestion="Remove domain and derivation fields (they are auto-populated)", - ) - ) - - # ========================================================================= - # Entity Validation - # ========================================================================= - - def _validate_entities(self): - """Validate entity-level constraints.""" - for entity in self.schema.entities: - loc = f"entity '{entity.name}'" - - # Cardinality must be positive - if entity.cardinality <= 0: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="ENT_01", - message=f"Entity cardinality must be positive, got: {entity.cardinality}", - location=loc, - suggestion="Use a positive integer like cardinality=100", - ) - ) - - # Must have at least one column - if not entity.columns: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="ENT_02", - message="Entity must have at least one column", - location=loc, - suggestion="Add column definitions to this entity", - ) - ) - - # ========================================================================= - # Relationship Validation - # ========================================================================= - - def _validate_relationships(self): - """Validate entity relationship constraints.""" - if not self.schema.entity_relationships: - return - - for rel in self.schema.entity_relationships: - loc = f"relationship {rel.from_entity} -> {rel.to_entity}" - - # join_columns cannot be empty - if not rel.join_columns: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="REL_01", - message="Relationship join_columns cannot be empty", - location=loc, - suggestion='Add join_columns like {"user_id": "user_id"}', - ) - ) - - # from_entity ≠ to_entity - if rel.from_entity == rel.to_entity: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="REL_02", - message=f"Relationship cannot be self-referential: '{rel.from_entity}'", - location=loc, - suggestion="Create relationships between different entities", - ) - ) - - # from_entity and to_entity must exist - if rel.from_entity not in self.entity_map: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="REL_03", - message=f"Relationship references unknown from_entity: '{rel.from_entity}'", - location=loc, - suggestion=f"Use one of: {list(self.entity_map.keys())}", - ) - ) - - if rel.to_entity not in self.entity_map: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="REL_04", - message=f"Relationship references unknown to_entity: '{rel.to_entity}'", - location=loc, - suggestion=f"Use one of: {list(self.entity_map.keys())}", - ) - ) - - # ========================================================================= - # GlobalTimestamp Validation - # ========================================================================= - - def _validate_global_timestamp(self): - """Validate GlobalTimestamp constraints.""" - if self.schema.global_timestamp is None: - return - - gt = self.schema.global_timestamp - loc = "global_timestamp" - - # time_interval format validation (already done in config.py __attrs_post_init__) - # Just add a reminder check - import re - - pattern = r"^\d+(min|hour|day|month)$" - if not re.match(pattern, gt.time_interval): - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="GT_01", - message=f"GlobalTimestamp time_interval format invalid: '{gt.time_interval}'", - location=loc, - suggestion="Use format like '15min', '1hour', '1day', or '3months'", - ) - ) - - # GT_02: t_start < t_end - if gt.t_start >= gt.t_end: - self.errors.append( - ValidationError( - level=ValidationLevel.ERROR, - rule="GT_02", - message=f"GlobalTimestamp t_start ({gt.t_start}) must be less than t_end ({gt.t_end})", - location=loc, - suggestion="Ensure start time is before end time", - ) - ) - -# ============================================================================= +############################################################################### # Public API -# ============================================================================= +############################################################################### -def validate_dataschema_comprehensive( - schema: DataSchema, -) -> list[ValidationError]: +def validate_dataschema_comprehensive(schema: DataSchema) -> list[ValidationError]: """ - Run comprehensive validation on a DataSchema. + Comprehensive validation of DataSchema configuration. + + This function validates business logic rules that are NOT covered by the + Rockfish SDK's built-in __attrs_post_init__ methods. - This performs validation beyond rf.converter.structure(), checking: - - Parameter ranges and constraints (30 rules) - - Business logic rules (R1-R12) - - Semantic relationships + Validates: + - PARAM_TS_07: peak_start_hour in [0,23] + - PARAM_TS_08: peak_end_hour in [0,23] + - R11: MEASUREMENT→MEASUREMENT dependency prevention (prevents runtime KeyError) + - R12: MEASUREMENT INDEPENDENT restriction (unsupported combination) + + The SDK already validates everything else (types, required fields, enums, + parameter ranges, entity constraints, etc.). Args: - schema: DataSchema object to validate + schema: DataSchema object to validate (already structured by SDK) Returns: List of ValidationError objects (empty if valid) Example: + >>> schema = rf.converter.structure(schema_dict, DataSchema) >>> errors = validate_dataschema_comprehensive(schema) >>> if errors: ... for err in errors: - ... print(f"[{err.rule}] {err.message}") - ... else: - ... print("Schema is valid!") + ... print(f"{err.rule}: {err.message}") """ validator = DataSchemaValidator(schema) return validator.validate_all() def format_validation_errors(errors: list[ValidationError]) -> str: - """Format validation errors as a readable report.""" - if not errors: - return "✅ Schema validation passed!" + """ + Format validation errors for display. - report = [f"❌ Found {len(errors)} validation error(s):\n"] + All errors are treated as ERROR level (blocking). + + Args: + errors: List of ValidationError objects + + Returns: + Formatted string with all errors + """ + if not errors: + return "No validation errors" + lines = [] for idx, err in enumerate(errors, 1): - report.append(f"{idx}. [{err.level}] {err.rule}: {err.message}") - report.append(f" Location: {err.location}") + lines.append( + f"{idx}. [ERROR] {err.rule}: {err.message}\n" f" Location: {err.location}" + ) if err.suggestion: - report.append(f" Suggestion: {err.suggestion}") - report.append("") + lines.append(f" Suggestion: {err.suggestion}") - return "\n".join(report) + return "\n\n".join(lines) From 1f4746b0016cf18d22977d47a4bef9cf82c28976 Mon Sep 17 00:00:00 2001 From: Fan Date: Fri, 5 Dec 2025 12:53:45 -0800 Subject: [PATCH 08/10] comment incorrect and unused functions for validators.py --- src/rockfish_mcp/validators.py | 78 +++++++++++++++++----------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/src/rockfish_mcp/validators.py b/src/rockfish_mcp/validators.py index 2ebd5d3..d8869ca 100644 --- a/src/rockfish_mcp/validators.py +++ b/src/rockfish_mcp/validators.py @@ -78,7 +78,7 @@ def validate_all(self) -> list[ValidationError]: self._validate_domain_params() # Validate MEASUREMENT dependencies (R11) - self._validate_derivation_params() + # self._validate_derivation_params() # Validate column business rules (R12) self._validate_business_rules() @@ -96,6 +96,7 @@ def _validate_domain_params(self): if column.domain.type == DomainType.TIMESERIES: self._validate_timeseries_params(column.domain.params, location) + # Will be handled in https://github.com/Rockfish-Data/cuttlefish/pull/1102 def _validate_timeseries_params(self, params: TimeseriesParams, location: str): """ Validate TimeseriesParams peak hour ranges (SDK doesn't validate these). @@ -145,43 +146,44 @@ def _validate_derivation_params(self): # R11: Check for MEASUREMENT→MEASUREMENT dependencies - unsolved issue? self._validate_measurement_dependencies(entity, column, location) - - def _validate_measurement_dependencies( - self, entity: Entity, column: Column, location: str - ): - """ - R11: MEASUREMENT derived columns cannot depend on same-entity MEASUREMENT columns. - - This is a critical validator that prevents runtime KeyError during generation. - MEASUREMENT columns are generated in arbitrary order, so dependencies between - them in the same entity will fail. - - The SDK does NOT validate this. - """ - if ( - column.column_category_type == ColumnCategoryType.MEASUREMENT - and column.derivation - ): - # Check each dependency - for dep_col_name in column.derivation.dependencies: - # Find dependency column in same entity - dep_col = next( - (col for col in entity.columns if col.name == dep_col_name), None - ) - - if ( - dep_col - and dep_col.column_category_type == ColumnCategoryType.MEASUREMENT - ): - self.errors.append( - ValidationError( - rule="R11", - message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", - location=location, - suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", - ) - ) - + # This is incorrect. The correct validation is added in https://github.com/Rockfish-Data/cuttlefish/pull/1099 + # def _validate_measurement_dependencies( + # self, entity: Entity, column: Column, location: str + # ): + # """ + # R11: MEASUREMENT derived columns cannot depend on same-entity MEASUREMENT columns. + + # This is a critical validator that prevents runtime KeyError during generation. + # MEASUREMENT columns are generated in arbitrary order, so dependencies between + # them in the same entity will fail. + + # The SDK does NOT validate this. + # """ + # if ( + # column.column_category_type == ColumnCategoryType.MEASUREMENT + # and column.derivation + # ): + # # Check each dependency + # for dep_col_name in column.derivation.dependencies: + # # Find dependency column in same entity + # dep_col = next( + # (col for col in entity.columns if col.name == dep_col_name), None + # ) + + # if ( + # dep_col + # and dep_col.column_category_type == ColumnCategoryType.MEASUREMENT + # ): + # self.errors.append( + # ValidationError( + # rule="R11", + # message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", + # location=location, + # suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", + # ) + # ) + + # This is handled in https://github.com/Rockfish-Data/cuttlefish/pull/1097 def _validate_business_rules(self): """Validate business logic rules.""" for entity_idx, entity in enumerate(self.schema.entities): From 893f073951ddbf10eeddbc8fb25ff13f867f4853 Mon Sep 17 00:00:00 2001 From: Fan Date: Fri, 5 Dec 2025 17:43:44 -0800 Subject: [PATCH 09/10] remove validators.py- these extra validations should be added in rockfish sdk --- src/rockfish_mcp/sdk_client.py | 26 +-- src/rockfish_mcp/validators.py | 281 --------------------------------- 2 files changed, 2 insertions(+), 305 deletions(-) delete mode 100644 src/rockfish_mcp/validators.py diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index cd5cee2..ea9f28b 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -23,9 +23,6 @@ logger = logging.getLogger(__name__) -# Import validators for comprehensive DataSchema validation -from rockfish_mcp.validators import ValidationError, validate_dataschema_comprehensive - class RockfishSDKClient: def __init__( @@ -425,7 +422,7 @@ async def collect_logs(): cache_entry = { "data_schema_config": data_schema_config, } - # TODO: entity_labels? + if entity_labels: cache_entry["entity_labels"] = entity_labels @@ -466,9 +463,7 @@ async def collect_logs(): # Retrieve and remove from cache cache_entry = self._cache.pop(schema_config_id) data_schema_dict = cache_entry["data_schema_config"] - entity_labels = cache_entry.get( - "entity_labels" - ) # TODO: make use of entity_labels in GenerateFromDataSchema + entity_labels = cache_entry.get("entity_labels") # Convert dict to DataSchema (should not fail since we already validated) try: @@ -648,23 +643,6 @@ def _validate_data_schema(schema_dict: Dict[str, Any]) -> Dict[str, Any]: # SDK validates all levels automatically via __attrs_post_init__ methods data_schema = rf.converter.structure(schema_dict, ra.ent.DataSchema) - # Run comprehensive validation via validators.py - comprehensive_errors = validate_dataschema_comprehensive(data_schema) - - if comprehensive_errors: - # Format errors same way as structure error path - # Show ALL validator errors (no limiting) - error_list = [] - for idx, err in enumerate(comprehensive_errors, 1): - error_list.append(f"{idx}. {err.location}: {err.message}") - - return { - "valid": False, - "error_count": len(comprehensive_errors), - "error_message": "\n\n".join(error_list), - "summary": f"Validation failed: {len(comprehensive_errors)} error(s)", - } - return { "valid": True, "data_schema": data_schema, diff --git a/src/rockfish_mcp/validators.py b/src/rockfish_mcp/validators.py deleted file mode 100644 index d8869ca..0000000 --- a/src/rockfish_mcp/validators.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Validation layer for DataSchema configurations. - -This module provides business logic validation rules that are NOT covered by the -Rockfish SDK's built-in __attrs_post_init__ methods. These 4 validators catch -errors that would otherwise cause runtime failures during generation. - -The SDK already validates: -- Type checking (Layer 1) -- Required fields -- Enum values -- Basic parameter constraints (min < max, ranges [0,1], etc.) - -This module adds: -- PARAM_TS_07, TS_08: Peak hour range validation [0,23] -- R11: MEASUREMENT→MEASUREMENT dependency prevention -- R12: MEASUREMENT INDEPENDENT restriction -""" - -from dataclasses import dataclass -from typing import Any - -from rockfish.actions.ent import ( - Column, - ColumnCategoryType, - ColumnType, - DataSchema, - Entity, -) -from rockfish.actions.ent.generate import Domain, DomainType, TimeseriesParams - -############################################################################### -# Core Validation Types -############################################################################### - - -@dataclass -class ValidationError: - """Represents a validation error with context. - - All validation errors are treated as blocking errors (ERROR level). - """ - - rule: str - message: str - location: str - suggestion: str = "" - - -############################################################################### -# Validator Class -############################################################################### - - -class DataSchemaValidator: - """ - Validates DataSchema configurations with business logic rules. - - Only validates rules that are NOT already covered by the Rockfish SDK: - - PARAM_TS_07: peak_start_hour in [0,23] - - PARAM_TS_08: peak_end_hour in [0,23] - - R11: MEASUREMENT→MEASUREMENT dependency prevention - - R12: MEASUREMENT cannot be INDEPENDENT - """ - - def __init__(self, schema: DataSchema): - self.schema = schema - self.errors: list[ValidationError] = [] - - def validate_all(self) -> list[ValidationError]: - """ - Run all validation rules. - - Returns: - List of ValidationError objects (empty if valid) - """ - # Validate TimeseriesParams peak hours (PARAM_TS_07, PARAM_TS_08) - self._validate_domain_params() - - # Validate MEASUREMENT dependencies (R11) - # self._validate_derivation_params() - - # Validate column business rules (R12) - self._validate_business_rules() - - return self.errors - - def _validate_domain_params(self): - """Validate Domain parameters for TimeseriesParams peak hours.""" - for entity_idx, entity in enumerate(self.schema.entities): - for column_idx, column in enumerate(entity.columns): - if column.domain: - location = f"$.entities[{entity_idx}] ({entity.name}) > columns[{column_idx}] ({column.name}) > domain" - - # Only validate TimeseriesParams - if column.domain.type == DomainType.TIMESERIES: - self._validate_timeseries_params(column.domain.params, location) - - # Will be handled in https://github.com/Rockfish-Data/cuttlefish/pull/1102 - def _validate_timeseries_params(self, params: TimeseriesParams, location: str): - """ - Validate TimeseriesParams peak hour ranges (SDK doesn't validate these). - - The SDK validates: - - min_value < max_value - - peak_start_hour < peak_end_hour - - seasonality_strength in [0,1] - - noise_level in [0,1] - - spike_probability in [0,1] - - spike_magnitude in [0,1] - - This validator adds: - - PARAM_TS_07: peak_start_hour in [0,23] - - PARAM_TS_08: peak_end_hour in [0,23] - """ - # PARAM_TS_07: peak_start_hour must be in [0, 23] - if params.peak_start_hour is not None and not ( - 0 <= params.peak_start_hour <= 23 - ): - self.errors.append( - ValidationError( - rule="PARAM_TS_07", - message=f"TimeseriesParams peak_start_hour must be in [0, 23], got: {params.peak_start_hour}", - location=location, - suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", - ) - ) - - # PARAM_TS_08: peak_end_hour must be in [0, 23] - if params.peak_end_hour is not None and not (0 <= params.peak_end_hour <= 23): - self.errors.append( - ValidationError( - rule="PARAM_TS_08", - message=f"TimeseriesParams peak_end_hour must be in [0, 23], got: {params.peak_end_hour}", - location=location, - suggestion="Use a valid hour value between 0 (midnight) and 23 (11 PM)", - ) - ) - - def _validate_derivation_params(self): - """Validate Derivation parameters and dependencies.""" - for entity_idx, entity in enumerate(self.schema.entities): - for column_idx, column in enumerate(entity.columns): - if column.derivation: - location = f"$.entities[{entity_idx}] ({entity.name}) > columns[{column_idx}] ({column.name}) > derivation" - - # R11: Check for MEASUREMENT→MEASUREMENT dependencies - unsolved issue? - self._validate_measurement_dependencies(entity, column, location) - # This is incorrect. The correct validation is added in https://github.com/Rockfish-Data/cuttlefish/pull/1099 - # def _validate_measurement_dependencies( - # self, entity: Entity, column: Column, location: str - # ): - # """ - # R11: MEASUREMENT derived columns cannot depend on same-entity MEASUREMENT columns. - - # This is a critical validator that prevents runtime KeyError during generation. - # MEASUREMENT columns are generated in arbitrary order, so dependencies between - # them in the same entity will fail. - - # The SDK does NOT validate this. - # """ - # if ( - # column.column_category_type == ColumnCategoryType.MEASUREMENT - # and column.derivation - # ): - # # Check each dependency - # for dep_col_name in column.derivation.dependencies: - # # Find dependency column in same entity - # dep_col = next( - # (col for col in entity.columns if col.name == dep_col_name), None - # ) - - # if ( - # dep_col - # and dep_col.column_category_type == ColumnCategoryType.MEASUREMENT - # ): - # self.errors.append( - # ValidationError( - # rule="R11", - # message=f"MEASUREMENT derived column '{column.name}' cannot depend on another MEASUREMENT column '{dep_col_name}' in the same entity (currently unsupported)", - # location=location, - # suggestion=f"Change '{dep_col_name}' to column_category_type='metadata', OR restructure to avoid MEASUREMENT->MEASUREMENT dependencies", - # ) - # ) - - # This is handled in https://github.com/Rockfish-Data/cuttlefish/pull/1097 - def _validate_business_rules(self): - """Validate business logic rules.""" - for entity_idx, entity in enumerate(self.schema.entities): - for column_idx, column in enumerate(entity.columns): - location = f"$.entities[{entity_idx}] ({entity.name}) > columns[{column_idx}] ({column.name})" - - # R12: MEASUREMENT cannot be INDEPENDENT - will added in the next release - self._validate_column_business_rules(column, location) - - def _validate_column_business_rules(self, column: Column, location: str): - """ - Validate column-level business rules. - - R12: MEASUREMENT columns cannot be INDEPENDENT type. - - This combination is currently unsupported by generation but the SDK allows it. - This validator prevents invalid configurations. - """ - # R12: MEASUREMENT column cannot be INDEPENDENT - if ( - column.column_category_type == ColumnCategoryType.MEASUREMENT - and column.column_type == ColumnType.INDEPENDENT - ): - self.errors.append( - ValidationError( - rule="R12", - message=f"MEASUREMENT column cannot be INDEPENDENT (currently unsupported)", - location=location, - suggestion="Change column_type to 'stateful' for time-varying measurements, or change column_category_type to 'metadata' for static values", - ) - ) - - -############################################################################### -# Public API -############################################################################### - - -def validate_dataschema_comprehensive(schema: DataSchema) -> list[ValidationError]: - """ - Comprehensive validation of DataSchema configuration. - - This function validates business logic rules that are NOT covered by the - Rockfish SDK's built-in __attrs_post_init__ methods. - - Validates: - - PARAM_TS_07: peak_start_hour in [0,23] - - PARAM_TS_08: peak_end_hour in [0,23] - - R11: MEASUREMENT→MEASUREMENT dependency prevention (prevents runtime KeyError) - - R12: MEASUREMENT INDEPENDENT restriction (unsupported combination) - - The SDK already validates everything else (types, required fields, enums, - parameter ranges, entity constraints, etc.). - - Args: - schema: DataSchema object to validate (already structured by SDK) - - Returns: - List of ValidationError objects (empty if valid) - - Example: - >>> schema = rf.converter.structure(schema_dict, DataSchema) - >>> errors = validate_dataschema_comprehensive(schema) - >>> if errors: - ... for err in errors: - ... print(f"{err.rule}: {err.message}") - """ - validator = DataSchemaValidator(schema) - return validator.validate_all() - - -def format_validation_errors(errors: list[ValidationError]) -> str: - """ - Format validation errors for display. - - All errors are treated as ERROR level (blocking). - - Args: - errors: List of ValidationError objects - - Returns: - Formatted string with all errors - """ - if not errors: - return "No validation errors" - - lines = [] - for idx, err in enumerate(errors, 1): - lines.append( - f"{idx}. [ERROR] {err.rule}: {err.message}\n" f" Location: {err.location}" - ) - if err.suggestion: - lines.append(f" Suggestion: {err.suggestion}") - - return "\n\n".join(lines) From 26ef9312f533d8991913a2740ddc3dd253a15836 Mon Sep 17 00:00:00 2001 From: Fan Date: Wed, 10 Dec 2025 21:26:10 -0800 Subject: [PATCH 10/10] remove unused code after updates --- src/rockfish_mcp/sdk_client.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/rockfish_mcp/sdk_client.py b/src/rockfish_mcp/sdk_client.py index ea9f28b..92a9b2c 100644 --- a/src/rockfish_mcp/sdk_client.py +++ b/src/rockfish_mcp/sdk_client.py @@ -406,14 +406,6 @@ async def collect_logs(): "message": validation_result["summary"], "error": validation_result["error_message"], } - # Add optional fields if present - if ( - "suggestion" in validation_result - and validation_result["suggestion"] - ): - error_response["suggestion"] = validation_result["suggestion"] - if "reference" in validation_result and validation_result["reference"]: - error_response["reference"] = validation_result["reference"] return error_response