diff --git a/.gitignore b/.gitignore index acd9206..1ca7607 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +my_yamls/ +browser/ +.codex +.vite/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -162,4 +167,13 @@ cython_debug/ #.idea/ # Misc -.vscode/ \ No newline at end of file +.vscode/ + +# Project-specific +ignore/ +*.zarr/ +.claude/ +test_corrections.zarr/ +correction_slices/ +corrections/ +output/ \ No newline at end of file diff --git a/cellmap_flow/cli/server_cli.py b/cellmap_flow/cli/server_cli.py index 1bff004..d733c5b 100644 --- a/cellmap_flow/cli/server_cli.py +++ b/cellmap_flow/cli/server_cli.py @@ -23,7 +23,6 @@ from cellmap_flow.utils.plugin_manager import load_plugins -logging.basicConfig() logger = logging.getLogger(__name__) @@ -47,7 +46,7 @@ def cli(log_level): cellmap_flow_server script -s /path/to/script.py -d /path/to/data cellmap_flow_server cellmap -f /path/to/model -n mymodel -d /path/to/data """ - logging.basicConfig(level=getattr(logging, log_level.upper())) + logging.basicConfig(level=getattr(logging, log_level.upper()), force=True) @cli.command(name="list-models") diff --git a/cellmap_flow/cli/yaml_cli.py b/cellmap_flow/cli/yaml_cli.py index 1b829af..54211fb 100644 --- a/cellmap_flow/cli/yaml_cli.py +++ b/cellmap_flow/cli/yaml_cli.py @@ -58,16 +58,17 @@ def _submit_model(model): ) return model_name - with ThreadPoolExecutor(max_workers=len(models)) as executor: - futures = {executor.submit(_submit_model, model): model for model in models} - for future in as_completed(futures): - try: - name = future.result() - logger.info(f"Job for {name} is ready") - except Exception as e: - model = futures[future] - model_name = getattr(model, "name", None) or type(model).__name__ - logger.error(f"Failed to start job for {model_name}: {e}") + if models: + with ThreadPoolExecutor(max_workers=len(models)) as executor: + futures = {executor.submit(_submit_model, model): model for model in models} + for future in as_completed(futures): + try: + name = future.result() + logger.info(f"Job for {name} is ready") + except Exception as e: + model = futures[future] + model_name = getattr(model, "name", None) or type(model).__name__ + logger.error(f"Failed to start job for {model_name}: {e}") generate_neuroglancer_url(dataset_path,wrap_raw=wrap_raw) @@ -191,7 +192,11 @@ def main(config_path: str, log_level: str, list_types: bool, validate_only: bool # Build model configuration objects dynamically logger.info("Building model configurations...") - g.models_config = build_models(config["models"]) + if config["models"]: + g.models_config = build_models(config["models"]) + else: + g.models_config = [] + logger.info("No models configured — starting dashboard for interactive use") logger.info(f"Configured {len(g.models_config)} model(s):") for i, model in enumerate(g.models_config, 1): diff --git a/cellmap_flow/dashboard/app.py b/cellmap_flow/dashboard/app.py index 6210495..a19f61c 100644 --- a/cellmap_flow/dashboard/app.py +++ b/cellmap_flow/dashboard/app.py @@ -5,8 +5,7 @@ from flask import Flask from flask_cors import CORS -from cellmap_flow.dashboard import state -from cellmap_flow.dashboard.state import LogHandler +from cellmap_flow.globals import g, LogHandler from cellmap_flow.dashboard.routes.logging_routes import logging_bp from cellmap_flow.dashboard.routes.index_page import index_bp from cellmap_flow.dashboard.routes.pipeline_builder_page import pipeline_builder_bp @@ -14,6 +13,7 @@ from cellmap_flow.dashboard.routes.pipeline import pipeline_bp from cellmap_flow.dashboard.routes.blockwise import blockwise_bp from cellmap_flow.dashboard.routes.bbx_generator import bbx_bp +from cellmap_flow.dashboard.routes.finetune import finetune_bp logger = logging.getLogger(__name__) @@ -37,11 +37,12 @@ app.register_blueprint(pipeline_bp) app.register_blueprint(blockwise_bp) app.register_blueprint(bbx_bp) +app.register_blueprint(finetune_bp) def create_and_run_app(neuroglancer_url=None, inference_servers=None): - state.NEUROGLANCER_URL = neuroglancer_url - state.INFERENCE_SERVER = inference_servers + g.NEUROGLANCER_URL = neuroglancer_url + g.INFERENCE_SERVER = inference_servers hostname = socket.gethostname() port = 0 logger.warning(f"Host name: {hostname}") diff --git a/cellmap_flow/dashboard/finetune_utils.py b/cellmap_flow/dashboard/finetune_utils.py new file mode 100644 index 0000000..29ed04a --- /dev/null +++ b/cellmap_flow/dashboard/finetune_utils.py @@ -0,0 +1,1087 @@ +""" +Helper functions for finetuning annotation workflows. + +Handles MinIO server management, annotation zarr creation, and +periodic synchronization of annotations between MinIO and local disk. +""" + +import json +import os +import re +import socket +import subprocess +import time +import logging +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path + +import numpy as np +import s3fs +import zarr + +from cellmap_flow.globals import g + +minio_state = g.minio_state +annotation_volumes = g.annotation_volumes +output_sessions = g.output_sessions + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Session management +# --------------------------------------------------------------------------- + +def get_or_create_session_path(base_output_path: str) -> str: + """ + Get or create a timestamped session directory for the given base output path. + + If a session already exists for this base path, reuse it. + Otherwise, create a new timestamped subdirectory. + + Args: + base_output_path: Base output directory (e.g., "output/to/here") + + Returns: + Timestamped session path (e.g., "output/to/here/20260213_123456") + """ + base_output_path = os.path.expanduser(base_output_path) + + if base_output_path not in output_sessions: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + session_path = os.path.join(base_output_path, timestamp) + output_sessions[base_output_path] = session_path + logger.info(f"Created new session path: {session_path}") + + return output_sessions[base_output_path] + + +# --------------------------------------------------------------------------- +# Network helpers +# --------------------------------------------------------------------------- + +def get_local_ip(): + """Get the local IP address for MinIO server.""" + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + local_ip = s.getsockname()[0] + s.close() + return local_ip + except Exception: + return "127.0.0.1" + + +def find_available_port(start_port=9000): + """Find an available port pair for MinIO server (API on port, console on port+1).""" + for port in range(start_port, start_port + 100): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: + s1.bind(("", port)) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: + s2.bind(("", port + 1)) + return port + except OSError: + continue + raise RuntimeError("Could not find available port for MinIO") + + +# --------------------------------------------------------------------------- +# Zarr creation +# --------------------------------------------------------------------------- + +def create_correction_zarr( + zarr_path, + raw_crop_shape, + raw_voxel_size, + raw_offset, + annotation_crop_shape, + annotation_voxel_size, + annotation_offset, + dataset_path, + model_name, + output_channels, + raw_dtype="uint8", + create_mask=False, +): + """ + Create a correction zarr with OME-NGFF v0.4 metadata. + + Structure: + crop_id.zarr/ + raw/s0/ (uint8, shape=raw_crop_shape) + annotation/s0/ (uint8, shape=annotation_crop_shape) + mask/s0/ (optional, uint8, shape=annotation_crop_shape) + .zattrs (metadata) + + Returns: + (success: bool, info: str) + """ + try: + def add_ome_ngff_metadata(group, name, voxel_size, translation_offset=None): + """Add OME-NGFF v0.4 metadata.""" + if translation_offset is not None: + physical_translation = [ + float(o * v) for o, v in zip(translation_offset, voxel_size) + ] + else: + physical_translation = [0.0, 0.0, 0.0] + + transforms = [{"type": "scale", "scale": [float(v) for v in voxel_size]}] + + if translation_offset is not None: + transforms.append( + {"type": "translation", "translation": physical_translation} + ) + + group.attrs["multiscales"] = [ + { + "version": "0.4", + "name": name, + "axes": [ + {"name": "z", "type": "space", "unit": "nanometer"}, + {"name": "y", "type": "space", "unit": "nanometer"}, + {"name": "x", "type": "space", "unit": "nanometer"}, + ], + "datasets": [ + {"path": "s0", "coordinateTransformations": transforms} + ], + } + ] + + root = zarr.open(zarr_path, mode="w") + + # Raw group + raw_group = root.create_group("raw") + raw_group.create_dataset( + "s0", + shape=tuple(raw_crop_shape), + chunks=(64, 64, 64), + dtype=raw_dtype, + compressor=zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE), + fill_value=0, + ) + add_ome_ngff_metadata(raw_group, "raw", raw_voxel_size, raw_offset) + + # Annotation group + annotation_group = root.create_group("annotation") + annotation_group.create_dataset( + "s0", + shape=tuple(annotation_crop_shape), + chunks=(64, 64, 64), + dtype="uint8", + compressor=zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE), + fill_value=0, + ) + add_ome_ngff_metadata( + annotation_group, "annotation", annotation_voxel_size, annotation_offset + ) + + # Optional mask group + if create_mask: + mask_group = root.create_group("mask") + mask_group.create_dataset( + "s0", + shape=tuple(annotation_crop_shape), + chunks=(64, 64, 64), + dtype="uint8", + compressor=zarr.Blosc( + cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE + ), + fill_value=0, + ) + add_ome_ngff_metadata( + mask_group, "mask", annotation_voxel_size, annotation_offset + ) + + # Root metadata + root.attrs["roi"] = { + "raw_offset": ( + raw_offset.tolist() + if hasattr(raw_offset, "tolist") + else list(raw_offset) + ), + "raw_shape": ( + raw_crop_shape.tolist() + if hasattr(raw_crop_shape, "tolist") + else list(raw_crop_shape) + ), + "annotation_offset": ( + annotation_offset.tolist() + if hasattr(annotation_offset, "tolist") + else list(annotation_offset) + ), + "annotation_shape": ( + annotation_crop_shape.tolist() + if hasattr(annotation_crop_shape, "tolist") + else list(annotation_crop_shape) + ), + } + root.attrs["raw_voxel_size"] = ( + raw_voxel_size.tolist() + if hasattr(raw_voxel_size, "tolist") + else list(raw_voxel_size) + ) + root.attrs["annotation_voxel_size"] = ( + annotation_voxel_size.tolist() + if hasattr(annotation_voxel_size, "tolist") + else list(annotation_voxel_size) + ) + root.attrs["model_name"] = model_name + root.attrs["dataset_path"] = dataset_path + root.attrs["created_at"] = datetime.now().isoformat() + + logger.info(f"Created correction zarr at {zarr_path}") + + return True, zarr_path + + except Exception as e: + logger.error(f"Error creating zarr: {e}") + return False, str(e) + + +def create_annotation_volume_zarr( + zarr_path, + dataset_shape_voxels, + output_voxel_size, + dataset_offset_nm, + chunk_size, + dataset_path, + model_name, + input_size, + input_voxel_size, + claimed_output_voxel_size=None, + claimed_input_voxel_size=None, + input_norm_config=None, +): + """ + Create a sparse annotation volume zarr covering the full dataset extent. + + The volume has chunk_size = model output_size so each chunk maps to one + training sample. Only metadata files are created (no chunk data), so the + zarr is tiny regardless of dataset size. + + Label scheme: 0=unannotated (ignored), 1=background, 2=foreground. + + Args: + output_voxel_size, input_voxel_size: the EFFECTIVE voxel sizes used + for the actual grid alignment (typically the dataset's closest + available scale to the model's claimed voxel size). + claimed_output_voxel_size, claimed_input_voxel_size: optional — + the model's originally-declared voxel sizes, recorded for + provenance. + + Returns: + (success: bool, info: str) + """ + try: + root = zarr.open(zarr_path, mode="w") + + annotation_group = root.create_group("annotation") + annotation_group.create_dataset( + "s0", + shape=tuple(dataset_shape_voxels), + chunks=tuple(chunk_size), + dtype="uint8", + compressor=zarr.Blosc(cname="zstd", clevel=3, shuffle=zarr.Blosc.SHUFFLE), + fill_value=0, + ) + + # OME-NGFF v0.4 metadata with translation for dataset offset + physical_translation = [float(o) for o in dataset_offset_nm] + transforms = [ + {"type": "scale", "scale": [float(v) for v in output_voxel_size]}, + {"type": "translation", "translation": physical_translation}, + ] + annotation_group.attrs["multiscales"] = [ + { + "version": "0.4", + "name": "annotation", + "axes": [ + {"name": "z", "type": "space", "unit": "nanometer"}, + {"name": "y", "type": "space", "unit": "nanometer"}, + {"name": "x", "type": "space", "unit": "nanometer"}, + ], + "datasets": [ + {"path": "s0", "coordinateTransformations": transforms} + ], + } + ] + + # Root metadata + root.attrs["type"] = "annotation_volume" + root.attrs["model_name"] = model_name + root.attrs["dataset_path"] = dataset_path + root.attrs["chunk_size"] = ( + chunk_size.tolist() if hasattr(chunk_size, "tolist") else list(chunk_size) + ) + root.attrs["output_voxel_size"] = ( + output_voxel_size.tolist() + if hasattr(output_voxel_size, "tolist") + else list(output_voxel_size) + ) + root.attrs["input_size"] = ( + input_size.tolist() if hasattr(input_size, "tolist") else list(input_size) + ) + root.attrs["input_voxel_size"] = ( + input_voxel_size.tolist() + if hasattr(input_voxel_size, "tolist") + else list(input_voxel_size) + ) + root.attrs["dataset_offset_nm"] = ( + dataset_offset_nm.tolist() + if hasattr(dataset_offset_nm, "tolist") + else list(dataset_offset_nm) + ) + root.attrs["dataset_shape_voxels"] = ( + dataset_shape_voxels.tolist() + if hasattr(dataset_shape_voxels, "tolist") + else list(dataset_shape_voxels) + ) + # Record the model's originally-declared voxel sizes for provenance. + # These may differ from the active output_voxel_size/input_voxel_size + # above when we've snapped to the dataset's closest available scale. + if claimed_output_voxel_size is not None: + root.attrs["claimed_output_voxel_size"] = ( + claimed_output_voxel_size.tolist() + if hasattr(claimed_output_voxel_size, "tolist") + else list(claimed_output_voxel_size) + ) + if claimed_input_voxel_size is not None: + root.attrs["claimed_input_voxel_size"] = ( + claimed_input_voxel_size.tolist() + if hasattr(claimed_input_voxel_size, "tolist") + else list(claimed_input_voxel_size) + ) + # Snapshot of the dashboard's input_norm at volume-creation time. + # Used as the baseline for Resume Existing (the new session inherits + # this normalization). Stored as the raw YAML-style dict so it round- + # trips via json.load / yaml.safe_load without any extra parsing. + if input_norm_config is not None: + root.attrs["input_norm"] = input_norm_config + root.attrs["created_at"] = datetime.now().isoformat() + + logger.info( + f"Created annotation volume zarr at {zarr_path} " + f"(shape={dataset_shape_voxels}, chunks={chunk_size})" + ) + + return True, zarr_path + + except Exception as e: + logger.error(f"Error creating annotation volume zarr: {e}") + return False, str(e) + + +# --------------------------------------------------------------------------- +# MinIO management +# --------------------------------------------------------------------------- + +def ensure_minio_serving(zarr_path, crop_id, output_base_dir=None): + """ + Ensure MinIO is running and upload zarr file. + + Args: + zarr_path: Path to zarr file to upload + crop_id: Unique identifier for the crop + output_base_dir: Base output directory (MinIO will use output_base_dir/.minio) + + Returns: + MinIO URL for the zarr file + """ + if minio_state["process"] is None or minio_state["process"].poll() is not None: + # Determine MinIO storage location + if output_base_dir: + minio_root = Path(output_base_dir) / ".minio" + minio_state["output_base"] = output_base_dir + else: + minio_root = Path("~/.minio-server").expanduser() + minio_state["output_base"] = None + + minio_root.mkdir(parents=True, exist_ok=True) + minio_state["minio_root"] = str(minio_root) + + ip = get_local_ip() + port = find_available_port() + + env = os.environ.copy() + env["MINIO_ROOT_USER"] = "minio" + env["MINIO_ROOT_PASSWORD"] = "minio123" + env["MINIO_API_CORS_ALLOW_ORIGIN"] = "*" + + minio_cmd = [ + "minio", + "server", + str(minio_root), + "--address", + f"{ip}:{port}", + "--console-address", + f"{ip}:{port+1}", + ] + + logger.info(f"Starting MinIO server at {ip}:{port}") + minio_proc = subprocess.Popen( + minio_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + time.sleep(3) + + if minio_proc.poll() is not None: + stderr = minio_proc.stderr.read().decode() if minio_proc.stderr else "" + raise RuntimeError(f"MinIO failed to start: {stderr}") + + minio_state["process"] = minio_proc + minio_state["port"] = port + minio_state["ip"] = ip + + logger.info(f"MinIO started (PID: {minio_proc.pid})") + + # Configure mc client + subprocess.run( + [ + "mc", + "alias", + "set", + "myserver", + f"http://{ip}:{port}", + "minio", + "minio123", + ], + check=True, + capture_output=True, + ) + + # Create bucket if needed + result = subprocess.run( + ["mc", "mb", f"myserver/{minio_state['bucket']}"], + capture_output=True, + text=True, + ) + if result.returncode != 0 and "already" not in result.stderr.lower(): + logger.warning(f"Bucket creation returned: {result.stderr}") + + # Make bucket public + subprocess.run( + ["mc", "anonymous", "set", "public", f"myserver/{minio_state['bucket']}"], + check=True, + capture_output=True, + ) + + # Start periodic sync thread + start_periodic_sync() + + # Upload zarr file + zarr_name = Path(zarr_path).name + target = f"myserver/{minio_state['bucket']}/{zarr_name}" + + logger.info(f"Uploading {zarr_name} to MinIO") + result = subprocess.run( + ["mc", "mirror", "--overwrite", zarr_path, target], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + raise RuntimeError(f"Failed to upload to MinIO: {result.stderr}") + + logger.info(f"Uploaded {zarr_name} to MinIO") + + minio_url = ( + f"http://{minio_state['ip']}:{minio_state['port']}" + f"/{minio_state['bucket']}/{zarr_name}" + ) + return minio_url + + +# --------------------------------------------------------------------------- +# S3 / MinIO sync helpers +# --------------------------------------------------------------------------- + +def _safe_epoch_timestamp(value) -> float: + """Convert LastModified-like values to epoch seconds, best-effort.""" + if value is None: + return 0.0 + if isinstance(value, datetime): + return float(value.timestamp()) + if isinstance(value, (int, float)): + return float(value) + try: + parsed = datetime.fromisoformat(str(value)) + return float(parsed.timestamp()) + except Exception: + return 0.0 + + +def _get_sync_worker_count() -> int: + """ + Determine thread count for chunk sync. + + Prefer scheduler-provided CPU counts (e.g., LSF bsub -n), then fall back + to process CPU affinity / system CPU count. + """ + env_candidates = [ + "LSB_DJOB_NUMPROC", + "LSB_MAX_NUM_PROCESSORS", + "NSLOTS", + "SLURM_CPUS_PER_TASK", + "OMP_NUM_THREADS", + ] + for key in env_candidates: + raw = os.environ.get(key) + if not raw: + continue + try: + value = int(raw) + if value > 0: + return value + except ValueError: + continue + + try: + return max(1, len(os.sched_getaffinity(0))) + except Exception: + return max(1, os.cpu_count() or 1) + + +def _copy_chunks_parallel(s3, copy_pairs): + """ + Copy chunk files from MinIO in parallel. + + Args: + s3: s3fs filesystem instance + copy_pairs: list of (src_chunk_path, dst_chunk_path_str) + """ + if not copy_pairs: + return + + available_workers = _get_sync_worker_count() + workers = max(1, min(len(copy_pairs), available_workers)) + + def _copy_one(src_dst): + src_chunk_path, dst_chunk_path = src_dst + s3.get(src_chunk_path, dst_chunk_path) + return src_chunk_path + + with ThreadPoolExecutor(max_workers=workers) as executor: + futures = [executor.submit(_copy_one, pair) for pair in copy_pairs] + for fut in as_completed(futures): + try: + fut.result() + except Exception as e: + logger.debug(f"Error syncing chunk in parallel copy: {e}") + + +def _make_s3_filesystem(): + """Create an s3fs filesystem pointed at the local MinIO instance.""" + return s3fs.S3FileSystem( + anon=False, + key="minio", + secret="minio123", + client_kwargs={ + "endpoint_url": f"http://{minio_state['ip']}:{minio_state['port']}", + "region_name": "us-east-1", + }, + ) + + +def _sync_zarr_group_metadata(s3, src_path, dst_path): + """Sync zarr group structure and metadata from S3 to local disk. + + Ensures destination arrays exist with correct shape/dtype and copies attrs. + """ + src_store = s3fs.S3Map(root=src_path, s3=s3) + src_group = zarr.open_group(store=src_store, mode="r") + + dst_store = zarr.DirectoryStore(str(dst_path)) + dst_group = zarr.open_group(store=dst_store, mode="a") + + for key in src_group.array_keys(): + src_array = src_group[key] + if key in dst_group: + dst_array = dst_group[key] + shape_mismatch = ( + tuple(dst_array.shape) != tuple(src_array.shape) + or tuple(dst_array.chunks) != tuple(src_array.chunks) + or dst_array.dtype != src_array.dtype + ) + else: + shape_mismatch = True + if shape_mismatch: + dst_group.create_dataset( + key, + shape=src_array.shape, + chunks=src_array.chunks, + dtype=src_array.dtype, + fill_value=0, + overwrite=True, + ) + dst_group[key].attrs.update(src_array.attrs) + + dst_group.attrs.update(src_group.attrs) + + +def _diff_and_sync_chunks(s3, s0_path, dst_s0_path, known_chunk_state, force=False): + """Diff remote vs known chunk state and pull changed chunks to local disk. + + Local disk is the source of truth — YAML imports are written locally + first and only later mirrored to MinIO; painted scribbles flow MinIO + → local through this function. We never delete on-disk chunks based + on remote state: an "absent" chunk on MinIO is almost always a + transient (paginated listing truncated, in-flight `mc mirror`, + server restart, network blip), not a real user erase. Painting BG + over a chunk in neuroglancer rewrites the chunk file, it does not + remove it. Treating remote-missing as "user erased it" once cost a + full session of training (3456 chunks wiped from disk after one bad + listing, FG index emptied, loss silently went to 0). + + Returns: + (changed_keys, removed_keys=[], remote_chunk_state) + ``removed_keys`` is always empty; the slot is preserved so + callers' tuple-unpacking keeps working. + """ + try: + chunk_files = s3.ls(s0_path) + except FileNotFoundError: + # Remote bucket has no annotation/s0 yet (just created) — keep + # whatever we have locally and try again next cycle. + return [], [], dict(known_chunk_state) + except Exception as e: + logger.warning(f"_diff_and_sync_chunks: s3.ls({s0_path}) failed: {e}; " + "treating as transient, skipping sync this cycle.") + return [], [], dict(known_chunk_state) + + remote_chunk_state = {} + for chunk_file in chunk_files: + chunk_key = Path(chunk_file).name + if not re.match(r"^\d+\.\d+\.\d+$", chunk_key): + continue + try: + info = s3.info(chunk_file) + remote_chunk_state[chunk_key] = _safe_epoch_timestamp(info.get("LastModified")) + except Exception: + remote_chunk_state[chunk_key] = 0.0 + + if force: + changed_keys = list(remote_chunk_state.keys()) + else: + changed_keys = [k for k, v in remote_chunk_state.items() if known_chunk_state.get(k) != v] + + if not changed_keys: + return [], [], remote_chunk_state + + # Copy changed chunks. We never delete: known_chunk_state may shrink + # if remote drops keys, but the on-disk file stays. + dst_s0_path = Path(dst_s0_path) + dst_s0_path.mkdir(parents=True, exist_ok=True) + copy_pairs = [(f"{s0_path}/{k}", str(dst_s0_path / k)) for k in changed_keys] + _copy_chunks_parallel(s3, copy_pairs) + + return changed_keys, [], remote_chunk_state + + +# --------------------------------------------------------------------------- +# Annotation sync (crop-based) +# --------------------------------------------------------------------------- + +def sync_annotation_from_minio(crop_id, force=False): + """ + Sync a single annotation crop from MinIO to local filesystem. + + Args: + crop_id: Crop ID to sync + force: Force sync even if not modified + + Returns: + bool: True if synced successfully + """ + if not minio_state["ip"] or not minio_state["port"] or not minio_state["output_base"]: + return False + + try: + s3 = _make_s3_filesystem() + + zarr_name = f"{crop_id}.zarr" + src_path = f"{minio_state['bucket']}/{zarr_name}/annotation" + dst_path = Path(minio_state["output_base"]) / zarr_name / "annotation" + + if not s3.exists(src_path): + return False + + known_chunk_state = minio_state["chunk_sync_state"].get(crop_id, {}) + s0_path = f"{src_path}/s0" + changed, removed, remote_chunk_state = _diff_and_sync_chunks( + s3, s0_path, dst_path / "s0", known_chunk_state, force=force + ) + + if not changed and not removed: + return False + + logger.info( + f"Syncing annotation for {crop_id} " + f"(changed={len(changed)}, removed={len(removed)})" + ) + + _sync_zarr_group_metadata(s3, src_path, dst_path) + + minio_state["last_sync"][crop_id] = datetime.now() + minio_state["chunk_sync_state"][crop_id] = remote_chunk_state + + logger.info(f"Successfully synced annotation for {crop_id}") + return True + + except Exception as e: + logger.error(f"Error syncing annotation for {crop_id}: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + + +# --------------------------------------------------------------------------- +# Annotation sync (full-dataset sync) +# --------------------------------------------------------------------------- + +def sync_all_annotations_from_minio(force: bool = True): + """Sync all annotations from MinIO to local disk. + + Returns: + Number of annotations synced, or -1 if MinIO is not initialized. + """ + if not minio_state.get("ip") or not minio_state.get("port"): + logger.info("MinIO not initialized, skipping annotation sync") + return -1 + + logger.info(f"Syncing all annotations from MinIO (force={force})...") + s3 = _make_s3_filesystem() + zarrs = s3.ls(minio_state["bucket"]) + zarr_ids = [Path(c).name.replace(".zarr", "") for c in zarrs if c.endswith(".zarr")] + synced = 0 + for zid in zarr_ids: + try: + zarr_name = f"{zid}.zarr" + attrs_path = f"{minio_state['bucket']}/{zarr_name}/.zattrs" + if s3.exists(attrs_path): + root_attrs = json.loads(s3.cat(attrs_path)) + if root_attrs.get("type") == "annotation_volume": + if sync_annotation_volume_from_minio(zid, force=force): + synced += 1 + continue + except Exception: + pass + if sync_annotation_from_minio(zid, force=force): + synced += 1 + logger.info(f"Synced {synced}/{len(zarr_ids)} annotations") + return synced + + +# --------------------------------------------------------------------------- +# Volume metadata helpers +# --------------------------------------------------------------------------- + +def _get_volume_metadata(volume_id, zarr_path=None): + """ + Get volume metadata from in-memory cache or reconstruct from zarr attrs. + + Used for server restart recovery -- if annotation_volumes dict was lost, + reconstruct metadata from the zarr's stored attributes. + """ + if volume_id in annotation_volumes: + return annotation_volumes[volume_id] + + if zarr_path is None: + return None + + try: + root = zarr.open(zarr_path, mode="r") + attrs = dict(root.attrs) + if attrs.get("type") != "annotation_volume": + return None + + metadata = { + "zarr_path": zarr_path, + "model_name": attrs.get("model_name", ""), + "output_size": attrs.get("chunk_size", [56, 56, 56]), + "input_size": attrs.get("input_size", [178, 178, 178]), + "input_voxel_size": attrs.get("input_voxel_size", [16, 16, 16]), + "output_voxel_size": attrs.get("output_voxel_size", [16, 16, 16]), + "dataset_path": attrs.get("dataset_path", ""), + "dataset_offset_nm": attrs.get("dataset_offset_nm", [0, 0, 0]), + "corrections_dir": str(Path(zarr_path).parent), + "extracted_chunks": set(), + "chunk_sync_state": {}, + } + annotation_volumes[volume_id] = metadata + return metadata + except Exception as e: + logger.error(f"Error reconstructing volume metadata for {volume_id}: {e}") + return None + + +def extract_correction_from_chunk(volume_id, chunk_indices, volume_metadata): + """ + Extract a correction entry from a single annotated chunk in a sparse volume. + + Reads the annotation chunk, extracts raw data with context padding, and + creates a standard correction zarr entry compatible with CorrectionDataset. + + Args: + volume_id: Volume identifier + chunk_indices: Tuple (cz, cy, cx) of chunk indices + volume_metadata: Volume metadata dict + + Returns: + bool: True if correction was created (chunk had annotations) + """ + from cellmap_flow.image_data_interface import ImageDataInterface + from funlib.geometry import Roi, Coordinate + + cz, cy, cx = chunk_indices + chunk_size = np.array(volume_metadata["output_size"]) + output_voxel_size = np.array(volume_metadata["output_voxel_size"]) + input_size = np.array(volume_metadata["input_size"]) + input_voxel_size = np.array(volume_metadata["input_voxel_size"]) + dataset_offset_nm = np.array(volume_metadata["dataset_offset_nm"]) + corrections_dir = volume_metadata["corrections_dir"] + + vol_zarr_path = volume_metadata["zarr_path"] + vol = zarr.open(vol_zarr_path, mode="r") + + z_start = cz * chunk_size[0] + y_start = cy * chunk_size[1] + x_start = cx * chunk_size[2] + + annotation_data = vol["annotation/s0"][ + z_start : z_start + chunk_size[0], + y_start : y_start + chunk_size[1], + x_start : x_start + chunk_size[2], + ] + + # Skip if all zeros (unannotated or erased) + if not np.any(annotation_data): + return False + + # Compute physical position of this chunk's center + chunk_offset_nm = dataset_offset_nm + np.array( + [z_start, y_start, x_start] + ) * output_voxel_size + chunk_center_nm = chunk_offset_nm + (chunk_size * output_voxel_size) / 2 + + # Extract raw data with full context padding + read_shape_nm = input_size * input_voxel_size + raw_roi = Roi( + offset=Coordinate(chunk_center_nm - read_shape_nm / 2), + shape=Coordinate(read_shape_nm), + ) + + logger.info( + f"Extracting raw for chunk ({cz},{cy},{cx}): " + f"ROI offset={raw_roi.offset}, shape={raw_roi.shape}" + ) + + idi = ImageDataInterface( + volume_metadata["dataset_path"], voxel_size=input_voxel_size + ) + raw_data = idi.to_ndarray_ts(raw_roi) + + # Create correction entry + correction_id = f"{volume_id}_chunk_{cz}_{cy}_{cx}" + correction_zarr_path = os.path.join(corrections_dir, f"{correction_id}.zarr") + + # If a stale zarr exists (e.g. copied in during Resume Existing Volume), + # wipe it before recreating. zarr's mode="w" only overwrites top-level + # metadata and can leave stale subarrays behind, causing + # KeyError: 'annotation/s0' when we later index into the group. + if os.path.isdir(correction_zarr_path): + import shutil + shutil.rmtree(correction_zarr_path, ignore_errors=True) + + raw_offset_voxels = ( + (chunk_center_nm - read_shape_nm / 2) / input_voxel_size + ).astype(int) + annotation_offset_voxels = (chunk_offset_nm / output_voxel_size).astype(int) + + success, zarr_info = create_correction_zarr( + zarr_path=correction_zarr_path, + raw_crop_shape=input_size, + raw_voxel_size=input_voxel_size, + raw_offset=raw_offset_voxels, + annotation_crop_shape=chunk_size, + annotation_voxel_size=output_voxel_size, + annotation_offset=annotation_offset_voxels, + dataset_path=volume_metadata["dataset_path"], + model_name=volume_metadata["model_name"], + output_channels=1, + raw_dtype=str(raw_data.dtype), + create_mask=False, + ) + + if not success: + logger.error(f"Failed to create correction zarr for chunk ({cz},{cy},{cx})") + return False + + # Write data + corr_zarr = zarr.open(correction_zarr_path, mode="r+") + corr_zarr["raw/s0"][:] = raw_data + corr_zarr["annotation/s0"][:] = annotation_data + + corr_zarr.attrs["source"] = "sparse_volume" + corr_zarr.attrs["volume_id"] = volume_id + corr_zarr.attrs["chunk_indices"] = [cz, cy, cx] + + logger.info(f"Created correction {correction_id} from chunk ({cz},{cy},{cx})") + return True + + +# --------------------------------------------------------------------------- +# Annotation volume sync +# --------------------------------------------------------------------------- + +def sync_annotation_volume_from_minio(volume_id, force=False): + """ + Sync an annotation volume from MinIO, detect annotated chunks, extract corrections. + + Steps: + 1. Sync the full annotation zarr from MinIO to local disk + 2. List chunk files in MinIO to find annotated chunks + 3. For each new annotated chunk, extract raw data and create correction entry + + Returns: + bool: True if any corrections were created + """ + if not minio_state["ip"] or not minio_state["port"] or not minio_state["output_base"]: + logger.warning("MinIO not initialized, skipping volume sync") + return False + + try: + zarr_name = f"{volume_id}.zarr" + local_zarr_path = os.path.join(minio_state["output_base"], zarr_name) + volume_meta = _get_volume_metadata(volume_id, local_zarr_path) + + if volume_meta is None: + logger.warning(f"No metadata for volume {volume_id}, skipping") + return False + + s3 = _make_s3_filesystem() + + bucket = minio_state["bucket"] + src_annotation_path = f"{bucket}/{zarr_name}/annotation" + + if not s3.exists(src_annotation_path): + return False + + # Sync zarr group metadata + dst_annotation_path = Path(local_zarr_path) / "annotation" + dst_annotation_path.mkdir(parents=True, exist_ok=True) + _sync_zarr_group_metadata(s3, src_annotation_path, dst_annotation_path) + + # Diff and sync chunks + s0_path = f"{bucket}/{zarr_name}/annotation/s0" + known_chunk_state = volume_meta.get("chunk_sync_state", {}) + changed_chunk_keys, removed_chunk_keys, remote_chunk_state = _diff_and_sync_chunks( + s3, s0_path, dst_annotation_path / "s0", known_chunk_state, force=force + ) + + if not changed_chunk_keys and not removed_chunk_keys: + minio_state["last_sync"][volume_id] = datetime.now() + return False + + logger.info( + f"Synced {len(changed_chunk_keys)} changed chunks for volume {volume_id}" + ) + + # Extract corrections for changed chunks. Skip entirely when a + # virtual-sources manifest is present: the trainer reads the volume + # zarr directly via VirtualPatchDataset and never touches per-chunk + # extracts, so this loop just slowly fills disk with thousands of + # 178**3 raw cubes that nothing reads. (See + # cellmap_flow/finetune/virtual_dataset.py for the manifest format.) + from cellmap_flow.finetune.virtual_dataset import read_manifest + + corrections_dir = volume_meta.get("corrections_dir") or os.path.dirname( + local_zarr_path + ) + manifest = read_manifest(corrections_dir) if corrections_dir else None + + extracted_chunks = volume_meta.get("extracted_chunks", set()) + changed_chunk_indices = [ + tuple(map(int, k.split("."))) + for k in changed_chunk_keys + ] + created_any = False + + if manifest is not None: + logger.debug( + f"Volume {volume_id}: skipping per-chunk extract (manifest present); " + f"{len(changed_chunk_indices)} changed chunks ignored." + ) + else: + for chunk_idx in changed_chunk_indices: + try: + created = extract_correction_from_chunk( + volume_id, chunk_idx, volume_meta + ) + if created: + extracted_chunks.add(chunk_idx) + created_any = True + else: + extracted_chunks.discard(chunk_idx) + except Exception as e: + logger.error(f"Error extracting correction for chunk {chunk_idx}: {e}") + import traceback + logger.error(traceback.format_exc()) + + # Update tracked state + volume_meta["extracted_chunks"] = extracted_chunks + volume_meta["chunk_sync_state"] = remote_chunk_state + minio_state["last_sync"][volume_id] = datetime.now() + + if created_any or changed_chunk_keys or removed_chunk_keys: + logger.info( + f"Volume {volume_id}: {len(extracted_chunks)} total chunks extracted" + ) + + return bool(created_any or changed_chunk_keys or removed_chunk_keys) + + except Exception as e: + logger.error(f"Error syncing annotation volume {volume_id}: {e}") + import traceback + logger.error(traceback.format_exc()) + return False + + +# --------------------------------------------------------------------------- +# Periodic sync +# --------------------------------------------------------------------------- + +def periodic_sync_annotations(): + """Background thread function to periodically sync annotations from MinIO.""" + while True: + try: + time.sleep(30) + if not minio_state["output_base"]: + continue + if not minio_state["ip"] or not minio_state["port"]: + continue + synced = sync_all_annotations_from_minio(force=False) + # After each successful sync, refresh the bounding-box overlay so + # the user sees where they've painted without clicking a button. + if synced and synced > 0: + try: + from cellmap_flow.dashboard.routes.finetune import ( + refresh_annotated_regions_layer, + ) + refresh_annotated_regions_layer() + except Exception as e: + logger.debug(f"Periodic sync: refresh_annotated_regions_layer failed: {e}") + except Exception as e: + logger.debug(f"Error in periodic sync: {e}") + + +def start_periodic_sync(): + """Start the periodic annotation sync thread if not already running.""" + if minio_state["sync_thread"] is None or not minio_state["sync_thread"].is_alive(): + thread = threading.Thread(target=periodic_sync_annotations, daemon=True) + thread.start() + minio_state["sync_thread"] = thread + logger.info("Started periodic annotation sync thread") + diff --git a/cellmap_flow/dashboard/routes/bbx_generator.py b/cellmap_flow/dashboard/routes/bbx_generator.py index 17ce1af..0599b95 100644 --- a/cellmap_flow/dashboard/routes/bbx_generator.py +++ b/cellmap_flow/dashboard/routes/bbx_generator.py @@ -4,7 +4,9 @@ from flask import Blueprint, request, jsonify from cellmap_flow.utils.scale_pyramid import get_raw_layer -from cellmap_flow.dashboard.state import bbx_generator_state +from cellmap_flow.globals import g + +bbx_generator_state = g.bbx_generator_state logger = logging.getLogger(__name__) diff --git a/cellmap_flow/dashboard/routes/blockwise.py b/cellmap_flow/dashboard/routes/blockwise.py index e51fee0..4b7c618 100644 --- a/cellmap_flow/dashboard/routes/blockwise.py +++ b/cellmap_flow/dashboard/routes/blockwise.py @@ -11,7 +11,7 @@ from cellmap_flow.globals import g from cellmap_flow.utils.web_utils import INPUT_NORM_DICT_KEY, POSTPROCESS_DICT_KEY -from cellmap_flow.dashboard.state import get_blockwise_tasks_dir +from cellmap_flow.globals import get_blockwise_tasks_dir logger = logging.getLogger(__name__) diff --git a/cellmap_flow/dashboard/routes/finetune/__init__.py b/cellmap_flow/dashboard/routes/finetune/__init__.py new file mode 100644 index 0000000..e39c87c --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/__init__.py @@ -0,0 +1,6 @@ +from cellmap_flow.dashboard.routes.finetune.annotation import ( + refresh_annotated_regions_layer, +) +from cellmap_flow.dashboard.routes.finetune.routes import finetune_bp + +__all__ = ["finetune_bp", "refresh_annotated_regions_layer"] diff --git a/cellmap_flow/dashboard/routes/finetune/annotation.py b/cellmap_flow/dashboard/routes/finetune/annotation.py new file mode 100644 index 0000000..2854d16 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/annotation.py @@ -0,0 +1,31 @@ +from cellmap_flow.dashboard.routes.finetune.annotation_core import ( + create_annotation_crop_response, + create_annotation_volume_response, + get_finetune_models_response, + get_user_prefs_response, + get_view_center_response, + set_user_prefs_response, +) +from cellmap_flow.dashboard.routes.finetune.annotation_sessions import ( + list_existing_sessions_response, + load_existing_volume_response, +) +from cellmap_flow.dashboard.routes.finetune.overlay import ( + add_crop_to_viewer_response, + refresh_annotated_regions_layer, + sync_annotations_manually_response, +) + +__all__ = [ + "add_crop_to_viewer_response", + "create_annotation_crop_response", + "create_annotation_volume_response", + "get_finetune_models_response", + "get_user_prefs_response", + "get_view_center_response", + "list_existing_sessions_response", + "load_existing_volume_response", + "refresh_annotated_regions_layer", + "set_user_prefs_response", + "sync_annotations_manually_response", +] diff --git a/cellmap_flow/dashboard/routes/finetune/annotation_core.py b/cellmap_flow/dashboard/routes/finetune/annotation_core.py new file mode 100644 index 0000000..e94f355 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/annotation_core.py @@ -0,0 +1,312 @@ +import logging +import os +import uuid +from datetime import datetime + +import numpy as np +import zarr +from flask import jsonify + +from cellmap_flow.dashboard.finetune_utils import ( + create_annotation_volume_zarr, + create_correction_zarr, + ensure_minio_serving, +) +from cellmap_flow.dashboard.routes.finetune.common import ( + ensure_corrections_storage, + find_model_config, + load_user_prefs, + save_user_prefs, + viewer_position_and_scales, +) +from cellmap_flow.dashboard.routes.finetune.overlay import refresh_annotated_regions_layer +from cellmap_flow.globals import current_input_norm_config, g + +logger = logging.getLogger(__name__) + + +def _get_selected_model_config(model_name): + if not getattr(g, "models_config", None): + return None, (jsonify({"success": False, "error": "No models loaded"}), 400) + + model_config = find_model_config(model_name) + if model_config is None: + return None, ( + jsonify({"success": False, "error": f"Model {model_name} not found"}), + 404, + ) + + return model_config, None + + +def _register_annotation_volume(volume_id, **volume_data): + if not hasattr(g, "annotation_volumes"): + g.annotation_volumes = {} + g.annotation_volumes[volume_id] = { + **volume_data, + "extracted_chunks": set(), + "chunk_sync_state": {}, + } + + +def get_finetune_models_response(): + try: + models = [] + for model_config in getattr(g, "models_config", []) or []: + try: + config = model_config.config + models.append( + { + "name": model_config.name, + "write_shape": list(config.write_shape), + "output_voxel_size": list(config.output_voxel_size), + "output_channels": config.output_channels, + } + ) + except Exception as e: + logger.warning(f"Could not extract config for {model_config.name}: {e}") + + if not models and hasattr(g, "jobs") and g.jobs: + logger.warning("No models in g.models_config, checking running jobs") + for job in g.jobs: + job_model_name = getattr(job, "model_name", None) + if not job_model_name: + continue + if hasattr(g, "pipeline_model_configs") and job_model_name in g.pipeline_model_configs: + config_dict = g.pipeline_model_configs[job_model_name] + try: + models.append( + { + "name": job_model_name, + "write_shape": config_dict.get("write_shape", []), + "output_voxel_size": config_dict.get("output_voxel_size", []), + "output_channels": config_dict.get("output_channels", 1), + } + ) + except Exception as e: + logger.warning(f"Could not extract config for {job_model_name}: {e}") + else: + logger.warning(f"No configuration found for running job: {job_model_name}") + + selected = models[0]["name"] if len(models) == 1 else None + return jsonify({"models": models, "selected_model": selected}) + except Exception as e: + logger.error(f"Error getting finetune models: {e}") + return jsonify({"error": str(e)}), 500 + + +def get_view_center_response(): + try: + position, scales_nm = viewer_position_and_scales() + logger.info(f"Got view center position: {position}") + return jsonify({"success": True, "position": position, "scales_nm": scales_nm}) + except ValueError as e: + return jsonify({"success": False, "error": str(e)}), 400 + except Exception as e: + logger.error(f"Error getting view center position: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 + + +def create_annotation_crop_response(data): + try: + from cellmap_flow.image_data_interface import ImageDataInterface + from funlib.geometry import Coordinate, Roi + + model_name = data.get("model_name") + output_path = data.get("output_path") + + position, viewer_scales_nm = viewer_position_and_scales() + view_center = np.array(position) + + model_config, error_response = _get_selected_model_config(model_name) + if error_response is not None: + return error_response + + config = model_config.config + read_shape = np.array(config.read_shape) + write_shape = np.array(config.write_shape) + input_voxel_size = np.array(config.input_voxel_size) + output_voxel_size = np.array(config.output_voxel_size) + output_channels = config.output_channels + + if viewer_scales_nm is not None: + view_center_nm = view_center * np.array(viewer_scales_nm) + else: + view_center_nm = view_center + logger.warning("No viewer scales provided, assuming view center is already in nm") + + raw_crop_shape_voxels = (read_shape / input_voxel_size).astype(int) + annotation_crop_shape_voxels = (write_shape / output_voxel_size).astype(int) + raw_crop_offset_voxels = ((view_center_nm - read_shape / 2) / input_voxel_size).astype(int) + annotation_crop_offset_voxels = ((view_center_nm - write_shape / 2) / output_voxel_size).astype(int) + + crop_id = f"{uuid.uuid4().hex[:8]}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + _, corrections_dir = ensure_corrections_storage(output_path) + zarr_path = os.path.join(corrections_dir, f"{crop_id}.zarr") + + dataset_path = getattr(g, "dataset_path", "unknown") + idi = ImageDataInterface(dataset_path, voxel_size=input_voxel_size) + raw_dtype = str(idi.ts.dtype) + + success, zarr_info = create_correction_zarr( + zarr_path=zarr_path, + raw_crop_shape=raw_crop_shape_voxels, + raw_voxel_size=input_voxel_size, + raw_offset=raw_crop_offset_voxels, + annotation_crop_shape=annotation_crop_shape_voxels, + annotation_voxel_size=output_voxel_size, + annotation_offset=annotation_crop_offset_voxels, + dataset_path=dataset_path, + model_name=model_name, + output_channels=output_channels, + raw_dtype=raw_dtype, + create_mask=False, + ) + if not success: + return jsonify({"success": False, "error": zarr_info}), 500 + + roi = Roi(offset=Coordinate(view_center_nm - read_shape / 2), shape=Coordinate(read_shape)) + raw_zarr = zarr.open(zarr_path, mode="r+") + raw_zarr["raw/s0"][:] = idi.to_ndarray_ts(roi) + + minio_url = ensure_minio_serving(zarr_path, crop_id, output_base_dir=corrections_dir) + return jsonify( + { + "success": True, + "crop_id": crop_id, + "zarr_path": zarr_path, + "minio_url": minio_url, + "neuroglancer_url": f"{minio_url}/annotation", + "metadata": { + "center_position_nm": view_center_nm.tolist(), + "raw_crop_offset": raw_crop_offset_voxels.tolist(), + "raw_crop_shape": raw_crop_shape_voxels.tolist(), + "raw_voxel_size": input_voxel_size.tolist(), + "annotation_crop_offset": annotation_crop_offset_voxels.tolist(), + "annotation_crop_shape": annotation_crop_shape_voxels.tolist(), + "annotation_voxel_size": output_voxel_size.tolist(), + }, + } + ) + except ValueError as e: + return jsonify({"success": False, "error": str(e)}), 400 + except Exception as e: + logger.error(f"Error creating annotation crop: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 + + +def create_annotation_volume_response(data): + try: + from cellmap_flow.image_data_interface import ImageDataInterface + from cellmap_flow.utils.neuroglancer_utils import get_raw_closest_scale + + model_name = data.get("model_name") + output_path = data.get("output_path") + + model_config, error_response = _get_selected_model_config(model_name) + if error_response is not None: + return error_response + + config = model_config.config + read_shape = np.array(config.read_shape) + write_shape = np.array(config.write_shape) + claimed_input_voxel_size = np.array(config.input_voxel_size) + claimed_output_voxel_size = np.array(config.output_voxel_size) + output_size = (write_shape / claimed_output_voxel_size).astype(int) + input_size = (read_shape / claimed_input_voxel_size).astype(int) + + dataset_path = getattr(g, "dataset_path", None) + if not dataset_path: + return jsonify({"success": False, "error": "No dataset path configured"}), 400 + + try: + effective_output_voxel_size = np.array( + get_raw_closest_scale(dataset_path, tuple(claimed_output_voxel_size)) + or claimed_output_voxel_size + ) + effective_input_voxel_size = np.array( + get_raw_closest_scale(dataset_path, tuple(claimed_input_voxel_size)) + or claimed_input_voxel_size + ) + except Exception: + effective_output_voxel_size = claimed_output_voxel_size + effective_input_voxel_size = claimed_input_voxel_size + + idi = ImageDataInterface(dataset_path, voxel_size=effective_output_voxel_size) + dataset_roi = idi.roi + dataset_offset_nm = np.array(dataset_roi.offset) + dataset_shape_nm = np.array(dataset_roi.shape) + dataset_shape_voxels = (dataset_shape_nm / effective_output_voxel_size).astype(int) + dataset_shape_voxels = np.ceil(dataset_shape_voxels / output_size).astype(int) * output_size + + volume_id = f"vol-{uuid.uuid4().hex[:8]}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + _, corrections_dir = ensure_corrections_storage(output_path) + zarr_path = os.path.join(corrections_dir, f"{volume_id}.zarr") + + success, zarr_info = create_annotation_volume_zarr( + zarr_path=zarr_path, + dataset_shape_voxels=dataset_shape_voxels, + output_voxel_size=effective_output_voxel_size, + dataset_offset_nm=dataset_offset_nm, + chunk_size=output_size, + dataset_path=dataset_path, + model_name=model_name, + input_size=input_size, + input_voxel_size=effective_input_voxel_size, + claimed_output_voxel_size=claimed_output_voxel_size, + claimed_input_voxel_size=claimed_input_voxel_size, + input_norm_config=current_input_norm_config(), + ) + if not success: + return jsonify({"success": False, "error": zarr_info}), 500 + + minio_url = ensure_minio_serving(zarr_path, volume_id, output_base_dir=corrections_dir) + _register_annotation_volume( + volume_id, + zarr_path=zarr_path, + model_name=model_name, + output_size=output_size.tolist(), + input_size=input_size.tolist(), + input_voxel_size=effective_input_voxel_size.tolist(), + output_voxel_size=effective_output_voxel_size.tolist(), + claimed_input_voxel_size=claimed_input_voxel_size.tolist(), + claimed_output_voxel_size=claimed_output_voxel_size.tolist(), + dataset_path=dataset_path, + dataset_offset_nm=dataset_offset_nm.tolist(), + corrections_dir=corrections_dir, + ) + refresh_annotated_regions_layer() + + return jsonify( + { + "success": True, + "volume_id": volume_id, + "zarr_path": zarr_path, + "minio_url": minio_url, + "neuroglancer_url": f"{minio_url}/annotation", + "metadata": { + "dataset_shape_voxels": dataset_shape_voxels.tolist(), + "chunk_size": output_size.tolist(), + "output_voxel_size": effective_output_voxel_size.tolist(), + "claimed_output_voxel_size": claimed_output_voxel_size.tolist(), + "dataset_offset_nm": dataset_offset_nm.tolist(), + }, + } + ) + except Exception as e: + logger.error(f"Error creating annotation volume: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 + + +def get_user_prefs_response(): + return jsonify({"success": True, "prefs": load_user_prefs()}) + + +def set_user_prefs_response(data): + try: + prefs = load_user_prefs() + prefs.update({key: value for key, value in data.items() if value is not None}) + save_user_prefs(prefs) + return jsonify({"success": True, "prefs": prefs}) + except Exception as e: + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/finetune/annotation_sessions.py b/cellmap_flow/dashboard/routes/finetune/annotation_sessions.py new file mode 100644 index 0000000..2f2c925 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/annotation_sessions.py @@ -0,0 +1,355 @@ +import json +import logging +import os +import threading +import time +from datetime import datetime + +from flask import jsonify + +from cellmap_flow.dashboard.finetune_utils import ensure_minio_serving +from cellmap_flow.dashboard.routes.finetune.common import ensure_corrections_storage +from cellmap_flow.dashboard.routes.finetune.overlay import refresh_annotated_regions_layer +from cellmap_flow.globals import g + +logger = logging.getLogger(__name__) + + +# Module-level progress tracker for in-flight Resume operations, keyed by a +# load_id supplied by the client. Same pattern as +# yaml_crops._PROGRESS / _set_progress so the dashboard can poll for updates +# while the long copytree + mirror is in flight. +_RESUME_PROGRESS: dict = {} +_RESUME_PROGRESS_LOCK = threading.Lock() +_RESUME_PROGRESS_TTL_SECONDS = 300 + + +def _set_resume_progress(load_id, **fields): + if not load_id: + return + with _RESUME_PROGRESS_LOCK: + entry = _RESUME_PROGRESS.setdefault(load_id, {"created_at": time.time()}) + entry.update(fields) + entry["updated_at"] = time.time() + now = time.time() + stale = [ + k for k, v in _RESUME_PROGRESS.items() + if now - v.get("updated_at", v.get("created_at", now)) > _RESUME_PROGRESS_TTL_SECONDS + ] + for k in stale: + _RESUME_PROGRESS.pop(k, None) + + +def get_resume_progress_response(load_id): + if not load_id: + return jsonify({"success": False, "error": "Missing 'load_id' query param"}), 400 + with _RESUME_PROGRESS_LOCK: + snapshot = _RESUME_PROGRESS.get(load_id) + snapshot = dict(snapshot) if snapshot else None + if snapshot is None: + return jsonify({"success": False, "error": f"Unknown load_id {load_id}"}), 404 + return jsonify({"success": True, "progress": snapshot}) + + +def _copytree_with_progress(src, dst, load_id, label, parent_done, parent_total): + """``shutil.copytree`` replacement that copies files in parallel and emits + per-file progress. NFS round-trip latency dominates per-file cost, so + threading gives a big speedup on small-file workloads (sparse zarr chunks). + """ + import shutil + from concurrent.futures import ThreadPoolExecutor, as_completed + + file_pairs: list[tuple[str, str]] = [] + os.makedirs(dst, exist_ok=True) + for root, dirs, files in os.walk(src): + rel = os.path.relpath(root, src) + target_root = os.path.join(dst, rel) if rel != "." else dst + os.makedirs(target_root, exist_ok=True) + for d in dirs: + os.makedirs(os.path.join(target_root, d), exist_ok=True) + for f in files: + file_pairs.append( + (os.path.join(root, f), os.path.join(target_root, f)) + ) + + files_in_src = len(file_pairs) + if files_in_src == 0: + return 0 + + # Use exactly what LSF allocated (LSB_DJOB_NUMPROC, falling back to CPU + # affinity). No artificial ceiling — going above the slot count means + # using cores LSF didn't give us; going below leaves throughput on the + # table. + from cellmap_flow.dashboard.finetune_utils import _get_sync_worker_count + + workers = max(1, min(_get_sync_worker_count(), files_in_src)) + + def _copy_one(pair): + s, d = pair + shutil.copy2(s, d) + + copied_so_far = 0 + progress_step = max(1, files_in_src // 50) + with ThreadPoolExecutor(max_workers=workers) as ex: + futures = [ex.submit(_copy_one, p) for p in file_pairs] + for fut in as_completed(futures): + fut.result() # surface any exception + copied_so_far += 1 + if copied_so_far % progress_step == 0 or copied_so_far == files_in_src: + _set_resume_progress( + load_id, + phase="copying", + current=label, + files_done=copied_so_far, + files_total=files_in_src, + parent_done=parent_done, + parent_total=parent_total, + ) + return files_in_src + + +def _register_annotation_volume(volume_id, **volume_data): + if not hasattr(g, "annotation_volumes"): + g.annotation_volumes = {} + g.annotation_volumes[volume_id] = { + **volume_data, + "extracted_chunks": set(), + "chunk_sync_state": {}, + } + + +def list_existing_sessions_response(data): + try: + output_path = data.get("output_path", "") + if not output_path: + return jsonify({"success": False, "error": "output_path required"}), 400 + + base = os.path.expanduser(output_path) + if not os.path.isdir(base): + return jsonify({"success": True, "sessions": []}) + + sessions = [] + for entry in sorted(os.listdir(base), reverse=True): + session_dir = os.path.join(base, entry) + corrections_dir = os.path.join(session_dir, "corrections") + if not os.path.isdir(corrections_dir): + continue + + volumes = [] + chunks = [] + for item in os.listdir(corrections_dir): + if not item.endswith(".zarr"): + continue + full = os.path.join(corrections_dir, item) + if "_chunk_" in item: + chunks.append(item) + else: + volumes.append({"volume_id": item.replace(".zarr", ""), "path": full}) + + if volumes or chunks: + sessions.append( + { + "session_id": entry, + "session_path": session_dir, + "volumes": volumes, + "chunk_count": len(chunks), + } + ) + + return jsonify({"success": True, "sessions": sessions}) + except Exception as e: + logger.error(f"Error listing sessions: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def load_existing_volume_response(data): + try: + import shutil + + from cellmap_flow.dashboard.finetune_utils import minio_state + + source_session_path = data.get("source_session_path") + output_path = data.get("output_path") + load_id = data.get("load_id") + if load_id: + _set_resume_progress( + load_id, + phase="starting", + done=False, + files_done=0, + files_total=0, + parent_done=0, + parent_total=0, + ) + if not source_session_path or not output_path: + return jsonify( + {"success": False, "error": "source_session_path and output_path required"} + ), 400 + + source_session_path = os.path.expanduser(source_session_path) + source_corrections = os.path.join(source_session_path, "corrections") + if not os.path.isdir(source_corrections): + return jsonify({"success": False, "error": f"No corrections found in {source_session_path}"}), 404 + + volume_entries = [ + entry + for entry in os.listdir(source_corrections) + if entry.endswith(".zarr") and "_chunk_" not in entry + ] + if not volume_entries: + return jsonify( + {"success": False, "error": f"No annotation volume found in {source_corrections}"} + ), 404 + + volume_dir = volume_entries[0] + volume_id = volume_dir.replace(".zarr", "") + new_session_path, new_corrections = ensure_corrections_storage(output_path) + + all_zarr_entries = [item for item in os.listdir(source_corrections) if item.endswith(".zarr")] + has_volume_zarr = any("_chunk_" not in e for e in all_zarr_entries) + if has_volume_zarr: + # New unified flow: trainer reads the volume zarr directly via + # VirtualPatchDataset; the per-chunk _chunk_*.zarr extracts from + # the legacy materialize pipeline are dead weight (and on big + # sessions can be thousands of files). + zarr_entries = [e for e in all_zarr_entries if "_chunk_" not in e] + skipped_chunk_extracts = len(all_zarr_entries) - len(zarr_entries) + if skipped_chunk_extracts: + logger.info( + f"Resume: skipping {skipped_chunk_extracts} legacy " + f"_chunk_*.zarr extracts; trainer will read the volume " + "zarr directly via the manifest." + ) + else: + # Legacy session with only per-chunk extracts (no volume zarr). + # Copy them so the trainer's CorrectionDataset path still works. + zarr_entries = all_zarr_entries + skipped_chunk_extracts = 0 + copied = [] + for idx, item in enumerate(zarr_entries): + src = os.path.join(source_corrections, item) + dst = os.path.join(new_corrections, item) + if os.path.exists(dst): + logger.info(f"Skipping {item} (already exists in target)") + continue + if load_id: + _set_resume_progress( + load_id, + phase="copying", + current=item, + files_done=0, + files_total=0, + parent_done=idx, + parent_total=len(zarr_entries), + done=False, + ) + _copytree_with_progress( + src, dst, load_id, label=item, + parent_done=idx, parent_total=len(zarr_entries), + ) + copied.append(item) + + source_minio = os.path.join(source_corrections, ".minio") + new_minio = os.path.join(new_corrections, ".minio") + copied_minio = False + if os.path.isdir(source_minio): + if minio_state.get("process") is not None and minio_state["process"].poll() is None: + logger.warning( + "MinIO already running with a different output_base; cannot rebind. " + "Falling back to mc mirror upload - painted data may be incomplete " + "if the source had unsynced chunks." + ) + elif not os.path.exists(new_minio): + if load_id: + _set_resume_progress( + load_id, + phase="copying_minio", + current=".minio", + files_done=0, files_total=0, + parent_done=len(zarr_entries), + parent_total=len(zarr_entries) + 1, + done=False, + ) + _copytree_with_progress( + source_minio, new_minio, load_id, label=".minio", + parent_done=len(zarr_entries), parent_total=len(zarr_entries) + 1, + ) + copied_minio = True + + if load_id: + _set_resume_progress( + load_id, + phase="mirroring_minio", + current=volume_dir, + done=False, + ) + + lineage_file = os.path.join(new_session_path, "loaded_from.json") + with open(lineage_file, "w") as f: + json.dump( + { + "source_session_path": source_session_path, + "loaded_at": datetime.now().isoformat(), + "copied_files": copied, + }, + f, + indent=2, + ) + + new_volume_path = os.path.join(new_corrections, volume_dir) + zattrs_file = os.path.join(new_volume_path, ".zattrs") + volume_meta = {} + if os.path.exists(zattrs_file): + with open(zattrs_file) as f: + volume_meta = json.load(f) + + s0_dir = os.path.join(new_volume_path, "annotation", "s0") + s0_count = 0 + if os.path.isdir(s0_dir): + s0_count = sum(1 for entry in os.listdir(s0_dir) if not entry.startswith(".")) + + minio_url = ensure_minio_serving(new_volume_path, volume_id, output_base_dir=new_corrections) + _register_annotation_volume( + volume_id, + zarr_path=new_volume_path, + model_name=volume_meta.get("model_name"), + output_size=volume_meta.get("chunk_size"), + input_size=volume_meta.get("input_size"), + input_voxel_size=volume_meta.get("input_voxel_size"), + output_voxel_size=volume_meta.get("output_voxel_size"), + dataset_path=volume_meta.get("dataset_path"), + dataset_offset_nm=volume_meta.get("dataset_offset_nm"), + corrections_dir=new_corrections, + ) + refresh_annotated_regions_layer() + + if load_id: + _set_resume_progress( + load_id, + phase="done", + done=True, + volume_id=volume_id, + copied_count=len(copied), + painted_chunk_count=s0_count, + ) + + return jsonify( + { + "success": True, + "volume_id": volume_id, + "new_session_path": new_session_path, + "zarr_path": new_volume_path, + "minio_url": minio_url, + "neuroglancer_url": f"{minio_url}/annotation", + "copied_count": len(copied), + "copied_minio": copied_minio, + "painted_chunk_count": s0_count, + "skipped_chunk_extracts": skipped_chunk_extracts, + "metadata": volume_meta, + } + ) + except Exception as e: + if load_id: + _set_resume_progress(load_id, phase="error", done=True, error=str(e)) + logger.error(f"Error loading existing volume: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/finetune/common.py b/cellmap_flow/dashboard/routes/finetune/common.py new file mode 100644 index 0000000..2542f41 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/common.py @@ -0,0 +1,222 @@ +import json +import logging +import os +from pathlib import Path + +import zarr + +from cellmap_flow.dashboard.finetune_utils import get_or_create_session_path +from cellmap_flow.globals import g + +logger = logging.getLogger(__name__) + +USER_PREFS_FILE = os.path.expanduser("~/.cellmap_flow/user_prefs.json") +LOG_FILTER_PATTERNS = [ + r"^\s+base_model\.\S+\.lora_", + r"^INFO:werkzeug:", + r"^Array metadata \(scale=", + r"^Host name:", + r"^DEBUG trainer:", +] +RESTART_PASSTHROUGH_KEYS = [ + "lora_r", + "lora_alpha", + "num_epochs", + "batch_size", + "learning_rate", + "loss_type", + "label_smoothing", + "distillation_lambda", + "margin", + "balance_classes", + "mask_unannotated", + "gradient_accumulation_steps", + "num_workers", + "no_augment", + "no_mixed_precision", + "patch_shape", + "output_type", + "select_channel", + "offsets", +] + + +def find_model_config(model_name): + for model_config in getattr(g, "models_config", []) or []: + if model_config.name == model_name: + return model_config + return None + + +def viewer_position_and_scales(): + if not hasattr(g, "viewer") or g.viewer is None: + raise ValueError("Viewer not initialized") + + with g.viewer.txn() as s: + position = s.position + dimensions = s.dimensions + scales_nm = None + + if dimensions and hasattr(dimensions, "scales"): + scales_nm = list(dimensions.scales) + if hasattr(dimensions, "units"): + units = dimensions.units + if isinstance(units, str): + units = [units] * len(scales_nm) + converted_scales = [] + for scale, unit in zip(scales_nm, units): + if unit == "m": + converted_scales.append(scale * 1e9) + elif unit == "nm": + converted_scales.append(scale) + else: + logger.warning(f"Unknown unit: {unit}, assuming nm") + converted_scales.append(scale) + scales_nm = converted_scales + + if hasattr(position, "tolist"): + position = position.tolist() + elif hasattr(position, "__iter__"): + position = list(position) + + return position, scales_nm + + +def ensure_corrections_storage(output_path): + if output_path: + session_path = get_or_create_session_path(output_path) + corrections_dir = os.path.join(session_path, "corrections") + os.makedirs(corrections_dir, exist_ok=True) + zarr.open_group(corrections_dir, mode="a") + return session_path, corrections_dir + + corrections_dir = os.path.expanduser("~/.cellmap_flow/corrections") + os.makedirs(corrections_dir, exist_ok=True) + zarr.open_group(corrections_dir, mode="a") + return None, corrections_dir + + +def load_user_prefs(): + try: + if os.path.exists(USER_PREFS_FILE): + with open(USER_PREFS_FILE) as f: + return json.load(f) + except Exception: + pass + return {} + + +def save_user_prefs(prefs): + try: + os.makedirs(os.path.dirname(USER_PREFS_FILE), exist_ok=True) + with open(USER_PREFS_FILE, "w") as f: + json.dump(prefs, f, indent=2) + except Exception as e: + logger.warning(f"Could not save user prefs: {e}") + + +def resolve_finetune_session(corrections_path_str): + base_corrections_path = Path(corrections_path_str) + if base_corrections_path.name == "corrections" and base_corrections_path.exists(): + return base_corrections_path.parent, base_corrections_path + + session_path = Path(get_or_create_session_path(str(base_corrections_path))) + return session_path, session_path / "corrections" + + +def detect_sparse_annotations(corrections_path): + try: + for path in corrections_path.iterdir(): + if path.suffix == ".zarr" and (path / ".zattrs").exists(): + attrs = json.loads((path / ".zattrs").read_text()) + if attrs.get("source") == "sparse_volume": + return True + except Exception as e: + logger.warning(f"Error checking for sparse annotations: {e}") + return False + + +def autodetect_output_type(model_config, output_type, offsets): + from cellmap_flow.finetune.finetune_cli import _read_offsets_from_script + + resolved_output_type = output_type + resolved_offsets = offsets + + if resolved_output_type is None: + if hasattr(model_config, "script_path"): + script_offsets = _read_offsets_from_script(model_config.script_path) + if script_offsets is not None: + resolved_output_type = "affinities" + resolved_offsets = json.dumps(script_offsets) + logger.info( + f"Auto-detected output_type='affinities' with " + f"{len(script_offsets)} offsets from model script" + ) + + if resolved_output_type is None: + channels = None + try: + if hasattr(model_config, "_load_metadata"): + meta = model_config._load_metadata() + channels = meta.get("channels_names") + elif hasattr(model_config, "_config") and hasattr(model_config._config, "channels"): + channels = model_config._config.channels + except Exception: + pass + + if channels and any("_aff" in channel for channel in channels): + resolved_output_type = "affinities" + n_aff = sum(1 for channel in channels if "_aff" in channel) + default_offsets = [ + [1 if axis == index else 0 for axis in range(3)] + for index in range(min(n_aff, 3)) + ] + resolved_offsets = json.dumps(default_offsets) + logger.info( + f"Auto-detected output_type='affinities' from " + f"channel names: {channels}, offsets: {default_offsets}" + ) + + if resolved_output_type is None: + resolved_output_type = "binary" + + if resolved_output_type == "affinities" and resolved_offsets is None: + if hasattr(model_config, "script_path"): + resolved_offsets = _read_offsets_from_script(model_config.script_path) + if resolved_offsets is not None: + logger.info(f"Auto-detected {len(resolved_offsets)} offsets from model script") + resolved_offsets = json.dumps(resolved_offsets) + if resolved_offsets is None: + raise ValueError( + "output_type='affinities' requires offsets. " + "Define 'offsets' in the model script or pass them in the request." + ) + elif isinstance(resolved_offsets, list): + resolved_offsets = json.dumps(resolved_offsets) + + return resolved_output_type, resolved_offsets + + +def build_restart_params(data): + updated_params = {} + for key in RESTART_PASSTHROUGH_KEYS: + if key in data and data[key] is not None: + updated_params[key] = data[key] + + if "distillation_scope" in data and data["distillation_scope"] is not None: + scope = str(data["distillation_scope"]).lower() + if scope in {"all", "unlabeled"}: + updated_params["distillation_all_voxels"] = scope == "all" + else: + logger.warning(f"Ignoring invalid distillation_scope: {data['distillation_scope']}") + + return updated_params + + +def get_lsf_job_id(finetune_job): + if finetune_job.lsf_job: + if hasattr(finetune_job.lsf_job, "job_id"): + return finetune_job.lsf_job.job_id + if hasattr(finetune_job.lsf_job, "process"): + return f"PID:{finetune_job.lsf_job.process.pid}" + return None diff --git a/cellmap_flow/dashboard/routes/finetune/overlay.py b/cellmap_flow/dashboard/routes/finetune/overlay.py new file mode 100644 index 0000000..19f5b2c --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/overlay.py @@ -0,0 +1,296 @@ +import json +import logging +import os +import re + +import neuroglancer +import numpy as np +from flask import jsonify + +from cellmap_flow.dashboard.finetune_utils import ( + sync_all_annotations_from_minio, + sync_annotation_from_minio, +) +from cellmap_flow.globals import g + +logger = logging.getLogger(__name__) + +_CHUNK_KEY_RE = re.compile(r"^\d+\.\d+\.\d+$") + + +def _chunk_outside_all_bboxes( + chunk_lo_voxels: np.ndarray, + chunk_hi_voxels: np.ndarray, + bbox_offsets: np.ndarray, + bbox_ends: np.ndarray, +) -> bool: + """Return True if the chunk is NOT fully contained in any + ``imported_crops`` bbox -- i.e. it represents painted-scribble + work that the per-import yellow boxes don't already cover. + + YAML imports write chunk-aligned slabs, so imported chunks land + fully inside an import bbox; painted-only chunks land fully outside. + A mixed chunk (rare; user paints over an import edge) reads as + "outside" by this rule, which is what we want -- it has painted + work the existing yellow box may not visually cue. + """ + if bbox_offsets.shape[0] == 0: + return True + fully_inside = np.all( + (chunk_lo_voxels >= bbox_offsets) & (chunk_hi_voxels <= bbox_ends), + axis=1, + ) + return not bool(fully_inside.any()) + + +def refresh_annotated_regions_layer(corrections_path=None): + if not hasattr(g, "viewer") or g.viewer is None: + return 0 + + scan_dirs = [] + if corrections_path: + scan_dirs.append(corrections_path) + else: + for volume in (getattr(g, "annotation_volumes", {}) or {}).values(): + corrections_dir = volume.get("corrections_dir") + if corrections_dir and corrections_dir not in scan_dirs: + scan_dirs.append(corrections_dir) + # Also scan corrections dirs from active output sessions so + # YAML-loaded crops show up even when no annotation_volume + # has been registered for the session. + for session_path in (getattr(g, "output_sessions", {}) or {}).values(): + session_corrections = os.path.join(session_path, "corrections") + if session_corrections not in scan_dirs and os.path.isdir(session_corrections): + scan_dirs.append(session_corrections) + if not scan_dirs: + return 0 + + boxes = [] + for corrections_dir in scan_dirs: + if not os.path.isdir(corrections_dir): + continue + for entry in sorted(os.listdir(corrections_dir)): + # Per-painted-chunk small boxes (the existing behavior). + if "_chunk_" in entry and entry.endswith(".zarr"): + zattrs_file = os.path.join(corrections_dir, entry, ".zattrs") + if not os.path.exists(zattrs_file): + continue + try: + with open(zattrs_file) as f: + meta = json.load(f) + roi = meta.get("roi", {}) + offset_vox = roi.get("annotation_offset") + shape_vox = roi.get("annotation_shape") + voxel = meta.get("annotation_voxel_size") + if not (offset_vox and shape_vox and voxel): + continue + voxel_arr = np.array(voxel, dtype=np.float64) + lo = np.array(offset_vox, dtype=np.float64) * voxel_arr + hi = lo + np.array(shape_vox, dtype=np.float64) * voxel_arr + boxes.append({"label": entry, "lo": lo.tolist(), "hi": hi.tolist()}) + except Exception as e: + logger.warning(f"Could not read chunk metadata for {entry}: {e}") + continue + + # Per-imported-YAML-crop large boxes (one per crop, read from the + # annotation_volume.zarr's root attrs that the YAML loader writes) + # plus per-painted-chunk small boxes for any populated chunk that + # isn't already covered by an import bbox. + if entry.endswith(".zarr"): + vol_attrs_file = os.path.join(corrections_dir, entry, ".zattrs") + if not os.path.exists(vol_attrs_file): + continue + try: + with open(vol_attrs_file) as f: + vol_meta = json.load(f) + if vol_meta.get("type") != "annotation_volume": + continue + voxel = vol_meta.get("output_voxel_size") + dataset_offset = vol_meta.get("dataset_offset_nm", [0, 0, 0]) + if not voxel: + continue + voxel_arr = np.array(voxel, dtype=np.float64) + dataset_offset_arr = np.array(dataset_offset, dtype=np.float64) + + # Pass 1: yellow boxes for each imported crop. + imported = vol_meta.get("imported_crops") or [] + bbox_off_list = [] + bbox_end_list = [] + for crop in imported: + offset_vox = crop.get("annotation_offset_voxels") + shape_vox = crop.get("annotation_shape_voxels") + if not (offset_vox and shape_vox): + continue + offset_arr = np.array(offset_vox, dtype=np.int64) + shape_arr = np.array(shape_vox, dtype=np.int64) + bbox_off_list.append(offset_arr) + bbox_end_list.append(offset_arr + shape_arr) + lo = ( + dataset_offset_arr + + offset_arr.astype(np.float64) * voxel_arr + ) + hi = lo + shape_arr.astype(np.float64) * voxel_arr + label = crop.get("name") or os.path.basename( + crop.get("path", "imported_crop").rstrip("/") + ) + boxes.append( + {"label": f"yaml_crop:{label}", "lo": lo.tolist(), "hi": hi.tolist()} + ) + bbox_offsets = ( + np.stack(bbox_off_list, axis=0) + if bbox_off_list + else np.zeros((0, 3), dtype=np.int64) + ) + bbox_ends = ( + np.stack(bbox_end_list, axis=0) + if bbox_end_list + else np.zeros((0, 3), dtype=np.int64) + ) + + # Pass 2: small boxes for painted-only chunks. Walk the + # volume zarr's annotation/s0/ chunk files and emit a box + # per chunk that isn't fully contained in any import bbox. + # Cheap: just lists chunk file names and compares spatial + # bbox to import bboxes -- never reads chunk contents. + chunk_size = vol_meta.get("chunk_size") + if not chunk_size: + continue + chunk_size_arr = np.array(chunk_size, dtype=np.int64) + s0_path = os.path.join(corrections_dir, entry, "annotation", "s0") + if not os.path.isdir(s0_path): + continue + crop_label = ( + os.path.basename(crop.get("path", "")).rstrip("/") + if imported + else "painted" + ) + for chunk_name in os.listdir(s0_path): + if not _CHUNK_KEY_RE.match(chunk_name): + continue + cz, cy, cx = (int(s) for s in chunk_name.split(".")) + chunk_lo_vox = ( + np.array([cz, cy, cx], dtype=np.int64) * chunk_size_arr + ) + chunk_hi_vox = chunk_lo_vox + chunk_size_arr + if not _chunk_outside_all_bboxes( + chunk_lo_vox, chunk_hi_vox, bbox_offsets, bbox_ends + ): + continue + lo = ( + dataset_offset_arr + + chunk_lo_vox.astype(np.float64) * voxel_arr + ) + hi = ( + dataset_offset_arr + + chunk_hi_vox.astype(np.float64) * voxel_arr + ) + boxes.append( + { + "label": f"painted:{chunk_name}", + "lo": lo.tolist(), + "hi": hi.tolist(), + } + ) + except Exception as e: + logger.warning( + f"Could not read annotation_volume metadata for {entry}: {e}" + ) + + layer_name = "annotated_regions" + if not boxes: + try: + with g.viewer.txn() as s: + if layer_name in s.layers: + del s.layers[layer_name] + except Exception: + pass + return 0 + + axes_names = ["z", "y", "x"] + try: + if hasattr(g, "raw") and g.raw is not None: + source = getattr(g.raw, "source", None) + if source is not None and hasattr(source, "dimensions"): + axes_names = list(source.dimensions.names) + except Exception: + pass + + annotations = [ + neuroglancer.AxisAlignedBoundingBoxAnnotation( + point_a=box["lo"], + point_b=box["hi"], + id=str(index), + description=box["label"], + ) + for index, box in enumerate(boxes) + ] + + try: + with g.viewer.txn() as s: + s.layers[layer_name] = neuroglancer.LocalAnnotationLayer( + dimensions=neuroglancer.CoordinateSpace( + names=axes_names, + units="nm", + scales=[1, 1, 1], + ), + annotations=annotations, + ) + # Force-visible in case a prior toggle archived the layer. + try: + s.layers[layer_name].visible = True + except Exception: + pass + except Exception as e: + logger.warning(f"Could not update annotated_regions layer: {e}") + return 0 + + return len(boxes) + + +def add_crop_to_viewer_response(data): + try: + crop_id = data.get("crop_id") + minio_url = data.get("minio_url") + if not hasattr(g, "viewer") or g.viewer is None: + return jsonify({"success": False, "error": "Viewer not initialized"}), 400 + + with g.viewer.txn() as s: + layer_name = data.get("layer_name", f"annotation_{crop_id}") + source_config = { + "url": f"s3+{minio_url}", + "subsources": {"default": {"writingEnabled": True}, "bounds": {}}, + } + s.layers[layer_name] = neuroglancer.SegmentationLayer(source=source_config) + + return jsonify({"success": True, "message": "Layer added to viewer", "layer_name": layer_name}) + except Exception as e: + logger.error(f"Error adding layer to viewer: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 + + +def sync_annotations_manually_response(data): + try: + crop_id = data.get("crop_id", None) + force = data.get("force", True) + + if crop_id: + success = sync_annotation_from_minio(crop_id, force=force) + refresh_annotated_regions_layer() + if success: + return jsonify({"success": True, "message": f"Synced annotation for {crop_id}"}) + return jsonify({"success": False, "message": f"No updates to sync for {crop_id}"}) + + synced = sync_all_annotations_from_minio(force=force) + refresh_annotated_regions_layer() + if synced == -1: + return jsonify({"success": False, "error": "MinIO not initialized"}), 400 + return jsonify( + { + "success": True, + "message": f"Synced {synced} annotations", + "synced_count": synced, + } + ) + except Exception as e: + logger.error(f"Error in sync endpoint: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/finetune/routes.py b/cellmap_flow/dashboard/routes/finetune/routes.py new file mode 100644 index 0000000..6ea4fe9 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/routes.py @@ -0,0 +1,162 @@ +from flask import Blueprint, request + +from cellmap_flow.dashboard.routes.finetune.annotation import ( + add_crop_to_viewer_response, + create_annotation_crop_response, + create_annotation_volume_response, + get_finetune_models_response, + get_user_prefs_response, + get_view_center_response, + list_existing_sessions_response, + load_existing_volume_response, + refresh_annotated_regions_layer, + set_user_prefs_response, + sync_annotations_manually_response, +) +from cellmap_flow.dashboard.routes.finetune.annotation_sessions import ( + get_resume_progress_response, +) +from cellmap_flow.dashboard.routes.finetune.training import ( + cancel_job_response, + get_job_logs_response, + get_job_status_response, + get_inference_server_status_response, + list_finetuning_jobs_response, + restart_finetuning_job_response, + stop_training_early_response, + stream_job_logs_response, + submit_finetuning_response, +) +from cellmap_flow.dashboard.routes.finetune.viewer import ( + add_finetuned_layer_to_viewer_response, +) +from cellmap_flow.dashboard.routes.finetune.yaml_crops import ( + get_load_crops_progress_response, + load_crops_from_yaml_response, + read_yaml_file_response, +) + +finetune_bp = Blueprint("finetune", __name__) + + +@finetune_bp.route("/api/finetune/models", methods=["GET"]) +def get_finetune_models(): + return get_finetune_models_response() + + +@finetune_bp.route("/api/finetune/view-center", methods=["GET"]) +def get_view_center(): + return get_view_center_response() + + +@finetune_bp.route("/api/finetune/create-crop", methods=["POST"]) +def create_annotation_crop(): + return create_annotation_crop_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/create-volume", methods=["POST"]) +def create_annotation_volume(): + return create_annotation_volume_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/load-crops", methods=["POST"]) +def load_crops_from_yaml(): + return load_crops_from_yaml_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/read-yaml", methods=["GET"]) +def read_yaml_file(): + return read_yaml_file_response(request.args.get("path")) + + +@finetune_bp.route("/api/finetune/load-crops-progress", methods=["GET"]) +def get_load_crops_progress(): + return get_load_crops_progress_response(request.args.get("load_id")) + + +@finetune_bp.route("/api/finetune/user-prefs", methods=["GET"]) +def get_user_prefs(): + return get_user_prefs_response() + + +@finetune_bp.route("/api/finetune/user-prefs", methods=["POST"]) +def set_user_prefs(): + return set_user_prefs_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/list-existing-sessions", methods=["POST"]) +def list_existing_sessions(): + return list_existing_sessions_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/load-existing-volume", methods=["POST"]) +def load_existing_volume(): + return load_existing_volume_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/load-existing-volume-progress", methods=["GET"]) +def load_existing_volume_progress(): + return get_resume_progress_response(request.args.get("load_id")) + + +@finetune_bp.route("/api/finetune/add-to-viewer", methods=["POST"]) +def add_crop_to_viewer(): + return add_crop_to_viewer_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/sync-annotations", methods=["POST"]) +def sync_annotations_manually(): + return sync_annotations_manually_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/submit", methods=["POST"]) +def submit_finetuning(): + return submit_finetuning_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/jobs", methods=["GET"]) +def get_finetuning_jobs(): + return list_finetuning_jobs_response() + + +@finetune_bp.route("/api/finetune/job//status", methods=["GET"]) +def get_job_status(job_id): + return get_job_status_response(job_id) + + +@finetune_bp.route("/api/finetune/job//logs", methods=["GET"]) +def get_job_logs(job_id): + return get_job_logs_response(job_id) + + +@finetune_bp.route("/api/finetune/job//logs/stream", methods=["GET"]) +def stream_job_logs(job_id): + return stream_job_logs_response(job_id) + + +@finetune_bp.route("/api/finetune/job//cancel", methods=["POST"]) +def cancel_job(job_id): + return cancel_job_response(job_id) + + +@finetune_bp.route("/api/finetune/job//stop-early", methods=["POST"]) +def stop_training_early(job_id): + return stop_training_early_response(job_id) + + +@finetune_bp.route("/api/finetune/job//inference-server", methods=["GET"]) +def get_inference_server_status(job_id): + return get_inference_server_status_response(job_id) + + +@finetune_bp.route("/api/viewer/add-finetuned-layer", methods=["POST"]) +def add_finetuned_layer_to_viewer(): + return add_finetuned_layer_to_viewer_response(request.get_json() or {}) + + +@finetune_bp.route("/api/finetune/job//restart", methods=["POST"]) +def restart_finetuning_job(job_id): + return restart_finetuning_job_response(job_id, request.get_json() or {}) + + +__all__ = ["finetune_bp", "refresh_annotated_regions_layer"] diff --git a/cellmap_flow/dashboard/routes/finetune/service.py b/cellmap_flow/dashboard/routes/finetune/service.py new file mode 100644 index 0000000..9607796 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/service.py @@ -0,0 +1,57 @@ +from cellmap_flow.dashboard.routes.finetune.annotation import ( + add_crop_to_viewer_response, + create_annotation_crop_response, + create_annotation_volume_response, + get_finetune_models_response, + get_user_prefs_response, + get_view_center_response, + list_existing_sessions_response, + load_existing_volume_response, + refresh_annotated_regions_layer, + set_user_prefs_response, + sync_annotations_manually_response, +) +from cellmap_flow.dashboard.routes.finetune.common import ( + autodetect_output_type as _autodetect_output_type, + build_restart_params as _build_restart_params, +) +from cellmap_flow.dashboard.routes.finetune.training import ( + cancel_job_response, + get_job_logs_response, + get_job_status_response, + get_inference_server_status_response, + list_finetuning_jobs_response, + restart_finetuning_job_response, + stop_training_early_response, + stream_job_logs_response, + submit_finetuning_response, +) +from cellmap_flow.dashboard.routes.finetune.viewer import ( + add_finetuned_layer_to_viewer_response, +) + +__all__ = [ + "_autodetect_output_type", + "_build_restart_params", + "add_crop_to_viewer_response", + "add_finetuned_layer_to_viewer_response", + "cancel_job_response", + "get_job_logs_response", + "get_job_status_response", + "create_annotation_crop_response", + "create_annotation_volume_response", + "get_finetune_models_response", + "get_inference_server_status_response", + "get_user_prefs_response", + "get_view_center_response", + "list_finetuning_jobs_response", + "list_existing_sessions_response", + "load_existing_volume_response", + "refresh_annotated_regions_layer", + "restart_finetuning_job_response", + "set_user_prefs_response", + "stop_training_early_response", + "stream_job_logs_response", + "submit_finetuning_response", + "sync_annotations_manually_response", +] diff --git a/cellmap_flow/dashboard/routes/finetune/training.py b/cellmap_flow/dashboard/routes/finetune/training.py new file mode 100644 index 0000000..6fe7e99 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/training.py @@ -0,0 +1,477 @@ +import json +import logging +import re +import subprocess +import time +from datetime import datetime +from pathlib import Path + +from flask import Response, jsonify + +from cellmap_flow.dashboard.finetune_utils import sync_all_annotations_from_minio +from cellmap_flow.dashboard.routes.finetune.common import ( + LOG_FILTER_PATTERNS, + autodetect_output_type, + build_restart_params, + detect_sparse_annotations, + find_model_config, + get_lsf_job_id, + resolve_finetune_session, +) +from cellmap_flow.globals import g + +logger = logging.getLogger(__name__) + + +def _parse_patches_per_epoch_override(data): + """Return ``(provided, value)`` for the optional virtual-dataset override. + + ``0`` means "auto" (manifest ``None``); blank/missing means leave the + existing manifest value untouched. + """ + if "patches_per_epoch" not in data: + return False, None + raw = data.get("patches_per_epoch") + if raw is None or raw == "": + return False, None + try: + value = int(raw) + except (TypeError, ValueError): + raise ValueError("patches_per_epoch must be a non-negative integer") + if value < 0: + raise ValueError("patches_per_epoch must be a non-negative integer") + return True, (None if value == 0 else value) + + +def _refresh_virtual_manifest_for_training(corrections_dir, manifest, data, context): + """Apply dashboard-owned training-time settings to a virtual manifest.""" + from cellmap_flow.finetune.virtual_dataset import write_manifest + from cellmap_flow.globals import current_input_norm_config + + current_norm = current_input_norm_config() + if current_norm and manifest.get("input_norm") != current_norm: + logger.info( + "Refreshing manifest input_norm before %s " + "(was: %s, now: %s)", + context, + list((manifest.get("input_norm") or {}).keys()), + list(current_norm.keys()), + ) + manifest["input_norm"] = current_norm + + override_given, patches_per_epoch = _parse_patches_per_epoch_override(data) + if override_given: + old_value = manifest.get("patches_per_epoch") + manifest["patches_per_epoch"] = patches_per_epoch + logger.info( + "Applying patches_per_epoch override before %s: %s -> %s", + context, + old_value, + "auto" if patches_per_epoch is None else patches_per_epoch, + ) + + write_manifest(str(corrections_dir), manifest) + return override_given, patches_per_epoch + + +def list_finetuning_jobs_response(): + try: + return jsonify({"success": True, "jobs": g.finetune_job_manager.list_jobs()}) + except Exception as e: + logger.error(f"Error listing jobs: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def get_job_status_response(job_id): + try: + status = g.finetune_job_manager.get_job_status(job_id) + if status is None: + return jsonify({"success": False, "error": "Job not found"}), 404 + return jsonify({"success": True, **status}) + except Exception as e: + logger.error(f"Error getting job status: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def get_job_logs_response(job_id): + try: + logs = g.finetune_job_manager.get_job_logs(job_id) + if logs is None: + return jsonify({"success": False, "error": "Job not found"}), 404 + return jsonify({"success": True, "logs": logs}) + except Exception as e: + logger.error(f"Error getting job logs: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def submit_finetuning_response(data): + try: + model_name = data.get("model_name") + corrections_path_str = data.get("corrections_path") + if not model_name: + return jsonify({"success": False, "error": "model_name is required"}), 400 + if not corrections_path_str: + return jsonify( + { + "success": False, + "error": "corrections_path is required. Please specify the output path where annotation crops are saved.", + } + ), 400 + + model_config = find_model_config(model_name) + if not model_config: + return jsonify({"success": False, "error": f"Model {model_name} not found"}), 404 + + session_path, actual_corrections_path = resolve_finetune_session(corrections_path_str) + if not actual_corrections_path.exists(): + return jsonify( + { + "success": False, + "error": f"Corrections path does not exist: {actual_corrections_path}. Please create annotation crops first.", + } + ), 400 + + # Pre-training sync: only needed by the legacy CorrectionDataset path, + # which reads per-chunk _chunk_*.zarr extracts. The new VirtualPatchDataset + # reads the annotation_volume.zarr directly, so when a manifest is present + # the sync is wasted work and can hang submit for many minutes when the + # volume contains imported YAML data. + from cellmap_flow.finetune.virtual_dataset import read_manifest + + existing_manifest = read_manifest(str(actual_corrections_path)) + if existing_manifest is None: + try: + sync_all_annotations_from_minio(force=False) + except Exception as e: + logger.warning(f"Error syncing annotations before training: {e}") + else: + _refresh_virtual_manifest_for_training( + actual_corrections_path, existing_manifest, data, "submit" + ) + logger.info( + "Virtual sources manifest present; skipping pre-training MinIO sync." + ) + + loss_type = data.get("loss_type", "mse") + distillation_lambda = data.get("distillation_lambda", 0.0) + has_sparse = detect_sparse_annotations(actual_corrections_path) + sparse_auto_switched = False + if has_sparse and loss_type == "mse": + loss_type = "margin" + distillation_lambda = 0.5 + sparse_auto_switched = True + logger.info( + "Auto-switched to margin loss + distillation (lambda=0.5) for sparse annotations" + ) + + output_type, offsets = autodetect_output_type( + model_config, + data.get("output_type", None), + data.get("offsets", None), + ) + + finetune_job = g.finetune_job_manager.submit_finetuning_job( + model_config=model_config, + corrections_path=actual_corrections_path, + lora_r=data.get("lora_r", 8), + num_epochs=data.get("num_epochs", 10), + batch_size=data.get("batch_size", 2), + learning_rate=data.get("learning_rate", 1e-4), + output_base=Path(session_path), + checkpoint_path_override=( + Path(data["checkpoint_path"]) if data.get("checkpoint_path") else None + ), + auto_serve=data.get("auto_serve", True), + mask_unannotated=has_sparse, + loss_type=loss_type, + label_smoothing=data.get("label_smoothing", 0.1), + distillation_lambda=distillation_lambda, + distillation_scope=data.get("distillation_scope", "unlabeled"), + margin=data.get("margin", 0.3), + balance_classes=data.get("balance_classes", False), + queue=data.get("queue", "gpu_h100"), + output_type=output_type, + select_channel=data.get("select_channel", None), + offsets=offsets, + ) + + response = { + "success": True, + "job_id": finetune_job.job_id, + "lsf_job_id": get_lsf_job_id(finetune_job), + "output_dir": str(finetune_job.output_dir), + "output_type": output_type, + "message": "Finetuning job submitted successfully", + } + if sparse_auto_switched: + response["note"] = ( + "Auto-switched to margin loss + distillation (lambda=0.5) for sparse annotations" + ) + return jsonify(response) + except ValueError as e: + logger.error(f"Validation error: {e}") + return jsonify({"success": False, "error": str(e)}), 400 + except Exception as e: + logger.error(f"Error submitting finetuning job: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 + + +def stream_job_logs_response(job_id): + log_filters = [re.compile(pattern) for pattern in LOG_FILTER_PATTERNS] + + def iter_visible_lines(text): + for line in text.splitlines(): + if line and not any(pattern.search(line) for pattern in log_filters): + yield line + + def sse_data_block(lines): + if not lines: + return None + payload = "\n".join(lines) + return "data: " + payload.replace("\n", "\ndata: ") + "\n\n" + + def read_bpeek_content(lsf_job_id): + try: + result = subprocess.run( + ["bpeek", str(lsf_job_id)], + capture_output=True, + text=True, + timeout=2, + ) + except Exception as e: + logger.debug(f"bpeek call failed for job {lsf_job_id}: {e}") + return None + + output = result.stdout or "" + stderr = (result.stderr or "").strip() + if stderr and "Not yet started" not in stderr: + logger.debug(f"bpeek stderr for job {lsf_job_id}: {stderr}") + return output + + def generate(): + heartbeat_interval_s = 1.0 + last_heartbeat = time.perf_counter() + + fjm = g.finetune_job_manager + if job_id not in fjm.jobs: + yield f"data: Job {job_id} not found\n\n" + return + + finetune_job = fjm.jobs[job_id] + lsf_job_id = None + if finetune_job.lsf_job and hasattr(finetune_job.lsf_job, "job_id"): + lsf_job_id = finetune_job.lsf_job.job_id + + # Prefer the tee'd log file once it exists. LSF bpeek can buffer output + # and then release several batch lines at once, which makes the + # dashboard look stuck even while training is moving. + use_bpeek = lsf_job_id is not None + last_bpeek_line_count = 0 + last_bpeek_poll = 0.0 + bpeek_poll_interval_s = 1.0 + streamed_bpeek = False + file_seen = finetune_job.log_file.exists() + last_position = 0 + + if file_seen: + try: + with open(finetune_job.log_file, "r") as f: + content = f.read() + last_position = f.tell() + block = sse_data_block(list(iter_visible_lines(content))) + if block: + yield block + except Exception as e: + logger.error(f"Error reading log file: {e}") + file_seen = False + elif use_bpeek: + initial = read_bpeek_content(lsf_job_id) + if initial is None: + use_bpeek = False + else: + last_bpeek_line_count = len(initial.splitlines()) + streamed_bpeek = bool(initial) + block = sse_data_block(list(iter_visible_lines(initial))) + if block: + yield block + + while finetune_job.status.value in ["PENDING", "RUNNING"]: + try: + now = time.perf_counter() + + if finetune_job.log_file.exists(): + if not file_seen: + file_seen = True + last_position = ( + finetune_job.log_file.stat().st_size + if streamed_bpeek + else 0 + ) + with open(finetune_job.log_file, "r") as f: + f.seek(last_position) + new_content = f.read() + last_position = f.tell() + if new_content: + block = sse_data_block(list(iter_visible_lines(new_content))) + if block: + yield block + elif use_bpeek and lsf_job_id and now - last_bpeek_poll >= bpeek_poll_interval_s: + last_bpeek_poll = now + content = read_bpeek_content(lsf_job_id) + if content is None: + use_bpeek = False + else: + current_lines = content.splitlines() + delta_lines = current_lines if len(current_lines) < last_bpeek_line_count else current_lines[last_bpeek_line_count:] + last_bpeek_line_count = len(current_lines) + if delta_lines: + streamed_bpeek = True + block = sse_data_block(list(iter_visible_lines("\n".join(delta_lines)))) + if block: + yield block + + if now - last_heartbeat >= heartbeat_interval_s: + yield ": ping\n\n" + last_heartbeat = now + time.sleep(0.1) + except Exception as e: + logger.error(f"Error streaming logs: {e}") + break + + yield f"data: === Training {finetune_job.status.value} ===\n\n" + + return Response( + generate(), + mimetype="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", + }, + ) + + +def cancel_job_response(job_id): + try: + success = g.finetune_job_manager.cancel_job(job_id) + if success: + return jsonify({"success": True, "message": f"Job {job_id} cancelled"}) + return jsonify({"success": False, "error": "Failed to cancel job"}), 400 + except Exception as e: + logger.error(f"Error cancelling job: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def stop_training_early_response(job_id): + try: + jobs = getattr(g.finetune_job_manager, "jobs", {}) or {} + job = jobs.get(job_id) + if job is None: + return jsonify({"success": False, "error": f"Job {job_id} not found"}), 404 + + output_dir = Path(job.output_dir) + if not output_dir.exists(): + return jsonify({"success": False, "error": f"Job output dir missing: {output_dir}"}), 400 + + signal_path = output_dir / "stop_signal.json" + with open(signal_path, "w") as f: + json.dump( + { + "requested_at": datetime.now().isoformat(), + "reason": "user_requested_stop_early", + }, + f, + indent=2, + ) + + return jsonify( + { + "success": True, + "message": ( + "Stop requested. Training will exit after the current epoch; " + "the inference server will then start so you can restart with " + "updated parameters." + ), + } + ) + except Exception as e: + logger.error(f"Error requesting stop-early: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def get_inference_server_status_response(job_id): + try: + job = g.finetune_job_manager.get_job(job_id) + if not job: + return jsonify({"success": False, "error": "Job not found"}), 404 + return jsonify( + { + "success": True, + "ready": job.inference_server_ready, + "url": job.inference_server_url, + "model_name": job.finetuned_model_name, + "model_script_path": str(job.model_script_path) if job.model_script_path else None, + } + ) + except Exception as e: + logger.error(f"Error getting inference server status: {e}") + return jsonify({"success": False, "error": str(e)}), 500 + + +def restart_finetuning_job_response(job_id, data): + try: + restart_t0 = time.perf_counter() + + # Pre-sync is only needed by the legacy CorrectionDataset path. With + # a virtual-sources manifest the trainer reads the volume zarr + # directly, so the sync would just download chunks the trainer never + # touches — and on big sessions can hang Restart for minutes. + from cellmap_flow.finetune.virtual_dataset import read_manifest + + jobs = getattr(g.finetune_job_manager, "jobs", {}) or {} + job_record = jobs.get(job_id) + corrections_dir = ( + str(getattr(job_record, "corrections_path", "") or "") + if job_record is not None + else "" + ) + + existing_manifest = ( + read_manifest(corrections_dir) if corrections_dir else None + ) + if existing_manifest is not None: + _refresh_virtual_manifest_for_training( + corrections_dir, existing_manifest, data, "restart" + ) + logger.info( + f"Virtual sources manifest present for job {job_id}; " + "skipping pre-restart MinIO sync." + ) + else: + try: + sync_t0 = time.perf_counter() + synced = sync_all_annotations_from_minio(force=False) + sync_elapsed = time.perf_counter() - sync_t0 + logger.info( + f"Restart pre-sync complete for job {job_id}: synced={synced}, " + f"elapsed={sync_elapsed:.2f}s" + ) + except Exception as e: + logger.warning(f"Error syncing annotations before restart: {e}") + + job = g.finetune_job_manager.restart_finetuning_job( + job_id=job_id, + updated_params=build_restart_params(data), + ) + total_elapsed = time.perf_counter() - restart_t0 + logger.info(f"Restart request processed for job {job_id}: total={total_elapsed:.2f}s") + return jsonify( + { + "success": True, + "job_id": job.job_id, + "message": "Restart request sent. Training will restart on the same GPU.", + } + ) + except Exception as e: + logger.error(f"Error restarting job: {e}") + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/finetune/viewer.py b/cellmap_flow/dashboard/routes/finetune/viewer.py new file mode 100644 index 0000000..d048d3a --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/viewer.py @@ -0,0 +1,117 @@ +import logging +from pathlib import Path + +import neuroglancer +from flask import jsonify + +from cellmap_flow.globals import g +from cellmap_flow.utils.load_py import load_safe_config + +logger = logging.getLogger(__name__) + + +def add_finetuned_layer_to_viewer_response(data): + try: + from cellmap_flow.utils.bsub_utils import LSFJob + from cellmap_flow.utils.neuroglancer_utils import ( + build_prediction_source, + get_norms_post_args, + get_raw_closest_scale, + ) + + server_url = data.get("server_url") + model_name = data.get("model_name") + model_script_path = data.get("model_script_path") + if not server_url or not model_name: + return jsonify({"success": False, "error": "Missing server_url or model_name"}), 400 + + base_model_name = model_name.rsplit("_finetuned_", 1)[0] if "_finetuned_" in model_name else model_name + if model_script_path and Path(model_script_path).exists(): + try: + model_config = load_safe_config(model_script_path) + if not hasattr(g, "models_config"): + g.models_config = [] + g.models_config = [ + mc + for mc in g.models_config + if not (hasattr(mc, "name") and mc.name.startswith(f"{base_model_name}_finetuned")) + ] + g.models_config.append(model_config) + except Exception as e: + logger.warning(f"Could not load model config: {e}") + + if not hasattr(g, "model_catalog"): + g.model_catalog = {} + if "Finetuned" not in g.model_catalog: + g.model_catalog["Finetuned"] = {} + g.model_catalog["Finetuned"] = { + name: path + for name, path in g.model_catalog["Finetuned"].items() + if not name.startswith(f"{base_model_name}_finetuned") + } + g.model_catalog["Finetuned"][model_name] = model_script_path if model_script_path else "" + + finetune_job = None + for ft_job in g.finetune_job_manager.jobs.values(): + if ft_job.finetuned_model_name == model_name: + finetune_job = ft_job + break + + if finetune_job and finetune_job.job_id: + inference_job = LSFJob(job_id=finetune_job.job_id, model_name=model_name) + inference_job.host = server_url + inference_job.status = finetune_job.status + g.jobs = [ + job + for job in g.jobs + if not ( + hasattr(job, "model_name") + and job.model_name + and job.model_name.startswith(f"{base_model_name}_finetuned") + ) + ] + g.jobs.append(inference_job) + else: + logger.warning(f"Could not find finetune job for {model_name}, Job object not created") + + with g.viewer.txn() as s: + if model_name in s.layers: + del s.layers[model_name] + + st_data = get_norms_post_args(g.input_norms, g.postprocess) + override_scales = None + try: + output_voxel_size = None + if finetune_job is not None and finetune_job.params: + output_voxel_size = tuple(finetune_job.params.get("output_voxel_size") or ()) + if not output_voxel_size: + for mc in getattr(g, "models_config", []) or []: + if mc.name == model_name: + output_voxel_size = tuple(mc.config.output_voxel_size) + break + dataset_path = getattr(g, "dataset_path", None) + if output_voxel_size and dataset_path: + closest = get_raw_closest_scale(dataset_path, output_voxel_size) + if closest is not None and tuple(closest) != tuple(output_voxel_size): + override_scales = closest + except Exception as e: + logger.warning(f"Could not compute override scales for finetuned '{model_name}': {e}") + + s.layers[model_name] = neuroglancer.ImageLayer( + source=build_prediction_source(server_url, model_name, st_data, override_scales), + shader="""#uicontrol invlerp normalized(range=[0, 255], window=[0, 255]); + #uicontrol vec3 color color(default="red"); + void main(){emitRGB(color * normalized());}""", + ) + + return jsonify( + { + "success": True, + "layer_name": model_name, + "model_name": model_name, + "reload_page": True, + } + ) + except Exception as e: + logger.error(f"Error adding finetuned layer: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/finetune/yaml_crops.py b/cellmap_flow/dashboard/routes/finetune/yaml_crops.py new file mode 100644 index 0000000..eaecfd5 --- /dev/null +++ b/cellmap_flow/dashboard/routes/finetune/yaml_crops.py @@ -0,0 +1,542 @@ +"""Endpoint for bulk-loading externally annotated crops via a YAML manifest. + +Design +------ +A YAML manifest is conceptually a different way to **seed an annotation +volume**, alongside "New Volume" (empty) and "Resume Existing Volume" +(copy a prior session). Importing crops writes them straight into the +session's ``annotation_volume.zarr`` at their correct physical offsets, so +the result is identical in shape to a painted volume — one editable layer +in neuroglancer, served via MinIO, picked up by the existing periodic-sync +machinery, and consumed by training via :class:`VirtualPatchDataset`. + +Painted scribbles + imported GT crops therefore share one source of truth +(the volume zarr). The user can paint over imports to fix GT errors or to +add corrections in regions the GT doesn't cover. The trainer sees the +union by construction. +""" + +import logging +import os +import threading +import time +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime + +import numpy as np +import zarr +from flask import jsonify +from pydantic import ValidationError + +# Module-level progress tracker, keyed by load_id supplied by the client. +# Each value is the most recent progress snapshot for that load + its +# final result (or None while in progress). Old entries are evicted after +# 5 minutes to bound memory. +_PROGRESS: dict = {} +_PROGRESS_LOCK = threading.Lock() +_PROGRESS_TTL_SECONDS = 300 + + +def _set_progress(load_id, **fields): + if not load_id: + return + with _PROGRESS_LOCK: + entry = _PROGRESS.setdefault(load_id, {"created_at": time.time()}) + entry.update(fields) + entry["updated_at"] = time.time() + now = time.time() + stale = [ + k for k, v in _PROGRESS.items() + if now - v.get("updated_at", v.get("created_at", now)) > _PROGRESS_TTL_SECONDS + ] + for k in stale: + _PROGRESS.pop(k, None) + + +from cellmap_flow.dashboard.finetune_utils import ( + create_annotation_volume_zarr, + ensure_minio_serving, +) +from cellmap_flow.dashboard.routes.finetune.annotation_core import ( + _get_selected_model_config, + _register_annotation_volume, +) +from cellmap_flow.dashboard.routes.finetune.common import ensure_corrections_storage +from cellmap_flow.dashboard.routes.finetune.overlay import refresh_annotated_regions_layer +from cellmap_flow.finetune.crop_loader import ( + _open_array, + _read_voxel_size_and_offset, + parse_crops_yaml, + remap_labels, +) +from cellmap_flow.finetune.virtual_dataset import write_manifest +from cellmap_flow.globals import current_input_norm_config, g + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Volume bookkeeping +# --------------------------------------------------------------------------- + +def _find_session_annotation_volume(corrections_dir): + """Return ``(volume_id, meta)`` for the annotation_volume in this corrections + dir, or ``(None, None)`` if none is registered yet.""" + for vid, meta in (getattr(g, "annotation_volumes", {}) or {}).items(): + if meta.get("corrections_dir") == corrections_dir: + return vid, meta + return None, None + + +def _create_session_annotation_volume( + *, + raw_dataset_path, + corrections_dir, + model_name, + config, +): + """Create a fresh annotation_volume.zarr in ``corrections_dir`` and register it. + + Mirrors the body of ``create_annotation_volume_response`` minus the + HTTP-shaped response wrapping; returns the freshly-built ``(volume_id, meta)``. + """ + from cellmap_flow.image_data_interface import ImageDataInterface + from cellmap_flow.utils.neuroglancer_utils import get_raw_closest_scale + + read_shape = np.array(config.read_shape) + write_shape = np.array(config.write_shape) + claimed_input_voxel_size = np.array(config.input_voxel_size) + claimed_output_voxel_size = np.array(config.output_voxel_size) + output_size = (write_shape / claimed_output_voxel_size).astype(int) + input_size = (read_shape / claimed_input_voxel_size).astype(int) + + try: + eff_output_vs = np.array( + get_raw_closest_scale(raw_dataset_path, tuple(claimed_output_voxel_size)) + or claimed_output_voxel_size + ) + eff_input_vs = np.array( + get_raw_closest_scale(raw_dataset_path, tuple(claimed_input_voxel_size)) + or claimed_input_voxel_size + ) + except Exception: + eff_output_vs = claimed_output_voxel_size + eff_input_vs = claimed_input_voxel_size + + idi = ImageDataInterface(raw_dataset_path, voxel_size=eff_output_vs) + dataset_offset_nm = np.array(idi.roi.offset) + dataset_shape_nm = np.array(idi.roi.shape) + dataset_shape_voxels = (dataset_shape_nm / eff_output_vs).astype(int) + dataset_shape_voxels = ( + np.ceil(dataset_shape_voxels / output_size).astype(int) * output_size + ) + + volume_id = ( + f"vol-{uuid.uuid4().hex[:8]}-{datetime.now().strftime('%Y%m%d-%H%M%S')}" + ) + zarr_path = os.path.join(corrections_dir, f"{volume_id}.zarr") + + success, info = create_annotation_volume_zarr( + zarr_path=zarr_path, + dataset_shape_voxels=dataset_shape_voxels, + output_voxel_size=eff_output_vs, + dataset_offset_nm=dataset_offset_nm, + chunk_size=output_size, + dataset_path=raw_dataset_path, + model_name=model_name, + input_size=input_size, + input_voxel_size=eff_input_vs, + claimed_output_voxel_size=claimed_output_voxel_size, + claimed_input_voxel_size=claimed_input_voxel_size, + # Snapshot whatever input_norm the dashboard is currently using so + # the trainer can reproduce inference-side normalization. + input_norm_config=current_input_norm_config(), + ) + if not success: + raise RuntimeError(f"create_annotation_volume_zarr failed: {info}") + + minio_url = ensure_minio_serving(zarr_path, volume_id, output_base_dir=corrections_dir) + _register_annotation_volume( + volume_id, + zarr_path=zarr_path, + model_name=model_name, + output_size=output_size.tolist(), + input_size=input_size.tolist(), + input_voxel_size=eff_input_vs.tolist(), + output_voxel_size=eff_output_vs.tolist(), + claimed_input_voxel_size=claimed_input_voxel_size.tolist(), + claimed_output_voxel_size=claimed_output_voxel_size.tolist(), + dataset_path=raw_dataset_path, + dataset_offset_nm=dataset_offset_nm.tolist(), + corrections_dir=corrections_dir, + minio_url=minio_url, + ) + meta = g.annotation_volumes[volume_id] + return volume_id, meta + + +def _ensure_editable_layer(volume_id, minio_url): + """Add the volume's MinIO-backed annotation layer to the viewer if absent.""" + import neuroglancer + + if not getattr(g, "viewer", None) or not minio_url: + return + layer_name = f"annotation_{volume_id}" + try: + with g.viewer.txn() as s: + if layer_name in s.layers: + return + source_config = { + "url": f"s3+{minio_url}/annotation", + "subsources": {"default": {"writingEnabled": True}, "bounds": {}}, + } + s.layers[layer_name] = neuroglancer.SegmentationLayer(source=source_config) + except Exception as e: + logger.warning(f"Could not add editable layer for {volume_id}: {e}") + + +# --------------------------------------------------------------------------- +# Crop -> volume write +# --------------------------------------------------------------------------- + +def _write_crop_into_volume(volume_meta, entry, *, progress_callback=None): + """Read a YAML crop's annotation, remap, and write it into volume[s0] at the + crop's physical offset. Returns the number of FG voxels written.""" + t0 = time.time() + sub, src_voxel_size_nm, src_offset_nm = _read_voxel_size_and_offset(entry.path) + t_meta = time.time() - t0 + t1 = time.time() + src_arr = _open_array(entry.path, sub) + src_data = src_arr[:] + t_read = time.time() - t1 + if src_data.ndim != 3: + raise ValueError( + f"Crop {entry.path}: expected 3D (z, y, x), got shape {src_data.shape}" + ) + + eff_output_vs = np.array(volume_meta["output_voxel_size"], dtype=float) + if not np.allclose(src_voxel_size_nm, eff_output_vs): + logger.warning( + f"Crop {entry.path} voxel size {tuple(src_voxel_size_nm)} != " + f"volume voxel size {tuple(eff_output_vs)}. Writing values as-is " + "without resampling — caller should ensure scale compatibility." + ) + + t2 = time.time() + remapped = remap_labels( + src_data, + fg_ids=entry.fg_ids, + bg_ids=list(entry.bg_ids), + mode=entry.mode, + connected_components=entry.connected_components, + ) + t_remap = time.time() - t2 + t3 = time.time() + n_fg = int(np.count_nonzero(remapped >= 2)) + t_count = time.time() - t3 + logger.info( + f"Crop {entry.path} prep: meta={t_meta:.2f}s read={t_read:.2f}s " + f"({src_data.nbytes/1e6:.1f} MB, dtype={src_data.dtype}, shape={src_data.shape}) " + f"remap={t_remap:.2f}s count_fg={t_count:.2f}s" + ) + + dataset_offset_nm = np.array(volume_meta["dataset_offset_nm"], dtype=float) + write_voxel_offset = ( + (src_offset_nm - dataset_offset_nm) / eff_output_vs + ).astype(int) + z0, y0, x0 = write_voxel_offset.tolist() + sz, sy, sx = remapped.shape + + vol = zarr.open(volume_meta["zarr_path"], mode="r+") + arr = vol["annotation/s0"] + if ( + z0 < 0 or y0 < 0 or x0 < 0 + or z0 + sz > arr.shape[0] + or y0 + sy > arr.shape[1] + or x0 + sx > arr.shape[2] + ): + raise ValueError( + f"Crop {entry.path} write region [{z0}:{z0+sz}, {y0}:{y0+sy}, {x0}:{x0+sx}] " + f"is outside annotation volume shape {arr.shape}. Check the source's " + "OME-NGFF translation against the dataset offset." + ) + + # Slice the crop into Z-aligned slabs and write them in parallel. Slabs + # are aligned to the underlying zarr chunk size so two slabs never + # touch the same chunk, making concurrent writes safe (zarr's chunk + # writes are per-chunk-file, no shared mutable state). + # + # Slab count tracks the LSF slot allocation so we always fully use what + # bsub gave us — capped by the number of chunk-aligned slabs we can + # actually produce. + from cellmap_flow.dashboard.finetune_utils import _get_sync_worker_count + + chunk_z = max(int(arr.chunks[0]), 1) + max_chunk_slabs = int(np.ceil(sz / chunk_z)) + n_slabs = max(1, min(_get_sync_worker_count(), max_chunk_slabs)) + slab_size = int(np.ceil(sz / n_slabs / chunk_z) * chunk_z) + slabs = [] + for s in range(n_slabs): + a = s * slab_size + b = min((s + 1) * slab_size, sz) + if a < b: + slabs.append((a, b)) + n_slabs = len(slabs) + + def _write_one(slab): + a, b = slab + arr[z0 + a : z0 + b, y0 : y0 + sy, x0 : x0 + sx] = remapped[a:b, :, :] + + t4 = time.time() + written = 0 + n_workers = max(1, n_slabs) + with ThreadPoolExecutor(max_workers=n_workers) as ex: + futures = [ex.submit(_write_one, s) for s in slabs] + for fut in as_completed(futures): + fut.result() # surface any per-slab exception + written += 1 + if progress_callback is not None: + progress_callback(written, n_slabs) + t_write = time.time() - t4 + logger.info( + f"Crop {entry.path} write: {n_slabs} slabs, {n_workers} workers, " + f"{t_write:.2f}s total wall" + ) + + # Record this import in the volume's root attrs so the bounding-box + # overlay can surface it as a single yellow box per crop (vs. the + # per-chunk small boxes from painted scribbles). + vol_root = zarr.open(volume_meta["zarr_path"], mode="r+") + imported = list(vol_root.attrs.get("imported_crops", [])) + imported.append( + { + "path": entry.path, + "name": entry.name, + "annotation_offset_voxels": [int(z0), int(y0), int(x0)], + "annotation_shape_voxels": [int(sz), int(sy), int(sx)], + "n_fg_voxels": int(n_fg), + } + ) + vol_root.attrs["imported_crops"] = imported + + return n_fg + + +# --------------------------------------------------------------------------- +# Endpoint +# --------------------------------------------------------------------------- + +def load_crops_from_yaml_response(data): + """Import crops from a YAML manifest into the session's annotation_volume. + + Request JSON: + - ``model_name``: required + - ``output_path``: optional, base path for the session corrections dir + - ``yaml``: required, YAML text (or path to a YAML file) + - ``load_id``: optional UUID for live progress polling + """ + try: + model_name = data.get("model_name") + output_path = data.get("output_path") + yaml_input = data.get("yaml") + load_id = data.get("load_id") + if load_id: + _set_progress( + load_id, + phase="starting", + current_path="", + tile_done=0, + tile_total=0, + crop_index=0, + n_crops=0, + done=False, + ) + + if not yaml_input: + return jsonify({"success": False, "error": "Missing 'yaml' field"}), 400 + if not model_name: + return jsonify({"success": False, "error": "Missing 'model_name' field"}), 400 + + try: + crops_config = parse_crops_yaml(yaml_input) + except ValidationError as e: + return ( + jsonify({"success": False, "error": "YAML validation failed", "details": e.errors()}), + 400, + ) + except Exception as e: + return jsonify({"success": False, "error": f"YAML parse error: {e}"}), 400 + + if not crops_config.crops: + return jsonify({"success": False, "error": "No crops listed in YAML"}), 400 + + model_config, error_response = _get_selected_model_config(model_name) + if error_response is not None: + return error_response + + raw_dataset_path = getattr(g, "dataset_path", None) + if not raw_dataset_path: + return jsonify({"success": False, "error": "No raw dataset path configured"}), 400 + + _, corrections_dir = ensure_corrections_storage(output_path) + + # Reuse the session's annotation_volume if the user already created one + # (via "New Volume" or "Resume Existing"). Otherwise spin up a fresh one + # so the YAML import has a destination. + volume_id, volume_meta = _find_session_annotation_volume(corrections_dir) + created_volume = False + if volume_meta is None: + volume_id, volume_meta = _create_session_annotation_volume( + raw_dataset_path=raw_dataset_path, + corrections_dir=corrections_dir, + model_name=model_name, + config=model_config.config, + ) + created_volume = True + _ensure_editable_layer(volume_id, volume_meta.get("minio_url")) + + n_crops = len(crops_config.crops) + errors = [] + total_fg_written = 0 + for crop_index, entry in enumerate(crops_config.crops): + if load_id: + _set_progress( + load_id, + phase="crop_start", + crop_index=crop_index, + n_crops=n_crops, + current_path=entry.path, + tile_done=0, + tile_total=0, + done=False, + ) + try: + def _cb(done, total, ci=crop_index, p=entry.path): + if load_id: + _set_progress( + load_id, + phase="tile", + crop_index=ci, + n_crops=n_crops, + current_path=p, + tile_done=int(done), + tile_total=int(total), + done=False, + ) + + n_fg = _write_crop_into_volume( + volume_meta, entry, progress_callback=_cb + ) + total_fg_written += n_fg + logger.info(f"Imported crop {entry.path}: {n_fg} FG voxels") + except Exception as e: + logger.exception(f"Failed to import crop {entry.path}") + errors.append({"path": entry.path, "error": str(e)}) + + # The MinIO bucket was mirrored once at volume-create time, when the + # zarr held only metadata. Re-mirror now that chunk data is written + # so neuroglancer can read the imported annotations from the + # editable layer. + try: + ensure_minio_serving( + volume_meta["zarr_path"], + volume_id, + output_base_dir=corrections_dir, + ) + except Exception as e: + logger.warning(f"MinIO re-mirror failed for {volume_id}: {e}") + + # Manifest: trainer reads from this single volume zarr. The + # ``input_norm`` block carries the dashboard's current normalization + # so VirtualPatchDataset (running in the LSF trainer process where + # g.input_norms is empty) can apply the same normalization the + # dashboard does at inference time. Without this the trainer feeds + # the model raw uint8 while inference feeds it [-1, 1] -- the + # trained adapter is then nonsense at inference time. + manifest = { + "kind": "volume_zarr_v1", + "volume_zarr_path": volume_meta["zarr_path"], + "raw_dataset_path": raw_dataset_path, + "input_size_voxels": list(volume_meta["input_size"]), + "output_size_voxels": list(volume_meta["output_size"]), + "input_voxel_size_nm": list(volume_meta["input_voxel_size"]), + "output_voxel_size_nm": list(volume_meta["output_voxel_size"]), + # patches_per_epoch=None tells VirtualPatchDataset to default to + # "one patch per populated chunk" (full coverage). Explicit ints + # in the YAML pass through verbatim. + "patches_per_epoch": crops_config.patches_per_epoch, + "jitter_voxels": crops_config.jitter_voxels, + "seed": crops_config.seed, + "input_norm": current_input_norm_config(), + # None → auto-balance dense vs sparse pools (50/50 when both + # exist, else use the surviving pool). + "dense_to_sparse_ratio": crops_config.dense_to_sparse_ratio, + } + write_manifest(corrections_dir, manifest) + + try: + refresh_annotated_regions_layer(corrections_path=corrections_dir) + except Exception as e: + logger.warning(f"refresh_annotated_regions_layer failed: {e}") + + if load_id: + _set_progress( + load_id, + phase="done", + done=True, + n_crops_imported=n_crops - len(errors), + n_errors=len(errors), + volume_id=volume_id, + fg_voxels_written=total_fg_written, + ) + + return jsonify( + { + "success": True, + "n_crops_requested": n_crops, + "n_crops_imported": n_crops - len(errors), + "n_errors": len(errors), + "fg_voxels_written": total_fg_written, + "volume_id": volume_id, + "created_new_volume": created_volume, + "errors": errors, + } + ) + except Exception as e: + logger.exception("load_crops_from_yaml_response failed") + return jsonify({"success": False, "error": str(e)}), 500 + + +# --------------------------------------------------------------------------- +# Auxiliary endpoints (file read + progress polling) — unchanged behavior +# --------------------------------------------------------------------------- + +def get_load_crops_progress_response(load_id): + """Return current progress for an in-flight ``/api/finetune/load-crops`` call.""" + if not load_id: + return jsonify({"success": False, "error": "Missing 'load_id' query param"}), 400 + with _PROGRESS_LOCK: + snapshot = _PROGRESS.get(load_id) + snapshot = dict(snapshot) if snapshot else None + if snapshot is None: + return jsonify({"success": False, "error": f"Unknown load_id {load_id}"}), 404 + return jsonify({"success": True, "progress": snapshot}) + + +def read_yaml_file_response(path): + """Return the contents of a YAML file so the dashboard can preview/edit it.""" + if not path: + return jsonify({"success": False, "error": "Missing 'path' query param"}), 400 + if not os.path.exists(path): + return jsonify({"success": False, "error": f"File not found: {path}"}), 404 + if not os.path.isfile(path): + return jsonify({"success": False, "error": f"Not a file: {path}"}), 400 + if os.path.getsize(path) > 1_000_000: + return jsonify({"success": False, "error": "File exceeds 1 MB; paste it directly instead"}), 400 + try: + with open(path) as f: + text = f.read() + return jsonify({"success": True, "text": text}) + except Exception as e: + return jsonify({"success": False, "error": str(e)}), 500 diff --git a/cellmap_flow/dashboard/routes/index_page.py b/cellmap_flow/dashboard/routes/index_page.py index d44198d..9b0db21 100644 --- a/cellmap_flow/dashboard/routes/index_page.py +++ b/cellmap_flow/dashboard/routes/index_page.py @@ -39,8 +39,8 @@ def index(): return render_template( "index.html", - neuroglancer_url=state.NEUROGLANCER_URL, - inference_servers=state.INFERENCE_SERVER, + neuroglancer_url=g.NEUROGLANCER_URL, + inference_servers=g.INFERENCE_SERVER, input_normalizers=input_norms, output_postprocessors=output_postprocessors, model_mergers=model_mergers, diff --git a/cellmap_flow/dashboard/routes/logging_routes.py b/cellmap_flow/dashboard/routes/logging_routes.py index 73ff32f..711f115 100644 --- a/cellmap_flow/dashboard/routes/logging_routes.py +++ b/cellmap_flow/dashboard/routes/logging_routes.py @@ -4,7 +4,7 @@ from flask import Blueprint, Response -from cellmap_flow.dashboard.state import log_buffer, log_clients +from cellmap_flow.globals import g logger = logging.getLogger(__name__) @@ -16,12 +16,12 @@ def stream_logs(): """Stream logs via Server-Sent Events (SSE)""" def generate(): # Send existing log buffer first - for log_line in log_buffer: + for log_line in g.log_buffer: yield f"data: {log_line}\n\n" # Create a queue for this client client_queue = queue.Queue(maxsize=100) - log_clients.append(client_queue) + g.log_clients.append(client_queue) try: while True: @@ -33,8 +33,8 @@ def generate(): yield ": keepalive\n\n" finally: # Clean up when client disconnects - if client_queue in log_clients: - log_clients.remove(client_queue) + if client_queue in g.log_clients: + g.log_clients.remove(client_queue) return Response(generate(), mimetype="text/event-stream", headers={ "Cache-Control": "no-cache", diff --git a/cellmap_flow/dashboard/routes/pipeline.py b/cellmap_flow/dashboard/routes/pipeline.py index 1bb2f9e..5ed3468 100644 --- a/cellmap_flow/dashboard/routes/pipeline.py +++ b/cellmap_flow/dashboard/routes/pipeline.py @@ -16,7 +16,6 @@ from cellmap_flow.utils.load_py import load_safe_config from cellmap_flow.utils.scale_pyramid import get_raw_layer from cellmap_flow.utils.web_utils import encode_to_str, ARGS_KEY -from cellmap_flow.dashboard.state import CUSTOM_CODE_FOLDER logger = logging.getLogger(__name__) @@ -107,7 +106,15 @@ def process(): del data["custom_code"] logger.warning(f"Data received: {type(data)} - {data.keys()} -{data}") g.input_norms = get_normalizations(data["input_norm"]) + # Keep the raw, JSON-serializable input_norm dict around so downstream + # components (finetune submit/restart, manifest, generated yaml) can + # propagate the same normalization to the trainer process. Without this + # the trainer reads raw uint8 from /nrs while inference normalizes to + # the model's expected range -> trained model never sees inference-scale + # inputs. + g.input_norm_config = data.get("input_norm", {}) or {} g.postprocess = get_postprocessors(data["postprocess"]) + g.postprocess_config = data.get("postprocess", {}) or {} # Save current shader state from viewer before refreshing layers _save_shaders_from_viewer() @@ -141,7 +148,7 @@ def process(): # Save custom code to a file with date and time timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"custom_code_{timestamp}.py" - filepath = os.path.join(CUSTOM_CODE_FOLDER, filename) + filepath = os.path.join(g.CUSTOM_CODE_FOLDER, filename) with open(filepath, "w") as file: file.write(custom_code) @@ -265,6 +272,10 @@ def apply_pipeline(): } logger.warning(f"\nNormalizers config dict: {input_norms_config}") g.input_norms = get_normalizations(input_norms_config) + # Mirror the JSON-serializable form so finetune submit/restart can + # propagate it to the trainer process (where g.input_norms can't be + # easily reconstructed across the LSF process boundary). + g.input_norm_config = input_norms_config or {} # Apply postprocessors postprocs_config = { @@ -272,6 +283,7 @@ def apply_pipeline(): } logger.warning(f"Postprocessors config dict: {postprocs_config}") g.postprocess = get_postprocessors(postprocs_config) + g.postprocess_config = postprocs_config or {} # Save complete pipeline visual state to globals g.pipeline_inputs = data.get("inputs", []) diff --git a/cellmap_flow/dashboard/routes/pipeline_builder_page.py b/cellmap_flow/dashboard/routes/pipeline_builder_page.py index eafa8a1..c45b45c 100644 --- a/cellmap_flow/dashboard/routes/pipeline_builder_page.py +++ b/cellmap_flow/dashboard/routes/pipeline_builder_page.py @@ -147,8 +147,10 @@ def pipeline_builder(): for model_name, model_data in available_models.items(): model_name_stripped = model_name.replace('_server', '') logger.warning(f" Checking available_models: {model_name} (stripped: {model_name_stripped}) vs job: {job.model_name} (stripped: {job_model_name})") - if model_name_stripped == job_model_name and isinstance(model_data, dict) and 'config' in model_data: - model_dict['config'] = model_data['config'] + if model_name_stripped == job_model_name and isinstance(model_data, dict): + # Models from g.models_config store to_dict() directly + # (no nested 'config' key); use the dict itself as config + model_dict['config'] = model_data.get('config', model_data) logger.warning(f" ✓ Config attached from available_models: {model_dict['config']}") config_found = True break diff --git a/cellmap_flow/dashboard/state.py b/cellmap_flow/dashboard/state.py index 07dd777..389c95b 100644 --- a/cellmap_flow/dashboard/state.py +++ b/cellmap_flow/dashboard/state.py @@ -1,53 +1,15 @@ -import os -import logging -import queue -from collections import deque - -logger = logging.getLogger(__name__) - -# Global log buffer for streaming to frontend -log_buffer = deque(maxlen=1000) # Keep last 1000 lines -log_clients = [] # List of queues for connected clients - - -# Custom handler to capture logs -class LogHandler(logging.Handler): - def emit(self, record): - log_entry = self.format(record) - log_buffer.append(log_entry) - # Send to all connected clients - for client_queue in log_clients: - try: - client_queue.put_nowait(log_entry) - except queue.Full: - pass - - -NEUROGLANCER_URL = None -INFERENCE_SERVER = None - -CUSTOM_CODE_FOLDER = os.path.expanduser( - os.environ.get( - "CUSTOM_CODE_FOLDER", - "~/Desktop/cellmap/cellmap-flow/example/example_norm", - ) -) - -# Blockwise task directory will be set from globals or use default -def get_blockwise_tasks_dir(): - from cellmap_flow.globals import g - tasks_dir = getattr(g, 'blockwise_tasks_dir', None) or os.path.expanduser("~/.cellmap_flow/blockwise_tasks") - os.makedirs(tasks_dir, exist_ok=True) - return tasks_dir - - -# Global state for BBX generator -bbx_generator_state = { - "dataset_path": None, - "num_boxes": 0, - "bounding_boxes": [], - "viewer": None, - "viewer_process": None, - "viewer_url": None, - "viewer_state": None -} +# Re-export all dashboard state from the globals singleton for backward compatibility. +# New code should import directly from cellmap_flow.globals. + +from cellmap_flow.globals import g, LogHandler, get_blockwise_tasks_dir # noqa: F401 + +log_buffer = g.log_buffer +log_clients = g.log_clients +NEUROGLANCER_URL = g.NEUROGLANCER_URL +INFERENCE_SERVER = g.INFERENCE_SERVER +CUSTOM_CODE_FOLDER = g.CUSTOM_CODE_FOLDER +bbx_generator_state = g.bbx_generator_state +finetune_job_manager = g.finetune_job_manager +minio_state = g.minio_state +annotation_volumes = g.annotation_volumes +output_sessions = g.output_sessions diff --git a/cellmap_flow/dashboard/static/css/dark.css b/cellmap_flow/dashboard/static/css/dark.css index bd454cd..92b1d0b 100644 --- a/cellmap_flow/dashboard/static/css/dark.css +++ b/cellmap_flow/dashboard/static/css/dark.css @@ -423,13 +423,26 @@ textarea.form-control[readonly] { } /* ---------- Modal ---------- */ -.modal-content.bg-dark { - background: var(--bg-surface) !important; +.modal-content { + background: var(--bg-surface); + color: var(--text-primary); border: 1px solid var(--border); border-radius: var(--radius-lg); box-shadow: var(--shadow-md); } +.modal-content.bg-dark { + background: var(--bg-surface) !important; +} + +.modal-header { + border-bottom-color: var(--border); +} + +.modal-footer { + border-top-color: var(--border); +} + .modal-header.border-secondary { border-color: var(--border) !important; padding: 16px 20px; @@ -449,6 +462,30 @@ textarea.form-control[readonly] { padding: 12px 20px; } +/* ---------- Alert ---------- */ +.alert-info { + background-color: var(--bg-muted); + border-color: var(--info); + color: var(--text-primary); +} + +/* ---------- Card ---------- */ +.card { + background-color: var(--bg-surface); + border-color: var(--border); + color: var(--text-primary); +} + +.card-header { + background-color: var(--bg-elevated); + border-bottom-color: var(--border); + color: var(--text-primary); +} + +.card-body { + color: var(--text-primary); +} + /* ---------- Editor ---------- */ .editor { width: 100%; diff --git a/cellmap_flow/dashboard/static/css/pipeline_builder.css b/cellmap_flow/dashboard/static/css/pipeline_builder.css index 34d2952..aba4ed9 100644 --- a/cellmap_flow/dashboard/static/css/pipeline_builder.css +++ b/cellmap_flow/dashboard/static/css/pipeline_builder.css @@ -229,8 +229,7 @@ .canvas-content { position: relative; - width: 100%; - height: 100%; + min-width: 100%; min-height: 600px; } diff --git a/cellmap_flow/dashboard/templates/_dashboard.html b/cellmap_flow/dashboard/templates/_dashboard.html index e141bc1..c252e09 100644 --- a/cellmap_flow/dashboard/templates/_dashboard.html +++ b/cellmap_flow/dashboard/templates/_dashboard.html @@ -44,6 +44,22 @@ Postprocess + + + @@ -72,5 +88,15 @@ > {% include "_output_tab.html" %} + + +
+ {% include "_finetune_tab.html" %} +
diff --git a/cellmap_flow/dashboard/templates/_finetune_tab.html b/cellmap_flow/dashboard/templates/_finetune_tab.html new file mode 100644 index 0000000..114773e --- /dev/null +++ b/cellmap_flow/dashboard/templates/_finetune_tab.html @@ -0,0 +1,2068 @@ + + + + + +
+ +
+ +
+ +
+ +
+ + +
+ + Crop will be sized to model's output inference size + +
+ + + + + +
+ + + Directory where annotation crops will be saved (must be accessible to MinIO). Crop will be created at current view center position. +
+ + +
+ +
+ + + +
+ + Sparse volume across full dataset extent (paint 1=background, 2=foreground). + Resuming copies the chosen session into a new one, leaving the original intact. + +
+ + +
+ + + Persists in-progress annotations from MinIO to disk. Painted regions are + automatically outlined as bounding boxes in the viewer (toggle via the + "annotated_regions" layer). + +
+ + + + + + + +
+ Advanced (deprecated): dense crop at view center +
+ + + Small dense region at current view center (paint 1=foreground). Rarely used. + +
+
+ + + + + +
+ + +
+
+
+ + +
+ + +
+
+
Training Configuration
+
+
+ +
+ + + + Override the base model checkpoint to finetune from. If left empty, the system will attempt to extract it from the model configuration or script. + +
+ +
+
+
+ + + Higher rank = more trainable parameters +
+
+
+
+ + + Typical range: 10-20 epochs +
+
+
+ +
+
+
+ + + Higher = faster but uses more GPU memory +
+
+
+
+ + + LoRA typically uses higher learning rates +
+
+
+ +
+
+
+ + + + Blank keeps the manifest value. 0 uses auto. A positive value + caps random samples per epoch. + +
+
+
+ +
+
+
+ + + Margin is recommended for sparse annotations +
+
+ + + + Smaller = stricter (more learning signal). Predictions outside + [margin, 1-margin] get zero gradient. + +
+
+
+
+ + + Keeps model close to original predictions +
+
+
+
+
+
+ + + Where to apply distillation loss +
+
+ + + + Softens hard 0/1 targets to [s/2, 1-s/2]. Default 0.1. + Set to 0 for clean annotations where you want sharp 0/1 outputs. + +
+
+
+
+
+ + +
+ Weight fg and bg equally in loss regardless of scribble ratio. Prevents foreground overprediction. +
+
+
+ +
+
+
+ + +
+
+
+ + +
+ + +
+ +
+ +
+
+
+ + + + + +
+
+
Training Logs
+
+ + +
+
+
+
+
+ Loss vs Epoch + No loss data yet +
+ +
+ +
+
+
+
+ + + diff --git a/cellmap_flow/dashboard/templates/pipeline_builder_v2.html b/cellmap_flow/dashboard/templates/pipeline_builder_v2.html index 2ce45da..502c4b2 100644 --- a/cellmap_flow/dashboard/templates/pipeline_builder_v2.html +++ b/cellmap_flow/dashboard/templates/pipeline_builder_v2.html @@ -13,13 +13,14 @@ ← Back

CellMapFlow Pipeline

+ - +
@@ -282,7 +283,7 @@

Create Custom Model Configuration

pipeline.outputs = (pipeline.outputs || []).map((n, i) => ({ id: n.id || `output-${i}-${Date.now()}`, name: 'OUTPUT', - params: {}, + params: n.params || {}, position: n.position || { x: 900, y: 20 + (i * 180) } })); pipeline.normalizers = (pipeline.normalizers || []).map((n, i) => ({ @@ -406,10 +407,146 @@

Create Custom Model Configuration

} } + // Auto-layout nodes based on graph topology (edges), measuring actual DOM sizes. + // Must be called AFTER renderCanvas() so nodes are in the DOM. + function autoLayoutNodes() { + const COL_GAP = 80; + const ROW_SPACING = 40; + const LEFT_MARGIN = 40; + const TOP_MARGIN = 40; + + // Collect all connected nodes (exclude blockwise-config) + const allNodes = [ + ...pipeline.inputs, + ...pipeline.normalizers, + ...pipeline.models, + ...pipeline.postprocessors, + ...pipeline.outputs + ]; + if (allNodes.length === 0) return; + + const nodeById = {}; + allNodes.forEach(n => { nodeById[n.id] = n; }); + + // Measure actual DOM sizes + const nodeSize = {}; + allNodes.forEach(n => { + const el = document.getElementById(`node-${n.id}`); + if (el) { + nodeSize[n.id] = { w: el.offsetWidth, h: el.offsetHeight }; + } else { + nodeSize[n.id] = { w: 300, h: 120 }; // fallback + } + }); + + // Build adjacency + const incomingCount = {}; + const outgoing = {}; + allNodes.forEach(n => { + incomingCount[n.id] = 0; + outgoing[n.id] = []; + }); + pipeline.edges.forEach(e => { + if (nodeById[e.from] && nodeById[e.to]) { + outgoing[e.from].push(e.to); + incomingCount[e.to] = (incomingCount[e.to] || 0) + 1; + } + }); + + // Assign layers via topological BFS (longest-path layering) + const layer = {}; + allNodes.forEach(n => { layer[n.id] = 0; }); + const queue = allNodes.filter(n => incomingCount[n.id] === 0).map(n => n.id); + const visited = new Set(); + while (queue.length > 0) { + const nid = queue.shift(); + if (visited.has(nid)) continue; + visited.add(nid); + outgoing[nid].forEach(childId => { + layer[childId] = Math.max(layer[childId], layer[nid] + 1); + incomingCount[childId]--; + if (incomingCount[childId] <= 0) { + queue.push(childId); + } + }); + } + + // Disconnected nodes get layer 0 + allNodes.forEach(n => { + if (!visited.has(n.id)) { + layer[n.id] = 0; + visited.add(n.id); + } + }); + + // Group by layer + const layers = {}; + allNodes.forEach(n => { + const l = layer[n.id]; + if (!layers[l]) layers[l] = []; + layers[l].push(n); + }); + const sortedLayerKeys = Object.keys(layers).map(Number).sort((a, b) => a - b); + + // Compute column heights using actual measured node heights + const colHeights = {}; + sortedLayerKeys.forEach(k => { + const nodes = layers[k]; + colHeights[k] = nodes.reduce((sum, n) => sum + nodeSize[n.id].h, 0) + (nodes.length - 1) * ROW_SPACING; + }); + const maxColHeight = Math.max(...Object.values(colHeights)); + + // Compute x positions: each column starts after the widest node in the previous column + const colX = {}; + let currentX = LEFT_MARGIN; + sortedLayerKeys.forEach((k, i) => { + colX[k] = currentX; + const maxWidth = Math.max(...layers[k].map(n => nodeSize[n.id].w)); + currentX += maxWidth + COL_GAP; + }); + + // Position each layer + sortedLayerKeys.forEach(k => { + const colNodes = layers[k]; + const x = colX[k]; + const yOffset = TOP_MARGIN + (maxColHeight - colHeights[k]) / 2; + + let y = yOffset; + colNodes.forEach(node => { + node.position = { x, y }; + // Update DOM element directly (no re-render needed) + const el = document.getElementById(`node-${node.id}`); + if (el) { + el.style.left = x + 'px'; + el.style.top = y + 'px'; + } + y += nodeSize[node.id].h + ROW_SPACING; + }); + }); + + // Position blockwise-config below the main flow + if (pipeline.blockwise_config && pipeline.blockwise_config.length > 0) { + const bottomY = TOP_MARGIN + maxColHeight + 60; + pipeline.blockwise_config.forEach((node, i) => { + const size = nodeSize[node.id] || { w: 350 }; + node.position = { x: LEFT_MARGIN + i * (size.w + COL_GAP), y: bottomY }; + const el = document.getElementById(`node-${node.id}`); + if (el) { + el.style.left = node.position.x + 'px'; + el.style.top = node.position.y + 'px'; + } + }); + } + + // Redraw connections with new positions + renderConnections(); + } + // Run auto-connect if we have nodes - if (pipeline.inputs.length > 0 || pipeline.normalizers.length > 0 || + if (pipeline.inputs.length > 0 || pipeline.normalizers.length > 0 || pipeline.models.length > 0 || pipeline.postprocessors.length > 0) { autoConnectNodes(); + // autoLayoutNodes() is called after first renderCanvas() in initializeLibrary() } // Initialize library @@ -478,6 +615,8 @@

Create Custom Model Configuration

canvas.addEventListener('drop', handleDrop); renderCanvas(); + // Auto-layout after first render so we can measure actual DOM sizes + autoLayoutNodes(); } function toggleSection(btn, sectionId) { @@ -602,6 +741,11 @@

Create Custom Model Configuration

// Store the full config if available if (modelDef.config) { modelConfig = { ...modelDef.config }; + } else if (modelDef.type) { + // Models from g.models_config have to_dict() as the + // entry itself (no nested .config key). Use the whole + // definition as config so params auto-populate. + modelConfig = { ...modelDef }; } } } else if (type === 'postprocessor') { @@ -641,12 +785,61 @@

Create Custom Model Configuration

pipeline.blockwise_config.push(node); } - // Auto-connect the new node into the pipeline - autoConnectNodes(); - renderCanvas(); + // Intelligently connect just the new node (don't wipe existing edges) + connectNewNode(node, type); + + // Append just this node to the canvas (don't rebuild everything) + const canvas = document.getElementById('canvas-content'); + const nodeWithType = { ...node, type }; + const nodeEl = createNodeElement(nodeWithType); + canvas.appendChild(nodeEl); + setupNodeDragging(nodeEl, nodeWithType); + renderConnections(); + debouncedApply(); showMessage(`Added ${name}`, 'success'); } + // Connect a newly added node based on its type, without wiping existing edges + function connectNewNode(node, type) { + function addEdge(fromId, toId) { + if (fromId && toId) { + const exists = pipeline.edges.some(e => e.from === fromId && e.to === toId); + if (!exists) { + pipeline.edges.push({ + id: `edge-${Date.now()}-${Math.random()}`, + from: fromId, + to: toId + }); + } + } + } + + if (type === 'input') { + // Connect to first normalizer if exists, else first model + const target = pipeline.normalizers[0] || pipeline.models[0]; + if (target) addEdge(node.id, target.id); + } else if (type === 'output') { + // Connect from last postprocessor if exists, else last model + const source = pipeline.postprocessors.length > 0 + ? pipeline.postprocessors[pipeline.postprocessors.length - 1] + : (pipeline.models.length > 0 ? pipeline.models[pipeline.models.length - 1] : null); + if (source) addEdge(source.id, node.id); + } else if (type === 'normalizer') { + // No auto-connect — user can drag connections manually + } else if (type === 'model') { + // Connect from last normalizer (or input) if exists + const source = pipeline.normalizers.length > 0 + ? pipeline.normalizers[pipeline.normalizers.length - 1] + : (pipeline.inputs.length > 0 ? pipeline.inputs[0] : null); + if (source) addEdge(source.id, node.id); + // Connect to first postprocessor or output + const target = pipeline.postprocessors[0] || (pipeline.outputs.length > 0 ? pipeline.outputs[0] : null); + if (target) addEdge(node.id, target.id); + } else if (type === 'postprocessor') { + // No auto-connect — user can drag connections manually + } + } + function removeNode(id, type) { // Remove connections involving this node pipeline.edges = pipeline.edges.filter(e => e.from !== id && e.to !== id); @@ -664,9 +857,11 @@

Create Custom Model Configuration

} else if (type === 'blockwise-config') { pipeline.blockwise_config = pipeline.blockwise_config.filter(c => c.id !== id); } - // Re-auto-connect remaining nodes - autoConnectNodes(); - renderCanvas(); + // Remove the node's DOM element and re-render connections + const nodeEl = document.getElementById(`node-${id}`); + if (nodeEl) nodeEl.remove(); + renderConnections(); + debouncedApply(); showMessage('Node removed', 'success'); } @@ -1443,7 +1638,7 @@

Create Custom Model Configuration

const checked = value ? 'checked' : ''; inputHTML = ` @@ -1452,13 +1647,13 @@

Create Custom Model Configuration

inputHTML = ` + oninput="handleParamChange(this, '${node.id}', '${node.type}')"> `; } else { inputHTML = ` + oninput="handleParamChange(this, '${node.id}', '${node.type}')"> `; } @@ -1483,7 +1678,7 @@

Create Custom Model Configuration

+ oninput="handleParamChange(this, '${node.id}', '${node.type}')"> `; @@ -1548,7 +1743,6 @@

Create Custom Model Configuration

${node.name}
-
@@ -1576,7 +1770,7 @@

Create Custom Model Configuration

+ oninput="handleParamChange(this, '${node.id}', 'blockwise-config')"> `; }); @@ -1587,7 +1781,6 @@

Create Custom Model Configuration

⚙️ Blockwise Config
-
@@ -1605,7 +1798,6 @@

Create Custom Model Configuration

${node.type}
-
@@ -1672,9 +1864,49 @@

Create Custom Model Configuration

nodeEl.style.zIndex = 999; // 999 so dragging can use 1000 } + // Debounced auto-apply to backend + let _applyTimer = null; + function debouncedApply() { + if (_applyTimer) clearTimeout(_applyTimer); + _applyTimer = setTimeout(() => applyPipeline(), 2000); + } + + // Always sync on page leave (covers back button, link clicks, etc.) + window.addEventListener('beforeunload', () => { + if (_applyTimer) clearTimeout(_applyTimer); + const payload = buildApplyPayload(); + navigator.sendBeacon('/api/pipeline/apply', new Blob([JSON.stringify(payload)], { type: 'application/json' })); + }); + function handleParamChange(input, nodeId, nodeType) { - const nodeBox = document.getElementById(`node-${nodeId}`); - nodeBox.classList.add('dirty'); + // Immediately update the pipeline object + const key = input.dataset.key; + const type = input.dataset.type || 'text'; + let value; + if (type === 'boolean') { + value = input.value === 'true'; + } else if (type === 'number') { + value = parseFloat(input.value); + if (isNaN(value)) return; + } else { + try { value = JSON.parse(input.value); } catch { value = input.value; } + } + + let node; + if (nodeType === 'input') node = pipeline.inputs.find(n => n.id === nodeId); + else if (nodeType === 'output') node = pipeline.outputs.find(n => n.id === nodeId); + else if (nodeType === 'normalizer') node = pipeline.normalizers.find(n => n.id === nodeId); + else if (nodeType === 'model') node = pipeline.models.find(m => m.id === nodeId); + else if (nodeType === 'postprocessor') node = pipeline.postprocessors.find(p => p.id === nodeId); + else if (nodeType === 'blockwise-config') node = pipeline.blockwise_config.find(c => c.id === nodeId); + + if (node) { + node.params[key] = value; + if (nodeType === 'model' && node.config) { + node.config[key] = value; + } + } + debouncedApply(); } function saveNode(nodeId, nodeType) { @@ -1829,9 +2061,8 @@

Create Custom Model Configuration

setTimeout(() => div.remove(), 5000); } - async function applyPipeline() { - // Transform pipeline data to backend format - include complete state for persistence - const payload = { + function buildApplyPayload() { + return { input_normalizers: pipeline.normalizers.map(n => ({ id: n.id, name: n.name, @@ -1852,10 +2083,12 @@

Create Custom Model Configuration

})), inputs: pipeline.inputs.map(i => ({ id: i.id, + params: i.params, position: i.position })), outputs: pipeline.outputs.map(o => ({ id: o.id, + params: o.params, position: o.position })), edges: pipeline.edges.map(e => ({ @@ -1864,22 +2097,22 @@

Create Custom Model Configuration

to: e.to })) }; + } + async function applyPipeline() { + const payload = buildApplyPayload(); try { const response = await fetch('/api/pipeline/apply', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(payload) }); - const result = await response.json(); - if (response.ok) { - showMessage('✓ Pipeline applied successfully!', 'success'); - } else { - showMessage('Error: ' + (result.error || 'Unknown error'), 'error'); + if (!response.ok) { + console.error('Pipeline sync error:', result.error || 'Unknown error'); } } catch (err) { - showMessage('Failed: ' + err.message, 'error'); + console.error('Pipeline sync failed:', err.message); } } @@ -2134,7 +2367,9 @@

Create Custom Model Configuration

}; console.log('Final pipeline:', pipeline); // Debug log + autoConnectNodes(); renderCanvas(); + autoLayoutNodes(); // If blockwise config is imported, sync it to backend globals if (pipeline.blockwise_config.length > 0) { @@ -2493,6 +2728,7 @@

Create Custom Model Configuration

to: toNodeId }); renderConnections(); + debouncedApply(); showMessage('Connection created', 'success'); } else { showMessage('Connection already exists', 'info'); @@ -2516,6 +2752,18 @@

Create Custom Model Configuration

return `M ${x1} ${y1} C ${cx1} ${cy1}, ${cx2} ${cy2}, ${x2} ${y2}`; } + // Get connection dot center relative to canvas-content coordinate space. + // Both dot and canvas-content are children of the same scroll container, + // so their getBoundingClientRect difference is already scroll-independent. + function getDotPosition(dot) { + const containerRect = document.getElementById('canvas-content').getBoundingClientRect(); + const dotRect = dot.getBoundingClientRect(); + return { + x: dotRect.left + dotRect.width / 2 - containerRect.left, + y: dotRect.top + dotRect.height / 2 - containerRect.top + }; + } + function renderConnections() { const svg = document.querySelector('.connections-svg'); if (!svg) return; @@ -2524,6 +2772,20 @@

Create Custom Model Configuration

const existingPaths = svg.querySelectorAll('.connection-path:not(.dragging-connection)'); existingPaths.forEach(p => p.remove()); + // Compute extent of all nodes to size the SVG properly + const canvas = document.getElementById('canvas-content'); + const nodeBoxes = canvas.querySelectorAll('.node-box'); + let maxRight = 600, maxBottom = 600; + nodeBoxes.forEach(box => { + const right = box.offsetLeft + box.offsetWidth; + const bottom = box.offsetTop + box.offsetHeight; + if (right > maxRight) maxRight = right; + if (bottom > maxBottom) maxBottom = bottom; + }); + // Expand canvas-content to fit all nodes so SVG (100% of parent) covers everything + canvas.style.minWidth = (maxRight + 100) + 'px'; + canvas.style.minHeight = (maxBottom + 100) + 'px'; + // Render each edge pipeline.edges.forEach(edge => { const fromNodeEl = document.getElementById(`node-${edge.from}`); @@ -2536,18 +2798,12 @@

Create Custom Model Configuration

if (!fromDot || !toDot) return; - const canvasRect = document.getElementById('canvas-content').getBoundingClientRect(); - const fromRect = fromDot.getBoundingClientRect(); - const toRect = toDot.getBoundingClientRect(); - - const x1 = fromRect.left + fromRect.width / 2 - canvasRect.left; - const y1 = fromRect.top + fromRect.height / 2 - canvasRect.top; - const x2 = toRect.left + toRect.width / 2 - canvasRect.left; - const y2 = toRect.top + toRect.height / 2 - canvasRect.top; + const from = getDotPosition(fromDot); + const to = getDotPosition(toDot); const path = document.createElementNS('http://www.w3.org/2000/svg', 'path'); path.classList.add('connection-path'); - path.setAttribute('d', createBezierPath(x1, y1, x2, y2)); + path.setAttribute('d', createBezierPath(from.x, from.y, to.x, to.y)); path.dataset.edgeId = edge.id; path.style.pointerEvents = 'stroke'; @@ -2564,6 +2820,7 @@

Create Custom Model Configuration

function deleteConnection(edgeId) { pipeline.edges = pipeline.edges.filter(e => e.id !== edgeId); renderConnections(); + debouncedApply(); showMessage('Connection deleted', 'success'); } @@ -2600,6 +2857,7 @@

Create Custom Model Configuration

const finalY = Math.max(0, draggedNode.originalY + dy); draggedNode.node.position = { x: finalX, y: finalY }; + debouncedApply(); // Keep z-index at 999 since it's selected, instead of resetting to auto draggedNode.element.style.zIndex = 999; draggedNode.element.style.cursor = 'auto'; diff --git a/cellmap_flow/finetune/__init__.py b/cellmap_flow/finetune/__init__.py new file mode 100644 index 0000000..2077344 --- /dev/null +++ b/cellmap_flow/finetune/__init__.py @@ -0,0 +1,38 @@ +""" +Human-in-the-loop finetuning for CellMap-Flow models. + +This package provides lightweight LoRA-based finetuning for pre-trained models +using user corrections as training data. +""" + +from cellmap_flow.finetune.lora_wrapper import ( + detect_adaptable_layers, + wrap_model_with_lora, + print_lora_parameters, + load_lora_adapter, + save_lora_adapter, +) + +from cellmap_flow.finetune.correction_dataset import ( + CorrectionDataset, + create_dataloader, +) + +from cellmap_flow.finetune.lora_trainer import ( + LoRAFinetuner, + DiceLoss, + CombinedLoss, +) + +__all__ = [ + "detect_adaptable_layers", + "wrap_model_with_lora", + "print_lora_parameters", + "load_lora_adapter", + "save_lora_adapter", + "CorrectionDataset", + "create_dataloader", + "LoRAFinetuner", + "DiceLoss", + "CombinedLoss", +] diff --git a/cellmap_flow/finetune/correction_dataset.py b/cellmap_flow/finetune/correction_dataset.py new file mode 100644 index 0000000..680c8aa --- /dev/null +++ b/cellmap_flow/finetune/correction_dataset.py @@ -0,0 +1,377 @@ +""" +PyTorch Dataset for loading user corrections. + +This module provides a Dataset class that loads 3D EM data and correction +masks from Zarr files for training LoRA adapters. +""" + +import logging +from pathlib import Path +from typing import List, Tuple, Optional +import numpy as np +import zarr +import torch +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + + +class CorrectionDataset(Dataset): + """ + PyTorch Dataset for user corrections stored in Zarr format. + + Loads raw EM data and corrected masks from corrections.zarr/, with + optional 3D augmentation. + + Args: + corrections_zarr_path: Path to corrections.zarr directory + patch_shape: Shape of patches to extract (Z, Y, X) + If None, uses full correction size + augment: Whether to apply 3D augmentation + model_name: If specified, only load corrections for this model + + Examples: + >>> dataset = CorrectionDataset( + ... "test_corrections.zarr", + ... patch_shape=(64, 64, 64), + ... augment=True + ... ) + >>> print(f"Dataset size: {len(dataset)}") + >>> raw, target = dataset[0] + >>> print(f"Raw shape: {raw.shape}, Target shape: {target.shape}") + """ + + def __init__( + self, + corrections_zarr_path: str, + patch_shape: Optional[Tuple[int, int, int]] = None, + augment: bool = True, + model_name: Optional[str] = None, + ): + self.corrections_path = Path(corrections_zarr_path) + self.patch_shape = patch_shape + self.augment = augment + self.model_name = model_name + + # Load corrections + self.corrections = self._load_corrections() + + if len(self.corrections) == 0: + raise ValueError( + f"No corrections found in {corrections_zarr_path}. " + f"Generate corrections first." + ) + + logger.info( + f"Loaded {len(self.corrections)} corrections from {corrections_zarr_path}" + ) + + def _load_corrections(self) -> List[dict]: + """Load correction metadata from Zarr.""" + corrections = [] + + logger.info(f"Loading corrections from: {self.corrections_path}") + + if not self.corrections_path.exists(): + logger.error(f"Corrections path does not exist: {self.corrections_path}") + return corrections + + path_str = str(self.corrections_path) + z = zarr.open_group(path_str, mode="r") + + for correction_id in z.keys(): + corr_group = z[correction_id] + + # Check if correction has required data + # Support both 'mask' (from test scripts) and 'annotation' (from dashboard) + has_raw = "raw" in corr_group + has_mask = "mask" in corr_group + has_annotation = "annotation" in corr_group + + has_raw_s0 = has_raw and "s0" in corr_group["raw"] + has_mask_s0 = has_mask and "s0" in corr_group["mask"] + has_annotation_s0 = has_annotation and "s0" in corr_group["annotation"] + + if not has_raw_s0 or not (has_mask_s0 or has_annotation_s0): + logger.warning( + f"Skipping {correction_id}: missing raw/s0 or mask|annotation/s0" + ) + continue + + # Use 'mask' if available, otherwise use 'annotation' + mask_key = "mask" if has_mask_s0 else "annotation" + + # Get metadata + attrs = dict(corr_group.attrs) + + # Filter by model name if specified + if self.model_name and attrs.get("model_name") != self.model_name: + continue + + raw_path = self.corrections_path / correction_id / "raw" / "s0" + mask_path = self.corrections_path / correction_id / mask_key / "s0" + + if not raw_path.exists() or not mask_path.exists(): + logger.warning( + f"Skipping {correction_id}: missing paths " + f"raw_path={raw_path} (exists={raw_path.exists()}), " + f"mask_path={mask_path} (exists={mask_path.exists()})" + ) + continue + + corrections.append( + { + "id": correction_id, + "raw_path": str(raw_path), + "mask_path": str(mask_path), + "metadata": attrs, + } + ) + + return corrections + + def __len__(self) -> int: + """Return number of corrections.""" + return len(self.corrections) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Load a correction pair (raw, target). + + Args: + idx: Index of correction + + Returns: + Tuple of (raw, target) tensors: + - raw: (1, Z, Y, X) float32 tensor + - target: (1, Z, Y, X) float32 tensor, values in [0, 1] + """ + correction = self.corrections[idx] + + # Load data using ImageDataInterface for consistent data loading + from cellmap_flow.image_data_interface import ImageDataInterface + + try: + raw = ImageDataInterface( + correction["raw_path"], normalize=False + ).to_ndarray_ts() + mask = ImageDataInterface( + correction["mask_path"], normalize=False + ).to_ndarray_ts() + except Exception as e: + raise FileNotFoundError( + f"Failed loading correction '{correction.get('id', idx)}' " + f"raw_path='{correction.get('raw_path')}' " + f"mask_path='{correction.get('mask_path')}': {e}" + ) from e + + # Convert to float + raw = raw.astype(np.float32) + mask = mask.astype(np.float32) + + + # For models with different input/output sizes, we keep raw at full size + # Patching is disabled for this case - use full corrections + # Apply augmentation (only if raw and mask have same shape) + if self.augment and raw.shape == mask.shape: + raw, mask = self._augment_3d(raw, mask) + elif self.augment and raw.shape != mask.shape: + logger.debug( + f"Skipping augmentation: raw {raw.shape} != mask {mask.shape}. " + "Augmentation requires matching sizes." + ) + + # Add channel dimension and convert to torch + raw = torch.from_numpy(raw[np.newaxis, ...]) # (1, Z, Y, X) + mask = torch.from_numpy(mask[np.newaxis, ...]) # (1, Z, Y, X) + + return raw, mask + + def _random_crop( + self, raw: np.ndarray, mask: np.ndarray, patch_shape: Tuple[int, int, int] + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Extract a random patch from the volumes. + + Args: + raw: Raw data (Z, Y, X) + mask: Mask data (Z, Y, X) + patch_shape: Desired patch shape (Z, Y, X) + + Returns: + Cropped (raw, mask) pair + """ + z, y, x = raw.shape + pz, py, px = patch_shape + + # If volume is smaller than patch, pad it + if z < pz or y < py or x < px: + pad_z = max(0, pz - z) + pad_y = max(0, py - y) + pad_x = max(0, px - x) + + raw = np.pad(raw, ((0, pad_z), (0, pad_y), (0, pad_x)), mode="reflect") + mask = np.pad(mask, ((0, pad_z), (0, pad_y), (0, pad_x)), mode="reflect") + z, y, x = raw.shape + + # Random offset + z_offset = np.random.randint(0, max(1, z - pz + 1)) + y_offset = np.random.randint(0, max(1, y - py + 1)) + x_offset = np.random.randint(0, max(1, x - px + 1)) + + # Crop + raw_crop = raw[ + z_offset : z_offset + pz, y_offset : y_offset + py, x_offset : x_offset + px + ] + mask_crop = mask[ + z_offset : z_offset + pz, y_offset : y_offset + py, x_offset : x_offset + px + ] + + return raw_crop, mask_crop + + def _augment_3d( + self, raw: np.ndarray, mask: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Apply 3D augmentation to raw and mask. + + Augmentations: + - Random flips on Z/Y/X axes (50% each) + - Random 90° rotations in XY plane (0°, 90°, 180°, 270°) + - Random intensity scaling for raw (×0.8 to ×1.2) + - Random Gaussian noise for raw (σ=0.01) + + Args: + raw: Raw data (Z, Y, X) + mask: Mask data (Z, Y, X) + + Returns: + Augmented (raw, mask) pair + """ + # Random flips + if np.random.rand() > 0.5: + raw = np.flip(raw, axis=0).copy() # Flip Z + mask = np.flip(mask, axis=0).copy() + + if np.random.rand() > 0.5: + raw = np.flip(raw, axis=1).copy() # Flip Y + mask = np.flip(mask, axis=1).copy() + + if np.random.rand() > 0.5: + raw = np.flip(raw, axis=2).copy() # Flip X + mask = np.flip(mask, axis=2).copy() + + # Random 90° rotation in XY plane + k = np.random.randint(0, 4) # 0, 1, 2, or 3 (0°, 90°, 180°, 270°) + if k > 0: + raw = np.rot90(raw, k=k, axes=(1, 2)).copy() + mask = np.rot90(mask, k=k, axes=(1, 2)).copy() + + # Intensity augmentation for raw only + # Random scaling (×0.8 to ×1.2) + scale = np.random.uniform(0.8, 1.2) + raw = np.clip(raw * scale, 0, 1) + + # Random Gaussian noise (σ=0.01) + noise = np.random.normal(0, 0.01, raw.shape).astype(np.float32) + raw = np.clip(raw + noise, 0, 1) + + return raw, mask + + +def create_dataloader( + corrections_zarr_path: str, + batch_size: int = 2, + patch_shape: Optional[Tuple[int, int, int]] = None, + augment: bool = True, + num_workers: int = 4, + shuffle: bool = True, + model_name: Optional[str] = None, +) -> torch.utils.data.DataLoader: + # If a virtual-sources manifest is present in this corrections dir, + # bypass the materialized-chunk dataset entirely and stream patches + # directly from the source zarrs. The trainer entry point doesn't need + # to know — it still passes the same --corrections directory. + from cellmap_flow.finetune.virtual_dataset import ( + VIRTUAL_MANIFEST_FILENAME, + dataset_from_manifest, + read_manifest, + ) + + manifest = read_manifest(corrections_zarr_path) + if manifest is not None: + logger.info( + f"Found virtual sources manifest at " + f"{Path(corrections_zarr_path) / VIRTUAL_MANIFEST_FILENAME}; " + f"using VirtualPatchDataset." + ) + dataset = dataset_from_manifest(manifest) + actual_batch_size = max(1, min(batch_size, len(dataset))) + return torch.utils.data.DataLoader( + dataset, + batch_size=actual_batch_size, + shuffle=False, # virtual dataset already samples randomly + num_workers=num_workers, + pin_memory=True, + persistent_workers=num_workers > 0, + multiprocessing_context="spawn" if num_workers > 0 else None, + ) + """ + Create a DataLoader for corrections. + + Args: + corrections_zarr_path: Path to corrections.zarr directory + batch_size: Batch size (2-4 recommended for 3D data) + patch_shape: Shape of patches to extract (Z, Y, X) + augment: Whether to apply augmentation + num_workers: Number of data loading workers + shuffle: Whether to shuffle data + model_name: If specified, only load corrections for this model + + Returns: + DataLoader instance + + Examples: + >>> dataloader = create_dataloader( + ... "test_corrections.zarr", + ... batch_size=2, + ... patch_shape=(64, 64, 64) + ... ) + >>> for raw, target in dataloader: + ... print(f"Batch: raw={raw.shape}, target={target.shape}") + ... break + Batch: raw=torch.Size([2, 1, 64, 64, 64]), target=torch.Size([2, 1, 64, 64, 64]) + """ + dataset = CorrectionDataset( + corrections_zarr_path, + patch_shape=patch_shape, + augment=augment, + model_name=model_name, + ) + + # Clamp batch size to number of samples so DataLoader doesn't error + actual_batch_size = ( + min(batch_size, len(dataset)) if len(dataset) > 0 else batch_size + ) + if actual_batch_size != batch_size: + logger.info( + f"Clamped batch_size from {batch_size} to {actual_batch_size} " + f"(only {len(dataset)} samples available)" + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=actual_batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=True, # Faster GPU transfer + persistent_workers=num_workers > 0, # Keep workers alive between epochs + multiprocessing_context='spawn' if num_workers > 0 else None, # tensorstore threading incompatible with fork + ) + + logger.info( + f"Created DataLoader with {len(dataset)} samples, " + f"batch_size={actual_batch_size}, num_workers={num_workers}" + ) + + return dataloader diff --git a/cellmap_flow/finetune/crop_loader.py b/cellmap_flow/finetune/crop_loader.py new file mode 100644 index 0000000..dbd751a --- /dev/null +++ b/cellmap_flow/finetune/crop_loader.py @@ -0,0 +1,339 @@ +""" +YAML manifest schema and helpers for importing externally annotated crops. + +Each manifest entry points at a 3D zarr (typically OME-NGFF) of instance +labels or class IDs. The dashboard's YAML loader (see +``cellmap_flow.dashboard.routes.finetune.yaml_crops``) reads these crops, +remaps their values into the trainer's +``0 = unannotated, 1 = background, >=2 = foreground instance`` convention, +and writes them into the session's annotation_volume.zarr at the crops' +physical offsets. + +This module owns: + - The pydantic schema (:class:`CropEntry`, :class:`CropsConfig`). + - The label remap function (:func:`remap_labels`). + - Small zarr-attrs helpers used by the loader to derive a crop's voxel + size, offset, and the array sub-path inside an OME-NGFF group. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import Path +from typing import Iterable, List, Literal, Optional, Tuple + +import numpy as np +import yaml +import zarr +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# YAML schema +# --------------------------------------------------------------------------- + + +class CropEntry(BaseModel): + """One annotation crop to import. + + Fields other than ``path`` are optional with sensible defaults: + - ``fg_ids=None`` means "every nonzero source value is foreground" + - ``bg_ids=[]`` means "no explicit BG ids, see mode for what 0 means" + - ``mode='dense'`` treats unmatched voxels (incl. 0) as background + - ``mode='sparse'`` treats unmatched voxels as unannotated + - ``connected_components=False`` keeps source ids as instance ids; + set True with a single-id ``fg_ids`` to split same-id blobs into + per-instance ids for affinity-style training. + """ + + model_config = ConfigDict(extra="forbid") + + path: str + name: Optional[str] = None + fg_ids: Optional[List[int]] = None + bg_ids: List[int] = Field(default_factory=list) + mode: Literal["dense", "sparse"] = "dense" + connected_components: bool = False + + @field_validator("fg_ids", "bg_ids") + @classmethod + def _no_zero_ids(cls, value): + if value is not None and 0 in value: + raise ValueError("fg_ids/bg_ids cannot include 0 (0 = unannotated sentinel)") + return value + + @model_validator(mode="after") + def _validate(self): + if self.fg_ids is not None and self.bg_ids: + overlap = set(self.fg_ids) & set(self.bg_ids) + if overlap: + raise ValueError(f"fg_ids and bg_ids overlap: {sorted(overlap)}") + if self.connected_components and self.fg_ids is None: + raise ValueError("connected_components=True requires fg_ids to be specified") + return self + + +class CropsConfig(BaseModel): + """Top-level YAML schema. + + ``patches_per_epoch``, ``jitter_voxels``, and ``seed`` are passed through + to the :class:`VirtualPatchDataset` manifest the loader writes — they + govern epoch length, patch-center jitter (in voxels), and the per-worker + RNG base seed for reproducible patch sampling across runs. + + ``patches_per_epoch=None`` (the default) means "cover every populated + chunk roughly once per epoch" — the dataset substitutes the total + populated-chunk count at index build time. Override with an explicit + int to cap the epoch length. + + ``dense_to_sparse_ratio=None`` (the default) means "auto-balance": + 50/50 split between dense imported crops and sparse painted scribbles + when both pools exist; degrades to 1.0 (all from the surviving pool) + when only one pool has FG voxels. + """ + + model_config = ConfigDict(extra="forbid") + + crops: List[CropEntry] + patches_per_epoch: Optional[int] = None + jitter_voxels: Optional[List[int]] = None + seed: int = 0 + dense_to_sparse_ratio: Optional[float] = None + + @field_validator("crops", mode="before") + @classmethod + def _coerce_bare_strings(cls, value): + if not isinstance(value, list): + return value + out = [] + for entry in value: + if isinstance(entry, str): + out.append({"path": entry}) + else: + out.append(entry) + return out + + +def parse_crops_yaml(yaml_text_or_path: str) -> CropsConfig: + """Parse a YAML string OR the path to a YAML file into a validated config.""" + text = yaml_text_or_path + if "\n" not in yaml_text_or_path and os.path.exists(yaml_text_or_path): + with open(yaml_text_or_path) as f: + text = f.read() + data = yaml.safe_load(text) or {} + return CropsConfig.model_validate(data) + + +# --------------------------------------------------------------------------- +# Zarr metadata helpers +# --------------------------------------------------------------------------- + + +def _read_voxel_size_and_offset( + zarr_path: str, +) -> Tuple[Tuple[str, ...], np.ndarray, np.ndarray]: + """Return ``(array_subpath, voxel_size_nm, offset_nm)`` for an annotation zarr. + + Handles three layouts: + 1. Multiscale group with ``multiscales`` -> first scale's array. + 2. Plain ``zarr.Array`` with ``transform``/``resolution`` attrs. + 3. Plain ``zarr.Array`` with no metadata -> voxel_size=(1,1,1), + offset=(0,0,0). + """ + node = zarr.open(zarr_path, mode="r") + + if isinstance(node, zarr.hierarchy.Group): + attrs = dict(node.attrs) + multiscales = attrs.get("multiscales") + if multiscales: + ms = multiscales[0] + ds = ms["datasets"][0] + sub = ds["path"] + scale = np.array([1.0, 1.0, 1.0]) + translation = np.array([0.0, 0.0, 0.0]) + for tx in ds.get("coordinateTransformations", []): + if tx.get("type") == "scale": + scale = np.array(tx["scale"], dtype=float) + elif tx.get("type") == "translation": + translation = np.array(tx["translation"], dtype=float) + return (sub,), scale, translation + if "s0" in node: + return ("s0",), np.array([1.0, 1.0, 1.0]), np.array([0.0, 0.0, 0.0]) + raise ValueError( + f"Group at {zarr_path} has no 'multiscales' attribute and no 's0' child." + ) + + attrs = dict(node.attrs) + if "transform" in attrs: + tx = attrs["transform"] + scale = np.array(tx.get("scale", [1, 1, 1]), dtype=float) + translation = np.array(tx.get("translate", [0, 0, 0]), dtype=float) + return (), scale, translation + if "resolution" in attrs: + scale = np.array(attrs["resolution"], dtype=float) + translation = np.array(attrs.get("offset", [0, 0, 0]), dtype=float) + return (), scale, translation + return (), np.array([1.0, 1.0, 1.0]), np.array([0.0, 0.0, 0.0]) + + +def _open_array(zarr_path: str, sub: Tuple[str, ...]) -> zarr.Array: + target = zarr_path + for piece in sub: + target = os.path.join(target, piece) + arr = zarr.open(target, mode="r") + if not isinstance(arr, zarr.Array): + raise ValueError(f"Expected zarr.Array at {target}, got {type(arr).__name__}") + return arr + + +# --------------------------------------------------------------------------- +# Label remap +# --------------------------------------------------------------------------- + + +def remap_labels( + source: np.ndarray, + fg_ids: Optional[Iterable[int]], + bg_ids: Iterable[int], + mode: Literal["dense", "sparse"], + connected_components: bool, +) -> np.ndarray: + """Map source label values to ``0=unannotated, 1=BG, >=2=FG instance``. + + Mapping rules: + - source value in ``fg_ids`` (or any nonzero if ``fg_ids is None``) + becomes a unique instance id >=2. If ``connected_components`` is True, + each connected blob within an fg_id class gets its own instance. + Otherwise, source ids map to consecutive 2,3,... in order. + - source value in ``bg_ids`` -> 1 (background). + - everything else -> 1 if ``mode='dense'``, else 0 (unannotated). + + Returns ``uint8``. If the number of distinct instances would overflow + ``uint8``, all FG voxels collapse to id=2 and a warning is emitted. + """ + bg_set = {int(v) for v in bg_ids} + + if connected_components: + return _remap_with_cc(source, fg_ids, bg_set, mode) + + # Fast path: build a lookup table over source's value range and apply it + # as a single fancy-index pass. ~10-100x faster than the previous + # per-class boolean-mask loop on 600^3 arrays. + try: + import fastremap + + unique_vals = fastremap.unique(source) + except Exception: + unique_vals = np.unique(source) + + if fg_ids is None: + fg_classes = [int(v) for v in unique_vals if int(v) != 0 and int(v) not in bg_set] + else: + # Preserve caller order for deterministic instance ids. + fg_classes = [int(v) for v in fg_ids] + + src_max = int(unique_vals.max()) if len(unique_vals) else 0 + # Sanity cap: a 32-bit max would blow memory. Real label crops top out in + # the thousands; bail to the slow per-class path if someone hands us a + # pathological array. + if src_max > 8_000_000: + return _remap_per_class(source, fg_classes, bg_set, mode) + + default = 1 if mode == "dense" else 0 + # uint32 so we can hold instance ids before the uint8 clamp warning fires. + lookup = np.full(src_max + 1, default, dtype=np.uint32) + if mode != "dense": + # sparse: source==0 stays 0 (unannotated) + lookup[0] = 0 + else: + lookup[0] = 1 + for bg in bg_set: + if 0 <= bg <= src_max: + lookup[bg] = 1 + next_instance_id = 2 + for cls in fg_classes: + if 0 <= cls <= src_max: + lookup[cls] = next_instance_id + next_instance_id += 1 + + out = lookup[source] + + if next_instance_id > 256: + logger.warning( + f"Crop produced {next_instance_id - 2} instances; collapsing to single FG class " + "to fit uint8. Affinities between distinct blobs may be inaccurate." + ) + np.minimum(out, 2, out=out, where=(out >= 2)) + + return out.astype(np.uint8) + + +def _remap_with_cc(source, fg_ids, bg_set, mode): + """Connected-components path: per-class CC labeling, retained for the + rare ``connected_components=True`` case. Slower than the lookup-table + fast path but produces distinct instance ids per blob.""" + from scipy.ndimage import label as cc_label + + if fg_ids is None: + try: + import fastremap + + unique_vals = fastremap.unique(source) + except Exception: + unique_vals = np.unique(source) + fg_classes = [int(v) for v in unique_vals if int(v) != 0 and int(v) not in bg_set] + else: + fg_classes = [int(v) for v in fg_ids] + + out = np.zeros(source.shape, dtype=np.uint32) + next_instance_id = 2 + for cls in fg_classes: + cls_mask = source == cls + if not cls_mask.any(): + continue + labeled, n = cc_label(cls_mask) + for i in range(1, n + 1): + out[labeled == i] = next_instance_id + next_instance_id += 1 + + fg_set = out >= 2 + if bg_set: + bg_mask = np.isin(source, list(bg_set)) + out[bg_mask & ~fg_set] = 1 + if mode == "dense": + out[(out == 0) & ~fg_set] = 1 + + if next_instance_id > 256: + logger.warning( + f"Crop produced {next_instance_id - 2} instances; collapsing to single FG class " + "to fit uint8. Affinities between distinct blobs may be inaccurate." + ) + out[out >= 2] = 2 + return out.astype(np.uint8) + + +def _remap_per_class(source, fg_classes, bg_set, mode): + """Fallback for pathologically-large source IDs: original per-class loop.""" + out = np.zeros(source.shape, dtype=np.uint32) + next_instance_id = 2 + for cls in fg_classes: + cls_mask = source == cls + if cls_mask.any(): + out[cls_mask] = next_instance_id + next_instance_id += 1 + fg_set = out >= 2 + if bg_set: + bg_mask = np.isin(source, list(bg_set)) + out[bg_mask & ~fg_set] = 1 + if mode == "dense": + out[(out == 0) & ~fg_set] = 1 + if next_instance_id > 256: + logger.warning( + f"Crop produced {next_instance_id - 2} instances; collapsing to single FG class " + "to fit uint8. Affinities between distinct blobs may be inaccurate." + ) + out[out >= 2] = 2 + return out.astype(np.uint8) diff --git a/cellmap_flow/finetune/finetune_cli.py b/cellmap_flow/finetune/finetune_cli.py new file mode 100644 index 0000000..98837eb --- /dev/null +++ b/cellmap_flow/finetune/finetune_cli.py @@ -0,0 +1,1006 @@ +#!/usr/bin/env python +""" +Command-line interface for LoRA finetuning. + +Usage: + python -m cellmap_flow.finetune.finetune_cli \ + --model-checkpoint /path/to/checkpoint \ + --corrections corrections.zarr \ + --output-dir output/fly_organelles_v1.1 + + # With custom settings + python -m cellmap_flow.finetune.finetune_cli \ + --model-checkpoint /path/to/checkpoint \ + --corrections corrections.zarr \ + --output-dir output/fly_organelles_v1.1 \ + --lora-r 16 \ + --batch-size 4 \ + --num-epochs 20 \ + --learning-rate 2e-4 +""" + +import argparse +import gc +import json +import logging +import socket +import sys +import threading +import time +from contextlib import closing +from datetime import datetime +from pathlib import Path +from typing import Optional + +import torch + +from cellmap_flow.models.models_config import FlyModelConfig, DaCapoModelConfig, HuggingFaceModelConfig, ModelConfig +from cellmap_flow.utils.ds import _is_remote_path +from cellmap_flow.finetune.lora_wrapper import wrap_model_with_lora +from cellmap_flow.finetune.correction_dataset import create_dataloader +from cellmap_flow.finetune.lora_trainer import LoRAFinetuner + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + force=True, +) +logger = logging.getLogger(__name__) + + +class RestartController: + """In-memory restart control shared between training loop and server endpoint.""" + + def __init__(self): + self._event = threading.Event() + self._lock = threading.Lock() + self._pending = None + + def request_restart(self, payload: Optional[dict]) -> bool: + signal_data = { + "restart": True, + "timestamp": datetime.now().isoformat(), + "params": {}, + } + if isinstance(payload, dict): + if "timestamp" in payload and payload["timestamp"]: + signal_data["timestamp"] = payload["timestamp"] + if isinstance(payload.get("params"), dict): + signal_data["params"] = payload["params"] + + with self._lock: + self._pending = signal_data + self._event.set() + return True + + def get_if_triggered(self) -> Optional[dict]: + if not self._event.is_set(): + return None + with self._lock: + signal_data = self._pending + self._pending = None + self._event.clear() + return signal_data + + +def _wait_for_port_ready(host: str, port: int, timeout_s: float = 30.0, interval_s: float = 0.1) -> bool: + """Wait until a TCP port is accepting connections.""" + deadline = time.perf_counter() + timeout_s + while time.perf_counter() < deadline: + try: + with closing(socket.create_connection((host, port), timeout=0.5)): + return True + except OSError: + time.sleep(interval_s) + return False + + +def _start_inference_server_background( + args, model_config: ModelConfig, trained_model, restart_controller: Optional[RestartController] = None +): + """ + Start inference server in a background daemon thread. + + The server shares the same model object, so retraining updates weights + automatically without needing to restart the server. + + Args: + args: Command-line arguments + model_config: Base model configuration + trained_model: The trained LoRA model + + Returns: + (thread, port) tuple + """ + logger.info("=" * 60) + logger.info("Starting inference server with finetuned model...") + logger.info("=" * 60) + + startup_t0 = time.perf_counter() + + # Clear GPU cache from training + cleanup_t0 = time.perf_counter() + logger.info("Clearing GPU cache...") + torch.cuda.empty_cache() + gc.collect() + cleanup_elapsed = time.perf_counter() - cleanup_t0 + + # Validate serve data path + if not args.serve_data_path: + raise ValueError("--serve-data-path is required when --auto-serve is enabled") + + if not _is_remote_path(args.serve_data_path) and not Path(args.serve_data_path).exists(): + raise ValueError(f"Data path not found: {args.serve_data_path}") + + # Use the already-trained model + logger.info("Using trained LoRA model for inference...") + + from cellmap_flow.models.models_config import _get_device + device = _get_device() + trained_model.eval() + logger.info(f"Model set to eval mode on {device}") + + # Replace the model in the config with our finetuned version + model_config.config.model = trained_model + + # Start server + from cellmap_flow.server import CellMapFlowServer + from cellmap_flow.utils.web_utils import get_free_port + + setup_t0 = time.perf_counter() + logger.info(f"Creating server for dataset: {model_config.name}_finetuned") + restart_callback = restart_controller.request_restart if restart_controller is not None else None + server = CellMapFlowServer(args.serve_data_path, model_config, restart_callback=restart_callback) + + # Get port + port = args.serve_port if args.serve_port != 0 else get_free_port() + + # Start in daemon thread (server.run() prints CELLMAP_FLOW_SERVER_IP marker automatically) + server_thread = threading.Thread( + target=server.run, + kwargs={'port': port, 'debug': False}, + daemon=True + ) + server_thread.start() + setup_elapsed = time.perf_counter() - setup_t0 + + wait_t0 = time.perf_counter() + server_ready = _wait_for_port_ready("127.0.0.1", port) + wait_elapsed = time.perf_counter() - wait_t0 + + host_url = f"http://{socket.gethostname()}:{port}" + total_elapsed = time.perf_counter() - startup_t0 + logger.info("=" * 60) + if server_ready: + logger.info(f"Inference server port is ready on 127.0.0.1:{port}") + else: + logger.warning(f"Inference server did not become ready within timeout on 127.0.0.1:{port}") + logger.info(f"Inference server running at {host_url}") + logger.info( + f"Startup timings (s): cleanup={cleanup_elapsed:.2f}, setup={setup_elapsed:.2f}, " + f"wait_for_bind={wait_elapsed:.2f}, total={total_elapsed:.2f}" + ) + logger.info("Server is running in background. Watching for restart signals...") + logger.info("=" * 60) + + return server_thread, port + + +def _wait_for_restart_signal( + signal_file: Optional[Path], + check_interval: float = 1.0, + restart_controller: Optional[RestartController] = None, +): + """ + Watch for a restart signal file. Blocks until signal appears. + + Prefers in-memory restart events from the control endpoint, and + falls back to a signal file for backward compatibility. + + Args: + signal_file: Optional path to watch for legacy signal file + check_interval: Seconds between checks + + Returns: + Dict with restart parameters, or None if signal file is malformed + """ + logger.info(f"Watching for restart signal (controller + file fallback: {signal_file})") + + while True: + if restart_controller is not None: + in_memory_signal = restart_controller.get_if_triggered() + if in_memory_signal is not None: + logger.info(f"Restart signal received via HTTP control endpoint: {in_memory_signal}") + return in_memory_signal + + if signal_file and signal_file.exists(): + try: + with open(signal_file) as f: + signal_data = json.load(f) + signal_file.unlink() # Remove signal file + logger.info(f"Restart signal received: {signal_data}") + return signal_data + except Exception as e: + logger.error(f"Error reading restart signal: {e}") + # Remove malformed signal file + try: + signal_file.unlink() + except OSError: + pass + return None + time.sleep(check_interval) + + +def _apply_restart_params(args, signal_data: dict): + """ + Update args with parameters from restart signal and persist to metadata.json. + + Args: + args: argparse Namespace to update + signal_data: Dict from restart signal file + """ + params = signal_data.get("params", {}) + changed = False + for key, value in params.items(): + if hasattr(args, key) and value is not None: + old_value = getattr(args, key) + setattr(args, key, value) + if old_value != value: + logger.info(f"Updated {key}: {old_value} -> {value}") + changed = True + + # Persist updated params to metadata.json + if changed and hasattr(args, 'output_dir') and args.output_dir: + metadata_file = Path(args.output_dir) / "metadata.json" + if metadata_file.exists(): + try: + import json as json_mod + with open(metadata_file, "r") as f: + metadata = json_mod.load(f) + if "params" in metadata: + for key, value in params.items(): + if key in metadata["params"]: + metadata["params"][key] = value + metadata["last_restart_at"] = signal_data.get("timestamp") + with open(metadata_file, "w") as f: + json_mod.dump(metadata, f, indent=2) + logger.info(f"Updated metadata.json with restart params") + except Exception as e: + logger.warning(f"Failed to update metadata.json: {e}") + + +def _generate_model_files(args, model_config, timestamp): + """ + Generate YAML config file after training. + + Args: + args: Command-line arguments + model_config: Model configuration + timestamp: Timestamp string for naming + + Returns: + (finetuned_model_name, yaml_path) tuple + """ + from cellmap_flow.finetune.finetuned_model_templates import ( + generate_finetuned_model_yaml + ) + + model_basename = model_config.name + finetuned_model_name = f"{model_basename}_finetuned_{timestamp}" + + # Create models directory in output + output_dir_path = Path(args.output_dir) + session_path = output_dir_path.parent.parent.parent + models_dir = session_path / "models" + models_dir.mkdir(exist_ok=True, parents=True) + + logger.info(f"Generating model config for {finetuned_model_name}...") + + # Extract data path from corrections + corrections_path = Path(args.corrections) + zarr_dirs = list(corrections_path.glob("*.zarr")) + data_path = None + if zarr_dirs: + zattrs_file = zarr_dirs[0] / ".zattrs" + if zattrs_file.exists(): + with open(zattrs_file) as f: + metadata = json.load(f) + data_path = metadata.get("dataset_path") + + if not data_path: + logger.warning("Could not extract data_path from corrections, using serve_data_path") + data_path = args.serve_data_path if args.auto_serve else "/path/to/data.zarr" + + # Bake the training-time input_norm into the generated yaml so the + # served finetuned model gets queried with the same normalization the + # adapter was trained on. Without this, training-vs-inference scale + # mismatch silently destroys finetuning quality. + json_data = None + try: + from cellmap_flow.finetune.virtual_dataset import read_manifest + + manifest = read_manifest(str(corrections_path)) or {} + train_input_norm = manifest.get("input_norm") + if train_input_norm: + json_data = {"input_norm": train_input_norm, "postprocess": {}} + except Exception as _e: + logger.warning( + f"Could not load training input_norm from manifest: {_e}. " + "Generated finetuned yaml will lack normalization metadata." + ) + + yaml_path = generate_finetuned_model_yaml( + lora_adapter_path=str(output_dir_path / "lora_adapter"), + base_model_dict=model_config.to_dict(), + model_name=finetuned_model_name, + output_path=models_dir / f"{finetuned_model_name}.yaml", + data_path=data_path, + json_data=json_data, + ) + logger.info(f"Generated YAML: {yaml_path}") + + return finetuned_model_name, yaml_path + + +def _build_target_transform(args, model_config): + """Build a TargetTransform based on CLI args.""" + from cellmap_flow.finetune.target_transforms import ( + BinaryTargetTransform, + BroadcastBinaryTargetTransform, + AffinityTargetTransform, + ) + + output_type = args.output_type + num_channels = model_config.config.output_channels + + if output_type == "binary": + if num_channels > 1 and args.select_channel is None: + logger.warning( + f"Model has {num_channels} output channels but --output-type is 'binary' " + f"and --select-channel is not set. Consider using --select-channel or " + f"--output-type binary_broadcast." + ) + return BinaryTargetTransform() + + elif output_type == "binary_broadcast": + logger.info(f"Broadcasting binary target to {num_channels} channels") + return BroadcastBinaryTargetTransform(num_channels) + + elif output_type == "affinities": + offsets = None + + # Try CLI arg first + if args.offsets: + offsets = json.loads(args.offsets) + + # Try reading from model script + if offsets is None and args.model_script: + offsets = _read_offsets_from_script(args.model_script) + + if offsets is None: + raise ValueError( + "Affinity output type requires offsets. Provide --offsets as a JSON list " + "(e.g. '[[1,0,0],[0,1,0],[0,0,1]]') or define an 'offsets' variable in " + "the model script." + ) + + if len(offsets) > num_channels: + raise ValueError( + f"Number of offsets ({len(offsets)}) exceeds model output channels " + f"({num_channels})." + ) + + if len(offsets) < num_channels: + logger.info( + f"Model has {num_channels} output channels but only {len(offsets)} affinity offsets. " + f"Remaining {num_channels - len(offsets)} channels (e.g. LSDs) will be masked out." + ) + + logger.info(f"Using affinity target transform with {len(offsets)} offsets: {offsets}") + return AffinityTargetTransform(offsets, num_channels=num_channels) + + else: + raise ValueError(f"Unknown output type: {output_type}") + + +def _read_offsets_from_script(script_path): + """Try to read an 'offsets' variable from a model script via AST parsing.""" + import ast + + try: + with open(script_path, "r") as f: + tree = ast.parse(f.read()) + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "offsets": + return ast.literal_eval(node.value) + except Exception as e: + logger.debug(f"Could not read offsets from {script_path}: {e}") + + return None + + +def build_arg_parser(): + parser = argparse.ArgumentParser( + description="Finetune CellMap-Flow models with LoRA using user corrections" + ) + + # Model arguments + parser.add_argument( + "--model-type", + type=str, + default="fly", + choices=["fly", "dacapo", "huggingface", "script"], + help="Model type (fly, dacapo, huggingface, or script)" + ) + parser.add_argument( + "--model-checkpoint", + type=str, + required=False, + default=None, + help="Path to model checkpoint (optional - can train from scratch)" + ) + parser.add_argument( + "--model-script", + type=str, + required=False, + default=None, + help="Path to model script (alternative to checkpoint)" + ) + parser.add_argument( + "--repo", + type=str, + required=False, + default=None, + help="HuggingFace model repository (e.g., janelia-cellmap/mito_aff_unet_setup_16)" + ) + parser.add_argument( + "--revision", + type=str, + required=False, + default=None, + help="HuggingFace model revision (optional)" + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help="Model name (for filtering corrections)" + ) + parser.add_argument( + "--channels", + type=str, + nargs="+", + default=["mito"], + help="Model output channels" + ) + parser.add_argument( + "--input-voxel-size", + type=int, + nargs=3, + default=[16, 16, 16], + help="Input voxel size (Z Y X)" + ) + parser.add_argument( + "--output-voxel-size", + type=int, + nargs=3, + default=[16, 16, 16], + help="Output voxel size (Z Y X)" + ) + + # LoRA arguments + parser.add_argument( + "--lora-r", + type=int, + default=8, + help="LoRA rank (default: 8)" + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha scaling (default: 16)" + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.1, + help="LoRA dropout (default: 0.1)" + ) + + # Data arguments + parser.add_argument( + "--corrections", + type=str, + required=True, + help="Path to corrections.zarr directory" + ) + parser.add_argument( + "--patch-shape", + type=int, + nargs=3, + default=None, + help="Patch shape for training (Z Y X). Default: None (use full corrections)" + ) + parser.add_argument( + "--no-augment", + action="store_true", + help="Disable data augmentation" + ) + + # Training arguments + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Output directory for checkpoints and adapter" + ) + parser.add_argument( + "--batch-size", + type=int, + default=2, + help="Batch size (default: 2)" + ) + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of training epochs (default: 10)" + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-4, + help="Learning rate (default: 1e-4)" + ) + parser.add_argument( + "--gradient-accumulation-steps", + type=int, + default=1, + help="Gradient accumulation steps (default: 1)" + ) + parser.add_argument( + "--loss-type", + type=str, + default="combined", + choices=["dice", "bce", "combined", "mse", "margin"], + help="Loss function (default: combined)" + ) + parser.add_argument( + "--label-smoothing", + type=float, + default=0.0, + help="Label smoothing factor (e.g., 0.1 maps targets from 0/1 to 0.05/0.95). " + "Helps preserve gradual distance-like outputs. (default: 0.0)" + ) + parser.add_argument( + "--distillation-lambda", + type=float, + default=0.0, + help="Teacher distillation weight. Keeps model close to base on unlabeled voxels. " + "0.0=disabled, try 0.5-1.0 for sparse scribbles. (default: 0.0)" + ) + parser.add_argument( + "--distillation-all-voxels", + action="store_true", + help="Apply distillation loss on all voxels instead of only unlabeled voxels. (default: unlabeled only)" + ) + parser.add_argument( + "--margin", + type=float, + default=0.3, + help="Margin threshold for margin loss. " + "Foreground must exceed 1-margin, background must stay below margin. (default: 0.3)" + ) + parser.add_argument( + "--balance-classes", + action="store_true", + help="Balance fg/bg loss contribution so each class is weighted equally, " + "regardless of scribble voxel counts. Helps prevent foreground overprediction. (default: off)" + ) + parser.add_argument( + "--no-mixed-precision", + action="store_true", + help="Disable mixed precision (FP16) training" + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="DataLoader num_workers (default: 4)" + ) + + # Resuming + parser.add_argument( + "--resume", + type=str, + default=None, + help="Path to checkpoint to resume from" + ) + + # Auto-serve arguments + parser.add_argument( + "--auto-serve", + action="store_true", + help="Automatically start inference server after training completes" + ) + parser.add_argument( + "--serve-data-path", + type=str, + default=None, + help="Dataset path for inference server (required if --auto-serve is used)" + ) + parser.add_argument( + "--serve-port", + type=int, + default=0, + help="Port for inference server (0 for auto-assignment)" + ) + parser.add_argument( + "--mask-unannotated", + action="store_true", + help="Enable masked loss for sparse annotations (0=ignore, 1=bg, 2+=fg)" + ) + + # Output type and target transform arguments + parser.add_argument( + "--output-type", + type=str, + default="binary", + choices=["binary", "binary_broadcast", "affinities"], + help="How to generate training targets from annotations. " + "'binary': single-channel fg/bg (use with --select-channel for multi-channel models). " + "'binary_broadcast': broadcast binary target to all output channels. " + "'affinities': compute affinity targets from instance labels (requires offsets). " + "(default: binary)" + ) + parser.add_argument( + "--select-channel", + type=int, + default=None, + help="Select a single channel from multi-channel model output for binary training. " + "Only used with --output-type binary. (default: None, use all channels)" + ) + parser.add_argument( + "--offsets", + type=str, + default=None, + help="JSON list of [dz,dy,dx] offsets for affinity target generation. " + "Example: '[[1,0,0],[0,1,0],[0,0,1]]'. " + "If not provided with --output-type affinities, will try to read 'offsets' " + "from the model script." + ) + + return parser + + +def main(): + parser = build_arg_parser() + + args = parser.parse_args() + + # Print configuration + logger.info("=" * 60) + logger.info("LoRA Finetuning Configuration") + logger.info("=" * 60) + logger.info(f"Model type: {args.model_type}") + logger.info(f"Model checkpoint: {args.model_checkpoint}") + logger.info(f"Corrections: {args.corrections}") + logger.info(f"Output directory: {args.output_dir}") + logger.info(f"LoRA rank: {args.lora_r}") + logger.info(f"Batch size: {args.batch_size}") + logger.info(f"Epochs: {args.num_epochs}") + logger.info(f"Learning rate: {args.learning_rate}") + logger.info("") + + # === Load model (once) === + logger.info("Loading model...") + + if args.model_script: + from cellmap_flow.models.models_config import ScriptModelConfig + logger.info(f"Using script-based model: {args.model_script}") + model_config = ScriptModelConfig( + script_path=args.model_script, + name=args.model_name or "script_model" + ) + elif args.model_type == "script": + raise ValueError("For script models, --model-script is required") + elif args.model_type == "fly": + if not args.model_checkpoint: + raise ValueError( + "For fly models, either --model-checkpoint or --model-script must be provided" + ) + model_config = FlyModelConfig( + checkpoint_path=args.model_checkpoint, + channels=args.channels, + input_voxel_size=tuple(args.input_voxel_size), + output_voxel_size=tuple(args.output_voxel_size), + name=args.model_name, + ) + elif args.model_type == "dacapo": + if not args.model_checkpoint: + raise ValueError("For dacapo models, --model-checkpoint is required") + checkpoint_path = Path(args.model_checkpoint) + iteration = int(checkpoint_path.stem.split('_')[-1]) + run_name = checkpoint_path.parent.name + + model_config = DaCapoModelConfig( + run_name=run_name, + iteration=iteration, + ) + elif args.model_type == "huggingface": + if not args.repo: + raise ValueError("For huggingface models, --repo is required") + model_config = HuggingFaceModelConfig( + repo=args.repo, + revision=args.revision, + name=args.model_name, + ) + else: + raise ValueError(f"Unknown model type: {args.model_type}") + + base_model = model_config.config.model + logger.info(f"Model loaded: {type(base_model).__name__}") + + # TorchScript models (RecursiveScriptModule) can't be used with LoRA. + # Use cellmap_model.train() to get a trainable nn.Module via torch.export + # unflatten — no fly_organelles dependency needed. + if isinstance(base_model, torch.jit.ScriptModule): + logger.info("TorchScript model detected — loading trainable model via cellmap_model.train()...") + cellmap_model = None + if args.model_type == "huggingface": + from cellmap_models.model_export.cellmap_model import get_huggingface_model + cellmap_model = get_huggingface_model(args.repo, args.revision) + elif hasattr(model_config, 'cellmap_model'): + cellmap_model = model_config.cellmap_model + + if cellmap_model is not None: + trainable = cellmap_model.train() + if trainable is not None: + # UnflattenedModule (from torch.export) often has fixed batch=1. + # Wrap it so the trainer can use any batch size. + if type(trainable).__name__ == 'UnflattenedModule': + from cellmap_flow.finetune.lora_wrapper import BatchLoopWrapper + trainable = BatchLoopWrapper(trainable) + logger.info("Wrapped UnflattenedModule with BatchLoopWrapper for variable batch sizes") + base_model = trainable + logger.info(f"Trainable model loaded: {type(base_model).__name__}") + else: + logger.warning("cellmap_model.train() returned None — LoRA may fail") + else: + logger.warning("No CellmapModel available — LoRA may fail on TorchScript model") + + # === Wrap with LoRA (once - same object is reused across restarts) === + logger.info(f"Wrapping model with LoRA (r={args.lora_r})...") + lora_model = wrap_model_with_lora( + base_model, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + + # === Training loop (supports restart via signal file) === + server_started = False + restart_controller = RestartController() + iteration = 0 + + while True: + iteration += 1 + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + if iteration > 1: + logger.info("") + logger.info("=" * 60) + logger.info(f"Training Iteration {iteration}") + logger.info("=" * 60) + + # Create dataloader (re-created each iteration to pick up new annotations) + if iteration > 1: + print("RESTART_STATUS: Loading corrections...", flush=True) + logger.info(f"Loading corrections from {args.corrections}...") + dataloader = create_dataloader( + args.corrections, + batch_size=args.batch_size, + patch_shape=tuple(args.patch_shape) if args.patch_shape is not None else None, + augment=not args.no_augment, + num_workers=args.num_workers, + shuffle=True, + model_name=args.model_name, + ) + logger.info(f"DataLoader created: {len(dataloader.dataset)} corrections") + + # Snapshot the active input_norm into metadata.json so any saved + # checkpoint in this iteration is reproducible -- you can read + # metadata.json next to the .pth and know exactly which + # normalization was applied to the training data. + try: + from cellmap_flow.finetune.virtual_dataset import read_manifest + + manifest_norm = (read_manifest(args.corrections) or {}).get("input_norm") + if manifest_norm is not None and args.output_dir: + metadata_file = Path(args.output_dir) / "metadata.json" + if metadata_file.exists(): + import json as json_mod + with open(metadata_file) as f: + md = json_mod.load(f) + md.setdefault("params", {})["input_norm"] = manifest_norm + with open(metadata_file, "w") as f: + json_mod.dump(md, f, indent=2) + logger.info( + f"Snapshot input_norm into {metadata_file} " + f"(keys: {list(manifest_norm.keys())})" + ) + except Exception as _e: + logger.warning(f"Could not snapshot input_norm into metadata.json: {_e}") + + # Build target transform (re-built each iteration to pick up restart params) + select_channel = args.select_channel + target_transform = _build_target_transform(args, model_config) + logger.info(f"output_type={args.output_type}, select_channel={select_channel}") + + # Create trainer (re-created each iteration for fresh optimizer/scheduler) + if iteration > 1: + print("RESTART_STATUS: Preparing trainer...", flush=True) + logger.info("Creating trainer...") + trainer = LoRAFinetuner( + lora_model, + dataloader, + output_dir=args.output_dir, + learning_rate=args.learning_rate, + num_epochs=args.num_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + use_mixed_precision=not args.no_mixed_precision, + loss_type=args.loss_type, + select_channel=select_channel, + mask_unannotated=args.mask_unannotated, + label_smoothing=args.label_smoothing, + distillation_lambda=args.distillation_lambda, + distillation_all_voxels=args.distillation_all_voxels, + margin=args.margin, + balance_classes=args.balance_classes, + target_transform=target_transform, + ) + + # Resume from checkpoint if specified (first iteration only) + if args.resume and iteration == 1: + logger.info(f"Resuming from checkpoint: {args.resume}") + trainer.load_checkpoint(args.resume) + + # Train + try: + if iteration > 1: + print("RESTART_STATUS: Starting training...", flush=True) + stats = trainer.train() + + # If training diverged (NaN/Inf), skip saving and wait for restart + if stats.get('diverged'): + logger.warning("Training diverged — skipping model save.") + if args.auto_serve: + # Still wait for restart so the user can adjust params + signal_file = Path(args.output_dir) / "restart_signal.json" + restart_data = _wait_for_restart_signal( + signal_file=signal_file, + check_interval=1.0, + restart_controller=restart_controller, + ) + if restart_data is None: + logger.error("Malformed restart signal, exiting") + return 1 + _apply_restart_params(args, restart_data) + + # Reset LoRA weights for fresh restart + logger.info("Resetting LoRA adapter weights for fresh restart...") + from peft import PeftModel + if isinstance(lora_model, PeftModel): + base = lora_model.unload() + lora_model = wrap_model_with_lora( + base, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + + lora_model.train() + torch.cuda.empty_cache() + gc.collect() + logger.info("Restarting training with fresh LoRA weights...") + print("RESTARTING_TRAINING", flush=True) + continue + else: + return 1 + + # Save final adapter + logger.info("\nSaving LoRA adapter...") + trainer.save_adapter() + + logger.info("\n" + "=" * 60) + logger.info("Finetuning Complete!") + logger.info(f"Best loss: {stats['best_loss']:.6f}") + logger.info(f"Adapter saved to: {args.output_dir}/lora_adapter") + logger.info("=" * 60) + + # Generate model files + finetuned_model_name, _ = _generate_model_files( + args, model_config, timestamp + ) + + # Print completion marker with timestamp (for job manager to detect) + print(f"TRAINING_ITERATION_COMPLETE: {finetuned_model_name}", flush=True) + + # Auto-serve if requested + if args.auto_serve: + if not server_started: + # First time: start inference server in background thread + try: + _start_inference_server_background( + args, model_config, lora_model, restart_controller=restart_controller + ) + server_started = True + except Exception as e: + logger.error(f"Failed to start inference server: {e}", exc_info=True) + print(f"INFERENCE_SERVER_FAILED: {e}", flush=True) + return 0 + else: + # Server already running - just set model back to eval mode + # The server shares the same model object, so it automatically + # serves with the updated weights + lora_model.eval() + logger.info("Model updated and set to eval mode. Server continuing with new weights.") + + # Watch for restart signal + signal_file = Path(args.output_dir) / "restart_signal.json" + restart_data = _wait_for_restart_signal( + signal_file=signal_file, + check_interval=1.0, + restart_controller=restart_controller, + ) + + if restart_data is None: + logger.error("Malformed restart signal, exiting") + return 1 + + # Apply updated parameters + _apply_restart_params(args, restart_data) + + # Reset LoRA weights to initial state for a true restart. + # Delete the current adapter and re-create it so training + # starts from the frozen base model, not from the previous + # finetuned weights. + logger.info("Resetting LoRA adapter weights for fresh restart...") + from peft import PeftModel + if isinstance(lora_model, PeftModel): + base = lora_model.unload() + lora_model = wrap_model_with_lora( + base, + lora_r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + + lora_model.train() + torch.cuda.empty_cache() + gc.collect() + logger.info("Restarting training with fresh LoRA weights...") + print("RESTARTING_TRAINING", flush=True) + continue # Loop back to retrain + + # No auto-serve: just exit after training + return 0 + + except KeyboardInterrupt: + logger.info("\nTraining interrupted by user") + logger.info("Saving current state...") + trainer.save_checkpoint(is_best=False) + return 1 + + except Exception as e: + logger.error(f"Training failed: {e}", exc_info=True) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/cellmap_flow/finetune/finetune_job_manager.py b/cellmap_flow/finetune/finetune_job_manager.py new file mode 100644 index 0000000..39cb77c --- /dev/null +++ b/cellmap_flow/finetune/finetune_job_manager.py @@ -0,0 +1,1439 @@ +""" +Job manager for orchestrating finetuning jobs on LSF cluster. + +This module provides: +- FinetuneJob: Track metadata and status of a single finetuning job +- FinetuneJobManager: Orchestrate job lifecycle from submission to completion +""" + +import json +import logging +import shlex +import re +import sys +import threading +import time +import uuid +import requests +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Dict, List, Optional, Any + +from cellmap_flow.utils.bsub_utils import ( + submit_bsub_job, + run_locally, + is_bsub_available, + LSFJob, + JobStatus as LSFJobStatus +) + +logger = logging.getLogger(__name__) + + +class JobStatus(Enum): + """Status of a finetuning job.""" + PENDING = "PENDING" + RUNNING = "RUNNING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELLED = "CANCELLED" + + +@dataclass +class FinetuneJob: + """Track a finetuning job with metadata, status, and training progress. + + Manages lifecycle from submission through completion, including inference + server state and restart chain linkage. + """ + job_id: str + lsf_job: Optional[LSFJob] + model_name: str + output_dir: Path + params: Dict[str, Any] + status: JobStatus + created_at: datetime + log_file: Path + finetuned_model_name: Optional[str] = None + model_script_path: Optional[Path] = None + model_yaml_path: Optional[Path] = None + current_epoch: int = 0 + total_epochs: int = 10 + latest_loss: Optional[float] = None + inference_server_url: Optional[str] = None + inference_server_ready: bool = False + previous_job_id: Optional[str] = None + next_job_id: Optional[str] = None + _processed_iteration_count: int = 0 + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + # Get LSF job ID or local PID + lsf_job_id = None + if self.lsf_job: + if hasattr(self.lsf_job, 'job_id'): + lsf_job_id = self.lsf_job.job_id + elif hasattr(self.lsf_job, 'process'): + lsf_job_id = f"PID:{self.lsf_job.process.pid}" + + return { + "job_id": self.job_id, + "lsf_job_id": lsf_job_id, + "model_name": self.model_name, + "output_dir": str(self.output_dir), + "params": self.params, + "status": self.status.value, + "created_at": self.created_at.isoformat(), + "log_file": str(self.log_file), + "finetuned_model_name": self.finetuned_model_name, + "model_script_path": str(self.model_script_path) if self.model_script_path else None, + "model_yaml_path": str(self.model_yaml_path) if self.model_yaml_path else None, + "current_epoch": self.current_epoch, + "total_epochs": self.total_epochs, + "latest_loss": self.latest_loss, + "inference_server_url": self.inference_server_url, + "inference_server_ready": self.inference_server_ready, + "previous_job_id": self.previous_job_id, + "next_job_id": self.next_job_id, + } + + +class FinetuneJobManager: + """ + Orchestrate finetuning jobs from submission to completion. + + Manages the full lifecycle: + 1. Validation and job submission to LSF + 2. Background monitoring of training progress + 3. Post-training model registration + 4. Job cancellation and cleanup + """ + + def __init__(self): + """Initialize the job manager.""" + self.jobs: Dict[str, FinetuneJob] = {} + self.logger = logging.getLogger(__name__) + self._monitor_threads: Dict[str, threading.Thread] = {} + + def _get_model_metadata(self, model_config, attr_name: str, default=None): + """ + Get metadata from model config, checking both direct attributes and loaded config. + + Args: + model_config: The model configuration object + attr_name: Name of the attribute to retrieve + default: Default value if attribute not found + + Returns: + The attribute value if found, otherwise the default value + """ + # First try direct attribute access + if hasattr(model_config, attr_name): + value = getattr(model_config, attr_name, None) + if value is not None: + return value + + # Then try loading config and checking there + try: + config = model_config.config + if hasattr(config, attr_name): + value = getattr(config, attr_name, None) + if value is not None: + return value + except Exception as e: + self.logger.debug(f"Could not load config to check for {attr_name}: {e}") + + return default + + def _extract_data_path_from_corrections(self, corrections_path: Path) -> str: + """Extract dataset path from corrections metadata.""" + # Look for first .zarr directory + zarr_dirs = list(corrections_path.glob("*.zarr")) + if not zarr_dirs: + raise ValueError("No .zarr directories found in corrections") + + # Read .zattrs + zattrs_file = zarr_dirs[0] / ".zattrs" + if not zattrs_file.exists(): + raise ValueError("No .zattrs metadata found in corrections") + + with open(zattrs_file) as f: + metadata = json.load(f) + + if "dataset_path" not in metadata: + raise ValueError("No 'dataset_path' found in corrections metadata") + + return metadata["dataset_path"] + + def _build_base_model_dict(self, finetune_job: FinetuneJob, metadata: dict) -> dict: + """Build base_model dict for FinetuneModelConfig from job metadata. + + Reconstructs the dict that would come from model_config.to_dict(), + based on what was stored in metadata.json at job submission time. + """ + model_type = metadata.get("model_type", "fly") + + if model_type == "huggingface": + result = {"type": "huggingface", "repo": metadata["repo"]} + if metadata.get("revision"): + result["revision"] = metadata["revision"] + return result + + if model_type == "script" or metadata.get("model_script"): + return { + "type": "script", + "script_path": metadata["model_script"], + } + + # Default: fly model with checkpoint + result = { + "type": "fly", + "channels": finetune_job.params.get("channels", ["mito"]), + "input_voxel_size": finetune_job.params.get("input_voxel_size", [16, 16, 16]), + "output_voxel_size": finetune_job.params.get("output_voxel_size", [16, 16, 16]), + } + checkpoint = metadata.get("model_checkpoint") or finetune_job.params.get("model_checkpoint") + if checkpoint: + result["checkpoint_path"] = checkpoint + return result + + def _resolve_model_type(self, model_config) -> str: + """Infer the finetuning CLI model type from the model config.""" + model_type = getattr(type(model_config), "cli_name", "fly") + if model_type == "fly" and "dacapo" in model_config.name.lower(): + return "dacapo" + return model_type + + def _normalize_metadata_list(self, value, default): + """Return model metadata as a plain list for CLI serialization.""" + if value is None: + return list(default) + if isinstance(value, str): + return [value] + if isinstance(value, list): + return value + return list(value) + + def _build_finetune_command( + self, + *, + model_config, + model_type: str, + checkpoint_path: Optional[Path], + corrections_path: Path, + output_dir: Path, + log_file: Path, + channels: List[str], + input_voxel_size: List[int], + output_voxel_size: List[int], + lora_r: int, + num_epochs: int, + batch_size: int, + learning_rate: float, + loss_type: str, + label_smoothing: float, + distillation_lambda: float, + distillation_scope: str, + margin: float, + auto_serve: bool, + serve_data_path: Optional[str], + mask_unannotated: bool, + balance_classes: bool, + output_type: str, + select_channel: Optional[int], + offsets: Optional[str], + ) -> str: + """Build the shell command used to launch finetuning.""" + command_parts = [ + sys.executable, + "-m", + "cellmap_flow.finetune.finetune_cli", + "--model-type", model_type, + ] + + if model_type == "huggingface": + command_parts += ["--repo", str(model_config.repo)] + if getattr(model_config, "revision", None): + command_parts += ["--revision", str(model_config.revision)] + elif checkpoint_path: + command_parts += ["--model-checkpoint", str(checkpoint_path)] + elif hasattr(model_config, "script_path"): + command_parts += ["--model-script", str(model_config.script_path)] + + command_parts += [ + "--corrections", str(corrections_path), + "--output-dir", str(output_dir), + "--model-name", str(model_config.name), + "--channels", *map(str, channels), + "--input-voxel-size", *map(str, input_voxel_size), + "--output-voxel-size", *map(str, output_voxel_size), + "--lora-r", str(lora_r), + "--lora-alpha", str(lora_r * 2), + "--num-epochs", str(num_epochs), + "--batch-size", str(batch_size), + "--learning-rate", str(learning_rate), + "--loss-type", str(loss_type), + ] + + if label_smoothing > 0: + command_parts += ["--label-smoothing", str(label_smoothing)] + if distillation_lambda > 0: + command_parts += ["--distillation-lambda", str(distillation_lambda)] + if distillation_scope == "all": + command_parts.append("--distillation-all-voxels") + if loss_type == "margin": + command_parts += ["--margin", str(margin)] + if auto_serve and serve_data_path: + command_parts += ["--auto-serve", "--serve-data-path", str(serve_data_path)] + if mask_unannotated: + command_parts.append("--mask-unannotated") + if balance_classes: + command_parts.append("--balance-classes") + if output_type != "binary": + command_parts += ["--output-type", str(output_type)] + if select_channel is not None: + command_parts += ["--select-channel", str(select_channel)] + if offsets is not None: + command_parts += ["--offsets", str(offsets)] + + command = " ".join(shlex.quote(part) for part in command_parts) + return f"stdbuf -oL {command} 2>&1 | tee {shlex.quote(str(log_file))}" + + def _build_submission_metadata( + self, + *, + model_config, + model_type: str, + checkpoint_path: Optional[Path], + corrections_path: Path, + num_corrections: int, + output_dir: Path, + lora_r: int, + num_epochs: int, + batch_size: int, + learning_rate: float, + loss_type: str, + label_smoothing: float, + distillation_lambda: float, + distillation_scope: str, + margin: float, + balance_classes: bool, + channels: List[str], + input_voxel_size: List[int], + output_voxel_size: List[int], + output_type: str, + queue: str, + charge_group: str, + command: str, + ) -> dict: + """Build metadata persisted for a submitted finetuning job.""" + return { + "job_id": str(uuid.uuid4()), + "model_name": model_config.name, + "model_type": model_type, + "model_checkpoint": str(checkpoint_path) if checkpoint_path else None, + "model_script": str(model_config.script_path) if hasattr(model_config, "script_path") else None, + "repo": model_config.repo if model_type == "huggingface" else None, + "revision": getattr(model_config, "revision", None) if model_type == "huggingface" else None, + "corrections_path": str(corrections_path), + "num_corrections": num_corrections, + "output_dir": str(output_dir), + "params": { + "model_checkpoint": str(checkpoint_path) if checkpoint_path else None, + "lora_r": lora_r, + "lora_alpha": lora_r * 2, + "num_epochs": num_epochs, + "batch_size": batch_size, + "learning_rate": learning_rate, + "loss_type": loss_type, + "label_smoothing": label_smoothing, + "distillation_lambda": distillation_lambda, + "distillation_scope": distillation_scope, + "margin": margin, + "balance_classes": balance_classes, + "channels": channels, + "input_voxel_size": input_voxel_size, + "output_voxel_size": output_voxel_size, + "output_type": output_type, + }, + "queue": queue, + "charge_group": charge_group, + "created_at": datetime.now().isoformat(), + "command": command, + } + + def submit_finetuning_job( + self, + model_config, + corrections_path: Path, + lora_r: int = 8, + num_epochs: int = 10, + batch_size: int = 2, + learning_rate: float = 1e-4, + output_base: Optional[Path] = None, + queue: str = "gpu_h100", + charge_group: str = "cellmap", + checkpoint_path_override: Optional[Path] = None, + auto_serve: bool = True, + mask_unannotated: bool = False, + loss_type: str = "combined", + label_smoothing: float = 0.0, + distillation_lambda: float = 0.0, + distillation_scope: str = "unlabeled", + margin: float = 0.3, + balance_classes: bool = False, + output_type: str = "binary", + select_channel: Optional[int] = None, + offsets: Optional[str] = None, + ) -> FinetuneJob: + """ + Submit finetuning job to LSF cluster. + + Args: + model_config: Model configuration object (FlyModelConfig, etc.) + corrections_path: Path to corrections.zarr directory + lora_r: LoRA rank (default: 8) + num_epochs: Number of training epochs (default: 10) + batch_size: Training batch size (default: 2) + learning_rate: Learning rate (default: 1e-4) + output_base: Base directory for outputs (default: output/finetuning) + queue: LSF queue name (default: gpu_h100) + charge_group: LSF charge group (default: cellmap) + checkpoint_path_override: Optional path to override checkpoint detection (default: None) + auto_serve: Automatically start inference server after training (default: True) + + Returns: + FinetuneJob object tracking the submitted job + + Raises: + ValueError: If validation fails + RuntimeError: If job submission fails + """ + # === Validation === + + # 1. Check model config + if not model_config: + raise ValueError("Model config is required") + + # 2. Get checkpoint path if available (optional) + # For script models: we'll pass the script path instead + # For fly/dacapo models: we need the checkpoint path + checkpoint_path = None + + # Check for checkpoint override first + if checkpoint_path_override: + checkpoint_path = Path(checkpoint_path_override) + self.logger.info(f"Using checkpoint path override: {checkpoint_path}") + # For FlyModelConfig, get checkpoint_path attribute + elif hasattr(model_config, 'checkpoint_path') and model_config.checkpoint_path: + checkpoint_path = Path(model_config.checkpoint_path) + self.logger.info(f"Found checkpoint_path: {checkpoint_path}") + + # Validate checkpoint exists if specified + if checkpoint_path and not checkpoint_path.exists(): + raise ValueError( + f"Model checkpoint not found: {checkpoint_path}\n" + f"Please verify the path exists and is accessible." + ) + + # 3. Check corrections path exists + if not corrections_path.exists(): + raise ValueError(f"Corrections path does not exist: {corrections_path}") + + # 4. Count corrections (warn if few) + correction_dirs = list(corrections_path.glob("*/")) + num_corrections = len([d for d in correction_dirs if (d / ".zattrs").exists()]) + + if num_corrections == 0: + raise ValueError(f"No corrections found in {corrections_path}") + + if num_corrections < 5: + self.logger.warning( + f"Only {num_corrections} corrections found. " + "Recommend at least 5-10 for meaningful finetuning." + ) + + self.logger.info(f"Found {num_corrections} corrections for training") + + # === Setup output directory === + + if output_base is None: + output_base = Path("output/finetuning") + else: + output_base = Path(output_base) + + # Create timestamped run directory inside finetuning subdirectory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_basename = model_config.name.replace("/", "_").replace(" ", "_") + run_dir_name = f"{model_basename}_{timestamp}" + output_dir = output_base / "finetuning" / "runs" / run_dir_name + output_dir.mkdir(parents=True, exist_ok=True) + + log_file = output_dir / "training_log.txt" + + self.logger.info(f"Output directory: {output_dir}") + + # === Build training command === + + # Get model type from the config class's cli_name (e.g., "fly", "dacapo", "huggingface") + model_type = self._resolve_model_type(model_config) + + # Get channels - try multiple attribute names + channels = None + for attr_name in ["channels", "classes", "class_names"]: + channels = self._get_model_metadata(model_config, attr_name, None) + if channels: + break + if channels is None: + channels = ["mito"] # Default fallback + channels = self._normalize_metadata_list(channels, ["mito"]) + + # Get voxel sizes + input_voxel_size = self._get_model_metadata(model_config, "input_voxel_size", [16, 16, 16]) + output_voxel_size = self._get_model_metadata(model_config, "output_voxel_size", [16, 16, 16]) + + input_voxel_size = self._normalize_metadata_list(input_voxel_size, [16, 16, 16]) + output_voxel_size = self._normalize_metadata_list(output_voxel_size, [16, 16, 16]) + + # Extract data path for inference server if auto-serve is enabled + serve_data_path = None + if auto_serve: + try: + serve_data_path = self._extract_data_path_from_corrections(corrections_path) + self.logger.info(f"Extracted dataset path for inference: {serve_data_path}") + except Exception as e: + self.logger.warning(f"Could not extract dataset path from corrections: {e}") + self.logger.warning("Auto-serve will be disabled") + auto_serve = False + + cli_command = self._build_finetune_command( + model_config=model_config, + model_type=model_type, + checkpoint_path=checkpoint_path, + corrections_path=corrections_path, + output_dir=output_dir, + log_file=log_file, + channels=channels, + input_voxel_size=input_voxel_size, + output_voxel_size=output_voxel_size, + lora_r=lora_r, + num_epochs=num_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + loss_type=loss_type, + label_smoothing=label_smoothing, + distillation_lambda=distillation_lambda, + distillation_scope=distillation_scope, + margin=margin, + auto_serve=auto_serve, + serve_data_path=serve_data_path, + mask_unannotated=mask_unannotated, + balance_classes=balance_classes, + output_type=output_type, + select_channel=select_channel, + offsets=offsets, + ) + + self.logger.info(f"Training command: {cli_command}") + + # === Save job metadata === + + metadata = self._build_submission_metadata( + model_config=model_config, + model_type=model_type, + checkpoint_path=checkpoint_path, + corrections_path=corrections_path, + num_corrections=num_corrections, + output_dir=output_dir, + lora_r=lora_r, + num_epochs=num_epochs, + batch_size=batch_size, + learning_rate=learning_rate, + loss_type=loss_type, + label_smoothing=label_smoothing, + distillation_lambda=distillation_lambda, + distillation_scope=distillation_scope, + margin=margin, + balance_classes=balance_classes, + channels=channels, + input_voxel_size=input_voxel_size, + output_voxel_size=output_voxel_size, + output_type=output_type, + queue=queue, + charge_group=charge_group, + command=cli_command, + ) + + metadata_file = output_dir / "metadata.json" + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.info(f"Saved metadata to {metadata_file}") + + # === Submit job (LSF or local) === + + job_name = f"finetune_{model_basename}_{timestamp}" + + # Check if bsub is available + if is_bsub_available(): + self.logger.info("Submitting to LSF cluster via bsub") + try: + lsf_job = submit_bsub_job( + command=cli_command, + queue=queue, + charge_group=charge_group, + job_name=job_name, + num_gpus=1, + num_cpus=4 + ) + self.logger.info(f"Submitted LSF job {lsf_job.job_id} for finetuning") + except Exception as e: + self.logger.error(f"Failed to submit job to LSF: {e}") + raise RuntimeError(f"Job submission to LSF failed: {e}") + else: + # Fallback to local execution + self.logger.info("bsub not available - running finetuning locally") + try: + lsf_job = run_locally( + command=cli_command, + name=job_name + ) + self.logger.info(f"Started local finetuning job (PID: {lsf_job.process.pid})") + except Exception as e: + self.logger.error(f"Failed to start local job: {e}") + raise RuntimeError(f"Local job execution failed: {e}") + + # === Create FinetuneJob tracking object === + + job_id = metadata["job_id"] + + finetune_job = FinetuneJob( + job_id=job_id, + lsf_job=lsf_job, + model_name=model_config.name, + output_dir=output_dir, + params=metadata["params"], + status=JobStatus.PENDING, + created_at=datetime.now(), + log_file=log_file, + total_epochs=num_epochs + ) + + self.jobs[job_id] = finetune_job + + # === Start monitoring thread === + + monitor_thread = threading.Thread( + target=self.monitor_job, + args=(finetune_job,), + daemon=True + ) + monitor_thread.start() + self._monitor_threads[job_id] = monitor_thread + + self.logger.info(f"Started monitoring thread for job {job_id}") + + return finetune_job + + def monitor_job(self, finetune_job: FinetuneJob): + """ + Background thread for job monitoring. + + Polls LSF status and tails log file to track training progress. + Triggers completion when job finishes. + + Args: + finetune_job: The FinetuneJob to monitor + """ + job_id = finetune_job.job_id + self.logger.info(f"Monitoring job {job_id}...") + + last_log_position = 0 + check_interval = 3 # seconds + + try: + while True: + # === Check LSF job status === + + if finetune_job.lsf_job: + lsf_status = finetune_job.lsf_job.get_status() + + # Map LSF status to FinetuneJob status + if lsf_status == LSFJobStatus.RUNNING: + if finetune_job.status == JobStatus.PENDING: + self.logger.info(f"Job {job_id} started running") + finetune_job.status = JobStatus.RUNNING + elif lsf_status == LSFJobStatus.PENDING: + finetune_job.status = JobStatus.PENDING + elif lsf_status == LSFJobStatus.COMPLETED: + self.logger.info(f"Job {job_id} completed according to LSF") + finetune_job.status = JobStatus.COMPLETED + break + elif lsf_status == LSFJobStatus.FAILED: + self.logger.error(f"Job {job_id} failed according to LSF") + finetune_job.status = JobStatus.FAILED + break + elif lsf_status == LSFJobStatus.KILLED: + self.logger.warning(f"Job {job_id} was killed") + finetune_job.status = JobStatus.CANCELLED + break + + # === Tail log file for progress updates === + + if finetune_job.log_file.exists(): + try: + # Check if file was truncated (e.g., during restart archival) + file_size = finetune_job.log_file.stat().st_size + if file_size < last_log_position: + self.logger.info(f"Log file truncated (size {file_size} < position {last_log_position}), resetting") + last_log_position = 0 + + with open(finetune_job.log_file, "r") as f: + # Seek to last read position + f.seek(last_log_position) + new_content = f.read() + last_log_position = f.tell() + + if new_content: + # Parse for epoch and loss information + self._parse_training_progress(finetune_job, new_content) + # Parse for inference server ready marker + self._parse_inference_server_ready(finetune_job, new_content) + + # Always check for restart/iteration markers (reads full log). + # This must run every cycle, not just when there's new content, + # because the marker may have been at the end of the previous + # chunk and we need to detect it even if no new output follows. + self._parse_training_restart(finetune_job, new_content if new_content else "") + except Exception as e: + self.logger.debug(f"Error reading log file: {e}") + + # Sleep before next check + time.sleep(check_interval) + + except Exception as e: + self.logger.error(f"Error monitoring job {job_id}: {e}") + finetune_job.status = JobStatus.FAILED + + finally: + # === Post-completion actions === + + if finetune_job.status == JobStatus.COMPLETED: + try: + self.complete_job(finetune_job) + except Exception as e: + self.logger.error(f"Error in post-completion for job {job_id}: {e}") + finetune_job.status = JobStatus.FAILED + + self.logger.info(f"Stopped monitoring job {job_id}. Final status: {finetune_job.status.value}") + + def _parse_training_progress(self, finetune_job: FinetuneJob, log_content: str): + """ + Parse log content for training progress (epoch, loss). + + Args: + finetune_job: Job to update + log_content: New log content to parse + """ + # Pair epoch and loss from the per-epoch summary line + # ("Epoch X/Y - Loss: Z") so the loss is guaranteed to belong to the + # epoch reported alongside it. A previous version scanned epoch and + # loss with independent regexes; per-batch lines ("Batch X/N - Loss: Z") + # then bumped latest_loss inside epoch N+1 while current_epoch was + # still pinned to epoch N's summary, so the dashboard plot pinned + # epoch N's running batch loss onto epoch N-1. + summary_pattern = r"Epoch\s+(\d+)/(\d+)\s*-\s*Loss:\s*([\d.]+)" + summary_matches = re.findall(summary_pattern, log_content, re.IGNORECASE) + if summary_matches: + cur, total, loss = summary_matches[-1] + finetune_job.current_epoch = int(cur) + finetune_job.total_epochs = int(total) + try: + finetune_job.latest_loss = float(loss) + except ValueError: + pass + return + + # No epoch summary yet (still in epoch 1's batches): fall back to + # bare "Epoch X/Y" so the progress bar can advance, but leave + # latest_loss alone -- per-batch losses are not epoch summaries. + epoch_pattern = r"Epoch\s+(\d+)/(\d+)" + epoch_matches = re.findall(epoch_pattern, log_content, re.IGNORECASE) + if epoch_matches: + cur, total = epoch_matches[-1] + finetune_job.current_epoch = int(cur) + finetune_job.total_epochs = int(total) + + def _add_finetuned_neuroglancer_layer(self, finetune_job: FinetuneJob, model_name: str): + """ + Add (or replace) the finetuned model's neuroglancer layer. + + Mirrors run_model() from cellmap_flow/models/run.py: + 1. Create/update Job object in g.jobs + 2. Add neuroglancer ImageLayer with pre/post processing args + + Args: + finetune_job: Job with inference_server_url set + model_name: Layer name (e.g. "mito_finetuned_20240101_120000") + """ + from cellmap_flow.globals import g + from cellmap_flow.utils.web_utils import get_norms_post_args, ARGS_KEY + import neuroglancer + + server_url = finetune_job.inference_server_url + + # Create a Job object for the running server + inference_job = LSFJob( + job_id=finetune_job.lsf_job.job_id if finetune_job.lsf_job else "local", + model_name=model_name + ) + inference_job.host = server_url + inference_job.status = LSFJobStatus.RUNNING + + # Remove any old finetuned jobs for this base model + g.jobs = [ + j for j in g.jobs + if not (hasattr(j, 'model_name') and j.model_name + and j.model_name.startswith(f"{finetune_job.model_name}_finetuned")) + ] + + # Add to g.jobs + g.jobs.append(inference_job) + self.logger.info(f"Added finetuned job to g.jobs: {model_name}") + + # Get pre/post processing args (same hash as other models) + st_data = get_norms_post_args(g.input_norms, g.postprocess) + + if g.viewer is None: + self.logger.error("g.viewer is None - neuroglancer not initialized yet") + return + + # Lie about the model's voxel size so the layer overlays the raw at + # the closest available scale (e.g. trained at 16nm but raw is + # multiscale 6/12/24 -> tell neuroglancer it's 12nm). + from cellmap_flow.utils.neuroglancer_utils import ( + build_prediction_source, + get_raw_closest_scale, + ) + override_scales = None + try: + output_voxel_size = tuple( + finetune_job.params.get("output_voxel_size") or () + ) + dataset_path = getattr(g, "dataset_path", None) + if output_voxel_size and dataset_path: + closest = get_raw_closest_scale(dataset_path, output_voxel_size) + if closest is not None and tuple(closest) != tuple(output_voxel_size): + override_scales = closest + self.logger.info( + f"Finetuned model '{model_name}' output_voxel_size=" + f"{output_voxel_size} overridden to closest raw scale " + f"{closest} for viewer overlay" + ) + except Exception as e: + self.logger.warning( + f"Could not compute override scales for finetuned '{model_name}': {e}" + ) + + source_spec = build_prediction_source( + server_url, model_name, st_data, override_scales + ) + self.logger.info(f"Adding neuroglancer layer: {model_name}") + self.logger.info(f" source: {source_spec}") + + with g.viewer.txn() as s: + # Remove old finetuned layer if it exists (exact name match) + old_layer_name = finetune_job.finetuned_model_name + if old_layer_name and old_layer_name in s.layers: + self.logger.info(f"Removing old finetuned layer: {old_layer_name}") + del s.layers[old_layer_name] + + # Also remove by current name in case of re-add + if model_name in s.layers: + del s.layers[model_name] + + # Add new layer - exact same format as run_model() + s.layers[model_name] = neuroglancer.ImageLayer( + source=source_spec, + shader=f"""#uicontrol invlerp normalized(range=[0, 255], window=[0, 255]); + #uicontrol vec3 color color(default="red"); + void main(){{emitRGB(color * normalized());}}""", + ) + + # Update the stored name + finetune_job.finetuned_model_name = model_name + self.logger.info(f"Successfully added neuroglancer layer: {model_name}") + + def _parse_inference_server_ready(self, finetune_job: FinetuneJob, log_content: str): + """ + Parse log for CELLMAP_FLOW_SERVER_IP marker and add finetuned model + to neuroglancer exactly like a normal inference model. + + Args: + finetune_job: Job to update + log_content: New log content to parse + """ + if finetune_job.inference_server_ready: + return + + # Look for the standard server IP marker (same one start_hosts() uses) + from cellmap_flow.utils.web_utils import IP_PATTERN + ip_start = IP_PATTERN[0] + ip_end = IP_PATTERN[1] + + pattern = re.escape(ip_start) + r"(.+?)" + re.escape(ip_end) + matches = re.findall(pattern, log_content) + if not matches: + return + + server_url = matches[-1] + finetune_job.inference_server_url = server_url + finetune_job.inference_server_ready = True + self.logger.info(f"Finetuned inference server detected at {server_url}") + + try: + # Read the FULL log file to find TRAINING_ITERATION_COMPLETE marker. + # This marker is printed BEFORE the server starts, so it's typically + # in an earlier log chunk than the server IP marker. + iter_pattern = r"TRAINING_ITERATION_COMPLETE:\s+(\S+)" + full_log = finetune_job.log_file.read_text() + iter_matches = re.findall(iter_pattern, full_log) + if iter_matches: + model_name = iter_matches[-1] + else: + model_name = f"{finetune_job.model_name}_finetuned" + + self._add_finetuned_neuroglancer_layer(finetune_job, model_name) + except Exception as e: + self.logger.error(f"Failed to add finetuned model to neuroglancer: {e}", exc_info=True) + + try: + self._register_finetune_model_config(finetune_job, model_name) + except Exception as e: + self.logger.error(f"Failed to register FinetuneModelConfig: {e}", exc_info=True) + + def _register_finetune_model_config( + self, finetune_job: FinetuneJob, finetuned_model_name: str + ): + """Register a FinetuneModelConfig in g.models_config so it appears + in the pipeline builder with auto-populated parameters.""" + from cellmap_flow.globals import g + from cellmap_flow.models.models_config import FinetuneModelConfig + + adapter_path = str(finetune_job.output_dir / "lora_adapter") + params = finetune_job.params + + # Find the base model's to_dict() from g.models_config + base_model_dict = None + if hasattr(g, "models_config") and g.models_config: + for mc in g.models_config: + if getattr(mc, "name", None) == finetune_job.model_name: + base_model_dict = mc.to_dict() + break + + if base_model_dict is None: + # Fallback: reconstruct from job params + base_model_dict = {"type": "fly"} + if params.get("model_checkpoint"): + base_model_dict["checkpoint_path"] = params["model_checkpoint"] + for key in ("channels", "input_voxel_size", "output_voxel_size"): + if key in params: + base_model_dict[key] = params[key] + + ft_config = FinetuneModelConfig( + lora_adapter_path=adapter_path, + base_model=base_model_dict, + name=finetuned_model_name, + scale=params.get("scale"), + ) + + if not hasattr(g, "models_config"): + g.models_config = [] + + # Remove any previous finetuned versions of the same base model + base_model_name = finetune_job.model_name + g.models_config = [ + mc + for mc in g.models_config + if not ( + hasattr(mc, "name") + and mc.name.startswith(f"{base_model_name}_finetuned") + ) + ] + + g.models_config.append(ft_config) + self.logger.info( + f"Registered FinetuneModelConfig: {finetuned_model_name}" + ) + + def _parse_training_restart(self, finetune_job: FinetuneJob, log_content: str): + """ + Parse log for RESTARTING_TRAINING and TRAINING_ITERATION_COMPLETE markers + to handle iterative training restarts. + + On RESTARTING_TRAINING: reset training progress counters. + On TRAINING_ITERATION_COMPLETE: update the neuroglancer layer name with new timestamp. + + Args: + finetune_job: Job to update + log_content: New log content to parse + """ + # Check for diverged marker - training produced NaN/Inf loss + if "TRAINING_DIVERGED" in log_content: + self.logger.warning(f"Training diverged for job {finetune_job.job_id}") + finetune_job.status = JobStatus.RUNNING # still alive, waiting for restart + finetune_job.latest_loss = None + + # Check for restart marker - reset progress + if "RESTARTING_TRAINING" in log_content: + self.logger.info(f"Training restart detected for job {finetune_job.job_id}") + finetune_job.current_epoch = 0 + finetune_job.latest_loss = None + finetune_job.status = JobStatus.RUNNING + finetune_job.inference_server_ready = False + + # Check for iteration complete marker - update neuroglancer layer. + # Read full log in case the marker was in a previous chunk. + iter_pattern = r"TRAINING_ITERATION_COMPLETE:\s+(\S+)" + try: + full_log = finetune_job.log_file.read_text() + except Exception: + full_log = log_content + iter_matches = re.findall(iter_pattern, full_log) + # Only process new iteration-complete markers (ignore ones already handled). + # After a restart, _processed_iteration_count stays at the old count so + # previously-seen markers don't re-trigger inference_server_ready or + # neuroglancer layer updates. + if len(iter_matches) > finetune_job._processed_iteration_count: + finetune_job._processed_iteration_count = len(iter_matches) + + # For in-process restarts, the inference server usually stays on the same + # URL and does not emit a fresh CELLMAP_FLOW_SERVER_IP marker. Mark the + # server as ready once we see a completed training iteration if URL exists. + if finetune_job.inference_server_url: + finetune_job.inference_server_ready = True + + new_model_name = iter_matches[-1] + if new_model_name != finetune_job.finetuned_model_name: + self.logger.info(f"New training iteration complete: {new_model_name}") + try: + self._add_finetuned_neuroglancer_layer(finetune_job, new_model_name) + except Exception as e: + self.logger.error(f"Failed to update neuroglancer layer: {e}", exc_info=True) + # Still update the stored name so the frontend reflects the new model + # and we don't retry the failed neuroglancer update every cycle + finetune_job.finetuned_model_name = new_model_name + try: + self._register_finetune_model_config(finetune_job, new_model_name) + except Exception as e: + self.logger.error(f"Failed to register FinetuneModelConfig: {e}", exc_info=True) + + def complete_job(self, finetune_job: FinetuneJob): + """ + Post-training actions after job completes successfully. + + 1. Verify adapter files exist + 2. Generate model script and YAML + 3. Register in g.models_config + 4. Update job status and metadata + + Args: + finetune_job: The completed job + + Raises: + RuntimeError: If adapter files missing or registration fails + """ + job_id = finetune_job.job_id + self.logger.info(f"Running post-completion for job {job_id}...") + + # === Verify adapter files exist === + + adapter_path = finetune_job.output_dir / "lora_adapter" + + # Check for adapter model (supports both .bin and .safetensors formats) + adapter_model_bin = adapter_path / "adapter_model.bin" + adapter_model_safetensors = adapter_path / "adapter_model.safetensors" + + if not (adapter_model_bin.exists() or adapter_model_safetensors.exists()): + raise RuntimeError( + f"Training completed but adapter model not found. " + f"Checked: {adapter_model_bin} and {adapter_model_safetensors}" + ) + + adapter_config_file = adapter_path / "adapter_config.json" + if not adapter_config_file.exists(): + raise RuntimeError( + f"Training completed but adapter config not found: {adapter_config_file}" + ) + + self.logger.info(f"Verified LoRA adapter files exist in {adapter_path}") + + # === Generate finetuned model name === + + timestamp = finetune_job.created_at.strftime("%Y%m%d_%H%M%S") + model_basename = finetune_job.model_name.replace("/", "_").replace(" ", "_") + finetuned_model_name = f"{model_basename}_finetuned_{timestamp}" + + finetune_job.finetuned_model_name = finetuned_model_name + + self.logger.info(f"Generated finetuned model name: {finetuned_model_name}") + + # === Generate model YAML === + + from cellmap_flow.finetune.finetuned_model_templates import ( + generate_finetuned_model_yaml + ) + + # Models output directory (at session level, not in finetuning subdirectory) + # output_dir structure: session_path/finetuning/runs/model_timestamp/ + # So parent.parent.parent gets us to session_path + models_dir = finetune_job.output_dir.parent.parent.parent / "models" + + try: + models_dir.mkdir(parents=True, exist_ok=True) + self.logger.info(f"Models directory ready: {models_dir}") + except Exception as e: + self.logger.error(f"Failed to create models directory {models_dir}: {e}") + raise RuntimeError(f"Failed to create models directory: {e}") + + # Check if YAML already exists (generated by CLI with auto-serve) + expected_yaml = models_dir / f"{finetuned_model_name}.yaml" + + if expected_yaml.exists(): + self.logger.info(f"Model YAML already generated by CLI, skipping generation") + finetune_job.model_yaml_path = expected_yaml + yaml_path = expected_yaml + else: + self.logger.info(f"Generating model config...") + + # Read metadata for base model info + metadata_file = finetune_job.output_dir / "metadata.json" + metadata = {} + if metadata_file.exists(): + try: + with open(metadata_file, "r") as f: + metadata = json.load(f) + except Exception as e: + self.logger.warning(f"Could not read metadata: {e}") + + try: + # Build base_model_dict from metadata + base_model_dict = self._build_base_model_dict(finetune_job, metadata) + self.logger.info(f"Base model dict: {base_model_dict}") + + # === Extract configuration from base model and corrections === + + data_path = None + json_data = None + base_scale = "s0" # Default scale (only safe default) + + # 1. Get dataset_path from corrections metadata (REQUIRED) + corrections_dir = Path(metadata.get("corrections_path", "")) + try: + data_path = self._extract_data_path_from_corrections(corrections_dir) + self.logger.info(f"Found dataset_path from corrections: {data_path}") + except (ValueError, Exception) as e: + self.logger.error(f"Could not extract dataset_path: {e}") + + # 2. Get normalization and preprocessing from the running server's config + from cellmap_flow.globals import g as g_globals + from cellmap_flow.utils.serilization_utils import serialize_norms_posts_to_json + if hasattr(g_globals, 'input_norms') and g_globals.input_norms: + import json as json_mod + json_data = json_mod.loads(serialize_norms_posts_to_json( + g_globals.input_norms, g_globals.postprocess + )) + self.logger.info(f"Found json_data from running server config") + + # 3. Validate we have required data (NO PLACEHOLDERS!) + if not data_path: + raise RuntimeError( + "Could not determine dataset_path for finetuned model. " + "Checked corrections metadata and base model YAML. " + "Cannot generate model config without actual dataset path." + ) + + if not json_data: + self.logger.warning( + "No json_data (normalization/postprocessing) found. " + "Finetuned model may not work correctly without proper normalization. " + "Consider adding json_data to base model YAML." + ) + + # Generate .yaml config + yaml_path = generate_finetuned_model_yaml( + lora_adapter_path=str(adapter_path), + base_model_dict=base_model_dict, + model_name=finetuned_model_name, + output_path=expected_yaml, + data_path=data_path, + queue=metadata.get("queue", "gpu_h100"), + charge_group=metadata.get("charge_group", "cellmap"), + json_data=json_data, + scale=base_scale, + ) + + finetune_job.model_yaml_path = yaml_path + self.logger.info(f"Generated model YAML: {yaml_path}") + + except Exception as e: + import traceback + self.logger.error(f"Error generating model files: {e}") + self.logger.error(f"Traceback:\n{traceback.format_exc()}") + raise RuntimeError(f"Failed to generate model files: {e}") + + # === Update metadata file with completion info === + + metadata_file = finetune_job.output_dir / "metadata.json" + if metadata_file.exists(): + with open(metadata_file, "r") as f: + metadata = json.load(f) + + metadata["completed_at"] = datetime.now().isoformat() + metadata["status"] = "COMPLETED" + metadata["finetuned_model_name"] = finetuned_model_name + metadata["model_yaml_path"] = str(yaml_path) + metadata["final_epoch"] = finetune_job.current_epoch + metadata["final_loss"] = finetune_job.latest_loss + + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=2) + + self.logger.info(f"Updated metadata file: {metadata_file}") + + self.logger.info(f"Job {job_id} completed successfully!") + + def cancel_job(self, job_id: str) -> bool: + """ + Cancel a running job. + + Args: + job_id: Job ID to cancel + + Returns: + True if successfully cancelled, False otherwise + """ + if job_id not in self.jobs: + self.logger.error(f"Job {job_id} not found") + return False + + finetune_job = self.jobs[job_id] + + if finetune_job.status in [JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED]: + self.logger.warning(f"Job {job_id} already finished with status {finetune_job.status}") + return False + + self.logger.info(f"Cancelling job {job_id}...") + + if finetune_job.lsf_job: + try: + finetune_job.lsf_job.kill() + finetune_job.status = JobStatus.CANCELLED + self.logger.info(f"Successfully cancelled job {job_id}") + return True + except Exception as e: + self.logger.error(f"Error cancelling job {job_id}: {e}") + return False + else: + self.logger.error(f"No LSF job associated with {job_id}") + return False + + def get_job_status(self, job_id: str) -> Optional[Dict[str, Any]]: + """ + Get detailed status of a specific job. + + Args: + job_id: Job ID to query + + Returns: + Dictionary with job status details, or None if not found + """ + if job_id not in self.jobs: + return None + + finetune_job = self.jobs[job_id] + result = finetune_job.to_dict() + result["loss"] = result.pop("latest_loss", None) + result["progress_percent"] = ( + finetune_job.current_epoch / finetune_job.total_epochs * 100 + ) if finetune_job.total_epochs > 0 else 0 + return result + + def list_jobs(self) -> List[Dict[str, Any]]: + """ + Get list of all jobs with their status. + + Returns: + List of job status dictionaries + """ + return [self.get_job_status(job_id) for job_id in self.jobs.keys()] + + def get_job_logs(self, job_id: str) -> Optional[str]: + """ + Get full log content for a job. + + Args: + job_id: Job ID + + Returns: + Log file content as string, or None if not found + """ + if job_id not in self.jobs: + return None + + finetune_job = self.jobs[job_id] + + if not finetune_job.log_file.exists(): + return "Log file not yet created..." + + try: + with open(finetune_job.log_file, "r") as f: + return f.read() + except Exception as e: + self.logger.error(f"Error reading log file: {e}") + return f"Error reading log file: {e}" + + def get_job(self, job_id: str) -> Optional[FinetuneJob]: + """ + Get a FinetuneJob object by ID. + + Args: + job_id: Job ID to retrieve + + Returns: + FinetuneJob object, or None if not found + """ + return self.jobs.get(job_id) + + def _archive_job_logs(self, job: FinetuneJob): + """ + Archive logs before restart. + + Args: + job: The job whose logs to archive + """ + log_file = job.log_file + metadata_file = job.output_dir / "metadata.json" + + # Find next archive number + archive_num = 1 + while (job.output_dir / f"training_log_{archive_num}.txt").exists(): + archive_num += 1 + + # Archive log (copy only - do NOT truncate, as tee still has an open file descriptor) + if log_file.exists(): + import shutil + archive_log = job.output_dir / f"training_log_{archive_num}.txt" + shutil.copy(log_file, archive_log) + self.logger.info(f"Archived log to {archive_log}") + + # Archive metadata + if metadata_file.exists(): + import shutil + archive_meta = job.output_dir / f"metadata_{archive_num}.json" + shutil.copy(metadata_file, archive_meta) + self.logger.info(f"Archived metadata to {archive_meta}") + + def restart_finetuning_job( + self, + job_id: str, + updated_params: Optional[Dict[str, Any]] = None + ) -> FinetuneJob: + """ + Restart training on the same GPU via control endpoint. + + Primary path sends an HTTP restart request to the running + inference server in the same process as the training loop. + Falls back to file signal if control endpoint is unavailable. + + Args: + job_id: ID of job to restart + updated_params: Dict of updated training parameters + + Returns: + Same FinetuneJob object (updated in-place) + + Raises: + ValueError: If job not found or not in a restartable state + """ + restart_t0 = time.perf_counter() + + if job_id not in self.jobs: + raise ValueError(f"Job {job_id} not found") + + job = self.jobs[job_id] + + # Only allow restart if the job is running (serving after training) + if job.status not in [JobStatus.RUNNING, JobStatus.COMPLETED]: + raise ValueError( + f"Job {job_id} is in state {job.status.value} - " + f"can only restart jobs that are RUNNING (serving) or COMPLETED" + ) + + if not job.inference_server_ready: + raise ValueError( + f"Job {job_id} inference server not ready - " + f"training must complete and server must start before restarting" + ) + + # 1. Archive current logs + self.logger.info(f"Archiving logs for job {job_id}...") + archive_t0 = time.perf_counter() + self._archive_job_logs(job) + archive_elapsed = time.perf_counter() - archive_t0 + + signal_data = { + "restart": True, + "timestamp": datetime.now().isoformat(), + "params": updated_params or {} + } + + # 2. Send restart request to running inference server (primary path) + signal_write_mode = "http_control" + write_t0 = time.perf_counter() + http_error = None + if job.inference_server_url: + try: + control_url = job.inference_server_url.rstrip("/") + "/__control__/restart" + response = requests.post(control_url, json=signal_data, timeout=5) + response.raise_for_status() + data = response.json() + if not data.get("success", False): + raise RuntimeError(data.get("error", "Unknown restart control failure")) + self.logger.info(f"Sent restart request via HTTP control endpoint: {control_url}") + except Exception as e: + http_error = e + self.logger.warning(f"HTTP restart control failed for job {job_id}: {e}") + else: + http_error = RuntimeError("No inference_server_url for HTTP restart control") + + # 3. Fallback to signal file if HTTP control endpoint is unavailable + if http_error is not None: + signal_write_mode = "file_signal_fallback" + signal_file = job.output_dir / "restart_signal.json" + with open(signal_file, 'w') as f: + json.dump(signal_data, f, indent=2) + self.logger.info(f"Wrote fallback restart signal to {signal_file}") + write_elapsed = time.perf_counter() - write_t0 + + # 4. Reset training progress (keep inference server info) + job.current_epoch = 0 + job.latest_loss = None + job.status = JobStatus.RUNNING + job.inference_server_ready = False + + # 5. Update stored params + if updated_params: + job.params.update(updated_params) + + total_elapsed = time.perf_counter() - restart_t0 + self.logger.info( + f"Restart signal timings for job {job_id}: " + f"archive={archive_elapsed:.2f}s write={write_elapsed:.2f}s " + f"mode={signal_write_mode} total={total_elapsed:.2f}s" + ) + self.logger.info(f"Job {job_id} restart request sent, waiting for CLI to pick it up") + + return job diff --git a/cellmap_flow/finetune/finetuned_model_templates.py b/cellmap_flow/finetune/finetuned_model_templates.py new file mode 100644 index 0000000..038b391 --- /dev/null +++ b/cellmap_flow/finetune/finetuned_model_templates.py @@ -0,0 +1,100 @@ +""" +Templates for generating finetuned model YAML configurations. + +This module provides functions to auto-generate the YAML config for serving +finetuned models using FinetuneModelConfig (type: finetune). +""" + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def generate_finetuned_model_yaml( + lora_adapter_path: str, + base_model_dict: dict, + model_name: str, + output_path: Path, + data_path: str, + queue: str = "gpu_h100", + charge_group: str = "cellmap", + json_data: dict = None, + scale: str = "s0", +) -> Path: + """ + Generate .yaml configuration for serving a finetuned model. + + The generated YAML uses type: finetune, which delegates to + FinetuneModelConfig. This loads the base model via its own + ModelConfig, applies the LoRA adapter, and serves the result. + + Args: + lora_adapter_path: Path to the saved LoRA adapter directory + base_model_dict: Dict describing the base model (from model_config.to_dict()) + model_name: Name of the finetuned model + output_path: Where to write the .yaml file + data_path: Path to actual dataset (REQUIRED - no placeholders) + queue: LSF queue name + charge_group: LSF charge group + json_data: Optional dict with input_norm and postprocess from base model + scale: Scale level (e.g., "s0", "s1") from base model + + Returns: + Path to the generated YAML file + """ + import yaml as yaml_lib + + if not data_path or data_path == "/path/to/your/data.zarr": + raise ValueError( + "data_path is required and cannot be a placeholder. " + "Must provide actual dataset path from training corrections." + ) + + # Build the model entry + model_entry = { + "type": "finetune", + "name": model_name, + "lora_adapter_path": lora_adapter_path, + "base_model": base_model_dict, + "scale": scale, + } + + # Build the full YAML structure + yaml_dict = { + "data_path": data_path, + "charge_group": charge_group, + "queue": queue, + "models": [model_entry], + } + + # Add json_data (normalization/postprocessing) + if json_data: + yaml_dict["json_data"] = json_data + + # Header comment + header = ( + f"# Finetuned model configuration: {model_name}\n" + f"# Auto-generated by CellMap-Flow finetuning workflow\n" + f"#\n" + ) + if not json_data: + header += ( + "# WARNING: No normalization found in base model!\n" + "# Model may not work correctly without proper normalization.\n" + "#\n" + ) + + # Write to file + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + yaml_content = yaml_lib.dump(yaml_dict, default_flow_style=False, sort_keys=False) + + with open(output_path, "w") as f: + f.write(header) + f.write(yaml_content) + + logger.info(f"Generated finetuned model YAML: {output_path}") + + return output_path diff --git a/cellmap_flow/finetune/lora_trainer.py b/cellmap_flow/finetune/lora_trainer.py new file mode 100644 index 0000000..453256c --- /dev/null +++ b/cellmap_flow/finetune/lora_trainer.py @@ -0,0 +1,985 @@ +""" +LoRA finetuning trainer for CellMap-Flow models. + +This module provides a trainer class for finetuning models using user +corrections with mixed-precision training and gradient accumulation. +""" + +import logging +import math +from pathlib import Path +from typing import Optional, Dict, Any +import time + +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.amp import autocast, GradScaler +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +class DiceLoss(nn.Module): + """ + Dice Loss for segmentation tasks. + + Dice loss is effective for imbalanced datasets where the target class + may be sparse (e.g., mitochondria in EM images). + + Formula: 1 - (2 * |X ∩ Y| + smooth) / (|X| + |Y| + smooth) + """ + + def __init__(self, smooth: float = 1.0): + """ + Args: + smooth: Smoothing factor to avoid division by zero (default: 1.0) + """ + super().__init__() + self.smooth = smooth + self.apply_sigmoid = True + + def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute Dice loss. + + Args: + pred: Predictions (B, C, Z, Y, X) - raw logits or probabilities + target: Targets (B, C, Z, Y, X) - binary masks [0, 1] + mask: Optional mask (B, 1, Z, Y, X) - if provided, only compute loss on masked regions + + Returns: + Dice loss value (scalar) + """ + # Flatten spatial dimensions + pred = pred.reshape(pred.size(0), pred.size(1), -1) # (B, C, N) + target = target.reshape(target.size(0), target.size(1), -1) # (B, C, N) + + if self.apply_sigmoid: + pred = torch.sigmoid(pred) + + # Apply mask if provided. Mask may be (B, 1, ...) for a shared mask + # or (B, C, ...) for a per-channel mask (e.g. AffinityTargetTransform + # produces one mask per affinity offset). + if mask is not None: + mask = mask.reshape(mask.size(0), mask.size(1), -1) # (B, Cmask, N) + pred = pred * mask + target = target * mask + + # Compute intersection and union + intersection = (pred * target).sum(dim=2) # (B, C) + union = pred.sum(dim=2) + target.sum(dim=2) # (B, C) + + # Dice coefficient + dice = (2.0 * intersection + self.smooth) / (union + self.smooth) + + # Dice loss (1 - dice) + return 1.0 - dice.mean() + + +class CombinedLoss(nn.Module): + """ + Combined Dice + BCE loss for better convergence. + + Uses both Dice loss (for overlap) and BCE loss (for pixel-wise accuracy). + """ + + def __init__(self, dice_weight: float = 0.5, bce_weight: float = 0.5): + """ + Args: + dice_weight: Weight for Dice loss + bce_weight: Weight for BCE loss + """ + super().__init__() + self.dice_loss = DiceLoss() + self.bce_loss = nn.BCEWithLogitsLoss(reduction='none') + self.dice_weight = dice_weight + self.bce_weight = bce_weight + + def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Compute combined loss. + + Args: + pred: Predictions (B, C, Z, Y, X) - raw logits + target: Targets (B, C, Z, Y, X) - binary masks [0, 1] + mask: Optional mask (B, 1, Z, Y, X) - if provided, only compute loss on masked regions + + Returns: + Combined loss value (scalar) + """ + dice = self.dice_loss(pred, target, mask) + + # For BCE, manually apply mask if provided + bce = self.bce_loss(pred, target) + if mask is not None: + bce = bce * mask + bce = bce.sum() / mask.sum().clamp(min=1) # Average over masked regions + else: + bce = bce.mean() + + return self.dice_weight * dice + self.bce_weight * bce + + +class MarginLoss(nn.Module): + """ + Margin-based loss for sparse/scribble annotations. + + Only penalizes predictions on the wrong side of a margin threshold. + For post-sigmoid outputs in [0, 1]: + - Foreground (target=1): loss = relu(threshold - pred)^2, threshold = 1 - margin + - Background (target=0): loss = relu(pred - margin)^2 + - No loss when prediction is already correct with sufficient confidence. + """ + + def __init__(self, margin: float = 0.3, balance_classes: bool = False): + super().__init__() + self.margin = margin + self.balance_classes = balance_classes + self.apply_sigmoid = True + + def forward(self, pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.apply_sigmoid: + pred = torch.sigmoid(pred) + + threshold_high = 1.0 - self.margin # e.g., 0.7 + threshold_low = self.margin # e.g., 0.3 + + # Foreground loss: penalize if pred < threshold_high + fg_loss = torch.relu(threshold_high - pred) ** 2 + # Background loss: penalize if pred > threshold_low + bg_loss = torch.relu(pred - threshold_low) ** 2 + + if self.balance_classes and mask is not None: + # Average each class separately so fg/bg contribute equally + # regardless of how many scribble voxels each has + fg_mask = target * mask + bg_mask = (1.0 - target) * mask + fg_count = fg_mask.sum().clamp(min=1) + bg_count = bg_mask.sum().clamp(min=1) + fg_contrib = (fg_loss * fg_mask).sum() / fg_count + bg_contrib = (bg_loss * bg_mask).sum() / bg_count + return (fg_contrib + bg_contrib) / 2.0 + + # Blend by target: target=1 -> fg_loss, target=0 -> bg_loss + loss = target * fg_loss + (1.0 - target) * bg_loss + + if mask is not None: + loss = loss * mask + return loss.sum() / mask.sum().clamp(min=1) + return loss.mean() + + +class LoRAFinetuner: + """ + Trainer for finetuning models with LoRA adapters. + + Features: + - Mixed precision (FP16) training for memory efficiency + - Gradient accumulation to simulate larger batch sizes + - Checkpointing with best model tracking + - Progress logging + - Partial annotation support (mask unannotated regions) + + Args: + model: PEFT model with LoRA adapters + dataloader: DataLoader for training data + output_dir: Directory to save checkpoints and logs + learning_rate: Learning rate (default: 1e-4) + num_epochs: Number of training epochs (default: 10) + gradient_accumulation_steps: Steps to accumulate gradients (default: 1) + use_mixed_precision: Enable FP16 training (default: True) + loss_type: Loss function ("dice", "bce", or "combined") + device: Training device ("cuda" or "cpu", auto-detected if None) + select_channel: Optional channel index to select from multi-channel output (default: None) + mask_unannotated: If True (default), only compute loss on annotated regions (target > 0). + Targets are shifted down by 1 (e.g., 1->0, 2->1) after masking. + This allows partial annotations where 0=unannotated, 1=background, 2=foreground, etc. + Ignored if target_transform is provided. + target_transform: Optional TargetTransform instance that converts raw annotations + to (target, mask) pairs. Overrides mask_unannotated when provided. + See cellmap_flow.finetune.target_transforms. + + Examples: + >>> lora_model = wrap_model_with_lora(model) + >>> dataloader = create_dataloader("corrections.zarr") + >>> trainer = LoRAFinetuner( + ... lora_model, + ... dataloader, + ... output_dir="output/fly_organelles_v1.1" + ... ) + >>> trainer.train() + >>> trainer.save_adapter() + """ + + def __init__( + self, + model: nn.Module, + dataloader: DataLoader, + output_dir: str, + learning_rate: float = 1e-4, + num_epochs: int = 10, + gradient_accumulation_steps: int = 1, + use_mixed_precision: bool = True, + loss_type: str = "combined", + device: Optional[str] = None, + select_channel: Optional[int] = None, + mask_unannotated: bool = True, + label_smoothing: float = 0.0, + distillation_lambda: float = 0.0, + distillation_all_voxels: bool = False, + margin: float = 0.3, + balance_classes: bool = False, + target_transform=None, + ): + self.model = model + self.dataloader = dataloader + self.output_dir = Path(output_dir) + self.num_epochs = num_epochs + self.gradient_accumulation_steps = gradient_accumulation_steps + self.use_mixed_precision = use_mixed_precision + self.select_channel = select_channel + self.mask_unannotated = mask_unannotated + self.label_smoothing = label_smoothing + self.distillation_lambda = distillation_lambda + self.distillation_all_voxels = distillation_all_voxels + self.balance_classes = balance_classes + self.target_transform = target_transform + + # Create output directory + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Device + if device is None: + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.device = torch.device(device) + + logger.info(f"Using device: {self.device}") + + # Move model to device + self.model = self.model.to(self.device) + + # Optimizer (only LoRA parameters) + self.optimizer = AdamW( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=learning_rate, + ) + + # Loss function + self._use_bce = False + self._use_mse = False + if loss_type == "dice": + self.criterion = DiceLoss() + elif loss_type == "bce": + # Use reduction='none' so we can manually apply mask if needed + self.criterion = nn.BCEWithLogitsLoss(reduction='none') + self._use_bce = True + elif loss_type == "combined": + self.criterion = CombinedLoss() + elif loss_type == "mse": + self.criterion = nn.MSELoss(reduction='none') + self._use_mse = True + elif loss_type == "margin": + self.criterion = MarginLoss(margin=margin, balance_classes=balance_classes) + else: + raise ValueError(f"Unknown loss_type: {loss_type}") + + # Label smoothing is redundant with margin loss + if loss_type == "margin" and self.label_smoothing > 0: + logger.warning("Label smoothing is redundant with margin loss, setting to 0") + self.label_smoothing = 0.0 + + if self.balance_classes: + logger.info("Class balancing enabled: fg and bg scribble voxels weighted equally") + + logger.info(f"Using {loss_type} loss") + if self.label_smoothing > 0: + logger.info(f"Label smoothing: {self.label_smoothing} (targets: {self.label_smoothing/2:.3f} to {1-self.label_smoothing/2:.3f})") + if self.distillation_lambda > 0: + # FX-interpreted models (torch.export UnflattenedModule, often + # wrapped in BatchLoopWrapper) keep every intermediate tensor + # alive in the FX env, so distillation's two passes can OOM + # even on H100/A100 80GB. We don't auto-disable here — request + # a larger node if needed; the OOM handler in train() will + # disable it as a last resort if memory actually runs out. + inner = getattr(self.model, "model", self.model) + base = getattr(inner, "model", inner) + if type(base).__name__ in ("UnflattenedModule",) or type(inner).__name__ == "BatchLoopWrapper": + logger.warning( + "Distillation enabled with an FX-interpreted base model " + "(UnflattenedModule). This requires substantial GPU memory; " + "if you OOM, distillation will be disabled automatically as " + "a fallback in the OOM handler." + ) + scope_str = "all voxels" if self.distillation_all_voxels else "unlabeled voxels only" + logger.info(f"Teacher distillation enabled: lambda={self.distillation_lambda} ({scope_str})") + + # Mixed precision scaler + self.scaler = GradScaler('cuda', enabled=use_mixed_precision) + + # Training state + self.current_epoch = 0 + self.global_step = 0 + self.best_loss = float('inf') + self.training_stats = [] + + def _fallback_to_fp32(self): + """Disable mixed precision training.""" + self.use_mixed_precision = False + self.scaler = GradScaler('cuda', enabled=False) + torch.cuda.empty_cache() + + def _reset_training_state(self): + """Reset LoRA weights, optimizer, and training counters for a fresh start.""" + from peft import PeftModel + if isinstance(self.model, PeftModel): + # Reset LoRA adapter weights to zero (equivalent to base model) + for name, param in self.model.named_parameters(): + if 'lora_' in name and param.requires_grad: + nn.init.zeros_(param) if 'lora_B' in name else nn.init.kaiming_uniform_(param, a=math.sqrt(5)) + self.optimizer = AdamW( + filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.optimizer.defaults['lr'], + ) + self.current_epoch = 0 + self.global_step = 0 + self.best_loss = float('inf') + self.training_stats = [] + + def _halve_batch_size(self): + """Halve batch size and double gradient accumulation to keep effective batch size. + + Returns True if batch size was reduced, False if already at 1. + """ + old_bs = self.dataloader.batch_size + new_bs = max(1, old_bs // 2) + if new_bs >= old_bs: + return False + old_accum = self.gradient_accumulation_steps + self.gradient_accumulation_steps = old_accum * (old_bs // new_bs) + self.dataloader = DataLoader( + self.dataloader.dataset, + batch_size=new_bs, + shuffle=True, + num_workers=self.dataloader.num_workers, + pin_memory=self.dataloader.pin_memory, + multiprocessing_context=self.dataloader.multiprocessing_context, + ) + self._log_message( + f"Halved batch size {old_bs} → {new_bs}, " + f"gradient accumulation {old_accum} → {self.gradient_accumulation_steps} " + f"(effective batch size unchanged)" + ) + torch.cuda.empty_cache() + return True + + def _model_cache_targets(self): + """Return model objects that may survive LoRA unwrap/rewrap cycles.""" + targets = [] + seen = set() + stack = [self.model] + while stack: + obj = stack.pop(0) + if obj is None or id(obj) in seen: + continue + seen.add(id(obj)) + targets.append(obj) + for attr in ("base_model", "model", "module"): + child = getattr(obj, attr, None) + if child is not None and id(child) not in seen: + stack.append(child) + return targets + + def _sigmoid_cache_key(self): + if self.select_channel is not None: + return ("channel", self.select_channel) + return ("all", None) + + def _get_cached_model_has_sigmoid(self) -> Optional[bool]: + key = self._sigmoid_cache_key() + for target in self._model_cache_targets(): + cache = getattr(target, "_cellmap_flow_model_has_sigmoid_cache", None) + if isinstance(cache, dict) and key in cache: + return bool(cache[key]) + return None + + def _cache_model_has_sigmoid(self, value: bool): + key = self._sigmoid_cache_key() + for target in self._model_cache_targets(): + try: + cache = getattr( + target, "_cellmap_flow_model_has_sigmoid_cache", None + ) + if not isinstance(cache, dict): + cache = {} + setattr(target, "_cellmap_flow_model_has_sigmoid_cache", cache) + cache[key] = bool(value) + except Exception: + pass + + def _apply_probability_output_mode(self, log_message): + """Configure losses for models that already emit probabilities.""" + if self._use_bce: + log_message( + "Switching BCEWithLogitsLoss to BCELoss to avoid double-sigmoid" + ) + self.criterion = nn.BCELoss(reduction='none') + if hasattr(self.criterion, 'bce_loss'): + self.criterion.bce_loss = nn.BCELoss(reduction='none') + # Tell DiceLoss/MarginLoss to skip their sigmoid + if hasattr(self.criterion, 'apply_sigmoid'): + self.criterion.apply_sigmoid = False + if ( + hasattr(self.criterion, 'dice_loss') + and hasattr(self.criterion.dice_loss, 'apply_sigmoid') + ): + self.criterion.dice_loss.apply_sigmoid = False + + def train(self) -> Dict[str, Any]: + """ + Run the training loop. + + Returns: + Training statistics dictionary with: + - final_loss: Final epoch loss + - best_loss: Best loss achieved + - total_epochs: Number of epochs trained + - total_steps: Total training steps + """ + # Create log file + log_file = self.output_dir / "training_log.txt" + + def log_message(msg): + """Log to console (tee handles writing to log file).""" + print(msg, flush=True) + + log_message("="*60) + log_message("Starting LoRA Finetuning") + log_message("="*60) + log_message(f"Epochs: {self.num_epochs}") + log_message(f"Batches per epoch: {len(self.dataloader)}") + log_message(f"Gradient accumulation: {self.gradient_accumulation_steps}") + log_message(f"Effective batch size: {self.dataloader.batch_size * self.gradient_accumulation_steps}") + log_message(f"Mixed precision: {self.use_mixed_precision}") + log_message(f"Mask unannotated regions: {self.mask_unannotated}") + log_message(f"Log file: {log_file}") + log_message("") + + self.model.train() + start_time = time.time() + + # Store log function for use in _train_epoch and helpers + self._log_message = log_message + + # Probe for FP16 stability: run a single forward pass and check for NaN. + # Some model+data combinations produce NaN under FP16 autocast. + if self.use_mixed_precision: + try: + probe_raw, _ = next(iter(self.dataloader)) + probe_raw = probe_raw[:1] + probe_raw = probe_raw.to(self.device) + with torch.no_grad(), autocast('cuda', enabled=True): + probe_out = self.model(probe_raw) + if not torch.isfinite(probe_out).all(): + log_message("WARNING: Model produces NaN/Inf under FP16 — falling back to FP32.") + self._fallback_to_fp32() + del probe_raw, probe_out + torch.cuda.empty_cache() + except torch.cuda.OutOfMemoryError: + log_message("WARNING: FP16 probe OOM — falling back to FP32 with smaller batch.") + self._fallback_to_fp32() + self._halve_batch_size() + except Exception as e: + log_message(f"WARNING: FP16 probe failed ({e}) — falling back to FP32.") + self._fallback_to_fp32() + + # Probe for built-in sigmoid: if model outputs are bounded to [0,1] + # even with extreme inputs, the model has sigmoid baked in. + # In that case, switch BCEWithLogitsLoss to BCELoss to avoid double-sigmoid, + # and tell DiceLoss/MarginLoss to skip their sigmoid. + cached_model_has_sigmoid = self._get_cached_model_has_sigmoid() + if cached_model_has_sigmoid is not None: + model_has_sigmoid = cached_model_has_sigmoid + if model_has_sigmoid: + log_message("Using cached built-in sigmoid detection") + self._apply_probability_output_mode(log_message) + else: + try: + probe_raw, _ = next(iter(self.dataloader)) + probe_raw = probe_raw[:1] + probe_extreme = torch.randn( + probe_raw.shape, + dtype=probe_raw.dtype, + device=self.device, + ) * 100 + with torch.no_grad(), autocast( + 'cuda', enabled=self.use_mixed_precision + ): + probe_out = self.model(probe_extreme) + if self.select_channel is not None: + probe_out = probe_out[ + :, + self.select_channel : self.select_channel + 1, + :, + :, + :, + ] + model_has_sigmoid = bool( + ((probe_out.min() >= 0) & (probe_out.max() <= 1)).item() + ) + self._cache_model_has_sigmoid(model_has_sigmoid) + if model_has_sigmoid: + log_message("Detected built-in sigmoid in model output") + self._apply_probability_output_mode(log_message) + del probe_extreme, probe_out + torch.cuda.empty_cache() + except Exception as e: + log_message( + f"WARNING: Sigmoid probe failed ({e}) — assuming raw logits output." + ) + torch.cuda.empty_cache() + + stop_signal_path = self.output_dir / "stop_signal.json" + # Make sure no stale signal from a previous run lingers. + try: + if stop_signal_path.exists(): + stop_signal_path.unlink() + except Exception: + pass + + for epoch in range(self.num_epochs): + self.current_epoch = epoch + # User-requested graceful stop: drop out of the training loop so + # the outer flow (inference server + wait for restart) kicks in. + if stop_signal_path.exists(): + log_message( + f"Stop signal received at epoch {epoch+1}/{self.num_epochs}; " + f"exiting training loop early." + ) + try: + stop_signal_path.unlink() + except Exception: + pass + break + log_message(f"Starting epoch {epoch+1} of {self.num_epochs}...") + # Mitigation loop: keep applying mitigations (halve batch, then + # disable distillation) and retrying until the epoch succeeds or + # there's nothing left to try. A single try/except wasn't enough — + # users can restart with a much larger lora_r/batch_size combo and + # the second attempt also OOMs, so we need to iterate. + epoch_loss = None + while True: + try: + epoch_loss = self._train_epoch() + break + except torch.cuda.OutOfMemoryError: + torch.cuda.empty_cache() + mitigated = False + if self._halve_batch_size(): + log_message( + f"OOM at epoch {epoch+1} — retrying with smaller batch size " + f"(now {self.dataloader.batch_size})." + ) + mitigated = True + elif self.distillation_lambda > 0: + log_message( + f"OOM at epoch {epoch+1} and batch already at 1 — " + f"disabling distillation (was lambda={self.distillation_lambda}) and retrying." + ) + self.distillation_lambda = 0 + mitigated = True + if not mitigated: + log_message("ERROR: OOM at batch=1 with no distillation. Cannot continue.") + return { + 'final_loss': float('nan'), + 'best_loss': self.best_loss, + 'total_epochs': epoch, + 'total_steps': self.global_step, + 'training_time': time.time() - start_time, + 'diverged': True, + } + # Reset optimizer state (accumulated grads are stale after OOM) + self.optimizer.zero_grad(set_to_none=True) + torch.cuda.empty_cache() + + # Handle NaN/Inf loss + if not math.isfinite(epoch_loss): + if self.use_mixed_precision: + # NaN likely caused by FP16 overflow on specific data — + # fall back to FP32 and restart training from scratch + log_message( + f"WARNING: NaN loss at epoch {epoch+1} under FP16 — " + f"falling back to FP32 and restarting training." + ) + self._fallback_to_fp32() + self._reset_training_state() + return self.train() + + self._log_message( + f"ERROR: Loss is {epoch_loss} at epoch {epoch+1}. " + f"Stopping training." + ) + print("TRAINING_DIVERGED", flush=True) + return { + 'final_loss': epoch_loss, + 'best_loss': self.best_loss, + 'total_epochs': epoch + 1, + 'total_steps': self.global_step, + 'training_time': time.time() - start_time, + 'diverged': True, + } + + # Log epoch results + self._log_message( + f"Epoch {epoch+1}/{self.num_epochs} - " + f"Loss: {epoch_loss:.6f} - " + f"Best: {self.best_loss:.6f}" + ) + + # Save checkpoint if best + if epoch_loss < self.best_loss: + self.best_loss = epoch_loss + self._log_message(" Saving best checkpoint...") + self.save_checkpoint(is_best=True) + self._log_message(f" → Saved best checkpoint") + + # Save regular checkpoint every 5 epochs + if (epoch + 1) % 5 == 0: + self.save_checkpoint(is_best=False) + + self.training_stats.append({ + 'epoch': epoch + 1, + 'loss': epoch_loss, + 'best_loss': self.best_loss, + }) + + # Final checkpoint + self.save_checkpoint(is_best=False) + + total_time = time.time() - start_time + self._log_message("") + self._log_message("="*60) + self._log_message("Training Complete!") + self._log_message(f"Total time: {total_time/60:.2f} minutes") + self._log_message(f"Best loss: {self.best_loss:.6f}") + self._log_message(f"Final loss: {epoch_loss:.6f}") + self._log_message(f"Output directory: {self.output_dir}") + self._log_message("="*60) + + return { + 'final_loss': epoch_loss, + 'best_loss': self.best_loss, + 'total_epochs': self.num_epochs, + 'total_steps': self.global_step, + 'training_time': total_time, + } + + def _train_epoch(self) -> float: + """Train for one epoch and return average loss.""" + epoch_loss = 0.0 + epoch_supervised_loss = 0.0 + epoch_distill_loss = 0.0 + num_batches = len(self.dataloader) + + # Gradient-flow diagnostic: watch one LoRA-B param across the epoch + # AND, at end of epoch, count how many trainable params received any + # gradient at all. Together they answer: + # - is backward reaching LoRA at all? (per-param mean|grad|) + # - if yes for some, which ones? (zero-grad count + sample names) + diag_param_name = None + diag_param = None + for name, param in self.model.named_parameters(): + if param.requires_grad and "lora_B" in name: + diag_param_name = name + diag_param = param + break + diag_param_initial = ( + diag_param.detach().clone() if diag_param is not None else None + ) + diag_grad_abs_sum = 0.0 + diag_grad_count = 0 + + # Track zero-grad status across all trainable params for the LAST + # batch of the epoch (cumulative grad before zero_grad fires). + diag_param_grad_seen_nonzero: dict[str, bool] = {} + + for batch_idx, (raw, target) in enumerate(self.dataloader): + # Move to device + raw = raw.to(self.device, non_blocking=True) + target = target.to(self.device, non_blocking=True) + + # Handle partial annotations: create mask and shift labels + mask = None + if self.target_transform is not None: + target, mask = self.target_transform(target) + elif self.mask_unannotated: + # Legacy behavior: binary single-channel + mask = (target > 0).float() # (B, C, Z, Y, X) + # Shift labels down by 1 (but keep 0 as 0) + # e.g., 0->0 (unannotated), 1->0 (background), 2->1 (foreground) + target = torch.clamp(target - 1, min=0) + + # Apply label smoothing: 0 -> s/2, 1 -> 1-s/2 + # This prevents the model from being pushed to extreme 0/1 outputs, + # preserving gradual distance-like predictions + if self.label_smoothing > 0: + target = target * (1 - self.label_smoothing) + self.label_smoothing / 2 + + # Teacher forward pass for distillation (before student pass) + # Uses the base model without LoRA adapters as the teacher + teacher_pred = None + if self.distillation_lambda > 0: + with torch.no_grad(): + self.model.disable_adapter_layers() + try: + with autocast('cuda', enabled=self.use_mixed_precision): + teacher_pred = self.model(raw) + if self.select_channel is not None: + teacher_pred = teacher_pred[:, self.select_channel:self.select_channel+1, :, :, :] + teacher_pred = teacher_pred.detach() + finally: + self.model.enable_adapter_layers() + if not torch.isfinite(teacher_pred).all(): + logger.warning(f"NaN/Inf in teacher_pred! range=[{teacher_pred.min():.4f}, {teacher_pred.max():.4f}]") + + # Student forward pass with mixed precision + with autocast('cuda', enabled=self.use_mixed_precision): + pred = self.model(raw) + + if not torch.isfinite(pred).all(): + logger.warning(f"NaN/Inf in student pred! range=[{pred.min():.4f}, {pred.max():.4f}]") + + # Select specific channel if requested (e.g., mito = channel 2 from 8-channel output) + if self.select_channel is not None: + pred = pred[:, self.select_channel:self.select_channel+1, :, :, :] + + # Compute supervised loss with optional mask + if (self._use_bce or self._use_mse) and mask is not None: + # For per-element losses (BCE, MSE), manually apply mask + per_element_loss = self.criterion(pred, target) + if self.balance_classes: + # Average fg and bg separately so each contributes equally + fg_mask = target * mask + bg_mask = (1.0 - target) * mask + fg_count = fg_mask.sum().clamp(min=1) + bg_count = bg_mask.sum().clamp(min=1) + fg_contrib = (per_element_loss * fg_mask).sum() / fg_count + bg_contrib = (per_element_loss * bg_mask).sum() / bg_count + supervised_loss = (fg_contrib + bg_contrib) / 2.0 + else: + supervised_loss = (per_element_loss * mask).sum() / mask.sum().clamp(min=1) + elif hasattr(self.criterion, 'forward') and 'mask' in self.criterion.forward.__code__.co_varnames: + # For custom losses that support masking (DiceLoss, CombinedLoss, MarginLoss) + supervised_loss = self.criterion(pred, target, mask) + else: + # No masking needed + supervised_loss = self.criterion(pred, target) + if self._use_bce or self._use_mse: + supervised_loss = supervised_loss.mean() + + loss = supervised_loss + + if not torch.isfinite(supervised_loss): + logger.warning(f"NaN/Inf supervised_loss: {supervised_loss.item()}") + + # Compute distillation loss + distillation_loss = torch.tensor(0.0, device=self.device) + if self.distillation_lambda > 0 and teacher_pred is not None: + distill_loss_map = (pred - teacher_pred) ** 2 # per-element MSE + if self.distillation_all_voxels or mask is None: + # Apply on all voxels + distillation_loss = distill_loss_map.mean() + else: + # Apply only on unlabeled voxels. + # Cast to float32 before multiply/sum to avoid FP16 overflow + # when summing over many voxels (e.g., 13-channel models). + unlabeled_mask = (1.0 - mask).float() + distillation_loss = (distill_loss_map.float() * unlabeled_mask).sum() / unlabeled_mask.sum().clamp(min=1) + if not torch.isfinite(distillation_loss): + logger.warning(f"NaN/Inf distillation_loss: {distillation_loss.item()}") + loss = loss + self.distillation_lambda * distillation_loss + + # Scale loss for gradient accumulation + loss = loss / self.gradient_accumulation_steps + + # Backward pass + self.scaler.scale(loss).backward() + + # Diagnostic: capture grad on the watched LoRA param BEFORE the + # optimizer step (zero_grad clears it). Also note which trainable + # params have ever seen a nonzero gradient this epoch so we can + # report the dead ones at the end. + if diag_param is not None and diag_param.grad is not None: + diag_grad_abs_sum += diag_param.grad.detach().abs().mean().item() + diag_grad_count += 1 + for name, p in self.model.named_parameters(): + if not p.requires_grad: + continue + if diag_param_grad_seen_nonzero.get(name, False): + continue + if p.grad is not None and p.grad.detach().abs().sum().item() > 0: + diag_param_grad_seen_nonzero[name] = True + + # Update weights after accumulation + if (batch_idx + 1) % self.gradient_accumulation_steps == 0: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + self.global_step += 1 + + # Accumulate losses (unscaled) + batch_loss = loss.item() * self.gradient_accumulation_steps + if not math.isfinite(batch_loss): + logger.warning(f"NaN/Inf loss at epoch {self.current_epoch+1}, batch {batch_idx+1}. Aborting epoch.") + return float('nan') + epoch_loss += batch_loss + epoch_supervised_loss += supervised_loss.item() + epoch_distill_loss += distillation_loss.item() + + # Log progress every batch (since we have few batches) + avg_loss = epoch_loss / (batch_idx + 1) + if hasattr(self, '_log_message'): + if self.distillation_lambda > 0: + avg_sup = epoch_supervised_loss / (batch_idx + 1) + avg_distill = epoch_distill_loss / (batch_idx + 1) + self._log_message( + f" Batch {batch_idx+1}/{num_batches} - " + f"Loss: {avg_loss:.6f} (sup: {avg_sup:.6f}, distill: {avg_distill:.6f})" + ) + else: + self._log_message( + f" Batch {batch_idx+1}/{num_batches} - " + f"Loss: {avg_loss:.6f}" + ) + else: + # Fallback if _log_message not set + msg = f" Batch {batch_idx+1}/{num_batches} - Loss: {avg_loss:.6f}" + print(msg) + logger.info(msg) + + # Handle leftover accumulated gradients at end of epoch + # (in case num_batches is not divisible by gradient_accumulation_steps) + if num_batches % self.gradient_accumulation_steps != 0: + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + self.global_step += 1 + + # Diagnostic summary: per-watched-param grad/delta + counts of + # trainable params that received any nonzero gradient this epoch. + if diag_param is not None and hasattr(self, "_log_message"): + mean_grad = ( + diag_grad_abs_sum / diag_grad_count if diag_grad_count else 0.0 + ) + param_delta = ( + (diag_param.detach() - diag_param_initial).abs().mean().item() + if diag_param_initial is not None + else 0.0 + ) + self._log_message( + f" [diag] {diag_param_name}: " + f"mean|grad|={mean_grad:.3e} (over {diag_grad_count} batches), " + f"mean|param_delta|={param_delta:.3e} this epoch" + ) + + if hasattr(self, "_log_message"): + n_trainable = sum( + 1 for _, p in self.model.named_parameters() if p.requires_grad + ) + n_live = sum(1 for v in diag_param_grad_seen_nonzero.values() if v) + n_dead = n_trainable - n_live + dead_names = [ + name for name, p in self.model.named_parameters() + if p.requires_grad and not diag_param_grad_seen_nonzero.get(name) + ] + self._log_message( + f" [diag] gradient flow: {n_live}/{n_trainable} trainable " + f"params got nonzero grad; {n_dead} are dead. " + f"First 5 dead: {dead_names[:5]}" + ) + + return epoch_loss / num_batches + + def save_checkpoint(self, is_best: bool = False): + """ + Save training checkpoint. + + Args: + is_best: If True, saves as "best_model.pth" + """ + checkpoint_name = "best_checkpoint.pth" if is_best else f"checkpoint_epoch_{self.current_epoch+1}.pth" + checkpoint_path = self.output_dir / checkpoint_name + + # Save only trainable (LoRA) parameters to avoid writing the full + # 800M+ param base model to disk every checkpoint. + trainable_keys = {n for n, p in self.model.named_parameters() if p.requires_grad} + trainable_state = {k: v for k, v in self.model.state_dict().items() if k in trainable_keys} + checkpoint = { + 'epoch': self.current_epoch, + 'global_step': self.global_step, + 'model_state_dict': trainable_state, + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scaler_state_dict': self.scaler.state_dict(), + 'best_loss': self.best_loss, + 'training_stats': self.training_stats, + 'lora_only': True, + } + + torch.save(checkpoint, checkpoint_path) + logger.debug(f"Checkpoint saved: {checkpoint_path}") + + def save_adapter(self, adapter_path: Optional[str] = None): + """ + Save only the LoRA adapter (not the full model). + + Automatically loads the best checkpoint weights before saving + so the exported adapter reflects the best training epoch. + + Args: + adapter_path: Path to save adapter. If None, uses output_dir/lora_adapter + """ + from cellmap_flow.finetune.lora_wrapper import save_lora_adapter + + if adapter_path is None: + adapter_path = str(self.output_dir / "lora_adapter") + + # Load best checkpoint weights before saving + best_ckpt = self.output_dir / "best_checkpoint.pth" + if best_ckpt.exists(): + checkpoint = torch.load(best_ckpt, map_location=self.device) + if checkpoint.get('lora_only', False): + self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) + else: + self.model.load_state_dict(checkpoint['model_state_dict']) + logger.info(f"Loaded best checkpoint (epoch {checkpoint['epoch'] + 1}, loss {checkpoint['best_loss']:.6f}) before saving adapter") + else: + logger.warning("No best checkpoint found, saving adapter from final epoch weights") + + save_lora_adapter(self.model, adapter_path) + logger.info(f"LoRA adapter saved to: {adapter_path}") + + def load_checkpoint(self, checkpoint_path: str): + """ + Load training checkpoint to resume training. + + Args: + checkpoint_path: Path to checkpoint file + """ + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + if checkpoint.get('lora_only', False): + # Checkpoint contains only trainable (LoRA) params — merge into full state + self.model.load_state_dict(checkpoint['model_state_dict'], strict=False) + else: + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.scaler.load_state_dict(checkpoint['scaler_state_dict']) + + self.current_epoch = checkpoint['epoch'] + self.global_step = checkpoint['global_step'] + self.best_loss = checkpoint['best_loss'] + self.training_stats = checkpoint.get('training_stats', []) + + logger.info(f"Checkpoint loaded from: {checkpoint_path}") + logger.info(f"Resuming from epoch {self.current_epoch+1}") diff --git a/cellmap_flow/finetune/lora_wrapper.py b/cellmap_flow/finetune/lora_wrapper.py new file mode 100644 index 0000000..3098d28 --- /dev/null +++ b/cellmap_flow/finetune/lora_wrapper.py @@ -0,0 +1,464 @@ +""" +Generic LoRA wrapper for PyTorch models. + +This module provides automatic detection of adaptable layers and wraps +PyTorch models with LoRA (Low-Rank Adaptation) adapters using the +HuggingFace PEFT library. + +LoRA enables efficient finetuning by training only a small number of +additional parameters (typically 1-2% of the original model) while +keeping the base model frozen. +""" + +import logging +from typing import List, Optional, Union +import torch +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +def detect_adaptable_layers( + model: nn.Module, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, +) -> List[str]: + """ + Automatically detect layers suitable for LoRA adaptation. + + Searches for Conv2d, Conv3d, and Linear layers, filtering by name patterns. + By default, only excludes batch/layer-norm style modules. Output/head + layers are deliberately INCLUDED so the model can fully adapt its + feature→output mapping for cross-domain finetuning. (Previously + 'final', 'head', 'output' were excluded; that left the output projection + frozen, which prevented learning when the base model's predictions on + the target dataset were poor.) + + Args: + model: PyTorch model to inspect + include_patterns: List of regex patterns for layer names to include + If None, includes all Conv/Linear layers + exclude_patterns: List of substrings for layer names to exclude + Default: ['bn', 'norm'] + + Returns: + List of layer names suitable for LoRA adaptation + """ + import re + + if exclude_patterns is None: + exclude_patterns = ['bn', 'norm'] + + adaptable = [] + + for name, module in model.named_modules(): + is_adaptable = isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)) + + # Fallback: detect by parameter shape (e.g. InterpreterModule from torch.export) + if not is_adaptable and hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): + is_adaptable = module.weight.ndim >= 2 + + if not is_adaptable: + continue + + # Apply include patterns if specified + if include_patterns is not None: + if not any(re.match(pattern, name) for pattern in include_patterns): + continue + + # Apply exclude patterns + if any(exclude in name.lower() for exclude in exclude_patterns): + logger.debug(f"Excluding layer: {name} (matched exclude pattern)") + continue + + adaptable.append(name) + + logger.info(f"Detected {len(adaptable)} adaptable layers") + if len(adaptable) > 0: + logger.debug(f"Adaptable layers: {adaptable[:5]}..." if len(adaptable) > 5 else f"Adaptable layers: {adaptable}") + + return adaptable + + +def _replace_interpreter_modules(model: nn.Module) -> int: + """Replace non-standard leaf modules (e.g. InterpreterModule from torch.export + unflatten) with real nn.Conv*/nn.Linear that share the same weight/bias tensors. + + PEFT's dispatch only accepts nn.Conv1d/2d/3d, nn.Linear, etc., so unflattened + modules need to be swapped before LoRA wrapping. The FX graph's call_module + will invoke whatever module is registered under the name, so the swap doesn't + break the forward pass. + + Returns the number of modules replaced. + """ + count = 0 + for name, module in list(model.named_modules()): + if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)): + continue + if not hasattr(module, 'weight') or not isinstance(module.weight, torch.Tensor): + continue + + w = module.weight + b = getattr(module, 'bias', None) + if b is not None and not isinstance(b, torch.Tensor): + b = None + + if w.ndim == 5: + out_c, in_c, kz, ky, kx = w.shape + new_mod = nn.Conv3d(in_c, out_c, (kz, ky, kx), padding=0, bias=(b is not None)) + elif w.ndim == 4: + out_c, in_c, ky, kx = w.shape + new_mod = nn.Conv2d(in_c, out_c, (ky, kx), padding=0, bias=(b is not None)) + elif w.ndim == 3: + out_c, in_c, k = w.shape + new_mod = nn.Conv1d(in_c, out_c, k, padding=0, bias=(b is not None)) + elif w.ndim == 2: + out_f, in_f = w.shape + new_mod = nn.Linear(in_f, out_f, bias=(b is not None)) + else: + continue + + new_mod.weight = nn.Parameter(w) + if b is not None: + new_mod.bias = nn.Parameter(b) + + parts = name.split('.') + parent = model + for p in parts[:-1]: + parent = getattr(parent, p) + setattr(parent, parts[-1], new_mod) + count += 1 + + if count > 0: + logger.info(f"Replaced {count} non-standard modules with nn.Conv/Linear for PEFT compatibility") + return count + + +class BatchLoopWrapper(nn.Module): + """Wraps a model with fixed batch_size=1 (e.g. UnflattenedModule from + torch.export without dynamic shapes) so it accepts arbitrary batch sizes + by looping over the batch dim. + """ + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, x, *args, **kwargs): + if x.shape[0] == 1: + return self.model(x, *args, **kwargs) + outs = [self.model(x[i:i + 1], *args, **kwargs) for i in range(x.shape[0])] + return torch.cat(outs, dim=0) + + +class SequentialWrapper(nn.Module): + """ + Wrapper for Sequential models to make them compatible with PEFT. + + PEFT expects models to accept **kwargs, but Sequential only accepts + positional args. This wrapper provides that interface. + """ + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + def forward(self, x=None, input_ids=None, **kwargs): + # PEFT may pass input as 'input_ids' kwarg for transformers + # For vision models, we expect 'x' as positional or kwarg + if x is None and input_ids is not None: + x = input_ids + if x is None: + raise ValueError("Input tensor not provided") + # Ignore other kwargs and just pass x + return self.model(x) + + +def wrap_model_with_lora( + model: nn.Module, + target_modules: Optional[List[str]] = None, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.1, + modules_to_save: Optional[List[str]] = None, + task_type: Optional[str] = None, +) -> nn.Module: + """ + Wrap a PyTorch model with LoRA adapters using HuggingFace PEFT. + + This creates a PEFT model with LoRA adapters on specified layers. + The base model is frozen, and only LoRA parameters are trainable. + + Args: + model: PyTorch model to wrap (e.g., UNet, CNN) + target_modules: List of layer names to adapt. If None, auto-detects. + lora_r: LoRA rank (number of low-rank dimensions) + Higher = more capacity, more parameters + Typical values: 4-32, default 8 + lora_alpha: LoRA alpha (scaling factor) + Controls strength of LoRA updates + Typical: 2*r, default 16 + lora_dropout: Dropout probability for LoRA layers (0.0-0.5, default 0.1) + modules_to_save: Additional modules to make trainable (e.g., final layer) + task_type: PEFT task type. Options: + - "FEATURE_EXTRACTION" (default, for general models) + - "SEQ_CLS" (sequence classification) + - "TOKEN_CLS" (token classification) + - "CAUSAL_LM" (causal language modeling) + + Returns: + PEFT model with LoRA adapters + + Raises: + ImportError: If peft library is not installed + ValueError: If no adaptable layers found + + Examples: + >>> # Auto-detect and wrap all Conv/Linear layers + >>> lora_model = wrap_model_with_lora(model, lora_r=8) + + >>> # Wrap specific layers with custom config + >>> lora_model = wrap_model_with_lora( + ... model, + ... target_modules=["encoder.conv1", "encoder.conv2"], + ... lora_r=16, + ... lora_alpha=32, + ... modules_to_save=["final_conv"] + ... ) + + >>> # Check trainable parameters + >>> print_lora_parameters(lora_model) + """ + try: + from peft import LoraConfig, get_peft_model, TaskType + except ImportError: + raise ImportError( + "peft library is required for LoRA finetuning. " + "Install with: pip install peft" + ) + + # Wrap Sequential models to make them compatible with PEFT + if isinstance(model, nn.Sequential): + logger.info("Wrapping Sequential model for PEFT compatibility") + model = SequentialWrapper(model) + + # Replace any non-standard leaf modules (e.g. InterpreterModule) with + # real nn.Conv*/Linear so PEFT's dispatch can wrap them. + _replace_interpreter_modules(model) + + # Auto-detect target modules if not specified + if target_modules is None: + target_modules = detect_adaptable_layers(model) + if len(target_modules) == 0: + raise ValueError( + "No adaptable layers found in model. " + "Specify target_modules manually or check model architecture." + ) + logger.info(f"Auto-detected {len(target_modules)} target modules for LoRA") + + # Map task type string to PEFT TaskType enum + # None means PEFT uses the base PeftModel with a clean forward() passthrough, + # which is correct for custom nn.Module models (not HuggingFace transformers). + task_type_map = { + "FEATURE_EXTRACTION": TaskType.FEATURE_EXTRACTION, + "SEQ_CLS": TaskType.SEQ_CLS, + "TOKEN_CLS": TaskType.TOKEN_CLS, + "CAUSAL_LM": TaskType.CAUSAL_LM, + } + + peft_task_type = task_type_map.get(task_type) if task_type else None + + # Create LoRA config + lora_config = LoraConfig( + task_type=peft_task_type, + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=target_modules, + modules_to_save=modules_to_save, + bias="none", # Don't adapt bias terms + ) + + logger.info( + f"Creating LoRA model with r={lora_r}, alpha={lora_alpha}, " + f"dropout={lora_dropout}" + ) + + # Wrap model with PEFT + peft_model = get_peft_model(model, lora_config) + + logger.info("LoRA model created successfully") + print_lora_parameters(peft_model) + + return peft_model + + +def print_lora_parameters(model: nn.Module): + """ + Print statistics about trainable and total parameters in a LoRA model. + + Args: + model: PEFT model with LoRA adapters + + Examples: + >>> lora_model = wrap_model_with_lora(model) + >>> print_lora_parameters(lora_model) + Trainable params: 294,912 (1.2% of total) + Total params: 24,567,890 + """ + try: + from peft import PeftModel + if isinstance(model, PeftModel): + model.print_trainable_parameters() + return + except ImportError: + pass + + # Fallback if not a PEFT model + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in model.parameters()) + + if total_params > 0: + percentage = 100 * trainable_params / total_params + logger.info( + f"Trainable params: {trainable_params:,} ({percentage:.2f}% of total)" + ) + logger.info(f"Total params: {total_params:,}") + else: + logger.warning("Model has no parameters") + + +def load_lora_adapter( + model: nn.Module, + adapter_path: str, + is_trainable: bool = False, +) -> nn.Module: + """ + Load a pretrained LoRA adapter into a base model. + + Args: + model: Base PyTorch model (without LoRA) + adapter_path: Path to saved LoRA adapter directory + is_trainable: If True, adapter parameters are trainable (for continued training) + If False, adapter parameters are frozen (for inference) + + Returns: + PEFT model with loaded adapter + + Examples: + >>> # Load adapter for inference + >>> model = load_lora_adapter( + ... base_model, + ... "models/fly_organelles/v1.1.0/lora_adapter" + ... ) + + >>> # Load adapter for continued training + >>> model = load_lora_adapter( + ... base_model, + ... "models/fly_organelles/v1.1.0/lora_adapter", + ... is_trainable=True + ... ) + """ + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "peft library is required. Install with: pip install peft" + ) + + logger.info(f"Loading LoRA adapter from: {adapter_path}") + + # Wrap Sequential models to make them compatible with PEFT + if isinstance(model, nn.Sequential): + logger.info("Wrapping Sequential model for PEFT compatibility") + model = SequentialWrapper(model) + + peft_model = PeftModel.from_pretrained( + model, + adapter_path, + is_trainable=is_trainable, + ) + + if is_trainable: + logger.info("Adapter loaded in trainable mode") + else: + logger.info("Adapter loaded in inference mode (frozen)") + + print_lora_parameters(peft_model) + + return peft_model + + +def save_lora_adapter( + model: nn.Module, + output_path: str, +): + """ + Save only the LoRA adapter parameters (not the full model). + + This saves only the trained LoRA weights (~5-20 MB) rather than + the entire model (~200-500 MB). + + Args: + model: PEFT model with LoRA adapters + output_path: Directory to save adapter + + Examples: + >>> save_lora_adapter( + ... lora_model, + ... "models/fly_organelles/v1.1.0/lora_adapter" + ... ) + """ + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "peft library is required. Install with: pip install peft" + ) + + if not isinstance(model, PeftModel): + raise ValueError( + "Model must be a PeftModel. Use wrap_model_with_lora() first." + ) + + logger.info(f"Saving LoRA adapter to: {output_path}") + model.save_pretrained(output_path) + logger.info("Adapter saved successfully") + + +def merge_lora_into_base(model: nn.Module) -> nn.Module: + """ + Merge LoRA weights back into the base model. + + This creates a standalone model with LoRA weights merged in, + removing the need for PEFT at inference time. + + Warning: This increases model size back to the full model size. + Only use if you need a standalone model without PEFT dependency. + + Args: + model: PEFT model with LoRA adapters + + Returns: + Base model with merged weights + + Examples: + >>> merged_model = merge_lora_into_base(lora_model) + >>> torch.save(merged_model.state_dict(), "merged_model.pt") + """ + try: + from peft import PeftModel + except ImportError: + raise ImportError( + "peft library is required. Install with: pip install peft" + ) + + if not isinstance(model, PeftModel): + raise ValueError( + "Model must be a PeftModel to merge adapters" + ) + + logger.info("Merging LoRA adapters into base model") + merged_model = model.merge_and_unload() + logger.info("Adapters merged successfully") + + return merged_model diff --git a/cellmap_flow/finetune/target_transforms.py b/cellmap_flow/finetune/target_transforms.py new file mode 100644 index 0000000..4749821 --- /dev/null +++ b/cellmap_flow/finetune/target_transforms.py @@ -0,0 +1,134 @@ +""" +Target transforms for converting user annotations to training targets. + +Each transform takes a raw annotation tensor (B, 1, Z, Y, X) with values: + 0 = unannotated (ignored in loss) + 1 = background + 2 = first foreground object + 3 = second foreground object, etc. + +And produces: + target: (B, C, Z, Y, X) — training target matching model output channels + mask: (B, C, Z, Y, X) or (B, 1, Z, Y, X) — valid loss mask +""" + +from typing import List, Tuple + +import torch +from torch import Tensor + + +class TargetTransform: + """Base class for target transforms.""" + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + """Convert annotation to (target, mask) pair.""" + raise NotImplementedError + + +class BinaryTargetTransform(TargetTransform): + """Standard binary segmentation transform (current default behavior). + + Produces single-channel binary target: bg=0, fg=1. + Mask marks annotated regions. + """ + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + mask = (annotation > 0).float() + target = torch.clamp(annotation - 1, min=0) + target = (target > 0).float() + return target, mask + + +class BroadcastBinaryTargetTransform(TargetTransform): + """Binary target broadcast to N channels. + + All output channels receive the same fg/bg target. + Useful for treating multi-channel models (affinities, distances) + as simple binary segmentation. + """ + + def __init__(self, num_channels: int): + self.num_channels = num_channels + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + mask = (annotation > 0).float() + target = (torch.clamp(annotation - 1, min=0) > 0).float() + # expand is lazy (no memory copy), contiguous() ensures safe downstream use + target = target.expand(-1, self.num_channels, -1, -1, -1).contiguous() + mask = mask.expand(-1, self.num_channels, -1, -1, -1).contiguous() + return target, mask + + +class AffinityTargetTransform(TargetTransform): + """Compute affinity targets from instance labels. + + For each offset, affinity is: + 1 if both voxels belong to the same foreground object (same label > 1) + 0 if different objects, or either is background + + The loss mask requires both voxels in each pair to be annotated (label > 0), + producing a per-channel mask since each offset shifts differently. + + Args: + offsets: List of [dz, dy, dx] offset tuples defining neighbor relationships. + num_channels: Total number of model output channels. If greater than + len(offsets), extra channels (e.g. LSDs) are masked out + (mask=0) so they receive no gradient. If None, defaults + to len(offsets). + """ + + def __init__(self, offsets: List[List[int]], num_channels: int = None): + self.offsets = offsets + self.num_channels = num_channels if num_channels is not None else len(offsets) + + def __call__(self, annotation: Tensor) -> Tuple[Tensor, Tensor]: + B, _C, Z, Y, X = annotation.shape + # Allocate for all output channels; non-affinity channels stay zero (masked out) + target = torch.zeros(B, self.num_channels, Z, Y, X, device=annotation.device) + mask = torch.zeros(B, self.num_channels, Z, Y, X, device=annotation.device) + + labels = annotation[:, 0] # (B, Z, Y, X) + annotated = labels > 0 # bool + + for i, offset in enumerate(self.offsets): + dz, dy, dx = offset + src_slices, dst_slices = _offset_slices(Z, Y, X, dz, dy, dx) + + src_labels = labels[(slice(None), *src_slices)] + dst_labels = labels[(slice(None), *dst_slices)] + src_ann = annotated[(slice(None), *src_slices)] + dst_ann = annotated[(slice(None), *dst_slices)] + + # Affinity = 1 iff same foreground object + same_fg = (src_labels == dst_labels) & (src_labels > 1) + both_annotated = src_ann & dst_ann + + target[(slice(None), i, *src_slices)] = same_fg.float() + mask[(slice(None), i, *src_slices)] = both_annotated.float() + + return target, mask + + +def _offset_slices(Z, Y, X, dz, dy, dx): + """Compute source and destination slices for an offset. + + For a volume of shape (Z, Y, X) and offset (dz, dy, dx), + returns slices such that: + volume[src_slices] and volume[dst_slices] + are aligned views offset by (dz, dy, dx). + """ + + def _dim_slices(size, d): + if d > 0: + return slice(None, size - d), slice(d, None) + elif d < 0: + return slice(-d, None), slice(None, size + d) + else: + return slice(None), slice(None) + + sz, dz_s = _dim_slices(Z, dz) + sy, dy_s = _dim_slices(Y, dy) + sx, dx_s = _dim_slices(X, dx) + + return (sz, sy, sx), (dz_s, dy_s, dx_s) diff --git a/cellmap_flow/finetune/virtual_dataset.py b/cellmap_flow/finetune/virtual_dataset.py new file mode 100644 index 0000000..beb05b7 --- /dev/null +++ b/cellmap_flow/finetune/virtual_dataset.py @@ -0,0 +1,523 @@ +""" +On-the-fly random-patch dataset for finetuning. + +Architecture +------------ +There is exactly one source of truth per session: an +``annotation_volume.zarr`` (sparse, full-dataset extent, OME-NGFF) that +holds **every** annotation — painted scribbles plus any imported YAML +crops, all merged at their physical offsets. This dataset reads patches +straight out of that single volume zarr; no per-tile materialization, no +parallel source list to keep in sync. + +Sampling rule +------------- +Two-pool stratified sampling. FG voxels are partitioned by membership in +the volume's ``imported_crops`` bbox list (recorded in the volume zattrs +when YAML crops are imported): + - **dense pool**: voxels inside any imported_crops bbox (abundant GT) + - **sparse pool**: voxels outside all bboxes (painted scribbles, by + construction always sparse and informative — the user paints there + because the base model failed) + +Each ``__getitem__`` picks a pool by ``dense_to_sparse_ratio`` (default +0.5/0.5 when both pools exist; auto-degrades to 1.0 when only one +exists), samples a random FG voxel from that pool, jitters the patch +center, and reads raw + annotation patches around it. + +Without stratification, voxel-uniform sampling buries scribbles: a +typical session has ~40M dense voxels vs ~10K painted, so 999/1000 +patches would be dense and the corrections you painted barely move the +gradient. Stratification guarantees scribbles get a defined share of +each epoch regardless of voxel count. + +Index construction reads only **populated** chunks of the sparse zarr +(walks ``annotation/s0/`` for files matching ``z.y.x``). For an empty +volume that's an empty index; for a fully painted region it's the FG +voxels of those chunks. + +Reviewer notes +-------------- +- Workers each rebuild the FG index on spawn (cheap — only populated + chunks are read). We don't pickle any open zarr/tensorstore handles. +- ``len(self)`` is ``patches_per_epoch``; it has no relationship to the + number of populated chunks. The trainer treats this as the epoch length. +- The dataset returns ``(raw, annotation)`` tensors with shape + ``(1, Z, Y, X)`` matching :class:`CorrectionDataset`'s contract. +""" + +from __future__ import annotations + +import json +import logging +import os +import re +from typing import List, Optional, Tuple + +import numpy as np +import torch +import zarr +from torch.utils.data import Dataset + +logger = logging.getLogger(__name__) + +_CHUNK_KEY_RE = re.compile(r"^\d+\.\d+\.\d+$") + + +def _voxels_inside_any_bbox( + voxels: np.ndarray, bbox_offsets: np.ndarray, bbox_ends: np.ndarray +) -> np.ndarray: + """Return a boolean mask: ``True`` where ``voxels[i]`` lies inside any + ``[bbox_offsets[j], bbox_ends[j])`` half-open box. + + voxels: (N, 3) int. bbox_offsets, bbox_ends: (M, 3) int. Vectorized + over both: builds an (N, M) inside-test matrix and reduces along M. + For typical M ~ 1-10 the temporary stays small. + """ + if voxels.shape[0] == 0 or bbox_offsets.shape[0] == 0: + return np.zeros(voxels.shape[0], dtype=bool) + # (N, M, 3) broadcast: voxels[:, None, :] vs bbox_offsets[None, :, :] + ge = np.all(voxels[:, None, :] >= bbox_offsets[None, :, :], axis=-1) + lt = np.all(voxels[:, None, :] < bbox_ends[None, :, :], axis=-1) + return np.any(ge & lt, axis=-1) + + +class VirtualPatchDataset(Dataset): + """Yield random raw+annotation patches anchored on FG voxels in a volume zarr. + + Args: + volume_zarr_path: path to the session's ``annotation_volume.zarr``. + raw_dataset_path: path to the raw EM zarr the volume is aligned to. + input_size_voxels: shape (Z, Y, X) of the raw patch returned per + sample, in voxels at ``input_voxel_size_nm``. + output_size_voxels: shape (Z, Y, X) of the annotation patch, in + voxels at ``output_voxel_size_nm``. + input_voxel_size_nm: voxel size for raw patches (the dataset's + closest scale to the model's claimed input voxel size). + output_voxel_size_nm: voxel size for annotation patches. + patches_per_epoch: ``len(self)``; controls how many random patches + comprise one epoch. ``None`` (the default) means "auto: + substitute the total populated-chunk count" — every populated + chunk gets ~one patch per epoch on average. + jitter_voxels: half-range of the random offset applied to the patch + center, in **annotation voxels**. Defaults to + ``output_size_voxels // 4``. + seed: RNG seed; per-worker offset added so multi-worker dataloaders + sample distinct streams. + dense_to_sparse_ratio: fraction in [0, 1] of patches drawn from + the dense pool (FG voxels inside any imported_crops bbox). + ``None`` (default) means auto: 0.5 if both pools have voxels, + else 1.0 (use the non-empty pool exclusively). + """ + + def __init__( + self, + volume_zarr_path: str, + raw_dataset_path: str, + input_size_voxels: Tuple[int, int, int], + output_size_voxels: Tuple[int, int, int], + input_voxel_size_nm: Tuple[float, float, float], + output_voxel_size_nm: Tuple[float, float, float], + patches_per_epoch: Optional[int] = None, + jitter_voxels: Optional[Tuple[int, int, int]] = None, + seed: int = 0, + input_norm_config: Optional[dict] = None, + dense_to_sparse_ratio: Optional[float] = None, + ): + self.volume_zarr_path = volume_zarr_path + self.raw_dataset_path = raw_dataset_path + self.input_size = np.array(input_size_voxels, dtype=int) + self.output_size = np.array(output_size_voxels, dtype=int) + self.input_voxel_size = np.array(input_voxel_size_nm, dtype=float) + self.output_voxel_size = np.array(output_voxel_size_nm, dtype=float) + # Resolved to an int by _build_index() once the populated-chunk + # count is known (when patches_per_epoch was passed as None). + self.patches_per_epoch: Optional[int] = ( + int(patches_per_epoch) if patches_per_epoch is not None else None + ) + self.jitter = ( + np.array(jitter_voxels, dtype=int) + if jitter_voxels is not None + else (self.output_size // 4) + ) + self.seed = int(seed) + self.dense_to_sparse_ratio = ( + float(dense_to_sparse_ratio) + if dense_to_sparse_ratio is not None + else None + ) + self._effective_dense_ratio: float = 0.0 # set in _build_index + + # Input normalization to apply to every raw patch the dataset emits. + # The dashboard's inference path normalizes raw via ``g.input_norms`` + # before feeding the model; the trainer (a separate LSF process) + # has an empty ``g.input_norms``, so without this the trainer would + # train on raw uint8 while inference sees normalized [-1, 1]. + # ``input_norm_config`` is the JSON-serializable dict from the YAML + # (e.g. {"MinMaxNormalizer": {...}, "LambdaNormalizer": {...}}). + self.input_norm_config: dict = dict(input_norm_config or {}) + self._input_normalizers = self._build_input_normalizers(self.input_norm_config) + if not self._input_normalizers and self.input_norm_config: + logger.warning( + "input_norm_config provided but produced no normalizers; " + "raw patches will be returned unnormalized." + ) + if self._input_normalizers: + logger.info( + f"VirtualPatchDataset: applying {len(self._input_normalizers)} " + f"input normalizer(s) per patch: " + f"{[type(n).__name__ for n in self._input_normalizers]}" + ) + else: + logger.warning( + "VirtualPatchDataset: no input normalizers configured. " + "Raw patches will be returned in their native dtype/range. " + "If inference normalizes to [-1, 1] (typical), the trained " + "model will see different inputs at train vs inference time." + ) + + self.dataset_offset_nm: np.ndarray = np.zeros(3) + self.volume_shape_voxels: np.ndarray = np.zeros(3, dtype=int) + # Two-pool stratified sampling: dense FG voxels live inside any + # imported_crops bbox; sparse FG voxels are everywhere else + # (painted scribbles, by construction). Either may be empty. + self._fg_index_dense: Optional[np.ndarray] = None + self._fg_index_sparse: Optional[np.ndarray] = None + self._volume_arr = None # opened lazily after worker fork + self._raw_idi = None # opened lazily after worker fork + # Cached per-worker RNG. None until first __getitem__ (after fork/spawn). + # Without this cache, every __getitem__ would reseed and re-pick the + # very first integer of the same stream — producing the same patch + # forever and silently breaking training. + self._cached_rng: Optional[np.random.Generator] = None + + self._build_index() + + # ------------------------------------------------------------------ + # Index construction + # ------------------------------------------------------------------ + + def _build_index(self) -> None: + """Walk the volume's populated chunks and build dense + sparse FG indices. + + We use the on-disk file layout (zarr v2 stores one file per chunk + named ``z.y.x``) to enumerate just the chunks that have been + written. Empty regions of the sparse volume produce no files and + cost us nothing. FG voxels are then partitioned by membership in + the volume's ``imported_crops`` bbox list. + """ + s0_path = os.path.join(self.volume_zarr_path, "annotation", "s0") + if not os.path.isdir(s0_path): + raise ValueError( + f"Volume zarr at {self.volume_zarr_path} has no annotation/s0/ " + "directory; was it created?" + ) + + # Pull volume-level metadata once so we can map voxel coords to nm + # and classify FG voxels as dense (inside an imported crop) or + # sparse (outside). + with open(os.path.join(self.volume_zarr_path, ".zattrs")) as f: + root_attrs = json.load(f) + self.dataset_offset_nm = np.array( + root_attrs.get("dataset_offset_nm", [0, 0, 0]), dtype=float + ) + imported = root_attrs.get("imported_crops", []) or [] + # Bbox list as two stacked (M, 3) arrays for vectorized membership + # tests below. Empty when no YAML crops were imported. + if imported: + bbox_offsets = np.array( + [c["annotation_offset_voxels"] for c in imported], dtype=np.int64 + ) + bbox_shapes = np.array( + [c["annotation_shape_voxels"] for c in imported], dtype=np.int64 + ) + bbox_ends = bbox_offsets + bbox_shapes + else: + bbox_offsets = np.zeros((0, 3), dtype=np.int64) + bbox_ends = np.zeros((0, 3), dtype=np.int64) + + arr = zarr.open(s0_path, mode="r") + self.volume_shape_voxels = np.array(arr.shape, dtype=int) + chunk_shape = np.array(arr.chunks, dtype=int) + + chunk_keys = [ + name for name in os.listdir(s0_path) + if _CHUNK_KEY_RE.match(name) + ] + if not chunk_keys: + raise ValueError( + f"Volume zarr at {self.volume_zarr_path} has no populated chunks. " + "Paint annotations or import crops first." + ) + + dense_rows: List[np.ndarray] = [] + sparse_rows: List[np.ndarray] = [] + n_fg_chunks = 0 # chunks that actually contributed FG voxels + for key in chunk_keys: + cz, cy, cx = (int(s) for s in key.split(".")) + chunk_origin = np.array([cz, cy, cx], dtype=np.int64) * chunk_shape + chunk_data = arr.blocks[cz, cy, cx] + fg_local = np.argwhere(chunk_data >= 2).astype(np.int64) + if not fg_local.size: + # On-disk file exists (zarr writes fill chunks during slab + # writes) but contributes no FG; skip and don't count it. + continue + n_fg_chunks += 1 + fg_global = fg_local + chunk_origin + if bbox_offsets.shape[0] == 0: + # No imported crops → everything is sparse (painted). + sparse_rows.append(fg_global) + continue + in_dense = _voxels_inside_any_bbox(fg_global, bbox_offsets, bbox_ends) + if in_dense.any(): + dense_rows.append(fg_global[in_dense]) + if (~in_dense).any(): + sparse_rows.append(fg_global[~in_dense]) + + self._fg_index_dense = ( + np.concatenate(dense_rows, axis=0) if dense_rows else np.zeros((0, 3), dtype=np.int64) + ) + self._fg_index_sparse = ( + np.concatenate(sparse_rows, axis=0) if sparse_rows else np.zeros((0, 3), dtype=np.int64) + ) + n_dense = int(self._fg_index_dense.shape[0]) + n_sparse = int(self._fg_index_sparse.shape[0]) + + if n_dense == 0 and n_sparse == 0: + raise ValueError( + f"Volume zarr at {self.volume_zarr_path} has populated chunks " + "but no foreground voxels (>=2). Did you only paint background?" + ) + + # Resolve dense ratio: explicit value wins, else auto-balance to + # 0.5 when both pools have voxels, else fall back to whichever + # pool is non-empty so we still draw patches. + if self.dense_to_sparse_ratio is None: + if n_dense > 0 and n_sparse > 0: + self._effective_dense_ratio = 0.5 + elif n_dense > 0: + self._effective_dense_ratio = 1.0 + else: + self._effective_dense_ratio = 0.0 + else: + ratio = max(0.0, min(1.0, self.dense_to_sparse_ratio)) + # Clamp away from a pool that's empty so __getitem__ never + # tries to sample from an empty index. + if n_dense == 0: + self._effective_dense_ratio = 0.0 + elif n_sparse == 0: + self._effective_dense_ratio = 1.0 + else: + self._effective_dense_ratio = ratio + + # Default patches_per_epoch = number of FG-bearing chunks: each + # such chunk gets ~1 patch per epoch on average. Cheap "auto cover + # everything" mode the user can override via YAML or UI. We count + # FG-bearing chunks (not all chunk files) because zarr writes + # empty fill chunks during slab writes -- those don't represent + # annotation work and shouldn't inflate epoch length. + if self.patches_per_epoch is None: + self.patches_per_epoch = max(1, n_fg_chunks) + + logger.info( + f"VirtualPatchDataset: built FG index with {n_dense + n_sparse} voxels " + f"(dense={n_dense}, sparse={n_sparse}) from {n_fg_chunks} FG-bearing " + f"chunk(s) ({len(chunk_keys)} chunk files on disk) of " + f"{self.volume_zarr_path}; " + f"patches_per_epoch={self.patches_per_epoch}, " + f"dense_ratio={self._effective_dense_ratio:.3f} " + f"({'auto' if self.dense_to_sparse_ratio is None else 'explicit'}), " + f"jitter={self.jitter.tolist()}" + ) + + # ------------------------------------------------------------------ + # Dataset protocol + # ------------------------------------------------------------------ + + def __len__(self) -> int: + # Resolved by _build_index() in __init__. + return int(self.patches_per_epoch or 0) + + def __getitem__(self, _idx: int): + rng = self._worker_rng() + # Pick a pool by the resolved dense ratio. Both indices may exist; + # _build_index guarantees we never end up with the chosen pool empty. + use_dense = ( + self._effective_dense_ratio >= 1.0 + or (self._effective_dense_ratio > 0.0 and rng.random() < self._effective_dense_ratio) + ) + pool = self._fg_index_dense if use_dense else self._fg_index_sparse + anchor_zyx = pool[rng.integers(0, pool.shape[0])].astype(np.float64) + + jitter_offset = rng.integers( + low=-self.jitter, high=self.jitter + 1, size=3 + ).astype(np.float64) + ann_center_voxels = anchor_zyx + jitter_offset + + # Convert annotation-space voxel center to physical (nm) for the raw read. + ann_center_nm = ( + self.dataset_offset_nm + ann_center_voxels * self.output_voxel_size + ) + + ann_patch = self._read_annotation_patch(ann_center_voxels) + raw_patch = self._read_raw_patch(ann_center_nm) + + raw_t = torch.from_numpy(raw_patch.astype(np.float32)[np.newaxis, ...]) + ann_t = torch.from_numpy(ann_patch.astype(np.float32)[np.newaxis, ...]) + return raw_t, ann_t + + # ------------------------------------------------------------------ + # Patch reads + # ------------------------------------------------------------------ + + def _open_volume(self): + if self._volume_arr is None: + self._volume_arr = zarr.open( + os.path.join(self.volume_zarr_path, "annotation", "s0"), mode="r" + ) + return self._volume_arr + + def _read_annotation_patch(self, center_voxels: np.ndarray) -> np.ndarray: + """Crop a patch from the volume's annotation array. + + Out-of-bounds voxels are filled with 0 (= unannotated → masked + out by the trainer's loss when ``mask_unannotated=True``). + """ + out_size = self.output_size + lo = (center_voxels - out_size / 2).astype(int) + hi = lo + out_size + + clip_lo = np.maximum(lo, 0) + clip_hi = np.minimum(hi, self.volume_shape_voxels) + valid = np.all(clip_hi > clip_lo) + + patch = np.zeros(out_size, dtype=np.uint8) + if valid: + arr = self._open_volume() + src_slices = tuple(slice(int(c), int(d)) for c, d in zip(clip_lo, clip_hi)) + dst_slices = tuple( + slice(int(c - l), int(d - l)) + for c, d, l in zip(clip_lo, clip_hi, lo) + ) + patch[dst_slices] = arr[src_slices] + return patch + + def _read_raw_patch(self, center_nm: np.ndarray) -> np.ndarray: + """Read an ``input_size`` patch from the raw dataset, centered at ``center_nm``. + + The raw read uses ``normalize=False`` because the trainer process's + global ``g.input_norms`` is empty -- the dashboard's normalization + config doesn't propagate across the LSF process boundary. We apply + the dashboard's normalizers explicitly here from + ``self._input_normalizers``, which is built from the manifest at + construction time. + """ + from cellmap_flow.image_data_interface import ImageDataInterface + from funlib.geometry import Coordinate, Roi + + if self._raw_idi is None: + self._raw_idi = ImageDataInterface( + self.raw_dataset_path, + voxel_size=self.input_voxel_size, + normalize=False, + ) + idi = self._raw_idi + read_shape_nm = self.input_size * self.input_voxel_size + roi = Roi( + offset=Coordinate(center_nm - read_shape_nm / 2), + shape=Coordinate(read_shape_nm), + ) + patch = idi.to_ndarray_ts(roi) + + # Apply the dashboard's normalizers locally (no global state). + # Each normalizer is callable and returns an ndarray; the chain + # mirrors what apply_norms() does inside the dashboard process. + for norm in self._input_normalizers: + patch = norm(patch) + return patch + + @staticmethod + def _build_input_normalizers(input_norm_config: dict) -> list: + """Materialize the dict-form ``input_norm`` config into normalizer objects.""" + if not input_norm_config: + return [] + try: + from cellmap_flow.norm.input_normalize import get_normalizations + + return get_normalizations(input_norm_config) + except Exception as e: + logger.error( + f"Failed to build input normalizers from config " + f"{input_norm_config!r}: {e}. Patches will be unnormalized." + ) + return [] + + # ------------------------------------------------------------------ + # RNG plumbing + # ------------------------------------------------------------------ + + def _worker_rng(self) -> np.random.Generator: + # Cache the Generator on self so consecutive __getitem__ calls draw + # from the *advancing* state of the same RNG. Reseeding every call + # made every patch identical (the first integer pulled from a freshly + # seeded generator is deterministic). + if self._cached_rng is None: + worker_info = torch.utils.data.get_worker_info() + worker_id = 0 if worker_info is None else worker_info.id + self._cached_rng = np.random.default_rng( + self.seed + worker_id * 1_000_003 + ) + return self._cached_rng + + +# --------------------------------------------------------------------------- +# Manifest helpers +# --------------------------------------------------------------------------- + +VIRTUAL_MANIFEST_FILENAME = "_virtual_sources.json" + + +def write_manifest(corrections_dir: str, manifest: dict) -> str: + """Persist a manifest sentinel that ``create_dataloader`` looks for.""" + os.makedirs(corrections_dir, exist_ok=True) + path = os.path.join(corrections_dir, VIRTUAL_MANIFEST_FILENAME) + with open(path, "w") as f: + json.dump(manifest, f, indent=2) + return path + + +def read_manifest(corrections_dir: str) -> Optional[dict]: + """Return the manifest if present, else ``None``.""" + path = os.path.join(corrections_dir, VIRTUAL_MANIFEST_FILENAME) + if not os.path.exists(path): + return None + with open(path) as f: + return json.load(f) + + +def dataset_from_manifest(manifest: dict) -> VirtualPatchDataset: + """Instantiate a :class:`VirtualPatchDataset` from a manifest dict. + + Recognized manifest kinds: + - ``volume_zarr_v1`` (current): trainer reads the session's + annotation_volume.zarr directly. Field ``volume_zarr_path``. + """ + kind = manifest.get("kind") + if kind != "volume_zarr_v1": + raise ValueError( + f"Unsupported manifest kind: {kind!r}. Expected 'volume_zarr_v1'." + ) + return VirtualPatchDataset( + volume_zarr_path=manifest["volume_zarr_path"], + raw_dataset_path=manifest["raw_dataset_path"], + input_size_voxels=tuple(manifest["input_size_voxels"]), + output_size_voxels=tuple(manifest["output_size_voxels"]), + input_voxel_size_nm=tuple(manifest["input_voxel_size_nm"]), + output_voxel_size_nm=tuple(manifest["output_voxel_size_nm"]), + # None defaults to "cover all populated chunks" inside the dataset. + patches_per_epoch=manifest.get("patches_per_epoch"), + jitter_voxels=tuple(manifest["jitter_voxels"]) if manifest.get("jitter_voxels") else None, + seed=manifest.get("seed", 0), + input_norm_config=manifest.get("input_norm") or None, + dense_to_sparse_ratio=manifest.get("dense_to_sparse_ratio"), + ) diff --git a/cellmap_flow/globals.py b/cellmap_flow/globals.py index bc80c8a..9809e9a 100644 --- a/cellmap_flow/globals.py +++ b/cellmap_flow/globals.py @@ -2,9 +2,12 @@ from cellmap_flow.post.postprocessors import DefaultPostprocessor, ThresholdPostprocessor import os +import queue import yaml +import logging import threading import numpy as np +from collections import deque import logging from typing import Any, Dict, List, Optional @@ -77,6 +80,18 @@ class Flow: shader_controls: dict _server_config_cached: bool + # Dashboard state (moved from cellmap_flow.dashboard.state) + log_buffer: deque + log_clients: list + NEUROGLANCER_URL: Optional[str] + INFERENCE_SERVER: Optional[Any] + CUSTOM_CODE_FOLDER: str + bbx_generator_state: dict + finetune_job_manager: Any + minio_state: dict + annotation_volumes: dict + output_sessions: dict + def __new__(cls): if cls._instance is None: cls._instance = super(Flow, cls).__new__(cls) @@ -85,7 +100,19 @@ def __new__(cls): cls._instance.servers = [] cls._instance.raw = None cls._instance.input_norms = input_norms + # Raw JSON-serializable form of the dashboard's input_norm config. + # Populated by /api/run from the request payload; used by the + # finetune submit/restart flow so the trainer process applies the + # same normalization the dashboard uses at inference. + # + # NOTE: prefer ``current_input_norm_config()`` over reading this + # directly. Some startup paths (e.g. yaml_cli.py at server boot) + # populate ``input_norms`` from a YAML's ``json_data.input_norm`` + # but never touch ``input_norm_config``. The helper falls back to + # reconstructing the dict from the live normalizer instances. + cls._instance.input_norm_config = {} cls._instance.postprocess = postprocess + cls._instance.postprocess_config = {} cls._instance.viewer = None cls._instance.dataset_path = None cls._instance.model_catalog = {} @@ -129,8 +156,55 @@ def __new__(cls): cls._instance.shaders = {} # ShaderControls state: key = layer name, value = shaderControls dict cls._instance.shader_controls = {} + + # Dashboard state (moved from cellmap_flow.dashboard.state) + cls._instance.log_buffer = deque(maxlen=1000) + cls._instance.log_clients = [] + cls._instance.NEUROGLANCER_URL = None + cls._instance.INFERENCE_SERVER = None + cls._instance.CUSTOM_CODE_FOLDER = os.path.expanduser( + os.environ.get( + "CUSTOM_CODE_FOLDER", + "~/Desktop/cellmap/cellmap-flow/example/example_norm", + ) + ) + cls._instance.bbx_generator_state = { + "dataset_path": None, + "num_boxes": 0, + "bounding_boxes": [], + "viewer": None, + "viewer_process": None, + "viewer_url": None, + "viewer_state": None, + } + cls._instance.minio_state = { + "process": None, + "port": None, + "ip": None, + "bucket": "annotations", + "minio_root": None, + "output_base": None, + "last_sync": {}, + "chunk_sync_state": {}, + "sync_thread": None, + } + cls._instance.annotation_volumes = {} + cls._instance.output_sessions = {} + cls._instance._finetune_job_manager = None + return cls._instance + @property + def finetune_job_manager(self): + if self._finetune_job_manager is None: + from cellmap_flow.finetune.finetune_job_manager import FinetuneJobManager + self._finetune_job_manager = FinetuneJobManager() + return self._finetune_job_manager + + @finetune_job_manager.setter + def finetune_job_manager(self, value): + self._finetune_job_manager = value + def to_dict(self): return self.__dict__.items() @@ -231,3 +305,49 @@ def delete(cls): g = Flow() + + +# Custom handler to capture logs into Flow singleton +class LogHandler(logging.Handler): + def emit(self, record): + log_entry = self.format(record) + g.log_buffer.append(log_entry) + # Send to all connected clients + for client_queue in g.log_clients: + try: + client_queue.put_nowait(log_entry) + except queue.Full: + pass + + +def current_input_norm_config() -> dict: + """Return the dashboard's current input_norm as a JSON-serializable dict. + + Reads ``g.input_norm_config`` if populated; otherwise reconstructs the + dict from the live ``g.input_norms`` instances via their ``.to_dict()``. + The fallback matters because some startup paths (yaml_cli) populate + ``g.input_norms`` from the YAML at server boot but never touch + ``input_norm_config`` -- if the user submits training without first + hitting /api/run, the manifest would otherwise be written empty. + """ + cfg = getattr(g, "input_norm_config", None) or {} + if cfg: + return cfg + norms = getattr(g, "input_norms", None) or [] + derived = {} + for n in norms: + try: + d = n.to_dict() + name = d.pop("name", type(n).__name__) + derived[name] = d + except Exception: + continue + return derived + + +def get_blockwise_tasks_dir(): + tasks_dir = g.blockwise_tasks_dir or os.path.expanduser( + "~/.cellmap_flow/blockwise_tasks" + ) + os.makedirs(tasks_dir, exist_ok=True) + return tasks_dir diff --git a/cellmap_flow/image_data_interface.py b/cellmap_flow/image_data_interface.py index 60a79d8..29df67e 100644 --- a/cellmap_flow/image_data_interface.py +++ b/cellmap_flow/image_data_interface.py @@ -1,6 +1,7 @@ -import os import zarr from cellmap_flow.utils.ds import ( + _join_path, + _open_zarr, find_closest_scale, get_ds_info, open_ds_tensorstore, @@ -23,13 +24,14 @@ def __init__( concurrency_limit=1, normalize=True, ): + dataset_path = dataset_path.replace("\\ ", " ") if not dataset_path.startswith("precomputed://"): try: - ds = zarr.open(dataset_path, mode="r") + ds = _open_zarr(dataset_path, mode="r") if isinstance(ds, zarr.hierarchy.Group): scale, _, _ = find_closest_scale(dataset_path, voxel_size) logger.info(f"found scale {scale} for voxel size {voxel_size}") - dataset_path = os.path.join(dataset_path, scale) + dataset_path = _join_path(dataset_path, scale) logger.info(f"using dataset path {dataset_path}") except Exception as e: logger.warning(f"could not open dataset {dataset_path} to find scale: {e}") diff --git a/cellmap_flow/models/model_registry.py b/cellmap_flow/models/model_registry.py index 230fe0f..c967b62 100644 --- a/cellmap_flow/models/model_registry.py +++ b/cellmap_flow/models/model_registry.py @@ -12,6 +12,7 @@ BioModelConfig, CellMapModelConfig, HuggingFaceModelConfig, + FinetuneModelConfig, ) @@ -23,7 +24,7 @@ 'BioModelConfig': BioModelConfig, 'CellMapModelConfig': CellMapModelConfig, 'HuggingFaceModelConfig': HuggingFaceModelConfig, - + 'FinetuneModelConfig': FinetuneModelConfig, } HUGGING_FACE_ORGS_NAME = "cellmap" @@ -119,6 +120,8 @@ def get_all_model_configs() -> Dict[str, Dict[str, Any]]: param_info['input_type'] = 'file' elif 'channels' in param_name.lower() or 'voxel_size' in param_name.lower(): param_info['input_type'] = 'textarea' # for multi-line JSON + elif param_info.get('type') in ('dict',): + param_info['input_type'] = 'textarea' # for JSON dicts elif param_name in ('input_size', 'output_size', 'edge_length_to_process', 'iteration'): param_info['input_type'] = 'number' else: @@ -186,6 +189,12 @@ def instantiate_model_config(class_name: str, params: Dict[str, Any]) -> Any: if annotation.__origin__ == tuple: value = tuple(value) + # Handle dict types (e.g., base_model JSON) + elif annotation == dict: + if isinstance(value, str): + import json + value = json.loads(value) + # Handle numeric types elif annotation in (int, float): value = annotation(value) diff --git a/cellmap_flow/models/models_config.py b/cellmap_flow/models/models_config.py index 9ea93e8..8688e9b 100644 --- a/cellmap_flow/models/models_config.py +++ b/cellmap_flow/models/models_config.py @@ -193,14 +193,35 @@ def _get_config(self): from cellmap_flow.utils.load_py import load_safe_config config = load_safe_config(self.script_path) - if not hasattr(config, "read_shape"): + + # Derive read_shape/write_shape from input_size/output_size or vice versa + has_input_size = hasattr(config, "input_size") + has_output_size = hasattr(config, "output_size") + has_read_shape = hasattr(config, "read_shape") + has_write_shape = hasattr(config, "write_shape") + + if not has_read_shape and has_input_size: config.read_shape = Coordinate(config.input_size) * Coordinate( config.input_voxel_size ) - if not hasattr(config, "write_shape"): + if not has_write_shape and has_output_size: config.write_shape = Coordinate(config.output_size) * Coordinate( config.output_voxel_size ) + # Reverse: derive input_size/output_size from read_shape/write_shape + if not has_input_size and has_read_shape: + config.input_size = tuple( + int(s) + for s in Coordinate(config.read_shape) + / Coordinate(config.input_voxel_size) + ) + if not has_output_size and has_write_shape: + config.output_size = tuple( + int(s) + for s in Coordinate(config.write_shape) + / Coordinate(config.output_voxel_size) + ) + if not hasattr(config, "block_shape"): config.block_shape = np.array( tuple(config.output_size) + (config.output_channels,) @@ -336,6 +357,12 @@ def load_eval_model(self, num_channels, checkpoint_path): if checkpoint_path.endswith(".ts"): model_backbone = torch.jit.load(checkpoint_path, map_location=device) + elif checkpoint_path.endswith("model.pt"): + # Load full model directly (for trusted fly_organelles models) + model = torch.load(checkpoint_path, weights_only=False, map_location=device) + model.to(device) + model.eval() + return model else: from fly_organelles.model import StandardUnet @@ -711,6 +738,142 @@ def to_dict(self): result["scale"] = self.scale return result +class FinetuneModelConfig(ModelConfig): + """Configuration class for a LoRA-finetuned model. + + Wraps any base ModelConfig with a LoRA adapter applied on top. + The base model is loaded via its own ModelConfig, then the adapter + is applied using PEFT. + """ + + cli_name = "finetune" + + def __init__( + self, + lora_adapter_path: str, + base_model: dict, + name: str = None, + scale=None, + ): + """ + Args: + lora_adapter_path: Path to the saved LoRA adapter directory. + base_model: Dict describing the base model (same format as a YAML + model entry, e.g. {"type": "fly", "checkpoint_path": "...", ...}). + name: Display name for this model. + scale: Optional scale override. + """ + super().__init__() + self.lora_adapter_path = lora_adapter_path + self.base_model_dict = base_model + self.name = name + self.scale = scale + self._base_model_config = None + + @property + def base_model_config(self): + """Lazily build the base ModelConfig from the stored dict.""" + if self._base_model_config is None: + from cellmap_flow.utils.config_utils import build_model_from_entry + + base_name = self.base_model_dict.get("name", "base_model") + self._base_model_config = build_model_from_entry( + self.base_model_dict, model_name=base_name + ) + return self._base_model_config + + @property + def command(self): + return f"finetune --lora-adapter-path {self.lora_adapter_path}" + + def _get_config(self): + from cellmap_flow.finetune.lora_wrapper import load_lora_adapter + + # Get the fully-populated config from the base model + base_cfg = self.base_model_config.config + + # Apply LoRA adapter to the base model + base_model = base_cfg.model + + # TorchScript models can't be used with LoRA. Use cellmap_model.train() + # to get a trainable nn.Module via torch.export unflatten. + if isinstance(base_model, torch.jit.ScriptModule): + base_type = self.base_model_dict.get("type", "") + cellmap_model = None + if base_type == "huggingface": + repo = self.base_model_dict.get("repo") + revision = self.base_model_dict.get("revision") + if repo: + cellmap_model = get_huggingface_model(repo, revision) + elif base_type == "cellmap": + folder_path = self.base_model_dict.get("folder_path") + if folder_path: + cellmap_model = CellmapModel(folder_path=folder_path) + + if cellmap_model is not None: + trainable = cellmap_model.train() + if trainable is not None: + if type(trainable).__name__ == 'UnflattenedModule': + from cellmap_flow.finetune.lora_wrapper import BatchLoopWrapper + trainable = BatchLoopWrapper(trainable) + base_model = trainable + + device = next(base_model.parameters()).device + model = load_lora_adapter(base_model, self.lora_adapter_path, is_trainable=False) + model.to(device) + model.eval() + + # Replace the model in the config, keep everything else + config = Config() + config.model = model + config.input_voxel_size = base_cfg.input_voxel_size + config.output_voxel_size = base_cfg.output_voxel_size + config.read_shape = base_cfg.read_shape + config.write_shape = base_cfg.write_shape + config.output_channels = base_cfg.output_channels + config.block_shape = base_cfg.block_shape + + # Copy optional attributes from base config + for attr in ("channels", "axes_names", "chunk_output_axes", "output_dtype"): + if hasattr(base_cfg, attr): + setattr(config, attr, getattr(base_cfg, attr)) + + return config + + def to_dict(self): + """Export configuration for use with build_model_from_entry. + + Surfaces key base model fields at the top level so the pipeline + builder UI can display them alongside the finetune-specific fields. + """ + result = { + "type": "finetune", + "lora_adapter_path": self.lora_adapter_path, + "base_model": self.base_model_dict, + } + if self.name is not None: + result["name"] = self.name + if self.scale is not None: + result["scale"] = self.scale + + # Surface base model fields for UI display + base = self.base_model_dict + for key in ( + "channels", + "checkpoint_path", + "input_voxel_size", + "output_voxel_size", + "input_size", + "output_size", + ): + if key in base and key not in result: + result[key] = base[key] + if "type" in base: + result["base_type"] = base["type"] + + return result + + class HuggingFaceModelConfig(ModelConfig): """Configuration class for a Hugging Face model.""" diff --git a/cellmap_flow/norm/input_normalize.py b/cellmap_flow/norm/input_normalize.py index 0ba9c66..6d23232 100644 --- a/cellmap_flow/norm/input_normalize.py +++ b/cellmap_flow/norm/input_normalize.py @@ -59,6 +59,15 @@ class InputNormalizer(SerializableInterface): pass +class ChannelSelector(InputNormalizer): + def __init__(self, channel=0): + self.channel = int(channel) + + def _process(self, data) -> np.ndarray: + # No-op: channel selection is applied at the TensorStore level + return data + + class Dilate(InputNormalizer): def __init__(self, size=1): self.size = int(size) @@ -162,10 +171,29 @@ def _process(self, data) -> np.ndarray: class LambdaNormalizer(InputNormalizer): def __init__(self, expression: str): self.expression = expression - self._lambda = eval(f"lambda x: {expression}") + # ``_lambda`` is a Python ``lambda`` and not picklable, which breaks + # multiprocessing workers (e.g. PyTorch DataLoader with spawn). Don't + # store it on the instance; build it lazily in ``_process`` so it + # lives only in the worker that needs it. ``__getstate__``/ + # ``__setstate__`` further guarantee any older pickled instances + # don't try to round-trip the lambda. + + def _get_lambda(self): + if not hasattr(self, "_lambda") or self._lambda is None: + self._lambda = eval(f"lambda x: {self.expression}") + return self._lambda + + def __getstate__(self): + state = self.__dict__.copy() + state.pop("_lambda", None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._lambda = None # rebuilt lazily on first call def _process(self, data) -> np.ndarray: - return self._lambda(data.astype(np.float32)) + return self._get_lambda()(data.astype(np.float32)) @property def dtype(self): diff --git a/cellmap_flow/post/postprocessors.py b/cellmap_flow/post/postprocessors.py index 34889fc..ae8a5c0 100644 --- a/cellmap_flow/post/postprocessors.py +++ b/cellmap_flow/post/postprocessors.py @@ -26,6 +26,17 @@ def is_segmentation(self): return None +class SigmoidPostprocessor(PostProcessor): + """Apply sigmoid activation to convert logits to probabilities.""" + + def _process(self, data): + return 1.0 / (1.0 + np.exp(-data.astype(np.float32))) + + @property + def dtype(self): + return np.float32 + + class DefaultPostprocessor(PostProcessor): def __init__( self, diff --git a/cellmap_flow/server.py b/cellmap_flow/server.py index 56204d4..eed7d35 100644 --- a/cellmap_flow/server.py +++ b/cellmap_flow/server.py @@ -3,7 +3,7 @@ from http import HTTPStatus import numpy as np import numcodecs -from flask import Flask, jsonify, redirect +from flask import Flask, jsonify, redirect, request from flask_cors import CORS from flasgger import Swagger from funlib.geometry import Roi @@ -33,7 +33,7 @@ class CellMapFlowServer: All routes are defined via Flask decorators for convenience. """ - def __init__(self, dataset_name: str, model_config: ModelConfig): + def __init__(self, dataset_name: str, model_config: ModelConfig, restart_callback=None): """ Initialize the server and set up routes via decorators. """ @@ -47,6 +47,7 @@ def __init__(self, dataset_name: str, model_config: ModelConfig): self.model_output_axes = model_config.chunk_output_axes self.inferencer = Inferencer(model_config) + self.restart_callback = restart_callback # Load or initialize your dataset self.idi_raw = ImageDataInterface( @@ -100,6 +101,20 @@ def __init__(self, dataset_name: str, model_config: ModelConfig): def home(): return redirect("/apidocs/") + @self.app.route("/__control__/restart", methods=["POST"]) + def control_restart(): + if self.restart_callback is None: + return jsonify({"success": False, "error": "Restart control not enabled"}), HTTPStatus.NOT_IMPLEMENTED + try: + payload = request.get_json(silent=True) or {} + accepted = self.restart_callback(payload) + if not accepted: + return jsonify({"success": False, "error": "Restart request rejected"}), HTTPStatus.CONFLICT + return jsonify({"success": True}), HTTPStatus.OK + except Exception as e: + logger.error(f"Failed to process restart control request: {e}", exc_info=True) + return jsonify({"success": False, "error": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR + @self.app.route("//.zattrs", methods=["GET"]) def top_level_attributes(dataset): self.refresh_dataset(dataset) diff --git a/cellmap_flow/utils/bsub_utils.py b/cellmap_flow/utils/bsub_utils.py index 796e3ee..ed0de42 100644 --- a/cellmap_flow/utils/bsub_utils.py +++ b/cellmap_flow/utils/bsub_utils.py @@ -121,13 +121,13 @@ def get_status(self) -> JobStatus: else: return JobStatus.FAILED - def wait_for_host(self, timeout: int = 60) -> Optional[str]: + def wait_for_host(self, timeout: int = 180) -> Optional[str]: """ Monitor process output for host information. - + Args: - timeout: Maximum time to wait in seconds - + timeout: Maximum time to wait in seconds (default 180s for model loading) + Returns: Host URL if found, None otherwise """ @@ -464,18 +464,19 @@ def run_locally(command: str, name: str) -> LocalJob: LocalJob object with process information """ logger.info(f"Running locally: {command}") - + try: process = subprocess.Popen( - command.split(), + command, + shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) - + local_job = LocalJob(process=process, model_name=name) return local_job - + except Exception as e: logger.error(f"Error starting local process: {e}") raise diff --git a/cellmap_flow/utils/config_utils.py b/cellmap_flow/utils/config_utils.py index 1821458..f3f5bba 100644 --- a/cellmap_flow/utils/config_utils.py +++ b/cellmap_flow/utils/config_utils.py @@ -65,21 +65,11 @@ def load_config(path: str) -> Dict[str, Any]: logger.warning(f"Missing 'queue' in YAML, using: {fallback}") config["queue"] = fallback - # Models must be present and non-empty (can be dict or list for backward compatibility) - if "models" not in config: - logger.error("YAML must contain 'models' field") - sys.exit(1) - - if isinstance(config["models"], dict): - if not config["models"]: - logger.error("YAML 'models' dict is empty") - sys.exit(1) - elif isinstance(config["models"], list): - if not config["models"]: - logger.error("YAML 'models' list is empty") - sys.exit(1) - # logger.warning("Using deprecated list format for models. Consider using dict format with model names as keys.") - else: + # Models field: must be a dict, list, or empty/missing (for dashboard-only mode) + if "models" not in config or config["models"] is None: + config["models"] = {} + + if not isinstance(config["models"], (dict, list)): logger.error("YAML 'models' must be either a dict or list") sys.exit(1) diff --git a/cellmap_flow/utils/ds.py b/cellmap_flow/utils/ds.py index f358332..0ea25bc 100644 --- a/cellmap_flow/utils/ds.py +++ b/cellmap_flow/utils/ds.py @@ -60,13 +60,36 @@ def generate_singlescale_metadata( def get_scale_info(zarr_grp): attrs = zarr_grp.attrs + ms = attrs["multiscales"][0] + + # Determine which axes are spatial so we can skip channel axes + axes = ms.get("axes", []) + spatial_indices = [i for i, a in enumerate(axes) if a.get("type") == "space"] + # If no axes metadata, assume all dimensions are spatial + if not spatial_indices: + spatial_indices = None + resolutions = {} offsets = {} shapes = {} - for scale in attrs["multiscales"][0]["datasets"]: - resolutions[scale["path"]] = scale["coordinateTransformations"][0]["scale"] - offsets[scale["path"]] = scale["coordinateTransformations"][1]["translation"] - shapes[scale["path"]] = zarr_grp[scale["path"]].shape + for scale in ms["datasets"]: + transforms = scale["coordinateTransformations"] + full_res = transforms[0]["scale"] + # Translation is optional (e.g. s0 often has only scale) + full_translation = next( + (t["translation"] for t in transforms if t["type"] == "translation"), + [0.0] * len(full_res), + ) + full_shape = zarr_grp[scale["path"]].shape + + if spatial_indices is not None: + resolutions[scale["path"]] = [full_res[i] for i in spatial_indices] + offsets[scale["path"]] = [full_translation[i] for i in spatial_indices] + shapes[scale["path"]] = tuple(full_shape[i] for i in spatial_indices) + else: + resolutions[scale["path"]] = full_res + offsets[scale["path"]] = full_translation + shapes[scale["path"]] = full_shape return offsets, resolutions, shapes @@ -76,18 +99,18 @@ def get_array_path_if_needed(zarr_grp_path, target_resolution): # If successful, it's a dataset path return zarr_grp_path except Exception as e: - if ".zarr" not in zarr_grp_path: + if ".zarr" not in zarr_grp_path and not _is_zarr_container(zarr_grp_path): raise RuntimeError( f"Failed to open dataset at {zarr_grp_path}: {e}\n Multiscale is only supported for zarr groups. Please provide a valid dataset path." ) # Otherwise, it's a group path; find the appropriate scale target_scale, _, _ = find_target_scale(zarr_grp_path, target_resolution) - return os.path.join(zarr_grp_path, target_scale) + return _join_path(zarr_grp_path, target_scale) def find_target_scale(zarr_grp_path, target_resolution): try: - zarr_grp = zarr.open(zarr_grp_path, mode="r") + zarr_grp = _open_zarr(zarr_grp_path, mode="r") except Exception as e: raise RuntimeError(f"Failed to open zarr group at {zarr_grp_path}: {e}") offsets, resolutions, shapes = get_scale_info(zarr_grp) @@ -103,7 +126,7 @@ def find_target_scale(zarr_grp_path, target_resolution): def find_closest_scale(zarr_grp_path, target_resolution): - zarr_grp = zarr.open(zarr_grp_path, mode="r") + zarr_grp = _open_zarr(zarr_grp_path, mode="r") offsets, resolutions, shapes = get_scale_info(zarr_grp) target_scale = None last_scale = None @@ -137,6 +160,53 @@ def ends_with_scale(string): return bool(re.search(pattern, string)) +def _normalize_path(path: str) -> str: + """Remove shell-escape backslashes from a filesystem path. + + Users often copy-paste paths from a terminal where spaces are escaped + (e.g. ``/path/to/file\\ name.zarr``). YAML preserves the literal + backslashes, but the filesystem expects plain spaces. + """ + if _is_remote_path(path): + return path + return path.replace("\\ ", " ") + + +def _is_remote_path(path: str) -> bool: + return path.startswith("http://") or path.startswith("https://") + + +def _open_zarr(path, mode="r"): + """Open a zarr dataset, handling HTTP/HTTPS URLs via fsspec.""" + path = _normalize_path(path) + if _is_remote_path(path): + import fsspec + + return zarr.open(fsspec.get_mapper(path), mode=mode) + return zarr.open(path, mode=mode) + + +def _join_path(base, *parts): + """Join path components, handling URLs correctly.""" + if _is_remote_path(base): + return "/".join([base.rstrip("/"), *parts]) + return os.path.join(base, *parts) + + +def _is_zarr_container(path: str) -> bool: + """Check if a local path is a zarr container by looking for zarr metadata files. + + Works for zarr directories that don't have a .zarr extension. + """ + if _is_remote_path(path): + return False + return os.path.isdir(path) and ( + os.path.exists(os.path.join(path, ".zgroup")) + or os.path.exists(os.path.join(path, ".zarray")) + or os.path.exists(os.path.join(path, ".zattrs")) + ) + + def split_dataset_path(dataset_path, scale=None) -> tuple[str, str]: """Split the dataset path into the filename and dataset @@ -148,19 +218,59 @@ def split_dataset_path(dataset_path, scale=None) -> tuple[str, str]: Tuple of filename and dataset """ - # split at .zarr or .n5, whichever comes last - splitter = ( - ".zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else ".n5" - ) + has_zarr = ".zarr" in dataset_path + has_n5 = ".n5" in dataset_path - filename, dataset = dataset_path.split(splitter) - if dataset.startswith("/"): - dataset = dataset[1:] - # include scale if present - if scale is not None: - dataset += f"/s{scale}" + if has_zarr or has_n5: + # split at .zarr or .n5, whichever comes last + splitter = ( + ".zarr" + if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") + else ".n5" + ) + + filename, dataset = dataset_path.rsplit(splitter, 1) + if dataset.startswith("/"): + dataset = dataset[1:] + # include scale if present + if scale is not None: + dataset += f"/s{scale}" - return filename + splitter, dataset + return filename + splitter, dataset + + # No .zarr or .n5 extension — walk up the path to find a zarr group container. + if _is_remote_path(dataset_path): + raise RuntimeError( + f"Remote URL must contain .zarr or .n5 in the path: {dataset_path}" + ) + # Prefer .zgroup (container root) over .zarray (leaf dataset). + path = os.path.normpath(dataset_path) + parts = [] + fallback = None # track first .zarray-only match as fallback + while path and path != os.path.dirname(path): + if os.path.isdir(path): + if os.path.exists(os.path.join(path, ".zgroup")): + dataset = "/".join(reversed(parts)) + if scale is not None: + dataset = f"{dataset}/s{scale}" if dataset else f"s{scale}" + return path, dataset + if fallback is None and os.path.exists( + os.path.join(path, ".zarray") + ): + fallback = (path, list(parts)) + path, part = os.path.split(path) + parts.append(part) + + if fallback is not None: + fb_path, fb_parts = fallback + dataset = "/".join(reversed(fb_parts)) + if scale is not None: + dataset = f"{dataset}/s{scale}" if dataset else f"s{scale}" + return fb_path, dataset + + raise RuntimeError( + f"Could not find a zarr or n5 container in path: {dataset_path}" + ) def apply_norms(data): @@ -190,13 +300,30 @@ def __getattr__(self, attr): return at +def _detect_filetype(dataset_path: str) -> str: + """Detect whether a dataset path is zarr or n5.""" + if ".zarr" in dataset_path or ".n5" in dataset_path: + return ( + "zarr" + if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") + else "n5" + ) + # No extension — check filesystem for zarr metadata + normalized = os.path.normpath(dataset_path) + path = normalized + while path and path != os.path.dirname(path): + if _is_zarr_container(path): + return "zarr" + path = os.path.dirname(path) + # Default to zarr + return "zarr" + + def open_ds_tensorstore( dataset_path: str, mode="r", concurrency_limit=None, normalize=True ): # open with zarr or n5 depending on extension - filetype = ( - "zarr" if dataset_path.rfind(".zarr") > dataset_path.rfind(".n5") else "n5" - ) + filetype = _detect_filetype(dataset_path) extra_args = {} if dataset_path.startswith("precomputed://"): @@ -213,12 +340,11 @@ def open_ds_tensorstore( "path": os.path.normpath(raw_path), } extra_args = {"scale_index": scale_index} - elif dataset_path.startswith("http://"): - path = dataset_path.split("http://")[1] + elif dataset_path.startswith("http://") or dataset_path.startswith("https://"): kvstore = { "driver": "http", - "base_url": "http://", - "path": path, + "base_url": dataset_path.rstrip("/"), + "path": "", } elif dataset_path.startswith("s3://"): kvstore = { @@ -263,11 +389,36 @@ def open_ds_tensorstore( else: dataset_future = ts.open(spec, read=False, write=True) - if dataset_path.startswith("gs://") or dataset_path.startswith("precomputed://"): - # NOTE: Currently a hack since google store / precomputed is stored as multichannel - ts_dataset = dataset_future.result()[ts.d["channel"][0]] - else: + try: ts_dataset = dataset_future.result() + if ts_dataset.ndim > 3: + from cellmap_flow.norm.input_normalize import ChannelSelector + + channel = 0 + for norm in g.input_norms: + if isinstance(norm, ChannelSelector): + channel = norm.channel + break + ts_dataset = ts_dataset[channel] + except ValueError as e: + if "extra members" in str(e) and filetype == "zarr": + # Some zarr files have extra fields (e.g. "checksum") in the + # compressor metadata that tensorstore doesn't recognize. + # Fix by providing the metadata explicitly without the extra fields. + import json + zarray_path = os.path.join(os.path.normpath(dataset_path), ".zarray") + with open(zarray_path) as f: + zarray = json.load(f) + if "compressor" in zarray and isinstance(zarray["compressor"], dict): + zarray["compressor"].pop("checksum", None) + spec["metadata"] = zarray + if mode == "r": + dataset_future = ts.open(spec, read=True, write=False, assume_metadata=True) + else: + dataset_future = ts.open(spec, read=False, write=True, assume_metadata=True) + ts_dataset = dataset_future.result() + else: + raise # return ts_dataset if normalize: @@ -416,6 +567,18 @@ def separate_store_path(store, path): new_store, path_prefix = os.path.split(store) if ".zarr" in path_prefix or ".n5" in path_prefix: return store, path + # For extensionless zarr containers, check for zarr metadata on disk. + # Strip file:// protocol prefix for filesystem check. + local_path = store + if local_path.startswith("file://"): + local_path = local_path[len("file://"):] + if os.path.exists(os.path.join(local_path, ".zgroup")) or os.path.exists( + os.path.join(local_path, ".zarray") + ): + return store, path + if new_store == store: + # Reached the root without finding a container + raise RuntimeError(f"Could not find zarr/n5 container in path: {store}") return separate_store_path(new_store, os.path.join(path_prefix, path)) @@ -764,7 +927,7 @@ def get_ds_info(path: str, mode: str = "r"): filename: - The name of the container "file" (which is a directory for Zarr and + The name of the container "file" (which is a directory for zarr and N5). ds_name: @@ -776,6 +939,7 @@ def get_ds_info(path: str, mode: str = "r"): A :class:`Array` pointing to the dataset. """ + path = _normalize_path(path) axes_names = ["x", "y", "z"] if path.startswith("s3://"): ts_info = open_ds_tensorstore(path) @@ -829,15 +993,151 @@ def get_ds_info(path: str, mode: str = "r"): roi = Roi([0] * len(shape), Coordinate(shape) * voxel_size) return voxel_size, chunk_shape, shape, roi, axes_names, "precomputed" + if _is_remote_path(path): + ds = _open_zarr(path, mode="r") + + # If the URL points to a zarr Group (e.g. multiscale container), + # read OME-Zarr multiscales metadata and navigate into the first array. + if isinstance(ds, zarr.hierarchy.Group): + multiscales = ds.attrs.get("multiscales", None) + if multiscales: + ms = multiscales[0] + first_dataset = ms["datasets"][0] + ds = ds[first_dataset["path"]] + + # Extract spatial axes info (skip channel axes) + axes = ms.get("axes", []) + spatial_indices = [ + i for i, a in enumerate(axes) if a.get("type") == "space" + ] + axes_names = [axes[i]["name"] for i in spatial_indices] + + scale_transform = first_dataset["coordinateTransformations"][0]["scale"] + voxel_size = Coordinate(scale_transform[i] for i in spatial_indices) + + translation = next( + ( + t["translation"] + for t in first_dataset["coordinateTransformations"] + if t["type"] == "translation" + ), + [0.0] * len(scale_transform), + ) + offset = Coordinate(translation[i] for i in spatial_indices) + + shape = Coordinate(ds.shape[i] for i in spatial_indices) + chunk_shape = tuple(ds.chunks[i] for i in spatial_indices) + roi = Roi(offset, voxel_size * shape) + return voxel_size, chunk_shape, shape, roi, axes_names, "zarr" + else: + for key in sorted(ds.keys()): + if isinstance(ds[key], zarr.core.Array): + ds = ds[key] + break + + # The path points to a sub-array (e.g. .zarr/raw/s0). Try to read + # multiscale metadata from the parent zarr Group. + if ".zarr" in path or ".n5" in path: + container, sub_path = split_dataset_path(path) + if sub_path: + try: + parent = _open_zarr(container, mode="r") + multiscales = parent.attrs.get("multiscales", None) + if multiscales: + ms = multiscales[0] + axes = ms.get("axes", []) + spatial_indices = [ + i + for i, a in enumerate(axes) + if a.get("type") == "space" + ] + if not spatial_indices: + spatial_indices = list(range(len(ds.shape))) + + # Find the matching dataset entry + dataset_entry = next( + ( + d + for d in ms["datasets"] + if d["path"] == sub_path + ), + ms["datasets"][0], + ) + scale_transform = dataset_entry[ + "coordinateTransformations" + ][0]["scale"] + voxel_size = Coordinate( + scale_transform[i] for i in spatial_indices + ) + translation = next( + ( + t["translation"] + for t in dataset_entry[ + "coordinateTransformations" + ] + if t["type"] == "translation" + ), + [0.0] * len(scale_transform), + ) + offset = Coordinate( + translation[i] for i in spatial_indices + ) + axes_names = [axes[i]["name"] for i in spatial_indices] if axes else ["z", "y", "x"] + shape = Coordinate( + ds.shape[i] for i in spatial_indices + ) + chunk_shape = tuple( + ds.chunks[i] for i in spatial_indices + ) + roi = Roi(offset, voxel_size * shape) + return ( + voxel_size, + chunk_shape, + shape, + roi, + axes_names, + "zarr", + ) + except Exception as e: + logger.warning( + "failed to read parent multiscale metadata for %s: %s" + % (path, e) + ) + + # Fallback for remote arrays without multiscales metadata + try: + order = ds.attrs["order"] + except KeyError: + try: + order = ds.order + except Exception: + logger.error("no order attribute found, set default C") + order = "C" + try: + voxel_size, offset = _read_voxel_size_offset(ds, order) + except Exception: + logger.error( + "failed to read voxel size and offset for %s, Will use default values" + % path + ) + voxel_size = Coordinate((1,) * 3) + offset = Coordinate((0,) * 3) + shape = Coordinate(ds.shape[-len(voxel_size) :]) + roi = Roi(offset, voxel_size * shape) + chunk_shape = ds.chunks + return voxel_size, chunk_shape, shape, roi, ["z", "y", "x"], "zarr" + filename, ds_name = split_dataset_path(path) - if filename.endswith(".zarr") or filename.endswith(".zip"): + if filename.endswith(".zarr") or filename.endswith(".zip") or _is_zarr_container(filename): assert ( not filename.endswith(".zip") or mode == "r" ), "Only reading supported for zarr ZipStore" logger.debug("opening zarr dataset %s in %s", ds_name, filename) try: - ds = zarr.open(filename, mode=mode)[ds_name] + ds = zarr.open(filename, mode=mode) + if ds_name: + ds = ds[ds_name] except Exception as e: logger.error("failed to open %s/%s" % (filename, ds_name)) raise e diff --git a/cellmap_flow/utils/load_py.py b/cellmap_flow/utils/load_py.py index deb8179..bea1842 100644 --- a/cellmap_flow/utils/load_py.py +++ b/cellmap_flow/utils/load_py.py @@ -42,12 +42,9 @@ def analyze_script(filepath): # If function is a direct name (e.g., `eval()`) if isinstance(node.func, ast.Name) and node.func.id in DISALLOWED_FUNCTIONS: issues.append(f"Disallowed function call detected: {node.func.id}") - # If function is an attribute call (e.g., `os.system()`) - elif ( - isinstance(node.func, ast.Attribute) - and node.func.attr in DISALLOWED_FUNCTIONS - ): - issues.append(f"Disallowed function call detected: {node.func.attr}") + # Note: We intentionally do NOT flag method calls like `model.eval()` here + # Method calls on objects (e.g., model.eval()) are safe - only direct calls + # to dangerous builtin functions (e.g., eval()) are a security risk # Return whether the script is safe (no issues found) and the list of issues is_safe = len(issues) == 0 @@ -96,6 +93,7 @@ def visit_Name(self, node): # Convert the modified AST back to source code code = ast.unparse(tree) + exec(code, config_namespace) # Extract the config object from the namespace config = Config(**config_namespace) diff --git a/cellmap_flow/utils/neuroglancer_utils.py b/cellmap_flow/utils/neuroglancer_utils.py index c688b3a..ac5c2ed 100644 --- a/cellmap_flow/utils/neuroglancer_utils.py +++ b/cellmap_flow/utils/neuroglancer_utils.py @@ -4,6 +4,7 @@ from cellmap_flow.dashboard.app import create_and_run_app from cellmap_flow.utils.scale_pyramid import get_raw_layer +from cellmap_flow.utils.ds import find_closest_scale, get_scale_info, _open_zarr from cellmap_flow.globals import g from cellmap_flow.utils.web_utils import ( @@ -17,11 +18,65 @@ neuroglancer.set_server_bind_address("0.0.0.0") +def get_raw_closest_scale(dataset_path, target_resolution): + """Return the raw multiscale scale (as a tuple of nm) closest to the + model's target resolution, or None if it can't be determined.""" + try: + zarr_grp = _open_zarr(dataset_path, mode="r") + _, resolutions, _ = get_scale_info(zarr_grp) + target_scale, _, _ = find_closest_scale(dataset_path, target_resolution) + return tuple(resolutions[target_scale]) + except Exception as e: + logger.warning( + f"Could not determine closest raw scale for {dataset_path} at " + f"target_resolution={target_resolution}: {e}" + ) + return None + + +def build_prediction_source(host, model, st_data, override_scales): + """Build a source spec for the prediction zarr that overrides the + source dimensions' scales so the layer overlays the raw at its native + resolution (e.g. claim a 16nm model output is actually at 12nm). + + The prediction zarr is 4D (z, y, x, c). We override the spatial scales + and leave the channel dim as a unitless dimension. + """ + url = f"zarr://{host}/{model}{ARGS_KEY}{st_data}{ARGS_KEY}" + if override_scales is None: + return url + sx, sy, sz = override_scales[0], override_scales[1], override_scales[2] + # Use a dict form so we can supply matching input/output dimensions + # of the same rank (neuroglancer requires equal rank on both sides). + return { + "url": url, + "transform": { + "outputDimensions": { + "z": [sz * 1e-9, "m"], + "y": [sy * 1e-9, "m"], + "x": [sx * 1e-9, "m"], + "c^": [1, ""], + }, + "inputDimensions": { + "z": [sz * 1e-9, "m"], + "y": [sy * 1e-9, "m"], + "x": [sx * 1e-9, "m"], + "c^": [1, ""], + }, + }, + } + + def generate_neuroglancer_url(dataset_path,wrap_raw=True): g.viewer = neuroglancer.Viewer() g.dataset_path = dataset_path st_data = get_norms_post_args(g.input_norms, g.postprocess) + # Map model name -> ModelConfig for voxel-size lookups + model_configs_by_name = {} + for mc in getattr(g, "models_config", []) or []: + model_configs_by_name[mc.name] = mc + # Add a layer to the viewer with g.viewer.txn() as s: g.raw = get_raw_layer(dataset_path, wrap_raw=wrap_raw) @@ -47,8 +102,29 @@ def generate_neuroglancer_url(dataset_path,wrap_raw=True): shader = g.shaders.get(model, default_shader) if model not in g.shaders: g.shaders[model] = default_shader + + # Lie about the prediction's voxel size so it overlays the raw + # at the closest available scale (model trained at 16nm but raw + # is multiscale 6/12/24/...; we tell neuroglancer "treat the + # output as 12nm" so it lines up). + override_scales = None + mc = model_configs_by_name.get(model) + if mc is not None: + try: + output_voxel_size = tuple(mc.config.output_voxel_size) + closest = get_raw_closest_scale(dataset_path, output_voxel_size) + if closest is not None and tuple(closest) != output_voxel_size: + override_scales = closest + logger.info( + f"Model '{model}' output_voxel_size={output_voxel_size} " + f"overridden to closest raw scale {closest} for viewer overlay" + ) + except Exception as e: + logger.warning(f"Could not compute override scales for '{model}': {e}") + + source = build_prediction_source(host, model, st_data, override_scales) layer_kwargs = { - "source": f"zarr://{host}/{model}{ARGS_KEY}{st_data}{ARGS_KEY}", + "source": source, "shader": shader, } shader_controls = g.shader_controls.get(model) diff --git a/cellmap_flow/utils/scale_pyramid.py b/cellmap_flow/utils/scale_pyramid.py index 9d403d7..c5f53fc 100644 --- a/cellmap_flow/utils/scale_pyramid.py +++ b/cellmap_flow/utils/scale_pyramid.py @@ -9,14 +9,22 @@ import zarr from cellmap_flow.image_data_interface import ImageDataInterface -from cellmap_flow.utils.ds import check_for_multiscale, get_ds_info +from cellmap_flow.utils.ds import ( + _is_remote_path, + _is_zarr_container, + _join_path, + _open_zarr, + check_for_multiscale, + get_ds_info, +) logger = logging.getLogger(__name__) def get_raw_layer(dataset_path, normalize=True, wrap_raw=True): + dataset_path = dataset_path.replace("\\ ", " ") + original_dataset_path = dataset_path is_precomputed = dataset_path.startswith("precomputed://") - # if multiscale dataset if is_precomputed: # precomputed format handles scales internally via tensorstore @@ -29,14 +37,14 @@ def get_raw_layer(dataset_path, normalize=True, wrap_raw=True): is_multiscale = True else: try: - is_multiscale = check_for_multiscale(zarr.open(dataset_path, mode="r"))[0] + is_multiscale = check_for_multiscale(_open_zarr(dataset_path, mode="r"))[0] except Exception as e: logger.error(e) is_multiscale = False if is_precomputed: filetype = "precomputed" - elif ".zarr" in dataset_path: + elif ".zarr" in dataset_path or _is_zarr_container(dataset_path): filetype = "zarr" elif ".n5" in dataset_path: filetype = "n5" @@ -58,13 +66,24 @@ def get_raw_layer(dataset_path, normalize=True, wrap_raw=True): if is_multiscale: try: - scales = [ - f for f in os.listdir(dataset_path) if f[0] == "s" and f[1:].isdigit() - ] - scales.sort(key=lambda x: int(x[1:])) + if _is_remote_path(dataset_path): + grp = _open_zarr(dataset_path, mode="r") + multiscales = grp.attrs.get("multiscales", None) + if multiscales: + scales = [d["path"] for d in multiscales[0]["datasets"]] + else: + scales = sorted( + [k for k in grp.keys() if k.startswith("s") and k[1:].isdigit()], + key=lambda x: int(x[1:]), + ) + else: + scales = [ + f for f in os.listdir(dataset_path) if f[0] == "s" and f[1:].isdigit() + ] + scales.sort(key=lambda x: int(x[1:])) for scale in scales: image = ImageDataInterface( - f"{os.path.join(dataset_path, scale)}", normalize=normalize + _join_path(dataset_path, scale), normalize=normalize ) # Use axes from the actual dataset - neuroglancer will use them as-is layers.append( @@ -87,7 +106,7 @@ def get_raw_layer(dataset_path, normalize=True, wrap_raw=True): is_multiscale = False if not is_multiscale: - image = ImageDataInterface(dataset_path) + image = ImageDataInterface(original_dataset_path) return neuroglancer.ImageLayer( source=neuroglancer.LocalVolume( data=image.ts, diff --git a/docs/finetuning.md b/docs/finetuning.md new file mode 100644 index 0000000..52722fc --- /dev/null +++ b/docs/finetuning.md @@ -0,0 +1,133 @@ +# Finetuning Guide + +This guide walks through the full finetuning workflow in CellMap-Flow: loading data, creating annotations, and training a finetuned model — all from the dashboard. + +## Installation + +Create a conda environment with the finetuning dependencies: + +```bash +mamba create -n cellmap-flow-finetune python=3.11 minio-server minio-client -c conda-forge -y +mamba activate cellmap-flow-finetune +pip install git+https://github.com/briossant/neuroglancer@feature/voxel-annotation +pip install -e ".[finetune]" +``` + +This installs: +- **MinIO** (server + client) — local S3-compatible server for serving annotation zarr files to Neuroglancer +- **Neuroglancer** (voxel annotation branch) — adds voxel-level annotation tools needed for painting labels +- **LoRA/PEFT dependencies** — for parameter-efficient finetuning + +## 1. Launch the Dashboard + +Start by loading your data and model with a YAML configuration file: + +```bash +cellmap_flow_yaml my_yamls/jrc_c-elegans-bw-1_affinities.yaml +``` + +This starts the dashboard with your dataset and model loaded into the Neuroglancer viewer. + +## 2. Create or Resume an Annotation Volume +![Annotation Crops tab](screenshots/finetune_annotation_crops.png) + +Navigate to the **Finetune** tab in the dashboard. + +Under **Annotation Crops**, you will see your model configuration (name, output size, voxel size, crop shape, channels) along with controls for starting a new sparse annotation volume, resuming a previous session, and syncing annotations from MinIO back to disk. + +### Start a new volume + +1. Set the **Output Path for Zarr Files** to a directory where annotation data will be saved. This must be accessible to the MinIO server that the dashboard starts. +2. Click **New Volume**. +3. This creates a sparse annotation zarr covering the full dataset extent, where each chunk maps to one training sample. +4. A MinIO server will start automatically to serve the zarr for editing in Neuroglancer. + +### Resume an existing volume + +If you already have a prior annotation session: + +1. Set **Output Path for Zarr Files** to the root directory where you want the resumed session to be created. +2. Click **Resume Existing Volume**. +3. In the modal, scan the directory containing existing timestamped finetuning sessions. +4. Select a session and click **Load Selected**. + +This copies the chosen session into a new session directory, rather than editing the original in place. The copied session records its source in `loaded_from.json`. + +### Save annotations to disk + +While you are painting in Neuroglancer, edits are served through MinIO. Click **Save Annotations to Disk** to explicitly sync those in-progress annotations back to local storage. + +Painted regions are also shown in the viewer as bounding boxes through the `annotated_regions` layer. + + +## 3. Set Up Annotation Tools in Neuroglancer +![Draw tab with bound keys](screenshots/finetune_draw_tab.png) + +Once the annotation volume is created or resumed and added to the viewer: + +1. **Select the annotation layer** by right-clicking on it in the layer list (it will be named something like `sparse_annotation_vol-XXXX`). +2. Go to the **Draw** tab for that layer. +3. **Bind keyboard shortcuts** to the drawing tools: + - Click the small box next to each tool name (e.g. `[A] Brush`, `[S] Flood Fill`, `[D] Seg Picker`). + - Press the letter you want to assign to that tool. + - Once bound, activate a tool by pressing **Shift + the assigned letter**. + + + +## 4. Annotate + +When you start drawing, Neuroglancer will ask if you want to write to the file — click **Yes**. + +### Annotation label rules + +- **Paint Value 1** = **background** (this voxel is not the object of interest) +- **Paint Value 2** = **foreground** (this voxel is the object of interest) +- For **affinities models** with multiple object IDs, use higher paint values (3, 4, ...) for distinct object instances. The finetuning pipeline will automatically convert these instance IDs into affinity targets using the offsets defined in the model script. +- **Paint Value 0** = **unannotated / ignored** — these voxels are excluded from the loss during training. + +You can change the paint value in the Draw tab by editing the **Paint Value** field, or click **Random** next to **New Random Value** to pick a new instance ID. + +Annotate as many chunks as you like across the dataset. Only chunks with non-zero annotations will be used for training. + +### Deprecated dense crop workflow + +The **Create Annotation Crop** button is still available under the advanced section, but it is deprecated. It creates a small dense crop at the current view center and is rarely needed compared with the sparse full-volume workflow above. + +## 5. Training + +Switch to the **Training** tab in the Finetune section. + +![Training tab](screenshots/finetune_training_tab.png) + +### Training configuration options + +| Parameter | Description | +|---|---| +| **Checkpoint Path** | (Optional, Advanced) Override the base model checkpoint to finetune from. Leave empty to auto-detect from the model configuration or script. | +| **LoRA Rank** | Controls the number of trainable parameters. The current UI exposes `4`, `8`, `16`, and `64`. Higher rank = more capacity and more memory use. | +| **Number of Epochs** | How many passes over the training data. The UI currently defaults to `20`. | +| **Batch Size** | Number of samples per training step. The UI currently exposes `1`, `2`, `4`, `8`, `16`, and `32`. Higher = faster but uses more GPU memory. | +| **Learning Rate** | Step size for optimization. The UI currently exposes values from `1e-7` through `1e-1`, with `1e-4` as the standard default. | +| **Loss Function** | The training objective. The current UI exposes **Margin**, **MSE**, **BCE**, **Dice**, and **Combined (Dice + BCE)**. **Margin** is the default and is generally the best fit for sparse scribble-style annotations. | +| **Margin** | Only used when **Loss Function** is set to **Margin**. Controls how strict the margin loss is; smaller values provide more learning signal, while larger values create a wider no-gradient band. | +| **Distillation Weight** | Keeps the finetuned model close to the original model's predictions. The UI currently exposes `0`, `0.01`, `0.05`, `0.1`, `0.2`, `0.5`, `1.0`, `2.0`, `5.0`, and `10.0`, with `0.1` as the current default. Set to `0` to disable distillation. | +| **Distillation Scope** | (Advanced) Where to apply distillation loss — **Unlabeled** (only on unannotated voxels) or **All** (everywhere). | +| **Label Smoothing** | Softens hard `0/1` targets. Useful when annotations are noisy; set to `0` if you want sharp targets. | +| **Balance fg/bg classes** | Weights foreground and background equally in the loss regardless of how much of each you've annotated. Prevents the model from overpredicting whichever class dominates the scribbles. | +| **GPU Queue** | Which GPU queue to submit the training job to (e.g. H100, H200). | +| **Auto-load model after training** | When checked, the finetuned model will automatically start an inference server and be added to the Neuroglancer viewer once training completes. | + +### Start training + +Click **Start Finetuning** to submit the training job to the GPU cluster. You can monitor training progress via the live log stream in the Training tab. + +## 6. Iterative Refinement + +After reviewing the finetuned model's predictions in Neuroglancer: + +1. Add more annotations or correct existing ones in the annotation volume. +2. Go back to the **Training** tab. +3. Click **Restart Finetuning** — this retrains on the same GPU using your updated annotations without needing to resubmit a new job. +4. Updated parameters (epochs, learning rate, loss, etc.) can be changed before restarting. + +Repeat this annotate-train-review cycle until the model performs well on your data. diff --git a/docs/screenshots/finetune_annotation_crops.png b/docs/screenshots/finetune_annotation_crops.png new file mode 100644 index 0000000..e436693 Binary files /dev/null and b/docs/screenshots/finetune_annotation_crops.png differ diff --git a/docs/screenshots/finetune_draw_tab.png b/docs/screenshots/finetune_draw_tab.png new file mode 100644 index 0000000..a55ed1b Binary files /dev/null and b/docs/screenshots/finetune_draw_tab.png differ diff --git a/docs/screenshots/finetune_training_tab.png b/docs/screenshots/finetune_training_tab.png new file mode 100644 index 0000000..ad21af3 Binary files /dev/null and b/docs/screenshots/finetune_training_tab.png differ diff --git a/example/model_spec_affinities.py b/example/model_spec_affinities.py index 46c4e7b..1f5ee71 100644 --- a/example/model_spec_affinities.py +++ b/example/model_spec_affinities.py @@ -1,15 +1,17 @@ -#%% +# %% # pip install fly-organelles from funlib.geometry.coordinate import Coordinate import torch import funlib.learn.torch import numpy as np + voxel_size = (16, 16, 16) read_shape = Coordinate((178, 178, 178)) * Coordinate(voxel_size) write_shape = Coordinate((56, 56, 56)) * Coordinate(voxel_size) output_voxel_size = Coordinate((16, 16, 16)) -#%% + +# %% class StandardUnet(torch.nn.Module): def __init__( self, @@ -43,19 +45,23 @@ def __init__( constant_upsample=True, ) - self.final_conv = torch.nn.Conv3d(num_fmaps, out_channels, (1, 1, 1), padding="valid") + self.final_conv = torch.nn.Conv3d( + num_fmaps, out_channels, (1, 1, 1), padding="valid" + ) def forward(self, raw): x = self.unet_backbone(raw) return self.final_conv(x) -#%% + + +# %% def load_eval_model(num_labels, checkpoint_path): model_backbone = StandardUnet(num_labels) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") - print("device:", device) + print("device:", device) checkpoint = torch.load(checkpoint_path, weights_only=True, map_location=device) model_backbone.load_state_dict(checkpoint["model_state_dict"]) model = torch.nn.Sequential(model_backbone, torch.nn.Sigmoid()) @@ -64,9 +70,11 @@ def load_eval_model(num_labels, checkpoint_path): return model -classes = ["mito",]*9 +classes = [ + "mito", +] * 9 CHECKPOINT_PATH = "/groups/cellmap/cellmap/zouinkhim/c-elegen/v2/train/fly_run/all/affinities/new/run04_mito/model_checkpoint_65000" output_channels = len(classes) model = load_eval_model(output_channels, CHECKPOINT_PATH) -block_shape = np.array((56, 56, 56,output_channels)) +block_shape = np.array((56, 56, 56, output_channels)) # %% diff --git a/pyproject.toml b/pyproject.toml index fb62df6..b362caa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,16 @@ postprocess = ["edt", "mwatershed", "funlib.math",] +finetune = [ + "peft>=0.7.0", # HuggingFace Parameter-Efficient Fine-Tuning + "transformers>=4.35.0", # Required by peft + "accelerate>=0.20.0", # Training utilities + # MinIO server and client are required for annotations but not available on PyPI. + # Install via conda-forge: mamba install minio-server minio-client -c conda-forge + # Neuroglancer with voxel annotation support is required but not yet on PyPI. + # Install via: pip install git+https://github.com/briossant/neuroglancer@feature/voxel-annotation +] + [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" diff --git a/tests/finetune/__init__.py b/tests/finetune/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/finetune/test_finetune_cli.py b/tests/finetune/test_finetune_cli.py new file mode 100644 index 0000000..7dc2f94 --- /dev/null +++ b/tests/finetune/test_finetune_cli.py @@ -0,0 +1,30 @@ +"""Tests for finetuning CLI parsing.""" + +import unittest + +from cellmap_flow.finetune.finetune_cli import build_arg_parser + + +class FinetuneCliParserTests(unittest.TestCase): + def test_parser_accepts_script_model_type(self): + parser = build_arg_parser() + + args = parser.parse_args( + [ + "--model-type", + "script", + "--model-script", + "/tmp/model.py", + "--corrections", + "/tmp/corrections", + "--output-dir", + "/tmp/output", + ] + ) + + self.assertEqual(args.model_type, "script") + self.assertEqual(args.model_script, "/tmp/model.py") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/finetune/test_finetune_job_manager.py b/tests/finetune/test_finetune_job_manager.py new file mode 100644 index 0000000..413f637 --- /dev/null +++ b/tests/finetune/test_finetune_job_manager.py @@ -0,0 +1,135 @@ +"""Tests for finetuning job manager helpers and metadata.""" + +import json +import sys +import tempfile +import unittest +from datetime import datetime +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +from cellmap_flow.finetune.finetune_job_manager import ( + FinetuneJob, + FinetuneJobManager, + JobStatus, +) + + +class DummyScriptModelConfig: + cli_name = "script" + + def __init__(self): + self.name = "dummy_script_model" + self.script_path = "/tmp/dummy_model.py" + self.channels = ["mito"] + self.input_voxel_size = [8, 8, 8] + self.output_voxel_size = [8, 8, 8] + + +class DummyThread: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.started = False + + def start(self): + self.started = True + + +class FinetuneJobManagerTests(unittest.TestCase): + def test_submit_job_uses_console_script_and_preserves_scheduler_metadata(self): + manager = FinetuneJobManager() + model_config = DummyScriptModelConfig() + + with tempfile.TemporaryDirectory() as tmpdir: + corrections_dir = Path(tmpdir) / "corrections" + correction = corrections_dir / "crop_1.zarr" + correction.mkdir(parents=True) + (correction / ".zattrs").write_text(json.dumps({"dataset_path": "/data/raw.zarr"})) + + fake_job = SimpleNamespace(process=SimpleNamespace(pid=1234)) + + with patch( + "cellmap_flow.finetune.finetune_job_manager.is_bsub_available", + return_value=False, + ), patch( + "cellmap_flow.finetune.finetune_job_manager.run_locally", + return_value=fake_job, + ), patch( + "cellmap_flow.finetune.finetune_job_manager.threading.Thread", + DummyThread, + ): + job = manager.submit_finetuning_job( + model_config=model_config, + corrections_path=corrections_dir, + output_base=Path(tmpdir), + queue="gpu_a100", + charge_group="my_lab", + ) + + metadata = json.loads((job.output_dir / "metadata.json").read_text()) + command = metadata["command"] + + self.assertIn(sys.executable, command) + self.assertIn("-m cellmap_flow.finetune.finetune_cli", command) + self.assertNotIn( + "stdbuf -oL python -m cellmap_flow.finetune.finetune_cli", + command, + ) + self.assertIn("--model-type script", command) + self.assertIn("--model-script /tmp/dummy_model.py", command) + self.assertEqual(metadata["queue"], "gpu_a100") + self.assertEqual(metadata["charge_group"], "my_lab") + + def test_complete_job_uses_metadata_scheduler_settings_in_yaml(self): + manager = FinetuneJobManager() + + with tempfile.TemporaryDirectory() as tmpdir: + session_dir = Path(tmpdir) + output_dir = session_dir / "finetuning" / "runs" / "run_1" + output_dir.mkdir(parents=True) + + adapter_dir = output_dir / "lora_adapter" + adapter_dir.mkdir() + (adapter_dir / "adapter_model.bin").write_bytes(b"adapter") + (adapter_dir / "adapter_config.json").write_text("{}") + + corrections_dir = session_dir / "corrections" + correction = corrections_dir / "crop_1.zarr" + correction.mkdir(parents=True) + (correction / ".zattrs").write_text(json.dumps({"dataset_path": "/data/raw.zarr"})) + + metadata = { + "corrections_path": str(corrections_dir), + "model_type": "script", + "model_script": "/tmp/dummy_model.py", + "queue": "gpu_l40s", + "charge_group": "cellmap-special", + } + (output_dir / "metadata.json").write_text(json.dumps(metadata)) + + job = FinetuneJob( + job_id="job-1", + lsf_job=None, + model_name="dummy_script_model", + output_dir=output_dir, + params={}, + status=JobStatus.COMPLETED, + created_at=datetime(2025, 1, 2, 3, 4, 5), + log_file=output_dir / "training_log.txt", + ) + job.current_epoch = 3 + job.latest_loss = 0.25 + + manager.complete_job(job) + + self.assertIsNotNone(job.model_yaml_path) + yaml_text = Path(job.model_yaml_path).read_text() + self.assertIn("queue: gpu_l40s", yaml_text) + self.assertIn("charge_group: cellmap-special", yaml_text) + self.assertIn("type: finetune", yaml_text) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/finetune/test_finetune_service.py b/tests/finetune/test_finetune_service.py new file mode 100644 index 0000000..1c8bfba --- /dev/null +++ b/tests/finetune/test_finetune_service.py @@ -0,0 +1,47 @@ +"""Tests for finetuning dashboard service helpers.""" + +import tempfile +import unittest +from pathlib import Path +from types import SimpleNamespace + +from cellmap_flow.dashboard.routes.finetune.service import ( + _autodetect_output_type, + _build_restart_params, +) + + +class FinetuneServiceHelperTests(unittest.TestCase): + def test_build_restart_params_maps_distillation_scope(self): + params = _build_restart_params( + { + "batch_size": 4, + "loss_type": "margin", + "distillation_scope": "all", + "offsets": [[1, 0, 0]], + } + ) + + self.assertEqual(params["batch_size"], 4) + self.assertEqual(params["loss_type"], "margin") + self.assertEqual(params["distillation_all_voxels"], True) + self.assertEqual(params["offsets"], [[1, 0, 0]]) + + def test_autodetect_output_type_reads_script_offsets(self): + with tempfile.TemporaryDirectory() as tmpdir: + script_path = Path(tmpdir) / "model.py" + script_path.write_text("offsets = [[1, 0, 0], [0, 1, 0]]\n") + model_config = SimpleNamespace(script_path=str(script_path)) + + output_type, offsets = _autodetect_output_type( + model_config, + output_type=None, + offsets=None, + ) + + self.assertEqual(output_type, "affinities") + self.assertEqual(offsets, "[[1, 0, 0], [0, 1, 0]]") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/finetune/test_lora_grad_flow.py b/tests/finetune/test_lora_grad_flow.py new file mode 100644 index 0000000..86793c0 --- /dev/null +++ b/tests/finetune/test_lora_grad_flow.py @@ -0,0 +1,495 @@ +"""Regression tests for LoRA gradient flow. + +Background +---------- +On 2026-04-28 a c-elegans script-path training run produced bit-for-bit +constant loss across many epochs. Diagnostic logging revealed +``mean|grad|=0.000e+00`` on the watched LoRA-B layer for every batch in +every epoch. Two compounding bugs were involved: + +1. ``VirtualPatchDataset._worker_rng`` reseeded on every ``__getitem__``, + so every patch was identical -- masked the symptom for a while. +2. The LoRA wrap on the script-path Sequential model produced trainable + parameters that received zero gradient for some configurations + (notably with distillation enabled, where the trainer toggles + ``disable_adapter_layers()`` / ``enable_adapter_layers()`` around the + teacher pass). + +These tests assert at the wrap layer that: + - Every ``lora_B`` weight gets a nonzero gradient after one forward + + backward through the wrapped model. + - The toggle dance leaves the model in a state where ``lora_B`` still + receives gradient on the next forward. + +Tiny synthetic UNet-style model (a few Conv3d blocks) is used so the +tests run in seconds on CPU. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import pytest + + +def _tiny_sequential(): + """Mimic the script-path layout: nn.Sequential of 3D conv blocks.""" + return nn.Sequential( + nn.Conv3d(1, 4, kernel_size=3, padding=1), + nn.ReLU(inplace=False), + nn.Conv3d(4, 4, kernel_size=3, padding=1), + nn.ReLU(inplace=False), + nn.Conv3d(4, 1, kernel_size=1), + nn.Sigmoid(), + ) + + +def _gradient_summary(model: nn.Module) -> dict[str, dict[str, float]]: + """Return per-trainable-param mean|grad| (or None if grad is None).""" + summary = {} + for name, p in model.named_parameters(): + if not p.requires_grad: + continue + summary[name] = { + "mean_abs_grad": ( + None if p.grad is None else p.grad.detach().abs().mean().item() + ), + "numel": p.numel(), + } + return summary + + +def _assert_lora_b_grads_nonzero(model: nn.Module) -> None: + """Assert every ``lora_B`` weight got a nonzero gradient.""" + summary = _gradient_summary(model) + lora_b_grads = { + name: info for name, info in summary.items() if "lora_B" in name + } + assert lora_b_grads, ( + "Expected at least one trainable lora_B weight after wrap; got none. " + "wrap_model_with_lora may have failed to attach adapters." + ) + zero = { + name: info for name, info in lora_b_grads.items() + if info["mean_abs_grad"] in (None, 0.0) + } + assert not zero, ( + "Some lora_B weights received no gradient after fwd+bwd; the LoRA " + "branch isn't on the autograd path for these layers:\n" + + "\n".join(f" {name}: mean|grad|={info['mean_abs_grad']!r}" + for name, info in zero.items()) + ) + + +def test_basic_lora_wrap_grad_flow(): + """One fwd+bwd through a tiny PEFT-wrapped Sequential should give every + lora_B nonzero gradient.""" + from cellmap_flow.finetune.lora_wrapper import wrap_model_with_lora + + base = _tiny_sequential() + peft = wrap_model_with_lora(base, lora_r=4, lora_alpha=8, lora_dropout=0.0) + peft.train() + + x = torch.randn(1, 1, 8, 8, 8) + y = peft(x) + loss = y.float().pow(2).mean() + loss.backward() + + _assert_lora_b_grads_nonzero(peft) + + +def test_lora_wrap_grad_flow_after_disable_enable_toggle(): + """The trainer's distillation pass calls ``disable_adapter_layers()`` + before the teacher forward and ``enable_adapter_layers()`` after. + + If ``enable_adapter_layers()`` doesn't fully restore state, the + student forward goes through the base model only and lora_B receives + no gradient. This test exercises that exact dance.""" + from cellmap_flow.finetune.lora_wrapper import wrap_model_with_lora + + base = _tiny_sequential() + peft = wrap_model_with_lora(base, lora_r=4, lora_alpha=8, lora_dropout=0.0) + peft.train() + + x = torch.randn(1, 1, 8, 8, 8) + + # Teacher pass with adapters disabled (mirrors trainer). + with torch.no_grad(): + peft.disable_adapter_layers() + try: + _teacher = peft(x) + finally: + peft.enable_adapter_layers() + + # Student pass — should activate LoRA branch and propagate gradient. + y = peft(x) + loss = y.float().pow(2).mean() + loss.backward() + + _assert_lora_b_grads_nonzero(peft) + + +def test_lora_wrap_grad_flow_with_batch_loop_wrapper(): + """The trainer wraps UnflattenedModule with BatchLoopWrapper *before* + PEFT, so PEFT sees ``BatchLoopWrapper(model)``. Verify gradient flows + through BatchLoopWrapper at batch_size > 1 (where the loop actually + runs) -- a stale issue we suspected when the loop returns + ``torch.cat`` of N separate forward calls.""" + from cellmap_flow.finetune.lora_wrapper import ( + BatchLoopWrapper, + wrap_model_with_lora, + ) + + base = BatchLoopWrapper(_tiny_sequential()) + peft = wrap_model_with_lora(base, lora_r=4, lora_alpha=8, lora_dropout=0.0) + peft.train() + + # Batch size > 1 forces BatchLoopWrapper to actually iterate and cat. + x = torch.randn(3, 1, 8, 8, 8) + y = peft(x) + assert y.shape[0] == 3 + loss = y.float().pow(2).mean() + loss.backward() + + _assert_lora_b_grads_nonzero(peft) + + +def test_lambda_normalizer_is_picklable(): + """LambdaNormalizer used to store an eval()'d ``lambda`` on the + instance. PyTorch DataLoader workers spawned via ``multiprocessing_context + ='spawn'`` pickle the dataset (and therefore the normalizers) before + starting -- lambdas can't be pickled, which crashed training before any + batches ran. Regression test: a normalizer must round-trip through + pickle and still produce the right output.""" + import pickle + import numpy as np + + from cellmap_flow.norm.input_normalize import LambdaNormalizer + + n = LambdaNormalizer("x*2-1") + pickled = pickle.dumps(n) + n2 = pickle.loads(pickled) + out = n2._process(np.array([0.5, 1.0])) + assert out[0] == 0.0 and out[1] == 1.0, f"unexpected: {out}" + + +def test_virtual_patch_dataset_applies_input_norm(): + """Regression test for the train/inference normalization mismatch bug. + + The dashboard's inference path normalizes raw via ``g.input_norms`` + before feeding the model. The trainer is a separate LSF process where + ``g.input_norms`` is empty -- so without an explicit per-dataset + normalizer, the trainer trained the model on raw uint8 [0, 255] while + inference fed it [-1, 1]. The trained model was nonsense at inference + time. Asserts that VirtualPatchDataset, given an ``input_norm_config`` + matching the dashboard's typical config, returns raw patches in the + expected normalized range -- not raw uint8. + """ + import numpy as np + import zarr + import tempfile + import os + + from cellmap_flow.finetune.virtual_dataset import VirtualPatchDataset + + tmp = tempfile.mkdtemp() + raw_path = os.path.join(tmp, "raw.zarr") + g = zarr.open_group(raw_path, mode="w") + g.create_dataset("s0", shape=(32, 32, 32), dtype="uint8", chunks=(16, 16, 16)) + g["s0"][:] = np.full((32, 32, 32), 128, dtype=np.uint8) # constant + g.attrs["multiscales"] = [{ + "version": "0.4", + "axes": [{"name": a, "type": "space", "unit": "nanometer"} for a in "zyx"], + "datasets": [{"path": "s0", "coordinateTransformations": [ + {"type": "scale", "scale": [16.0, 16.0, 16.0]}, + {"type": "translation", "translation": [0.0, 0.0, 0.0]}, + ]}], + }] + + vol_path = os.path.join(tmp, "vol.zarr") + v = zarr.open_group(vol_path, mode="w") + v.create_group("annotation").create_dataset( + "s0", shape=(32, 32, 32), chunks=(16, 16, 16), dtype="uint8", fill_value=0 + ) + arr = v["annotation"]["s0"][:] + arr[4:28, 4:28, 4:28] = 2 + v["annotation"]["s0"][:] = arr + v.attrs["dataset_offset_nm"] = [0.0, 0.0, 0.0] + v["annotation"].attrs["multiscales"] = g.attrs["multiscales"] + + common = dict( + volume_zarr_path=vol_path, + raw_dataset_path=raw_path, + input_size_voxels=(8, 8, 8), + output_size_voxels=(4, 4, 4), + input_voxel_size_nm=(16, 16, 16), + output_voxel_size_nm=(16, 16, 16), + patches_per_epoch=4, + seed=0, + ) + + # Without input_norm: raw is returned as native uint8 ~128. + raw_unnormalized, _ = VirtualPatchDataset(input_norm_config=None, **common)[0] + assert ( + 110 < float(raw_unnormalized.min()) < 140 + ), ( + "Without input_norm, raw should pass through ~uint8 (~128); got " + f"range [{raw_unnormalized.min()}, {raw_unnormalized.max()}]" + ) + + # With the dashboard's typical input_norm, raw should land in [-1, 1]. + # 128 / 255 * 2 - 1 = 0.0039. + raw_normalized, _ = VirtualPatchDataset( + input_norm_config={ + "MinMaxNormalizer": {"min_value": 0, "max_value": 255, "invert": False}, + "LambdaNormalizer": {"expression": "x*2-1"}, + }, + **common, + )[0] + rmin = float(raw_normalized.min()) + rmax = float(raw_normalized.max()) + assert -0.05 < rmin < 0.05 and -0.05 < rmax < 0.05, ( + "With input_norm, raw should be normalized to [-1, 1] range " + f"(expect ~0.004); got [{rmin}, {rmax}]" + ) + + +def test_virtual_patch_dataset_rng_advances(): + """Regression test for a bug where ``VirtualPatchDataset._worker_rng`` + reseeded on every ``__getitem__`` call -- making every patch identical + and silently breaking training. Two consecutive draws should yield + different RNG samples.""" + import numpy as np + import zarr + import tempfile + import os + + from cellmap_flow.finetune.virtual_dataset import VirtualPatchDataset + + tmp = tempfile.mkdtemp() + raw_path = os.path.join(tmp, "raw.zarr") + g = zarr.open_group(raw_path, mode="w") + g.create_dataset("s0", shape=(32, 32, 32), dtype="uint8", chunks=(16, 16, 16)) + g["s0"][:] = np.random.randint(0, 255, (32, 32, 32), dtype=np.uint8) + g.attrs["multiscales"] = [{ + "version": "0.4", + "axes": [{"name": a, "type": "space", "unit": "nanometer"} for a in "zyx"], + "datasets": [{"path": "s0", "coordinateTransformations": [ + {"type": "scale", "scale": [16.0, 16.0, 16.0]}, + {"type": "translation", "translation": [0.0, 0.0, 0.0]}, + ]}], + }] + + vol_path = os.path.join(tmp, "vol.zarr") + v = zarr.open_group(vol_path, mode="w") + v.create_group("annotation").create_dataset( + "s0", shape=(32, 32, 32), chunks=(16, 16, 16), dtype="uint8", fill_value=0 + ) + arr = v["annotation"]["s0"][:] + arr[4:28, 4:28, 4:28] = 2 + v["annotation"]["s0"][:] = arr + v.attrs["dataset_offset_nm"] = [0.0, 0.0, 0.0] + v["annotation"].attrs["multiscales"] = g.attrs["multiscales"] + + ds = VirtualPatchDataset( + volume_zarr_path=vol_path, + raw_dataset_path=raw_path, + input_size_voxels=(8, 8, 8), + output_size_voxels=(4, 4, 4), + input_voxel_size_nm=(16, 16, 16), + output_voxel_size_nm=(16, 16, 16), + patches_per_epoch=10, + seed=0, + ) + + # Pull raw patches from several draws — they should not all be identical. + raws = [ds[i][0].numpy().tobytes() for i in range(8)] + assert len(set(raws)) > 1, ( + "All draws produced identical raw patches; the per-worker RNG is " + "being re-seeded on every __getitem__ instead of advancing." + ) + + +def test_virtual_patch_dataset_stratified_sampling(): + """Regression test for two-pool stratified sampling. Without it, a + session with a 600^3 imported crop (~40M FG voxels) and a small + painted scribble (~hundreds of voxels) would draw 99.99% of patches + from the dense crop and effectively ignore the scribble. Stratified + sampling with default ratio=0.5 must give the sparse pool a real + share of the patches. + """ + import numpy as np + import zarr + import tempfile + import os + + from cellmap_flow.finetune.virtual_dataset import VirtualPatchDataset + + tmp = tempfile.mkdtemp() + raw_path = os.path.join(tmp, "raw.zarr") + g = zarr.open_group(raw_path, mode="w") + g.create_dataset("s0", shape=(64, 64, 64), dtype="uint8", chunks=(16, 16, 16)) + g["s0"][:] = np.full((64, 64, 64), 128, dtype=np.uint8) + g.attrs["multiscales"] = [{ + "version": "0.4", + "axes": [{"name": a, "type": "space", "unit": "nanometer"} for a in "zyx"], + "datasets": [{"path": "s0", "coordinateTransformations": [ + {"type": "scale", "scale": [16.0, 16.0, 16.0]}, + {"type": "translation", "translation": [0.0, 0.0, 0.0]}, + ]}], + }] + + # Volume zarr: simulate one big imported crop (large dense FG region) + # plus a tiny painted scribble outside its bbox. + vol_path = os.path.join(tmp, "vol.zarr") + v = zarr.open_group(vol_path, mode="w") + v.create_group("annotation").create_dataset( + "s0", shape=(64, 64, 64), chunks=(16, 16, 16), dtype="uint8", fill_value=0 + ) + arr = v["annotation"]["s0"][:] + # Dense imported crop: 32^3 region (~32K FG voxels) + arr[0:32, 0:32, 0:32] = 2 + # Sparse scribble: 2^3 region outside the imported crop (~8 FG voxels) + arr[40:42, 40:42, 40:42] = 2 + v["annotation"]["s0"][:] = arr + v.attrs["dataset_offset_nm"] = [0.0, 0.0, 0.0] + v.attrs["imported_crops"] = [ + { + "path": "/fake/crop.zarr", + "name": None, + "annotation_offset_voxels": [0, 0, 0], + "annotation_shape_voxels": [32, 32, 32], + "n_fg_voxels": 32 ** 3, + } + ] + v["annotation"].attrs["multiscales"] = g.attrs["multiscales"] + + common = dict( + volume_zarr_path=vol_path, + raw_dataset_path=raw_path, + input_size_voxels=(8, 8, 8), + output_size_voxels=(4, 4, 4), + input_voxel_size_nm=(16, 16, 16), + output_voxel_size_nm=(16, 16, 16), + seed=0, + ) + + # Default (auto): both pools exist → ratio resolves to 0.5. Roughly + # half the patch anchors should land in the sparse region. + ds = VirtualPatchDataset(**common, patches_per_epoch=200) + assert abs(ds._effective_dense_ratio - 0.5) < 1e-9 + sparse_hits = 0 + dense_hits = 0 + for _ in range(200): + rng = ds._worker_rng() + use_dense = ( + ds._effective_dense_ratio >= 1.0 + or (ds._effective_dense_ratio > 0.0 and rng.random() < ds._effective_dense_ratio) + ) + pool = ds._fg_index_dense if use_dense else ds._fg_index_sparse + anchor = pool[rng.integers(0, pool.shape[0])] + # Voxel in [40, 42)^3 came from the sparse scribble; rest from dense. + if (anchor >= 40).all() and (anchor < 42).all(): + sparse_hits += 1 + else: + dense_hits += 1 + # With ratio=0.5 over 200 draws we expect ~100 sparse hits. Allow a + # wide band so the test isn't flaky; the failure mode we're guarding + # against (no stratification) would give 0-1 sparse hits. + assert sparse_hits > 50, ( + f"Stratified sampling gave only {sparse_hits}/200 sparse hits; " + "expected ~100. Two-pool sampling is not active." + ) + assert dense_hits > 50, ( + f"Stratified sampling gave only {dense_hits}/200 dense hits; " + "expected ~100." + ) + + # Auto-degrade: explicit ratio=0.5 but only one pool populated. + # Build a volume with NO imported_crops → all FG goes to sparse pool; + # ratio should clamp to 0.0 so we don't try to sample an empty dense. + vol_no_crops = os.path.join(tmp, "vol_no_crops.zarr") + v2 = zarr.open_group(vol_no_crops, mode="w") + v2.create_group("annotation").create_dataset( + "s0", shape=(32, 32, 32), chunks=(16, 16, 16), dtype="uint8", fill_value=0 + ) + a2 = v2["annotation"]["s0"][:] + a2[8:24, 8:24, 8:24] = 2 + v2["annotation"]["s0"][:] = a2 + v2.attrs["dataset_offset_nm"] = [0.0, 0.0, 0.0] + v2["annotation"].attrs["multiscales"] = g.attrs["multiscales"] + ds2 = VirtualPatchDataset( + volume_zarr_path=vol_no_crops, + raw_dataset_path=raw_path, + input_size_voxels=(8, 8, 8), + output_size_voxels=(4, 4, 4), + input_voxel_size_nm=(16, 16, 16), + output_voxel_size_nm=(16, 16, 16), + patches_per_epoch=4, + dense_to_sparse_ratio=0.5, # explicit, but should clamp + seed=0, + ) + assert ds2._effective_dense_ratio == 0.0, ( + "With no imported_crops the dense pool is empty; ratio should " + f"clamp to 0.0, got {ds2._effective_dense_ratio}" + ) + assert ds2._fg_index_dense.shape[0] == 0 + assert ds2._fg_index_sparse.shape[0] > 0 + + +def test_virtual_patch_dataset_default_patches_per_epoch(): + """Regression test: ``patches_per_epoch=None`` (the new default) means + "cover every populated chunk roughly once per epoch" -- the dataset + substitutes the populated-chunk count at index build time. + """ + import numpy as np + import zarr + import tempfile + import os + + from cellmap_flow.finetune.virtual_dataset import VirtualPatchDataset + + tmp = tempfile.mkdtemp() + raw_path = os.path.join(tmp, "raw.zarr") + g = zarr.open_group(raw_path, mode="w") + g.create_dataset("s0", shape=(48, 48, 48), dtype="uint8", chunks=(16, 16, 16)) + g["s0"][:] = np.full((48, 48, 48), 128, dtype=np.uint8) + g.attrs["multiscales"] = [{ + "version": "0.4", + "axes": [{"name": a, "type": "space", "unit": "nanometer"} for a in "zyx"], + "datasets": [{"path": "s0", "coordinateTransformations": [ + {"type": "scale", "scale": [16.0, 16.0, 16.0]}, + {"type": "translation", "translation": [0.0, 0.0, 0.0]}, + ]}], + }] + + vol_path = os.path.join(tmp, "vol.zarr") + v = zarr.open_group(vol_path, mode="w") + # 48/16 = 3 chunks per dim → 27 total chunks, but we'll only populate 3. + v.create_group("annotation").create_dataset( + "s0", shape=(48, 48, 48), chunks=(16, 16, 16), dtype="uint8", fill_value=0 + ) + arr = v["annotation"]["s0"][:] + # Three populated chunks (each chunk gets at least one FG voxel). + arr[1, 1, 1] = 2 + arr[17, 17, 17] = 2 + arr[33, 33, 33] = 2 + v["annotation"]["s0"][:] = arr + v.attrs["dataset_offset_nm"] = [0.0, 0.0, 0.0] + v["annotation"].attrs["multiscales"] = g.attrs["multiscales"] + + ds = VirtualPatchDataset( + volume_zarr_path=vol_path, + raw_dataset_path=raw_path, + input_size_voxels=(8, 8, 8), + output_size_voxels=(4, 4, 4), + input_voxel_size_nm=(16, 16, 16), + output_voxel_size_nm=(16, 16, 16), + patches_per_epoch=None, # default → use populated chunk count + seed=0, + ) + assert ds.patches_per_epoch == 3, ( + f"Default patches_per_epoch should equal populated-chunk count " + f"(3), got {ds.patches_per_epoch}" + ) + assert len(ds) == 3 diff --git a/tests/finetune/test_target_transforms.py b/tests/finetune/test_target_transforms.py new file mode 100644 index 0000000..04e48f8 --- /dev/null +++ b/tests/finetune/test_target_transforms.py @@ -0,0 +1,200 @@ +"""Tests for target transforms.""" + +import torch +from cellmap_flow.finetune.target_transforms import ( + BinaryTargetTransform, + BroadcastBinaryTargetTransform, + AffinityTargetTransform, + _offset_slices, +) + + +def test_binary_transform_basic(): + """Test that BinaryTargetTransform produces correct targets and masks.""" + # annotation: 0=unannotated, 1=bg, 2=fg + annotation = torch.tensor([[[[[0, 1, 2, 0, 1]]]]]).float() # (1, 1, 1, 1, 5) + transform = BinaryTargetTransform() + target, mask = transform(annotation) + + # mask: 1 where annotated (>0) + assert mask.tolist() == [[[[[0, 1, 1, 0, 1]]]]] + # target: 0 for bg (was 1), 1 for fg (was 2), 0 for unannotated + assert target.tolist() == [[[[[0, 0, 1, 0, 0]]]]] + + +def test_binary_transform_multi_object(): + """Labels 2 and 3 both become foreground (1).""" + annotation = torch.tensor([[[[[1, 2, 3]]]]]).float() + transform = BinaryTargetTransform() + target, mask = transform(annotation) + + assert target.tolist() == [[[[[0, 1, 1]]]]] + assert mask.tolist() == [[[[[1, 1, 1]]]]] + + +def test_broadcast_transform(): + """Test broadcasting to multiple channels.""" + annotation = torch.tensor([[[[[0, 1, 2]]]]]).float() # (1, 1, 1, 1, 3) + transform = BroadcastBinaryTargetTransform(num_channels=3) + target, mask = transform(annotation) + + assert target.shape == (1, 3, 1, 1, 3) + assert mask.shape == (1, 3, 1, 1, 3) + # All channels should be identical + for c in range(3): + assert target[0, c].tolist() == [[[0, 0, 1]]] + assert mask[0, c].tolist() == [[[0, 1, 1]]] + + +def test_affinity_transform_same_object(): + """Two adjacent voxels of the same object should have affinity=1.""" + # 1D-like: [bg, obj2, obj2, bg] along X + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([1, 2, 2, 1]).float() + + offsets = [[0, 0, 1]] # X offset + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # target shape: (1, 1, 1, 1, 4) + assert target.shape == (1, 1, 1, 1, 4) + + # Pairs (along X, offset +1): + # (0,1): bg-obj2 -> 0, both annotated -> mask=1 + # (1,2): obj2-obj2 -> 1, both annotated -> mask=1 + # (2,3): obj2-bg -> 0, both annotated -> mask=1 + # Position 3 has no pair (boundary) -> target=0, mask=0 + assert target[0, 0, 0, 0, :3].tolist() == [0, 1, 0] + assert mask[0, 0, 0, 0, :3].tolist() == [1, 1, 1] + assert mask[0, 0, 0, 0, 3].item() == 0 # no pair for last voxel + + +def test_affinity_transform_different_objects(): + """Adjacent voxels of different objects should have affinity=0.""" + annotation = torch.zeros(1, 1, 1, 1, 3) + annotation[0, 0, 0, 0, :] = torch.tensor([2, 3, 2]).float() + + offsets = [[0, 0, 1]] + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # (0,1): obj2-obj3 -> 0 + # (1,2): obj3-obj2 -> 0 + assert target[0, 0, 0, 0, :2].tolist() == [0, 0] + assert mask[0, 0, 0, 0, :2].tolist() == [1, 1] + + +def test_affinity_transform_unannotated_masking(): + """Unannotated voxels should produce mask=0.""" + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([2, 0, 2, 1]).float() + + offsets = [[0, 0, 1]] + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # (0,1): obj2-unannotated -> mask=0 + # (1,2): unannotated-obj2 -> mask=0 + # (2,3): obj2-bg -> mask=1, target=0 + assert mask[0, 0, 0, 0, 0].item() == 0 + assert mask[0, 0, 0, 0, 1].item() == 0 + assert mask[0, 0, 0, 0, 2].item() == 1 + assert target[0, 0, 0, 0, 2].item() == 0 + + +def test_affinity_transform_multiple_offsets(): + """Test with Z, Y, X offsets.""" + annotation = torch.zeros(1, 1, 3, 3, 3) + # Fill with same object + annotation[:] = 2 + # Set corners to background + annotation[0, 0, 0, 0, 0] = 1 + + offsets = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + assert target.shape == (1, 3, 3, 3, 3) + assert mask.shape == (1, 3, 3, 3, 3) + + # All annotated (>0), so mask should be 1 everywhere there's a valid pair + # Z offset channel: mask=1 for z=0,1 (pairs with z+1 exist), mask=0 for z=2 + assert mask[0, 0, 2, :, :].sum().item() == 0 # no z+1 for z=2 + assert mask[0, 0, 0, :, :].sum().item() == 9 # all y,x pairs valid + assert mask[0, 0, 1, :, :].sum().item() == 9 + + # Corner (0,0,0) is bg, (1,0,0) is fg -> Z-offset affinity at (0,0,0) = 0 + assert target[0, 0, 0, 0, 0].item() == 0 + # (1,0,0) and (2,0,0) both fg -> Z-offset affinity at (1,0,0) = 1 + assert target[0, 0, 1, 0, 0].item() == 1 + + +def test_affinity_transform_negative_offset(): + """Test that negative offsets work correctly.""" + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([1, 2, 2, 1]).float() + + offsets = [[0, 0, -1]] # Negative X offset + transform = AffinityTargetTransform(offsets) + target, mask = transform(annotation) + + # With offset -1, source starts at index 1, dest starts at index 0 + # Pair (1,0): obj2-bg -> 0, both annotated -> mask=1 + # Pair (2,1): obj2-obj2 -> 1, both annotated -> mask=1 + # Pair (3,2): bg-obj2 -> 0, both annotated -> mask=1 + assert target[0, 0, 0, 0, 1].item() == 0 + assert target[0, 0, 0, 0, 2].item() == 1 + assert target[0, 0, 0, 0, 3].item() == 0 + assert mask[0, 0, 0, 0, 0].item() == 0 # no pair for index 0 + + +def test_offset_slices(): + """Test _offset_slices helper.""" + # Positive offset + src, dst = _offset_slices(10, 10, 10, 1, 0, 0) + assert src == (slice(None, 9), slice(None), slice(None)) + assert dst == (slice(1, None), slice(None), slice(None)) + + # Negative offset + src, dst = _offset_slices(10, 10, 10, 0, 0, -2) + assert src == (slice(None), slice(None), slice(2, None)) + assert dst == (slice(None), slice(None), slice(None, 8)) + + # Zero offset + src, dst = _offset_slices(10, 10, 10, 0, 0, 0) + assert src == (slice(None), slice(None), slice(None)) + assert dst == (slice(None), slice(None), slice(None)) + + +def test_affinity_transform_extra_channels_masked(): + """Extra channels (e.g. LSDs) should have mask=0.""" + annotation = torch.zeros(1, 1, 1, 1, 4) + annotation[0, 0, 0, 0, :] = torch.tensor([1, 2, 2, 1]).float() + + offsets = [[0, 0, 1]] # 1 affinity channel + transform = AffinityTargetTransform(offsets, num_channels=4) # 1 aff + 3 extra + target, mask = transform(annotation) + + assert target.shape == (1, 4, 1, 1, 4) + assert mask.shape == (1, 4, 1, 1, 4) + + # Channel 0 (affinity) should have valid mask + assert mask[0, 0, 0, 0, :3].sum().item() == 3 + # Channels 1-3 (extra, e.g. LSDs) should be fully masked out + assert mask[0, 1, :, :, :].sum().item() == 0 + assert mask[0, 2, :, :, :].sum().item() == 0 + assert mask[0, 3, :, :, :].sum().item() == 0 + + +if __name__ == "__main__": + test_binary_transform_basic() + test_binary_transform_multi_object() + test_broadcast_transform() + test_affinity_transform_same_object() + test_affinity_transform_different_objects() + test_affinity_transform_unannotated_masking() + test_affinity_transform_multiple_offsets() + test_affinity_transform_negative_offset() + test_offset_slices() + test_affinity_transform_extra_channels_masked() + print("All tests passed!")