diff --git a/autoflow/algorithms/__init__.py b/autoflow/algorithms/__init__.py index a67a1a7..1acae83 100755 --- a/autoflow/algorithms/__init__.py +++ b/autoflow/algorithms/__init__.py @@ -5,12 +5,15 @@ """ from .data import ( + LoadedCase, + LoaderCapabilities, _axis_pair, _need_flip, _permute_spatial, _flip_axes, reorient, _ensure_flow_mag_time_and_segmask, + normalize_loaded_case, _reorient_spatial_only, _compute_spatial_bbox, _target_bbox_to_source_slices, @@ -103,8 +106,11 @@ load_metrics_as_table, ) __all__ = [ + "LoadedCase", + "LoaderCapabilities", "reorient", "load_h5_data", + "normalize_loaded_case", "filter_segmask_labels", "binarize_segmask", "merge_segmask_to_3d", diff --git a/autoflow/algorithms/data.py b/autoflow/algorithms/data.py index de617ee..24e596b 100755 --- a/autoflow/algorithms/data.py +++ b/autoflow/algorithms/data.py @@ -1,7 +1,7 @@ import h5py import numpy as np -from .metrics import compute_tke_array_from_sigma +from ..case_types import LoadedCase, LoaderCapabilities def _axis_pair(a): @@ -128,6 +128,126 @@ def _ensure_flow_mag_time_and_segmask(flow, mag, segmask): ) +def _ensure_flow_mag_time(flow, mag): + flow = np.asarray(flow) + mag = np.asarray(mag) + if flow.ndim == 4 and flow.shape[-1] == 3: + flow = flow[..., np.newaxis, :] + if flow.ndim != 5 or flow.shape[-1] != 3: + raise ValueError(f"flow must be XYZTV or XYZV with 3 components, got {flow.shape}") + nt = int(flow.shape[3]) + if mag.ndim == 3: + mag = np.repeat(mag[..., np.newaxis], nt, axis=3) + elif mag.ndim == 4 and mag.shape[3] == 1 and nt > 1: + mag = np.repeat(mag, nt, axis=3) + elif mag.ndim != 4: + raise ValueError(f"mag must be XYZT or XYZ, got {mag.shape}") + if mag.shape[3] != nt: + if mag.shape[3] == 1: + mag = np.repeat(mag, nt, axis=3) + else: + raise ValueError(f"mag time dimension {mag.shape[3]} does not match flow {nt}") + return ( + np.ascontiguousarray(flow, dtype=np.float32), + np.ascontiguousarray(mag, dtype=np.float32), + ) + + +def _ensure_optional_time_volume(arr, nt, name, dtype): + if arr is None: + return None + arr = np.asarray(arr) + if arr.ndim == 3: + arr = np.repeat(arr[..., np.newaxis], nt, axis=3) + elif arr.ndim == 4 and arr.shape[3] == 1 and nt > 1: + arr = np.repeat(arr, nt, axis=3) + elif arr.ndim != 4: + raise ValueError(f"{name} must be XYZT or XYZ, got {arr.shape}") + if arr.shape[3] != nt: + if arr.shape[3] == 1: + arr = np.repeat(arr, nt, axis=3) + else: + raise ValueError(f"{name} time dimension {arr.shape[3]} does not match flow {nt}") + return np.ascontiguousarray(arr, dtype=dtype) + + +def _ensure_optional_sigma_time(sigma, nt): + if sigma is None: + return None + sigma = np.asarray(sigma) + if sigma.ndim == 4 and sigma.shape[-1] == 3: + sigma = sigma[..., np.newaxis, :] + elif sigma.ndim != 5 or sigma.shape[-1] != 3: + raise ValueError(f"sigma must be XYZTV or XYZV with 3 components, got {sigma.shape}") + if sigma.shape[3] != nt: + if sigma.shape[3] == 1: + sigma = np.repeat(sigma, nt, axis=3) + else: + raise ValueError(f"sigma time dimension {sigma.shape[3]} does not match flow {nt}") + return np.ascontiguousarray(sigma, dtype=np.float32) + + +def normalize_loaded_case( + *, + flow, + mag, + resolution, + origin, + venc, + rr, + segmentation=None, + tke_array=None, + sigma=None, + metadata=None, + source_format="", + source_group=None, + capabilities=None, +): + flow_out, mag_out = _ensure_flow_mag_time(flow, mag) + nt = int(flow_out.shape[3]) + seg_out = _ensure_optional_time_volume(segmentation, nt, "segmentation", np.int16) + tke_out = _ensure_optional_time_volume(tke_array, nt, "tke_array", np.float32) + sigma_out = _ensure_optional_sigma_time(sigma, nt) + if capabilities is None: + capabilities = LoaderCapabilities() + else: + capabilities = LoaderCapabilities( + has_segmentation=bool(capabilities.has_segmentation), + has_tke=bool(capabilities.has_tke), + has_complex_source=bool(capabilities.has_complex_source), + supports_wss=bool(capabilities.supports_wss), + supports_plane_metrics=bool(capabilities.supports_plane_metrics), + ) + capabilities.has_segmentation = seg_out is not None + capabilities.has_tke = bool(capabilities.has_tke or tke_out is not None) + + resolution = np.asarray(resolution, dtype=float).reshape(-1) + if resolution.size == 1: + resolution = np.repeat(resolution, 3) + origin = np.asarray(origin, dtype=float).reshape(-1) + if origin.size == 1: + origin = np.repeat(origin, 3) + venc = np.asarray(venc, dtype=float).reshape(-1) + if venc.size == 1: + venc = np.repeat(venc, 3) + + return LoadedCase( + flow=flow_out, + mag=mag_out, + segmentation=seg_out, + resolution=np.asarray(resolution[:3], dtype=float).reshape(3), + origin=np.asarray(origin[:3], dtype=float).reshape(3), + venc=np.asarray(venc[:3], dtype=float).reshape(3), + rr=float(rr), + sigma=sigma_out, + tke_array=tke_out, + metadata=dict(metadata or {}), + source_format=str(source_format or ""), + source_group=source_group, + capabilities=capabilities, + ) + + def _reorient_spatial_only(arr, spatial_order, target_spatial_order): arr_r, src_pos = _permute_spatial(arr, spatial_order, target_spatial_order, spatial_axes=(0, 1, 2)) flip_axes = [i for i in range(3) if _need_flip(spatial_order[src_pos[i]], target_spatial_order[i])] @@ -198,54 +318,88 @@ def load_h5_data(path): target_spatial_order = ("LR", "AP", "FH") target_venc_order = ("LR", "AP", "FH") with h5py.File(path, "r") as g: - if "img_complex" not in g or "segmask" not in g: - raise ValueError(f"h5 must contain img_complex and segmask: {path}") VENC = g["VENC"][:] if "VENC" in g else np.array([150, 150, 150], dtype=float) resolution = g["Resolution"][:] if "Resolution" in g else np.array([1, 1, 1], dtype=float) - origin = np.array([0.0, 0.0, 0.0], dtype=float) + origin = g["Origin"][:] if "Origin" in g else np.array([0.0, 0.0, 0.0], dtype=float) rr = float(g["RR"][()]) if "RR" in g else 1000.0 spatial_order = g["SpatialOrder"][:].astype(str) if "SpatialOrder" in g else np.array(["FH", "AP", "LR"]) venc_order = g["VENCOrder"][:].astype(str) if "VENCOrder" in g else np.array(["FH", "AP", "LR"]) - - segmask_ds = g["segmask"] - segmask_full = segmask_ds[:].astype(np.int16) - src_slices = _compute_spatial_bbox(segmask_full, pad=2) - segmask = segmask_ds[src_slices + (slice(None),) * (segmask_ds.ndim - 3)].astype(np.int16) - del segmask_full - - img_ds = g["img_complex"] - img_complex = np.asarray(img_ds[src_slices + (slice(None),) * (img_ds.ndim - 3)]) - - mag = np.abs(img_complex[..., 0]).astype(np.float32) - flow_raw = np.angle(img_complex[..., 1:4] * np.conj(img_complex[..., 0][..., None])).astype(np.float32) - sigma_raw = _sigma_from_complex(img_complex, VENC) - - flow, mag_out, seg_r, venc_new, res_new = reorient( - mag, flow_raw, segmask, venc=VENC, resolution=resolution, - spatial_order=spatial_order, venc_order=venc_order, - target_spatial_order=target_spatial_order, - target_venc_order=target_venc_order, - return_velocity=True, - ) - sigma = _reorient_component_abs( - sigma_raw, - spatial_order=spatial_order, - target_spatial_order=target_spatial_order, - venc_order=venc_order, - target_venc_order=target_venc_order, - ).astype(np.float32) - - flow, mag_out, seg_r = _ensure_flow_mag_time_and_segmask(flow, mag_out, seg_r) - tke_array = compute_tke_array_from_sigma(sigma, rho=1060.0) - - return { - "flow": flow, - "mag": mag_out, - "segmask": seg_r, - "resolution": np.asarray(res_new, dtype=float), - "origin": origin, - "venc": np.asarray(venc_new, dtype=float), - "rr": float(rr), - "sigma": sigma, - "tke_array": tke_array, - } + seg_name = "segmask" if "segmask" in g else "segmentation" if "segmentation" in g else None + + if "img_complex" in g: + img_ds = g["img_complex"] + if seg_name is not None: + segmask_ds = g[seg_name] + segmask_full = segmask_ds[:].astype(np.int16) + src_slices = _compute_spatial_bbox(segmask_full, pad=2) + segmask = segmask_ds[src_slices + (slice(None),) * (segmask_ds.ndim - 3)].astype(np.int16) + del segmask_full + else: + src_slices = tuple(slice(0, int(img_ds.shape[i])) for i in range(3)) + segmask = None + + img_complex = np.asarray(img_ds[src_slices + (slice(None),) * (img_ds.ndim - 3)]) + mag = np.abs(img_complex[..., 0]).astype(np.float32) + flow_raw = np.angle(img_complex[..., 1:4] * np.conj(img_complex[..., 0][..., None])).astype(np.float32) + sigma_raw = _sigma_from_complex(img_complex, VENC) + segmask_for_reorient = segmask if segmask is not None else np.zeros(mag.shape, dtype=np.int16) + + flow, mag_out, seg_r, venc_new, res_new = reorient( + mag, flow_raw, segmask_for_reorient, venc=VENC, resolution=resolution, + spatial_order=spatial_order, venc_order=venc_order, + target_spatial_order=target_spatial_order, + target_venc_order=target_venc_order, + return_velocity=True, + ) + sigma = _reorient_component_abs( + sigma_raw, + spatial_order=spatial_order, + target_spatial_order=target_spatial_order, + venc_order=venc_order, + target_venc_order=target_venc_order, + ).astype(np.float32) + return normalize_loaded_case( + flow=flow, + mag=mag_out, + segmentation=seg_r if segmask is not None else None, + resolution=np.asarray(res_new, dtype=float), + origin=origin, + venc=np.asarray(venc_new, dtype=float), + rr=float(rr), + sigma=sigma, + tke_array=None, + source_format="legacy_h5", + capabilities=LoaderCapabilities( + has_segmentation=segmask is not None, + has_tke=True, + has_complex_source=True, + supports_wss=True, + supports_plane_metrics=True, + ), + ) + + if "flow" in g and "mag" in g: + sigma = np.asarray(g["sigma"][:], dtype=np.float32) if "sigma" in g else None + tke_array = np.asarray(g["tke_array"][:], dtype=np.float32) if "tke_array" in g else None + segmentation = None if seg_name is None else np.asarray(g[seg_name][:], dtype=np.int16) + return normalize_loaded_case( + flow=np.asarray(g["flow"][:], dtype=np.float32), + mag=np.asarray(g["mag"][:], dtype=np.float32), + segmentation=segmentation, + resolution=resolution, + origin=origin, + venc=VENC, + rr=float(rr), + sigma=sigma, + tke_array=tke_array, + source_format="normalized_h5", + capabilities=LoaderCapabilities( + has_segmentation=segmentation is not None, + has_tke=tke_array is not None, + has_complex_source=False, + supports_wss=True, + supports_plane_metrics=True, + ), + ) + + raise ValueError(f"unsupported h5 layout: {path}") diff --git a/autoflow/algorithms/metrics.py b/autoflow/algorithms/metrics.py index f4636d9..7800fdf 100755 --- a/autoflow/algorithms/metrics.py +++ b/autoflow/algorithms/metrics.py @@ -276,9 +276,6 @@ def compute_derived_metrics(mask4d, flow, spacing, origin=(0, 0, 0), tube_radius=0.1, rho=1060.0, save_pixelwise=False, tke_array=None, sigma=None): mask4d = _ensure_mask4d(mask4d) - tke = compute_tke_metrics( - mask4d, spacing, origin=origin, tke_array=tke_array, sigma=sigma, rho=rho, - ) wss = compute_wss_metrics( mask4d, flow, spacing, origin=origin, smoothing_iteration=smoothing_iteration, @@ -287,26 +284,33 @@ def compute_derived_metrics(mask4d, flow, spacing, origin=(0, 0, 0), parabolic_fitting=parabolic_fitting, no_slip_condition=no_slip_condition, ) + tke = None + if tke_array is not None or sigma is not None: + tke = compute_tke_metrics( + mask4d, spacing, origin=origin, tke_array=tke_array, sigma=sigma, rho=rho, + ) spacing = np.asarray(spacing, dtype=float).reshape(3) origin = np.asarray(origin, dtype=float).reshape(3) result = { "wss_surfaces": wss["wss_surfaces"], "wss_volume": wss["wss_volume"], - "tke_volume": tke["tke_volume"], - "tke_array": tke["tke_array"], - "tke_peak": tke["tke_peak"], + "tke_volume": None if tke is None else tke["tke_volume"], + "tke_array": None if tke is None else tke["tke_array"], + "tke_peak": None if tke is None else tke["tke_peak"], "streamlines": [], "tube_radius": float(tube_radius), } if save_pixelwise: - result["pixelwise_export"] = { + pixelwise_export = { "wss": np.asarray(wss["wss_volume"], dtype=np.float32), - "tke": np.asarray(tke["tke_peak"], dtype=np.float32), - "tke_time": np.asarray(tke["tke_array"], dtype=np.float32), "spacing": np.asarray(spacing, dtype=np.float32), "origin": np.asarray(origin, dtype=np.float32), } + if tke is not None: + pixelwise_export["tke"] = np.asarray(tke["tke_peak"], dtype=np.float32) + pixelwise_export["tke_time"] = np.asarray(tke["tke_array"], dtype=np.float32) + result["pixelwise_export"] = pixelwise_export else: result["pixelwise_export"] = {} return result diff --git a/autoflow/case_types.py b/autoflow/case_types.py new file mode 100755 index 0000000..68ca80e --- /dev/null +++ b/autoflow/case_types.py @@ -0,0 +1,79 @@ +import copy +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +import numpy as np + + +@dataclass +class LoaderCapabilities: + has_segmentation: bool = False + has_tke: bool = False + has_complex_source: bool = False + supports_wss: bool = False + supports_plane_metrics: bool = False + + def to_dict(self): + return { + "has_segmentation": bool(self.has_segmentation), + "has_tke": bool(self.has_tke), + "has_complex_source": bool(self.has_complex_source), + "supports_wss": bool(self.supports_wss), + "supports_plane_metrics": bool(self.supports_plane_metrics), + } + + @staticmethod + def from_dict(d): + return LoaderCapabilities( + has_segmentation=bool(d.get("has_segmentation", False)), + has_tke=bool(d.get("has_tke", False)), + has_complex_source=bool(d.get("has_complex_source", False)), + supports_wss=bool(d.get("supports_wss", False)), + supports_plane_metrics=bool(d.get("supports_plane_metrics", False)), + ) + + +@dataclass +class LoadedCase: + mag: np.ndarray + flow: np.ndarray + resolution: np.ndarray + origin: np.ndarray + venc: np.ndarray + rr: float + segmentation: Optional[np.ndarray] = None + tke_array: Optional[np.ndarray] = None + sigma: Optional[np.ndarray] = None + metadata: Dict[str, Any] = field(default_factory=dict) + source_format: str = "" + source_group: Optional[str] = None + capabilities: LoaderCapabilities = field(default_factory=LoaderCapabilities) + + @property + def segmask(self): + return self.segmentation + + +@dataclass +class InputState: + source_format: str = "" + source_group: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + capabilities: LoaderCapabilities = field(default_factory=LoaderCapabilities) + + def to_dict(self): + return { + "source_format": self.source_format, + "source_group": self.source_group, + "metadata": copy.deepcopy(self.metadata), + "capabilities": self.capabilities.to_dict(), + } + + @staticmethod + def from_dict(d): + return InputState( + source_format=str(d.get("source_format", "")), + source_group=d.get("source_group"), + metadata=copy.deepcopy(d.get("metadata", {})), + capabilities=LoaderCapabilities.from_dict(d.get("capabilities", {})), + ) diff --git a/autoflow/core/__init__.py b/autoflow/core/__init__.py index e7ac33e..ce1f89d 100755 --- a/autoflow/core/__init__.py +++ b/autoflow/core/__init__.py @@ -1,4 +1,14 @@ -from .models import * -from .pipeline import PipelineEngine, StepResult +from .models import * # noqa: F401,F403 -__all__ = [name for name in globals() if not name.startswith("_")] +__all__ = [name for name in globals() if not name.startswith("_")] + ["PipelineEngine", "StepResult"] + + +def __getattr__(name): + if name in {"PipelineEngine", "StepResult"}: + from .pipeline import PipelineEngine, StepResult + + return { + "PipelineEngine": PipelineEngine, + "StepResult": StepResult, + }[name] + raise AttributeError(f"module 'autoflow.core' has no attribute {name!r}") diff --git a/autoflow/core/models.py b/autoflow/core/models.py index 2709cd5..ba0bf98 100755 --- a/autoflow/core/models.py +++ b/autoflow/core/models.py @@ -4,6 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple import numpy as np +from ..case_types import InputState, LoadedCase, LoaderCapabilities + class ObjectKind(Enum): SEGMENTATION = "Segmentation" @@ -296,6 +298,7 @@ class Workspace: plane_gen_params: PlaneGenerationParams = field(default_factory=PlaneGenerationParams) streamline_params: StreamlineParams = field(default_factory=StreamlineParams) derived_params: DerivedMetricsParams = field(default_factory=DerivedMetricsParams) + input_state: InputState = field(default_factory=InputState) resolution: np.ndarray = field(default_factory=lambda: np.array([1., 1., 1.])) origin: np.ndarray = field(default_factory=lambda: np.array([0., 0., 0.])) @@ -310,6 +313,8 @@ class Workspace: segmask_3d: Optional[np.ndarray] = None mag_raw: Optional[np.ndarray] = None + source_sigma: Optional[np.ndarray] = None + source_tke_array: Optional[np.ndarray] = None skeleton_points: Optional[np.ndarray] = None skeleton_mask: Optional[np.ndarray] = None @@ -349,6 +354,9 @@ def time_count(self): def has_flow(self): return self.flow_raw is not None + def has_segmentation(self): + return self.segmask_raw is not None + def unique_labels(self): if self.segmask_raw is None: return [] @@ -396,6 +404,7 @@ def reset_all(self): ("plane_gen_params", PlaneGenerationParams()), ("streamline_params", StreamlineParams()), ("derived_params", DerivedMetricsParams()), + ("input_state", InputState()), ]: setattr(self, attr, default) self.resolution = np.array([1., 1., 1.]) @@ -404,7 +413,7 @@ def reset_all(self): self.rr = 1000.0 for attr in ["segmask_raw", "segmask_labels", "segmask_binary", "segmask_3d", "skeleton_points", "skeleton_mask", "branch_labels", "flow_raw", - "streamline_seeds", "mag_raw"]: + "streamline_seeds", "mag_raw", "source_sigma", "source_tke_array"]: setattr(self, attr, None) self.graph = GraphData() self.centerline_paths = [] @@ -437,6 +446,7 @@ def arr(v): "plane_gen_params": self.plane_gen_params.to_dict(), "streamline_params": self.streamline_params.to_dict(), "derived_params": self.derived_params.to_dict(), + "input_state": self.input_state.to_dict(), "resolution": arr(self.resolution), "origin": arr(self.origin), "venc": arr(self.venc), @@ -446,6 +456,8 @@ def arr(v): "segmask_binary": arr(self.segmask_binary), "segmask_3d": arr(self.segmask_3d), "mag_raw": arr(self.mag_raw), + "source_sigma": arr(self.source_sigma), + "source_tke_array": arr(self.source_tke_array), "skeleton_points": arr(self.skeleton_points), "skeleton_mask": arr(self.skeleton_mask), "graph": {"points": arr(self.graph.points), "edges": arr(self.graph.edges)}, @@ -483,6 +495,7 @@ def restore_dict(self, d): self.plane_gen_params = PlaneGenerationParams.from_dict(d.get("plane_gen_params", {})) self.streamline_params = StreamlineParams.from_dict(d.get("streamline_params", {})) self.derived_params = DerivedMetricsParams.from_dict(d.get("derived_params", {})) + self.input_state = InputState.from_dict(d.get("input_state", {})) self.resolution = np.asarray(d.get("resolution", [1, 1, 1]), dtype=float) self.origin = np.array([0.0, 0.0, 0.0], dtype=float) self.venc = np.asarray(d.get("venc", [1, 1, 1]), dtype=float) @@ -497,6 +510,8 @@ def nparr(k, dt=np.float64): self.segmask_binary = None if d.get("segmask_binary") is None else np.asarray(d["segmask_binary"], dtype=bool) self.segmask_3d = None if d.get("segmask_3d") is None else np.asarray(d["segmask_3d"], dtype=bool) self.mag_raw = nparr("mag_raw") + self.source_sigma = nparr("source_sigma") + self.source_tke_array = nparr("source_tke_array") self.skeleton_points = nparr("skeleton_points") self.skeleton_mask = nparr("skeleton_mask") gd = d.get("graph", {}) diff --git a/autoflow/core/pipeline.py b/autoflow/core/pipeline.py index ffc6320..0f74bbd 100755 --- a/autoflow/core/pipeline.py +++ b/autoflow/core/pipeline.py @@ -26,6 +26,11 @@ def __init__(self, step, success=True, skipped=False, message="", outputs=None): class PipelineEngine: + def _missing_segmentation_message(self, ws, action): + if not ws.input_state.capabilities.has_segmentation: + return f"{action} skipped: input has no segmentation" + return f"{action} skipped: no segmentation available" + def _output_dir(self, ws): out_dir = getattr(ws.paths, "output_dir", "") or "" if out_dir: @@ -55,41 +60,46 @@ def load_data(self, ws, log): if not path: raise ValueError("data path is empty") data = load_h5_data(path) - flow = np.asarray(data["flow"], dtype=np.float32) - mag = np.asarray(data["mag"], dtype=np.float32) - seg = np.asarray(data["segmask"], dtype=np.int16) - if flow.ndim == 4 and flow.shape[-1] == 3: - flow = flow[..., np.newaxis, :] - if mag.ndim == 3: - mag = mag[..., np.newaxis] - if seg.ndim == 3: - seg = np.repeat(seg[..., np.newaxis], flow.shape[3], axis=3) - elif seg.ndim == 4 and seg.shape[3] == 1 and flow.shape[3] > 1: - seg = np.repeat(seg, flow.shape[3], axis=3) - if seg.shape[3] != flow.shape[3]: - raise ValueError(f"segmask time dimension {seg.shape[3]} != flow {flow.shape[3]}") + flow = np.asarray(data.flow, dtype=np.float32) + mag = np.asarray(data.mag, dtype=np.float32) + seg = None if data.segmentation is None else np.asarray(data.segmentation, dtype=np.int16) ws.segmask_raw = seg - ws.resolution = np.asarray(data["resolution"], dtype=float).reshape(3) - ws.origin = np.asarray(data.get("origin", [0.0, 0.0, 0.0]), dtype=float).reshape(3) - ws.venc = np.asarray(data["venc"], dtype=float).reshape(-1) - ws.rr = float(data.get("rr", 1000.0)) + ws.resolution = np.asarray(data.resolution, dtype=float).reshape(3) + ws.origin = np.asarray(data.origin, dtype=float).reshape(3) + ws.venc = np.asarray(data.venc, dtype=float).reshape(-1) + ws.rr = float(data.rr) ws.current_t = 0 ws.flow_raw = flow ws.mag_raw = mag - ws.derived.tke_array = np.asarray(data["tke_array"], dtype=np.float32) if "tke_array" in data else None + ws.input_state.source_format = str(data.source_format or "") + ws.input_state.source_group = data.source_group + ws.input_state.metadata = dict(data.metadata or {}) + ws.input_state.capabilities = data.capabilities + ws.source_sigma = None if data.sigma is None else np.asarray(data.sigma, dtype=np.float32) + ws.source_tke_array = None if data.tke_array is None else np.asarray(data.tke_array, dtype=np.float32) + ws.derived.tke_array = None + ws.derived.tke_volume = None + ws.derived.wss_surfaces = [] + ws.derived.wss_volume = None + ws.derived.pixelwise_export = {} ws.data_loaded = True - ws.remove_object_by_data_key("segmask_raw_surface") - ws.add_object(name="segmask_raw", kind=ObjectKind.SEGMENTATION, - data_key="segmask_raw_surface", visible=True, opacity=0.3, - scalars="label", cmap="tab10", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Label") - - ulabels = ws.unique_labels() - msg = f"Loaded: segmask={ws.segmask_raw.shape} labels={ulabels} rr={ws.rr}" + for data_key in ["segmask_raw_surface", "segmask_pre_surface", "wss_surface_live", "tke_volume"]: + ws.remove_object_by_data_key(data_key) + if ws.segmask_raw is not None: + ws.add_object(name="segmask_raw", kind=ObjectKind.SEGMENTATION, + data_key="segmask_raw_surface", visible=True, opacity=0.3, + scalars="label", cmap="tab10", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Label") + + seg_desc = "none" + if ws.segmask_raw is not None: + seg_desc = f"{ws.segmask_raw.shape} labels={ws.unique_labels()}" + msg = f"Loaded: segmask={seg_desc} rr={ws.rr}" msg += f" flow={ws.flow_raw.shape} mag={ws.mag_raw.shape}" msg += f" origin={ws.origin.tolist()}" + msg += f" caps={ws.input_state.capabilities.to_dict()}" log(msg) return msg @@ -122,6 +132,11 @@ def run_step(self, ws, step, log): return dispatch[step](ws) def _step_generate_skeleton(self, ws): + if ws.segmask_raw is None: + return StepResult( + StepId.GENERATE_SKELETON, True, True, + self._missing_segmentation_message(ws, "Skeleton"), + ) self.preprocess(ws) if ws.skeleton_params.remove_small_cc: from ..algorithms import remove_small_cc_from_binary_mask @@ -151,8 +166,15 @@ def _step_edit_skeleton(self, ws): return StepResult(StepId.EDIT_SKELETON, True, True, "Skeleton edit") def _step_generate_graph(self, ws): + if ws.segmask_raw is None: + return StepResult( + StepId.GENERATE_GRAPH, True, True, + self._missing_segmentation_message(ws, "Graph"), + ) if ws.skeleton_points is None or len(ws.skeleton_points) == 0: - self._step_generate_skeleton(ws) + skel_result = self._step_generate_skeleton(ws) + if skel_result.skipped or not skel_result.success: + return StepResult(StepId.GENERATE_GRAPH, skel_result.success, True, skel_result.message) graph = build_graph_from_points(ws.skeleton_points, ws.resolution) ws.graph = graph @@ -199,6 +221,8 @@ def _step_edit_graph(self, ws): def _compute_plane_metrics_internal(self, ws, save=True, use_multithread=False): if not ws.has_flow(): return [], {}, "Plane metrics skipped: no flow" + if ws.segmask_raw is None: + return [], {}, self._missing_segmentation_message(ws, "Plane metrics") if ws.segmask_binary is None: self.preprocess(ws) # Prefer the smoothed centerlines (better local tangents) but fall back @@ -259,8 +283,15 @@ def _save_planes_json(self, ws): return out_path def _step_generate_planes(self, ws): + if ws.segmask_raw is None: + return StepResult( + StepId.GENERATE_PLANES, True, True, + self._missing_segmentation_message(ws, "Planes"), + ) if ws.graph is None or len(ws.graph.points) == 0: - self._step_generate_graph(ws) + graph_result = self._step_generate_graph(ws) + if graph_result.skipped or not graph_result.success: + return StepResult(StepId.GENERATE_PLANES, graph_result.success, True, graph_result.message) if len(ws.centerline_paths) == 0: flow_for_orientation = None @@ -329,6 +360,11 @@ def _step_edit_planes(self, ws): return StepResult(StepId.EDIT_PLANES, True, True, "Plane edit") def _step_generate_streamlines(self, ws): + if ws.segmask_raw is None: + return StepResult( + StepId.GENERATE_STREAMLINES, True, True, + self._missing_segmentation_message(ws, "Streamlines"), + ) if ws.flow_raw is None or ws.segmask_3d is None: return StepResult(StepId.GENERATE_STREAMLINES, True, True, "Streamlines skipped: no flow or mask") self.preprocess(ws) @@ -355,6 +391,11 @@ def _step_generate_streamlines(self, ws): return StepResult(StepId.GENERATE_STREAMLINES, True, False, param_msg) def _step_plane_streamlines(self, ws): + if ws.segmask_raw is None: + return StepResult( + StepId.PLANE_STREAMLINES, True, True, + self._missing_segmentation_message(ws, "Plane streamlines"), + ) if ws.flow_raw is None or ws.segmask_3d is None: return StepResult(StepId.PLANE_STREAMLINES, True, True, "Plane streamlines skipped: no flow or mask") if len(ws.planes) == 0: @@ -376,8 +417,15 @@ def _step_plane_streamlines(self, ws): def _step_compute_plane_metrics(self, ws): if not ws.has_flow(): return StepResult(StepId.COMPUTE_PLANE_METRICS, True, True, "Plane metrics skipped: no flow") + if ws.segmask_raw is None: + return StepResult( + StepId.COMPUTE_PLANE_METRICS, True, True, + self._missing_segmentation_message(ws, "Plane metrics"), + ) if len(ws.planes) == 0: - self._step_generate_planes(ws) + plane_result = self._step_generate_planes(ws) + if plane_result.skipped or not plane_result.success: + return StepResult(StepId.COMPUTE_PLANE_METRICS, plane_result.success, True, plane_result.message) use_mt = getattr(ws.derived_params, "use_multithread", False) _, _, msg = self._compute_plane_metrics_internal(ws, save=True, use_multithread=use_mt) self._save_planes_json(ws) @@ -387,9 +435,15 @@ def _step_compute_plane_metrics(self, ws): def _step_compute_derived_metrics(self, ws): if not ws.has_flow(): return StepResult(StepId.COMPUTE_DERIVED_METRICS, True, True, "Derived metrics skipped: no flow") + if ws.segmask_raw is None: + return StepResult( + StepId.COMPUTE_DERIVED_METRICS, True, True, + self._missing_segmentation_message(ws, "Derived metrics"), + ) self.preprocess(ws) dp = ws.derived_params - loaded_tke = ws.derived.tke_array + source_tke = ws.source_tke_array + source_sigma = ws.source_sigma if ws.input_state.capabilities.has_complex_source else None result = compute_derived_metrics( flow=ws.flow_raw * ws.segmask_binary[..., None], mask4d=ws.segmask_binary, @@ -404,7 +458,8 @@ def _step_compute_derived_metrics(self, ws): tube_radius=dp.tube_radius, rho=dp.rho, save_pixelwise=False, - tke_array=loaded_tke, + tke_array=source_tke, + sigma=source_sigma, ) ws.derived.wss_surfaces = result["wss_surfaces"] ws.derived.wss_volume = result.get("wss_volume") @@ -420,10 +475,14 @@ def _step_compute_derived_metrics(self, ws): data_key="wss_surface_live", visible=False, opacity=1.0, scalars="wss", cmap="jet", clim=(0.0, wss_max if wss_max > 0 else 1.0), dynamic=True, show_scalar_bar=True, scalar_bar_title="WSS (Pa)") - ws.add_object(name="tke_volume", kind=ObjectKind.METRIC, - data_key="tke_volume", visible=False, opacity=0.5, - scalars="TKE", cmap="hot", clim=(0.0, tke_max if tke_max > 0 else 1.0), dynamic=True, - show_scalar_bar=True, scalar_bar_title="TKE (J/m³)") + has_tke = ws.derived.tke_array is not None or ws.derived.tke_volume is not None + if has_tke: + ws.add_object(name="tke_volume", kind=ObjectKind.METRIC, + data_key="tke_volume", visible=False, opacity=0.5, + scalars="TKE", cmap="hot", clim=(0.0, tke_max if tke_max > 0 else 1.0), dynamic=True, + show_scalar_bar=True, scalar_bar_title="TKE (J/m³)") msg = f"Derived: Nt={len(ws.derived.wss_surfaces)}" + if not has_tke: + msg += " tke=unavailable" ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) return StepResult(StepId.COMPUTE_DERIVED_METRICS, True, False, msg) diff --git a/autoflow/processing.py b/autoflow/processing.py index 9382e88..2056146 100755 --- a/autoflow/processing.py +++ b/autoflow/processing.py @@ -111,39 +111,45 @@ def process_single( pixelwise_result = {} if not skip_derived: - print("[7/7] Compute Derived Metrics (WSS/TKE)...") - dp = ws.derived_params - engine.preprocess(ws) - loaded_tke = ws.derived.tke_array - - derived = compute_derived_metrics( - flow=ws.flow_raw * ws.segmask_binary[..., None], - mask4d=ws.segmask_binary, - spacing=ws.resolution, - origin=ws.origin, - smoothing_iteration=dp.smoothing_iteration, - viscosity=dp.viscosity, - inward_distance=dp.inward_distance, - parabolic_fitting=dp.parabolic_fitting, - no_slip_condition=dp.no_slip_condition, - step_size=dp.step_size, - tube_radius=dp.tube_radius, - rho=dp.rho, - save_pixelwise=True, - tke_array=loaded_tke, - ) - ws.derived.wss_surfaces = derived["wss_surfaces"] - ws.derived.wss_volume = derived.get("wss_volume") - ws.derived.tke_volume = derived["tke_volume"] - ws.derived.tke_array = derived.get("tke_array") - ws.derived.pixelwise_export = derived.get("pixelwise_export", {}) - pixelwise_result = ws.derived.pixelwise_export - pixel_path = os.path.join(out_dir, "derived_metrics_pixelwise.npz") - if pixelwise_result: - np.savez_compressed(pixel_path, **pixelwise_result) - print(f" -> Saved pixelwise: {pixel_path}") - ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) - print(f" -> Derived: Nt={len(ws.derived.wss_surfaces)}") + if ws.segmask_raw is None: + print("[7/7] Skipped derived metrics (no segmentation)") + else: + print("[7/7] Compute Derived Metrics (WSS/TKE)...") + dp = ws.derived_params + engine.preprocess(ws) + source_tke = ws.source_tke_array + source_sigma = ws.source_sigma if ws.input_state.capabilities.has_complex_source else None + + derived = compute_derived_metrics( + flow=ws.flow_raw * ws.segmask_binary[..., None], + mask4d=ws.segmask_binary, + spacing=ws.resolution, + origin=ws.origin, + smoothing_iteration=dp.smoothing_iteration, + viscosity=dp.viscosity, + inward_distance=dp.inward_distance, + parabolic_fitting=dp.parabolic_fitting, + no_slip_condition=dp.no_slip_condition, + step_size=dp.step_size, + tube_radius=dp.tube_radius, + rho=dp.rho, + save_pixelwise=True, + tke_array=source_tke, + sigma=source_sigma, + ) + ws.derived.wss_surfaces = derived["wss_surfaces"] + ws.derived.wss_volume = derived.get("wss_volume") + ws.derived.tke_volume = derived["tke_volume"] + ws.derived.tke_array = derived.get("tke_array") + ws.derived.pixelwise_export = derived.get("pixelwise_export", {}) + pixelwise_result = ws.derived.pixelwise_export + pixel_path = os.path.join(out_dir, "derived_metrics_pixelwise.npz") + if pixelwise_result: + np.savez_compressed(pixel_path, **pixelwise_result) + print(f" -> Saved pixelwise: {pixel_path}") + ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) + tke_suffix = "" if ws.derived.tke_array is not None or ws.derived.tke_volume is not None else " tke=unavailable" + print(f" -> Derived: Nt={len(ws.derived.wss_surfaces)}{tke_suffix}") else: print("[7/7] Skipped derived metrics (WSS/TKE)") @@ -260,6 +266,9 @@ def process_single( "resolution": ws.resolution.tolist(), "origin": np.asarray(ws.origin, dtype=float).reshape(3).tolist(), "rr": ws.rr, + "source_format": ws.input_state.source_format, + "source_group": ws.input_state.source_group, + "capabilities": ws.input_state.capabilities.to_dict(), "total_time_sec": float(total_time_sec), "n_planes": len(ws.planes), "n_skeleton_pts": len(ws.skeleton_points) if ws.skeleton_points is not None else 0, diff --git a/tests/test_loader_capabilities.py b/tests/test_loader_capabilities.py new file mode 100755 index 0000000..d8607fc --- /dev/null +++ b/tests/test_loader_capabilities.py @@ -0,0 +1,222 @@ +from pathlib import Path + +import h5py +import numpy as np +import pytest + +import autoflow.algorithms.metrics as metrics_module +import autoflow.core.pipeline as pipeline_module +from autoflow.algorithms import load_h5_data, normalize_loaded_case +from autoflow.core.models import LoaderCapabilities, StepId, Workspace +from autoflow.core.pipeline import PipelineEngine + + +def _write_legacy_complex_h5(path: Path): + img_complex = np.ones((2, 2, 2, 1, 4), dtype=np.complex64) + img_complex[..., 0] = 1.0 + 0.0j + for idx, phase in enumerate((0.1, 0.2, 0.3), start=1): + img_complex[..., idx] = 0.5 * np.exp(1j * phase) + segmask = np.ones((2, 2, 2), dtype=np.int16) + + with h5py.File(path, "w") as f: + f["img_complex"] = img_complex + f["segmask"] = segmask + f["VENC"] = np.array([150.0, 160.0, 170.0], dtype=np.float32) + f["Resolution"] = np.array([1.1, 1.2, 1.3], dtype=np.float32) + f["RR"] = 920.0 + + +def _write_normalized_h5(path: Path): + flow = np.zeros((2, 2, 2, 3), dtype=np.float32) + flow[..., 0] = 15.0 + mag = np.ones((2, 2, 2), dtype=np.float32) + + with h5py.File(path, "w") as f: + f["flow"] = flow + f["mag"] = mag + f["VENC"] = np.array([120.0, 130.0, 140.0], dtype=np.float32) + f["Resolution"] = np.array([1.0, 1.5, 2.0], dtype=np.float32) + f["RR"] = 1000.0 + + +def test_load_h5_data_legacy_complex_is_normalized_but_tke_is_lazy(tmp_path): + path = tmp_path / "legacy_complex.h5" + _write_legacy_complex_h5(path) + + case = load_h5_data(str(path)) + + assert case.source_format == "legacy_h5" + assert case.flow.shape == (2, 2, 2, 1, 3) + assert case.mag.shape == (2, 2, 2, 1) + assert case.segmentation.shape == (2, 2, 2, 1) + assert case.sigma.shape == (2, 2, 2, 1, 3) + assert case.tke_array is None + assert case.capabilities.to_dict() == { + "has_segmentation": True, + "has_tke": True, + "has_complex_source": True, + "supports_wss": True, + "supports_plane_metrics": True, + } + + +def test_load_h5_data_normalized_flow_mag_only_keeps_optional_fields_empty(tmp_path): + path = tmp_path / "normalized_flow_only.h5" + _write_normalized_h5(path) + + case = load_h5_data(str(path)) + + assert case.source_format == "normalized_h5" + assert case.flow.shape == (2, 2, 2, 1, 3) + assert case.mag.shape == (2, 2, 2, 1) + assert case.segmentation is None + assert case.sigma is None + assert case.tke_array is None + assert case.capabilities.to_dict() == { + "has_segmentation": False, + "has_tke": False, + "has_complex_source": False, + "supports_wss": True, + "supports_plane_metrics": True, + } + + +def test_pipeline_load_data_keeps_tke_sources_out_of_derived_state(monkeypatch): + sigma = np.ones((2, 2, 2, 1, 3), dtype=np.float32) + case = normalize_loaded_case( + flow=np.zeros((2, 2, 2, 3), dtype=np.float32), + mag=np.ones((2, 2, 2), dtype=np.float32), + segmentation=np.ones((2, 2, 2), dtype=np.int16), + resolution=[1.0, 1.0, 1.0], + origin=[0.0, 0.0, 0.0], + venc=[150.0, 150.0, 150.0], + rr=1000.0, + sigma=sigma, + tke_array=None, + source_format="legacy_h5", + capabilities=LoaderCapabilities( + has_segmentation=True, + has_tke=True, + has_complex_source=True, + supports_wss=True, + supports_plane_metrics=True, + ), + ) + monkeypatch.setattr(pipeline_module, "load_h5_data", lambda path: case) + + ws = Workspace() + ws.paths.flow_path = "dummy.h5" + PipelineEngine().load_data(ws, lambda msg: None) + + assert ws.input_state.capabilities.has_tke is True + assert ws.source_sigma is not None + assert ws.source_tke_array is None + assert ws.derived.tke_array is None + assert ws.derived.tke_volume is None + assert all(obj.data_key != "tke_volume" for obj in ws.scene_objects.values()) + + +def test_pipeline_skips_segmentation_steps_when_input_has_no_segmentation(monkeypatch): + case = normalize_loaded_case( + flow=np.zeros((2, 2, 2, 3), dtype=np.float32), + mag=np.ones((2, 2, 2), dtype=np.float32), + segmentation=None, + resolution=[1.0, 1.0, 1.0], + origin=[0.0, 0.0, 0.0], + venc=[150.0, 150.0, 150.0], + rr=1000.0, + source_format="normalized_h5", + capabilities=LoaderCapabilities( + has_segmentation=False, + has_tke=False, + has_complex_source=False, + supports_wss=True, + supports_plane_metrics=True, + ), + ) + monkeypatch.setattr(pipeline_module, "load_h5_data", lambda path: case) + + ws = Workspace() + ws.paths.flow_path = "dummy.h5" + engine = PipelineEngine() + engine.load_data(ws, lambda msg: None) + + result = engine.run_step(ws, StepId.GENERATE_SKELETON, lambda msg: None) + + assert result.skipped is True + assert "no segmentation" in result.message + + +def test_pipeline_derived_metrics_does_not_use_sigma_without_complex_source(monkeypatch): + captured = {} + + def fake_compute_derived_metrics(**kwargs): + captured["sigma"] = kwargs.get("sigma") + captured["tke_array"] = kwargs.get("tke_array") + return { + "wss_surfaces": [object()], + "wss_volume": np.ones((2, 2, 2, 1), dtype=np.float32), + "tke_volume": None, + "tke_array": None, + "tke_peak": None, + "streamlines": [], + "tube_radius": kwargs["tube_radius"], + "pixelwise_export": {}, + } + + monkeypatch.setattr(pipeline_module, "compute_derived_metrics", fake_compute_derived_metrics) + + ws = Workspace() + ws.flow_raw = np.ones((2, 2, 2, 1, 3), dtype=np.float32) + ws.mag_raw = np.ones((2, 2, 2, 1), dtype=np.float32) + ws.segmask_raw = np.ones((2, 2, 2, 1), dtype=np.int16) + ws.resolution = np.array([1.0, 1.0, 1.0], dtype=float) + ws.origin = np.array([0.0, 0.0, 0.0], dtype=float) + ws.source_sigma = np.ones((2, 2, 2, 1, 3), dtype=np.float32) + ws.source_tke_array = None + ws.input_state.capabilities = LoaderCapabilities( + has_segmentation=True, + has_tke=False, + has_complex_source=False, + supports_wss=True, + supports_plane_metrics=True, + ) + + result = PipelineEngine()._step_compute_derived_metrics(ws) + + assert result.success is True + assert result.skipped is False + assert captured["sigma"] is None + assert captured["tke_array"] is None + assert ws.derived.tke_array is None + assert any(obj.data_key == "wss_surface_live" for obj in ws.scene_objects.values()) + assert all(obj.data_key != "tke_volume" for obj in ws.scene_objects.values()) + assert "tke=unavailable" in result.message + + +def test_compute_derived_metrics_allows_wss_only_when_tke_is_absent(monkeypatch): + def fake_compute_wss_metrics(*args, **kwargs): + return { + "wss_surfaces": [object()], + "wss_volume": np.ones((2, 2, 2, 1), dtype=np.float32), + } + + def fail_compute_tke_metrics(*args, **kwargs): + raise AssertionError("compute_tke_metrics should not be called without a TKE source") + + monkeypatch.setattr(metrics_module, "compute_wss_metrics", fake_compute_wss_metrics) + monkeypatch.setattr(metrics_module, "compute_tke_metrics", fail_compute_tke_metrics) + + result = metrics_module.compute_derived_metrics( + mask4d=np.ones((2, 2, 2, 1), dtype=bool), + flow=np.zeros((2, 2, 2, 1, 3), dtype=np.float32), + spacing=(1.0, 1.0, 1.0), + save_pixelwise=True, + tke_array=None, + sigma=None, + ) + + assert result["tke_array"] is None + assert result["tke_volume"] is None + assert result["tke_peak"] is None + assert set(result["pixelwise_export"]) == {"wss", "spacing", "origin"}