diff --git a/Dockerfile b/Dockerfile index 1e32046e..20f9552a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,13 @@ FROM tiangolo/uwsgi-nginx-flask:python3.12 AS builder -RUN apt-get update && apt-get install -y gcc +RUN apt-get update && apt-get install -y gcc curl ca-certificates gnupg lsb-release \ + && install -d /usr/share/postgresql-common/pgdg \ + && curl -fsSL https://www.postgresql.org/media/keys/ACCC4CF8.asc \ + -o /usr/share/postgresql-common/pgdg/apt.postgresql.org.asc \ + && sh -c 'echo "deb [signed-by=/usr/share/postgresql-common/pgdg/apt.postgresql.org.asc] \ + https://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \ + > /etc/apt/sources.list.d/pgdg.list' \ + && apt-get update \ + && apt-get install -y postgresql-client-18 RUN pip install uv # Enable bytecode compilation ENV UV_COMPILE_BYTECODE=1 diff --git a/materializationengine/blueprints/upload/api.py b/materializationengine/blueprints/upload/api.py index 1a0c8d76..b3547e1d 100644 --- a/materializationengine/blueprints/upload/api.py +++ b/materializationengine/blueprints/upload/api.py @@ -48,7 +48,7 @@ get_job_status, process_and_upload, ) -from materializationengine.database import db_manager +from materializationengine.database import db_manager, dynamic_annotation_cache from materializationengine.info_client import get_datastack_info, get_datastacks from materializationengine.utils import get_config_param from materializationengine import __version__ @@ -297,25 +297,29 @@ def create_storage_service(): @upload_bp.route("/generate-presigned-url/", methods=["POST"]) @auth_requires_permission("edit", table_arg="datastack_name") def generate_presigned_url(datastack_name: str): - data = request.json - filename = data["filename"] - content_type = data["contentType"] - bucket_name = current_app.config.get("MATERIALIZATION_UPLOAD_BUCKET_PATH") - storage_client = storage.Client() - bucket = storage_client.bucket(bucket_name) - blob = bucket.blob(filename) - origin = request.headers.get("Origin") or current_app.config.get( - "LOCAL_SERVER_URL", "http://localhost:5000" - ) - try: + data = request.json + if not data: + return jsonify({"status": "error", "message": "Request body must be JSON"}), 400 + filename = data["filename"] + content_type = data["contentType"] + bucket_name = current_app.config.get("MATERIALIZATION_UPLOAD_BUCKET_PATH") + if not bucket_name: + return jsonify({"status": "error", "message": "Upload bucket is not configured"}), 500 + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(filename) + origin = request.headers.get("Origin") or current_app.config.get( + "LOCAL_SERVER_URL", "http://localhost:5000" + ) resumable_url = blob.create_resumable_upload_session( content_type=content_type, - origin=origin, # Allow cross-origin requests for uploads - timeout=3600, # Set the session timeout to 1 hour + origin=origin, + timeout=3600, ) - return jsonify({"resumableUrl": resumable_url, "origin": origin}) + except KeyError as e: + return jsonify({"status": "error", "message": f"Missing required field: {e}"}), 400 except google_exceptions.Forbidden as e: current_app.logger.error( f"GCS Forbidden error generating presigned URL: {str(e)}" @@ -552,6 +556,73 @@ def save_metadata(): 400, ) + # Check whether the table name already exists in production before + # allowing the user to proceed with the upload. + table_name = data["table_name"] + datastack_name = data["datastack_name"] + force_overwrite = data.get("force_overwrite", False) + try: + datastack_info = get_datastack_info(datastack_name) + production_db_name = datastack_info["aligned_volume"]["name"] + production_db_client = dynamic_annotation_cache.get_db(production_db_name) + existing_meta = production_db_client.database.get_table_metadata(table_name) + if existing_meta: + return ( + jsonify( + { + "status": "error", + "message": f"Table '{table_name}' already exists in production. " + f"Choose a different name.", + } + ), + 409, + ) + except Exception as check_err: + current_app.logger.warning( + f"Could not check production for existing table '{table_name}': {check_err}" + ) + + # Check whether the table already exists in staging (from a previous failed run). + # If so and the user has not explicitly requested an overwrite, surface the conflict + # so the frontend can ask for confirmation before wiping any existing staging data. + if not force_overwrite: + try: + staging_db_name = current_app.config.get("STAGING_DATABASE_NAME") + if staging_db_name: + staging_db_client = dynamic_annotation_cache.get_db(staging_db_name) + staging_meta = staging_db_client.database.get_table_metadata(table_name) + if staging_meta: + # Count existing rows so the UI can show a meaningful message. + staging_engine = db_manager.get_engine(staging_db_name) + try: + from sqlalchemy import text as _text + with staging_engine.connect() as conn: + row_count = conn.execute( + _text(f"SELECT COUNT(*) FROM {table_name}") + ).scalar() + except Exception: + row_count = None + return ( + jsonify( + { + "status": "staging_conflict", + "staging_exists": True, + "row_count": row_count, + "message": ( + f"Table '{table_name}' already exists in staging" + + (f" with {row_count:,} rows" if row_count is not None else "") + + " from a previous run. " + + "Do you want to clear it and restart the upload?" + ), + } + ), + 409, + ) + except Exception as check_err: + current_app.logger.warning( + f"Could not check staging for existing table '{table_name}': {check_err}" + ) + success, result = storage.save_metadata( filename=data["table_name"], metadata=data ) @@ -671,19 +742,26 @@ def start_csv_processing(): @auth_required def check_processing_status(job_id): """Get processing job status""" - status = get_job_status(job_id) - if not status: - return jsonify({"status": "error", "message": "Job not found"}), 404 + try: + status = get_job_status(job_id) + if not status: + return jsonify({"status": "error", "message": "Job not found"}), 404 - if _check_authorization(status): - return jsonify({"status": "error", "message": "Forbidden"}), 403 + if _check_authorization(status): + return jsonify({"status": "error", "message": "Forbidden"}), 403 - _set_item_type(status) + _set_item_type(status) - if status.get("active_workflow_part") == "spatial_lookup": - _handle_spatial_lookup(status, job_id) + if status.get("active_workflow_part") == "spatial_lookup": + _handle_spatial_lookup(status, job_id) - return jsonify(status) + return jsonify(status) + except Exception as e: + current_app.logger.error( + f"Unexpected error in check_processing_status for job {job_id}: {e}", + exc_info=True, + ) + return jsonify({"status": "error", "message": "Internal error fetching job status"}), 500 def _check_authorization(status): diff --git a/materializationengine/blueprints/upload/checkpoint_manager.py b/materializationengine/blueprints/upload/checkpoint_manager.py index c8326cc1..fef7ba32 100644 --- a/materializationengine/blueprints/upload/checkpoint_manager.py +++ b/materializationengine/blueprints/upload/checkpoint_manager.py @@ -85,6 +85,7 @@ class WorkflowData: chunking_parameters: Optional[dict] = None current_pending_scan_cursor: Optional[int] = 0 submitted_chunks: int = 0 + recovery_attempts: int = 0 @property def progress(self) -> float: @@ -115,6 +116,15 @@ def _get_chunk_failed_details_key(self, table_name: str) -> str: def _get_retryable_chunks_set_key(self, table_name: str) -> str: return f"{self.workflow_prefix}{table_name}:failed_retryable_chunks" + def _get_processing_subtasks_timestamps_key(self, table_name: str) -> str: + return f"{self.workflow_prefix}{table_name}:processing_subtasks_timestamps" + + def _get_processing_timestamps_key(self, table_name: str) -> str: + return f"{self.workflow_prefix}{table_name}:processing_timestamps" + + def _get_dispatch_params_key(self, table_name: str) -> str: + return f"{self.workflow_prefix}{table_name}:dispatch_params" + def get_bbox_hash(self, bbox: Union[np.ndarray, List]) -> str: """Generate hash for bounding box.""" bbox_list = bbox.tolist() if isinstance(bbox, np.ndarray) else bbox @@ -288,13 +298,16 @@ def get_processing_rate(self, table_name: str) -> Optional[str]: return None def reset_chunk_statuses_and_details(self, table_name: str): - """Deletes chunk_statuses, chunk_failed_details, and failed_retryable_set keys for the table.""" + """Deletes chunk_statuses, chunk_failed_details, failed_retryable_set, and stale-recovery timestamp keys for the table.""" chunk_statuses_key = self._get_chunk_statuses_key(table_name) chunk_failed_details_key = self._get_chunk_failed_details_key(table_name) retryable_set_key = self._get_retryable_chunks_set_key(table_name) + processing_subtasks_ts_key = self._get_processing_subtasks_timestamps_key(table_name) + processing_ts_key = self._get_processing_timestamps_key(table_name) try: REDIS_CLIENT.delete( - chunk_statuses_key, chunk_failed_details_key, retryable_set_key + chunk_statuses_key, chunk_failed_details_key, retryable_set_key, + processing_subtasks_ts_key, processing_ts_key, ) celery_logger.info( f"Reset chunk statuses, details, and retryable set for table: {table_name}" @@ -342,11 +355,34 @@ def set_chunk_status( old_status_bytes.decode("utf-8") if old_status_bytes else None ) + processing_subtasks_ts_key = self._get_processing_subtasks_timestamps_key(table_name) + processing_ts_key = self._get_processing_timestamps_key(table_name) pipe.multi() pipe.hset(chunk_statuses_key, str(chunk_index), status) pipe.expire(chunk_statuses_key, self.expiry_time) + # Track when chunks enter PROCESSING for stale-chunk recovery + now_iso = datetime.datetime.now(datetime.timezone.utc).isoformat() + if status == CHUNK_STATUS_PROCESSING: + pipe.hset(processing_ts_key, str(chunk_index), now_iso) + pipe.expire(processing_ts_key, self.expiry_time) + else: + # Chunk left PROCESSING — remove its timestamp entry + pipe.hdel(processing_ts_key, str(chunk_index)) + + # Track when chunks enter PROCESSING_SUBTASKS for stale-chunk recovery + if status == CHUNK_STATUS_PROCESSING_SUBTASKS: + pipe.hset( + processing_subtasks_ts_key, + str(chunk_index), + now_iso, + ) + pipe.expire(processing_subtasks_ts_key, self.expiry_time) + else: + # Chunk left PROCESSING_SUBTASKS — remove its timestamp entry + pipe.hdel(processing_subtasks_ts_key, str(chunk_index)) + current_time_iso = datetime.datetime.now( datetime.timezone.utc ).isoformat() @@ -645,6 +681,191 @@ def get_chunks_to_process( return chunks_to_process, None, new_pending_cursor + def recover_stale_processing_subtasks( + self, table_name: str, stale_threshold_seconds: int = 600 + ) -> int: + """ + Scans chunks in PROCESSING_SUBTASKS and marks any that have been there + longer than stale_threshold_seconds as FAILED_RETRYABLE so the dispatcher + can re-dispatch them. + + Also recovers chunks in PROCESSING_SUBTASKS that have no timestamp entry + (e.g. from before timestamp tracking was deployed, or from a failed write) + — these are treated as immediately stale since their age is unknown. + + Returns the number of chunks recovered. + """ + ts_key = self._get_processing_subtasks_timestamps_key(table_name) + retryable_set_key = self._get_retryable_chunks_set_key(table_name) + chunk_statuses_key = self._get_chunk_statuses_key(table_name) + recovered = 0 + try: + all_timestamps = REDIS_CLIENT.hgetall(ts_key) + all_statuses = REDIS_CLIENT.hgetall(chunk_statuses_key) + + now = datetime.datetime.now(datetime.timezone.utc) + cutoff = now - datetime.timedelta(seconds=stale_threshold_seconds) + + # First pass: recover chunks in PROCESSING_SUBTASKS with no timestamp entry. + # These have unknown age (e.g. entered before timestamp tracking was deployed) + # and should be treated as immediately stale. + for chunk_idx_bytes, status_bytes in all_statuses.items(): + if status_bytes.decode("utf-8") != CHUNK_STATUS_PROCESSING_SUBTASKS: + continue + if chunk_idx_bytes in all_timestamps: + continue # Has a timestamp — handled in second pass + chunk_idx_str = chunk_idx_bytes.decode("utf-8") + try: + celery_logger.warning( + f"Chunk {chunk_idx_str} for '{table_name}' is in PROCESSING_SUBTASKS " + f"with no timestamp recorded (unknown age). Marking FAILED_RETRYABLE." + ) + REDIS_CLIENT.hset(chunk_statuses_key, chunk_idx_str, CHUNK_STATUS_FAILED_RETRYABLE) + REDIS_CLIENT.sadd(retryable_set_key, chunk_idx_str) + recovered += 1 + except Exception as e_inner: + celery_logger.error( + f"Error recovering timestamp-less chunk {chunk_idx_str} for {table_name}: {e_inner}" + ) + + # Second pass: recover chunks whose timestamp has exceeded the stale threshold. + for chunk_idx_bytes, ts_bytes in all_timestamps.items(): + chunk_idx_str = chunk_idx_bytes.decode("utf-8") + try: + entered_at = datetime.datetime.fromisoformat(ts_bytes.decode("utf-8")) + if entered_at > cutoff: + continue # Still within the grace period + + # Verify the chunk is still in PROCESSING_SUBTASKS (not already resolved) + current_status_bytes = REDIS_CLIENT.hget(chunk_statuses_key, chunk_idx_str) + if current_status_bytes is None: + REDIS_CLIENT.hdel(ts_key, chunk_idx_str) + continue + current_status = current_status_bytes.decode("utf-8") + if current_status != CHUNK_STATUS_PROCESSING_SUBTASKS: + REDIS_CLIENT.hdel(ts_key, chunk_idx_str) + continue + + age_seconds = (now - entered_at).total_seconds() + celery_logger.warning( + f"Chunk {chunk_idx_str} for '{table_name}' has been in " + f"PROCESSING_SUBTASKS for {age_seconds:.0f}s (threshold {stale_threshold_seconds}s). " + f"Marking FAILED_RETRYABLE for re-dispatch." + ) + REDIS_CLIENT.hset(chunk_statuses_key, chunk_idx_str, CHUNK_STATUS_FAILED_RETRYABLE) + REDIS_CLIENT.sadd(retryable_set_key, chunk_idx_str) + REDIS_CLIENT.hdel(ts_key, chunk_idx_str) + recovered += 1 + except Exception as e_inner: + celery_logger.error( + f"Error processing stale-chunk entry {chunk_idx_str} for {table_name}: {e_inner}" + ) + except Exception as e: + celery_logger.error( + f"Error in recover_stale_processing_subtasks for {table_name}: {e}" + ) + return recovered + + def recover_stale_processing_chunks( + self, table_name: str, stale_threshold_seconds: int = 600 + ) -> int: + """ + Scans chunks in PROCESSING state and marks any that have been there + longer than stale_threshold_seconds as FAILED_RETRYABLE so the dispatcher + can re-dispatch them. + + Also recovers chunks in PROCESSING that have no timestamp entry (e.g. from + before timestamp tracking was deployed) — treated as immediately stale. + + Returns the number of chunks recovered. + """ + ts_key = self._get_processing_timestamps_key(table_name) + chunk_statuses_key = self._get_chunk_statuses_key(table_name) + retryable_set_key = self._get_retryable_chunks_set_key(table_name) + recovered = 0 + try: + all_timestamps = REDIS_CLIENT.hgetall(ts_key) + all_statuses = REDIS_CLIENT.hgetall(chunk_statuses_key) + + now = datetime.datetime.now(datetime.timezone.utc) + cutoff = now - datetime.timedelta(seconds=stale_threshold_seconds) + + # First pass: recover PROCESSING chunks with no timestamp (unknown age → stale). + for chunk_idx_bytes, status_bytes in all_statuses.items(): + if status_bytes.decode("utf-8") != CHUNK_STATUS_PROCESSING: + continue + if chunk_idx_bytes in all_timestamps: + continue + chunk_idx_str = chunk_idx_bytes.decode("utf-8") + try: + celery_logger.warning( + f"Chunk {chunk_idx_str} for '{table_name}' is in PROCESSING " + f"with no timestamp recorded (unknown age). Marking FAILED_RETRYABLE." + ) + REDIS_CLIENT.hset(chunk_statuses_key, chunk_idx_str, CHUNK_STATUS_FAILED_RETRYABLE) + REDIS_CLIENT.sadd(retryable_set_key, chunk_idx_str) + recovered += 1 + except Exception as e_inner: + celery_logger.error( + f"Error recovering no-timestamp PROCESSING chunk {chunk_idx_str} for {table_name}: {e_inner}" + ) + + # Second pass: recover PROCESSING chunks whose timestamp is past the cutoff. + for chunk_idx_bytes, ts_bytes in all_timestamps.items(): + chunk_idx_str = chunk_idx_bytes.decode("utf-8") + try: + entered_at = datetime.datetime.fromisoformat(ts_bytes.decode("utf-8")) + if entered_at > cutoff: + continue + + current_status_bytes = REDIS_CLIENT.hget(chunk_statuses_key, chunk_idx_str) + if current_status_bytes is None: + REDIS_CLIENT.hdel(ts_key, chunk_idx_str) + continue + current_status = current_status_bytes.decode("utf-8") + if current_status != CHUNK_STATUS_PROCESSING: + REDIS_CLIENT.hdel(ts_key, chunk_idx_str) + continue + + age_seconds = (now - entered_at).total_seconds() + celery_logger.warning( + f"Chunk {chunk_idx_str} for '{table_name}' has been in " + f"PROCESSING for {age_seconds:.0f}s (threshold {stale_threshold_seconds}s). " + f"Marking FAILED_RETRYABLE for re-dispatch." + ) + REDIS_CLIENT.hset(chunk_statuses_key, chunk_idx_str, CHUNK_STATUS_FAILED_RETRYABLE) + REDIS_CLIENT.sadd(retryable_set_key, chunk_idx_str) + REDIS_CLIENT.hdel(ts_key, chunk_idx_str) + recovered += 1 + except Exception as e_inner: + celery_logger.error( + f"Error processing stale-processing entry {chunk_idx_str} for {table_name}: {e_inner}" + ) + except Exception as e: + celery_logger.error( + f"Error in recover_stale_processing_chunks for {table_name}: {e}" + ) + return recovered + + def set_dispatch_params(self, table_name: str, params: dict) -> None: + """Store the parameters needed to re-dispatch process_table_in_chunks for recovery.""" + key = self._get_dispatch_params_key(table_name) + try: + REDIS_CLIENT.set(key, json.dumps(params), ex=self.expiry_time) + except Exception as e: + celery_logger.error(f"Error storing dispatch params for {table_name}: {e}") + + def get_dispatch_params(self, table_name: str) -> Optional[dict]: + """Retrieve stored dispatch parameters for process_table_in_chunks.""" + key = self._get_dispatch_params_key(table_name) + try: + data = REDIS_CLIENT.get(key) + if data: + return json.loads(data) + except Exception as e: + celery_logger.error(f"Error retrieving dispatch params for {table_name}: {e}") + return None + def get_all_chunk_statuses(self, table_name: str) -> Optional[Dict[str, str]]: """Gets all chunk statuses for a table.""" chunk_statuses_key = self._get_chunk_statuses_key(table_name) diff --git a/materializationengine/blueprints/upload/gcs_processor.py b/materializationengine/blueprints/upload/gcs_processor.py index ec2937c6..79416282 100644 --- a/materializationengine/blueprints/upload/gcs_processor.py +++ b/materializationengine/blueprints/upload/gcs_processor.py @@ -145,7 +145,7 @@ def process_csv_in_chunks( except Exception as e: raise ValueError( f"Failed to process and upload file to GCS using blob.open: {str(e)}" - ) + ) from e def _download_in_chunks(self, blob: storage.Blob) -> Iterator[pd.DataFrame]: """ diff --git a/materializationengine/blueprints/upload/processor.py b/materializationengine/blueprints/upload/processor.py index b1fd86b4..96f49d48 100644 --- a/materializationengine/blueprints/upload/processor.py +++ b/materializationengine/blueprints/upload/processor.py @@ -28,6 +28,7 @@ def __init__( reference_table: str = None, column_mapping: Dict[str, str] = None, ignored_columns: List[str] = None, + id_counter_start: int = 0, ): """ Initialize processor with schema name and optional column mapping @@ -50,7 +51,7 @@ def __init__( self.reverse_mapping = {v: k for k, v in self.column_mapping.items()} self.ignored_columns = set(ignored_columns or []) self.generate_ids = "id" not in self.column_mapping - self._id_counter = 0 + self._id_counter = id_counter_start if self.is_reference and reference_table is None: raise ValueError( @@ -59,6 +60,7 @@ def __init__( self.spatial_points = self._get_spatial_point_fields() self.required_spatial_points = self._get_required_fields() + self.dropped_rows = 0 # cumulative count of rows dropped due to invalid/missing required fields table_metadata = ( {"reference_table": reference_table} if self.is_reference else None @@ -183,10 +185,16 @@ def process_spatial_point( else: # Handle separate x,y,z columns coords = [float(row[col]) for col in coordinate_cols] - + if len(coords) != 3: raise ValueError(f"Expected 3 coordinates, got {len(coords)}") - + + # NaN coordinates cannot be stored as a valid PointZ — Shapely/GEOS drops + # the Z dimension for all-NaN points, producing a 2D WKB that PostgreSQL + # rejects on a PointZ column. Treat missing coordinates as NULL instead. + if any(np.isnan(c) for c in coords): + return None + point = Point(coords) return create_wkt_element(point) @@ -221,42 +229,52 @@ def process_chunk(self, chunk: pd.DataFrame, timestamp: datetime) -> str: for field_name, coordinate_cols in self.spatial_points: if all(col in chunk.columns for col in coordinate_cols): - coordinates = chunk[coordinate_cols].astype(float).values - points = [Point(coords) for coords in coordinates] processed_data[f"{field_name}_position"] = [ - create_wkt_element(point) if not any(pd.isna(coords)) else "" - for point, coords in zip(points, coordinates) + self.process_spatial_point(row, coordinate_cols) + for _, row in chunk.iterrows() ] else: if field_name in self.required_spatial_points: raise ValueError( f"Missing coordinates for required spatial point: {field_name}" ) - processed_data[f"{field_name}_position"] = [""] * chunk_size + processed_data[f"{field_name}_position"] = [None] * chunk_size for field_name, field in self.schema._declared_fields.items(): if not isinstance(field, mm.fields.Nested): csv_col = self._get_mapped_column(field_name) if csv_col in chunk.columns and csv_col not in self.ignored_columns: if isinstance(field, mm.fields.Int): - processed_data[field_name] = ( - chunk[csv_col].fillna(0).astype(int) - ) + processed_data[field_name] = chunk[csv_col].astype("Int64") elif isinstance(field, mm.fields.Float): - processed_data[field_name] = ( - chunk[csv_col].fillna(0.0).astype(float) - ) + processed_data[field_name] = chunk[csv_col].astype(float) elif isinstance(field, mm.fields.Bool): - processed_data[field_name] = ( - chunk[csv_col].fillna(False).astype(bool) - ) + processed_data[field_name] = chunk[csv_col].astype("boolean") else: - processed_data[field_name] = chunk[csv_col].fillna("") + processed_data[field_name] = chunk[csv_col] df = pd.DataFrame(processed_data) for col in self.column_order: if col not in df.columns: - df[col] = [""] * chunk_size + df[col] = [None] * chunk_size + + # Drop rows where any required spatial point is NULL (missing or NaN coordinates). + # These rows cannot be meaningfully stored and would cause DB errors. + required_position_cols = [ + f"{field_name}_position" + for field_name in self.required_spatial_points + if f"{field_name}_position" in df.columns + ] + if required_position_cols: + invalid_mask = df[required_position_cols].isnull().any(axis=1) + n_dropped = int(invalid_mask.sum()) + if n_dropped: + logger.warning( + f"Dropping {n_dropped} row(s) with missing required spatial coordinates " + f"(columns: {required_position_cols})" + ) + self.dropped_rows += n_dropped + df = df[~invalid_mask] return df[self.column_order] diff --git a/materializationengine/blueprints/upload/tasks.py b/materializationengine/blueprints/upload/tasks.py index 34f5c85c..8a2856dd 100644 --- a/materializationengine/blueprints/upload/tasks.py +++ b/materializationengine/blueprints/upload/tasks.py @@ -2,7 +2,7 @@ import os import subprocess from datetime import datetime, timezone -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import shlex import pandas as pd @@ -13,6 +13,7 @@ from redis import Redis from sqlalchemy import text from dynamicannotationdb.key_utils import build_segmentation_table_name +from dynamicannotationdb.models import SegmentationMetadata from materializationengine.blueprints.upload.gcs_processor import GCSCsvProcessor from materializationengine.blueprints.upload.processor import SchemaProcessor @@ -178,6 +179,7 @@ def process_csv( reference_table: str = None, ignored_columns: List[str] = None, chunk_size: int = 10000, + id_counter_start: int = 0, ) -> Dict[str, Any]: """Process CSV file in chunks using GCSCsvProcessor""" try: @@ -198,6 +200,7 @@ def process_csv( reference_table, column_mapping=column_mapping, ignored_columns=ignored_columns, + id_counter_start=id_counter_start, ) bucket_name = current_app.config.get("MATERIALIZATION_UPLOAD_BUCKET_PATH") @@ -243,10 +246,28 @@ def progress_callback(progress_details: Dict[str, Any]): chunk_upload_callback=progress_callback, ) + last_assigned_id = schema_processor._id_counter + dropped_rows = schema_processor.dropped_rows + + status_update: Dict[str, Any] = {"last_assigned_id": last_assigned_id} + if dropped_rows: + celery_logger.warning( + f"CSV processing dropped {dropped_rows} row(s) with missing required " + f"spatial coordinates for job {job_id_for_status}." + ) + status_update["dropped_rows"] = dropped_rows + status_update["warning"] = ( + f"{dropped_rows:,} row(s) were skipped because their required spatial " + f"coordinates were missing or invalid." + ) + update_job_status(job_id_for_status, status_update) + return { "status": "completed_csv_processing", "output_path": f"{bucket_name}/{destination_blob_name}", "job_id_for_status": job_id_for_status, + "last_assigned_id": last_assigned_id, + "dropped_rows": dropped_rows, } except Exception as e: celery_logger.error( @@ -350,6 +371,7 @@ def upload_to_database( }, ) db_client = dynamic_annotation_cache.get_db(staging_database) + force_overwrite = file_metadata["metadata"].get("force_overwrite", False) try: db_client.annotation.create_table( @@ -369,6 +391,21 @@ def upload_to_database( except Exception as e: celery_logger.error(f"Error creating table: {str(e)}") + # If the user confirmed overwrite, clear any data left from a previous run + # before importing so we don't accumulate duplicate IDs. + if force_overwrite: + try: + with db_client.database.engine.begin() as conn: + # CASCADE truncates any FK-dependent tables (e.g. the segmentation table) + conn.execute(text(f"TRUNCATE TABLE {table_name} CASCADE")) + celery_logger.info( + f"force_overwrite=True: truncated staging data for '{table_name}' (CASCADE)" + ) + except Exception as trunc_err: + celery_logger.warning( + f"Could not truncate staging table '{table_name}' for overwrite: {trunc_err}" + ) + update_job_status( job_id_for_status, {"phase": "Uploading to Database: Dropping Indices", "progress": 10}, @@ -530,7 +567,9 @@ def upload_to_database( }, ) - raise + raise RuntimeError( + f"gcloud sql import csv failed: {e.stderr or str(e)}" + ) from e except subprocess.TimeoutExpired as e: celery_logger.error(f"Subprocess timed out: {e}") update_job_status( @@ -591,6 +630,7 @@ def upload_to_database( "active_workflow_part": "spatial_lookup", "total_rows": total_rows_from_csv, "processed_rows": processed_rows_from_csv, + "spatial_lookup_config": spatial_lookup_config, }, ) @@ -617,7 +657,7 @@ def upload_to_database( raise @celery.task( - name="process:monitor_spatial_workflow_completion", bind=True, max_retries=None + name="orchestration:monitor_spatial_workflow_completion", bind=True, max_retries=None ) def monitor_spatial_workflow_completion( self, @@ -704,7 +744,7 @@ def monitor_spatial_workflow_completion( }, "job_id_for_status": job_id_for_status, } - elif current_workflow_status in {CHUNK_STATUS_ERROR, "failed"}: + elif current_workflow_status in {CHUNK_STATUS_ERROR, CHUNK_STATUS_FAILED_PERMANENT, "failed"}: err_msg = ( f"Spatial workflow '{workflow_to_monitor}' has terminally FAILED with status " f"'{current_workflow_status}'. Last Error recorded: {workflow_data.last_error}" @@ -720,6 +760,33 @@ def monitor_spatial_workflow_completion( ) raise Exception(err_msg) else: + # Refresh the Redis key TTL and update progress details. + # Without this, long spatial lookups (>1 hour) cause the key to expire, + # making the status endpoint return 404 and breaking the UI poller. + try: + update_job_status( + job_id_for_status, + { + "status": "processing", + "phase": ( + f"Spatial Lookup: {current_workflow_status} " + f"({workflow_data.completed_chunks}/{workflow_data.total_chunks} chunks)" + ), + "progress": round(workflow_data.progress, 2), + "active_workflow_part": "spatial_lookup", + "spatial_lookup_config": { + "table_name": workflow_to_monitor, + "database_name": db_for_monitor, + "datastack_name": datastack_info.get("datastack", ""), + }, + "total_rows": workflow_data.total_chunks, + "processed_rows": workflow_data.completed_chunks, + }, + ) + except Exception as status_err: + celery_logger.warning( + f"{log_prefix} Failed to refresh job status during retry: {status_err}" + ) celery_logger.info( f"{log_prefix} Workflow '{workflow_to_monitor}' status is '{current_workflow_status}'. " @@ -743,18 +810,25 @@ def transfer_to_production( table_name_to_transfer = monitor_result["table_name"] materialization_time_stamp_str = monitor_result["materialization_time_stamp"] spatial_workflow_status = monitor_result.get("spatial_workflow_final_status", "UNKNOWN") - + job_id_for_ui = monitor_result.get("job_id_for_status") + try: materialization_time_stamp_dt = datetime.fromisoformat(materialization_time_stamp_str) except ValueError: materialization_time_stamp_dt = datetime.strptime(materialization_time_stamp_str, '%Y-%m-%d %H:%M:%S.%f') - + + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "processing", + "phase": "Preparing Transfer to Production", + "progress": 0, + }) celery_logger.info( f"Executing transfer_to_production for table: '{table_name_to_transfer}'. " f"Spatial workflow final status: '{spatial_workflow_status}'." ) - if "workflow_details" in monitor_result: + if "workflow_details" in monitor_result: celery_logger.info(f"Spatial workflow details: {monitor_result['workflow_details']}") staging_schema_name = get_config_param("STAGING_DATABASE_NAME") @@ -777,21 +851,34 @@ def transfer_to_production( needs_segmentation_table = staging_db_client.schema.is_segmentation_table_required(schema_type) - production_table_exists = False + production_engine = db_manager.get_engine(production_schema_name) + staging_engine = db_manager.get_engine(staging_schema_name) + + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "processing", + "phase": "Creating Annotation Table in Production", + "progress": 5, + }) + + # get_table_metadata returns None (no exception) when the table is not found, + # so we must check the return value rather than relying on exception handling. try: - production_db_client.database.get_table_metadata(table_name_to_transfer) - production_table_exists = True - celery_logger.info( - f"Annotation table '{table_name_to_transfer}' already exists in production schema '{production_schema_name}'." + _meta = production_db_client.database.get_table_metadata(table_name_to_transfer) + metadata_exists = bool(_meta) + except Exception: + metadata_exists = False + + with production_engine.connect() as conn: + physical_table_exists = production_engine.dialect.has_table( + conn, table_name_to_transfer ) - except Exception: - production_table_exists = False + + if not metadata_exists: + # Fresh table — create both metadata record and physical table celery_logger.info( - f"Annotation table '{table_name_to_transfer}' does not exist in production schema '{production_schema_name}'. Will create." + f"Creating annotation table '{table_name_to_transfer}' in production schema '{production_schema_name}'" ) - - if not production_table_exists: - celery_logger.info(f"Creating annotation table '{table_name_to_transfer}' in production schema '{production_schema_name}'") production_db_client.annotation.create_table( table_name=table_name_to_transfer, schema_type=schema_type, @@ -803,32 +890,60 @@ def transfer_to_production( table_metadata=table_metadata_from_staging.get("table_metadata"), flat_segmentation_source=table_metadata_from_staging.get("flat_segmentation_source"), write_permission=table_metadata_from_staging.get("write_permission", "PRIVATE"), - read_permission=table_metadata_from_staging.get("read_permission", "PRIVATE"), + read_permission=table_metadata_from_staging.get("read_permission", "PRIVATE"), notice_text=table_metadata_from_staging.get("notice_text"), ) + elif not physical_table_exists: + # Orphaned metadata from a previous failed run — recreate only the physical table + celery_logger.warning( + f"Metadata record exists for '{table_name_to_transfer}' in production but physical " + f"table is missing (likely from a previous failed transfer). Recreating physical table." + ) + model = production_db_client.schema.create_annotation_model( + table_name_to_transfer, schema_type + ) + model.__table__.create(production_engine, checkfirst=True) + else: + celery_logger.info( + f"Annotation table '{table_name_to_transfer}' already exists in production schema '{production_schema_name}'." + ) - - production_engine = db_manager.get_engine(production_schema_name) - db_url_obj = production_engine.url + + db_url_obj = production_engine.url db_connection_info_for_cli = get_db_connection_info(db_url_obj) + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "processing", + "phase": "Transferring Annotation Data", + "progress": 10, + }) + celery_logger.info(f"Transferring data for annotation table '{table_name_to_transfer}'") annotation_rows_transferred = transfer_table_using_pg_dump( table_name=table_name_to_transfer, - source_db=staging_schema_name, - target_db=production_schema_name, - db_info=db_connection_info_for_cli, + source_db=staging_schema_name, + target_db=production_schema_name, + db_info=db_connection_info_for_cli, drop_indices=True, rebuild_indices=True, engine=production_engine, + source_engine=staging_engine, model_creator=lambda: production_db_client.schema.create_annotation_model( table_name_to_transfer, schema_type ), - job_id=monitor_result.get("job_id_for_status"), + job_id=job_id_for_ui, + table_label="Annotation Table", ) + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "processing", + "phase": f"Annotation Data Transferred ({annotation_rows_transferred:,} rows)", + "progress": 70, + }) segmentation_transfer_results = None if transfer_segmentation and needs_segmentation_table: @@ -846,23 +961,36 @@ def transfer_to_production( if staging_segmentation_exists: celery_logger.info(f"Preparing to transfer segmentation table '{segmentation_table_name}'") - + + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "processing", + "phase": "Creating Segmentation Table in Production", + "progress": 80, + }) + mat_metadata_for_segmentation_table = { "annotation_table_name": table_name_to_transfer, "segmentation_table_name": segmentation_table_name, - "schema_type": schema_type, - "database": production_schema_name, - "aligned_volume": production_schema_name, + "schema": schema_type, # create_segmentation_model reads "schema", not "schema_type" + "database": production_schema_name, + "aligned_volume": production_schema_name, "pcg_table_name": pcg_table_name, - "last_updated": materialization_time_stamp_dt, + "last_updated": materialization_time_stamp_dt, "voxel_resolution_x": table_metadata_from_staging.get("voxel_resolution_x", 1.0), "voxel_resolution_y": table_metadata_from_staging.get("voxel_resolution_y", 1.0), "voxel_resolution_z": table_metadata_from_staging.get("voxel_resolution_z", 1.0), } - - create_missing_segmentation_table(mat_metadata_for_segmentation_table, db_client=production_db_client) - + create_missing_segmentation_table(mat_metadata_for_segmentation_table) + + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "processing", + "phase": "Transferring Segmentation Data", + "progress": 85, + }) + celery_logger.info(f"Transferring data for segmentation table '{segmentation_table_name}'") segmentation_rows_transferred = transfer_table_using_pg_dump( table_name=segmentation_table_name, @@ -872,8 +1000,27 @@ def transfer_to_production( drop_indices=True, rebuild_indices=True, engine=production_engine, + source_engine=staging_engine, model_creator=lambda: create_segmentation_model(mat_metadata_for_segmentation_table), + job_id=job_id_for_ui, + table_label="Segmentation Table", ) + + # Record the materialization timestamp used for spatial root ID lookup. + # This is set AFTER transfer so it accurately reflects the state of the data. + # Also handles the case where the metadata record pre-existed (e.g. retry). + with db_manager.session_scope(production_schema_name) as session: + seg_meta = ( + session.query(SegmentationMetadata) + .filter(SegmentationMetadata.table_name == segmentation_table_name) + .one() + ) + seg_meta.last_updated = materialization_time_stamp_dt + celery_logger.info( + f"Set last_updated={materialization_time_stamp_dt} on segmentation " + f"metadata for '{segmentation_table_name}'" + ) + segmentation_transfer_results = { "name": segmentation_table_name, "success": True, @@ -888,6 +1035,13 @@ def transfer_to_production( } + if job_id_for_ui: + update_job_status(job_id_for_ui, { + "status": "done", + "phase": "Transfer Complete", + "progress": 100, + }) + return { "status": "success", "message": f"Transfer completed for table '{table_name_to_transfer}'.", @@ -903,6 +1057,19 @@ def transfer_to_production( except Exception as e: celery_logger.error(f"Error during transfer_to_production for table '{monitor_result.get('table_name', 'UNKNOWN')}': {str(e)}", exc_info=True) + job_id_for_ui = monitor_result.get("job_id_for_status") + if job_id_for_ui: + try: + update_job_status( + job_id_for_ui, + { + "status": "error", + "phase": "Transfer to Production Failed", + "error": str(e), + }, + ) + except Exception as update_err: + celery_logger.error(f"Failed to update job status after transfer error: {update_err}") raise def get_db_connection_info(db_url): @@ -935,8 +1102,10 @@ def transfer_table_using_pg_dump( drop_indices: bool = True, rebuild_indices: bool = True, engine=None, + source_engine=None, model_creator=None, - job_id: str = None, + job_id: Optional[str] = None, + table_label: Optional[str] = None, ) -> int: """ Transfer a table using pg_dump and psql. @@ -948,7 +1117,8 @@ def transfer_table_using_pg_dump( db_info: Dictionary with database connection information drop_indices: Whether to drop indices before transfer rebuild_indices: Whether to rebuild indices after transfer - engine: SQLAlchemy engine (required if drop_indices or rebuild_indices is True) + engine: SQLAlchemy engine for the target DB (required if drop_indices or rebuild_indices is True) + source_engine: SQLAlchemy engine for the source DB (used for row count validation) model_creator: Function that returns the SQLAlchemy model (required if rebuild_indices is True) Returns: @@ -962,13 +1132,29 @@ def transfer_table_using_pg_dump( if rebuild_indices and model_creator is None: raise ValueError("model_creator is required when rebuild_indices is True") + label = table_label or table_name + if drop_indices: celery_logger.info(f"Dropping indexes on {table_name}") + if job_id: + update_job_status(job_id, {"status": "processing", "phase": f"Dropping Indices: {label}"}) index_cache.drop_table_indices(table_name, engine) + # Count source rows before transfer for post-transfer validation + source_row_count = None + if source_engine is not None: + with source_engine.connect() as conn: + source_row_count = conn.execute( + text(f"SELECT COUNT(*) FROM {table_name}") + ).scalar() + celery_logger.info(f"Source row count for {table_name}: {source_row_count}") + with engine.begin() as conn: conn.execute(text(f"TRUNCATE TABLE {table_name}")) + if job_id: + update_job_status(job_id, {"status": "processing", "phase": f"Copying Data: {label}"}) + # Build pg_dump and psql commands pg_dump_cmd = [ "pg_dump", @@ -986,6 +1172,7 @@ def transfer_table_using_pg_dump( f'--port={db_info["port"]}', f'--username={db_info["user"]}', f"--dbname={target_db}", + "--single-transaction", # makes the entire COPY import atomic; rolls back on failure ] pg_env = os.environ.copy() @@ -997,20 +1184,46 @@ def transfer_table_using_pg_dump( celery_logger.info(f"Running pg_dump command: {shlex.join(pg_dump_cmd)}") celery_logger.info(f"Running psql command: {shlex.join(psql_cmd)}") + pg_dump_stderr = "" with subprocess.Popen( pg_dump_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=pg_env ) as dump_proc: - result = subprocess.run( - psql_cmd, - stdin=dump_proc.stdout, - capture_output=True, - text=True, - timeout=1200, - env=pg_env, + try: + result = subprocess.run( + psql_cmd, + stdin=dump_proc.stdout, + capture_output=True, + text=True, + timeout=1200, + env=pg_env, + ) + # Drain pg_dump stderr and wait for it to exit cleanly + dump_proc.stdout.close() + pg_dump_stderr = dump_proc.stderr.read().decode("utf-8", errors="replace") + dump_proc.wait() + except Exception: + # Close stdout before killing so pg_dump doesn't block on + # a full pipe buffer, then kill to ensure it exits before + # Popen.__exit__ calls wait() — preventing a deadlock when + # pg_dump suppresses SIGPIPE (as libpq does). + dump_proc.stdout.close() + dump_proc.kill() + raise + + if dump_proc.returncode != 0: + raise RuntimeError( + f"pg_dump exited with code {dump_proc.returncode}: {pg_dump_stderr}" + ) + if pg_dump_stderr: + celery_logger.warning(f"pg_dump stderr: {pg_dump_stderr}") + + celery_logger.info(f"psql output: {result.stdout}") + if result.stderr: + celery_logger.warning(f"psql stderr: {result.stderr}") + if result.returncode != 0: + raise RuntimeError( + f"psql exited with code {result.returncode}: {result.stderr}" ) - celery_logger.info(f"psql output: {result.stdout}") - if result.stderr: - celery_logger.warning(f"psql stderr: {result.stderr}") except subprocess.CalledProcessError as e: celery_logger.error(f"pg_dump/psql error: {e}") if job_id: @@ -1038,11 +1251,23 @@ def transfer_table_using_pg_dump( ) raise + if job_id: + update_job_status(job_id, {"status": "processing", "phase": f"Verifying Row Count: {label}"}) + with engine.connect() as conn: row_count = conn.execute(text(f"SELECT COUNT(*) FROM {table_name}")).scalar() + celery_logger.info(f"Destination row count for {table_name}: {row_count}") + if source_row_count is not None and row_count != source_row_count: + raise RuntimeError( + f"Row count mismatch after transfer of '{table_name}': " + f"source had {source_row_count} rows but destination has {row_count} rows" + ) + if rebuild_indices: celery_logger.info(f"Rebuilding indexes on {table_name}") + if job_id: + update_job_status(job_id, {"status": "processing", "phase": f"Rebuilding Indices: {label}"}) model = model_creator() indices = index_cache.add_indices_sql_commands(table_name, model, engine) for index in indices: diff --git a/materializationengine/celery_worker.py b/materializationengine/celery_worker.py index e8ce9bf4..b0fcd2f6 100644 --- a/materializationengine/celery_worker.py +++ b/materializationengine/celery_worker.py @@ -8,7 +8,7 @@ import redis from celery.app.builtins import add_backend_cleanup_task from celery.schedules import crontab -from celery.signals import after_setup_logger +from celery.signals import after_setup_logger, worker_process_init from celery.utils.log import get_task_logger from dateutil import relativedelta from marshmallow import ValidationError @@ -111,6 +111,38 @@ def celery_loggers(logger, *args, **kwargs): Add stdout handler for Celery logger output. """ logger.addHandler(logging.StreamHandler(sys.stdout)) + + +@worker_process_init.connect +def configure_http_connection_pools(sender=None, **kwargs): + """ + Increase urllib3/requests HTTP connection pool sizes for each forked worker process. + + The default pool_maxsize=10 per host is exhausted during parallel CloudVolume + supervoxel lookups (scattered_points makes many concurrent requests to + storage.googleapis.com). Each discarded connection requires a new TCP+TLS + handshake on the next request, adding latency to every supervoxel lookup. + + We patch HTTPAdapter.__init__ so every Session created in this process + (including sessions created internally by cloud-files/cloudvolume) uses the + larger pool. Explicit callers that pass their own pool_maxsize are unaffected. + + Tune with the GCS_CONNECTION_POOL_SIZE environment variable (default: 128). + """ + from requests.adapters import HTTPAdapter + + pool_size = int(os.environ.get("GCS_CONNECTION_POOL_SIZE", "128")) + _orig_init = HTTPAdapter.__init__ + + def _patched_init(self, pool_connections=pool_size, pool_maxsize=pool_size, **kw): + _orig_init(self, pool_connections=pool_connections, pool_maxsize=pool_maxsize, **kw) + + HTTPAdapter.__init__ = _patched_init + celery_logger.info( + f"[worker_process_init] HTTP connection pool defaults set to " + f"pool_connections={pool_size}, pool_maxsize={pool_size} " + f"(GCS_CONNECTION_POOL_SIZE={pool_size})." + ) def days_till_next_month(date): diff --git a/materializationengine/cloudvolume_gateway.py b/materializationengine/cloudvolume_gateway.py index 56455674..025a0d2f 100644 --- a/materializationengine/cloudvolume_gateway.py +++ b/materializationengine/cloudvolume_gateway.py @@ -1,6 +1,11 @@ import cloudvolume import os +# Number of parallel threads CloudVolume uses internally for fetching data. +# Matches the GCS connection pool size so all threads can make concurrent requests +# without exhausting the pool. Tune with CLOUDVOLUME_PARALLEL env var. +_CV_PARALLEL = int(os.environ.get("CLOUDVOLUME_PARALLEL", "32")) + class CloudVolumeGateway: """A class to manage cloudvolume clients and cache them for reuse.""" @@ -58,6 +63,7 @@ def _get_cv_client( green_threads=use_green_threads, lru_bytes=self._lru_bytes, lru_encoding='crackle', + parallel=_CV_PARALLEL, ) self._cv_clients[seg_source_key] = cv_client diff --git a/materializationengine/workflows/chunking.py b/materializationengine/workflows/chunking.py index e5aef510..64311bce 100644 --- a/materializationengine/workflows/chunking.py +++ b/materializationengine/workflows/chunking.py @@ -249,7 +249,8 @@ def select_strategy(self): data_chunking_info = self._create_data_specific_chunks() if data_chunking_info: - self._chunk_generator, self.total_chunks = data_chunking_info + self.data_chunk_bounds_list, self.total_chunks = data_chunking_info + self._chunk_generator = None # will be built lazily from data_chunk_bounds_list self.strategy_name = "data_chunks" celery_logger.info( f"Using data-specific chunking with {self.total_chunks} chunks of size {self.actual_chunk_size}" diff --git a/materializationengine/workflows/ingest_new_annotations.py b/materializationengine/workflows/ingest_new_annotations.py index e45ad4ad..9f3130ac 100644 --- a/materializationengine/workflows/ingest_new_annotations.py +++ b/materializationengine/workflows/ingest_new_annotations.py @@ -683,6 +683,7 @@ def create_missing_segmentation_table(mat_metadata: dict) -> bool: "valid": True, "created": creation_time, "pcg_table_name": mat_metadata.get("pcg_table_name"), + "last_updated": mat_metadata.get("last_updated"), } seg_metadata = SegmentationMetadata(**metadata_dict) diff --git a/materializationengine/workflows/spatial_lookup.py b/materializationengine/workflows/spatial_lookup.py index 0b3015d4..10fd8c80 100644 --- a/materializationengine/workflows/spatial_lookup.py +++ b/materializationengine/workflows/spatial_lookup.py @@ -1,11 +1,12 @@ import datetime +import threading import time from typing import Dict, List, Any import numpy as np import pandas as pd from celery import Task, chain, chord -from celery.exceptions import MaxRetriesExceededError +from celery.exceptions import MaxRetriesExceededError, Retry from celery.utils.log import get_task_logger from cloudvolume.lib import Vec from geoalchemy2 import Geometry @@ -25,6 +26,7 @@ CHUNK_STATUS_FAILED_PERMANENT, CHUNK_STATUS_FAILED_RETRYABLE, CHUNK_STATUS_PROCESSING, + CHUNK_STATUS_PENDING, CHUNK_STATUS_PROCESSING_SUBTASKS, CHUNK_STATUS_ERROR, RedisCheckpointManager, @@ -63,6 +65,106 @@ MAX_CHUNK_WORKFLOW_ATTEMPTS = 10 +# --------------------------------------------------------------------------- +# Worker self-isolation: when a pod's Cloud SQL proxy is down, repeated +# ConnectionErrors are tracked per-worker-process. After the threshold is +# reached the worker cancels its own queue consumer so healthy pods handle +# the backlog instead. A background thread polls for recovery and re-enables +# the consumer once the database is reachable again. +# --------------------------------------------------------------------------- +_infra_lock = threading.Lock() +_consecutive_infra_failures: int = 0 +_worker_isolated: bool = False +_INFRA_ISOLATION_THRESHOLD: int = 5 # consecutive failures before isolation +_RECOVERY_POLL_INTERVAL: int = 60 # seconds between DB recovery probes + + +def _on_connection_error(database_name: str, task_self: Task) -> None: + """Record a DB ConnectionError and isolate this worker if the threshold is exceeded.""" + global _consecutive_infra_failures, _worker_isolated + + with _infra_lock: + _consecutive_infra_failures += 1 + count = _consecutive_infra_failures + already_isolated = _worker_isolated + + celery_logger.warning( + f"DB connection failure #{count} for '{database_name}' on this worker " + f"(hostname={task_self.request.hostname})." + ) + + if count >= _INFRA_ISOLATION_THRESHOLD and not already_isolated: + with _infra_lock: + _worker_isolated = True + hostname = task_self.request.hostname + celery_logger.critical( + f"Worker {hostname}: {count} consecutive DB connection failures for " + f"'{database_name}'. Pausing queue consumption so healthy pods can " + f"handle the work. (Cloud SQL proxy down?)" + ) + try: + task_self.app.control.cancel_consumer( + "celery", destination=[hostname], reply=False + ) + celery_logger.warning(f"Worker {hostname}: queue consumer paused.") + except Exception as cancel_err: + celery_logger.error( + f"Worker {hostname}: could not pause consumer: {cancel_err}" + ) + threading.Thread( + target=_db_recovery_watcher, + args=(database_name, hostname, task_self.app), + daemon=True, + ).start() + + +def _db_recovery_watcher(database_name: str, hostname: str, app) -> None: + """Background thread: poll DB until it becomes available, then re-enable consumer.""" + global _consecutive_infra_failures, _worker_isolated + + celery_logger.info( + f"[DBRecovery/{hostname}] Monitoring '{database_name}' for connectivity..." + ) + while True: + time.sleep(_RECOVERY_POLL_INTERVAL) + try: + db_manager.get_engine(database_name) + celery_logger.info( + f"[DBRecovery/{hostname}] DB '{database_name}' is reachable again. " + f"Re-enabling queue consumer." + ) + with _infra_lock: + _consecutive_infra_failures = 0 + _worker_isolated = False + try: + app.control.add_consumer( + "celery", destination=[hostname], reply=False + ) + celery_logger.info(f"[DBRecovery/{hostname}] Queue consumer re-enabled.") + except Exception as add_err: + celery_logger.error( + f"[DBRecovery/{hostname}] Failed to re-add consumer: {add_err}" + ) + return + except ConnectionError: + celery_logger.info( + f"[DBRecovery/{hostname}] DB '{database_name}' still unreachable. " + f"Retrying in {_RECOVERY_POLL_INTERVAL}s." + ) + except Exception as probe_err: + celery_logger.warning( + f"[DBRecovery/{hostname}] Unexpected error probing DB: {probe_err}. " + f"Retrying in {_RECOVERY_POLL_INTERVAL}s." + ) + + +def _reset_infra_failure_count() -> None: + """Reset consecutive failure counter after a successful DB operation.""" + global _consecutive_infra_failures + with _infra_lock: + if _consecutive_infra_failures > 0: + _consecutive_infra_failures = 0 + class ChunkProcessingError(Exception): """Base class for errors during chunk processing.""" @@ -224,9 +326,7 @@ def update_workflow_status(database, table_name, status): name="workflow:process_table_in_chunks", bind=True, acks_late=True, - autoretry_for=(Exception,), - max_retries=3, - retry_backoff=True, + max_retries=480, # 480 × 30s = 4 hours of polling headroom ) def process_table_in_chunks( self, @@ -251,9 +351,14 @@ def process_table_in_chunks( 5. If no chunks remain, it triggers index rebuilding and final completion. """ checkpoint_manager = RedisCheckpointManager(database_name) - engine = db_manager.get_engine(database_name) try: + # NOTE: engine acquisition is inside the try so that ConnectionError from + # a broken SQL proxy is caught and retried rather than failing the task + # permanently (which would stall the entire workflow). + engine = db_manager.get_engine(database_name) + _reset_infra_failure_count() + workflow_data = checkpoint_manager.get_workflow_data(workflow_name) if not workflow_data: celery_logger.error( @@ -340,6 +445,23 @@ def process_table_in_chunks( celery_logger.info( f"Chunking strategy for {annotation_table_name} (workflow: {workflow_name}): {chunking.total_chunks} chunks. Params stored." ) + + if initial_run: + # Store the dispatch parameters so spatial_workflow_failed can + # re-dispatch this task for recovery if a chord failure occurs. + checkpoint_manager.set_dispatch_params( + workflow_name, + { + "datastack_info": datastack_info, + "mat_metadata": mat_metadata, + "workflow_name": workflow_name, + "annotation_table_name": annotation_table_name, + "database_name": database_name, + "chunk_scale_factor": chunk_scale_factor, + "supervoxel_batch_size": supervoxel_batch_size, + "batch_size_for_dispatch": batch_size_for_dispatch, + }, + ) else: celery_logger.info( f"Using existing chunking strategy for {annotation_table_name} (workflow: {workflow_name}) from checkpoint." @@ -440,6 +562,23 @@ def process_table_in_chunks( full_chain.apply_async() return f"All chunks processed for {annotation_table_name}. Finalizing." else: + # Before sleeping, check whether any chunks stuck in PROCESSING or + # PROCESSING_SUBTASKS state have been there long enough to be treated + # as lost (pod killed, broker blip, etc.) and should be retried. + recovered_subtasks = checkpoint_manager.recover_stale_processing_subtasks( + workflow_name, stale_threshold_seconds=600 + ) + recovered_processing = checkpoint_manager.recover_stale_processing_chunks( + workflow_name, stale_threshold_seconds=600 + ) + recovered = recovered_subtasks + recovered_processing + if recovered: + celery_logger.info( + f"Recovered {recovered} stale chunk(s) for {workflow_name} " + f"({recovered_subtasks} PROCESSING_SUBTASKS, {recovered_processing} PROCESSING). " + f"Retrying dispatcher immediately to re-dispatch them." + ) + raise self.retry(countdown=0) celery_logger.info( f"No chunks returned by get_chunks_to_process for {workflow_name}, but scan may not be exhausted or non-terminal chunks exist. Retrying dispatcher." ) @@ -496,6 +635,19 @@ def process_table_in_chunks( return f"Dispatched batch of {len(chunk_indices_to_process)} chunks for {annotation_table_name} (workflow {workflow_name})." + except ConnectionError as e: + # SQL proxy / infra failure. Retry indefinitely with a long countdown so + # a healthy pod picks up the work. Do NOT mark the workflow as failed — + # this is a transient infrastructure issue, not a data problem. + celery_logger.warning( + f"DB connection error in process_table_in_chunks for {workflow_name}: {e}. " + f"Retrying in 5 minutes." + ) + _on_connection_error(database_name, self) + raise self.retry(exc=e, countdown=300, max_retries=None) + + except Retry: + raise except Exception as e: celery_logger.error( f"Critical error in process_table_in_chunks dispatcher for {workflow_name} (table {annotation_table_name}): {str(e)}", @@ -507,7 +659,9 @@ def process_table_in_chunks( last_error=f"Dispatcher critical error: {str(e)}", ) - raise self.retry(exc=e, countdown=int(2**self.request.retries)) + # Cap backoff at 5 minutes regardless of how many polling retries have occurred + error_backoff = min(300, int(2 ** min(self.request.retries, 8))) + raise self.retry(exc=e, countdown=error_backoff) @celery.task( @@ -750,6 +904,28 @@ def process_chunk( raise + except ConnectionError as e: + # SQL proxy / infra failure — NOT a data error. + # Reset chunk to PENDING so the attempt budget is NOT consumed; a healthy + # pod will re-pick it up after the countdown. + celery_logger.warning( + f"{log_prefix} DB connection error (SQL proxy down?): {e}. " + f"Resetting chunk to PENDING and retrying in 5 minutes." + ) + try: + checkpoint_manager.set_chunk_status( + workflow_name, + chunk_idx, + CHUNK_STATUS_PENDING, + {"message": f"Connection error — will retry: {str(e)[:300]}"}, + ) + except Exception as status_err: + celery_logger.warning( + f"{log_prefix} Could not reset chunk status to PENDING: {status_err}" + ) + _on_connection_error(database_name, self) + raise self.retry(exc=e, countdown=300, max_retries=None) + except (OperationalError, DisconnectionError) as e: celery_logger.warning( f"{log_prefix} Transient DB/network error (Celery attempt {self.request.retries + 1}/{self.max_retries}): {e}" @@ -1379,6 +1555,15 @@ def process_and_insert_sub_batch( ) raise + except ConnectionError as e: + # SQL proxy / infra failure — retry indefinitely with a long countdown so + # a healthy pod picks up the sub-batch instead. + celery_logger.warning( + f"{log_prefix} DB connection error: {e}. Retrying in 5 minutes." + ) + _on_connection_error(database_name, self) + raise self.retry(exc=e, countdown=300, max_retries=None) + except (OperationalError, DisconnectionError) as e: celery_logger.warning( f"{log_prefix} Transient DB/network error: {e}. Celery will retry." @@ -1627,6 +1812,72 @@ def spatial_workflow_failed( if workflow_name and database_name: checkpoint_manager = RedisCheckpointManager(database_name) + + # --- Recovery attempt for chord failures --- + # A chord failure (first arg is a UUID string, no Exception) means one + # process_chunk task raised unexpectedly, killing the batch chord. This is + # usually caused by pod preemption or a transient broker/DB blip, not a + # genuine data error. Try to recover by resetting any stuck PROCESSING + # chunks and re-dispatching the dispatcher task (up to MAX_RECOVERY_ATTEMPTS). + MAX_RECOVERY_ATTEMPTS = 3 + is_chord_failure = isinstance(request_obj_uuid_or_exc, str) and not custom_message + # Don't try to recover the final completion chain (its workflow_name has a + # different prefix); only recover the chunk-processing loop. + is_completion_chain = workflow_name.startswith("spatial_lookup_completion_") + if is_chord_failure and not is_completion_chain: + try: + workflow_data = checkpoint_manager.get_workflow_data(workflow_name) + recovery_count = (workflow_data.recovery_attempts if workflow_data else 0) + if recovery_count < MAX_RECOVERY_ATTEMPTS: + # Reset any chunks stuck in PROCESSING to FAILED_RETRYABLE. + recovered = checkpoint_manager.recover_stale_processing_chunks( + workflow_name, stale_threshold_seconds=0 # treat ALL PROCESSING as stale + ) + if recovered: + celery_logger.warning( + f"{log_message_prefix} Chord failure recovery: reset {recovered} " + f"PROCESSING chunk(s) to FAILED_RETRYABLE." + ) + # Increment recovery_attempts counter and set status back to processing_chunks. + checkpoint_manager.update_workflow( + table_name=workflow_name, + status="processing_chunks", + last_error=f"[recovery {recovery_count + 1}/{MAX_RECOVERY_ATTEMPTS}] {error_info}", + recovery_attempts=recovery_count + 1, + ) + # Re-dispatch the dispatcher task using the stored params. + dispatch_params = checkpoint_manager.get_dispatch_params(workflow_name) + if dispatch_params: + process_table_in_chunks.apply_async( + kwargs={ + **dispatch_params, + "prioritize_failed_chunks": True, + "initial_run": False, + }, + ) + celery_logger.warning( + f"{log_message_prefix} Chord failure recovery attempt " + f"{recovery_count + 1}/{MAX_RECOVERY_ATTEMPTS}: re-dispatched " + f"process_table_in_chunks. Failed task UUID: {request_obj_uuid_or_exc}" + ) + return # Do NOT mark FAILED_PERMANENT — recovery dispatched. + else: + celery_logger.error( + f"{log_message_prefix} No dispatch params found for recovery. " + f"Cannot re-dispatch. Marking FAILED_PERMANENT." + ) + else: + celery_logger.error( + f"{log_message_prefix} Chord failure recovery exhausted " + f"({recovery_count}/{MAX_RECOVERY_ATTEMPTS} attempts). Marking FAILED_PERMANENT." + ) + except Exception as e_recovery: + celery_logger.error( + f"{log_message_prefix} Error during chord failure recovery: {e_recovery}. " + f"Falling through to FAILED_PERMANENT.", + exc_info=True, + ) + try: checkpoint_manager.update_workflow( table_name=workflow_name, status=final_status, last_error=error_info diff --git a/static/js/running_uploads.js b/static/js/running_uploads.js index e23538df..4eed00c7 100644 --- a/static/js/running_uploads.js +++ b/static/js/running_uploads.js @@ -325,6 +325,9 @@ function startPollingForJob(jobId) { if (runningJobsData.activePollers[jobId]) return; console.log(`[RunningUploads] Starting polling for job: ${jobId}`); + let consecutiveFailures = 0; + const MAX_CONSECUTIVE_FAILURES = 3; + runningJobsData.activePollers[jobId] = setInterval(async () => { try { const response = await fetch( @@ -344,23 +347,39 @@ function startPollingForJob(jobId) { stopPollingForJob(jobId); return; } + if (response.status === 401) { + updateJobInState(jobId, { + status: "error", + error: "Session expired. Please refresh the page to continue monitoring.", + phase: "Session Expired", + }); + stopPollingForJob(jobId); + return; + } throw new Error( `Failed to get status for ${jobId} (${response.status})` ); } const data = await response.json(); + consecutiveFailures = 0; updateJobInState(jobId, data); if (!isJobActive(data.status)) { stopPollingForJob(jobId); } } catch (error) { - console.error(`[RunningUploads] Error polling for job ${jobId}:`, error); - updateJobInState(jobId, { - status: "error", - error: "Polling failed. Could not retrieve latest status.", - phase: "Polling Error", - }); - stopPollingForJob(jobId); + consecutiveFailures++; + console.error( + `[RunningUploads] Error polling for job ${jobId} (failure ${consecutiveFailures}/${MAX_CONSECUTIVE_FAILURES}):`, + error + ); + if (consecutiveFailures >= MAX_CONSECUTIVE_FAILURES) { + updateJobInState(jobId, { + status: "error", + error: "Polling failed. Could not retrieve latest status.", + phase: "Polling Error", + }); + stopPollingForJob(jobId); + } } }, 5000); } diff --git a/static/js/step1.js b/static/js/step1.js index 12c35253..74965919 100644 --- a/static/js/step1.js +++ b/static/js/step1.js @@ -226,7 +226,12 @@ document.addEventListener("alpine:init", () => { if (!response.ok) { const errorText = await response.text(); console.error(`Failed response (${response.status}):`, errorText); - throw new Error(`Failed to get upload URL (${response.status})`); + let errorMessage = `HTTP ${response.status}`; + try { + const errorData = JSON.parse(errorText); + if (errorData.message) errorMessage = errorData.message; + } catch {} + throw new Error(`Failed to get upload URL: ${errorMessage}`); } const data = await response.json(); diff --git a/static/js/step3.js b/static/js/step3.js index 6448d2ed..336b8c6f 100644 --- a/static/js/step3.js +++ b/static/js/step3.js @@ -12,10 +12,11 @@ document.addEventListener("alpine:init", () => { voxel_resolution_nm_y: 1, voxel_resolution_nm_z: 1, write_permission: "PRIVATE", - read_permission: "PRIVATE", + read_permission: "PUBLIC", validationErrors: {}, isReferenceSchema: false, metadataSaved: false, + stagingConflict: null, // { message, row_count } when staging table exists }, init() { @@ -52,6 +53,9 @@ document.addEventListener("alpine:init", () => { saveState() { const stateToSave = { ...this.state }; delete stateToSave.validationErrors; + delete stateToSave.stagingConflict; + delete stateToSave.isReferenceSchema; + delete stateToSave.metadataSaved; localStorage.setItem("metadataStore", JSON.stringify(stateToSave)); }, @@ -101,35 +105,52 @@ document.addEventListener("alpine:init", () => { return Object.keys(errors).length === 0; }, - async saveMetadata() { + _buildPayload(forceOverwrite = false) { + return { + schema_type: this.state.schema_type, + datastack_name: this.state.datastack_name, + table_name: this.state.table_name, + description: this.state.description, + notice_text: this.state.notice_text, + reference_table: this.state.reference_table, + flat_segmentation_source: this.state.flat_segmentation_source, + voxel_resolution_nm_x: parseFloat(this.state.voxel_resolution_nm_x), + voxel_resolution_nm_y: parseFloat(this.state.voxel_resolution_nm_y), + voxel_resolution_nm_z: parseFloat(this.state.voxel_resolution_nm_z), + write_permission: this.state.write_permission, + read_permission: this.state.read_permission, + ...(forceOverwrite ? { force_overwrite: true } : {}), + }; + }, + + async saveMetadata(forceOverwrite = false) { if (!this.validateForm()) { return false; } + this.state.stagingConflict = null; + try { const response = await fetch("/materialize/upload/api/save-metadata", { method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - schema_type: this.state.schema_type, - datastack_name: this.state.datastack_name, - table_name: this.state.table_name, - description: this.state.description, - notice_text: this.state.notice_text, - reference_table: this.state.reference_table, - flat_segmentation_source: this.state.flat_segmentation_source, - voxel_resolution_nm_x: parseFloat(this.state.voxel_resolution_nm_x), - voxel_resolution_nm_y: parseFloat(this.state.voxel_resolution_nm_y), - voxel_resolution_nm_z: parseFloat(this.state.voxel_resolution_nm_z), - write_permission: this.state.write_permission, - read_permission: this.state.read_permission, - }), + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(this._buildPayload(forceOverwrite)), }); if (!response.ok) { - throw new Error("Failed to save metadata"); + let errorData = {}; + try { errorData = await response.json(); } catch (_) {} + + // Staging conflict: table exists from a previous run — ask user to confirm overwrite + if (response.status === 409 && errorData.staging_exists) { + this.state.stagingConflict = { + message: errorData.message, + row_count: errorData.row_count, + }; + return false; + } + + throw new Error(errorData.message || "Failed to save metadata"); } const data = await response.json(); @@ -143,6 +164,11 @@ document.addEventListener("alpine:init", () => { } }, + async confirmOverwrite() { + this.state.stagingConflict = null; + return await this.saveMetadata(true); + }, + isValid() { return this.validateForm(); }, diff --git a/static/js/step4.js b/static/js/step4.js index 893cc282..9d8cfcfe 100644 --- a/static/js/step4.js +++ b/static/js/step4.js @@ -124,9 +124,22 @@ document.addEventListener("alpine:init", () => { JSON.parse(schemaStoreData); const metadataFromStore = JSON.parse(metadataStoreData); + // Explicitly pick only the fields the backend schema expects, + // excluding UI-only state (validationErrors, metadataSaved, stagingConflict, etc.) const metadataPayload = { - ...metadataFromStore, schema_type: metadataFromStore.schema_type || selectedSchema, + datastack_name: metadataFromStore.datastack_name, + table_name: metadataFromStore.table_name, + description: metadataFromStore.description, + notice_text: metadataFromStore.notice_text, + reference_table: metadataFromStore.reference_table, + flat_segmentation_source: metadataFromStore.flat_segmentation_source, + voxel_resolution_nm_x: metadataFromStore.voxel_resolution_nm_x, + voxel_resolution_nm_y: metadataFromStore.voxel_resolution_nm_y, + voxel_resolution_nm_z: metadataFromStore.voxel_resolution_nm_z, + write_permission: metadataFromStore.write_permission, + read_permission: metadataFromStore.read_permission, + ...(metadataFromStore.force_overwrite ? { force_overwrite: true } : {}), }; if (!this.state.inputFile) { @@ -156,12 +169,14 @@ document.addEventListener("alpine:init", () => { if (!response.ok) { let errorText = "Failed to start processing."; try { - const errorData = await response.json(); - errorText = - errorData.message || errorData.error || JSON.stringify(errorData); - } catch (e) { - errorText = (await response.text()) || errorText; - } + const rawText = await response.text(); + try { + const errorData = JSON.parse(rawText); + errorText = errorData.message || errorData.error || rawText; + } catch (_) { + errorText = rawText || errorText; + } + } catch (_) {} throw new Error(`Server error (${response.status}): ${errorText}`); } diff --git a/templates/upload/step3.html b/templates/upload/step3.html index 341698c5..3b27c946 100644 --- a/templates/upload/step3.html +++ b/templates/upload/step3.html @@ -6,6 +6,22 @@

Step 3: Annotation Table Metadata

+ +
+ Existing Staging Data Found +

+
+ + +
+
+