Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 45 additions & 15 deletions materializationengine/celery_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,20 +116,31 @@ def celery_loggers(logger, *args, **kwargs):
@worker_process_init.connect
def configure_http_connection_pools(sender=None, **kwargs):
"""
Increase urllib3/requests HTTP connection pool sizes for each forked worker process.

The default pool_maxsize=10 per host is exhausted during parallel CloudVolume
supervoxel lookups (scattered_points makes many concurrent requests to
storage.googleapis.com). Each discarded connection requires a new TCP+TLS
handshake on the next request, adding latency to every supervoxel lookup.

We patch HTTPAdapter.__init__ so every Session created in this process
(including sessions created internally by cloud-files/cloudvolume) uses the
larger pool. Explicit callers that pass their own pool_maxsize are unaffected.

Tune with the GCS_CONNECTION_POOL_SIZE environment variable (default: 128).
Increase GCS urllib3 connection pool sizes for each forked worker process.

CloudVolume uses cloud-files with use_https=True, which converts gs:// paths to
https://storage.googleapis.com/... and routes them through HttpInterface.
HttpInterface holds a class-level HTTPAdapter (created at import time with the
default pool_maxsize=10) that is shared across ALL instances and sessions.
With CLOUDVOLUME_PARALLEL concurrent threads all funnelling through that one
adapter, the pool fills immediately → discarded connections → TCP+TLS handshake
on every request.

Three-part fix (all run once per forked worker process):
1. Patch HTTPAdapter.__init__ so any new Session/AuthorizedSession created after
this hook uses pool_maxsize=GCS_CONNECTION_POOL_SIZE (default: 128).
2. Replace HttpInterface.adaptor (the shared class-level adapter) with a fresh
HTTPAdapter that has the larger pool. This is the critical fix for use_https
paths because the old adapter was created before the patch could apply.
3. Reset cloud-files' GC_POOL and invalidate cloudvolume_cache so any gs://
connections inherited from the parent process are discarded; fresh ones pick
up the patched HTTPAdapter.

Tune with GCS_CONNECTION_POOL_SIZE environment variable (default: 128).
"""
from requests.adapters import HTTPAdapter
import cloudfiles.interfaces as cf_interfaces
from materializationengine.cloudvolume_gateway import cloudvolume_cache

pool_size = int(os.environ.get("GCS_CONNECTION_POOL_SIZE", "128"))
_orig_init = HTTPAdapter.__init__
Expand All @@ -138,10 +149,29 @@ def _patched_init(self, pool_connections=pool_size, pool_maxsize=pool_size, **kw
_orig_init(self, pool_connections=pool_connections, pool_maxsize=pool_maxsize, **kw)

HTTPAdapter.__init__ = _patched_init

# Replace the class-level adapter shared by all HttpInterface instances.
# This is the primary fix for use_https=True (https://storage.googleapis.com)
# paths: the old class-level adapter has pool_maxsize=10 and cannot be patched
# retroactively via HTTPAdapter.__init__.
cf_interfaces.HttpInterface.adaptor = HTTPAdapter(
pool_connections=pool_size, pool_maxsize=pool_size
)

# For gs:// paths (non-use_https): discard GCS bucket connections inherited
# from the parent process. reset_connection_pools() replaces the global
# GC_POOL with fresh empty queues; the next gs:// request creates a new
# google.cloud.storage.Client → AuthorizedSession → patched HTTPAdapter.
cf_interfaces.reset_connection_pools()

# Clear any CloudVolume client objects that hold references to old connections.
# They are re-populated lazily on first use in this worker process.
cloudvolume_cache.invalidate_cache()

celery_logger.info(
f"[worker_process_init] HTTP connection pool defaults set to "
f"pool_connections={pool_size}, pool_maxsize={pool_size} "
f"(GCS_CONNECTION_POOL_SIZE={pool_size})."
f"[worker_process_init] GCS connection pool reset: "
f"HTTPAdapter defaults patched to pool_maxsize={pool_size}, "
f"HttpInterface.adaptor replaced, GC_POOL reset, cloudvolume_cache invalidated."
)


Expand Down
7 changes: 3 additions & 4 deletions materializationengine/cloudvolume_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import os

# Number of parallel threads CloudVolume uses internally for fetching data.
# Matches the GCS connection pool size so all threads can make concurrent requests
# without exhausting the pool. Tune with CLOUDVOLUME_PARALLEL env var.
_CV_PARALLEL = int(os.environ.get("CLOUDVOLUME_PARALLEL", "32"))
# Default of 10 matches the urllib3 pool_maxsize (10) used by cloud-files/google-cloud-storage,
# avoiding connection pool overflow. Override with CLOUDVOLUME_PARALLEL env var.
_CV_PARALLEL = int(os.environ.get("CLOUDVOLUME_PARALLEL", "10"))


class CloudVolumeGateway:
Expand Down Expand Up @@ -57,7 +57,6 @@ def _get_cv_client(
cv_client = cloudvolume.CloudVolume(
seg_source,
mip=mip_level,
use_https=True,
bounded=False,
fill_missing=True,
green_threads=use_green_threads,
Expand Down
56 changes: 22 additions & 34 deletions materializationengine/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dynamicannotationdb import DynamicAnnotationInterface
from flask import current_app
from sqlalchemy import MetaData, create_engine, text
from sqlalchemy import MetaData, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.pool import QueuePool

Expand Down Expand Up @@ -59,41 +59,29 @@ def get_engine(self, database_name: str):
SQL_URI_CONFIG = current_app.config["SQLALCHEMY_DATABASE_URI"]
sql_base_uri = SQL_URI_CONFIG.rpartition("/")[0]
sql_uri = f"{sql_base_uri}/{database_name}"

pool_size = current_app.config.get("DB_CONNECTION_POOL_SIZE", 20)
max_overflow = current_app.config.get("DB_CONNECTION_MAX_OVERFLOW", 30)

try:
engine = create_engine(
sql_uri,
poolclass=QueuePool,
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=30,
pool_recycle=1800, # Recycle connections after 30 minutes
pool_pre_ping=True, # Ensure connections are still valid
)

# Test the connection to make sure the database exists and is accessible
with engine.connect() as conn:
conn.execute(text("SELECT 1"))

# Only store engine if connection test passes
self._engines[database_name] = engine
celery_logger.info(f"Created new connection pool for {database_name} "
f"(size={pool_size}, max_overflow={max_overflow})")

except Exception as e:
# Clean up engine if it was created but connection failed
if 'engine' in locals():
engine.dispose()

celery_logger.error(f"Failed to create/connect to database {database_name}: {e}")
raise ConnectionError(f"Cannot connect to database '{database_name}'. "
f"Please check if the database exists and is accessible. "
f"Connection URI: {sql_uri}. "
f"Error: {e}")


engine = create_engine(
sql_uri,
poolclass=QueuePool,
pool_size=pool_size,
max_overflow=max_overflow,
pool_timeout=30,
pool_recycle=1800, # Recycle connections after 30 minutes
pool_pre_ping=True, # Test connections before use; reconnect if stale
)
# Cache immediately — pool_pre_ping handles connectivity on first checkout.
# Previously we ran a SELECT 1 test here and didn't cache on failure, which
# caused a transient error to permanently break this worker's engine cache,
# leading to repeated ConnectionError → _on_connection_error → consumer isolation.
self._engines[database_name] = engine
celery_logger.info(
f"Created engine for {database_name} "
f"(pool_size={pool_size}, max_overflow={max_overflow})"
)

return self._engines[database_name]

def get_session_factory(self, database_name: str):
Expand Down
21 changes: 21 additions & 0 deletions materializationengine/kvdb_gateway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

from kvdbclient.bigtable import BigTableConfig
from kvdbclient.bigtable.client import Client


class KVDBGateway:
def __init__(self, project: str, instance: str):
self._clients = {}
self._config = BigTableConfig(PROJECT=project, INSTANCE=instance, ADMIN=False)

def get_client(self, table_id: str) -> Client:
if table_id not in self._clients:
self._clients[table_id] = Client(table_id=table_id, config=self._config)
return self._clients[table_id]


kvdb_cache = KVDBGateway(
project=os.environ.get("BIGTABLE_PROJECT", ""),
instance=os.environ.get("BIGTABLE_INSTANCE", "pychunkedgraph"),
)
98 changes: 56 additions & 42 deletions materializationengine/workflows/spatial_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
RedisCheckpointManager,
)
from materializationengine.celery_init import celery
from materializationengine.chunkedgraph_gateway import chunkedgraph_cache
from materializationengine.kvdb_gateway import kvdb_cache
from materializationengine.cloudvolume_gateway import cloudvolume_cache
from materializationengine.database import db_manager, dynamic_annotation_cache
from materializationengine.index_manager import index_cache
Expand All @@ -55,7 +55,6 @@
)
from materializationengine.workflows.ingest_new_annotations import (
create_missing_segmentation_table,
get_root_ids,
)

Base = declarative_base()
Expand Down Expand Up @@ -337,7 +336,7 @@ def process_table_in_chunks(
database_name: str,
chunk_scale_factor: int,
supervoxel_batch_size: int,
batch_size_for_dispatch: int = 10,
batch_size_for_dispatch: int = 50,
prioritize_failed_chunks: bool = True,
initial_run: bool = False,
):
Expand Down Expand Up @@ -495,6 +494,22 @@ def process_table_in_chunks(
full_chain.apply_async()
return f"No chunks to process for {annotation_table_name}. Finalizing."

# Recover stale chunks on EVERY conductor invocation so that preempted
# sub-batch workers are detected within 600s regardless of whether there
# are still pending chunks in the queue to dispatch.
recovered_subtasks = checkpoint_manager.recover_stale_processing_subtasks(
workflow_name, stale_threshold_seconds=600
)
recovered_processing = checkpoint_manager.recover_stale_processing_chunks(
workflow_name, stale_threshold_seconds=600
)
recovered = recovered_subtasks + recovered_processing
if recovered:
celery_logger.info(
f"Recovered {recovered} stale chunk(s) for {workflow_name} "
f"({recovered_subtasks} PROCESSING_SUBTASKS, {recovered_processing} PROCESSING)."
)

chunk_indices_to_process, new_failed_cursor, new_pending_cursor = (
checkpoint_manager.get_chunks_to_process(
table_name=workflow_name,
Expand Down Expand Up @@ -562,27 +577,10 @@ def process_table_in_chunks(
full_chain.apply_async()
return f"All chunks processed for {annotation_table_name}. Finalizing."
else:
# Before sleeping, check whether any chunks stuck in PROCESSING or
# PROCESSING_SUBTASKS state have been there long enough to be treated
# as lost (pod killed, broker blip, etc.) and should be retried.
recovered_subtasks = checkpoint_manager.recover_stale_processing_subtasks(
workflow_name, stale_threshold_seconds=600
)
recovered_processing = checkpoint_manager.recover_stale_processing_chunks(
workflow_name, stale_threshold_seconds=600
)
recovered = recovered_subtasks + recovered_processing
if recovered:
celery_logger.info(
f"Recovered {recovered} stale chunk(s) for {workflow_name} "
f"({recovered_subtasks} PROCESSING_SUBTASKS, {recovered_processing} PROCESSING). "
f"Retrying dispatcher immediately to re-dispatch them."
)
raise self.retry(countdown=0)
celery_logger.info(
f"No chunks returned by get_chunks_to_process for {workflow_name}, but scan may not be exhausted or non-terminal chunks exist. Retrying dispatcher."
)
raise self.retry(countdown=30)
raise self.retry(countdown=30 if not recovered else 0)

processing_tasks = []
for chunk_idx_to_process in chunk_indices_to_process:
Expand Down Expand Up @@ -906,25 +904,37 @@ def process_chunk(

except ConnectionError as e:
# SQL proxy / infra failure — NOT a data error.
# Reset chunk to PENDING so the attempt budget is NOT consumed; a healthy
# pod will re-pick it up after the countdown.
celery_logger.warning(
f"{log_prefix} DB connection error (SQL proxy down?): {e}. "
f"Resetting chunk to PENDING and retrying in 5 minutes."
)
# Use a finite retry limit (5 × 300 s = 25 min max) so the outer chord
# eventually completes rather than blocking the conductor indefinitely.
# On exhaustion, mark FAILED_RETRYABLE and return normally so the chord
# body (next conductor invocation) can fire and stale detection re-queues
# the chunk.
_on_connection_error(database_name, self)
try:
checkpoint_manager.set_chunk_status(
workflow_name,
chunk_idx,
CHUNK_STATUS_PENDING,
{"message": f"Connection error — will retry: {str(e)[:300]}"},
)
except Exception as status_err:
celery_logger.warning(
f"{log_prefix} Could not reset chunk status to PENDING: {status_err}"
f"{log_prefix} DB connection error (SQL proxy down?): {e}. "
f"Retrying in 5 minutes (attempt {self.request.retries + 1}/5)."
)
_on_connection_error(database_name, self)
raise self.retry(exc=e, countdown=300, max_retries=None)
raise self.retry(exc=e, countdown=300, max_retries=5)
except MaxRetriesExceededError:
celery_logger.error(
f"{log_prefix} Max connection-error retries exceeded on process_chunk. "
f"Marking chunk FAILED_RETRYABLE so conductor can re-dispatch."
)
error_payload = {
"error_message": f"Connection error retries exhausted in process_chunk: {str(e)}",
"error_type": type(e).__name__,
"attempt_count": workflow_attempt_count + 1,
"celery_task_id": self.request.id,
}
checkpoint_manager.set_chunk_status(
workflow_name, chunk_idx, CHUNK_STATUS_FAILED_RETRYABLE, error_payload
)
return {
"status": "failed_connection_error_retries_exhausted",
"chunk_idx": chunk_idx,
"marked_retryable_in_checkpoint": True,
}

except (OperationalError, DisconnectionError) as e:
celery_logger.warning(
Expand Down Expand Up @@ -1125,7 +1135,7 @@ def get_root_ids_from_supervoxels(
"""
start_time = time.time()

pcg_table_name = mat_metadata.get("pcg_table_name")
pcg_table_name: str = mat_metadata["pcg_table_name"]
database = mat_metadata.get("database")

try:
Expand Down Expand Up @@ -1159,6 +1169,7 @@ def get_root_ids_from_supervoxels(
col for col in supervoxel_df.columns if col.endswith("supervoxel_id")
]


root_id_col_names = [
col.replace("supervoxel_id", "root_id") for col in supervoxel_col_names
]
Expand Down Expand Up @@ -1193,7 +1204,7 @@ def get_root_ids_from_supervoxels(
if existing_value and existing_value > 0:
root_ids_df.at[idx, root_col] = existing_value

cg_client = chunkedgraph_cache.init_pcg(pcg_table_name)
cg_client = kvdb_cache.get_client(pcg_table_name)

for sv_col in supervoxel_col_names:
root_col = sv_col.replace("supervoxel_id", "root_id")
Expand All @@ -1212,8 +1223,11 @@ def get_root_ids_from_supervoxels(

if not supervoxels_to_lookup.empty:
try:
root_ids = get_root_ids(
cg_client, supervoxels_to_lookup, materialization_time_stamp
root_ids = np.squeeze(
cg_client.root_ext.get_roots(
supervoxels_to_lookup.to_numpy(),
time_stamp=materialization_time_stamp,
)
)

root_ids_df.loc[sv_mask, root_col] = root_ids
Expand Down Expand Up @@ -1411,7 +1425,7 @@ def insert_segmentation_data(


@celery.task(
name="workflow:process_and_insert_sub_batch",
name="process:process_and_insert_sub_batch",
bind=True,
acks_late=True,
autoretry_for=(OperationalError, DisconnectionError, ChunkDataValidationError),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"Flask-Limiter[redis]",
"cryptography>=44.0.2",
"uwsgi>=2.0.30",
"kvdbclient[extensions]",
]
authors = [
{ name = "Forrest Collman", email = "forrestc@alleninstitute.org" },
Expand Down
Loading
Loading