diff --git a/autoflow/api.py b/autoflow/api.py index 1e5d689..f3e4bbf 100755 --- a/autoflow/api.py +++ b/autoflow/api.py @@ -6,8 +6,9 @@ import numpy as np -from .models import Workspace -from .utils import collect_h5_files, process_single, resolve_reuse_plane_file +from .core.models import Workspace +from .plane_io import resolve_reuse_plane_file +from .processing import collect_h5_files, process_single DEFAULT_WSS_BAR_CFG = { "position_x": 0.75, diff --git a/autoflow/app.py b/autoflow/app.py index ce03680..ffe31cd 100755 --- a/autoflow/app.py +++ b/autoflow/app.py @@ -1,1908 +1,14 @@ -import json -import os -import sys -import time -import traceback -from functools import partial +"""Compatibility re-exports for the main GUI application.""" -import numpy as np -import pyvista as pv -from PyQt5 import QtCore, QtWidgets -from pyvistaqt import QtInteractor -from pyvista import _vtk +__all__ = ["MainWindow", "main"] -from .models import Workspace, StepId, ObjectKind -from .pipeline import PipelineEngine -from .viewer import SceneController -from .editors import SkeletonEditor, PlaneEditor -from .ortho_viewer import OrthoViewer -from .algorithms import compute_plane_metrics, apply_internal_consistency_to_metrics, compute_plane_metrics_multithread +def __getattr__(name): + if name in {"MainWindow", "main"}: + from .ui.app import MainWindow, main -def _parse_plane_index(data_key): - if not isinstance(data_key, str) or not data_key.startswith("plane_"): - return None - suffix = data_key[len("plane_"):] - if suffix.isdigit(): - return int(suffix) - return None - - -def _parse_path_index(data_key): - if not data_key.startswith("smooth_path_"): - return None - suffix = data_key[len("smooth_path_"):] - try: - return int(suffix) - except ValueError: - return None - - -class MainWindow(QtWidgets.QMainWindow): - def __init__(self): - super().__init__() - self.setWindowTitle("AutoFlow") - self.resize(1800, 980) - self.workspace = Workspace() - self.pipeline = PipelineEngine() - self.scene = None - self._play_timer = QtCore.QTimer(self) - self._play_timer.timeout.connect(self._on_play_tick) - self._edit_mode = None - self._edit_points = None - self._edit_edges = None - self._edit_selected_idx = None - self._edit_edge_mode = False - self._edit_edge_src_idx = None - self._edit_selected_edge_idx = None - self._edit_sel_edge_poly = None - self._edit_sel_edge_actor = None - self._edit_poly = None - self._edit_actor = None - self._edit_edge_poly = None - self._edit_edge_actor = None - self._edit_sel_poly = None - self._edit_sel_actor = None - self._edit_widget = None - self._edit_pick_enabled = False - self._vtk_left_click_obs_id = None - self._vtk_keypress_obs_id = None - self._vtk_point_picker = None - self._edit_overlay_dialog = None - self._edit_info_label = None - self._edit_status_label = None - self._edit_btn_edge = None - self._edit_original_points = None - self._edit_original_edges = None - self._plane_drag_active = False - self._plane_drag_index = None - self._plane_widget_initializing = False - self._plane_drag_metrics_dirty = False - self._selected_plane_index = -1 - self._plane_drag_timer = QtCore.QTimer(self) - self._plane_drag_timer.setSingleShot(True) - self._plane_drag_timer.timeout.connect(lambda: self._recompute_dragged_plane_metrics(persist=False)) - self._build_ui() - self._bind_scene() - self._esc_shortcut = QtWidgets.QShortcut(QtCore.Qt.Key_Escape, self) - self._esc_shortcut.setContext(QtCore.Qt.ApplicationShortcut) - self._esc_shortcut.activated.connect(self._force_exit_edit) - QtCore.QTimer.singleShot(0, self._setup_focus_behavior) - self._refresh_all() - - def _setup_focus_behavior(self): - try: - self.plotter.setFocusPolicy(QtCore.Qt.ClickFocus) - except Exception: - pass - - def _build_ui(self): - self._build_menu() - central = QtWidgets.QWidget() - self.setCentralWidget(central) - root = QtWidgets.QVBoxLayout(central) - root.setContentsMargins(6, 6, 6, 6) - root.setSpacing(6) - splitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal) - root.addWidget(splitter, 1) - left = QtWidgets.QWidget() - left_lay = QtWidgets.QVBoxLayout(left) - left_lay.setContentsMargins(0, 0, 0, 0) - left_lay.setSpacing(6) - self._build_browser(left_lay) - mid = QtWidgets.QWidget() - mid_lay = QtWidgets.QVBoxLayout(mid) - mid_lay.setContentsMargins(0, 0, 0, 0) - mid_lay.setSpacing(6) - self.plotter = QtInteractor(self) - self.plotter.setFocusPolicy(QtCore.Qt.ClickFocus) - mid_splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical) - mid_splitter.addWidget(self.plotter) - step_and_params = QtWidgets.QWidget() - sp_lay = QtWidgets.QVBoxLayout(step_and_params) - sp_lay.setContentsMargins(0, 0, 0, 0) - sp_lay.setSpacing(4) - self._build_step_buttons(sp_lay) - scroll = QtWidgets.QScrollArea() - scroll.setWidgetResizable(True) - params_widget = QtWidgets.QWidget() - self.params_layout = QtWidgets.QVBoxLayout(params_widget) - self.params_layout.setContentsMargins(4, 4, 4, 4) - self._build_preprocess_params() - self._build_skeleton_params() - self._build_plane_params() - self._build_streamline_params() - self._build_derived_params() - self.params_layout.addStretch() - scroll.setWidget(params_widget) - sp_lay.addWidget(scroll, 1) - mid_splitter.addWidget(step_and_params) - mid_splitter.setStretchFactor(0, 4) - mid_splitter.setStretchFactor(1, 2) - mid_lay.addWidget(mid_splitter, 1) - right = QtWidgets.QWidget() - right_lay = QtWidgets.QVBoxLayout(right) - right_lay.setContentsMargins(0, 0, 0, 0) - self.ortho_viewer = OrthoViewer(self.workspace, self) - right_lay.addWidget(self.ortho_viewer) - splitter.addWidget(left) - splitter.addWidget(mid) - splitter.addWidget(right) - splitter.setSizes([300, 850, 450]) - splitter.setStretchFactor(0, 1) - splitter.setStretchFactor(1, 4) - splitter.setStretchFactor(2, 2) - bot = QtWidgets.QWidget() - bot_lay = QtWidgets.QVBoxLayout(bot) - bot_lay.setContentsMargins(0, 0, 0, 0) - bot_lay.setSpacing(4) - bot_splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical) - root.addWidget(bot_splitter, 0) - timeline_w = QtWidgets.QWidget() - tl_lay = QtWidgets.QVBoxLayout(timeline_w) - tl_lay.setContentsMargins(0, 0, 0, 0) - self._build_timeline(tl_lay) - bot_splitter.addWidget(timeline_w) - sel_w = QtWidgets.QWidget() - sel_lay = QtWidgets.QVBoxLayout(sel_w) - sel_lay.setContentsMargins(0, 0, 0, 0) - self._build_selection_info(sel_lay) - bot_splitter.addWidget(sel_w) - log_w = QtWidgets.QWidget() - log_lay = QtWidgets.QVBoxLayout(log_w) - log_lay.setContentsMargins(0, 0, 0, 0) - self._build_log(log_lay) - bot_splitter.addWidget(log_w) - bot_splitter.setSizes([40, 80, 60]) - - def _build_browser(self, parent): - grp = QtWidgets.QGroupBox("Browser") - lay = QtWidgets.QVBoxLayout(grp) - self.tree_objects = QtWidgets.QTreeWidget() - self.tree_objects.setHeaderLabels(["Name", "Kind", "Visible"]) - self.tree_objects.setColumnWidth(0, 200) - self.tree_objects.itemSelectionChanged.connect(self._on_browser_select) - self.tree_objects.itemChanged.connect(self._on_tree_item_changed) - self.tree_objects.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) - self.tree_objects.customContextMenuRequested.connect(self._on_browser_ctx_menu) - lay.addWidget(self.tree_objects) - row = QtWidgets.QHBoxLayout() - self.btn_delete_obj = QtWidgets.QPushButton("Delete Selected") - self.btn_delete_obj.clicked.connect(self._on_delete_object) - row.addWidget(self.btn_delete_obj) - row.addStretch() - lay.addLayout(row) - parent.addWidget(grp, 3) - - def _build_step_buttons(self, parent): - grp = QtWidgets.QGroupBox("Steps") - gl = QtWidgets.QGridLayout(grp) - self.step_buttons = {} - for row_idx, steps in enumerate([StepId.top_row_steps(), StepId.bottom_row_steps(), StepId.extra_row_steps()]): - for i, s in enumerate(steps): - b = QtWidgets.QPushButton(s.label) - b.clicked.connect(partial(self._run_single_step, s)) - self.step_buttons[s] = b - gl.addWidget(b, row_idx, i) - btn_run_all = QtWidgets.QPushButton("▶▶ Run All (Generate → Metrics → WSS/TKE)") - btn_run_all.setStyleSheet("QPushButton { background-color: #2a6; color: white; font-weight: bold; padding: 4px; }") - btn_run_all.clicked.connect(self._run_all_pipeline) - gl.addWidget(btn_run_all, 3, 0, 1, 4) - parent.addWidget(grp, 0) - - def _build_preprocess_params(self): - return - - def _build_skeleton_params(self): - grp = QtWidgets.QGroupBox("Generate Skeleton Parameters") - fl = QtWidgets.QFormLayout(grp) - self.chk_remove_small_cc = QtWidgets.QCheckBox() - self.chk_remove_small_cc.setChecked(False) - self.edit_min_cc_volume = QtWidgets.QLineEdit("50.0") - self.chk_closing = QtWidgets.QCheckBox() - self.chk_closing.setChecked(True) - self.chk_opening = QtWidgets.QCheckBox() - self.chk_gaussian = QtWidgets.QCheckBox() - self.chk_gaussian.setChecked(True) - self.edit_gauss_sigma = QtWidgets.QLineEdit("0.5") - fl.addRow("Remove Small CC", self.chk_remove_small_cc) - fl.addRow(u"Min Volume (mm\u00b3)", self.edit_min_cc_volume) - fl.addRow("Closing", self.chk_closing) - fl.addRow("Opening", self.chk_opening) - fl.addRow("Gaussian", self.chk_gaussian) - fl.addRow(u"Gauss \u03c3", self.edit_gauss_sigma) - self.params_layout.addWidget(grp) - - def _build_plane_params(self): - grp = QtWidgets.QGroupBox("Generate Planes Parameters") - fl = QtWidgets.QFormLayout(grp) - self.radio_plane_by_distance = QtWidgets.QRadioButton("By Distance") - self.radio_plane_center = QtWidgets.QRadioButton("Center of Path") - self.radio_plane_center.setChecked(True) - mode_row = QtWidgets.QHBoxLayout() - mode_row.addWidget(self.radio_plane_by_distance) - mode_row.addWidget(self.radio_plane_center) - mode_w = QtWidgets.QWidget() - mode_w.setLayout(mode_row) - self.edit_plane_dist = QtWidgets.QLineEdit("20.0") - self.edit_plane_start = QtWidgets.QLineEdit("5.0") - self.edit_plane_end = QtWidgets.QLineEdit("0.0") - self.edit_plane_smooth_win = QtWidgets.QLineEdit("15") - self.edit_plane_smooth_poly = QtWidgets.QLineEdit("2") - self.edit_plane_inter_time = QtWidgets.QLineEdit("10") - fl.addRow("Plane Mode", mode_w) - fl.addRow("Cross-section Distance (mm)", self.edit_plane_dist) - fl.addRow("Start Distance (mm)", self.edit_plane_start) - fl.addRow("End Distance (mm)", self.edit_plane_end) - fl.addRow("SavGol Window", self.edit_plane_smooth_win) - fl.addRow("SavGol Polyorder", self.edit_plane_smooth_poly) - fl.addRow("Inter-time", self.edit_plane_inter_time) - self.params_layout.addWidget(grp) - - def _build_streamline_params(self): - grp = QtWidgets.QGroupBox("Streamline Parameters") - fl = QtWidgets.QFormLayout(grp) - self.edit_sl_ratio = QtWidgets.QLineEdit("0.02") - self.edit_sl_maxsteps = QtWidgets.QLineEdit("2000") - self.edit_sl_terminal = QtWidgets.QLineEdit("0.01") - fl.addRow("Seed Ratio", self.edit_sl_ratio) - fl.addRow("Max Steps", self.edit_sl_maxsteps) - fl.addRow("Terminal Speed", self.edit_sl_terminal) - self.params_layout.addWidget(grp) - - def _build_derived_params(self): - grp_wss = QtWidgets.QGroupBox("WSS Parameters") - fl_wss = QtWidgets.QFormLayout(grp_wss) - self.edit_dm_smoothing = QtWidgets.QLineEdit("200") - self.edit_dm_viscosity = QtWidgets.QLineEdit("4.0") - self.edit_dm_inward = QtWidgets.QLineEdit("0.6") - self.chk_dm_parabolic = QtWidgets.QCheckBox() - self.chk_dm_parabolic.setChecked(True) - self.chk_dm_noslip = QtWidgets.QCheckBox() - self.chk_dm_noslip.setChecked(True) - fl_wss.addRow("Smoothing Iterations", self.edit_dm_smoothing) - fl_wss.addRow(u"Viscosity (mPa\u00b7s)", self.edit_dm_viscosity) - fl_wss.addRow("Inward Distance (mm)", self.edit_dm_inward) - fl_wss.addRow("Parabolic Fitting", self.chk_dm_parabolic) - fl_wss.addRow("No-Slip Condition", self.chk_dm_noslip) - self.params_layout.addWidget(grp_wss) - - grp_tke = QtWidgets.QGroupBox("TKE / Flow Parameters") - fl_tke = QtWidgets.QFormLayout(grp_tke) - self.edit_dm_rho = QtWidgets.QLineEdit("1060.0") - self.edit_dm_stepsize = QtWidgets.QLineEdit("5") - self.edit_dm_tube = QtWidgets.QLineEdit("0.1") - self.chk_dm_multithread = QtWidgets.QCheckBox() - self.chk_dm_multithread.setChecked(False) - fl_tke.addRow(u"Density \u03c1 (kg/m\u00b3)", self.edit_dm_rho) - fl_tke.addRow("Step Size", self.edit_dm_stepsize) - fl_tke.addRow("Tube Radius", self.edit_dm_tube) - fl_tke.addRow("Multi-thread Metrics", self.chk_dm_multithread) - self.params_layout.addWidget(grp_tke) - - def _build_timeline(self, parent): - grp = QtWidgets.QGroupBox("Timeline") - tl = QtWidgets.QHBoxLayout(grp) - self.btn_prev = QtWidgets.QPushButton(u"\u25c0") - self.btn_prev.clicked.connect(self._on_prev_frame) - self.btn_play = QtWidgets.QPushButton(u"\u25b6 Play") - self.btn_play.clicked.connect(self._on_play) - self.btn_pause = QtWidgets.QPushButton(u"\u23f8 Pause") - self.btn_pause.clicked.connect(self._on_pause) - self.btn_next = QtWidgets.QPushButton(u"\u25b6") - self.btn_next.clicked.connect(self._on_next_frame) - self.slider_t = QtWidgets.QSlider(QtCore.Qt.Horizontal) - self.slider_t.setRange(0, 0) - self.slider_t.valueChanged.connect(self._on_t_changed) - self.lab_t = QtWidgets.QLabel("0") - self.spin_interval = QtWidgets.QSpinBox() - self.spin_interval.setRange(10, 2000) - self.spin_interval.setValue(120) - self.spin_interval.setSuffix(" ms") - for w in [self.btn_prev, self.btn_play, self.btn_pause, self.btn_next]: - tl.addWidget(w) - tl.addWidget(self.slider_t, 1) - tl.addWidget(self.lab_t) - tl.addWidget(self.spin_interval) - parent.addWidget(grp) - - def _build_selection_info(self, parent): - grp = QtWidgets.QGroupBox("Selection") - lay = QtWidgets.QHBoxLayout(grp) - box_plane = QtWidgets.QGroupBox("Plane") - lay_plane = QtWidgets.QVBoxLayout(box_plane) - self.text_plane_info = QtWidgets.QPlainTextEdit() - self.text_plane_info.setReadOnly(True) - self.text_plane_info.setMaximumHeight(82) - lay_plane.addWidget(self.text_plane_info) - box_path = QtWidgets.QGroupBox("Path") - lay_path = QtWidgets.QVBoxLayout(box_path) - self.text_path_info = QtWidgets.QPlainTextEdit() - self.text_path_info.setReadOnly(True) - self.text_path_info.setMaximumHeight(82) - lay_path.addWidget(self.text_path_info) - lay.addWidget(box_plane, 1) - lay.addWidget(box_path, 1) - parent.addWidget(grp) - - def _build_log(self, parent): - grp = QtWidgets.QGroupBox("Log") - ll = QtWidgets.QVBoxLayout(grp) - self.console = QtWidgets.QTextEdit() - self.console.setReadOnly(True) - self.console.setMaximumHeight(80) - ll.addWidget(self.console) - parent.addWidget(grp) - - def _build_menu(self): - mb = self.menuBar() - mf = mb.addMenu("File") - for label, slot in [("Open Data", self._on_open_data), ("Clear Workspace", self._on_close_workspace), ("Exit", self.close)]: - a = QtWidgets.QAction(label, self) - a.triggered.connect(slot) - mf.addAction(a) - mv = mb.addMenu("View") - for label, slot in [("Reset Camera", lambda: self.scene.reset_camera()), ("Toggle Axes", lambda: self.scene.toggle_axes()), - ("White BG", lambda: self.scene.set_background("white")), ("Dark BG", lambda: self.scene.set_background("#202124"))]: - a = QtWidgets.QAction(label, self) - a.triggered.connect(slot) - mv.addAction(a) - - def _bind_scene(self): - self.scene = SceneController(self.plotter, self.workspace, self.log) - self.scene.initialize() - self.scene.enable_plane_picking(self._on_3d_plane_picked) - self.scene.enable_path_picking(self._on_3d_path_picked) - - def _on_3d_plane_picked(self, uid, plane_idx): - if self._edit_mode is not None: - return - if uid is None or plane_idx is None: - self.workspace.selected_path_index = -1 - self._selected_plane_index = -1 - self._clear_plane_drag_widgets() - self.scene.highlight_plane(None) - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - self.ortho_viewer.set_selected_plane(None) - self._clear_browser_selection() - self._set_plane_info_text("") - self._set_path_info_text("") - return - self.workspace.selected_path_index = -1 - self._selected_plane_index = int(plane_idx) - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - self.scene.highlight_plane(uid) - self.ortho_viewer.set_selected_plane(int(plane_idx)) - self._select_browser_item_by_uid(uid) - self._activate_plane_drag_widgets(int(plane_idx)) - self._set_path_info_text("") - self._log_selected_plane_metric(int(plane_idx)) - - def _on_3d_path_picked(self, uid, path_idx): - if self._edit_mode is not None: - return - if uid is None or path_idx is None: - self.workspace.selected_path_index = -1 - self._selected_plane_index = -1 - self._clear_plane_drag_widgets() - self.scene.highlight_plane(None) - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - self.ortho_viewer.set_selected_plane(None) - self._clear_browser_selection() - self._set_plane_info_text("") - self._set_path_info_text("") - return - self.workspace.selected_path_index = int(path_idx) - self._selected_plane_index = -1 - self._clear_plane_drag_widgets() - self.scene.highlight_plane(None) - self.scene.highlight_path(uid) - self.scene.show_forks_for_path(int(path_idx)) - self.ortho_viewer.set_selected_plane(None) - self._select_browser_item_by_uid(uid) - self._set_plane_info_text("") - self._log_selected_path_info(int(path_idx)) - - def _find_uid_by_data_key(self, data_key): - for uid, obj in self.workspace.scene_objects.items(): - if obj.data_key == data_key: - return uid - return None - - def _plane_widget_distance(self): - spacing = self._get_spacing_xyz_from_resolution() - return max(5.0, float(np.mean(spacing)) * 8.0) - - def _clear_plane_drag_widgets(self): - self._plane_drag_timer.stop() - self._plane_drag_active = False - self._plane_drag_index = None - self._plane_widget_initializing = False - self._plane_drag_metrics_dirty = False - if self._edit_mode is not None: - return - try: - if hasattr(self.plotter, "clear_sphere_widgets"): - self.plotter.clear_sphere_widgets() - except Exception: - pass - - def _update_plane_from_drag(self, plane_idx, center=None, normal=None): - if self._plane_widget_initializing: - return - if not (0 <= int(plane_idx) < len(self.workspace.planes)): - return - plane = self.workspace.planes[int(plane_idx)] - tol = max(1e-4, float(np.mean(self._get_spacing_xyz_from_resolution())) * 1e-3) - changed = False - if center is not None: - c = np.asarray(center, dtype=float).reshape(3) - if np.linalg.norm(c - np.asarray(plane.center, dtype=float).reshape(3)) > tol: - plane.center = c - changed = True - if normal is not None: - n = np.asarray(normal, dtype=float).reshape(3) - if np.linalg.norm(n) > 1e-12: - new_normal = n / np.linalg.norm(n) - old_normal = np.asarray(plane.normal, dtype=float).reshape(3) - if min(np.linalg.norm(new_normal - old_normal), np.linalg.norm(new_normal + old_normal)) > 1e-5: - plane.normal = new_normal - changed = True - if not changed: - return - self.scene.invalidate_cache("plane_") - uid = self._find_uid_by_data_key(f"plane_{int(plane_idx)}") - if uid is not None: - obj = self.workspace.scene_objects.get(uid) - if obj is not None: - self.scene.readd_object(obj) - self.scene.highlight_plane(uid) - else: - self.scene.sync_from_workspace() - self._selected_plane_index = int(plane_idx) - self.ortho_viewer._selected_plane_idx = int(plane_idx) - self.ortho_viewer.refresh() - self._plane_drag_index = int(plane_idx) - self._plane_drag_metrics_dirty = True - self._plane_drag_timer.start(250) - try: - self.plotter.render() - except Exception: - pass - - def _persist_plane_outputs(self): - out_dir = self.pipeline._output_dir(self.workspace) - metrics = self.workspace.derived.plane_metrics - qc = self.workspace.derived.plane_qc - if metrics and len(metrics) == len(self.workspace.planes): - with open(os.path.join(out_dir, "plane_metrics.json"), "w", encoding="utf-8") as f: - json.dump(metrics, f, ensure_ascii=False, indent=2) - if qc: - with open(os.path.join(out_dir, "plane_qc.json"), "w", encoding="utf-8") as f: - json.dump(qc, f, ensure_ascii=False, indent=2) - try: - self.pipeline._save_planes_json(self.workspace) - except Exception: - pass - - def _finalize_plane_drag(self, plane_idx): - if not (0 <= int(plane_idx) < len(self.workspace.planes)): - return - self._plane_drag_timer.stop() - self._plane_drag_index = int(plane_idx) - if self.workspace.flow_raw is None or self.workspace.segmask_binary is None: - self._persist_plane_outputs() - self._plane_drag_metrics_dirty = False - return - if self._plane_drag_metrics_dirty or len(self.workspace.derived.plane_metrics) != len(self.workspace.planes): - self._recompute_dragged_plane_metrics(persist=True) - else: - self._persist_plane_outputs() - - def _activate_plane_drag_widgets(self, plane_idx): - if self._edit_mode is not None or not (0 <= int(plane_idx) < len(self.workspace.planes)): - return - if self._plane_drag_active and int(self._plane_drag_index) == int(plane_idx): - return - self._clear_plane_drag_widgets() - plane = self.workspace.planes[int(plane_idx)] - center = np.asarray(plane.center, dtype=float).reshape(3) - normal = np.asarray(plane.normal, dtype=float).reshape(3) - if np.linalg.norm(normal) <= 1e-12: - normal = np.array([1.0, 0.0, 0.0], dtype=float) - normal = normal / np.linalg.norm(normal) - tip = center + normal * self._plane_widget_distance() - radius = self._edit_widget_radius() - - def _center_cb(new_center): - self._update_plane_from_drag(plane_idx, center=new_center) - - def _normal_cb(new_tip): - c = np.asarray(self.workspace.planes[int(plane_idx)].center, dtype=float).reshape(3) - tip_now = np.asarray(new_tip, dtype=float).reshape(3) - self._update_plane_from_drag(plane_idx, normal=(tip_now - c)) - - def _end_cb(_widget, _event): - self._finalize_plane_drag(plane_idx) - - try: - self._plane_widget_initializing = True - center_widget = self.plotter.add_sphere_widget( - callback=_center_cb, - center=tuple(center.tolist()), - radius=radius, - color="cyan", - interaction_event="always", - ) - normal_widget = self.plotter.add_sphere_widget( - callback=_normal_cb, - center=tuple(tip.tolist()), - radius=radius, - color="orange", - interaction_event="always", - ) - center_widget.AddObserver(_vtk.vtkCommand.EndInteractionEvent, _end_cb) - normal_widget.AddObserver(_vtk.vtkCommand.EndInteractionEvent, _end_cb) - self._plane_drag_active = True - self._plane_drag_index = int(plane_idx) - except Exception as e: - self._plane_drag_active = False - self._plane_drag_index = None - self.log(f"Plane drag widget error: {type(e).__name__}: {e}") - finally: - self._plane_widget_initializing = False - - def _recompute_dragged_plane_metrics(self, persist=False): - if self._plane_drag_index is None or not (0 <= int(self._plane_drag_index) < len(self.workspace.planes)): - return - if self.workspace.flow_raw is None or self.workspace.segmask_binary is None: - self.ortho_viewer.refresh() - if persist: - self._persist_plane_outputs() - self._plane_drag_metrics_dirty = False - return - plane_idx = int(self._plane_drag_index) - try: - if len(self.workspace.derived.plane_metrics) != len(self.workspace.planes): - self.pipeline._compute_plane_metrics_internal(self.workspace, save=persist) - self.scene.invalidate_cache("plane_") - self.scene.sync_from_workspace() - else: - paths_for_tangent = ( - self.workspace.centerline_paths_smooth - if len(self.workspace.centerline_paths_smooth) > 0 - else self.workspace.centerline_paths - ) - partial_metrics = compute_plane_metrics( - self.workspace.flow_raw, - self.workspace.segmask_binary, - self.workspace.resolution, - self.workspace.origin, - [self.workspace.planes[plane_idx]], - RR=self.workspace.rr, - branch_labels_3d=self.workspace.branch_labels, - path_info=self.workspace.path_info, - forks=self.workspace.forks, - paths=paths_for_tangent, - return_qc=False, - ) - if partial_metrics: - metrics = [dict(m) for m in self.workspace.derived.plane_metrics] - metrics[plane_idx] = dict(partial_metrics[0]) - metrics, qc = apply_internal_consistency_to_metrics(metrics, path_info=self.workspace.path_info, forks=self.workspace.forks) - self.workspace.derived.plane_metrics = metrics - self.workspace.derived.plane_qc = qc - for i, metric in enumerate(metrics): - if i < len(self.workspace.planes): - self.workspace.planes[i].metrics = dict(metric) - if persist: - self._persist_plane_outputs() - if persist: - self._persist_plane_outputs() - self._selected_plane_index = plane_idx - self.ortho_viewer._selected_plane_idx = plane_idx - self.ortho_viewer.refresh() - self._log_selected_plane_metric(plane_idx) - self._plane_drag_metrics_dirty = False - except Exception as e: - self.log(f"Plane metric update error: {type(e).__name__}: {e}") - self.log(traceback.format_exc()) - - def _log_selected_plane_metric(self, plane_idx): - if not (0 <= int(plane_idx) < len(self.workspace.planes)): - self._set_plane_info_text("") - return - plane = self.workspace.planes[int(plane_idx)] - metric = getattr(plane, "metrics", {}) or {} - t = int(np.clip(self.workspace.current_t, 0, max(0, self.workspace.time_count() - 1))) - fr = metric.get("flowrate_mL_s", []) - ar = metric.get("area_mm2", []) - mv = metric.get("meanv_cm_s_t", []) - flow_t = float(fr[t]) if len(fr) > t else 0.0 - area_t = float(ar[t]) if len(ar) > t else 0.0 - meanv_t = float(mv[t]) if len(mv) > t else float(metric.get("meanv_cm_s", 0.0)) - path_dir = metric.get("path_direction", "") - header = f"Plane {int(plane_idx)} | Path {int(metric.get('path_index', plane.path_index))}" - if path_dir: - header += f" {path_dir}" - text_block = ( - f"{header}\n" - f"t={t} Flow Rate={flow_t:.4f} mL/s Area={area_t:.3f} mm^2 Mean Velocity={meanv_t:.3f} cm/s\n" - f"Peak Velocity={float(metric.get('peakv_cm_s', 0.0)):.3f} cm/s Net Flow={float(metric.get('netflow_mL_beat', 0.0)):.4f} mL/beat IC={float(metric.get('path_ic', 1.0)):.3f}" - ) - self._set_plane_info_text(text_block) - - def _log_selected_path_info(self, path_idx): - if not (0 <= int(path_idx) < len(self.workspace.path_info)): - self._set_path_info_text("") - return - info = self.workspace.path_info[int(path_idx)] - incoming = [int(x) for x in info.get("incoming_path_ids", [])] - outgoing = [int(x) for x in info.get("outgoing_path_ids", [])] - forks = [] - for fork in self.workspace.forks: - if int(path_idx) in fork.get("left", []) or int(path_idx) in fork.get("right", []): - forks.append(f"node={int(fork.get('node', -1))} L={fork.get('left', [])} R={fork.get('right', [])}") - fork_txt = " ; ".join(forks) if forks else "none" - text_block = ( - f"Path {int(path_idx)} | dir={info.get('direction_text', '')}\n" - f"start_node={int(info.get('start_node', -1))} end_node={int(info.get('end_node', -1))}\n" - f"incoming: {incoming if incoming else 'none'} outgoing: {outgoing if outgoing else 'none'}\n" - f"forks: {fork_txt}" - ) - self._set_path_info_text(text_block) - - def _select_browser_item_by_uid(self, uid): - self.tree_objects.blockSignals(True) - for i in range(self.tree_objects.topLevelItemCount()): - top = self.tree_objects.topLevelItem(i) - for j in range(top.childCount()): - child = top.child(j) - if child.data(0, QtCore.Qt.UserRole) == uid: - self.tree_objects.setCurrentItem(child) - self.tree_objects.blockSignals(False) - return - self.tree_objects.blockSignals(False) - - def _clear_browser_selection(self): - self.tree_objects.blockSignals(True) - self.tree_objects.clearSelection() - self.tree_objects.blockSignals(False) - - def _set_plane_info_text(self, text): - msg = str(text).strip() if text else "No plane selected." - self.text_plane_info.setPlainText(msg) - - def _set_path_info_text(self, text): - msg = str(text).strip() if text else "No path selected." - self.text_path_info.setPlainText(msg) - - def _refresh_selection_info(self): - if not (0 <= int(self._selected_plane_index) < len(self.workspace.planes)): - self._selected_plane_index = -1 - self._set_plane_info_text("") - else: - self._log_selected_plane_metric(int(self._selected_plane_index)) - path_idx = int(getattr(self.workspace, "selected_path_index", -1)) - if not (0 <= path_idx < len(self.workspace.path_info)): - self.workspace.selected_path_index = -1 - self._set_path_info_text("") - else: - self._log_selected_path_info(path_idx) - - def log(self, text): - self.console.append(str(text)) - - def _float_from_text(self, text, default=0.0): - try: - return float(text) - except Exception: - return default - - def _int_from_text(self, text, default=0): - try: - return int(text) - except Exception: - return default - - def _parse_int_list(self, text): - r = [] - for tok in text.replace(";", ",").split(","): - tok = tok.strip() - if tok: - try: - r.append(int(tok)) - except ValueError: - pass - return r - - def _sync_params_to_ws(self): - ws = self.workspace - ws.skeleton_params.remove_small_cc = self.chk_remove_small_cc.isChecked() - ws.skeleton_params.min_cc_volume_mm3 = self._float_from_text(self.edit_min_cc_volume.text(), 50.0) - ws.skeleton_params.do_closing = self.chk_closing.isChecked() - ws.skeleton_params.do_opening = self.chk_opening.isChecked() - ws.skeleton_params.gaussian_enabled = self.chk_gaussian.isChecked() - ws.skeleton_params.gaussian_sigma = self._float_from_text(self.edit_gauss_sigma.text(), 0.5) - ws.plane_gen_params.use_center_plane = self.radio_plane_center.isChecked() - ws.plane_gen_params.cross_section_distance = self._float_from_text(self.edit_plane_dist.text(), 20.0) - ws.plane_gen_params.start_distance = self._float_from_text(self.edit_plane_start.text(), 5.0) - ws.plane_gen_params.end_distance = self._float_from_text(self.edit_plane_end.text(), 0.0) - ws.plane_gen_params.smoothing_window = self._int_from_text(self.edit_plane_smooth_win.text(), 15) - ws.plane_gen_params.smoothing_polyorder = self._int_from_text(self.edit_plane_smooth_poly.text(), 3) - ws.plane_gen_params.inter_time = self._int_from_text(self.edit_plane_inter_time.text(), 10) - ws.streamline_params.seed_ratio = min(max(self._float_from_text(self.edit_sl_ratio.text(), 0.02), 0.0001), 1.0) - ws.streamline_params.max_steps = min(max(self._int_from_text(self.edit_sl_maxsteps.text(), 2000), 1), 200000) - ws.streamline_params.terminal_speed = min(max(self._float_from_text(self.edit_sl_terminal.text(), 0.01), 0.0), 1e6) - ws.streamline_params.min_seeds = 50 - ws.derived_params.smoothing_iteration = max(self._int_from_text(self.edit_dm_smoothing.text(), 200), 0) - ws.derived_params.viscosity = max(self._float_from_text(self.edit_dm_viscosity.text(), 4.0), 0.0) - ws.derived_params.inward_distance = max(self._float_from_text(self.edit_dm_inward.text(), 0.6), 0.01) - ws.derived_params.parabolic_fitting = self.chk_dm_parabolic.isChecked() - ws.derived_params.no_slip_condition = self.chk_dm_noslip.isChecked() - ws.derived_params.rho = max(self._float_from_text(self.edit_dm_rho.text(), 1060.0), 1.0) - ws.derived_params.step_size = max(self._int_from_text(self.edit_dm_stepsize.text(), 5), 1) - ws.derived_params.tube_radius = max(self._float_from_text(self.edit_dm_tube.text(), 0.1), 0.0) - ws.derived_params.use_multithread = self.chk_dm_multithread.isChecked() - - def _sync_params_to_ui(self): - ws = self.workspace - self.chk_remove_small_cc.setChecked(ws.skeleton_params.remove_small_cc) - self.edit_min_cc_volume.setText(str(ws.skeleton_params.min_cc_volume_mm3)) - self.chk_closing.setChecked(ws.skeleton_params.do_closing) - self.chk_opening.setChecked(ws.skeleton_params.do_opening) - self.chk_gaussian.setChecked(ws.skeleton_params.gaussian_enabled) - self.edit_gauss_sigma.setText(str(ws.skeleton_params.gaussian_sigma)) - self.radio_plane_center.setChecked(ws.plane_gen_params.use_center_plane) - self.radio_plane_by_distance.setChecked(not ws.plane_gen_params.use_center_plane) - self.edit_plane_dist.setText(str(ws.plane_gen_params.cross_section_distance)) - self.edit_plane_start.setText(str(ws.plane_gen_params.start_distance)) - self.edit_plane_end.setText(str(ws.plane_gen_params.end_distance)) - self.edit_plane_smooth_win.setText(str(ws.plane_gen_params.smoothing_window)) - self.edit_plane_smooth_poly.setText(str(ws.plane_gen_params.smoothing_polyorder)) - self.edit_plane_inter_time.setText(str(ws.plane_gen_params.inter_time)) - self.edit_sl_ratio.setText(str(ws.streamline_params.seed_ratio)) - self.edit_sl_maxsteps.setText(str(ws.streamline_params.max_steps)) - self.edit_sl_terminal.setText(str(ws.streamline_params.terminal_speed)) - self.edit_dm_smoothing.setText(str(ws.derived_params.smoothing_iteration)) - self.edit_dm_viscosity.setText(str(ws.derived_params.viscosity)) - self.edit_dm_inward.setText(str(ws.derived_params.inward_distance)) - self.chk_dm_parabolic.setChecked(ws.derived_params.parabolic_fitting) - self.chk_dm_noslip.setChecked(ws.derived_params.no_slip_condition) - self.edit_dm_rho.setText(str(ws.derived_params.rho)) - self.edit_dm_stepsize.setText(str(ws.derived_params.step_size)) - self.edit_dm_tube.setText(str(ws.derived_params.tube_radius)) - self.chk_dm_multithread.setChecked(ws.derived_params.use_multithread) - - def _rebuild_plane_objects(self): - self._clear_plane_drag_widgets() - ws = self.workspace - ws.remove_objects_by_prefix("plane_") - for i in range(len(ws.planes)): - ws.add_object(name=f"Plane {i}", kind=ObjectKind.PLANE, - data_key=f"plane_{i}", visible=True, opacity=0.6, - color="yellow", line_width=2) - self.scene.invalidate_cache("plane_") - self.scene.sync_from_workspace() - - def _refresh_all(self): - self._refresh_browser() - self._refresh_timeline() - self._sync_params_to_ui() - self._refresh_selection_info() - self._refresh_scene() - - def _refresh_browser(self): - self.tree_objects.blockSignals(True) - self.tree_objects.clear() - groups = {} - for obj in self.workspace.scene_objects.values(): - if obj.data_key == "branch_surface": - continue - kn = obj.kind.value - if kn not in groups: - top = QtWidgets.QTreeWidgetItem([kn, "", ""]) - top.setFlags(top.flags() | QtCore.Qt.ItemIsUserCheckable) - top.setFlags(top.flags() & ~QtCore.Qt.ItemIsSelectable) - top.setCheckState(0, QtCore.Qt.Checked) - groups[kn] = top - self.tree_objects.addTopLevelItem(top) - it = QtWidgets.QTreeWidgetItem([obj.name, obj.kind.value, ""]) - it.setData(0, QtCore.Qt.UserRole, obj.uid) - it.setFlags(it.flags() | QtCore.Qt.ItemIsUserCheckable) - it.setCheckState(0, QtCore.Qt.Checked if obj.visible else QtCore.Qt.Unchecked) - groups[kn].addChild(it) - for kn, top in groups.items(): - vis_count = sum(1 for i in range(top.childCount()) if top.child(i).checkState(0) == QtCore.Qt.Checked) - total = top.childCount() - if vis_count == total: - top.setCheckState(0, QtCore.Qt.Checked) - elif vis_count == 0: - top.setCheckState(0, QtCore.Qt.Unchecked) - else: - top.setCheckState(0, QtCore.Qt.PartiallyChecked) - self.tree_objects.expandAll() - self.tree_objects.blockSignals(False) - - def _selected_uid(self): - items = self.tree_objects.selectedItems() - if not items: - return None - return items[0].data(0, QtCore.Qt.UserRole) - - def _on_browser_select(self): - uid = self._selected_uid() - if uid: - obj = self.workspace.scene_objects.get(uid) - if obj and obj.kind == ObjectKind.PLANE: - pidx = _parse_plane_index(obj.data_key) - self.workspace.selected_path_index = -1 - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - if pidx is not None: - self._selected_plane_index = int(pidx) - self.ortho_viewer.set_selected_plane(int(pidx)) - self._activate_plane_drag_widgets(int(pidx)) - self._set_path_info_text("") - self._log_selected_plane_metric(int(pidx)) - self.scene.highlight_plane(uid) - elif obj and obj.kind == ObjectKind.BRANCH: - pidx = _parse_path_index(obj.data_key) - self._selected_plane_index = -1 - self.scene.highlight_plane(None) - self._clear_plane_drag_widgets() - self.scene.highlight_path(uid) - self._set_plane_info_text("") - self.ortho_viewer.set_selected_plane(None) - if pidx is not None: - self.workspace.selected_path_index = int(pidx) - self.scene.show_forks_for_path(int(pidx)) - self._log_selected_path_info(int(pidx)) - else: - self.workspace.selected_path_index = -1 - self.scene.show_forks_for_path(-1) - else: - self._selected_plane_index = -1 - self.workspace.selected_path_index = -1 - self._clear_plane_drag_widgets() - self.scene.highlight_plane(None) - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - self.ortho_viewer.set_selected_plane(None) - self._refresh_selection_info() - else: - self._selected_plane_index = -1 - self.workspace.selected_path_index = -1 - self._clear_plane_drag_widgets() - self.scene.highlight_plane(None) - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - self.ortho_viewer.set_selected_plane(None) - self._refresh_selection_info() - - def _on_tree_item_changed(self, item, column): - uid = item.data(0, QtCore.Qt.UserRole) - if uid: - obj = self.workspace.scene_objects.get(uid) - if obj: - obj.visible = item.checkState(0) == QtCore.Qt.Checked - self.scene.apply_object_properties(obj) - else: - checked = item.checkState(0) != QtCore.Qt.Unchecked - self.tree_objects.blockSignals(True) - for i in range(item.childCount()): - child = item.child(i) - child.setCheckState(0, QtCore.Qt.Checked if checked else QtCore.Qt.Unchecked) - cuid = child.data(0, QtCore.Qt.UserRole) - if cuid: - obj = self.workspace.scene_objects.get(cuid) - if obj: - obj.visible = checked - self.scene.apply_object_properties(obj) - self.tree_objects.blockSignals(False) - self._refresh_scene() - - def _on_browser_ctx_menu(self, pos): - item = self.tree_objects.itemAt(pos) - if item is None: - return - uid = item.data(0, QtCore.Qt.UserRole) - menu = QtWidgets.QMenu(self) - if uid is None: - act_show = menu.addAction("Show All") - act_hide = menu.addAction("Hide All") - act_del_all = menu.addAction("Delete All") - action = menu.exec_(self.tree_objects.viewport().mapToGlobal(pos)) - if action == act_show: - self._set_group_vis(item, True) - elif action == act_hide: - self._set_group_vis(item, False) - elif action == act_del_all: - self.tree_objects.setCurrentItem(item) - self._on_delete_object() - else: - act_toggle = menu.addAction("Toggle Visibility") - act_del = menu.addAction("Delete") - obj = self.workspace.scene_objects.get(uid) - act_plane_sl = None - if obj and obj.kind == ObjectKind.PLANE: - act_plane_sl = menu.addAction("Streamlines from Plane") - action = menu.exec_(self.tree_objects.viewport().mapToGlobal(pos)) - if action == act_toggle: - if obj: - obj.visible = not obj.visible - self.scene.apply_object_properties(obj) - self._refresh_browser() - elif action == act_del: - self._on_delete_object() - elif act_plane_sl is not None and action == act_plane_sl: - pidx = _parse_plane_index(obj.data_key) - if pidx is not None: - self._trigger_plane_streamlines(pidx) - - def _trigger_plane_streamlines(self, plane_idx): - self._sync_params_to_ws() - self.pipeline.preprocess(self.workspace) - self.workspace.plane_streamline_plane_idx = plane_idx - self.scene.trigger_plane_streamlines(plane_idx) - self._refresh_browser() - self.scene.invalidate_cache() - self.scene.sync_from_workspace() - self._refresh_all() - self._log_selected_plane_metric(int(plane_idx)) - - def _set_group_vis(self, group_item, visible): - for i in range(group_item.childCount()): - uid = group_item.child(i).data(0, QtCore.Qt.UserRole) - if uid: - obj = self.workspace.scene_objects.get(uid) - if obj: - obj.visible = visible - self.scene.apply_object_properties(obj) - self._refresh_browser() - self._refresh_scene() - - def _on_delete_object(self): - items = self.tree_objects.selectedItems() - if not items: - return - item = items[0] - uid = item.data(0, QtCore.Qt.UserRole) - plane_indices_removed = [] - if uid: - obj = self.workspace.scene_objects.get(uid) - name = obj.name if obj else uid - if obj and obj.kind == ObjectKind.PLANE: - pidx = _parse_plane_index(obj.data_key) - if pidx is not None: - plane_indices_removed.append(pidx) - self.scene.remove_object(uid) - self.log(f"Deleted: {name}") - else: - count = item.childCount() - if count == 0: - return - kind_name = item.text(0) - uids = [] - for i in range(count): - cuid = item.child(i).data(0, QtCore.Qt.UserRole) - if cuid: - cobj = self.workspace.scene_objects.get(cuid) - if cobj and cobj.kind == ObjectKind.PLANE: - pidx = _parse_plane_index(cobj.data_key) - if pidx is not None: - plane_indices_removed.append(pidx) - uids.append(cuid) - for u in uids: - self.scene.remove_object(u) - self.log(f"Deleted section: {kind_name} ({len(uids)} objects)") - if plane_indices_removed: - self._clear_plane_drag_widgets() - for pidx in sorted(plane_indices_removed, reverse=True): - if 0 <= pidx < len(self.workspace.planes): - self.workspace.planes.pop(pidx) - self._rebuild_plane_objects() - self._selected_plane_index = -1 - self.workspace.selected_path_index = -1 - self.scene.highlight_plane(None) - self.scene.highlight_path(None) - self.scene.show_forks_for_path(-1) - self._refresh_browser() - self._refresh_selection_info() - - def _refresh_timeline(self): - T = max(1, self.workspace.time_count()) - self.slider_t.blockSignals(True) - self.slider_t.setMaximum(T - 1) - self.slider_t.setValue(self.workspace.current_t) - self.slider_t.blockSignals(False) - self.lab_t.setText(str(self.workspace.current_t)) - - def _on_t_changed(self, v): - self.workspace.current_t = int(v) - self.lab_t.setText(str(v)) - self.scene.update_time(int(v)) - self.ortho_viewer.refresh() - self._refresh_selection_info() - - def _on_prev_frame(self): - self.workspace.current_t = max(0, self.workspace.current_t - 1) - self._refresh_timeline() - self.scene.update_time(self.workspace.current_t) - self.ortho_viewer.refresh() - self._refresh_selection_info() - - def _on_next_frame(self): - T = self.workspace.time_count() - self.workspace.current_t = min(T - 1, self.workspace.current_t + 1) - self._refresh_timeline() - self.scene.update_time(self.workspace.current_t) - self.ortho_viewer.refresh() - self._refresh_selection_info() - - def _on_play(self): - self.scene.set_playback_active(True) - self._play_timer.start(self.spin_interval.value()) - - def _on_pause(self): - self._play_timer.stop() - self.scene.set_playback_active(False) - - def _on_play_tick(self): - T = self.workspace.time_count() - if T <= 1: - return - self.workspace.current_t = (self.workspace.current_t + 1) % T - self._refresh_timeline() - self.scene.update_time(self.workspace.current_t) - self.ortho_viewer.refresh() - self._refresh_selection_info() - - def _refresh_scene(self): - try: - self.scene.render_all() - except Exception as e: - self.log(f"VIEW ERROR: {type(e).__name__}: {e}") - - def _on_open_data(self): - path, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Data", "", "H5 (*.h5 *.hdf5);;All (*)") - if not path: - return - try: - if self._edit_mode is not None: - self._exit_interactive_edit(False) - self._clear_plane_drag_widgets() - self.workspace.reset_all() - self.workspace.paths.segmask_path = path - self.workspace.paths.flow_path = path - self.pipeline.load_data(self.workspace, self.log) - self.scene.workspace = self.workspace - self.scene.reset_scene() - self._refresh_all() - self.ortho_viewer.update_slider_ranges() - except Exception as e: - self.log(f"LOAD ERROR: {type(e).__name__}: {e}") - self.log(traceback.format_exc()) - - def _on_close_workspace(self): - if self._edit_mode is not None: - self._exit_interactive_edit(False) - self._clear_plane_drag_widgets() - self.workspace.reset_all() - self._selected_plane_index = -1 - self.scene.reset_scene() - self.ortho_viewer.reset_state() - self._refresh_all() - self.log("Workspace cleared") - - def _run_single_step(self, step): - if not self.workspace.data_loaded: - self.log("No data loaded. Use File > Open Data.") - return - if self._edit_mode is not None: - if step == StepId.EDIT_SKELETON and self._edit_mode == "skeleton": - self._exit_interactive_edit(True) - return - if step == StepId.EDIT_GRAPH and self._edit_mode == "graph": - self._exit_interactive_edit(True) - return - self.log("Finish current interactive edit first. Press ESC to force exit.") - return - try: - self._sync_params_to_ws() - self._clear_plane_drag_widgets() - self.setEnabled(False) - QtWidgets.QApplication.processEvents() - if step == StepId.EDIT_SKELETON: - self._start_skeleton_interactive_edit() - return - if step == StepId.EDIT_GRAPH: - self._start_graph_interactive_edit() - return - if step == StepId.GENERATE_STREAMLINES: - self.pipeline.preprocess(self.workspace) - self.scene.trigger_streamlines() - self._refresh_browser() - self.scene.invalidate_cache() - self.scene.sync_from_workspace() - self._refresh_all() - return - if step == StepId.PLANE_STREAMLINES: - self._on_plane_streamlines_step() - return - t0 = time.time() - result = self.pipeline.run_step(self.workspace, step, self.log) - elapsed = time.time() - t0 - self.log(f"[{step.label}] {elapsed:.2f}s - {result.message}") - self.scene.invalidate_cache() - self.scene.sync_from_workspace() - self._refresh_all() - self.ortho_viewer.refresh() - except Exception as e: - self.log(f"STEP ERROR: {type(e).__name__}: {e}") - self.log(traceback.format_exc()) - finally: - self.setEnabled(True) - - def _run_all_pipeline(self): - if not self.workspace.data_loaded: - self.log("No data loaded. Use File > Open Data.") - return - if self._edit_mode is not None: - self.log("Finish current interactive edit first.") - return - self._sync_params_to_ws() - self._clear_plane_drag_widgets() - self.setEnabled(False) - QtWidgets.QApplication.processEvents() - all_steps = [ - StepId.GENERATE_SKELETON, - StepId.GENERATE_GRAPH, - StepId.GENERATE_PLANES, - StepId.COMPUTE_PLANE_METRICS, - StepId.COMPUTE_DERIVED_METRICS, - ] - try: - t_total = time.time() - for step in all_steps: - t0 = time.time() - self.log(f"[Run All] Running {step.label}...") - QtWidgets.QApplication.processEvents() - result = self.pipeline.run_step(self.workspace, step, self.log) - elapsed = time.time() - t0 - self.log(f"[{step.label}] {elapsed:.2f}s - {result.message}") - self.scene.invalidate_cache() - self.scene.sync_from_workspace() - self._refresh_all() - self.ortho_viewer.refresh() - total_elapsed = time.time() - t_total - self.log(f"[Run All] Completed in {total_elapsed:.2f}s") - except Exception as e: - self.log(f"RUN ALL ERROR: {type(e).__name__}: {e}") - self.log(traceback.format_exc()) - finally: - self.setEnabled(True) - - def _on_plane_streamlines_step(self): - ws = self.workspace - if len(ws.planes) == 0: - self.log("No planes available for plane streamlines.") - return - uid = self._selected_uid() - pidx = 0 - if uid: - obj = ws.scene_objects.get(uid) - if obj and obj.kind == ObjectKind.PLANE: - parsed = _parse_plane_index(obj.data_key) - if parsed is not None: - pidx = parsed - self._trigger_plane_streamlines(pidx) - - def _get_spacing_xyz_from_resolution(self): - r = np.asarray(self.workspace.resolution, dtype=float).reshape(-1) - if r.size >= 3: - return np.array([float(r[0]), float(r[1]), float(r[2])], dtype=float) - return np.array([1.0, 1.0, 1.0], dtype=float) - - def _edit_widget_radius(self): - spacing = self._get_spacing_xyz_from_resolution() - return max(0.1, float(np.mean(spacing)) * 0.6) - - def _graph_polydata(self, points, edges): - points = np.asarray(points, dtype=float).reshape(-1, 3) - poly = pv.PolyData(points) - edges = np.asarray(edges, dtype=int).reshape(-1, 2) if len(edges) else np.empty((0, 2), dtype=int) - if len(edges) > 0: - cells = np.empty((len(edges), 3), dtype=np.int64) - cells[:, 0] = 2 - cells[:, 1] = edges[:, 0] - cells[:, 2] = edges[:, 1] - poly.lines = cells.ravel() - return poly - - def _cleanup_edit_actors(self): - for actor in [self._edit_actor, self._edit_edge_actor, self._edit_sel_actor, self._edit_sel_edge_actor]: - if actor is not None: - try: - self.plotter.remove_actor(actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(actor) - except Exception: - pass - self._edit_actor = None - self._edit_edge_actor = None - self._edit_sel_actor = None - self._edit_sel_edge_actor = None - self._edit_poly = None - self._edit_edge_poly = None - self._edit_sel_poly = None - self._edit_sel_edge_poly = None - - def _remove_edit_widget(self): - try: - if hasattr(self.plotter, "clear_sphere_widgets"): - self.plotter.clear_sphere_widgets() - except Exception: - pass - try: - if hasattr(self.plotter, "remove_widget") and self._edit_widget is not None: - try: - self.plotter.remove_widget(self._edit_widget) - except Exception: - pass - except Exception: - pass - self._edit_widget = None - - def _update_edit_labels(self): - if self._edit_info_label is None or self._edit_status_label is None: - return - mode = self._edit_mode or "-" - npts = 0 if self._edit_points is None else len(self._edit_points) - nedges = 0 if self._edit_edges is None else len(self._edit_edges) - sel = "-" if self._edit_selected_idx is None else str(int(self._edit_selected_idx)) - esel = "-" if self._edit_selected_edge_idx is None else str(int(self._edit_selected_edge_idx)) - self._edit_info_label.setText(f"Mode: {mode} Points: {npts} Edges: {nedges} Selected Node: {sel} Selected Edge: {esel}") - edge_state = "ON" if self._edit_edge_mode else "OFF" - src = "-" if self._edit_edge_src_idx is None else str(int(self._edit_edge_src_idx)) - self._edit_status_label.setText(f"Edge Mode: {edge_state} Edge Src: {src} Keys: Delete/Backspace delete, E toggle edge mode") - if self._edit_btn_edge is not None: - self._edit_btn_edge.setText("Edge Mode: ON" if self._edit_edge_mode else "Edge Mode: OFF") - - def _update_edit_points_actor(self): - if self._edit_points is None or len(self._edit_points) == 0: - if self._edit_actor is not None: - try: - self.plotter.remove_actor(self._edit_actor) - except Exception: - pass - self._edit_actor = None - self._edit_poly = None - return - if self._edit_poly is None: - self._edit_poly = pv.PolyData(np.asarray(self._edit_points, dtype=float)) - else: - self._edit_poly.points = np.asarray(self._edit_points, dtype=float) - color = "red" if self._edit_mode == "skeleton" else "deepskyblue" - if self._edit_mode == "plane": - color = "yellow" - if self._edit_actor is None: - self._edit_actor = self.plotter.add_mesh(self._edit_poly, color=color, point_size=10, render_points_as_spheres=True, name="interactive_edit_points") - else: - try: - self._edit_actor.GetMapper().SetInputData(self._edit_poly) - except Exception: - try: - self.plotter.remove_actor(self._edit_actor) - except Exception: - pass - self._edit_actor = self.plotter.add_mesh(self._edit_poly, color=color, point_size=10, render_points_as_spheres=True, name="interactive_edit_points") - - def _update_edit_edges_actor(self): - if self._edit_edges is None or len(self._edit_edges) == 0 or self._edit_points is None or len(self._edit_points) == 0: - if self._edit_edge_actor is not None: - try: - self.plotter.remove_actor(self._edit_edge_actor) - except Exception: - pass - self._edit_edge_actor = None - self._edit_edge_poly = None - return - poly = self._graph_polydata(self._edit_points, self._edit_edges) - self._edit_edge_poly = poly - color = "green" if self._edit_mode == "graph" else "orange" - if self._edit_mode == "plane": - color = "yellow" - if self._edit_edge_actor is None: - self._edit_edge_actor = self.plotter.add_mesh(poly, color=color, line_width=3, name="interactive_edit_edges") - else: - try: - self._edit_edge_actor.GetMapper().SetInputData(poly) - except Exception: - try: - self.plotter.remove_actor(self._edit_edge_actor) - except Exception: - pass - self._edit_edge_actor = self.plotter.add_mesh(poly, color=color, line_width=3, name="interactive_edit_edges") - - def _set_selected_idx(self, idx): - self._edit_selected_idx = None if idx is None else int(idx) - if self._edit_points is None or len(self._edit_points) == 0: - self._edit_selected_idx = None - elif self._edit_selected_idx is not None and not (0 <= self._edit_selected_idx < len(self._edit_points)): - self._edit_selected_idx = None - if self._edit_selected_idx is None: - if self._edit_sel_actor is not None: - try: - self.plotter.remove_actor(self._edit_sel_actor) - except Exception: - pass - self._edit_sel_actor = None - self._edit_sel_poly = None - self._remove_edit_widget() - self._update_edit_labels() - try: - self.plotter.render() - except Exception: - pass - return - p = np.asarray(self._edit_points[self._edit_selected_idx], dtype=float).reshape(1, 3) - if self._edit_sel_poly is None: - self._edit_sel_poly = pv.PolyData(p) - else: - self._edit_sel_poly.points = p - if self._edit_sel_actor is None: - self._edit_sel_actor = self.plotter.add_mesh(self._edit_sel_poly, color="yellow", point_size=16, render_points_as_spheres=True, name="interactive_edit_selected") - else: - try: - self._edit_sel_actor.GetMapper().SetInputData(self._edit_sel_poly) - except Exception: - try: - self.plotter.remove_actor(self._edit_sel_actor) - except Exception: - pass - self._edit_sel_actor = self.plotter.add_mesh(self._edit_sel_poly, color="yellow", point_size=16, render_points_as_spheres=True, name="interactive_edit_selected") - self._create_or_move_edit_widget(p[0]) - self._update_edit_labels() - try: - self.plotter.render() - except Exception: - pass - - def _create_or_move_edit_widget(self, center): - if self._edit_selected_idx is None or self._edit_points is None or len(self._edit_points) == 0: - return - self._remove_edit_widget() - radius = self._edit_widget_radius() - def _cb(new_center): - if self._edit_selected_idx is None or self._edit_points is None or len(self._edit_points) == 0: - return - c = np.asarray(new_center, dtype=float).reshape(3) - self._edit_points[self._edit_selected_idx, :] = c - if self._edit_poly is not None: - try: - self._edit_poly.points = self._edit_points - except Exception: - pass - if self._edit_sel_poly is not None: - try: - self._edit_sel_poly.points = np.asarray([c], dtype=float) - except Exception: - pass - if self._edit_edge_poly is not None: - try: - self._edit_edge_poly.points = self._edit_points - except Exception: - self._update_edit_edges_actor() - try: - self.plotter.render() - except Exception: - pass - self._edit_widget = self.plotter.add_sphere_widget(callback=_cb, center=tuple(np.asarray(center, dtype=float).tolist()), radius=radius, color="orange") - - def _enable_interactive_key_events(self, enable): - try: - iren = self.plotter.iren.interactor - except Exception: - iren = None - if not enable: - if iren is not None and self._vtk_keypress_obs_id is not None: - try: - iren.RemoveObserver(self._vtk_keypress_obs_id) - except Exception: - pass - self._vtk_keypress_obs_id = None - return - if iren is None: - self.log("WARNING: No VTK interactor available; cannot enable key events.") - return - def _on_keypress(obj, ev): - key = "" - try: - key = iren.GetKeySym() - except Exception: - return - if key == "Escape": - self._force_exit_edit() - return - if self._edit_mode is None: - return - if key in ("Delete", "BackSpace"): - if self._edit_selected_edge_idx is not None: - self._delete_selected_edge() - else: - self._delete_selected_interactive_point() - return - if key in ("e", "E"): - self._toggle_edge_mode() - return - if self._vtk_keypress_obs_id is not None: - try: - iren.RemoveObserver(self._vtk_keypress_obs_id) - except Exception: - pass - self._vtk_keypress_obs_id = None - self._vtk_keypress_obs_id = iren.AddObserver("KeyPressEvent", _on_keypress) - - def _enable_interactive_point_picking(self, enable): - self._edit_pick_enabled = bool(enable) - try: - iren = self.plotter.iren.interactor - except Exception: - iren = None - if not enable: - if iren is not None and self._vtk_left_click_obs_id is not None: - try: - iren.RemoveObserver(self._vtk_left_click_obs_id) - except Exception: - pass - self._vtk_left_click_obs_id = None - self._vtk_point_picker = None - return - if iren is None: - self.log("WARNING: No VTK interactor available; cannot enable picking.") - return - if self._vtk_point_picker is None: - self._vtk_point_picker = pv._vtk.vtkPointPicker() - self._vtk_point_picker.SetTolerance(0.02) - def _on_left_click(obj, ev): - if self._edit_mode is None: - try: - iren.GetInteractorStyle().OnLeftButtonDown() - except Exception: - pass - return - if self._edit_actor is None or self._edit_points is None or len(self._edit_points) == 0: - try: - iren.GetInteractorStyle().OnLeftButtonDown() - except Exception: - pass - return - try: - x, y = iren.GetEventPosition() - except Exception: - x, y = None, None - if x is None: - try: - iren.GetInteractorStyle().OnLeftButtonDown() - except Exception: - pass - return - try: - self._vtk_point_picker.InitializePickList() - self._vtk_point_picker.AddPickList(self._edit_actor) - if self._edit_edge_actor is not None: - self._vtk_point_picker.AddPickList(self._edit_edge_actor) - self._vtk_point_picker.PickFromListOn() - except Exception: - pass - try: - ren = self.plotter.renderer - ok = self._vtk_point_picker.Pick(float(x), float(y), 0.0, ren) - except Exception: - ok = 0 - if not ok: - try: - iren.GetInteractorStyle().OnLeftButtonDown() - except Exception: - pass - return - try: - p = np.asarray(self._vtk_point_picker.GetPickPosition(), dtype=float).reshape(3) - except Exception: - try: - iren.GetInteractorStyle().OnLeftButtonDown() - except Exception: - pass - return - pts = np.asarray(self._edit_points, dtype=float) - d2 = np.sum((pts - p.reshape(1, 3)) ** 2, axis=1) - node_idx = int(np.argmin(d2)) - node_dist = float(np.sqrt(d2[node_idx])) - edge_idx, edge_dist = self._find_closest_edge(p) - if self._edit_edge_mode and self._edit_mode == "graph": - if self._edit_edge_src_idx is None: - self._edit_edge_src_idx = node_idx - self._set_selected_idx(node_idx) - self._edit_selected_edge_idx = None - self._clear_edge_selection_actor() - self._update_edit_labels() - else: - src = self._edit_edge_src_idx - self._edit_edge_src_idx = None - if src != node_idx: - self._toggle_edge(src, node_idx) - self._update_edit_labels() - elif self._edit_mode == "graph" and edge_idx is not None and edge_dist < node_dist * 0.7: - self._set_selected_edge_idx(edge_idx) - else: - self._edit_selected_edge_idx = None - self._clear_edge_selection_actor() - self._set_selected_idx(node_idx) - try: - self.plotter.render() - except Exception: - pass - if self._vtk_left_click_obs_id is not None: - try: - iren.RemoveObserver(self._vtk_left_click_obs_id) - except Exception: - pass - self._vtk_left_click_obs_id = None - try: - self._vtk_left_click_obs_id = iren.AddObserver("LeftButtonPressEvent", _on_left_click) - except Exception as e: - self.log(f"WARNING: failed to add VTK observer for picking: {e}") - self._vtk_left_click_obs_id = None - - def _toggle_edge_mode(self): - if self._edit_mode != "graph": - return - self._edit_edge_mode = not self._edit_edge_mode - self._edit_edge_src_idx = None - self._update_edit_labels() - - def _toggle_edge(self, i, j): - if self._edit_edges is None: - self._edit_edges = np.empty((0, 2), dtype=int) - edges = np.asarray(self._edit_edges, dtype=int).reshape(-1, 2) - found = -1 - for k, (a, b) in enumerate(edges): - if (int(a) == i and int(b) == j) or (int(a) == j and int(b) == i): - found = k - break - if found >= 0: - self._edit_edges = np.delete(edges, found, axis=0) - self.log(f"Removed edge ({i}, {j})") - else: - self._edit_edges = np.vstack([edges, [i, j]]) if len(edges) > 0 else np.array([[i, j]], dtype=int) - self.log(f"Added edge ({i}, {j})") - self._edit_selected_edge_idx = None - self._clear_edge_selection_actor() - self._update_edit_edges_actor() - self._update_edit_labels() - try: - self.plotter.render() - except Exception: - pass - - def _clear_edge_selection_actor(self): - if self._edit_sel_edge_actor is not None: - try: - self.plotter.remove_actor(self._edit_sel_edge_actor) - except Exception: - pass - self._edit_sel_edge_actor = None - self._edit_sel_edge_poly = None - - def _set_selected_edge_idx(self, idx): - self._edit_selected_edge_idx = None if idx is None else int(idx) - if self._edit_edges is None or len(self._edit_edges) == 0: - self._edit_selected_edge_idx = None - elif self._edit_selected_edge_idx is not None and not (0 <= self._edit_selected_edge_idx < len(self._edit_edges)): - self._edit_selected_edge_idx = None - if self._edit_selected_edge_idx is None: - self._clear_edge_selection_actor() - self._update_edit_labels() - return - edge = self._edit_edges[self._edit_selected_edge_idx] - pts = self._edit_points[edge] - poly = pv.PolyData(pts) - poly.lines = np.array([2, 0, 1], dtype=np.int64) - self._edit_sel_edge_poly = poly - if self._edit_sel_edge_actor is None: - self._edit_sel_edge_actor = self.plotter.add_mesh(poly, color="yellow", line_width=6, name="interactive_edit_selected_edge") - else: - try: - self._edit_sel_edge_actor.GetMapper().SetInputData(poly) - except Exception: - try: - self.plotter.remove_actor(self._edit_sel_edge_actor) - except Exception: - pass - self._edit_sel_edge_actor = self.plotter.add_mesh(poly, color="yellow", line_width=6, name="interactive_edit_selected_edge") - self._set_selected_idx(None) - self._update_edit_labels() - try: - self.plotter.render() - except Exception: - pass - - def _delete_selected_edge(self): - if self._edit_selected_edge_idx is None or self._edit_edges is None or len(self._edit_edges) == 0: - return - idx = int(self._edit_selected_edge_idx) - self.log(f"Deleted edge: {idx} ({self._edit_edges[idx].tolist()})") - self._edit_edges = np.delete(self._edit_edges, idx, axis=0) - self._edit_selected_edge_idx = None - self._clear_edge_selection_actor() - self._update_edit_edges_actor() - self._update_edit_labels() - try: - self.plotter.render() - except Exception: - pass - - def _find_closest_edge(self, pick_pos): - if self._edit_edges is None or len(self._edit_edges) == 0 or self._edit_points is None: - return None, float("inf") - pts = np.asarray(self._edit_points, dtype=float) - p = np.asarray(pick_pos, dtype=float).reshape(3) - best_idx = None - best_dist = float("inf") - for k, (a, b) in enumerate(self._edit_edges): - a_pt = pts[int(a)] - b_pt = pts[int(b)] - ab = b_pt - a_pt - ab_len2 = np.dot(ab, ab) - if ab_len2 < 1e-24: - d = np.linalg.norm(p - a_pt) - else: - t = np.clip(np.dot(p - a_pt, ab) / ab_len2, 0.0, 1.0) - proj = a_pt + t * ab - d = np.linalg.norm(p - proj) - if d < best_dist: - best_dist = d - best_idx = k - return best_idx, best_dist - - def _delete_selected_interactive_point(self): - if self._edit_selected_idx is None or self._edit_points is None or len(self._edit_points) == 0: - return - idx = int(self._edit_selected_idx) - self._edit_points = np.delete(np.asarray(self._edit_points, dtype=float), idx, axis=0) - if self._edit_edges is not None and len(self._edit_edges) > 0: - keep_idx = [i for i in range(len(self._edit_points) + 1) if i != idx] - remap = {old: new for new, old in enumerate(keep_idx)} - new_edges = [] - for a, b in np.asarray(self._edit_edges, dtype=int): - a = int(a) - b = int(b) - if a in remap and b in remap: - new_edges.append([remap[a], remap[b]]) - self._edit_edges = np.asarray(new_edges, dtype=int).reshape(-1, 2) if new_edges else np.empty((0, 2), dtype=int) - self._update_edit_points_actor() - self._update_edit_edges_actor() - if len(self._edit_points) == 0: - self._set_selected_idx(None) - else: - self._set_selected_idx(min(idx, len(self._edit_points) - 1)) - self._update_edit_labels() - self.log(f"Deleted point: {idx}") - - def _add_interactive_point(self): - return - - def _show_interactive_overlay(self): - self._edit_overlay_dialog = None - self._edit_info_label = None - self._edit_status_label = None - self._edit_btn_edge = None - mode = self._edit_mode or "-" - if mode == "skeleton": - hint = "Edit Skeleton: drag sphere to move | Delete/Backspace to remove | ESC to cancel | click 'Edit Skeleton' again to apply" - elif mode == "graph": - hint = "Edit Graph: drag sphere to move node | Delete/Backspace to remove node/edge | E toggle edge mode | Click edge to select | ESC to cancel | click 'Edit Graph' again to apply" - else: - hint = f"Edit {mode}: ESC to cancel" - try: - self.statusBar().showMessage(hint) - except Exception: - pass - self.log(hint) - - def _close_interactive_overlay(self): - self._edit_overlay_dialog = None - self._edit_info_label = None - self._edit_status_label = None - self._edit_btn_edge = None - try: - self.statusBar().clearMessage() - except Exception: - pass - - def _enter_interactive_edit(self, mode, points, edges=None): - self.scene.invalidate_cache() - self.scene.sync_from_workspace() - self._cleanup_edit_actors() - self._remove_edit_widget() - self._close_interactive_overlay() - self._edit_mode = mode - self._edit_points = np.asarray(points, dtype=float).reshape(-1, 3).copy() - if edges is None: - self._edit_edges = np.empty((0, 2), dtype=int) - else: - arr = np.asarray(edges, dtype=int) - self._edit_edges = arr.reshape(-1, 2).copy() if len(arr) else np.empty((0, 2), dtype=int) - self._edit_original_points = self._edit_points.copy() - self._edit_original_edges = self._edit_edges.copy() - self._edit_selected_idx = None - self._edit_edge_mode = False - self._edit_edge_src_idx = None - self._edit_selected_edge_idx = None - self._update_edit_points_actor() - self._update_edit_edges_actor() - self._enable_interactive_key_events(True) - self._enable_interactive_point_picking(True) - self._show_interactive_overlay() - if len(self._edit_points) > 0: - self._set_selected_idx(0) - else: - self._set_selected_idx(None) - self.log(f"Interactive edit started: {mode}") - try: - self.plotter.render() - except Exception: - pass - - def _exit_interactive_edit(self, apply_changes): - mode = self._edit_mode - if mode is None: - return - try: - self._enable_interactive_key_events(False) - self._enable_interactive_point_picking(False) - self._remove_edit_widget() - if apply_changes: - if mode == "skeleton": - ed = SkeletonEditor(self.workspace) - ed.replace_points(self._edit_points) - self.workspace.remove_object_by_data_key("skeleton_points") - self.workspace.add_object(name="skeleton_points", kind=ObjectKind.SKELETON, data_key="skeleton_points", visible=True, opacity=1.0, color="red", point_size=8) - self.workspace.pipeline.mark_done(StepId.EDIT_SKELETON, skipped=False) - self.log(f"Skeleton edited: {len(self.workspace.skeleton_points)} points") - elif mode == "graph": - self.workspace.graph.points = np.asarray(self._edit_points, dtype=float).reshape(-1, 3) - self.workspace.graph.edges = np.asarray(self._edit_edges, dtype=int).reshape(-1, 2) if len(self._edit_edges) else np.empty((0, 2), dtype=int) - self.workspace.remove_object_by_data_key("graph_lines") - self.workspace.add_object(name="graph_lines", kind=ObjectKind.GRAPH, data_key="graph_lines", visible=True, opacity=1.0, color="blue", line_width=2) - self.workspace.pipeline.mark_done(StepId.EDIT_GRAPH, skipped=False) - self.log(f"Graph edited: {len(self.workspace.graph.points)} nodes, {len(self.workspace.graph.edges)} edges") - else: - self.log(f"Interactive edit cancelled: {mode}") - finally: - self._cleanup_edit_actors() - self._close_interactive_overlay() - self._edit_mode = None - self._edit_points = None - self._edit_edges = None - self._edit_selected_idx = None - self._edit_edge_mode = False - self._edit_edge_src_idx = None - self._edit_selected_edge_idx = None - self._edit_original_points = None - self._edit_original_edges = None - self.scene.invalidate_cache() - self.scene.sync_from_workspace() - self._refresh_all() - try: - self.plotter.render() - except Exception: - pass - - def _force_exit_edit(self): - if self._edit_mode is None: - return - self.log("ESC: force exit interactive edit") - try: - self._exit_interactive_edit(False) - except Exception as e: - self.log(f"Force exit cleanup error: {type(e).__name__}: {e}") - finally: - self._edit_mode = None - self._edit_points = None - self._edit_edges = None - self._edit_selected_idx = None - self._edit_edge_mode = False - self._edit_edge_src_idx = None - self._edit_selected_edge_idx = None - self._edit_original_points = None - self._edit_original_edges = None - self._edit_overlay_dialog = None - self._edit_info_label = None - self._edit_status_label = None - self._edit_btn_edge = None - try: - self.statusBar().clearMessage() - except Exception: - pass - try: - self._cleanup_edit_actors() - except Exception: - pass - try: - self._remove_edit_widget() - except Exception: - pass - try: - self._enable_interactive_key_events(False) - except Exception: - pass - try: - self._enable_interactive_point_picking(False) - except Exception: - pass - try: - self.setEnabled(True) - except Exception: - pass - try: - self.plotter.render() - except Exception: - pass - - def _start_skeleton_interactive_edit(self): - if self.workspace.skeleton_points is None or len(self.workspace.skeleton_points) == 0: - self.log("Edit Skeleton: no skeleton points.") - return - self._enter_interactive_edit("skeleton", self.workspace.skeleton_points, edges=None) - - def _start_graph_interactive_edit(self): - if self.workspace.graph is None or len(self.workspace.graph.points) == 0: - self.log("Edit Graph: no graph data.") - return - self._enter_interactive_edit("graph", self.workspace.graph.points, self.workspace.graph.edges) - - def closeEvent(self, event): - try: - if self._edit_mode is not None: - self._exit_interactive_edit(False) - self._clear_plane_drag_widgets() - except Exception: - pass - event.accept() - - -def main(): - app = QtWidgets.QApplication(sys.argv) - w = MainWindow() - w.show() - sys.exit(app.exec_()) - - -if __name__ == "__main__": - main() + return { + "MainWindow": MainWindow, + "main": main, + }[name] + raise AttributeError(f"module 'autoflow.app' has no attribute {name!r}") diff --git a/autoflow/core/__init__.py b/autoflow/core/__init__.py new file mode 100755 index 0000000..e7ac33e --- /dev/null +++ b/autoflow/core/__init__.py @@ -0,0 +1,4 @@ +from .models import * +from .pipeline import PipelineEngine, StepResult + +__all__ = [name for name in globals() if not name.startswith("_")] diff --git a/autoflow/core/models.py b/autoflow/core/models.py new file mode 100755 index 0000000..2709cd5 --- /dev/null +++ b/autoflow/core/models.py @@ -0,0 +1,540 @@ +import copy, json, uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple +import numpy as np + + +class ObjectKind(Enum): + SEGMENTATION = "Segmentation" + SKELETON = "Skeleton" + GRAPH = "Graph" + BRANCH = "Branch" + PLANE = "Plane" + FLOW = "Flow" + METRIC = "Metric" + AUX = "Aux" + + +class StepId(Enum): + GENERATE_SKELETON = "step_skeleton" + EDIT_SKELETON = "step_edit_skeleton" + GENERATE_GRAPH = "step_graph" + EDIT_GRAPH = "step_edit_graph" + GENERATE_PLANES = "step_planes" + EDIT_PLANES = "step_edit_planes" + GENERATE_STREAMLINES = "step_streamlines" + PLANE_STREAMLINES = "step_plane_streamlines" + COMPUTE_PLANE_METRICS = "step_plane_metrics" + COMPUTE_DERIVED_METRICS = "step_derived_metrics" + + @property + def label(self): + return { + StepId.GENERATE_SKELETON: "Generate Skeleton", + StepId.EDIT_SKELETON: "Edit Skeleton", + StepId.GENERATE_GRAPH: "Generate Graph", + StepId.EDIT_GRAPH: "Edit Graph", + StepId.GENERATE_PLANES: "Generate Planes", + StepId.EDIT_PLANES: "Edit Planes", + StepId.GENERATE_STREAMLINES: "Generate Streamlines", + StepId.PLANE_STREAMLINES: "Plane Streamlines", + StepId.COMPUTE_PLANE_METRICS: "Calculate && Save Metrics", + StepId.COMPUTE_DERIVED_METRICS: "WSS / TKE", + }[self] + + @staticmethod + def top_row_steps(): + return [ + StepId.GENERATE_SKELETON, + StepId.GENERATE_GRAPH, + StepId.GENERATE_PLANES, + StepId.COMPUTE_PLANE_METRICS, + ] + + @staticmethod + def bottom_row_steps(): + return [ + StepId.EDIT_SKELETON, + StepId.EDIT_GRAPH, + ] + + @staticmethod + def extra_row_steps(): + return [ + StepId.GENERATE_STREAMLINES, + StepId.PLANE_STREAMLINES, + StepId.COMPUTE_DERIVED_METRICS, + ] + + +@dataclass +class PreprocessParams: + def to_dict(self): + return {} + + @staticmethod + def from_dict(d): + return PreprocessParams() + + +@dataclass +class SkeletonParams: + remove_small_cc: bool = False + min_cc_volume_mm3: float = 50.0 + do_closing: bool = True + do_opening: bool = False + gaussian_sigma: float = 0.5 + gaussian_enabled: bool = True + + def to_dict(self): + return { + "remove_small_cc": self.remove_small_cc, + "min_cc_volume_mm3": self.min_cc_volume_mm3, + "do_closing": self.do_closing, + "do_opening": self.do_opening, + "gaussian_sigma": self.gaussian_sigma, + "gaussian_enabled": self.gaussian_enabled, + } + + @staticmethod + def from_dict(d): + if "keep_largest_cc" in d and "remove_small_cc" not in d: + d["remove_small_cc"] = d["keep_largest_cc"] + return SkeletonParams( + remove_small_cc=bool(d.get("remove_small_cc", False)), + min_cc_volume_mm3=float(d.get("min_cc_volume_mm3", 50.0)), + do_closing=bool(d.get("do_closing", True)), + do_opening=bool(d.get("do_opening", False)), + gaussian_sigma=float(d.get("gaussian_sigma", 0.5)), + gaussian_enabled=bool(d.get("gaussian_enabled", True))) + + +@dataclass +class PlaneGenerationParams: + use_center_plane: bool = True + cross_section_distance: float = 20.0 + start_distance: float = 5.0 + end_distance: float = 0.0 + smoothing_window: int = 15 + smoothing_polyorder: int = 2 + inter_time: int = 10 + + def to_dict(self): + return { + "use_center_plane": bool(self.use_center_plane), + "cross_section_distance": self.cross_section_distance, + "start_distance": self.start_distance, + "end_distance": self.end_distance, + "smoothing_window": int(self.smoothing_window), + "smoothing_polyorder": int(self.smoothing_polyorder), + "inter_time": int(self.inter_time), + } + + @staticmethod + def from_dict(d): + return PlaneGenerationParams( + use_center_plane=bool(d.get("use_center_plane", True)), + cross_section_distance=float(d.get("cross_section_distance", 20.0)), + start_distance=float(d.get("start_distance", 5.0)), + end_distance=float(d.get("end_distance", 0.0)), + smoothing_window=int(d.get("smoothing_window", 15)), + smoothing_polyorder=int(d.get("smoothing_polyorder", 3)), + inter_time=int(d.get("inter_time", 10)), + ) + + +@dataclass +class StreamlineParams: + seed_ratio: float = 0.02 + max_steps: int = 2000 + min_seeds: int = 50 + terminal_speed: float = 0.01 + rng_seed: int = 0 + + def to_dict(self): + return { + "seed_ratio": self.seed_ratio, + "max_steps": self.max_steps, + "min_seeds": self.min_seeds, + "terminal_speed": self.terminal_speed, + "rng_seed": self.rng_seed, + } + + @staticmethod + def from_dict(d): + return StreamlineParams( + seed_ratio=float(d.get("seed_ratio", 0.02)), + max_steps=int(d.get("max_steps", 2000)), + min_seeds=50, + terminal_speed=float(d.get("terminal_speed", 0.01)), + rng_seed=int(d.get("rng_seed", 0)), + ) + + +@dataclass +class DerivedMetricsParams: + smoothing_iteration: int = 200 + viscosity: float = 4.0 + inward_distance: float = 0.6 + parabolic_fitting: bool = True + no_slip_condition: bool = True + step_size: int = 5 + tube_radius: float = 0.1 + rho: float = 1060.0 + use_multithread: bool = False + + def to_dict(self): + return { + "smoothing_iteration": self.smoothing_iteration, + "viscosity": self.viscosity, + "inward_distance": self.inward_distance, + "parabolic_fitting": self.parabolic_fitting, + "no_slip_condition": self.no_slip_condition, + "step_size": self.step_size, + "tube_radius": self.tube_radius, + "rho": self.rho, + "use_multithread": self.use_multithread, + } + + @staticmethod + def from_dict(d): + return DerivedMetricsParams( + smoothing_iteration=int(d.get("smoothing_iteration", 200)), + viscosity=float(d.get("viscosity", 4.0)), + inward_distance=float(d.get("inward_distance", 0.6)), + parabolic_fitting=bool(d.get("parabolic_fitting", True)), + no_slip_condition=bool(d.get("no_slip_condition", True)), + step_size=int(d.get("step_size", 5)), + tube_radius=float(d.get("tube_radius", 0.1)), + rho=float(d.get("rho", 1060.0)), + use_multithread=bool(d.get("use_multithread", False)), + ) + + +@dataclass +class PathsState: + segmask_path: str = "" + flow_path: str = "" + workspace_path: str = "" + output_dir: str = "" + + +@dataclass +class PipelineFlags: + completed: Dict[str, bool] = field(default_factory=dict) + skipped: Dict[str, bool] = field(default_factory=dict) + + def mark_done(self, step: StepId, skipped: bool = False): + self.completed[step.value] = True + self.skipped[step.value] = bool(skipped) + + def is_done(self, step: StepId) -> bool: + return bool(self.completed.get(step.value, False)) + + def reset(self): + self.completed.clear() + self.skipped.clear() + + +@dataclass +class SceneObject: + uid: str + name: str + kind: ObjectKind + data_key: str + visible: bool = True + opacity: float = 1.0 + color: str = "white" + scalars: Optional[str] = None + cmap: str = "turbo" + clim: Optional[Tuple[float, float]] = None + point_size: int = 8 + line_width: int = 2 + tube_radius: float = 0.0 + show_scalar_bar: bool = False + scalar_bar_title: Optional[str] = None + dynamic: bool = False + actor: Any = None + label_actor: Any = None + + +@dataclass +class PlaneData: + center: np.ndarray + normal: np.ndarray + label: int = 1 + path_index: int = 0 + distance: float = 0.0 + metrics: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class GraphData: + points: np.ndarray = field(default_factory=lambda: np.empty((0, 3), dtype=float)) + edges: np.ndarray = field(default_factory=lambda: np.empty((0, 2), dtype=int)) + + +@dataclass +class DerivedResults: + plane_metrics: List[Dict[str, Any]] = field(default_factory=list) + plane_qc: Dict[str, Any] = field(default_factory=dict) + wss_surfaces: List[Any] = field(default_factory=list) + wss_volume: Optional[np.ndarray] = None + tke_volume: Any = None + tke_array: Optional[np.ndarray] = None + streamlines: List[Any] = field(default_factory=list) + pixelwise_export: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class Workspace: + paths: PathsState = field(default_factory=PathsState) + pipeline: PipelineFlags = field(default_factory=PipelineFlags) + preprocess_params: PreprocessParams = field(default_factory=PreprocessParams) + skeleton_params: SkeletonParams = field(default_factory=SkeletonParams) + plane_gen_params: PlaneGenerationParams = field(default_factory=PlaneGenerationParams) + streamline_params: StreamlineParams = field(default_factory=StreamlineParams) + derived_params: DerivedMetricsParams = field(default_factory=DerivedMetricsParams) + + resolution: np.ndarray = field(default_factory=lambda: np.array([1., 1., 1.])) + origin: np.ndarray = field(default_factory=lambda: np.array([0., 0., 0.])) + spatial_order: List[str] = field(default_factory=lambda: ["FH", "AP", "LR"]) + venc_order: List[str] = field(default_factory=lambda: ["FH", "AP", "LR"]) + venc: np.ndarray = field(default_factory=lambda: np.array([1., 1., 1.])) + rr: float = 1000.0 + + segmask_raw: Optional[np.ndarray] = None + segmask_labels: Optional[np.ndarray] = None + segmask_binary: Optional[np.ndarray] = None + segmask_3d: Optional[np.ndarray] = None + + mag_raw: Optional[np.ndarray] = None + + skeleton_points: Optional[np.ndarray] = None + skeleton_mask: Optional[np.ndarray] = None + + graph: GraphData = field(default_factory=GraphData) + branch_labels: Optional[np.ndarray] = None + centerline_paths: List[np.ndarray] = field(default_factory=list) + centerline_node_paths: List[List[int]] = field(default_factory=list) + centerline_paths_smooth: List[np.ndarray] = field(default_factory=list) + path_info: List[Dict[str, Any]] = field(default_factory=list) + forks: List[Dict[str, Any]] = field(default_factory=list) + planes: List[PlaneData] = field(default_factory=list) + + flow_raw: Optional[np.ndarray] = None + streamline_seeds: Optional[np.ndarray] = None + streamline_cache: Dict[int, Any] = field(default_factory=dict) + streamline_active: bool = False + plane_streamline_cache: Dict[int, Any] = field(default_factory=dict) + plane_streamline_active: bool = False + plane_streamline_plane_idx: int = -1 + derived: DerivedResults = field(default_factory=DerivedResults) + + scene_objects: Dict[str, SceneObject] = field(default_factory=dict) + current_t: int = 0 + data_loaded: bool = False + + ortho_cursor: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0], dtype=int)) + selected_path_index: int = -1 + + def time_count(self): + if self.flow_raw is not None and self.flow_raw.ndim == 5: + return int(self.flow_raw.shape[3]) + if self.segmask_raw is not None and self.segmask_raw.ndim == 4: + return int(self.segmask_raw.shape[3]) + return 1 + + def has_flow(self): + return self.flow_raw is not None + + def unique_labels(self): + if self.segmask_raw is None: + return [] + return sorted(int(x) for x in np.unique(self.segmask_raw) if x != 0) + + def add_object(self, name, kind, data_key, **kw): + uid = str(uuid.uuid4()) + self.scene_objects[uid] = SceneObject(uid=uid, name=name, kind=kind, data_key=data_key, **kw) + return uid + + def remove_object(self, uid): + return self.scene_objects.pop(uid, None) + + def remove_object_by_data_key(self, data_key): + to_del = [u for u, o in self.scene_objects.items() if o.data_key == data_key] + return [self.scene_objects.pop(u) for u in to_del] + + def remove_objects_by_prefix(self, prefix): + to_del = [u for u, o in self.scene_objects.items() if o.data_key.startswith(prefix)] + return [self.scene_objects.pop(u) for u in to_del] + + def set_object_visible_by_data_key(self, data_key, visible): + for o in self.scene_objects.values(): + if o.data_key == data_key: + o.visible = visible + + def clear_streamlines(self): + self.streamline_seeds = None + self.streamline_cache.clear() + self.streamline_active = False + self.remove_object_by_data_key("streamlines_live") + + def clear_plane_streamlines(self): + self.plane_streamline_cache.clear() + self.plane_streamline_active = False + self.plane_streamline_plane_idx = -1 + self.remove_object_by_data_key("plane_streamlines_live") + + def reset_all(self): + for attr, default in [ + ("paths", PathsState()), + ("pipeline", PipelineFlags()), + ("preprocess_params", PreprocessParams()), + ("skeleton_params", SkeletonParams()), + ("plane_gen_params", PlaneGenerationParams()), + ("streamline_params", StreamlineParams()), + ("derived_params", DerivedMetricsParams()), + ]: + setattr(self, attr, default) + self.resolution = np.array([1., 1., 1.]) + self.origin = np.array([0., 0., 0.]) + self.venc = np.array([1., 1., 1.]) + self.rr = 1000.0 + for attr in ["segmask_raw", "segmask_labels", "segmask_binary", "segmask_3d", + "skeleton_points", "skeleton_mask", "branch_labels", "flow_raw", + "streamline_seeds", "mag_raw"]: + setattr(self, attr, None) + self.graph = GraphData() + self.centerline_paths = [] + self.centerline_node_paths = [] + self.centerline_paths_smooth = [] + self.path_info = [] + self.forks = [] + self.planes = [] + self.streamline_cache = {} + self.streamline_active = False + self.plane_streamline_cache = {} + self.plane_streamline_active = False + self.plane_streamline_plane_idx = -1 + self.derived = DerivedResults() + self.scene_objects = {} + self.current_t = 0 + self.data_loaded = False + self.ortho_cursor = np.array([0, 0, 0], dtype=int) + self.selected_path_index = -1 + + def snapshot_dict(self): + def arr(v): + return None if v is None else np.asarray(v).tolist() + return { + "paths": {"segmask_path": self.paths.segmask_path, "flow_path": self.paths.flow_path, + "workspace_path": self.paths.workspace_path, "output_dir": self.paths.output_dir}, + "pipeline": {"completed": dict(self.pipeline.completed), "skipped": dict(self.pipeline.skipped)}, + "preprocess_params": self.preprocess_params.to_dict(), + "skeleton_params": self.skeleton_params.to_dict(), + "plane_gen_params": self.plane_gen_params.to_dict(), + "streamline_params": self.streamline_params.to_dict(), + "derived_params": self.derived_params.to_dict(), + "resolution": arr(self.resolution), + "origin": arr(self.origin), + "venc": arr(self.venc), + "rr": float(self.rr), + "segmask_raw": arr(self.segmask_raw), + "segmask_labels": arr(self.segmask_labels), + "segmask_binary": arr(self.segmask_binary), + "segmask_3d": arr(self.segmask_3d), + "mag_raw": arr(self.mag_raw), + "skeleton_points": arr(self.skeleton_points), + "skeleton_mask": arr(self.skeleton_mask), + "graph": {"points": arr(self.graph.points), "edges": arr(self.graph.edges)}, + "branch_labels": arr(self.branch_labels), + "centerline_paths": [arr(x) for x in self.centerline_paths], + "centerline_node_paths": [list(map(int, x)) for x in self.centerline_node_paths], + "centerline_paths_smooth": [arr(x) for x in self.centerline_paths_smooth], + "path_info": copy.deepcopy(self.path_info), + "forks": copy.deepcopy(self.forks), + "planes": [{"center": arr(p.center), "normal": arr(p.normal), "label": int(p.label), + "path_index": int(p.path_index), "distance": float(p.distance), + "metrics": copy.deepcopy(p.metrics)} for p in self.planes], + "flow_raw": arr(self.flow_raw), + "streamline_seeds": arr(self.streamline_seeds), + "streamline_active": self.streamline_active, + "scene_objects": [ + {"uid": o.uid, "name": o.name, "kind": o.kind.value, "data_key": o.data_key, + "visible": o.visible, "opacity": o.opacity, "color": o.color, + "scalars": o.scalars, "cmap": o.cmap, + "clim": list(o.clim) if o.clim else None, + "point_size": o.point_size, "line_width": o.line_width, + "tube_radius": o.tube_radius, "show_scalar_bar": o.show_scalar_bar, + "scalar_bar_title": o.scalar_bar_title, "dynamic": o.dynamic} + for o in self.scene_objects.values()], + "current_t": self.current_t, "data_loaded": self.data_loaded, + "selected_path_index": int(self.selected_path_index), + } + + def restore_dict(self, d): + self.paths = PathsState(**{k: d.get("paths", {}).get(k, "") for k in ["segmask_path", "flow_path", "workspace_path", "output_dir"]}) + self.pipeline = PipelineFlags(completed=dict(d.get("pipeline", {}).get("completed", {})), + skipped=dict(d.get("pipeline", {}).get("skipped", {}))) + self.preprocess_params = PreprocessParams.from_dict(d.get("preprocess_params", {})) + self.skeleton_params = SkeletonParams.from_dict(d.get("skeleton_params", {})) + self.plane_gen_params = PlaneGenerationParams.from_dict(d.get("plane_gen_params", {})) + self.streamline_params = StreamlineParams.from_dict(d.get("streamline_params", {})) + self.derived_params = DerivedMetricsParams.from_dict(d.get("derived_params", {})) + self.resolution = np.asarray(d.get("resolution", [1, 1, 1]), dtype=float) + self.origin = np.array([0.0, 0.0, 0.0], dtype=float) + self.venc = np.asarray(d.get("venc", [1, 1, 1]), dtype=float) + self.rr = float(d.get("rr", 1000.0)) + + def nparr(k, dt=np.float64): + v = d.get(k) + return None if v is None else np.asarray(v, dtype=dt) + + self.segmask_raw = nparr("segmask_raw", np.int16) + self.segmask_labels = nparr("segmask_labels", np.int16) + self.segmask_binary = None if d.get("segmask_binary") is None else np.asarray(d["segmask_binary"], dtype=bool) + self.segmask_3d = None if d.get("segmask_3d") is None else np.asarray(d["segmask_3d"], dtype=bool) + self.mag_raw = nparr("mag_raw") + self.skeleton_points = nparr("skeleton_points") + self.skeleton_mask = nparr("skeleton_mask") + gd = d.get("graph", {}) + self.graph = GraphData( + points=np.asarray(gd.get("points", []), dtype=float).reshape(-1, 3) if gd.get("points") else np.empty((0, 3)), + edges=np.asarray(gd.get("edges", []), dtype=int).reshape(-1, 2) if gd.get("edges") else np.empty((0, 2), dtype=int)) + self.branch_labels = nparr("branch_labels") + self.centerline_paths = [np.asarray(x, dtype=float) for x in d.get("centerline_paths", [])] + self.centerline_node_paths = [list(map(int, x)) for x in d.get("centerline_node_paths", [])] + self.centerline_paths_smooth = [np.asarray(x, dtype=float) for x in d.get("centerline_paths_smooth", [])] + self.path_info = copy.deepcopy(d.get("path_info", [])) + self.forks = copy.deepcopy(d.get("forks", [])) + self.planes = [] + for p in d.get("planes", []): + self.planes.append(PlaneData( + center=np.asarray(p["center"], dtype=float), normal=np.asarray(p["normal"], dtype=float), + label=int(p.get("label", 1)), path_index=int(p.get("path_index", 0)), + distance=float(p.get("distance", 0.0)), metrics=copy.deepcopy(p.get("metrics", {})))) + self.flow_raw = nparr("flow_raw") + self.streamline_seeds = nparr("streamline_seeds") + self.streamline_cache = {} + self.streamline_active = bool(d.get("streamline_active", False)) + self.plane_streamline_cache = {} + self.plane_streamline_active = False + self.plane_streamline_plane_idx = -1 + self.scene_objects = {} + for it in d.get("scene_objects", []): + uid = it["uid"] + self.scene_objects[uid] = SceneObject( + uid=uid, name=it["name"], kind=ObjectKind(it["kind"]), data_key=it["data_key"], + visible=bool(it.get("visible", True)), opacity=float(it.get("opacity", 1.0)), + color=it.get("color", "white"), scalars=it.get("scalars"), + cmap=it.get("cmap", "turbo"), + clim=tuple(it["clim"]) if it.get("clim") else None, + point_size=int(it.get("point_size", 8)), line_width=int(it.get("line_width", 2)), + tube_radius=float(it.get("tube_radius", 0.0)), + show_scalar_bar=bool(it.get("show_scalar_bar", False)), + scalar_bar_title=it.get("scalar_bar_title"), dynamic=bool(it.get("dynamic", False))) + self.current_t = int(d.get("current_t", 0)) + self.data_loaded = bool(d.get("data_loaded", False)) + self.selected_path_index = int(d.get("selected_path_index", -1)) diff --git a/autoflow/core/pipeline.py b/autoflow/core/pipeline.py new file mode 100755 index 0000000..ffc6320 --- /dev/null +++ b/autoflow/core/pipeline.py @@ -0,0 +1,429 @@ +import json +import os +import numpy as np + +from .models import StepId, ObjectKind +from ..algorithms import ( + load_h5_data, + filter_segmask_labels, binarize_segmask, merge_segmask_to_3d, + preprocess_mask_for_skeleton, + generate_skeleton_from_mask3d, build_graph_from_points, + segment_vessels_from_graph_and_mask, + generate_planes_from_paths, + compute_plane_metrics, compute_derived_metrics, + compute_plane_metrics_multithread, + generate_seed_points, +) + + +class StepResult: + def __init__(self, step, success=True, skipped=False, message="", outputs=None): + self.step = step + self.success = success + self.skipped = skipped + self.message = message + self.outputs = outputs or [] + + +class PipelineEngine: + def _output_dir(self, ws): + out_dir = getattr(ws.paths, "output_dir", "") or "" + if out_dir: + os.makedirs(out_dir, exist_ok=True) + return out_dir + base = ws.paths.segmask_path or ws.paths.flow_path or "." + out_dir = os.path.dirname(base) or "." + os.makedirs(out_dir, exist_ok=True) + return out_dir + + def _json_safe(self, obj): + if isinstance(obj, np.floating): + val = float(obj) + return val if np.isfinite(val) else None + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, float): + return obj if np.isfinite(obj) else None + if isinstance(obj, dict): + return {k: self._json_safe(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [self._json_safe(v) for v in obj] + return obj + + def load_data(self, ws, log): + path = ws.paths.segmask_path or ws.paths.flow_path + if not path: + raise ValueError("data path is empty") + data = load_h5_data(path) + flow = np.asarray(data["flow"], dtype=np.float32) + mag = np.asarray(data["mag"], dtype=np.float32) + seg = np.asarray(data["segmask"], dtype=np.int16) + if flow.ndim == 4 and flow.shape[-1] == 3: + flow = flow[..., np.newaxis, :] + if mag.ndim == 3: + mag = mag[..., np.newaxis] + if seg.ndim == 3: + seg = np.repeat(seg[..., np.newaxis], flow.shape[3], axis=3) + elif seg.ndim == 4 and seg.shape[3] == 1 and flow.shape[3] > 1: + seg = np.repeat(seg, flow.shape[3], axis=3) + if seg.shape[3] != flow.shape[3]: + raise ValueError(f"segmask time dimension {seg.shape[3]} != flow {flow.shape[3]}") + + ws.segmask_raw = seg + ws.resolution = np.asarray(data["resolution"], dtype=float).reshape(3) + ws.origin = np.asarray(data.get("origin", [0.0, 0.0, 0.0]), dtype=float).reshape(3) + ws.venc = np.asarray(data["venc"], dtype=float).reshape(-1) + ws.rr = float(data.get("rr", 1000.0)) + ws.current_t = 0 + ws.flow_raw = flow + ws.mag_raw = mag + ws.derived.tke_array = np.asarray(data["tke_array"], dtype=np.float32) if "tke_array" in data else None + ws.data_loaded = True + + ws.remove_object_by_data_key("segmask_raw_surface") + ws.add_object(name="segmask_raw", kind=ObjectKind.SEGMENTATION, + data_key="segmask_raw_surface", visible=True, opacity=0.3, + scalars="label", cmap="tab10", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Label") + + ulabels = ws.unique_labels() + msg = f"Loaded: segmask={ws.segmask_raw.shape} labels={ulabels} rr={ws.rr}" + msg += f" flow={ws.flow_raw.shape} mag={ws.mag_raw.shape}" + msg += f" origin={ws.origin.tolist()}" + log(msg) + return msg + + def preprocess(self, ws): + if ws.segmask_raw is None: + raise ValueError("segmask_raw is None") + ws.segmask_labels = filter_segmask_labels(ws.segmask_raw) + ws.segmask_binary = binarize_segmask(ws.segmask_labels) + ws.segmask_3d = merge_segmask_to_3d(ws.segmask_binary) + ws.set_object_visible_by_data_key("segmask_raw_surface", False) + ws.remove_object_by_data_key("segmask_pre_surface") + ws.add_object(name="segmask_preprocessed", kind=ObjectKind.SEGMENTATION, + data_key="segmask_pre_surface", visible=True, opacity=0.25, + scalars="label", cmap="tab10", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Label") + + def run_step(self, ws, step, log): + dispatch = { + StepId.GENERATE_SKELETON: self._step_generate_skeleton, + StepId.EDIT_SKELETON: self._step_edit_skeleton, + StepId.GENERATE_GRAPH: self._step_generate_graph, + StepId.EDIT_GRAPH: self._step_edit_graph, + StepId.GENERATE_PLANES: self._step_generate_planes, + StepId.EDIT_PLANES: self._step_edit_planes, + StepId.GENERATE_STREAMLINES: self._step_generate_streamlines, + StepId.PLANE_STREAMLINES: self._step_plane_streamlines, + StepId.COMPUTE_PLANE_METRICS: self._step_compute_plane_metrics, + StepId.COMPUTE_DERIVED_METRICS: self._step_compute_derived_metrics, + } + return dispatch[step](ws) + + def _step_generate_skeleton(self, ws): + self.preprocess(ws) + if ws.skeleton_params.remove_small_cc: + from ..algorithms import remove_small_cc_from_binary_mask + ws.segmask_binary = remove_small_cc_from_binary_mask( + ws.segmask_binary, ws.resolution, ws.skeleton_params.min_cc_volume_mm3) + ws.segmask_3d = merge_segmask_to_3d(ws.segmask_binary) + processed = preprocess_mask_for_skeleton(ws.segmask_3d, ws.skeleton_params, resolution=ws.resolution) + pts, mask = generate_skeleton_from_mask3d(processed, ws.resolution) + ws.skeleton_points = pts + ws.skeleton_mask = mask + ws.remove_object_by_data_key("skeleton_points") + ws.remove_object_by_data_key("skeleton_mask_surface") + ws.remove_object_by_data_key("segmask_3d_surface") + ws.add_object(name="skeleton_points", kind=ObjectKind.SKELETON, + data_key="skeleton_points", visible=True, opacity=1.0, + color="red", point_size=8) + # ws.add_object(name="skeleton_mask", kind=ObjectKind.SKELETON, + # data_key="skeleton_mask_surface", visible=False, opacity=0.15, color="yellow") + ws.add_object(name="segmask_mesh", kind=ObjectKind.SEGMENTATION, + data_key="segmask_3d_surface", visible=True, opacity=0.15, + color="gray") + ws.pipeline.mark_done(StepId.GENERATE_SKELETON) + return StepResult(StepId.GENERATE_SKELETON, True, False, f"Skeleton: {len(pts)} points") + + def _step_edit_skeleton(self, ws): + ws.pipeline.mark_done(StepId.EDIT_SKELETON, skipped=True) + return StepResult(StepId.EDIT_SKELETON, True, True, "Skeleton edit") + + def _step_generate_graph(self, ws): + if ws.skeleton_points is None or len(ws.skeleton_points) == 0: + self._step_generate_skeleton(ws) + graph = build_graph_from_points(ws.skeleton_points, ws.resolution) + ws.graph = graph + + flow_for_orientation = None + if ws.flow_raw is not None and ws.segmask_binary is not None: + flow_for_orientation = ws.flow_raw * ws.segmask_binary[..., None] + labels, paths, node_paths, path_info, forks = segment_vessels_from_graph_and_mask( + ws.segmask_3d, ws.graph, ws.resolution, + flow_xyzt3=flow_for_orientation, + segmask_binary_4d=ws.segmask_binary, + origin=ws.origin, + ) + ws.branch_labels = labels + ws.centerline_paths = [np.asarray(p, dtype=float) for p in paths] + ws.centerline_node_paths = [list(map(int, p)) for p in node_paths] + ws.path_info = path_info + ws.forks = forks + ws.selected_path_index = -1 + + ws.remove_object_by_data_key("graph_lines") + ws.add_object(name="graph_lines", kind=ObjectKind.GRAPH, + data_key="graph_lines", visible=True, opacity=1.0, + color="blue", line_width=2) + + ws.remove_objects_by_prefix("path_") + ws.remove_objects_by_prefix("smooth_path_") + ws.remove_objects_by_prefix("path_arrow_") + ws.remove_object_by_data_key("fork_markers") + + if len(ws.forks) > 0: + ws.add_object(name="Forks", kind=ObjectKind.AUX, + data_key="fork_markers", visible=True, opacity=1.0, + color="magenta", point_size=12) + + ws.pipeline.mark_done(StepId.GENERATE_GRAPH) + return StepResult(StepId.GENERATE_GRAPH, True, False, + f"Graph: {len(graph.points)} nodes, {len(graph.edges)} edges | " + f"paths={len(ws.centerline_paths)} forks={len(ws.forks)}") + + def _step_edit_graph(self, ws): + ws.pipeline.mark_done(StepId.EDIT_GRAPH, skipped=True) + return StepResult(StepId.EDIT_GRAPH, True, True, "Graph edit") + + def _compute_plane_metrics_internal(self, ws, save=True, use_multithread=False): + if not ws.has_flow(): + return [], {}, "Plane metrics skipped: no flow" + if ws.segmask_binary is None: + self.preprocess(ws) + # Prefer the smoothed centerlines (better local tangents) but fall back + # to the raw ordered ones if the smoothing step hasn't been run yet. + paths_for_tangent = ws.centerline_paths_smooth if len(ws.centerline_paths_smooth) > 0 else ws.centerline_paths + if use_multithread: + metrics, qc = compute_plane_metrics_multithread( + ws.flow_raw, ws.segmask_binary, ws.resolution, ws.origin, ws.planes, + RR=ws.rr, branch_labels_3d=ws.branch_labels, + path_info=ws.path_info, forks=ws.forks, paths=paths_for_tangent, + return_qc=True) + else: + metrics, qc = compute_plane_metrics( + ws.flow_raw, ws.segmask_binary, ws.resolution, ws.origin, ws.planes, + RR=ws.rr, branch_labels_3d=ws.branch_labels, + path_info=ws.path_info, forks=ws.forks, paths=paths_for_tangent, + return_qc=True) + ws.derived.plane_metrics = metrics + ws.derived.plane_qc = qc + for i, metric in enumerate(metrics): + if i < len(ws.planes): + ws.planes[i].metrics = dict(metric) + msg = f"Plane metrics: {len(metrics)} paths={len(qc.get('path_ic', {}))} forks={len(qc.get('forks', []))}" + if save: + out_dir = self._output_dir(ws) + plane_metric_path = os.path.join(out_dir, "plane_metrics.json") + qc_path = os.path.join(out_dir, "plane_qc.json") + with open(plane_metric_path, "w", encoding="utf-8") as f: + json.dump(metrics, f, ensure_ascii=False, indent=2) + with open(qc_path, "w", encoding="utf-8") as f: + json.dump(qc, f, ensure_ascii=False, indent=2) + msg += f" saved={plane_metric_path} qc={qc_path}" + return metrics, qc, msg + + def _save_planes_json(self, ws): + out_dir = self._output_dir(ws) + out_path = os.path.join(out_dir, "planes.json") + payload = [] + origin = np.asarray(ws.origin, dtype=float).reshape(3) + for i, p in enumerate(ws.planes): + center_local = np.asarray(p.center, dtype=float).reshape(3) + item = { + "plane_index": int(i), + "center": center_local.tolist(), + "center_world": (center_local + origin).tolist(), + "normal": np.asarray(p.normal).tolist(), + "label": int(p.label), + "path_index": int(p.path_index), + "distance": float(p.distance), + } + if p.metrics: + item.update(json.loads(json.dumps(p.metrics, ensure_ascii=False))) + if 0 <= int(p.path_index) < len(ws.path_info): + item["path_info"] = ws.path_info[int(p.path_index)] + payload.append(item) + with open(out_path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + return out_path + + def _step_generate_planes(self, ws): + if ws.graph is None or len(ws.graph.points) == 0: + self._step_generate_graph(ws) + + if len(ws.centerline_paths) == 0: + flow_for_orientation = None + if ws.flow_raw is not None and ws.segmask_binary is not None: + flow_for_orientation = ws.flow_raw * ws.segmask_binary[..., None] + labels, paths, node_paths, path_info, forks = segment_vessels_from_graph_and_mask( + ws.segmask_3d, ws.graph, ws.resolution, + flow_xyzt3=flow_for_orientation, + segmask_binary_4d=ws.segmask_binary, + origin=ws.origin, + ) + ws.branch_labels = labels + ws.centerline_paths = [np.asarray(p, dtype=float) for p in paths] + ws.centerline_node_paths = [list(map(int, p)) for p in node_paths] + ws.path_info = path_info + ws.forks = forks + ws.selected_path_index = -1 + + ws.remove_objects_by_prefix("smooth_path_") + ws.remove_objects_by_prefix("path_arrow_") + ws.remove_object_by_data_key("fork_markers") + + pgp = ws.plane_gen_params + planes, smooth_paths = generate_planes_from_paths( + ws.centerline_paths, + cross_section_distance=pgp.cross_section_distance, + start_distance=pgp.start_distance, + end_distance=pgp.end_distance, + smoothing_window=pgp.smoothing_window * pgp.inter_time, + smoothing_polyorder=pgp.smoothing_polyorder, + inter_time=pgp.inter_time, + use_center_plane=pgp.use_center_plane, + ) + ws.planes = planes + ws.centerline_paths_smooth = smooth_paths + for i in range(len(ws.centerline_paths_smooth)): + direction_text = "" + if i < len(ws.path_info): + direction_text = ws.path_info[i].get("direction_text", "") + name = f"Path {i}" if not direction_text else f"Path {i} [{direction_text}]" + ws.add_object(name=name, kind=ObjectKind.BRANCH, + data_key=f"smooth_path_{i}", visible=True, opacity=1.0, + color="red", line_width=3) + # ws.add_object(name=f"Path {i} Arrow", kind=ObjectKind.AUX, + # data_key=f"path_arrow_{i}", visible=True, opacity=1.0, + # color="lime", line_width=2) + if len(ws.forks) > 0: + ws.remove_object_by_data_key("fork_markers") + ws.add_object(name="Forks", kind=ObjectKind.AUX, + data_key="fork_markers", visible=True, opacity=1.0, + color="magenta", point_size=12) + + ws.remove_objects_by_prefix("plane_") + for i in range(len(ws.planes)): + ws.add_object(name=f"Plane {i}", kind=ObjectKind.PLANE, + data_key=f"plane_{i}", visible=True, opacity=0.6, + color="yellow", line_width=2) + + planes_path = self._save_planes_json(ws) + ws.pipeline.mark_done(StepId.GENERATE_PLANES) + msg = f"Planes: {len(ws.planes)} paths={len(ws.centerline_paths_smooth)} forks={len(ws.forks)} saved={planes_path}" + return StepResult(StepId.GENERATE_PLANES, True, False, msg) + + def _step_edit_planes(self, ws): + ws.pipeline.mark_done(StepId.EDIT_PLANES, skipped=True) + return StepResult(StepId.EDIT_PLANES, True, True, "Plane edit") + + def _step_generate_streamlines(self, ws): + if ws.flow_raw is None or ws.segmask_3d is None: + return StepResult(StepId.GENERATE_STREAMLINES, True, True, "Streamlines skipped: no flow or mask") + self.preprocess(ws) + ws.streamline_seeds = generate_seed_points( + ws.segmask_3d, + ws.resolution, + ws.origin, + ratio=ws.streamline_params.seed_ratio, + rng_seed=ws.streamline_params.rng_seed, + min_seeds=ws.streamline_params.min_seeds, + ) + ws.streamline_cache.clear() + ws.streamline_active = True + ws.remove_object_by_data_key("streamlines_live") + ws.add_object( + name="streamlines", kind=ObjectKind.FLOW, + data_key="streamlines_live", visible=True, opacity=1.0, + scalars="Velocity", cmap="turbo", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") + ws.pipeline.mark_done(StepId.GENERATE_STREAMLINES) + p = ws.streamline_params + param_msg = (f"Streamlines enabled: seed_ratio={p.seed_ratio} max_steps={p.max_steps} " + f"min_seeds={p.min_seeds} terminal_speed={p.terminal_speed} rng_seed={p.rng_seed}") + return StepResult(StepId.GENERATE_STREAMLINES, True, False, param_msg) + + def _step_plane_streamlines(self, ws): + if ws.flow_raw is None or ws.segmask_3d is None: + return StepResult(StepId.PLANE_STREAMLINES, True, True, "Plane streamlines skipped: no flow or mask") + if len(ws.planes) == 0: + return StepResult(StepId.PLANE_STREAMLINES, True, True, "Plane streamlines skipped: no planes") + self.preprocess(ws) + ws.plane_streamline_cache.clear() + ws.plane_streamline_active = True + ws.remove_object_by_data_key("plane_streamlines_live") + ws.add_object( + name="plane_streamlines", kind=ObjectKind.FLOW, + data_key="plane_streamlines_live", visible=True, opacity=1.0, + scalars="Velocity", cmap="turbo", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") + ws.pipeline.mark_done(StepId.PLANE_STREAMLINES) + pidx = ws.plane_streamline_plane_idx + return StepResult(StepId.PLANE_STREAMLINES, True, False, + f"Plane streamlines enabled from plane {pidx}") + + def _step_compute_plane_metrics(self, ws): + if not ws.has_flow(): + return StepResult(StepId.COMPUTE_PLANE_METRICS, True, True, "Plane metrics skipped: no flow") + if len(ws.planes) == 0: + self._step_generate_planes(ws) + use_mt = getattr(ws.derived_params, "use_multithread", False) + _, _, msg = self._compute_plane_metrics_internal(ws, save=True, use_multithread=use_mt) + self._save_planes_json(ws) + ws.pipeline.mark_done(StepId.COMPUTE_PLANE_METRICS) + return StepResult(StepId.COMPUTE_PLANE_METRICS, True, False, msg) + + def _step_compute_derived_metrics(self, ws): + if not ws.has_flow(): + return StepResult(StepId.COMPUTE_DERIVED_METRICS, True, True, "Derived metrics skipped: no flow") + self.preprocess(ws) + dp = ws.derived_params + loaded_tke = ws.derived.tke_array + result = compute_derived_metrics( + flow=ws.flow_raw * ws.segmask_binary[..., None], + mask4d=ws.segmask_binary, + spacing=ws.resolution, + origin=ws.origin, + smoothing_iteration=dp.smoothing_iteration, + viscosity=dp.viscosity, + inward_distance=dp.inward_distance, + parabolic_fitting=dp.parabolic_fitting, + no_slip_condition=dp.no_slip_condition, + step_size=dp.step_size, + tube_radius=dp.tube_radius, + rho=dp.rho, + save_pixelwise=False, + tke_array=loaded_tke, + ) + ws.derived.wss_surfaces = result["wss_surfaces"] + ws.derived.wss_volume = result.get("wss_volume") + ws.derived.tke_volume = result["tke_volume"] + ws.derived.tke_array = result.get("tke_array") + ws.derived.streamlines = [] + ws.derived.pixelwise_export = result.get("pixelwise_export", {}) + for dk in ["wss_surface_live", "tke_volume"]: + ws.remove_object_by_data_key(dk) + wss_max = float(np.nanmax(ws.derived.wss_volume)) if ws.derived.wss_volume is not None and np.size(ws.derived.wss_volume) else 0.0 + tke_max = float(np.nanmax(ws.derived.tke_array)) if ws.derived.tke_array is not None and np.size(ws.derived.tke_array) else 0.0 + ws.add_object(name="wss_surface", kind=ObjectKind.METRIC, + data_key="wss_surface_live", visible=False, opacity=1.0, + scalars="wss", cmap="jet", clim=(0.0, wss_max if wss_max > 0 else 1.0), dynamic=True, + show_scalar_bar=True, scalar_bar_title="WSS (Pa)") + ws.add_object(name="tke_volume", kind=ObjectKind.METRIC, + data_key="tke_volume", visible=False, opacity=0.5, + scalars="TKE", cmap="hot", clim=(0.0, tke_max if tke_max > 0 else 1.0), dynamic=True, + show_scalar_bar=True, scalar_bar_title="TKE (J/m³)") + msg = f"Derived: Nt={len(ws.derived.wss_surfaces)}" + ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) + return StepResult(StepId.COMPUTE_DERIVED_METRICS, True, False, msg) diff --git a/autoflow/editors.py b/autoflow/editors.py index 61a25c9..b5114dd 100755 --- a/autoflow/editors.py +++ b/autoflow/editors.py @@ -1,108 +1,5 @@ -import numpy as np -from .models import PlaneData, GraphData +"""Compatibility re-exports for UI editors.""" +from .ui.editors import GraphEditor, PlaneEditor, SkeletonEditor -class SkeletonEditor: - def __init__(self, workspace): - self.workspace = workspace - - def remove_points_by_index(self, indices): - if self.workspace.skeleton_points is None: - return - pts = np.asarray(self.workspace.skeleton_points) - mask = np.ones(len(pts), dtype=bool) - mask[np.asarray(indices, dtype=int)] = False - self.workspace.skeleton_points = pts[mask] - - def append_points(self, points): - pts = np.asarray(points, dtype=float).reshape(-1, 3) - if self.workspace.skeleton_points is None or len(self.workspace.skeleton_points) == 0: - self.workspace.skeleton_points = pts - else: - self.workspace.skeleton_points = np.vstack([self.workspace.skeleton_points, pts]) - - def replace_points(self, points): - self.workspace.skeleton_points = np.asarray(points, dtype=float).reshape(-1, 3) - - -class GraphEditor: - def __init__(self, workspace): - self.workspace = workspace - - def remove_edges_by_index(self, indices): - edges = np.asarray(self.workspace.graph.edges, dtype=int) - if len(edges) == 0: - return - mask = np.ones(len(edges), dtype=bool) - mask[np.asarray(indices, dtype=int)] = False - self.workspace.graph = GraphData( - points=self.workspace.graph.points.copy(), edges=edges[mask]) - - def append_edges(self, edges): - e = np.asarray(edges, dtype=int).reshape(-1, 2) - if len(self.workspace.graph.edges) == 0: - new_edges = e - else: - new_edges = np.vstack([self.workspace.graph.edges, e]) - self.workspace.graph = GraphData( - points=self.workspace.graph.points.copy(), edges=new_edges) - - def remove_nodes_by_index(self, indices): - points = np.asarray(self.workspace.graph.points, dtype=float) - edges = np.asarray(self.workspace.graph.edges, dtype=int) - rm = set(int(i) for i in indices) - keep_idx = [i for i in range(len(points)) if i not in rm] - remap = {old: new for new, old in enumerate(keep_idx)} - new_points = points[keep_idx] - new_edges = [] - for a, b in edges: - if int(a) in remap and int(b) in remap: - new_edges.append([remap[int(a)], remap[int(b)]]) - self.workspace.graph = GraphData( - points=new_points, - edges=np.asarray(new_edges, dtype=int).reshape(-1, 2) if new_edges else np.empty((0, 2), dtype=int), - ) - - -class PlaneEditor: - def __init__(self, workspace): - self.workspace = workspace - - def add_plane(self, center, normal, label=1, path_index=0, distance=0.0): - n = np.asarray(normal, dtype=float).reshape(3) - n = n / (np.linalg.norm(n) + 1e-12) - self.workspace.planes.append(PlaneData( - center=np.asarray(center, dtype=float).reshape(3), - normal=n, - label=int(label), - path_index=int(path_index), - distance=float(distance), - )) - - def remove_planes_by_index(self, indices): - rm = set(int(i) for i in indices) - self.workspace.planes = [p for i, p in enumerate(self.workspace.planes) if i not in rm] - - def update_plane(self, index, center=None, normal=None, label=None): - p = self.workspace.planes[int(index)] - if center is not None: - p.center = np.asarray(center, dtype=float).reshape(3) - if normal is not None: - n = np.asarray(normal, dtype=float).reshape(3) - p.normal = n / (np.linalg.norm(n) + 1e-12) - if label is not None: - p.label = int(label) - - def replace_planes(self, planes): - out = [] - for p in planes: - n = np.asarray(p["normal"], dtype=float).reshape(3) - n = n / (np.linalg.norm(n) + 1e-12) - out.append(PlaneData( - center=np.asarray(p["center"], dtype=float).reshape(3), - normal=n, - label=int(p.get("label", 1)), - path_index=int(p.get("path_index", 0)), - distance=float(p.get("distance", 0.0)), - )) - self.workspace.planes = out +__all__ = ["GraphEditor", "PlaneEditor", "SkeletonEditor"] diff --git a/autoflow/gui.py b/autoflow/gui.py index 73839b1..1e71aa8 100755 --- a/autoflow/gui.py +++ b/autoflow/gui.py @@ -1,29 +1,5 @@ -import sys +"""Compatibility re-exports for the GUI launcher.""" +from .ui.launcher import launch_gui, main -_GUI_IMPORTS = {"PyQt5", "pyvistaqt", "matplotlib", "pyvista", "vtk"} - - -def launch_gui() -> None: - try: - from .app import main as app_main - except ModuleNotFoundError as exc: - module_name = exc.name or "" - base_name = module_name.split(".", 1)[0] - if base_name in _GUI_IMPORTS: - raise SystemExit( - "GUI dependencies are not installed. Run `pip install \".[gui]\"` in the repo root first." - ) from exc - raise - app_main() - - -def main() -> None: - try: - launch_gui() - except KeyboardInterrupt: - sys.exit(130) - - -if __name__ == "__main__": - main() +__all__ = ["launch_gui", "main"] diff --git a/autoflow/io_utils.py b/autoflow/io_utils.py index 889cdea..69dc281 100755 --- a/autoflow/io_utils.py +++ b/autoflow/io_utils.py @@ -1,5 +1,5 @@ import json -from .models import Workspace +from .core.models import Workspace def save_workspace_file(path, workspace): diff --git a/autoflow/models.py b/autoflow/models.py index 2709cd5..d841a49 100755 --- a/autoflow/models.py +++ b/autoflow/models.py @@ -1,540 +1,5 @@ -import copy, json, uuid -from dataclasses import dataclass, field -from enum import Enum -from typing import Any, Dict, List, Optional, Tuple -import numpy as np +"""Compatibility re-exports for core models.""" +from .core.models import * -class ObjectKind(Enum): - SEGMENTATION = "Segmentation" - SKELETON = "Skeleton" - GRAPH = "Graph" - BRANCH = "Branch" - PLANE = "Plane" - FLOW = "Flow" - METRIC = "Metric" - AUX = "Aux" - - -class StepId(Enum): - GENERATE_SKELETON = "step_skeleton" - EDIT_SKELETON = "step_edit_skeleton" - GENERATE_GRAPH = "step_graph" - EDIT_GRAPH = "step_edit_graph" - GENERATE_PLANES = "step_planes" - EDIT_PLANES = "step_edit_planes" - GENERATE_STREAMLINES = "step_streamlines" - PLANE_STREAMLINES = "step_plane_streamlines" - COMPUTE_PLANE_METRICS = "step_plane_metrics" - COMPUTE_DERIVED_METRICS = "step_derived_metrics" - - @property - def label(self): - return { - StepId.GENERATE_SKELETON: "Generate Skeleton", - StepId.EDIT_SKELETON: "Edit Skeleton", - StepId.GENERATE_GRAPH: "Generate Graph", - StepId.EDIT_GRAPH: "Edit Graph", - StepId.GENERATE_PLANES: "Generate Planes", - StepId.EDIT_PLANES: "Edit Planes", - StepId.GENERATE_STREAMLINES: "Generate Streamlines", - StepId.PLANE_STREAMLINES: "Plane Streamlines", - StepId.COMPUTE_PLANE_METRICS: "Calculate && Save Metrics", - StepId.COMPUTE_DERIVED_METRICS: "WSS / TKE", - }[self] - - @staticmethod - def top_row_steps(): - return [ - StepId.GENERATE_SKELETON, - StepId.GENERATE_GRAPH, - StepId.GENERATE_PLANES, - StepId.COMPUTE_PLANE_METRICS, - ] - - @staticmethod - def bottom_row_steps(): - return [ - StepId.EDIT_SKELETON, - StepId.EDIT_GRAPH, - ] - - @staticmethod - def extra_row_steps(): - return [ - StepId.GENERATE_STREAMLINES, - StepId.PLANE_STREAMLINES, - StepId.COMPUTE_DERIVED_METRICS, - ] - - -@dataclass -class PreprocessParams: - def to_dict(self): - return {} - - @staticmethod - def from_dict(d): - return PreprocessParams() - - -@dataclass -class SkeletonParams: - remove_small_cc: bool = False - min_cc_volume_mm3: float = 50.0 - do_closing: bool = True - do_opening: bool = False - gaussian_sigma: float = 0.5 - gaussian_enabled: bool = True - - def to_dict(self): - return { - "remove_small_cc": self.remove_small_cc, - "min_cc_volume_mm3": self.min_cc_volume_mm3, - "do_closing": self.do_closing, - "do_opening": self.do_opening, - "gaussian_sigma": self.gaussian_sigma, - "gaussian_enabled": self.gaussian_enabled, - } - - @staticmethod - def from_dict(d): - if "keep_largest_cc" in d and "remove_small_cc" not in d: - d["remove_small_cc"] = d["keep_largest_cc"] - return SkeletonParams( - remove_small_cc=bool(d.get("remove_small_cc", False)), - min_cc_volume_mm3=float(d.get("min_cc_volume_mm3", 50.0)), - do_closing=bool(d.get("do_closing", True)), - do_opening=bool(d.get("do_opening", False)), - gaussian_sigma=float(d.get("gaussian_sigma", 0.5)), - gaussian_enabled=bool(d.get("gaussian_enabled", True))) - - -@dataclass -class PlaneGenerationParams: - use_center_plane: bool = True - cross_section_distance: float = 20.0 - start_distance: float = 5.0 - end_distance: float = 0.0 - smoothing_window: int = 15 - smoothing_polyorder: int = 2 - inter_time: int = 10 - - def to_dict(self): - return { - "use_center_plane": bool(self.use_center_plane), - "cross_section_distance": self.cross_section_distance, - "start_distance": self.start_distance, - "end_distance": self.end_distance, - "smoothing_window": int(self.smoothing_window), - "smoothing_polyorder": int(self.smoothing_polyorder), - "inter_time": int(self.inter_time), - } - - @staticmethod - def from_dict(d): - return PlaneGenerationParams( - use_center_plane=bool(d.get("use_center_plane", True)), - cross_section_distance=float(d.get("cross_section_distance", 20.0)), - start_distance=float(d.get("start_distance", 5.0)), - end_distance=float(d.get("end_distance", 0.0)), - smoothing_window=int(d.get("smoothing_window", 15)), - smoothing_polyorder=int(d.get("smoothing_polyorder", 3)), - inter_time=int(d.get("inter_time", 10)), - ) - - -@dataclass -class StreamlineParams: - seed_ratio: float = 0.02 - max_steps: int = 2000 - min_seeds: int = 50 - terminal_speed: float = 0.01 - rng_seed: int = 0 - - def to_dict(self): - return { - "seed_ratio": self.seed_ratio, - "max_steps": self.max_steps, - "min_seeds": self.min_seeds, - "terminal_speed": self.terminal_speed, - "rng_seed": self.rng_seed, - } - - @staticmethod - def from_dict(d): - return StreamlineParams( - seed_ratio=float(d.get("seed_ratio", 0.02)), - max_steps=int(d.get("max_steps", 2000)), - min_seeds=50, - terminal_speed=float(d.get("terminal_speed", 0.01)), - rng_seed=int(d.get("rng_seed", 0)), - ) - - -@dataclass -class DerivedMetricsParams: - smoothing_iteration: int = 200 - viscosity: float = 4.0 - inward_distance: float = 0.6 - parabolic_fitting: bool = True - no_slip_condition: bool = True - step_size: int = 5 - tube_radius: float = 0.1 - rho: float = 1060.0 - use_multithread: bool = False - - def to_dict(self): - return { - "smoothing_iteration": self.smoothing_iteration, - "viscosity": self.viscosity, - "inward_distance": self.inward_distance, - "parabolic_fitting": self.parabolic_fitting, - "no_slip_condition": self.no_slip_condition, - "step_size": self.step_size, - "tube_radius": self.tube_radius, - "rho": self.rho, - "use_multithread": self.use_multithread, - } - - @staticmethod - def from_dict(d): - return DerivedMetricsParams( - smoothing_iteration=int(d.get("smoothing_iteration", 200)), - viscosity=float(d.get("viscosity", 4.0)), - inward_distance=float(d.get("inward_distance", 0.6)), - parabolic_fitting=bool(d.get("parabolic_fitting", True)), - no_slip_condition=bool(d.get("no_slip_condition", True)), - step_size=int(d.get("step_size", 5)), - tube_radius=float(d.get("tube_radius", 0.1)), - rho=float(d.get("rho", 1060.0)), - use_multithread=bool(d.get("use_multithread", False)), - ) - - -@dataclass -class PathsState: - segmask_path: str = "" - flow_path: str = "" - workspace_path: str = "" - output_dir: str = "" - - -@dataclass -class PipelineFlags: - completed: Dict[str, bool] = field(default_factory=dict) - skipped: Dict[str, bool] = field(default_factory=dict) - - def mark_done(self, step: StepId, skipped: bool = False): - self.completed[step.value] = True - self.skipped[step.value] = bool(skipped) - - def is_done(self, step: StepId) -> bool: - return bool(self.completed.get(step.value, False)) - - def reset(self): - self.completed.clear() - self.skipped.clear() - - -@dataclass -class SceneObject: - uid: str - name: str - kind: ObjectKind - data_key: str - visible: bool = True - opacity: float = 1.0 - color: str = "white" - scalars: Optional[str] = None - cmap: str = "turbo" - clim: Optional[Tuple[float, float]] = None - point_size: int = 8 - line_width: int = 2 - tube_radius: float = 0.0 - show_scalar_bar: bool = False - scalar_bar_title: Optional[str] = None - dynamic: bool = False - actor: Any = None - label_actor: Any = None - - -@dataclass -class PlaneData: - center: np.ndarray - normal: np.ndarray - label: int = 1 - path_index: int = 0 - distance: float = 0.0 - metrics: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class GraphData: - points: np.ndarray = field(default_factory=lambda: np.empty((0, 3), dtype=float)) - edges: np.ndarray = field(default_factory=lambda: np.empty((0, 2), dtype=int)) - - -@dataclass -class DerivedResults: - plane_metrics: List[Dict[str, Any]] = field(default_factory=list) - plane_qc: Dict[str, Any] = field(default_factory=dict) - wss_surfaces: List[Any] = field(default_factory=list) - wss_volume: Optional[np.ndarray] = None - tke_volume: Any = None - tke_array: Optional[np.ndarray] = None - streamlines: List[Any] = field(default_factory=list) - pixelwise_export: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class Workspace: - paths: PathsState = field(default_factory=PathsState) - pipeline: PipelineFlags = field(default_factory=PipelineFlags) - preprocess_params: PreprocessParams = field(default_factory=PreprocessParams) - skeleton_params: SkeletonParams = field(default_factory=SkeletonParams) - plane_gen_params: PlaneGenerationParams = field(default_factory=PlaneGenerationParams) - streamline_params: StreamlineParams = field(default_factory=StreamlineParams) - derived_params: DerivedMetricsParams = field(default_factory=DerivedMetricsParams) - - resolution: np.ndarray = field(default_factory=lambda: np.array([1., 1., 1.])) - origin: np.ndarray = field(default_factory=lambda: np.array([0., 0., 0.])) - spatial_order: List[str] = field(default_factory=lambda: ["FH", "AP", "LR"]) - venc_order: List[str] = field(default_factory=lambda: ["FH", "AP", "LR"]) - venc: np.ndarray = field(default_factory=lambda: np.array([1., 1., 1.])) - rr: float = 1000.0 - - segmask_raw: Optional[np.ndarray] = None - segmask_labels: Optional[np.ndarray] = None - segmask_binary: Optional[np.ndarray] = None - segmask_3d: Optional[np.ndarray] = None - - mag_raw: Optional[np.ndarray] = None - - skeleton_points: Optional[np.ndarray] = None - skeleton_mask: Optional[np.ndarray] = None - - graph: GraphData = field(default_factory=GraphData) - branch_labels: Optional[np.ndarray] = None - centerline_paths: List[np.ndarray] = field(default_factory=list) - centerline_node_paths: List[List[int]] = field(default_factory=list) - centerline_paths_smooth: List[np.ndarray] = field(default_factory=list) - path_info: List[Dict[str, Any]] = field(default_factory=list) - forks: List[Dict[str, Any]] = field(default_factory=list) - planes: List[PlaneData] = field(default_factory=list) - - flow_raw: Optional[np.ndarray] = None - streamline_seeds: Optional[np.ndarray] = None - streamline_cache: Dict[int, Any] = field(default_factory=dict) - streamline_active: bool = False - plane_streamline_cache: Dict[int, Any] = field(default_factory=dict) - plane_streamline_active: bool = False - plane_streamline_plane_idx: int = -1 - derived: DerivedResults = field(default_factory=DerivedResults) - - scene_objects: Dict[str, SceneObject] = field(default_factory=dict) - current_t: int = 0 - data_loaded: bool = False - - ortho_cursor: np.ndarray = field(default_factory=lambda: np.array([0, 0, 0], dtype=int)) - selected_path_index: int = -1 - - def time_count(self): - if self.flow_raw is not None and self.flow_raw.ndim == 5: - return int(self.flow_raw.shape[3]) - if self.segmask_raw is not None and self.segmask_raw.ndim == 4: - return int(self.segmask_raw.shape[3]) - return 1 - - def has_flow(self): - return self.flow_raw is not None - - def unique_labels(self): - if self.segmask_raw is None: - return [] - return sorted(int(x) for x in np.unique(self.segmask_raw) if x != 0) - - def add_object(self, name, kind, data_key, **kw): - uid = str(uuid.uuid4()) - self.scene_objects[uid] = SceneObject(uid=uid, name=name, kind=kind, data_key=data_key, **kw) - return uid - - def remove_object(self, uid): - return self.scene_objects.pop(uid, None) - - def remove_object_by_data_key(self, data_key): - to_del = [u for u, o in self.scene_objects.items() if o.data_key == data_key] - return [self.scene_objects.pop(u) for u in to_del] - - def remove_objects_by_prefix(self, prefix): - to_del = [u for u, o in self.scene_objects.items() if o.data_key.startswith(prefix)] - return [self.scene_objects.pop(u) for u in to_del] - - def set_object_visible_by_data_key(self, data_key, visible): - for o in self.scene_objects.values(): - if o.data_key == data_key: - o.visible = visible - - def clear_streamlines(self): - self.streamline_seeds = None - self.streamline_cache.clear() - self.streamline_active = False - self.remove_object_by_data_key("streamlines_live") - - def clear_plane_streamlines(self): - self.plane_streamline_cache.clear() - self.plane_streamline_active = False - self.plane_streamline_plane_idx = -1 - self.remove_object_by_data_key("plane_streamlines_live") - - def reset_all(self): - for attr, default in [ - ("paths", PathsState()), - ("pipeline", PipelineFlags()), - ("preprocess_params", PreprocessParams()), - ("skeleton_params", SkeletonParams()), - ("plane_gen_params", PlaneGenerationParams()), - ("streamline_params", StreamlineParams()), - ("derived_params", DerivedMetricsParams()), - ]: - setattr(self, attr, default) - self.resolution = np.array([1., 1., 1.]) - self.origin = np.array([0., 0., 0.]) - self.venc = np.array([1., 1., 1.]) - self.rr = 1000.0 - for attr in ["segmask_raw", "segmask_labels", "segmask_binary", "segmask_3d", - "skeleton_points", "skeleton_mask", "branch_labels", "flow_raw", - "streamline_seeds", "mag_raw"]: - setattr(self, attr, None) - self.graph = GraphData() - self.centerline_paths = [] - self.centerline_node_paths = [] - self.centerline_paths_smooth = [] - self.path_info = [] - self.forks = [] - self.planes = [] - self.streamline_cache = {} - self.streamline_active = False - self.plane_streamline_cache = {} - self.plane_streamline_active = False - self.plane_streamline_plane_idx = -1 - self.derived = DerivedResults() - self.scene_objects = {} - self.current_t = 0 - self.data_loaded = False - self.ortho_cursor = np.array([0, 0, 0], dtype=int) - self.selected_path_index = -1 - - def snapshot_dict(self): - def arr(v): - return None if v is None else np.asarray(v).tolist() - return { - "paths": {"segmask_path": self.paths.segmask_path, "flow_path": self.paths.flow_path, - "workspace_path": self.paths.workspace_path, "output_dir": self.paths.output_dir}, - "pipeline": {"completed": dict(self.pipeline.completed), "skipped": dict(self.pipeline.skipped)}, - "preprocess_params": self.preprocess_params.to_dict(), - "skeleton_params": self.skeleton_params.to_dict(), - "plane_gen_params": self.plane_gen_params.to_dict(), - "streamline_params": self.streamline_params.to_dict(), - "derived_params": self.derived_params.to_dict(), - "resolution": arr(self.resolution), - "origin": arr(self.origin), - "venc": arr(self.venc), - "rr": float(self.rr), - "segmask_raw": arr(self.segmask_raw), - "segmask_labels": arr(self.segmask_labels), - "segmask_binary": arr(self.segmask_binary), - "segmask_3d": arr(self.segmask_3d), - "mag_raw": arr(self.mag_raw), - "skeleton_points": arr(self.skeleton_points), - "skeleton_mask": arr(self.skeleton_mask), - "graph": {"points": arr(self.graph.points), "edges": arr(self.graph.edges)}, - "branch_labels": arr(self.branch_labels), - "centerline_paths": [arr(x) for x in self.centerline_paths], - "centerline_node_paths": [list(map(int, x)) for x in self.centerline_node_paths], - "centerline_paths_smooth": [arr(x) for x in self.centerline_paths_smooth], - "path_info": copy.deepcopy(self.path_info), - "forks": copy.deepcopy(self.forks), - "planes": [{"center": arr(p.center), "normal": arr(p.normal), "label": int(p.label), - "path_index": int(p.path_index), "distance": float(p.distance), - "metrics": copy.deepcopy(p.metrics)} for p in self.planes], - "flow_raw": arr(self.flow_raw), - "streamline_seeds": arr(self.streamline_seeds), - "streamline_active": self.streamline_active, - "scene_objects": [ - {"uid": o.uid, "name": o.name, "kind": o.kind.value, "data_key": o.data_key, - "visible": o.visible, "opacity": o.opacity, "color": o.color, - "scalars": o.scalars, "cmap": o.cmap, - "clim": list(o.clim) if o.clim else None, - "point_size": o.point_size, "line_width": o.line_width, - "tube_radius": o.tube_radius, "show_scalar_bar": o.show_scalar_bar, - "scalar_bar_title": o.scalar_bar_title, "dynamic": o.dynamic} - for o in self.scene_objects.values()], - "current_t": self.current_t, "data_loaded": self.data_loaded, - "selected_path_index": int(self.selected_path_index), - } - - def restore_dict(self, d): - self.paths = PathsState(**{k: d.get("paths", {}).get(k, "") for k in ["segmask_path", "flow_path", "workspace_path", "output_dir"]}) - self.pipeline = PipelineFlags(completed=dict(d.get("pipeline", {}).get("completed", {})), - skipped=dict(d.get("pipeline", {}).get("skipped", {}))) - self.preprocess_params = PreprocessParams.from_dict(d.get("preprocess_params", {})) - self.skeleton_params = SkeletonParams.from_dict(d.get("skeleton_params", {})) - self.plane_gen_params = PlaneGenerationParams.from_dict(d.get("plane_gen_params", {})) - self.streamline_params = StreamlineParams.from_dict(d.get("streamline_params", {})) - self.derived_params = DerivedMetricsParams.from_dict(d.get("derived_params", {})) - self.resolution = np.asarray(d.get("resolution", [1, 1, 1]), dtype=float) - self.origin = np.array([0.0, 0.0, 0.0], dtype=float) - self.venc = np.asarray(d.get("venc", [1, 1, 1]), dtype=float) - self.rr = float(d.get("rr", 1000.0)) - - def nparr(k, dt=np.float64): - v = d.get(k) - return None if v is None else np.asarray(v, dtype=dt) - - self.segmask_raw = nparr("segmask_raw", np.int16) - self.segmask_labels = nparr("segmask_labels", np.int16) - self.segmask_binary = None if d.get("segmask_binary") is None else np.asarray(d["segmask_binary"], dtype=bool) - self.segmask_3d = None if d.get("segmask_3d") is None else np.asarray(d["segmask_3d"], dtype=bool) - self.mag_raw = nparr("mag_raw") - self.skeleton_points = nparr("skeleton_points") - self.skeleton_mask = nparr("skeleton_mask") - gd = d.get("graph", {}) - self.graph = GraphData( - points=np.asarray(gd.get("points", []), dtype=float).reshape(-1, 3) if gd.get("points") else np.empty((0, 3)), - edges=np.asarray(gd.get("edges", []), dtype=int).reshape(-1, 2) if gd.get("edges") else np.empty((0, 2), dtype=int)) - self.branch_labels = nparr("branch_labels") - self.centerline_paths = [np.asarray(x, dtype=float) for x in d.get("centerline_paths", [])] - self.centerline_node_paths = [list(map(int, x)) for x in d.get("centerline_node_paths", [])] - self.centerline_paths_smooth = [np.asarray(x, dtype=float) for x in d.get("centerline_paths_smooth", [])] - self.path_info = copy.deepcopy(d.get("path_info", [])) - self.forks = copy.deepcopy(d.get("forks", [])) - self.planes = [] - for p in d.get("planes", []): - self.planes.append(PlaneData( - center=np.asarray(p["center"], dtype=float), normal=np.asarray(p["normal"], dtype=float), - label=int(p.get("label", 1)), path_index=int(p.get("path_index", 0)), - distance=float(p.get("distance", 0.0)), metrics=copy.deepcopy(p.get("metrics", {})))) - self.flow_raw = nparr("flow_raw") - self.streamline_seeds = nparr("streamline_seeds") - self.streamline_cache = {} - self.streamline_active = bool(d.get("streamline_active", False)) - self.plane_streamline_cache = {} - self.plane_streamline_active = False - self.plane_streamline_plane_idx = -1 - self.scene_objects = {} - for it in d.get("scene_objects", []): - uid = it["uid"] - self.scene_objects[uid] = SceneObject( - uid=uid, name=it["name"], kind=ObjectKind(it["kind"]), data_key=it["data_key"], - visible=bool(it.get("visible", True)), opacity=float(it.get("opacity", 1.0)), - color=it.get("color", "white"), scalars=it.get("scalars"), - cmap=it.get("cmap", "turbo"), - clim=tuple(it["clim"]) if it.get("clim") else None, - point_size=int(it.get("point_size", 8)), line_width=int(it.get("line_width", 2)), - tube_radius=float(it.get("tube_radius", 0.0)), - show_scalar_bar=bool(it.get("show_scalar_bar", False)), - scalar_bar_title=it.get("scalar_bar_title"), dynamic=bool(it.get("dynamic", False))) - self.current_t = int(d.get("current_t", 0)) - self.data_loaded = bool(d.get("data_loaded", False)) - self.selected_path_index = int(d.get("selected_path_index", -1)) +__all__ = [name for name in globals() if not name.startswith("_")] diff --git a/autoflow/ortho_viewer.py b/autoflow/ortho_viewer.py index 440a0a2..7348bab 100755 --- a/autoflow/ortho_viewer.py +++ b/autoflow/ortho_viewer.py @@ -1,555 +1,11 @@ -import numpy as np -from PyQt5 import QtWidgets, QtCore -from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.figure import Figure -from scipy.ndimage import map_coordinates +"""Compatibility re-exports for the ortho viewer.""" +__all__ = ["OrthoViewer"] -class OrthoViewer(QtWidgets.QWidget): - def __init__(self, workspace, parent=None): - super().__init__(parent) - self.workspace = workspace - self._selected_plane_idx = None - self._scalar_cbar = None - self._cache = {} - self._build_ui() - def _cached(self, group, key, builder, max_items=24): - bucket = self._cache.setdefault(group, {}) - if key in bucket: - return bucket[key] - value = builder() - if len(bucket) >= max_items: - bucket.clear() - bucket[key] = value - return value +def __getattr__(name): + if name == "OrthoViewer": + from .ui.ortho_viewer import OrthoViewer - def _build_ui(self): - layout = QtWidgets.QVBoxLayout(self) - layout.setContentsMargins(2, 2, 2, 2) - layout.setSpacing(2) - - ctrl = QtWidgets.QHBoxLayout() - self.combo_content = QtWidgets.QComboBox() - self.combo_content.addItems([ - "Flow X (cm/s)", "Flow Y (cm/s)", "Flow Z (cm/s)", - "Magnitude", "PC-MRA", "Speed (cm/s)", - "WSS (Pa)", "TKE (J/m³)" - ]) - self.combo_content.setCurrentIndex(4) - self.combo_content.currentIndexChanged.connect(self._on_content_changed) - ctrl.addWidget(QtWidgets.QLabel("Content:")) - ctrl.addWidget(self.combo_content) - ctrl.addStretch() - layout.addLayout(ctrl) - - slider_layout = QtWidgets.QHBoxLayout() - self.slider_x = QtWidgets.QSlider(QtCore.Qt.Horizontal) - self.slider_y = QtWidgets.QSlider(QtCore.Qt.Horizontal) - self.slider_z = QtWidgets.QSlider(QtCore.Qt.Horizontal) - self.label_x = QtWidgets.QLabel("X:0") - self.label_y = QtWidgets.QLabel("Y:0") - self.label_z = QtWidgets.QLabel("Z:0") - for lbl, sl in [(self.label_x, self.slider_x), (self.label_y, self.slider_y), (self.label_z, self.slider_z)]: - sl.setRange(0, 0) - sl.valueChanged.connect(self._on_slider_changed) - slider_layout.addWidget(lbl) - slider_layout.addWidget(sl) - layout.addLayout(slider_layout) - - self.label_value = QtWidgets.QLabel("Voxel: - Value: -") - self.label_plane_metric = QtWidgets.QLabel("Plane metrics: -") - self.label_value.setWordWrap(True) - self.label_plane_metric.setWordWrap(True) - layout.addWidget(self.label_value) - layout.addWidget(self.label_plane_metric) - - self.fig = Figure(figsize=(6.2, 6.6), dpi=80, facecolor="black") - self.canvas = FigureCanvas(self.fig) - self.canvas.setMinimumSize(300, 300) - self.ax_ax = self.fig.add_subplot(2, 2, 1) - self.ax_cor = self.fig.add_subplot(2, 2, 2) - self.ax_sag = self.fig.add_subplot(2, 2, 3) - self.ax_plane = self.fig.add_subplot(2, 2, 4) - for ax in [self.ax_ax, self.ax_cor, self.ax_sag, self.ax_plane]: - ax.set_facecolor("black") - ax.tick_params(colors="white", labelsize=6) - ax.set_xticks([]) - ax.set_yticks([]) - self.fig.subplots_adjust(left=0.03, right=0.96, top=0.96, bottom=0.03, wspace=0.14, hspace=0.24) - layout.addWidget(self.canvas, 1) - - self.canvas.mpl_connect("scroll_event", self._on_scroll) - self.canvas.mpl_connect("button_press_event", self._on_click) - - def _remove_colorbar(self): - if self._scalar_cbar is not None: - try: - self._scalar_cbar.remove() - except Exception: - pass - self._scalar_cbar = None - - def _on_scroll(self, event): - if event.inaxes is None: - return - ax = event.inaxes - factor = 0.8 if event.button == "up" else 1.25 - xlim = ax.get_xlim() - ylim = ax.get_ylim() - xdata = event.xdata if event.xdata is not None else (xlim[0] + xlim[1]) / 2 - ydata = event.ydata if event.ydata is not None else (ylim[0] + ylim[1]) / 2 - new_w = (xlim[1] - xlim[0]) * factor - new_h = (ylim[1] - ylim[0]) * factor - ax.set_xlim(xdata - new_w / 2, xdata + new_w / 2) - ax.set_ylim(ydata - new_h / 2, ydata + new_h / 2) - self.canvas.draw_idle() - - def _on_click(self, event): - if event.inaxes is None or event.xdata is None or event.ydata is None: - return - shape = self._get_volume_shape() - if shape is None: - return - x = int(np.clip(np.round(event.xdata), 0, shape[0] - 1)) - y = int(np.clip(np.round(event.ydata), 0, max(shape[1], shape[2]) - 1)) - cx, cy, cz = self.slider_x.value(), self.slider_y.value(), self.slider_z.value() - if event.inaxes == self.ax_ax: - self._set_cursor(x, int(np.clip(np.round(event.ydata), 0, shape[1] - 1)), cz) - elif event.inaxes == self.ax_cor: - self._set_cursor(x, cy, int(np.clip(np.round(event.ydata), 0, shape[2] - 1))) - elif event.inaxes == self.ax_sag: - self._set_cursor(cx, int(np.clip(np.round(event.xdata), 0, shape[1] - 1)), int(np.clip(np.round(event.ydata), 0, shape[2] - 1))) - - def _set_cursor(self, x, y, z): - self.slider_x.blockSignals(True) - self.slider_y.blockSignals(True) - self.slider_z.blockSignals(True) - self.slider_x.setValue(int(x)) - self.slider_y.setValue(int(y)) - self.slider_z.setValue(int(z)) - self.slider_x.blockSignals(False) - self.slider_y.blockSignals(False) - self.slider_z.blockSignals(False) - self._update_labels() - self.workspace.ortho_cursor = np.array([int(x), int(y), int(z)], dtype=int) - self.refresh() - - def update_slider_ranges(self): - shape = self._get_volume_shape() - if shape is None: - return - self.slider_x.blockSignals(True) - self.slider_y.blockSignals(True) - self.slider_z.blockSignals(True) - self.slider_x.setRange(0, max(0, shape[0] - 1)) - self.slider_y.setRange(0, max(0, shape[1] - 1)) - self.slider_z.setRange(0, max(0, shape[2] - 1)) - self.slider_x.setValue(min(shape[0] // 2, self.slider_x.maximum())) - self.slider_y.setValue(min(shape[1] // 2, self.slider_y.maximum())) - self.slider_z.setValue(min(shape[2] // 2, self.slider_z.maximum())) - self.slider_x.blockSignals(False) - self.slider_y.blockSignals(False) - self.slider_z.blockSignals(False) - self._update_labels() - self.workspace.ortho_cursor = np.array([self.slider_x.value(), self.slider_y.value(), self.slider_z.value()], dtype=int) - self.refresh() - - def _get_volume_shape(self): - ws = self.workspace - if ws.flow_raw is not None and ws.flow_raw.ndim == 5: - return ws.flow_raw.shape[:3] - if ws.mag_raw is not None and ws.mag_raw.ndim == 4: - return ws.mag_raw.shape[:3] - if ws.segmask_3d is not None: - return ws.segmask_3d.shape[:3] - return None - - def _update_labels(self): - self.label_x.setText(f"X:{self.slider_x.value()}") - self.label_y.setText(f"Y:{self.slider_y.value()}") - self.label_z.setText(f"Z:{self.slider_z.value()}") - - def _on_slider_changed(self, _): - self._update_labels() - self.workspace.ortho_cursor = np.array([self.slider_x.value(), self.slider_y.value(), self.slider_z.value()], dtype=int) - self.refresh() - - def _on_content_changed(self, _): - self.refresh() - - def set_selected_plane(self, idx): - self._selected_plane_idx = idx - self._move_to_plane_center(idx) - self.refresh() - - def _move_to_plane_center(self, idx): - ws = self.workspace - if idx is None or idx >= len(ws.planes): - return - plane = ws.planes[idx] - res = self._get_resolution() - center_vox = np.asarray(plane.center, dtype=float) / (res + 1e-12) - shape = self._get_volume_shape() - if shape is None: - return - self._set_cursor( - int(np.clip(np.round(center_vox[0]), 0, shape[0] - 1)), - int(np.clip(np.round(center_vox[1]), 0, shape[1] - 1)), - int(np.clip(np.round(center_vox[2]), 0, shape[2] - 1)), - ) - - def _get_resolution(self): - ws = self.workspace - if ws.resolution is not None and len(ws.resolution) >= 3: - r = np.asarray(ws.resolution, dtype=float).reshape(-1)[:3] - return np.where(r > 0, r, 1.0) - return np.array([1.0, 1.0, 1.0]) - - def _scene_style(self, data_key, default_cmap, default_clim=None): - for obj in self.workspace.scene_objects.values(): - if obj.data_key == data_key: - return obj.cmap or default_cmap, obj.clim if obj.clim else default_clim - return default_cmap, default_clim - - def _get_wss_volume(self, t): - ws = self.workspace - if ws.derived.wss_volume is not None: - cmap, clim = self._scene_style("wss_surface_live", "jet", None) - tidx = min(max(0, int(t)), ws.derived.wss_volume.shape[3] - 1) - return np.asarray(ws.derived.wss_volume[..., tidx], dtype=float), "WSS (Pa)", {"cmap": cmap, "clim": clim} - if not ws.derived.wss_surfaces: - return None, "WSS (no data)", {"cmap": "jet", "clim": None} - tidx = min(max(0, t), len(ws.derived.wss_surfaces) - 1) - surf = ws.derived.wss_surfaces[tidx] - if surf is None or "wss" not in surf.point_data: - return None, "WSS (no data)", {"cmap": "jet", "clim": None} - shape = self._get_volume_shape() - if shape is None: - return None, "WSS (no data)", {"cmap": "jet", "clim": None} - res = self._get_resolution() - key = ( - id(surf), - tuple(int(x) for x in shape), - tuple(np.round(res, 6).tolist()), - ) - def _build(): - vol = np.zeros(shape, dtype=float) - pts = np.asarray(surf.points, dtype=float) - vals = np.asarray(surf.point_data["wss"], dtype=float) - vox = np.rint(pts / (res.reshape(1, 3) + 1e-12)).astype(int) - for k in range(3): - vox[:, k] = np.clip(vox[:, k], 0, shape[k] - 1) - flat = np.ravel_multi_index((vox[:, 0], vox[:, 1], vox[:, 2]), shape) - tgt = vol.reshape(-1) - np.maximum.at(tgt, flat, vals) - return vol - vol = self._cached("wss_volume", key, _build) - cmap, clim = self._scene_style("wss_surface_live", "jet", None) - return vol, "WSS (Pa)", {"cmap": cmap, "clim": clim} - - def _get_tke_volume(self, t): - ws = self.workspace - if ws.derived.tke_array is not None: - arr = np.asarray(ws.derived.tke_array, dtype=float) - tidx = min(max(0, int(t)), arr.shape[3] - 1) if arr.ndim == 4 else 0 - mask_id = id(ws.segmask_binary) if ws.segmask_binary is not None else -1 - key = (id(ws.derived.tke_array), mask_id, int(tidx)) - def _build(): - if arr.ndim == 4: - vol = arr[..., tidx] - else: - vol = arr - if ws.segmask_binary is not None: - if ws.segmask_binary.ndim == 4: - mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] - else: - mask_t = ws.segmask_binary - vol = np.asarray(vol, dtype=float) * np.asarray(mask_t, dtype=float) - return np.asarray(vol, dtype=float) - vol = self._cached("tke_volume", key, _build) - cmap, clim = self._scene_style("tke_volume", "hot", None) - return vol, "TKE (J/m³)", {"cmap": cmap, "clim": clim} - if ws.derived.tke_volume is None: - return None, "TKE (no data)", {"cmap": "hot", "clim": None} - shape = self._get_volume_shape() - if shape is None: - return None, "TKE (no data)", {"cmap": "hot", "clim": None} - tke_mesh = ws.derived.tke_volume - if "TKE" not in tke_mesh.point_data and "TKE" not in tke_mesh.cell_data: - return None, "TKE (no data)", {"cmap": "hot", "clim": None} - res = self._get_resolution() - key = ( - id(tke_mesh), - tuple(int(x) for x in shape), - tuple(np.round(res, 6).tolist()), - ) - def _build(): - vol = np.zeros(shape, dtype=float) - if "TKE" in tke_mesh.cell_data: - pts = tke_mesh.cell_centers().points - vals = np.asarray(tke_mesh.cell_data["TKE"], dtype=float) - else: - pts = tke_mesh.points - vals = np.asarray(tke_mesh.point_data["TKE"], dtype=float) - vox = np.rint(pts / (res.reshape(1, 3) + 1e-12)).astype(int) - for k in range(3): - vox[:, k] = np.clip(vox[:, k], 0, shape[k] - 1) - flat = np.ravel_multi_index((vox[:, 0], vox[:, 1], vox[:, 2]), shape) - tgt = vol.reshape(-1) - np.maximum.at(tgt, flat, vals) - return vol - vol = self._cached("tke_mesh_volume", key, _build) - cmap, clim = self._scene_style("tke_volume", "hot", None) - return vol, "TKE (J/m³)", {"cmap": cmap, "clim": clim} - def _get_scalar_slice(self, t): - ws = self.workspace - content_idx = self.combo_content.currentIndex() - if content_idx == 0 and ws.flow_raw is not None: - vol = np.asarray(ws.flow_raw[..., t, 0], dtype=float) - vmax = max(abs(np.nanmin(vol)), abs(np.nanmax(vol)), 1e-6) - return vol, "Flow X (cm/s)", {"cmap": "RdBu_r", "clim": (-vmax, vmax)} - if content_idx == 1 and ws.flow_raw is not None: - vol = np.asarray(ws.flow_raw[..., t, 1], dtype=float) - vmax = max(abs(np.nanmin(vol)), abs(np.nanmax(vol)), 1e-6) - return vol, "Flow Y (cm/s)", {"cmap": "RdBu_r", "clim": (-vmax, vmax)} - if content_idx == 2 and ws.flow_raw is not None: - vol = np.asarray(ws.flow_raw[..., t, 2], dtype=float) - vmax = max(abs(np.nanmin(vol)), abs(np.nanmax(vol)), 1e-6) - return vol, "Flow Z (cm/s)", {"cmap": "RdBu_r", "clim": (-vmax, vmax)} - if content_idx == 3 and ws.mag_raw is not None: - vol = np.asarray(ws.mag_raw[..., t], dtype=float) - return vol, "Magnitude", {"cmap": "gray", "clim": (float(np.nanmin(vol)), float(np.nanmax(vol)))} - if content_idx == 4 and ws.mag_raw is not None and ws.flow_raw is not None: - key = (id(ws.mag_raw), id(ws.flow_raw), int(t)) - def _build(): - speed = np.sqrt(np.sum(ws.flow_raw[..., t, :] ** 2, axis=-1)) - return np.asarray(ws.mag_raw[..., t], dtype=float) * np.asarray(speed, dtype=float) - vol = self._cached("scalar_volume", ("pcmra",) + key, _build) - return vol, "PC-MRA", {"cmap": "gray", "clim": (float(np.nanmin(vol)), float(np.nanmax(vol)))} - if content_idx == 5 and ws.flow_raw is not None: - key = (id(ws.flow_raw), int(t)) - vol = self._cached( - "scalar_volume", - ("speed",) + key, - lambda: np.sqrt(np.sum(ws.flow_raw[..., t, :] ** 2, axis=-1)), - ) - return np.asarray(vol, dtype=float), "Speed (cm/s)", {"cmap": "turbo", "clim": (0.0, float(np.nanmax(vol)) if np.nanmax(vol) > 0 else 1.0)} - if content_idx == 6: - return self._get_wss_volume(t) - if content_idx == 7: - return self._get_tke_volume(t) - return None, "", {"cmap": "gray", "clim": None} - - def _get_mask_3d(self): - ws = self.workspace - if ws.segmask_3d is not None: - return ws.segmask_3d - if ws.segmask_binary is not None and ws.segmask_binary.ndim == 4: - key = (id(ws.segmask_binary), tuple(int(x) for x in ws.segmask_binary.shape)) - return self._cached("mask_3d", key, lambda: np.any(ws.segmask_binary, axis=3)) - return None - - def _update_value_label(self, vol, title): - if vol is None: - self.label_value.setText("Voxel: - Value: -") - return - x, y, z = self.slider_x.value(), self.slider_y.value(), self.slider_z.value() - try: - val = float(vol[x, y, z]) - self.label_value.setText(f"Voxel: ({x}, {y}, {z}) {title}: {val:.6g}") - except Exception: - self.label_value.setText(f"Voxel: ({x}, {y}, {z}) {title}: -") - - def _update_plane_metric_label(self): - ws = self.workspace - if self._selected_plane_idx is None or self._selected_plane_idx >= len(ws.planes): - self.label_plane_metric.setText("Plane metrics: -") - return - plane = ws.planes[self._selected_plane_idx] - metrics = plane.metrics or {} - t = int(ws.current_t) - fr = metrics.get("flowrate_mL_s", []) - ar = metrics.get("area_mm2", []) - mv = metrics.get("meanv_cm_s_t", []) - cur_fr = float(fr[t]) if t < len(fr) else 0.0 - cur_ar = float(ar[t]) if t < len(ar) else 0.0 - cur_mv = float(mv[t]) if t < len(mv) else metrics.get("meanv_cm_s", 0.0) - path_direction = metrics.get("path_direction", "") - path_ic = metrics.get("path_ic", None) - txt = f"Plane {self._selected_plane_idx}" - if path_direction: - txt += f" [{path_direction}]" - txt += f" t={t} Flow Rate={cur_fr:.4g} mL/s Area={cur_ar:.4g} mm² Mean Velocity={cur_mv:.4g} cm/s Peak Velocity={metrics.get('peakv_cm_s', 0.0):.4g} cm/s" - if path_ic is not None: - txt += f" Path IC={float(path_ic):.3f}" - self.label_plane_metric.setText(txt) - - def refresh(self): - ws = self.workspace - t = int(ws.current_t) - shape = self._get_volume_shape() - if shape is None: - self._remove_colorbar() - self.canvas.draw_idle() - return - - cx, cy, cz = self.slider_x.value(), self.slider_y.value(), self.slider_z.value() - vol, title, style = self._get_scalar_slice(t) - mask_3d = self._get_mask_3d() - res = self._get_resolution() - - for ax in [self.ax_ax, self.ax_cor, self.ax_sag]: - ax.clear() - ax.set_facecolor("black") - ax.set_xticks([]) - ax.set_yticks([]) - - self._remove_colorbar() - im = None - if vol is not None: - cmap = style.get("cmap", "gray") - clim = style.get("clim", None) - if clim is None: - clim = (float(np.nanmin(vol)), float(np.nanmax(vol))) - axial = vol[:, :, cz] - im = self.ax_ax.imshow(axial.T, origin="lower", cmap=cmap, vmin=clim[0], vmax=clim[1], aspect=float(res[1] / res[0])) - self.ax_ax.axhline(cy, color="lime", linewidth=0.5, alpha=0.5) - self.ax_ax.axvline(cx, color="lime", linewidth=0.5, alpha=0.5) - self.ax_ax.plot(cx, cy, "r+", markersize=8, markeredgewidth=1.5) - self.ax_ax.set_title(f"Axial Z={cz}", color="white", fontsize=8) - if mask_3d is not None and cz < mask_3d.shape[2]: - try: - self.ax_ax.contour(mask_3d[:, :, cz].astype(float).T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") - except Exception: - pass - - coronal = vol[:, cy, :] - self.ax_cor.imshow(coronal.T, origin="lower", cmap=cmap, vmin=clim[0], vmax=clim[1], aspect=float(res[2] / res[0])) - self.ax_cor.axhline(cz, color="lime", linewidth=0.5, alpha=0.5) - self.ax_cor.axvline(cx, color="lime", linewidth=0.5, alpha=0.5) - self.ax_cor.plot(cx, cz, "r+", markersize=8, markeredgewidth=1.5) - self.ax_cor.set_title(f"Coronal Y={cy}", color="white", fontsize=8) - if mask_3d is not None and cy < mask_3d.shape[1]: - try: - self.ax_cor.contour(mask_3d[:, cy, :].astype(float).T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") - except Exception: - pass - - sagittal = vol[cx, :, :] - self.ax_sag.imshow(sagittal.T, origin="lower", cmap=cmap, vmin=clim[0], vmax=clim[1], aspect=float(res[2] / res[1])) - self.ax_sag.axhline(cz, color="lime", linewidth=0.5, alpha=0.5) - self.ax_sag.axvline(cy, color="lime", linewidth=0.5, alpha=0.5) - self.ax_sag.plot(cy, cz, "r+", markersize=8, markeredgewidth=1.5) - self.ax_sag.set_title(f"Sagittal X={cx}", color="white", fontsize=8) - if mask_3d is not None and cx < mask_3d.shape[0]: - try: - self.ax_sag.contour(mask_3d[cx, :, :].astype(float).T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") - except Exception: - pass - - if self.combo_content.currentIndex() in (6, 7): - self._scalar_cbar = self.fig.colorbar(im, ax=[self.ax_ax, self.ax_cor, self.ax_sag], fraction=0.025, pad=0.01) - self._scalar_cbar.ax.tick_params(labelsize=6, colors="white") - self._scalar_cbar.set_label(title, color="white", fontsize=7) - try: - self._scalar_cbar.outline.set_edgecolor("white") - except Exception: - pass - - self._draw_plane_flow(t) - self._update_value_label(vol, title) - self._update_plane_metric_label() - self.fig.subplots_adjust(left=0.03, right=0.96, top=0.96, bottom=0.03, wspace=0.14, hspace=0.24) - self.canvas.draw_idle() - - def _resample_oblique(self, volume_3d, center_vox, normal, half_size=30): - normal = np.asarray(normal, dtype=float) - normal = normal / (np.linalg.norm(normal) + 1e-12) - up_hint = np.array([0.0, 1.0, 0.0]) if abs(normal[2]) > max(abs(normal[0]), abs(normal[1])) else np.array([0.0, 0.0, 1.0]) - u = np.cross(normal, up_hint) - u = u / (np.linalg.norm(u) + 1e-12) - v = np.cross(normal, u) - v = v / (np.linalg.norm(v) + 1e-12) - ii = np.arange(-half_size, half_size + 1, dtype=float) - jj = np.arange(-half_size, half_size + 1, dtype=float) - gi, gj = np.meshgrid(ii, jj, indexing="ij") - coords = center_vox.reshape(1, 1, 3) + gi[..., None] * u.reshape(1, 1, 3) + gj[..., None] * v.reshape(1, 1, 3) - sampled = map_coordinates(volume_3d, [coords[..., 0].ravel(), coords[..., 1].ravel(), coords[..., 2].ravel()], order=1, mode="constant", cval=0.0) - return sampled.reshape(len(ii), len(jj)) - - def _draw_plane_flow(self, t): - self.ax_plane.clear() - self.ax_plane.set_facecolor("black") - self.ax_plane.set_xticks([]) - self.ax_plane.set_yticks([]) - ws = self.workspace - if self._selected_plane_idx is None or self._selected_plane_idx >= len(ws.planes): - self.ax_plane.set_title("Plane Through-Plane Velocity (select a plane)", color="white", fontsize=8) - return - if ws.flow_raw is None: - self.ax_plane.set_title("Plane Through-Plane Velocity (no flow data)", color="white", fontsize=8) - return - plane = ws.planes[self._selected_plane_idx] - res = self._get_resolution() - center_vox = np.asarray(plane.center, dtype=float) / (res + 1e-12) - normal = np.asarray(plane.normal, dtype=float) - normal = normal / (np.linalg.norm(normal) + 1e-12) - flow_t = ws.flow_raw[..., t, :] - shape = self._get_volume_shape() - half_size = max(10, min(shape) // 2) - plane_key = ( - int(self._selected_plane_idx), - int(t), - tuple(np.round(center_vox, 4).tolist()), - tuple(np.round(normal, 6).tolist()), - int(half_size), - id(ws.flow_raw), - ) - def _build_plane_flow(): - proj = flow_t[..., 0] * normal[0] + flow_t[..., 1] * normal[1] + flow_t[..., 2] * normal[2] - return self._resample_oblique(proj, center_vox, normal, half_size=half_size) - sl = self._cached("plane_flow", plane_key, _build_plane_flow) - vmax = max(abs(np.nanmin(sl)), abs(np.nanmax(sl)), 1e-6) - self.ax_plane.imshow(sl.T, origin="lower", cmap="RdBu_r", vmin=-vmax, vmax=vmax, aspect=1.0) - self.ax_plane.plot(half_size, half_size, "r+", markersize=10, markeredgewidth=2) - mask_3d = self._get_mask_3d() - if mask_3d is not None: - mask_key = ( - int(self._selected_plane_idx), - tuple(np.round(center_vox, 4).tolist()), - tuple(np.round(normal, 6).tolist()), - int(half_size), - id(mask_3d), - ) - m_sl = self._cached( - "plane_mask", - mask_key, - lambda: self._resample_oblique(mask_3d.astype(float), center_vox, normal, half_size=half_size), - ) - try: - self.ax_plane.contour(m_sl.T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") - except Exception: - pass - metrics = plane.metrics or {} - txt = f"Plane {self._selected_plane_idx} Through-Plane Velocity [{-vmax:.2f}, {vmax:.2f}] cm/s" - if metrics: - fr = metrics.get("flowrate_mL_s", []) - ar = metrics.get("area_mm2", []) - flow_txt = float(fr[t]) if t < len(fr) else 0.0 - area_txt = float(ar[t]) if t < len(ar) else 0.0 - txt += f"\nFlow Rate={flow_txt:.4g} mL/s Area={area_txt:.4g} mm²" - self.ax_plane.set_title(txt, color="white", fontsize=7) - - def reset_state(self): - self._selected_plane_idx = None - self._remove_colorbar() - self._cache.clear() - self.label_value.setText("Voxel: - Value: -") - self.label_plane_metric.setText("Plane metrics: -") - for ax in [self.ax_ax, self.ax_cor, self.ax_sag, self.ax_plane]: - ax.clear() - ax.set_facecolor("black") - ax.set_xticks([]) - ax.set_yticks([]) - self.canvas.draw_idle() + return OrthoViewer + raise AttributeError(f"module 'autoflow.ortho_viewer' has no attribute {name!r}") diff --git a/autoflow/pipeline.py b/autoflow/pipeline.py index 304c24b..d7a95d4 100755 --- a/autoflow/pipeline.py +++ b/autoflow/pipeline.py @@ -1,429 +1,5 @@ -import json -import os -import numpy as np +"""Compatibility re-exports for the pipeline engine.""" -from .models import StepId, ObjectKind -from .algorithms import ( - load_h5_data, - filter_segmask_labels, binarize_segmask, merge_segmask_to_3d, - preprocess_mask_for_skeleton, - generate_skeleton_from_mask3d, build_graph_from_points, - segment_vessels_from_graph_and_mask, - generate_planes_from_paths, - compute_plane_metrics, compute_derived_metrics, - compute_plane_metrics_multithread, - generate_seed_points, -) +from .core.pipeline import PipelineEngine, StepResult - -class StepResult: - def __init__(self, step, success=True, skipped=False, message="", outputs=None): - self.step = step - self.success = success - self.skipped = skipped - self.message = message - self.outputs = outputs or [] - - -class PipelineEngine: - def _output_dir(self, ws): - out_dir = getattr(ws.paths, "output_dir", "") or "" - if out_dir: - os.makedirs(out_dir, exist_ok=True) - return out_dir - base = ws.paths.segmask_path or ws.paths.flow_path or "." - out_dir = os.path.dirname(base) or "." - os.makedirs(out_dir, exist_ok=True) - return out_dir - - def _json_safe(self, obj): - if isinstance(obj, np.floating): - val = float(obj) - return val if np.isfinite(val) else None - if isinstance(obj, np.integer): - return int(obj) - if isinstance(obj, float): - return obj if np.isfinite(obj) else None - if isinstance(obj, dict): - return {k: self._json_safe(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [self._json_safe(v) for v in obj] - return obj - - def load_data(self, ws, log): - path = ws.paths.segmask_path or ws.paths.flow_path - if not path: - raise ValueError("data path is empty") - data = load_h5_data(path) - flow = np.asarray(data["flow"], dtype=np.float32) - mag = np.asarray(data["mag"], dtype=np.float32) - seg = np.asarray(data["segmask"], dtype=np.int16) - if flow.ndim == 4 and flow.shape[-1] == 3: - flow = flow[..., np.newaxis, :] - if mag.ndim == 3: - mag = mag[..., np.newaxis] - if seg.ndim == 3: - seg = np.repeat(seg[..., np.newaxis], flow.shape[3], axis=3) - elif seg.ndim == 4 and seg.shape[3] == 1 and flow.shape[3] > 1: - seg = np.repeat(seg, flow.shape[3], axis=3) - if seg.shape[3] != flow.shape[3]: - raise ValueError(f"segmask time dimension {seg.shape[3]} != flow {flow.shape[3]}") - - ws.segmask_raw = seg - ws.resolution = np.asarray(data["resolution"], dtype=float).reshape(3) - ws.origin = np.asarray(data.get("origin", [0.0, 0.0, 0.0]), dtype=float).reshape(3) - ws.venc = np.asarray(data["venc"], dtype=float).reshape(-1) - ws.rr = float(data.get("rr", 1000.0)) - ws.current_t = 0 - ws.flow_raw = flow - ws.mag_raw = mag - ws.derived.tke_array = np.asarray(data["tke_array"], dtype=np.float32) if "tke_array" in data else None - ws.data_loaded = True - - ws.remove_object_by_data_key("segmask_raw_surface") - ws.add_object(name="segmask_raw", kind=ObjectKind.SEGMENTATION, - data_key="segmask_raw_surface", visible=True, opacity=0.3, - scalars="label", cmap="tab10", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Label") - - ulabels = ws.unique_labels() - msg = f"Loaded: segmask={ws.segmask_raw.shape} labels={ulabels} rr={ws.rr}" - msg += f" flow={ws.flow_raw.shape} mag={ws.mag_raw.shape}" - msg += f" origin={ws.origin.tolist()}" - log(msg) - return msg - - def preprocess(self, ws): - if ws.segmask_raw is None: - raise ValueError("segmask_raw is None") - ws.segmask_labels = filter_segmask_labels(ws.segmask_raw) - ws.segmask_binary = binarize_segmask(ws.segmask_labels) - ws.segmask_3d = merge_segmask_to_3d(ws.segmask_binary) - ws.set_object_visible_by_data_key("segmask_raw_surface", False) - ws.remove_object_by_data_key("segmask_pre_surface") - ws.add_object(name="segmask_preprocessed", kind=ObjectKind.SEGMENTATION, - data_key="segmask_pre_surface", visible=True, opacity=0.25, - scalars="label", cmap="tab10", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Label") - - def run_step(self, ws, step, log): - dispatch = { - StepId.GENERATE_SKELETON: self._step_generate_skeleton, - StepId.EDIT_SKELETON: self._step_edit_skeleton, - StepId.GENERATE_GRAPH: self._step_generate_graph, - StepId.EDIT_GRAPH: self._step_edit_graph, - StepId.GENERATE_PLANES: self._step_generate_planes, - StepId.EDIT_PLANES: self._step_edit_planes, - StepId.GENERATE_STREAMLINES: self._step_generate_streamlines, - StepId.PLANE_STREAMLINES: self._step_plane_streamlines, - StepId.COMPUTE_PLANE_METRICS: self._step_compute_plane_metrics, - StepId.COMPUTE_DERIVED_METRICS: self._step_compute_derived_metrics, - } - return dispatch[step](ws) - - def _step_generate_skeleton(self, ws): - self.preprocess(ws) - if ws.skeleton_params.remove_small_cc: - from .algorithms import remove_small_cc_from_binary_mask - ws.segmask_binary = remove_small_cc_from_binary_mask( - ws.segmask_binary, ws.resolution, ws.skeleton_params.min_cc_volume_mm3) - ws.segmask_3d = merge_segmask_to_3d(ws.segmask_binary) - processed = preprocess_mask_for_skeleton(ws.segmask_3d, ws.skeleton_params, resolution=ws.resolution) - pts, mask = generate_skeleton_from_mask3d(processed, ws.resolution) - ws.skeleton_points = pts - ws.skeleton_mask = mask - ws.remove_object_by_data_key("skeleton_points") - ws.remove_object_by_data_key("skeleton_mask_surface") - ws.remove_object_by_data_key("segmask_3d_surface") - ws.add_object(name="skeleton_points", kind=ObjectKind.SKELETON, - data_key="skeleton_points", visible=True, opacity=1.0, - color="red", point_size=8) - # ws.add_object(name="skeleton_mask", kind=ObjectKind.SKELETON, - # data_key="skeleton_mask_surface", visible=False, opacity=0.15, color="yellow") - ws.add_object(name="segmask_mesh", kind=ObjectKind.SEGMENTATION, - data_key="segmask_3d_surface", visible=True, opacity=0.15, - color="gray") - ws.pipeline.mark_done(StepId.GENERATE_SKELETON) - return StepResult(StepId.GENERATE_SKELETON, True, False, f"Skeleton: {len(pts)} points") - - def _step_edit_skeleton(self, ws): - ws.pipeline.mark_done(StepId.EDIT_SKELETON, skipped=True) - return StepResult(StepId.EDIT_SKELETON, True, True, "Skeleton edit") - - def _step_generate_graph(self, ws): - if ws.skeleton_points is None or len(ws.skeleton_points) == 0: - self._step_generate_skeleton(ws) - graph = build_graph_from_points(ws.skeleton_points, ws.resolution) - ws.graph = graph - - flow_for_orientation = None - if ws.flow_raw is not None and ws.segmask_binary is not None: - flow_for_orientation = ws.flow_raw * ws.segmask_binary[..., None] - labels, paths, node_paths, path_info, forks = segment_vessels_from_graph_and_mask( - ws.segmask_3d, ws.graph, ws.resolution, - flow_xyzt3=flow_for_orientation, - segmask_binary_4d=ws.segmask_binary, - origin=ws.origin, - ) - ws.branch_labels = labels - ws.centerline_paths = [np.asarray(p, dtype=float) for p in paths] - ws.centerline_node_paths = [list(map(int, p)) for p in node_paths] - ws.path_info = path_info - ws.forks = forks - ws.selected_path_index = -1 - - ws.remove_object_by_data_key("graph_lines") - ws.add_object(name="graph_lines", kind=ObjectKind.GRAPH, - data_key="graph_lines", visible=True, opacity=1.0, - color="blue", line_width=2) - - ws.remove_objects_by_prefix("path_") - ws.remove_objects_by_prefix("smooth_path_") - ws.remove_objects_by_prefix("path_arrow_") - ws.remove_object_by_data_key("fork_markers") - - if len(ws.forks) > 0: - ws.add_object(name="Forks", kind=ObjectKind.AUX, - data_key="fork_markers", visible=True, opacity=1.0, - color="magenta", point_size=12) - - ws.pipeline.mark_done(StepId.GENERATE_GRAPH) - return StepResult(StepId.GENERATE_GRAPH, True, False, - f"Graph: {len(graph.points)} nodes, {len(graph.edges)} edges | " - f"paths={len(ws.centerline_paths)} forks={len(ws.forks)}") - - def _step_edit_graph(self, ws): - ws.pipeline.mark_done(StepId.EDIT_GRAPH, skipped=True) - return StepResult(StepId.EDIT_GRAPH, True, True, "Graph edit") - - def _compute_plane_metrics_internal(self, ws, save=True, use_multithread=False): - if not ws.has_flow(): - return [], {}, "Plane metrics skipped: no flow" - if ws.segmask_binary is None: - self.preprocess(ws) - # Prefer the smoothed centerlines (better local tangents) but fall back - # to the raw ordered ones if the smoothing step hasn't been run yet. - paths_for_tangent = ws.centerline_paths_smooth if len(ws.centerline_paths_smooth) > 0 else ws.centerline_paths - if use_multithread: - metrics, qc = compute_plane_metrics_multithread( - ws.flow_raw, ws.segmask_binary, ws.resolution, ws.origin, ws.planes, - RR=ws.rr, branch_labels_3d=ws.branch_labels, - path_info=ws.path_info, forks=ws.forks, paths=paths_for_tangent, - return_qc=True) - else: - metrics, qc = compute_plane_metrics( - ws.flow_raw, ws.segmask_binary, ws.resolution, ws.origin, ws.planes, - RR=ws.rr, branch_labels_3d=ws.branch_labels, - path_info=ws.path_info, forks=ws.forks, paths=paths_for_tangent, - return_qc=True) - ws.derived.plane_metrics = metrics - ws.derived.plane_qc = qc - for i, metric in enumerate(metrics): - if i < len(ws.planes): - ws.planes[i].metrics = dict(metric) - msg = f"Plane metrics: {len(metrics)} paths={len(qc.get('path_ic', {}))} forks={len(qc.get('forks', []))}" - if save: - out_dir = self._output_dir(ws) - plane_metric_path = os.path.join(out_dir, "plane_metrics.json") - qc_path = os.path.join(out_dir, "plane_qc.json") - with open(plane_metric_path, "w", encoding="utf-8") as f: - json.dump(metrics, f, ensure_ascii=False, indent=2) - with open(qc_path, "w", encoding="utf-8") as f: - json.dump(qc, f, ensure_ascii=False, indent=2) - msg += f" saved={plane_metric_path} qc={qc_path}" - return metrics, qc, msg - - def _save_planes_json(self, ws): - out_dir = self._output_dir(ws) - out_path = os.path.join(out_dir, "planes.json") - payload = [] - origin = np.asarray(ws.origin, dtype=float).reshape(3) - for i, p in enumerate(ws.planes): - center_local = np.asarray(p.center, dtype=float).reshape(3) - item = { - "plane_index": int(i), - "center": center_local.tolist(), - "center_world": (center_local + origin).tolist(), - "normal": np.asarray(p.normal).tolist(), - "label": int(p.label), - "path_index": int(p.path_index), - "distance": float(p.distance), - } - if p.metrics: - item.update(json.loads(json.dumps(p.metrics, ensure_ascii=False))) - if 0 <= int(p.path_index) < len(ws.path_info): - item["path_info"] = ws.path_info[int(p.path_index)] - payload.append(item) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - return out_path - - def _step_generate_planes(self, ws): - if ws.graph is None or len(ws.graph.points) == 0: - self._step_generate_graph(ws) - - if len(ws.centerline_paths) == 0: - flow_for_orientation = None - if ws.flow_raw is not None and ws.segmask_binary is not None: - flow_for_orientation = ws.flow_raw * ws.segmask_binary[..., None] - labels, paths, node_paths, path_info, forks = segment_vessels_from_graph_and_mask( - ws.segmask_3d, ws.graph, ws.resolution, - flow_xyzt3=flow_for_orientation, - segmask_binary_4d=ws.segmask_binary, - origin=ws.origin, - ) - ws.branch_labels = labels - ws.centerline_paths = [np.asarray(p, dtype=float) for p in paths] - ws.centerline_node_paths = [list(map(int, p)) for p in node_paths] - ws.path_info = path_info - ws.forks = forks - ws.selected_path_index = -1 - - ws.remove_objects_by_prefix("smooth_path_") - ws.remove_objects_by_prefix("path_arrow_") - ws.remove_object_by_data_key("fork_markers") - - pgp = ws.plane_gen_params - planes, smooth_paths = generate_planes_from_paths( - ws.centerline_paths, - cross_section_distance=pgp.cross_section_distance, - start_distance=pgp.start_distance, - end_distance=pgp.end_distance, - smoothing_window=pgp.smoothing_window * pgp.inter_time, - smoothing_polyorder=pgp.smoothing_polyorder, - inter_time=pgp.inter_time, - use_center_plane=pgp.use_center_plane, - ) - ws.planes = planes - ws.centerline_paths_smooth = smooth_paths - for i in range(len(ws.centerline_paths_smooth)): - direction_text = "" - if i < len(ws.path_info): - direction_text = ws.path_info[i].get("direction_text", "") - name = f"Path {i}" if not direction_text else f"Path {i} [{direction_text}]" - ws.add_object(name=name, kind=ObjectKind.BRANCH, - data_key=f"smooth_path_{i}", visible=True, opacity=1.0, - color="red", line_width=3) - # ws.add_object(name=f"Path {i} Arrow", kind=ObjectKind.AUX, - # data_key=f"path_arrow_{i}", visible=True, opacity=1.0, - # color="lime", line_width=2) - if len(ws.forks) > 0: - ws.remove_object_by_data_key("fork_markers") - ws.add_object(name="Forks", kind=ObjectKind.AUX, - data_key="fork_markers", visible=True, opacity=1.0, - color="magenta", point_size=12) - - ws.remove_objects_by_prefix("plane_") - for i in range(len(ws.planes)): - ws.add_object(name=f"Plane {i}", kind=ObjectKind.PLANE, - data_key=f"plane_{i}", visible=True, opacity=0.6, - color="yellow", line_width=2) - - planes_path = self._save_planes_json(ws) - ws.pipeline.mark_done(StepId.GENERATE_PLANES) - msg = f"Planes: {len(ws.planes)} paths={len(ws.centerline_paths_smooth)} forks={len(ws.forks)} saved={planes_path}" - return StepResult(StepId.GENERATE_PLANES, True, False, msg) - - def _step_edit_planes(self, ws): - ws.pipeline.mark_done(StepId.EDIT_PLANES, skipped=True) - return StepResult(StepId.EDIT_PLANES, True, True, "Plane edit") - - def _step_generate_streamlines(self, ws): - if ws.flow_raw is None or ws.segmask_3d is None: - return StepResult(StepId.GENERATE_STREAMLINES, True, True, "Streamlines skipped: no flow or mask") - self.preprocess(ws) - ws.streamline_seeds = generate_seed_points( - ws.segmask_3d, - ws.resolution, - ws.origin, - ratio=ws.streamline_params.seed_ratio, - rng_seed=ws.streamline_params.rng_seed, - min_seeds=ws.streamline_params.min_seeds, - ) - ws.streamline_cache.clear() - ws.streamline_active = True - ws.remove_object_by_data_key("streamlines_live") - ws.add_object( - name="streamlines", kind=ObjectKind.FLOW, - data_key="streamlines_live", visible=True, opacity=1.0, - scalars="Velocity", cmap="turbo", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") - ws.pipeline.mark_done(StepId.GENERATE_STREAMLINES) - p = ws.streamline_params - param_msg = (f"Streamlines enabled: seed_ratio={p.seed_ratio} max_steps={p.max_steps} " - f"min_seeds={p.min_seeds} terminal_speed={p.terminal_speed} rng_seed={p.rng_seed}") - return StepResult(StepId.GENERATE_STREAMLINES, True, False, param_msg) - - def _step_plane_streamlines(self, ws): - if ws.flow_raw is None or ws.segmask_3d is None: - return StepResult(StepId.PLANE_STREAMLINES, True, True, "Plane streamlines skipped: no flow or mask") - if len(ws.planes) == 0: - return StepResult(StepId.PLANE_STREAMLINES, True, True, "Plane streamlines skipped: no planes") - self.preprocess(ws) - ws.plane_streamline_cache.clear() - ws.plane_streamline_active = True - ws.remove_object_by_data_key("plane_streamlines_live") - ws.add_object( - name="plane_streamlines", kind=ObjectKind.FLOW, - data_key="plane_streamlines_live", visible=True, opacity=1.0, - scalars="Velocity", cmap="turbo", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") - ws.pipeline.mark_done(StepId.PLANE_STREAMLINES) - pidx = ws.plane_streamline_plane_idx - return StepResult(StepId.PLANE_STREAMLINES, True, False, - f"Plane streamlines enabled from plane {pidx}") - - def _step_compute_plane_metrics(self, ws): - if not ws.has_flow(): - return StepResult(StepId.COMPUTE_PLANE_METRICS, True, True, "Plane metrics skipped: no flow") - if len(ws.planes) == 0: - self._step_generate_planes(ws) - use_mt = getattr(ws.derived_params, "use_multithread", False) - _, _, msg = self._compute_plane_metrics_internal(ws, save=True, use_multithread=use_mt) - self._save_planes_json(ws) - ws.pipeline.mark_done(StepId.COMPUTE_PLANE_METRICS) - return StepResult(StepId.COMPUTE_PLANE_METRICS, True, False, msg) - - def _step_compute_derived_metrics(self, ws): - if not ws.has_flow(): - return StepResult(StepId.COMPUTE_DERIVED_METRICS, True, True, "Derived metrics skipped: no flow") - self.preprocess(ws) - dp = ws.derived_params - loaded_tke = ws.derived.tke_array - result = compute_derived_metrics( - flow=ws.flow_raw * ws.segmask_binary[..., None], - mask4d=ws.segmask_binary, - spacing=ws.resolution, - origin=ws.origin, - smoothing_iteration=dp.smoothing_iteration, - viscosity=dp.viscosity, - inward_distance=dp.inward_distance, - parabolic_fitting=dp.parabolic_fitting, - no_slip_condition=dp.no_slip_condition, - step_size=dp.step_size, - tube_radius=dp.tube_radius, - rho=dp.rho, - save_pixelwise=False, - tke_array=loaded_tke, - ) - ws.derived.wss_surfaces = result["wss_surfaces"] - ws.derived.wss_volume = result.get("wss_volume") - ws.derived.tke_volume = result["tke_volume"] - ws.derived.tke_array = result.get("tke_array") - ws.derived.streamlines = [] - ws.derived.pixelwise_export = result.get("pixelwise_export", {}) - for dk in ["wss_surface_live", "tke_volume"]: - ws.remove_object_by_data_key(dk) - wss_max = float(np.nanmax(ws.derived.wss_volume)) if ws.derived.wss_volume is not None and np.size(ws.derived.wss_volume) else 0.0 - tke_max = float(np.nanmax(ws.derived.tke_array)) if ws.derived.tke_array is not None and np.size(ws.derived.tke_array) else 0.0 - ws.add_object(name="wss_surface", kind=ObjectKind.METRIC, - data_key="wss_surface_live", visible=False, opacity=1.0, - scalars="wss", cmap="jet", clim=(0.0, wss_max if wss_max > 0 else 1.0), dynamic=True, - show_scalar_bar=True, scalar_bar_title="WSS (Pa)") - ws.add_object(name="tke_volume", kind=ObjectKind.METRIC, - data_key="tke_volume", visible=False, opacity=0.5, - scalars="TKE", cmap="hot", clim=(0.0, tke_max if tke_max > 0 else 1.0), dynamic=True, - show_scalar_bar=True, scalar_bar_title="TKE (J/m³)") - msg = f"Derived: Nt={len(ws.derived.wss_surfaces)}" - ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) - return StepResult(StepId.COMPUTE_DERIVED_METRICS, True, False, msg) +__all__ = ["PipelineEngine", "StepResult"] diff --git a/autoflow/plane_io.py b/autoflow/plane_io.py new file mode 100755 index 0000000..82419ad --- /dev/null +++ b/autoflow/plane_io.py @@ -0,0 +1,149 @@ +import json +import os + +import numpy as np + +from .core.models import PlaneData + + +def _normalize(v): + arr = np.asarray(v, dtype=float).reshape(3) + n = np.linalg.norm(arr) + if n <= 1e-12: + return np.array([1.0, 0.0, 0.0], dtype=float) + return arr / n + + +def _path_cumdist(path): + pts = np.asarray(path, dtype=float).reshape(-1, 3) + if len(pts) <= 1: + return np.zeros(len(pts), dtype=float) + return np.concatenate([[0.0], np.cumsum(np.linalg.norm(np.diff(pts, axis=0), axis=1))]) + + +def _make_plane_payload(ws, source_path=""): + origin = np.asarray(ws.origin, dtype=float).reshape(3) + payload = { + "source": source_path, + "origin": origin.tolist(), + "resolution": np.asarray(ws.resolution, dtype=float).reshape(3).tolist(), + "planes": [], + } + for i, plane in enumerate(ws.planes): + center_local = np.asarray(plane.center, dtype=float).reshape(3) + payload["planes"].append( + { + "plane_index": int(i), + "center": center_local.tolist(), + "center_world": (center_local + origin).tolist(), + "normal": _normalize(plane.normal).tolist(), + "label": int(plane.label), + "path_index": int(plane.path_index), + "distance": float(plane.distance), + } + ) + return payload + + +def save_plane_positions(ws, out_path, source_path=""): + payload = _make_plane_payload(ws, source_path=source_path) + with open(out_path, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + return out_path + + +def load_plane_positions(path): + with open(path, "r", encoding="utf-8") as f: + payload = json.load(f) + if isinstance(payload, dict) and "planes" in payload: + return payload["planes"] + if isinstance(payload, list): + return payload + raise ValueError(f"Invalid plane position file: {path}") + + +def _nearest_path_info(center_world, paths_world): + best_dist = np.inf + best_path_idx = -1 + best_point_idx = -1 + best_distance = 0.0 + for path_idx, path in enumerate(paths_world): + pts = np.asarray(path, dtype=float).reshape(-1, 3) + if len(pts) == 0: + continue + d = np.linalg.norm(pts - center_world.reshape(1, 3), axis=1) + point_idx = int(np.argmin(d)) + dist = float(d[point_idx]) + if dist < best_dist: + cum = _path_cumdist(pts) + best_dist = dist + best_path_idx = int(path_idx) + best_point_idx = point_idx + best_distance = float(cum[point_idx]) if len(cum) > point_idx else 0.0 + return best_path_idx, best_point_idx, best_dist, best_distance + + +def _path_tangent(path_world, point_idx): + pts = np.asarray(path_world, dtype=float).reshape(-1, 3) + if len(pts) == 0: + return np.array([1.0, 0.0, 0.0], dtype=float) + i0 = max(0, int(point_idx) - 1) + i1 = min(len(pts) - 1, int(point_idx) + 1) + if i1 == i0: + i1 = min(len(pts) - 1, i0 + 1) + tangent = pts[i1] - pts[i0] + return _normalize(tangent) + + +def project_planes_to_workspace(plane_items, ws): + origin = np.asarray(ws.origin, dtype=float).reshape(3) + paths_local = ws.centerline_paths_smooth if len(ws.centerline_paths_smooth) > 0 else ws.centerline_paths + paths_world = [np.asarray(path, dtype=float).reshape(-1, 3) + origin.reshape(1, 3) for path in paths_local] + planes = [] + for item in plane_items: + if "center_world" in item: + center_world = np.asarray(item["center_world"], dtype=float).reshape(3) + elif "center" in item: + center_world = np.asarray(item["center"], dtype=float).reshape(3) + else: + continue + normal = _normalize(item.get("normal", [1.0, 0.0, 0.0])) + path_index = int(item.get("path_index", -1)) + distance = float(item.get("distance", 0.0)) + if paths_world: + nearest_path_idx, nearest_point_idx, _, nearest_distance = _nearest_path_info(center_world, paths_world) + if nearest_path_idx >= 0: + path_index = int(nearest_path_idx) + distance = float(nearest_distance) + if np.linalg.norm(normal) <= 1e-12: + normal = _path_tangent(paths_world[path_index], nearest_point_idx) + if path_index < 0: + path_index = 0 + planes.append( + PlaneData( + center=center_world - origin, + normal=_normalize(normal), + label=int(path_index) + 1, + path_index=int(path_index), + distance=float(distance), + ) + ) + return planes + + +def resolve_reuse_plane_file(reuse_spec, case_name): + if not reuse_spec: + return "" + if os.path.isfile(reuse_spec): + return reuse_spec + if os.path.isdir(reuse_spec): + candidates = [ + os.path.join(reuse_spec, case_name, "plane_positions.json"), + os.path.join(reuse_spec, case_name, "planes.json"), + os.path.join(reuse_spec, "plane_positions.json"), + os.path.join(reuse_spec, "planes.json"), + ] + for candidate in candidates: + if os.path.isfile(candidate): + return candidate + return reuse_spec diff --git a/autoflow/processing.py b/autoflow/processing.py new file mode 100755 index 0000000..9382e88 --- /dev/null +++ b/autoflow/processing.py @@ -0,0 +1,435 @@ +import copy +import glob +import json +import os +import traceback + +import numpy as np + +from .core.models import StepId, Workspace +from .core.pipeline import PipelineEngine +from .algorithms import compute_derived_metrics +from .plane_io import ( + load_plane_positions, + project_planes_to_workspace, + resolve_reuse_plane_file, + save_plane_positions, +) +from .rendering import ( + render_plane_rotation_video, + render_streamlines_video, + render_tke_video, + render_wss_video, +) +from .reporting import load_metrics_from_output, print_metrics_summary, print_qc_summary + + +def process_single( + h5_path, + out_dir, + workspace=None, + skip_derived=False, + skip_plane_metrics=False, + use_multithread=False, + reuse_planes_path="", + fps=24, + plane_rotation_frames=180, + rotate_dynamic_video=False, + dynamic_rotation_frames=180, + dynamic_rotation_elevation_deg=None, + make_plane_video=True, + make_wss_video=True, + make_streamlines_video=True, + make_tke_video=True, + camera_view="iso", + camera_distance_scale=1.0, + add_plane_idx=False, + add_path_idx=False, + wss_clim=None, + wss_bar_cfg=None, + tke_clim=None, + tke_bar_cfg=None, + streamline_clim=None, + streamline_bar_cfg=None, + dynamic_time_repeat=1, +): + print(f"\n{'=' * 60}") + print(f"Processing: {h5_path}") + print(f"Output dir: {out_dir}") + print(f"{'=' * 60}") + + os.makedirs(out_dir, exist_ok=True) + ws = copy.deepcopy(workspace) if workspace is not None else Workspace() + ws.paths.segmask_path = h5_path + ws.paths.flow_path = h5_path + ws.paths.output_dir = out_dir + ws.derived_params.use_multithread = use_multithread + engine = PipelineEngine() + logger = lambda msg: None + import time as _time + + t_total_start = _time.time() + + print("[1/7] Loading data...") + engine.load_data(ws, logger) + + print("[2/7] Generate Skeleton...") + result = engine.run_step(ws, StepId.GENERATE_SKELETON, logger) + print(f" -> {result.message}") + + print("[3/7] Generate Graph (+ branches/forks)...") + result = engine.run_step(ws, StepId.GENERATE_GRAPH, logger) + print(f" -> {result.message}") + + print("[4/7] Generate Planes...") + result = engine.run_step(ws, StepId.GENERATE_PLANES, logger) + print(f" -> {result.message}") + + if reuse_planes_path: + print(f"[5/7] Reuse Plane Positions: {reuse_planes_path}") + plane_items = load_plane_positions(reuse_planes_path) + ws.planes = project_planes_to_workspace(plane_items, ws) + planes_json = engine._save_planes_json(ws) + print(f" -> Reused {len(ws.planes)} planes saved={planes_json}") + else: + print("[5/7] Use generated planes") + + if skip_plane_metrics: + print("[6/7] Skipped plane metrics") + else: + print("[6/7] Calculate & Save Metrics...") + _, _, metric_msg = engine._compute_plane_metrics_internal( + ws, + save=True, + use_multithread=use_multithread, + ) + print(f" -> {metric_msg}") + try: + engine._save_planes_json(ws) + except Exception: + pass + + pixelwise_result = {} + if not skip_derived: + print("[7/7] Compute Derived Metrics (WSS/TKE)...") + dp = ws.derived_params + engine.preprocess(ws) + loaded_tke = ws.derived.tke_array + + derived = compute_derived_metrics( + flow=ws.flow_raw * ws.segmask_binary[..., None], + mask4d=ws.segmask_binary, + spacing=ws.resolution, + origin=ws.origin, + smoothing_iteration=dp.smoothing_iteration, + viscosity=dp.viscosity, + inward_distance=dp.inward_distance, + parabolic_fitting=dp.parabolic_fitting, + no_slip_condition=dp.no_slip_condition, + step_size=dp.step_size, + tube_radius=dp.tube_radius, + rho=dp.rho, + save_pixelwise=True, + tke_array=loaded_tke, + ) + ws.derived.wss_surfaces = derived["wss_surfaces"] + ws.derived.wss_volume = derived.get("wss_volume") + ws.derived.tke_volume = derived["tke_volume"] + ws.derived.tke_array = derived.get("tke_array") + ws.derived.pixelwise_export = derived.get("pixelwise_export", {}) + pixelwise_result = ws.derived.pixelwise_export + pixel_path = os.path.join(out_dir, "derived_metrics_pixelwise.npz") + if pixelwise_result: + np.savez_compressed(pixel_path, **pixelwise_result) + print(f" -> Saved pixelwise: {pixel_path}") + ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) + print(f" -> Derived: Nt={len(ws.derived.wss_surfaces)}") + else: + print("[7/7] Skipped derived metrics (WSS/TKE)") + + total_time_sec = _time.time() - t_total_start + print(f" => Total pipeline took {total_time_sec:.2f}s") + + plane_positions_path = save_plane_positions(ws, os.path.join(out_dir, "plane_positions.json"), source_path=h5_path) + print(f"Plane positions saved: {plane_positions_path}") + + video_paths = {} + if make_plane_video: + try: + video_paths["planes"] = render_plane_rotation_video( + ws, + out_dir, + fps=fps, + n_frames=plane_rotation_frames, + smoothing_iteration=ws.derived_params.smoothing_iteration, + distance_scale=camera_distance_scale, + add_plane_idx=add_plane_idx, + add_path_idx=add_path_idx, + ) + if video_paths["planes"]: + print(f"Plane video saved: {video_paths['planes']}") + except Exception: + print("[WARN] Plane video failed") + print(traceback.format_exc()) + video_paths["planes"] = "" + + if make_streamlines_video: + try: + video_paths["streamlines"] = render_streamlines_video( + ws, + out_dir, + fps=fps, + smoothing_iteration=ws.derived_params.smoothing_iteration, + view=camera_view, + distance_scale=camera_distance_scale, + streamline_clim=streamline_clim, + streamline_bar_cfg=streamline_bar_cfg, + rotate=rotate_dynamic_video, + rotation_frames=dynamic_rotation_frames, + elevation_deg=dynamic_rotation_elevation_deg, + time_repeat=dynamic_time_repeat, + ) + if video_paths["streamlines"]: + print(f"Streamlines video saved: {video_paths['streamlines']}") + except Exception: + print("[WARN] Streamlines video failed") + print(traceback.format_exc()) + video_paths["streamlines"] = "" + + if not skip_derived and make_wss_video: + try: + video_paths["wss"] = render_wss_video( + ws, + out_dir, + fps=fps, + smoothing_iteration=ws.derived_params.smoothing_iteration, + view=camera_view, + distance_scale=camera_distance_scale, + wss_clim=wss_clim, + wss_bar_cfg=wss_bar_cfg, + rotate=rotate_dynamic_video, + rotation_frames=dynamic_rotation_frames, + elevation_deg=dynamic_rotation_elevation_deg, + time_repeat=dynamic_time_repeat, + ) + if video_paths["wss"]: + print(f"WSS video saved: {video_paths['wss']}") + except Exception: + print("[WARN] WSS video failed") + print(traceback.format_exc()) + video_paths["wss"] = "" + + if not skip_derived and make_tke_video: + try: + video_paths["tke"] = render_tke_video( + ws, + out_dir, + fps=fps, + smoothing_iteration=ws.derived_params.smoothing_iteration, + view=camera_view, + distance_scale=camera_distance_scale, + tke_clim=tke_clim, + tke_bar_cfg=tke_bar_cfg, + rotate=rotate_dynamic_video, + rotation_frames=dynamic_rotation_frames, + elevation_deg=dynamic_rotation_elevation_deg, + time_repeat=dynamic_time_repeat, + ) + if video_paths["tke"]: + print(f"TKE video saved: {video_paths['tke']}") + except Exception: + print("[WARN] TKE video failed") + print(traceback.format_exc()) + video_paths["tke"] = "" + + table_rows, raw_metrics, qc_data = (None, None, None) + if not skip_plane_metrics: + table_rows, raw_metrics, qc_data = load_metrics_from_output(out_dir) + + if table_rows: + print("\n === Plane Metrics Summary ===") + print_metrics_summary(table_rows) + + if qc_data: + print("\n === Fork QC Summary ===") + print_qc_summary(qc_data, ws.forks) + + summary = { + "input": h5_path, + "output_dir": out_dir, + "resolution": ws.resolution.tolist(), + "origin": np.asarray(ws.origin, dtype=float).reshape(3).tolist(), + "rr": ws.rr, + "total_time_sec": float(total_time_sec), + "n_planes": len(ws.planes), + "n_skeleton_pts": len(ws.skeleton_points) if ws.skeleton_points is not None else 0, + "n_graph_nodes": len(ws.graph.points), + "n_graph_edges": len(ws.graph.edges), + "n_paths": len(ws.centerline_paths_smooth), + "n_forks": len(ws.forks), + "path_info": ws.path_info, + "forks": ws.forks, + "plane_metrics": ws.derived.plane_metrics, + "plane_qc": ws.derived.plane_qc, + "plane_positions_file": plane_positions_path, + "reused_planes_file": reuse_planes_path, + "videos": video_paths, + "pixelwise_export": {k: list(np.asarray(v).shape) for k, v in pixelwise_result.items()} if pixelwise_result else {}, + } + summary_path = os.path.join(out_dir, "summary.json") + with open(summary_path, "w", encoding="utf-8") as f: + json.dump(summary, f, ensure_ascii=False, indent=2) + print(f"\nSummary saved: {summary_path}") + return summary + + +def collect_h5_files(inputs): + files = [] + for inp in inputs: + if os.path.isfile(inp) and inp.lower().endswith((".h5", ".hdf5")): + files.append(inp) + elif os.path.isdir(inp): + files.extend(sorted(glob.glob(os.path.join(inp, "**", "*.h5"), recursive=True))) + files.extend(sorted(glob.glob(os.path.join(inp, "**", "*.hdf5"), recursive=True))) + return sorted(dict.fromkeys(files)) + + +def build_base_workspace(): + ws = Workspace() + ws.plane_gen_params.use_center_plane = globals().get("USE_CENTER_PLANE", True) + ws.plane_gen_params.cross_section_distance = globals().get("CROSS_SECTION_DIST", 5.0) + ws.plane_gen_params.start_distance = globals().get("START_DIST", 5.0) + ws.plane_gen_params.end_distance = globals().get("END_DIST", 0.0) + ws.skeleton_params.remove_small_cc = globals().get("REMOVE_SMALL_CC", True) + ws.skeleton_params.min_cc_volume_mm3 = globals().get("MIN_CC_VOLUME", 50.0) + ws.streamline_params.max_steps = 2000 + ws.streamline_params.min_seeds = 50 + ws.streamline_params.seed_ratio = globals().get("SEED_RATIO", 0.02) + ws.streamline_params.tube_radius = globals().get("TUBE_RADIUS", 0.05) + return ws + + +def run_batch(): + inputs = globals().get("INPUT", None) + if inputs is None: + raise ValueError("INPUT is not defined.") + dynamic_time_repeat = globals().get("DYNAMIC_TIME_REPEAT", 1) + output_dir = globals().get("OUTPUT_DIR", "./batch_output") + skip_derived = globals().get("SKIP_DERIVED", False) + use_multithread = globals().get("USE_MULTITHREAD", True) + reuse_planes = globals().get("REUSE_PLANES", "") + fps = globals().get("FPS", 12) + plane_rotation_frames = globals().get("PLANE_ROTATION_FRAMES", 180) + make_plane_video = globals().get("MAKE_PLANE_VIDEO", True) + make_wss_video = globals().get("MAKE_WSS_VIDEO", True) + make_streamlines_video = globals().get("MAKE_STREAMLINES_VIDEO", True) + make_tke_video = globals().get("MAKE_TKE_VIDEO", True) + camera_view = globals().get("CAMERA_VIEW", "posterior") + camera_distance_scale = globals().get("CAMERA_DISTANCE_SCALE", 1.5) + skip_plane_metrics = globals().get("SKIP_PLANE_METRICS", False) + rotate_dynamic_video = globals().get("ROTATE_DYNAMIC_VIDEO", False) + dynamic_rotation_frames = globals().get("DYNAMIC_ROTATION_FRAMES", 180) + dynamic_rotation_elevation_deg = globals().get("DYNAMIC_ROTATION_ELEVATION_DEG", None) + add_plane_idx = globals().get("ADD_PLANE_IDX", False) + add_path_idx = globals().get("ADD_PATH_IDX", False) + + wss_clim = globals().get("WSS_CLIM", (0, 5)) + wss_bar_cfg = globals().get( + "WSS_BAR_CFG", + {"position_x": 0.75, "position_y": 0.2, "height": 0.22, "width": 0.05, "title_font_size": 40, "label_font_size": 32}, + ) + tke_clim = globals().get("TKE_CLIM", (0, 2)) + tke_bar_cfg = globals().get( + "TKE_BAR_CFG", + {"position_x": 0.75, "position_y": 0.2, "height": 0.22, "width": 0.05, "title_font_size": 40, "label_font_size": 32}, + ) + streamline_clim = globals().get("STREAMLINE_CLIM", (0, 0.6)) + streamline_bar_cfg = globals().get( + "STREAMLINE_BAR_CFG", + {"position_x": 0.75, "position_y": 0.2, "height": 0.22, "width": 0.05, "title_font_size": 40, "label_font_size": 32}, + ) + + h5_files = collect_h5_files(inputs) + if not h5_files: + print("No H5 files found.") + return [], "" + + print(f"Found {len(h5_files)} file(s) to process.") + base_ws = build_base_workspace() + results = [] + case_out = "" + + for path in h5_files: + name = os.path.splitext(os.path.basename(path))[0] + case_out = os.path.join(output_dir, name) + reuse_file = resolve_reuse_plane_file(reuse_planes, name) + + if reuse_planes and not os.path.isfile(reuse_file): + results.append({"file": path, "status": "error", "error": f"reuse plane file not found: {reuse_planes}"}) + print(f"\n[ERROR] Reuse plane file not found: {reuse_planes}") + continue + + try: + summary = process_single( + path, + case_out, + workspace=base_ws, + skip_derived=skip_derived, + use_multithread=use_multithread, + reuse_planes_path=reuse_file, + fps=fps, + plane_rotation_frames=plane_rotation_frames, + make_plane_video=make_plane_video, + make_wss_video=make_wss_video, + make_streamlines_video=make_streamlines_video, + make_tke_video=make_tke_video, + camera_view=camera_view, + camera_distance_scale=camera_distance_scale, + add_plane_idx=add_plane_idx, + add_path_idx=add_path_idx, + wss_clim=wss_clim, + wss_bar_cfg=wss_bar_cfg, + tke_clim=tke_clim, + tke_bar_cfg=tke_bar_cfg, + streamline_clim=streamline_clim, + streamline_bar_cfg=streamline_bar_cfg, + skip_plane_metrics=skip_plane_metrics, + rotate_dynamic_video=rotate_dynamic_video, + dynamic_rotation_frames=dynamic_rotation_frames, + dynamic_rotation_elevation_deg=dynamic_rotation_elevation_deg, + dynamic_time_repeat=dynamic_time_repeat, + ) + results.append({"file": path, "status": "ok", "summary": summary}) + except Exception: + print(f"\n[ERROR] Failed: {path}") + print(traceback.format_exc()) + results.append({"file": path, "status": "error", "error": traceback.format_exc()}) + + os.makedirs(output_dir, exist_ok=True) + batch_report = os.path.join(output_dir, "batch_report.json") + with open(batch_report, "w", encoding="utf-8") as f: + json.dump(results, f, ensure_ascii=False, indent=2) + + n_ok = sum(1 for r in results if r["status"] == "ok") + + times_sec = [] + for r in results: + if r.get("status") == "ok": + total_time_sec = r.get("summary", {}).get("total_time_sec", None) + if total_time_sec is not None: + times_sec.append(float(total_time_sec)) + + if times_sec: + arr = np.asarray(times_sec, dtype=float) + mean_sec = arr.mean() + std_sec = arr.std(ddof=1) if len(arr) > 1 else 0.0 + time_text = f"{mean_sec:.2f} ± {std_sec:.2f} s" + print(f"\nCase time: {time_text}") + + with open(os.path.join(output_dir, "time_summary.txt"), "w", encoding="utf-8") as f: + f.write(f"n = {len(arr)}\n") + f.write(f"mean_sec = {mean_sec:.6f}\n") + f.write(f"std_sec = {std_sec:.6f}\n") + f.write(f"formatted = {time_text}\n") + print(f"\nDone: {n_ok}/{len(results)} succeeded. Report: {batch_report}") + return results, case_out diff --git a/autoflow/rendering/__init__.py b/autoflow/rendering/__init__.py new file mode 100755 index 0000000..fdc66a2 --- /dev/null +++ b/autoflow/rendering/__init__.py @@ -0,0 +1,19 @@ +from .videos import ( + CAMERA_PRESETS, + WINDOW_SIZE, + extract_frame, + render_plane_rotation_video, + render_streamlines_video, + render_tke_video, + render_wss_video, +) + +__all__ = [ + "CAMERA_PRESETS", + "WINDOW_SIZE", + "extract_frame", + "render_plane_rotation_video", + "render_streamlines_video", + "render_tke_video", + "render_wss_video", +] diff --git a/autoflow/rendering/videos.py b/autoflow/rendering/videos.py new file mode 100755 index 0000000..dd1ab59 --- /dev/null +++ b/autoflow/rendering/videos.py @@ -0,0 +1,674 @@ +import os + +import imageio.v2 as imageio +import numpy as np +import pyvista as pv +from PIL import Image + +from ..algorithms import create_uniform_grid, generate_seed_points, generate_streamlines_at_t + +WINDOW_SIZE = (1600, 1200) +_OFFSCREEN_BOOTSTRAPPED = False + +CAMERA_PRESETS = { + "iso": (35.0, 25.0), + "iso_back": (215.0, 25.0), + "right": (0.0, 0.0), + "left": (180.0, 0.0), + "anterior": (270.0, 0.0), + "posterior": (90.0, 0.0), + "superior": (0.0, 89.9), + "inferior": (0.0, -89.9), +} + + +def _offscreen_mode(): + mode = str(os.environ.get("AUTOFLOW_OFFSCREEN_MODE", "local")).strip().lower() + if mode in {"display", "x11", "onscreen"}: + return "display" + if mode in {"local", "headless", "xvfb"}: + return "local" + + display = str(os.environ.get("DISPLAY", "")).strip().lower() + if not display: + return "local" + if os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"): + if display.startswith("localhost:") or display.startswith("localhost/unix:") or display.startswith("127.0.0.1:"): + return "local" + return "display" + + +def _normalize(v): + arr = np.asarray(v, dtype=float).reshape(3) + n = np.linalg.norm(arr) + if n <= 1e-12: + return np.array([1.0, 0.0, 0.0], dtype=float) + return arr / n + + +def _path_polydata(path_world): + pts = np.asarray(path_world, dtype=float).reshape(-1, 3) + if len(pts) == 0: + return None + poly = pv.PolyData(pts) + if len(pts) >= 2: + cells = np.empty((len(pts) - 1, 3), dtype=np.int64) + cells[:, 0] = 2 + cells[:, 1] = np.arange(len(pts) - 1) + cells[:, 2] = np.arange(1, len(pts)) + poly.lines = cells.ravel() + return poly + + +def _plane_mesh(center_world, normal, size): + return pv.Plane( + center=np.asarray(center_world, dtype=float).reshape(3), + direction=_normalize(normal), + i_size=float(size), + j_size=float(size), + i_resolution=1, + j_resolution=1, + ) + + +def _scalar_bar_args(title, bar_cfg=None): + cfg = { + "title": title, + "vertical": True, + "position_x": 0.86, + "position_y": 0.1, + "height": 0.8, + "width": 0.08, + "title_font_size": 18, + "label_font_size": 14, + "n_labels": 5, + "fmt": "%.3g", + } + if bar_cfg: + cfg.update(bar_cfg) + return cfg + + +def _ensure_offscreen(): + global _OFFSCREEN_BOOTSTRAPPED + if _OFFSCREEN_BOOTSTRAPPED: + return + _OFFSCREEN_BOOTSTRAPPED = True + os.environ["PYVISTA_OFF_SCREEN"] = "true" + os.environ.pop("DISPLAY", None) + try: + if hasattr(pv, "start_xvfb"): + pv.start_xvfb() + except Exception: + pass + + +def _make_plotter(window_size=WINDOW_SIZE): + _ensure_offscreen() + plotter = pv.Plotter(off_screen=True, window_size=window_size) + plotter.set_background("white") + return plotter + + +def _write_video(frames, out_path, fps=24): + if not frames: + return None + out_path = os.path.splitext(out_path)[0] + ".mp4" + os.makedirs(os.path.dirname(out_path), exist_ok=True) + try: + with imageio.get_writer(out_path, fps=fps, codec="libx264", macro_block_size=None) as writer: + for frame in frames: + writer.append_data(np.asarray(frame)) + return out_path + except Exception: + gif_path = os.path.splitext(out_path)[0] + ".gif" + imageio.mimsave(gif_path, [np.asarray(frame) for frame in frames], duration=1.0 / max(int(fps), 1)) + return gif_path + + +def _surface_center_radius(poly): + if poly is None or poly.n_points == 0: + return np.zeros(3, dtype=float), 100.0 + bounds = np.array(poly.bounds, dtype=float).reshape(3, 2) + center = bounds.mean(axis=1) + extent = np.maximum(bounds[:, 1] - bounds[:, 0], 1.0) + radius = float(max(np.linalg.norm(extent) * 1.2, 50.0)) + return center, radius + + +def _resolve_view(view): + if view is None: + return CAMERA_PRESETS["iso"] + if isinstance(view, str): + if view not in CAMERA_PRESETS: + raise ValueError(f"unknown camera preset: {view}, options: {list(CAMERA_PRESETS)}") + return CAMERA_PRESETS[view] + azimuth_deg, elevation_deg = view + return float(azimuth_deg), float(elevation_deg) + + +def _camera_from_scene(poly, azimuth_deg=35.0, elevation_deg=25.0, distance_scale=1.0): + center, radius = _surface_center_radius(poly) + radius = radius * float(distance_scale) + azimuth = np.deg2rad(float(azimuth_deg)) + elevation = np.deg2rad(float(elevation_deg)) + pos = center + np.array( + [ + radius * np.cos(elevation) * np.cos(azimuth), + radius * np.cos(elevation) * np.sin(azimuth), + radius * np.sin(elevation), + ], + dtype=float, + ) + if abs(elevation_deg) > 80.0: + up = (0.0, 1.0, 0.0) + else: + up = (0.0, 0.0, 1.0) + return [tuple(pos.tolist()), tuple(center.tolist()), up] + + +def _orbit_camera(poly, azimuth_deg, elevation_deg=25.0, distance_scale=1.0): + return _camera_from_scene(poly, azimuth_deg, elevation_deg, distance_scale) + + +def _camera_from_view(poly, view, distance_scale=1.0): + azimuth_deg, elevation_deg = _resolve_view(view) + return _camera_from_scene(poly, azimuth_deg, elevation_deg, distance_scale) + + +def _time_and_azimuth(frame_idx, rotation_frames, n_time, time_repeat=1): + rotation_frames = int(max(rotation_frames, 1)) + n_time = int(max(n_time, 1)) + time_repeat = int(max(time_repeat, 1)) + + t = (frame_idx // time_repeat) % n_time + azimuth_deg = 360.0 * (frame_idx % rotation_frames) / rotation_frames + return t, azimuth_deg + + +def _build_union_surface(ws, smoothing_iteration=200): + if ws.segmask_binary is not None: + mask3d = np.any(np.asarray(ws.segmask_binary, dtype=bool), axis=3) + else: + mask3d = np.asarray(ws.segmask_3d, dtype=bool) + mesh = create_uniform_grid(mask3d, ws.resolution, origin=ws.origin) + mesh = mesh.threshold(0.1) + if mesh is None or mesh.n_cells == 0: + return None, None + surf = mesh.extract_surface() + if surf is not None and surf.n_points > 0 and int(smoothing_iteration) > 0: + surf = surf.smooth(n_iter=int(smoothing_iteration)) + return mesh, surf + + +def _plane_size_from_surface(surf): + if surf is None or surf.n_points == 0: + return 25.0 + bounds = np.array(surf.bounds, dtype=float).reshape(3, 2) + extent = bounds[:, 1] - bounds[:, 0] + return float(max(12.0, 0.12 * np.max(extent))) + + +def render_plane_rotation_video( + ws, + out_dir, + fps=24, + n_frames=180, + smoothing_iteration=200, + elevation_deg=0.0, + distance_scale=1.0, + add_plane_idx=False, + add_path_idx=False, +): + _, surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) + if surf is None or surf.n_points == 0: + return None + + plane_size = _plane_size_from_surface(surf) + origin = np.asarray(ws.origin, dtype=float).reshape(3) + plotter = _make_plotter() + plotter.add_mesh(surf, opacity=0.18, color="white") + + paths_world = [] + for path in ws.centerline_paths_smooth: + path_world = np.asarray(path, dtype=float) + origin.reshape(1, 3) + paths_world.append(path_world) + poly = _path_polydata(path_world) + if poly is not None and poly.n_points > 0: + plotter.add_mesh(poly, color="deepskyblue", line_width=5, render_lines_as_tubes=True) + + centers = [] + plane_labels = [] + for i, plane in enumerate(ws.planes): + center_world = np.asarray(plane.center, dtype=float).reshape(3) + origin + plane_mesh = _plane_mesh(center_world, plane.normal, plane_size) + plotter.add_mesh(plane_mesh, color="yellow", opacity=0.75, show_edges=True, edge_color="black", line_width=2) + centers.append(center_world) + plane_labels.append(f"Plane {i}") + + if add_plane_idx and centers: + plotter.add_point_labels( + np.asarray(centers, dtype=float), + plane_labels, + font_size=28, + bold=True, + text_color="black", + fill_shape=True, + shape="rounded_rect", + shape_color="yellow", + shape_opacity=0.85, + margin=5, + always_visible=True, + ) + + if add_path_idx and paths_world: + path_label_points = [] + path_label_texts = [] + offsets = [ + np.array([0, 0, 0]), + np.array([3, 0, 0]), + np.array([-3, 0, 0]), + np.array([0, 3, 0]), + np.array([0, -3, 0]), + ] + frac_choices = [0.25, 0.5, 0.75, 0.35, 0.65] + + for idx, path_world in enumerate(paths_world): + if path_world is None or len(path_world) == 0: + continue + n_points = len(path_world) + frac = frac_choices[idx % len(frac_choices)] + k = min(max(int(frac * (n_points - 1)), 0), n_points - 1) + anchor = np.asarray(path_world[k], dtype=float) + offsets[idx % len(offsets)] + path_label_points.append(anchor) + path_label_texts.append(f"Branch {idx}") + + if path_label_points: + plotter.add_point_labels( + np.asarray(path_label_points, dtype=float), + path_label_texts, + font_size=20, + bold=True, + text_color="black", + fill_shape=True, + shape="rounded_rect", + shape_color="deepskyblue", + shape_opacity=0.85, + margin=2, + always_visible=True, + ) + + frames = [] + for frame_idx in range(int(max(n_frames, 1))): + azimuth_deg = 360.0 * frame_idx / max(n_frames, 1) + plotter.camera_position = _orbit_camera(surf, azimuth_deg, elevation_deg, distance_scale) + plotter.add_text( + f"Rotating {frame_idx + 1}/{int(max(n_frames, 1))}", + position="upper_left", + font_size=14, + color="black", + name="frame_text", + ) + plotter.render() + frames.append(np.asarray(plotter.screenshot(return_img=True))) + try: + plotter.remove_actor("frame_text") + except Exception: + pass + plotter.close() + return _write_video(frames, os.path.join(out_dir, "planes_rotate.mp4"), fps=fps) + + +def render_wss_video( + ws, + out_dir, + fps=24, + smoothing_iteration=200, + view="iso", + distance_scale=1.0, + wss_clim=None, + wss_bar_cfg=None, + rotate=False, + rotation_frames=None, + elevation_deg=None, + time_repeat=1, +): + if not ws.derived.wss_surfaces: + return None + + _, context_surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) + if context_surf is None or context_surf.n_points == 0: + return None + + wss_max = 0.0 + for surf in ws.derived.wss_surfaces: + if surf is not None and surf.n_points > 0 and "wss" in surf.point_data: + vals = np.asarray(surf.point_data["wss"], dtype=float) + if vals.size: + wss_max = max(wss_max, float(np.nanmax(vals))) + wss_max = max(wss_max, 1e-6) + clim = wss_clim if wss_clim is not None else (0.0, wss_max) + + _, default_elevation_deg = _resolve_view(view) + if elevation_deg is None: + elevation_deg = default_elevation_deg + + n_time = int(max(ws.time_count(), 1)) + if rotate: + base_frames = n_time * int(max(time_repeat, 1)) + if rotation_frames is not None: + total_frames = max(int(rotation_frames), base_frames) + else: + total_frames = base_frames + else: + total_frames = n_time * int(max(time_repeat, 1)) + + plotter = _make_plotter() + frames = [] + + for frame_idx in range(total_frames): + if rotate: + t, azimuth_deg = _time_and_azimuth( + frame_idx, + rotation_frames=rotation_frames if rotation_frames is not None else total_frames, + n_time=n_time, + time_repeat=time_repeat, + ) + camera_position = _orbit_camera(context_surf, azimuth_deg, elevation_deg, distance_scale) + else: + t = min(frame_idx, n_time - 1) + camera_position = _camera_from_view(context_surf, view, distance_scale) + + plotter.clear() + plotter.set_background("white") + plotter.add_mesh(context_surf, opacity=0.08, color="white") + + surf = ws.derived.wss_surfaces[min(max(0, t), len(ws.derived.wss_surfaces) - 1)] + if surf is not None and surf.n_points > 0 and "wss" in surf.point_data: + plotter.add_mesh( + surf, + scalars="wss", + cmap="jet", + clim=clim, + show_scalar_bar=True, + scalar_bar_args=_scalar_bar_args("WSS (Pa)", wss_bar_cfg), + ) + + if rotate: + txt = f"t={t} | rot {frame_idx + 1}/{total_frames}" + else: + txt = f"t={t}" + + plotter.add_text(txt, position="upper_left", font_size=14, color="black") + plotter.camera_position = camera_position + plotter.render() + frames.append(np.asarray(plotter.screenshot(return_img=True))) + + plotter.close() + suffix = "rotate" if rotate else "video" + return _write_video(frames, os.path.join(out_dir, f"wss_{suffix}.mp4"), fps=fps) + + +def _streamline_speed_max(ws): + if ws.flow_raw is None: + return 1e-6 + speed = np.linalg.norm(np.asarray(ws.flow_raw, dtype=float) / 100.0, axis=-1) + if ws.segmask_binary is not None and np.any(ws.segmask_binary): + vals = speed[np.asarray(ws.segmask_binary, dtype=bool)] + if vals.size: + return max(float(np.nanmax(vals)), 1e-6) + return max(float(np.nanmax(speed)), 1e-6) + + +def _ensure_streamline_scalars(sl): + if sl is None: + return sl + if "Velocity" in sl.point_data or "Velocity" in sl.cell_data: + return sl + if "vector" in sl.point_data: + sl.point_data["Velocity"] = np.linalg.norm(np.asarray(sl.point_data["vector"], dtype=float), axis=1) + return sl + if "vector" in sl.cell_data: + sl.cell_data["Velocity"] = np.linalg.norm(np.asarray(sl.cell_data["vector"], dtype=float), axis=1) + return sl + return sl + + +def render_streamlines_video( + ws, + out_dir, + fps=24, + smoothing_iteration=200, + view="iso", + distance_scale=1.0, + streamline_clim=None, + streamline_bar_cfg=None, + rotate=False, + rotation_frames=None, + elevation_deg=None, + time_repeat=1, +): + if ws.flow_raw is None or ws.segmask_binary is None or ws.segmask_3d is None: + return None + + mesh, surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) + if mesh is None or surf is None or surf.n_points == 0: + return None + + seeds = generate_seed_points( + ws.segmask_3d, + ws.resolution, + ws.origin, + ratio=ws.streamline_params.seed_ratio, + rng_seed=ws.streamline_params.rng_seed, + min_seeds=50, + ) + + v_max = _streamline_speed_max(ws) + clim = streamline_clim if streamline_clim is not None else (0.0, v_max) + + _, default_elevation_deg = _resolve_view(view) + if elevation_deg is None: + elevation_deg = default_elevation_deg + + n_time = int(max(ws.time_count(), 1)) + if rotate: + base_frames = n_time * int(max(time_repeat, 1)) + if rotation_frames is not None: + total_frames = max(int(rotation_frames), base_frames) + else: + total_frames = base_frames + else: + total_frames = n_time * int(max(time_repeat, 1)) + + plotter = _make_plotter() + frames = [] + + for frame_idx in range(total_frames): + if rotate: + t, azimuth_deg = _time_and_azimuth( + frame_idx, + rotation_frames=rotation_frames if rotation_frames is not None else total_frames, + n_time=n_time, + time_repeat=time_repeat, + ) + camera_position = _orbit_camera(surf, azimuth_deg, elevation_deg, distance_scale) + else: + t = min(frame_idx, n_time - 1) + camera_position = _camera_from_view(surf, view, distance_scale) + + mask_t = np.asarray( + ws.segmask_binary[..., min(max(0, t), ws.segmask_binary.shape[3] - 1)], + dtype=bool, + ) + + sl = generate_streamlines_at_t( + ws.flow_raw, + t, + seeds, + ws.resolution, + ws.origin, + mask_3d=mask_t, + max_steps=ws.streamline_params.max_steps, + terminal_speed=ws.streamline_params.terminal_speed, + seed_ratio=ws.streamline_params.seed_ratio, + min_seeds=50, + rng_seed=ws.streamline_params.rng_seed, + ) + sl = _ensure_streamline_scalars(sl) + + plotter.clear() + plotter.set_background("white") + plotter.add_mesh(surf, opacity=0.18, color="lightgray") + + if sl is not None and sl.n_points > 0: + plotter.add_mesh( + sl, + scalars="Velocity", + cmap="turbo", + clim=clim, + show_scalar_bar=True, + scalar_bar_args=_scalar_bar_args("Velocity (m/s)", streamline_bar_cfg), + render_lines_as_tubes=True, + line_width=3, + ) + + if rotate: + txt = f"t={t} | rot {frame_idx + 1}/{total_frames}" + else: + txt = f"t={t}" + + plotter.add_text(txt, position="upper_left", font_size=14, color="black") + plotter.camera_position = camera_position + plotter.render() + frames.append(np.asarray(plotter.screenshot(return_img=True))) + + plotter.close() + suffix = "rotate" if rotate else "video" + return _write_video(frames, os.path.join(out_dir, f"streamlines_{suffix}.mp4"), fps=fps) + + +def _tke_max(ws): + if ws.derived.tke_array is not None: + return max(float(np.nanmax(np.asarray(ws.derived.tke_array, dtype=float))), 1e-6) + tke_mesh = ws.derived.tke_volume + if tke_mesh is None: + return 1e-6 + if "TKE" in tke_mesh.point_data: + return max(float(np.nanmax(np.asarray(tke_mesh.point_data["TKE"], dtype=float))), 1e-6) + if "TKE" in tke_mesh.cell_data: + return max(float(np.nanmax(np.asarray(tke_mesh.cell_data["TKE"], dtype=float))), 1e-6) + return 1e-6 + + +def render_tke_video( + ws, + out_dir, + fps=24, + smoothing_iteration=200, + view="iso", + distance_scale=1.0, + tke_clim=None, + tke_bar_cfg=None, + rotate=False, + rotation_frames=None, + elevation_deg=None, + time_repeat=1, +): + if ws.derived.tke_array is None and ws.derived.tke_volume is None: + return None + + _, surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) + if surf is None or surf.n_points == 0: + return None + + tke_max = _tke_max(ws) + clim = tke_clim if tke_clim is not None else (0.0, tke_max) + + _, default_elevation_deg = _resolve_view(view) + if elevation_deg is None: + elevation_deg = default_elevation_deg + + n_time = int(max(ws.time_count(), 1)) + if rotate: + base_frames = n_time * int(max(time_repeat, 1)) + if rotation_frames is not None: + total_frames = max(int(rotation_frames), base_frames) + else: + total_frames = base_frames + else: + total_frames = n_time * int(max(time_repeat, 1)) + + plotter = _make_plotter() + frames = [] + + for frame_idx in range(total_frames): + if rotate: + t, azimuth_deg = _time_and_azimuth( + frame_idx, + rotation_frames=rotation_frames if rotation_frames is not None else total_frames, + n_time=n_time, + time_repeat=time_repeat, + ) + camera_position = _orbit_camera(surf, azimuth_deg, elevation_deg, distance_scale) + else: + t = min(frame_idx, n_time - 1) + camera_position = _camera_from_view(surf, view, distance_scale) + + plotter.clear() + plotter.set_background("white") + plotter.add_mesh(surf, opacity=0.08, color="white") + + if ws.derived.tke_array is not None: + arr = np.asarray(ws.derived.tke_array, dtype=np.float32) + if arr.ndim == 4: + vol_t = arr[..., min(max(0, t), arr.shape[3] - 1)] + else: + vol_t = arr + tke_mesh = create_uniform_grid(vol_t, ws.resolution, origin=ws.origin, name="TKE") + mesh_union = create_uniform_grid( + np.max(ws.segmask_binary > 0, axis=-1), + ws.resolution, + origin=ws.origin, + ) + mesh_union = mesh_union.threshold(0.1) + tke_mesh = mesh_union.sample(tke_mesh) + plotter.add_mesh( + tke_mesh, + scalars="TKE", + cmap="hot", + clim=clim, + show_scalar_bar=True, + scalar_bar_args=_scalar_bar_args("TKE (J/m³)", tke_bar_cfg), + ) + else: + plotter.add_mesh( + ws.derived.tke_volume, + scalars="TKE", + cmap="hot", + clim=clim, + show_scalar_bar=True, + scalar_bar_args=_scalar_bar_args("TKE (J/m³)", tke_bar_cfg), + ) + + if rotate: + txt = f"t={t} | rot {frame_idx + 1}/{total_frames}" + else: + txt = f"t={t}" + + plotter.add_text(txt, position="upper_left", font_size=14, color="black") + plotter.camera_position = camera_position + plotter.render() + frames.append(np.asarray(plotter.screenshot(return_img=True))) + + plotter.close() + suffix = "rotate" if rotate else "video" + return _write_video(frames, os.path.join(out_dir, f"tke_{suffix}.mp4"), fps=fps) + + +def extract_frame(mp4_path, frame_index, out_png): + reader = imageio.get_reader(mp4_path, format="ffmpeg") + frame = reader.get_data(frame_index) + reader.close() + Image.fromarray(np.asarray(frame)).save(out_png, format="PNG", compress_level=0) + print(f"Saved frame {frame_index} -> {out_png}") diff --git a/autoflow/reporting.py b/autoflow/reporting.py new file mode 100755 index 0000000..34c54f6 --- /dev/null +++ b/autoflow/reporting.py @@ -0,0 +1,141 @@ +import json +import os + +import numpy as np + +from .algorithms import load_metrics_as_table + + +def load_metrics_from_output(out_dir): + metrics_path = os.path.join(out_dir, "plane_metrics.json") + qc_path = os.path.join(out_dir, "plane_qc.json") + if not os.path.isfile(metrics_path): + return None, None, None + qc_p = qc_path if os.path.isfile(qc_path) else None + table_rows, raw_metrics, qc_data = load_metrics_as_table(metrics_path, qc_p) + return table_rows, raw_metrics, qc_data + + +def print_metrics_summary(table_rows): + if not table_rows: + print(" No metrics to summarize.") + return + print( + f" {'Plane':>6} {'Path':>5} {'Net Flow(mL/beat)':>18} " + f"{'Peak Velocity(cm/s)':>20} {'Mean Velocity(cm/s)':>20} " + f"{'Reflux':>7} {'IC':>6}" + ) + print(f" {'-'*6} {'-'*5} {'-'*18} {'-'*20} {'-'*20} " f"{'-'*7} {'-'*6}") + for row in table_rows: + pidx = row.get("plane_index", "?") + path = row.get("path_index", "?") + nf = row.get("netflow_mL_beat", 0.0) + pv_ = row.get("peakv_cm_s", 0.0) + if "meanv_signed_cm_s" in row: + mv = row["meanv_signed_cm_s"] + else: + mv = row.get("meanv_cm_s", 0.0) + refl = row.get("reflux_fraction", 0.0) + ic = row.get("path_ic", 1.0) + print(f" {pidx:>6} {path:>5} {nf:>18.4f} " f"{pv_:>20.3f} {mv:>20.3f} {refl:>7.3f} {ic:>6.3f}") + + +def _format_path_group(v): + if v is None: + return "?" + if isinstance(v, dict): + if "paths" in v: + v = v["paths"] + elif "path_indices" in v: + v = v["path_indices"] + if isinstance(v, (int, np.integer)): + return f"Branch{int(v)}" + if isinstance(v, str): + return v + try: + vals = list(v) + except Exception: + return str(v) + if not vals: + return "-" + out = [] + for x in vals: + if isinstance(x, (int, np.integer)): + out.append(f"Branch{int(x)}") + else: + out.append(str(x)) + return "+".join(out) + + +def _fork_side_text(fork): + if not isinstance(fork, dict): + return "left=?", "right=?" + + left = ( + fork.get("left") + or fork.get("left_paths") + or fork.get("left_path_indices") + or fork.get("in_paths") + or fork.get("in_path_indices") + ) + right = ( + fork.get("right") + or fork.get("right_paths") + or fork.get("right_path_indices") + or fork.get("out_paths") + or fork.get("out_path_indices") + ) + + return f"{_format_path_group(left)}", f"{_format_path_group(right)}" + + +def print_qc_summary(qc_data, forks=None): + if not qc_data: + print(" No QC results to summarize.") + return + + items = [] + if isinstance(qc_data, list): + items = qc_data + elif isinstance(qc_data, dict): + if isinstance(qc_data.get("forks"), list): + items = qc_data["forks"] + elif isinstance(qc_data.get("fork_qc"), list): + items = qc_data["fork_qc"] + else: + for k, v in qc_data.items(): + if isinstance(v, dict): + row = dict(v) + row.setdefault("fork_index", k) + items.append(row) + + rows = [] + for i, item in enumerate(items): + if not isinstance(item, dict): + continue + fork_idx = item.get("fork_index", item.get("fork_id", item.get("fork", i))) + ic = item.get("internal_consistency", item.get("ic", item.get("path_ic", np.nan))) + + fork_obj = None + if isinstance(forks, list): + try: + if isinstance(fork_idx, (int, np.integer)) and 0 <= int(fork_idx) < len(forks): + fork_obj = forks[int(fork_idx)] + elif i < len(forks): + fork_obj = forks[i] + except Exception: + pass + + left_txt, right_txt = _fork_side_text(fork_obj) + rows.append((fork_idx, ic, left_txt, right_txt)) + + if not rows: + print(" No fork-level QC results found.") + print(json.dumps(qc_data, ensure_ascii=False, indent=2)) + return + + print(f" {'Fork':>6} {'Internal Consistency':>24} {'Left':>24} {'Right':>24}") + print(f" {'-'*6} {'-'*24} {'-'*24} {'-'*24}") + for fork_idx, ic, left_txt, right_txt in rows: + ic_str = f"{float(ic):.6f}" if np.isfinite(ic) else "nan" + print(f" {str(fork_idx):>6} {ic_str:>24} {left_txt:>24} {right_txt:>24}") diff --git a/autoflow/ui/__init__.py b/autoflow/ui/__init__.py new file mode 100755 index 0000000..562afdf --- /dev/null +++ b/autoflow/ui/__init__.py @@ -0,0 +1,41 @@ +__all__ = [ + "GraphEditor", + "MainWindow", + "OrthoViewer", + "PlaneEditor", + "SceneController", + "SkeletonEditor", + "launch_gui", + "main", +] + + +def __getattr__(name): + if name in {"MainWindow", "main"}: + from .app import MainWindow, main + + return { + "MainWindow": MainWindow, + "main": main, + }[name] + if name in {"GraphEditor", "PlaneEditor", "SkeletonEditor"}: + from .editors import GraphEditor, PlaneEditor, SkeletonEditor + + return { + "GraphEditor": GraphEditor, + "PlaneEditor": PlaneEditor, + "SkeletonEditor": SkeletonEditor, + }[name] + if name == "launch_gui": + from .launcher import launch_gui + + return launch_gui + if name == "OrthoViewer": + from .ortho_viewer import OrthoViewer + + return OrthoViewer + if name == "SceneController": + from .viewer import SceneController + + return SceneController + raise AttributeError(f"module 'autoflow.ui' has no attribute {name!r}") diff --git a/autoflow/ui/app.py b/autoflow/ui/app.py new file mode 100755 index 0000000..484e712 --- /dev/null +++ b/autoflow/ui/app.py @@ -0,0 +1,1908 @@ +import json +import os +import sys +import time +import traceback +from functools import partial + +import numpy as np +import pyvista as pv +from PyQt5 import QtCore, QtWidgets +from pyvistaqt import QtInteractor +from pyvista import _vtk + +from ..core.models import ObjectKind, StepId, Workspace +from ..core.pipeline import PipelineEngine +from ..algorithms import compute_plane_metrics, apply_internal_consistency_to_metrics, compute_plane_metrics_multithread +from .editors import PlaneEditor, SkeletonEditor +from .ortho_viewer import OrthoViewer +from .viewer import SceneController + + +def _parse_plane_index(data_key): + if not isinstance(data_key, str) or not data_key.startswith("plane_"): + return None + suffix = data_key[len("plane_"):] + if suffix.isdigit(): + return int(suffix) + return None + + +def _parse_path_index(data_key): + if not data_key.startswith("smooth_path_"): + return None + suffix = data_key[len("smooth_path_"):] + try: + return int(suffix) + except ValueError: + return None + + +class MainWindow(QtWidgets.QMainWindow): + def __init__(self): + super().__init__() + self.setWindowTitle("AutoFlow") + self.resize(1800, 980) + self.workspace = Workspace() + self.pipeline = PipelineEngine() + self.scene = None + self._play_timer = QtCore.QTimer(self) + self._play_timer.timeout.connect(self._on_play_tick) + self._edit_mode = None + self._edit_points = None + self._edit_edges = None + self._edit_selected_idx = None + self._edit_edge_mode = False + self._edit_edge_src_idx = None + self._edit_selected_edge_idx = None + self._edit_sel_edge_poly = None + self._edit_sel_edge_actor = None + self._edit_poly = None + self._edit_actor = None + self._edit_edge_poly = None + self._edit_edge_actor = None + self._edit_sel_poly = None + self._edit_sel_actor = None + self._edit_widget = None + self._edit_pick_enabled = False + self._vtk_left_click_obs_id = None + self._vtk_keypress_obs_id = None + self._vtk_point_picker = None + self._edit_overlay_dialog = None + self._edit_info_label = None + self._edit_status_label = None + self._edit_btn_edge = None + self._edit_original_points = None + self._edit_original_edges = None + self._plane_drag_active = False + self._plane_drag_index = None + self._plane_widget_initializing = False + self._plane_drag_metrics_dirty = False + self._selected_plane_index = -1 + self._plane_drag_timer = QtCore.QTimer(self) + self._plane_drag_timer.setSingleShot(True) + self._plane_drag_timer.timeout.connect(lambda: self._recompute_dragged_plane_metrics(persist=False)) + self._build_ui() + self._bind_scene() + self._esc_shortcut = QtWidgets.QShortcut(QtCore.Qt.Key_Escape, self) + self._esc_shortcut.setContext(QtCore.Qt.ApplicationShortcut) + self._esc_shortcut.activated.connect(self._force_exit_edit) + QtCore.QTimer.singleShot(0, self._setup_focus_behavior) + self._refresh_all() + + def _setup_focus_behavior(self): + try: + self.plotter.setFocusPolicy(QtCore.Qt.ClickFocus) + except Exception: + pass + + def _build_ui(self): + self._build_menu() + central = QtWidgets.QWidget() + self.setCentralWidget(central) + root = QtWidgets.QVBoxLayout(central) + root.setContentsMargins(6, 6, 6, 6) + root.setSpacing(6) + splitter = QtWidgets.QSplitter(QtCore.Qt.Horizontal) + root.addWidget(splitter, 1) + left = QtWidgets.QWidget() + left_lay = QtWidgets.QVBoxLayout(left) + left_lay.setContentsMargins(0, 0, 0, 0) + left_lay.setSpacing(6) + self._build_browser(left_lay) + mid = QtWidgets.QWidget() + mid_lay = QtWidgets.QVBoxLayout(mid) + mid_lay.setContentsMargins(0, 0, 0, 0) + mid_lay.setSpacing(6) + self.plotter = QtInteractor(self) + self.plotter.setFocusPolicy(QtCore.Qt.ClickFocus) + mid_splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical) + mid_splitter.addWidget(self.plotter) + step_and_params = QtWidgets.QWidget() + sp_lay = QtWidgets.QVBoxLayout(step_and_params) + sp_lay.setContentsMargins(0, 0, 0, 0) + sp_lay.setSpacing(4) + self._build_step_buttons(sp_lay) + scroll = QtWidgets.QScrollArea() + scroll.setWidgetResizable(True) + params_widget = QtWidgets.QWidget() + self.params_layout = QtWidgets.QVBoxLayout(params_widget) + self.params_layout.setContentsMargins(4, 4, 4, 4) + self._build_preprocess_params() + self._build_skeleton_params() + self._build_plane_params() + self._build_streamline_params() + self._build_derived_params() + self.params_layout.addStretch() + scroll.setWidget(params_widget) + sp_lay.addWidget(scroll, 1) + mid_splitter.addWidget(step_and_params) + mid_splitter.setStretchFactor(0, 4) + mid_splitter.setStretchFactor(1, 2) + mid_lay.addWidget(mid_splitter, 1) + right = QtWidgets.QWidget() + right_lay = QtWidgets.QVBoxLayout(right) + right_lay.setContentsMargins(0, 0, 0, 0) + self.ortho_viewer = OrthoViewer(self.workspace, self) + right_lay.addWidget(self.ortho_viewer) + splitter.addWidget(left) + splitter.addWidget(mid) + splitter.addWidget(right) + splitter.setSizes([300, 850, 450]) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 4) + splitter.setStretchFactor(2, 2) + bot = QtWidgets.QWidget() + bot_lay = QtWidgets.QVBoxLayout(bot) + bot_lay.setContentsMargins(0, 0, 0, 0) + bot_lay.setSpacing(4) + bot_splitter = QtWidgets.QSplitter(QtCore.Qt.Vertical) + root.addWidget(bot_splitter, 0) + timeline_w = QtWidgets.QWidget() + tl_lay = QtWidgets.QVBoxLayout(timeline_w) + tl_lay.setContentsMargins(0, 0, 0, 0) + self._build_timeline(tl_lay) + bot_splitter.addWidget(timeline_w) + sel_w = QtWidgets.QWidget() + sel_lay = QtWidgets.QVBoxLayout(sel_w) + sel_lay.setContentsMargins(0, 0, 0, 0) + self._build_selection_info(sel_lay) + bot_splitter.addWidget(sel_w) + log_w = QtWidgets.QWidget() + log_lay = QtWidgets.QVBoxLayout(log_w) + log_lay.setContentsMargins(0, 0, 0, 0) + self._build_log(log_lay) + bot_splitter.addWidget(log_w) + bot_splitter.setSizes([40, 80, 60]) + + def _build_browser(self, parent): + grp = QtWidgets.QGroupBox("Browser") + lay = QtWidgets.QVBoxLayout(grp) + self.tree_objects = QtWidgets.QTreeWidget() + self.tree_objects.setHeaderLabels(["Name", "Kind", "Visible"]) + self.tree_objects.setColumnWidth(0, 200) + self.tree_objects.itemSelectionChanged.connect(self._on_browser_select) + self.tree_objects.itemChanged.connect(self._on_tree_item_changed) + self.tree_objects.setContextMenuPolicy(QtCore.Qt.CustomContextMenu) + self.tree_objects.customContextMenuRequested.connect(self._on_browser_ctx_menu) + lay.addWidget(self.tree_objects) + row = QtWidgets.QHBoxLayout() + self.btn_delete_obj = QtWidgets.QPushButton("Delete Selected") + self.btn_delete_obj.clicked.connect(self._on_delete_object) + row.addWidget(self.btn_delete_obj) + row.addStretch() + lay.addLayout(row) + parent.addWidget(grp, 3) + + def _build_step_buttons(self, parent): + grp = QtWidgets.QGroupBox("Steps") + gl = QtWidgets.QGridLayout(grp) + self.step_buttons = {} + for row_idx, steps in enumerate([StepId.top_row_steps(), StepId.bottom_row_steps(), StepId.extra_row_steps()]): + for i, s in enumerate(steps): + b = QtWidgets.QPushButton(s.label) + b.clicked.connect(partial(self._run_single_step, s)) + self.step_buttons[s] = b + gl.addWidget(b, row_idx, i) + btn_run_all = QtWidgets.QPushButton("▶▶ Run All (Generate → Metrics → WSS/TKE)") + btn_run_all.setStyleSheet("QPushButton { background-color: #2a6; color: white; font-weight: bold; padding: 4px; }") + btn_run_all.clicked.connect(self._run_all_pipeline) + gl.addWidget(btn_run_all, 3, 0, 1, 4) + parent.addWidget(grp, 0) + + def _build_preprocess_params(self): + return + + def _build_skeleton_params(self): + grp = QtWidgets.QGroupBox("Generate Skeleton Parameters") + fl = QtWidgets.QFormLayout(grp) + self.chk_remove_small_cc = QtWidgets.QCheckBox() + self.chk_remove_small_cc.setChecked(False) + self.edit_min_cc_volume = QtWidgets.QLineEdit("50.0") + self.chk_closing = QtWidgets.QCheckBox() + self.chk_closing.setChecked(True) + self.chk_opening = QtWidgets.QCheckBox() + self.chk_gaussian = QtWidgets.QCheckBox() + self.chk_gaussian.setChecked(True) + self.edit_gauss_sigma = QtWidgets.QLineEdit("0.5") + fl.addRow("Remove Small CC", self.chk_remove_small_cc) + fl.addRow(u"Min Volume (mm\u00b3)", self.edit_min_cc_volume) + fl.addRow("Closing", self.chk_closing) + fl.addRow("Opening", self.chk_opening) + fl.addRow("Gaussian", self.chk_gaussian) + fl.addRow(u"Gauss \u03c3", self.edit_gauss_sigma) + self.params_layout.addWidget(grp) + + def _build_plane_params(self): + grp = QtWidgets.QGroupBox("Generate Planes Parameters") + fl = QtWidgets.QFormLayout(grp) + self.radio_plane_by_distance = QtWidgets.QRadioButton("By Distance") + self.radio_plane_center = QtWidgets.QRadioButton("Center of Path") + self.radio_plane_center.setChecked(True) + mode_row = QtWidgets.QHBoxLayout() + mode_row.addWidget(self.radio_plane_by_distance) + mode_row.addWidget(self.radio_plane_center) + mode_w = QtWidgets.QWidget() + mode_w.setLayout(mode_row) + self.edit_plane_dist = QtWidgets.QLineEdit("20.0") + self.edit_plane_start = QtWidgets.QLineEdit("5.0") + self.edit_plane_end = QtWidgets.QLineEdit("0.0") + self.edit_plane_smooth_win = QtWidgets.QLineEdit("15") + self.edit_plane_smooth_poly = QtWidgets.QLineEdit("2") + self.edit_plane_inter_time = QtWidgets.QLineEdit("10") + fl.addRow("Plane Mode", mode_w) + fl.addRow("Cross-section Distance (mm)", self.edit_plane_dist) + fl.addRow("Start Distance (mm)", self.edit_plane_start) + fl.addRow("End Distance (mm)", self.edit_plane_end) + fl.addRow("SavGol Window", self.edit_plane_smooth_win) + fl.addRow("SavGol Polyorder", self.edit_plane_smooth_poly) + fl.addRow("Inter-time", self.edit_plane_inter_time) + self.params_layout.addWidget(grp) + + def _build_streamline_params(self): + grp = QtWidgets.QGroupBox("Streamline Parameters") + fl = QtWidgets.QFormLayout(grp) + self.edit_sl_ratio = QtWidgets.QLineEdit("0.02") + self.edit_sl_maxsteps = QtWidgets.QLineEdit("2000") + self.edit_sl_terminal = QtWidgets.QLineEdit("0.01") + fl.addRow("Seed Ratio", self.edit_sl_ratio) + fl.addRow("Max Steps", self.edit_sl_maxsteps) + fl.addRow("Terminal Speed", self.edit_sl_terminal) + self.params_layout.addWidget(grp) + + def _build_derived_params(self): + grp_wss = QtWidgets.QGroupBox("WSS Parameters") + fl_wss = QtWidgets.QFormLayout(grp_wss) + self.edit_dm_smoothing = QtWidgets.QLineEdit("200") + self.edit_dm_viscosity = QtWidgets.QLineEdit("4.0") + self.edit_dm_inward = QtWidgets.QLineEdit("0.6") + self.chk_dm_parabolic = QtWidgets.QCheckBox() + self.chk_dm_parabolic.setChecked(True) + self.chk_dm_noslip = QtWidgets.QCheckBox() + self.chk_dm_noslip.setChecked(True) + fl_wss.addRow("Smoothing Iterations", self.edit_dm_smoothing) + fl_wss.addRow(u"Viscosity (mPa\u00b7s)", self.edit_dm_viscosity) + fl_wss.addRow("Inward Distance (mm)", self.edit_dm_inward) + fl_wss.addRow("Parabolic Fitting", self.chk_dm_parabolic) + fl_wss.addRow("No-Slip Condition", self.chk_dm_noslip) + self.params_layout.addWidget(grp_wss) + + grp_tke = QtWidgets.QGroupBox("TKE / Flow Parameters") + fl_tke = QtWidgets.QFormLayout(grp_tke) + self.edit_dm_rho = QtWidgets.QLineEdit("1060.0") + self.edit_dm_stepsize = QtWidgets.QLineEdit("5") + self.edit_dm_tube = QtWidgets.QLineEdit("0.1") + self.chk_dm_multithread = QtWidgets.QCheckBox() + self.chk_dm_multithread.setChecked(False) + fl_tke.addRow(u"Density \u03c1 (kg/m\u00b3)", self.edit_dm_rho) + fl_tke.addRow("Step Size", self.edit_dm_stepsize) + fl_tke.addRow("Tube Radius", self.edit_dm_tube) + fl_tke.addRow("Multi-thread Metrics", self.chk_dm_multithread) + self.params_layout.addWidget(grp_tke) + + def _build_timeline(self, parent): + grp = QtWidgets.QGroupBox("Timeline") + tl = QtWidgets.QHBoxLayout(grp) + self.btn_prev = QtWidgets.QPushButton(u"\u25c0") + self.btn_prev.clicked.connect(self._on_prev_frame) + self.btn_play = QtWidgets.QPushButton(u"\u25b6 Play") + self.btn_play.clicked.connect(self._on_play) + self.btn_pause = QtWidgets.QPushButton(u"\u23f8 Pause") + self.btn_pause.clicked.connect(self._on_pause) + self.btn_next = QtWidgets.QPushButton(u"\u25b6") + self.btn_next.clicked.connect(self._on_next_frame) + self.slider_t = QtWidgets.QSlider(QtCore.Qt.Horizontal) + self.slider_t.setRange(0, 0) + self.slider_t.valueChanged.connect(self._on_t_changed) + self.lab_t = QtWidgets.QLabel("0") + self.spin_interval = QtWidgets.QSpinBox() + self.spin_interval.setRange(10, 2000) + self.spin_interval.setValue(120) + self.spin_interval.setSuffix(" ms") + for w in [self.btn_prev, self.btn_play, self.btn_pause, self.btn_next]: + tl.addWidget(w) + tl.addWidget(self.slider_t, 1) + tl.addWidget(self.lab_t) + tl.addWidget(self.spin_interval) + parent.addWidget(grp) + + def _build_selection_info(self, parent): + grp = QtWidgets.QGroupBox("Selection") + lay = QtWidgets.QHBoxLayout(grp) + box_plane = QtWidgets.QGroupBox("Plane") + lay_plane = QtWidgets.QVBoxLayout(box_plane) + self.text_plane_info = QtWidgets.QPlainTextEdit() + self.text_plane_info.setReadOnly(True) + self.text_plane_info.setMaximumHeight(82) + lay_plane.addWidget(self.text_plane_info) + box_path = QtWidgets.QGroupBox("Path") + lay_path = QtWidgets.QVBoxLayout(box_path) + self.text_path_info = QtWidgets.QPlainTextEdit() + self.text_path_info.setReadOnly(True) + self.text_path_info.setMaximumHeight(82) + lay_path.addWidget(self.text_path_info) + lay.addWidget(box_plane, 1) + lay.addWidget(box_path, 1) + parent.addWidget(grp) + + def _build_log(self, parent): + grp = QtWidgets.QGroupBox("Log") + ll = QtWidgets.QVBoxLayout(grp) + self.console = QtWidgets.QTextEdit() + self.console.setReadOnly(True) + self.console.setMaximumHeight(80) + ll.addWidget(self.console) + parent.addWidget(grp) + + def _build_menu(self): + mb = self.menuBar() + mf = mb.addMenu("File") + for label, slot in [("Open Data", self._on_open_data), ("Clear Workspace", self._on_close_workspace), ("Exit", self.close)]: + a = QtWidgets.QAction(label, self) + a.triggered.connect(slot) + mf.addAction(a) + mv = mb.addMenu("View") + for label, slot in [("Reset Camera", lambda: self.scene.reset_camera()), ("Toggle Axes", lambda: self.scene.toggle_axes()), + ("White BG", lambda: self.scene.set_background("white")), ("Dark BG", lambda: self.scene.set_background("#202124"))]: + a = QtWidgets.QAction(label, self) + a.triggered.connect(slot) + mv.addAction(a) + + def _bind_scene(self): + self.scene = SceneController(self.plotter, self.workspace, self.log) + self.scene.initialize() + self.scene.enable_plane_picking(self._on_3d_plane_picked) + self.scene.enable_path_picking(self._on_3d_path_picked) + + def _on_3d_plane_picked(self, uid, plane_idx): + if self._edit_mode is not None: + return + if uid is None or plane_idx is None: + self.workspace.selected_path_index = -1 + self._selected_plane_index = -1 + self._clear_plane_drag_widgets() + self.scene.highlight_plane(None) + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + self.ortho_viewer.set_selected_plane(None) + self._clear_browser_selection() + self._set_plane_info_text("") + self._set_path_info_text("") + return + self.workspace.selected_path_index = -1 + self._selected_plane_index = int(plane_idx) + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + self.scene.highlight_plane(uid) + self.ortho_viewer.set_selected_plane(int(plane_idx)) + self._select_browser_item_by_uid(uid) + self._activate_plane_drag_widgets(int(plane_idx)) + self._set_path_info_text("") + self._log_selected_plane_metric(int(plane_idx)) + + def _on_3d_path_picked(self, uid, path_idx): + if self._edit_mode is not None: + return + if uid is None or path_idx is None: + self.workspace.selected_path_index = -1 + self._selected_plane_index = -1 + self._clear_plane_drag_widgets() + self.scene.highlight_plane(None) + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + self.ortho_viewer.set_selected_plane(None) + self._clear_browser_selection() + self._set_plane_info_text("") + self._set_path_info_text("") + return + self.workspace.selected_path_index = int(path_idx) + self._selected_plane_index = -1 + self._clear_plane_drag_widgets() + self.scene.highlight_plane(None) + self.scene.highlight_path(uid) + self.scene.show_forks_for_path(int(path_idx)) + self.ortho_viewer.set_selected_plane(None) + self._select_browser_item_by_uid(uid) + self._set_plane_info_text("") + self._log_selected_path_info(int(path_idx)) + + def _find_uid_by_data_key(self, data_key): + for uid, obj in self.workspace.scene_objects.items(): + if obj.data_key == data_key: + return uid + return None + + def _plane_widget_distance(self): + spacing = self._get_spacing_xyz_from_resolution() + return max(5.0, float(np.mean(spacing)) * 8.0) + + def _clear_plane_drag_widgets(self): + self._plane_drag_timer.stop() + self._plane_drag_active = False + self._plane_drag_index = None + self._plane_widget_initializing = False + self._plane_drag_metrics_dirty = False + if self._edit_mode is not None: + return + try: + if hasattr(self.plotter, "clear_sphere_widgets"): + self.plotter.clear_sphere_widgets() + except Exception: + pass + + def _update_plane_from_drag(self, plane_idx, center=None, normal=None): + if self._plane_widget_initializing: + return + if not (0 <= int(plane_idx) < len(self.workspace.planes)): + return + plane = self.workspace.planes[int(plane_idx)] + tol = max(1e-4, float(np.mean(self._get_spacing_xyz_from_resolution())) * 1e-3) + changed = False + if center is not None: + c = np.asarray(center, dtype=float).reshape(3) + if np.linalg.norm(c - np.asarray(plane.center, dtype=float).reshape(3)) > tol: + plane.center = c + changed = True + if normal is not None: + n = np.asarray(normal, dtype=float).reshape(3) + if np.linalg.norm(n) > 1e-12: + new_normal = n / np.linalg.norm(n) + old_normal = np.asarray(plane.normal, dtype=float).reshape(3) + if min(np.linalg.norm(new_normal - old_normal), np.linalg.norm(new_normal + old_normal)) > 1e-5: + plane.normal = new_normal + changed = True + if not changed: + return + self.scene.invalidate_cache("plane_") + uid = self._find_uid_by_data_key(f"plane_{int(plane_idx)}") + if uid is not None: + obj = self.workspace.scene_objects.get(uid) + if obj is not None: + self.scene.readd_object(obj) + self.scene.highlight_plane(uid) + else: + self.scene.sync_from_workspace() + self._selected_plane_index = int(plane_idx) + self.ortho_viewer._selected_plane_idx = int(plane_idx) + self.ortho_viewer.refresh() + self._plane_drag_index = int(plane_idx) + self._plane_drag_metrics_dirty = True + self._plane_drag_timer.start(250) + try: + self.plotter.render() + except Exception: + pass + + def _persist_plane_outputs(self): + out_dir = self.pipeline._output_dir(self.workspace) + metrics = self.workspace.derived.plane_metrics + qc = self.workspace.derived.plane_qc + if metrics and len(metrics) == len(self.workspace.planes): + with open(os.path.join(out_dir, "plane_metrics.json"), "w", encoding="utf-8") as f: + json.dump(metrics, f, ensure_ascii=False, indent=2) + if qc: + with open(os.path.join(out_dir, "plane_qc.json"), "w", encoding="utf-8") as f: + json.dump(qc, f, ensure_ascii=False, indent=2) + try: + self.pipeline._save_planes_json(self.workspace) + except Exception: + pass + + def _finalize_plane_drag(self, plane_idx): + if not (0 <= int(plane_idx) < len(self.workspace.planes)): + return + self._plane_drag_timer.stop() + self._plane_drag_index = int(plane_idx) + if self.workspace.flow_raw is None or self.workspace.segmask_binary is None: + self._persist_plane_outputs() + self._plane_drag_metrics_dirty = False + return + if self._plane_drag_metrics_dirty or len(self.workspace.derived.plane_metrics) != len(self.workspace.planes): + self._recompute_dragged_plane_metrics(persist=True) + else: + self._persist_plane_outputs() + + def _activate_plane_drag_widgets(self, plane_idx): + if self._edit_mode is not None or not (0 <= int(plane_idx) < len(self.workspace.planes)): + return + if self._plane_drag_active and int(self._plane_drag_index) == int(plane_idx): + return + self._clear_plane_drag_widgets() + plane = self.workspace.planes[int(plane_idx)] + center = np.asarray(plane.center, dtype=float).reshape(3) + normal = np.asarray(plane.normal, dtype=float).reshape(3) + if np.linalg.norm(normal) <= 1e-12: + normal = np.array([1.0, 0.0, 0.0], dtype=float) + normal = normal / np.linalg.norm(normal) + tip = center + normal * self._plane_widget_distance() + radius = self._edit_widget_radius() + + def _center_cb(new_center): + self._update_plane_from_drag(plane_idx, center=new_center) + + def _normal_cb(new_tip): + c = np.asarray(self.workspace.planes[int(plane_idx)].center, dtype=float).reshape(3) + tip_now = np.asarray(new_tip, dtype=float).reshape(3) + self._update_plane_from_drag(plane_idx, normal=(tip_now - c)) + + def _end_cb(_widget, _event): + self._finalize_plane_drag(plane_idx) + + try: + self._plane_widget_initializing = True + center_widget = self.plotter.add_sphere_widget( + callback=_center_cb, + center=tuple(center.tolist()), + radius=radius, + color="cyan", + interaction_event="always", + ) + normal_widget = self.plotter.add_sphere_widget( + callback=_normal_cb, + center=tuple(tip.tolist()), + radius=radius, + color="orange", + interaction_event="always", + ) + center_widget.AddObserver(_vtk.vtkCommand.EndInteractionEvent, _end_cb) + normal_widget.AddObserver(_vtk.vtkCommand.EndInteractionEvent, _end_cb) + self._plane_drag_active = True + self._plane_drag_index = int(plane_idx) + except Exception as e: + self._plane_drag_active = False + self._plane_drag_index = None + self.log(f"Plane drag widget error: {type(e).__name__}: {e}") + finally: + self._plane_widget_initializing = False + + def _recompute_dragged_plane_metrics(self, persist=False): + if self._plane_drag_index is None or not (0 <= int(self._plane_drag_index) < len(self.workspace.planes)): + return + if self.workspace.flow_raw is None or self.workspace.segmask_binary is None: + self.ortho_viewer.refresh() + if persist: + self._persist_plane_outputs() + self._plane_drag_metrics_dirty = False + return + plane_idx = int(self._plane_drag_index) + try: + if len(self.workspace.derived.plane_metrics) != len(self.workspace.planes): + self.pipeline._compute_plane_metrics_internal(self.workspace, save=persist) + self.scene.invalidate_cache("plane_") + self.scene.sync_from_workspace() + else: + paths_for_tangent = ( + self.workspace.centerline_paths_smooth + if len(self.workspace.centerline_paths_smooth) > 0 + else self.workspace.centerline_paths + ) + partial_metrics = compute_plane_metrics( + self.workspace.flow_raw, + self.workspace.segmask_binary, + self.workspace.resolution, + self.workspace.origin, + [self.workspace.planes[plane_idx]], + RR=self.workspace.rr, + branch_labels_3d=self.workspace.branch_labels, + path_info=self.workspace.path_info, + forks=self.workspace.forks, + paths=paths_for_tangent, + return_qc=False, + ) + if partial_metrics: + metrics = [dict(m) for m in self.workspace.derived.plane_metrics] + metrics[plane_idx] = dict(partial_metrics[0]) + metrics, qc = apply_internal_consistency_to_metrics(metrics, path_info=self.workspace.path_info, forks=self.workspace.forks) + self.workspace.derived.plane_metrics = metrics + self.workspace.derived.plane_qc = qc + for i, metric in enumerate(metrics): + if i < len(self.workspace.planes): + self.workspace.planes[i].metrics = dict(metric) + if persist: + self._persist_plane_outputs() + if persist: + self._persist_plane_outputs() + self._selected_plane_index = plane_idx + self.ortho_viewer._selected_plane_idx = plane_idx + self.ortho_viewer.refresh() + self._log_selected_plane_metric(plane_idx) + self._plane_drag_metrics_dirty = False + except Exception as e: + self.log(f"Plane metric update error: {type(e).__name__}: {e}") + self.log(traceback.format_exc()) + + def _log_selected_plane_metric(self, plane_idx): + if not (0 <= int(plane_idx) < len(self.workspace.planes)): + self._set_plane_info_text("") + return + plane = self.workspace.planes[int(plane_idx)] + metric = getattr(plane, "metrics", {}) or {} + t = int(np.clip(self.workspace.current_t, 0, max(0, self.workspace.time_count() - 1))) + fr = metric.get("flowrate_mL_s", []) + ar = metric.get("area_mm2", []) + mv = metric.get("meanv_cm_s_t", []) + flow_t = float(fr[t]) if len(fr) > t else 0.0 + area_t = float(ar[t]) if len(ar) > t else 0.0 + meanv_t = float(mv[t]) if len(mv) > t else float(metric.get("meanv_cm_s", 0.0)) + path_dir = metric.get("path_direction", "") + header = f"Plane {int(plane_idx)} | Path {int(metric.get('path_index', plane.path_index))}" + if path_dir: + header += f" {path_dir}" + text_block = ( + f"{header}\n" + f"t={t} Flow Rate={flow_t:.4f} mL/s Area={area_t:.3f} mm^2 Mean Velocity={meanv_t:.3f} cm/s\n" + f"Peak Velocity={float(metric.get('peakv_cm_s', 0.0)):.3f} cm/s Net Flow={float(metric.get('netflow_mL_beat', 0.0)):.4f} mL/beat IC={float(metric.get('path_ic', 1.0)):.3f}" + ) + self._set_plane_info_text(text_block) + + def _log_selected_path_info(self, path_idx): + if not (0 <= int(path_idx) < len(self.workspace.path_info)): + self._set_path_info_text("") + return + info = self.workspace.path_info[int(path_idx)] + incoming = [int(x) for x in info.get("incoming_path_ids", [])] + outgoing = [int(x) for x in info.get("outgoing_path_ids", [])] + forks = [] + for fork in self.workspace.forks: + if int(path_idx) in fork.get("left", []) or int(path_idx) in fork.get("right", []): + forks.append(f"node={int(fork.get('node', -1))} L={fork.get('left', [])} R={fork.get('right', [])}") + fork_txt = " ; ".join(forks) if forks else "none" + text_block = ( + f"Path {int(path_idx)} | dir={info.get('direction_text', '')}\n" + f"start_node={int(info.get('start_node', -1))} end_node={int(info.get('end_node', -1))}\n" + f"incoming: {incoming if incoming else 'none'} outgoing: {outgoing if outgoing else 'none'}\n" + f"forks: {fork_txt}" + ) + self._set_path_info_text(text_block) + + def _select_browser_item_by_uid(self, uid): + self.tree_objects.blockSignals(True) + for i in range(self.tree_objects.topLevelItemCount()): + top = self.tree_objects.topLevelItem(i) + for j in range(top.childCount()): + child = top.child(j) + if child.data(0, QtCore.Qt.UserRole) == uid: + self.tree_objects.setCurrentItem(child) + self.tree_objects.blockSignals(False) + return + self.tree_objects.blockSignals(False) + + def _clear_browser_selection(self): + self.tree_objects.blockSignals(True) + self.tree_objects.clearSelection() + self.tree_objects.blockSignals(False) + + def _set_plane_info_text(self, text): + msg = str(text).strip() if text else "No plane selected." + self.text_plane_info.setPlainText(msg) + + def _set_path_info_text(self, text): + msg = str(text).strip() if text else "No path selected." + self.text_path_info.setPlainText(msg) + + def _refresh_selection_info(self): + if not (0 <= int(self._selected_plane_index) < len(self.workspace.planes)): + self._selected_plane_index = -1 + self._set_plane_info_text("") + else: + self._log_selected_plane_metric(int(self._selected_plane_index)) + path_idx = int(getattr(self.workspace, "selected_path_index", -1)) + if not (0 <= path_idx < len(self.workspace.path_info)): + self.workspace.selected_path_index = -1 + self._set_path_info_text("") + else: + self._log_selected_path_info(path_idx) + + def log(self, text): + self.console.append(str(text)) + + def _float_from_text(self, text, default=0.0): + try: + return float(text) + except Exception: + return default + + def _int_from_text(self, text, default=0): + try: + return int(text) + except Exception: + return default + + def _parse_int_list(self, text): + r = [] + for tok in text.replace(";", ",").split(","): + tok = tok.strip() + if tok: + try: + r.append(int(tok)) + except ValueError: + pass + return r + + def _sync_params_to_ws(self): + ws = self.workspace + ws.skeleton_params.remove_small_cc = self.chk_remove_small_cc.isChecked() + ws.skeleton_params.min_cc_volume_mm3 = self._float_from_text(self.edit_min_cc_volume.text(), 50.0) + ws.skeleton_params.do_closing = self.chk_closing.isChecked() + ws.skeleton_params.do_opening = self.chk_opening.isChecked() + ws.skeleton_params.gaussian_enabled = self.chk_gaussian.isChecked() + ws.skeleton_params.gaussian_sigma = self._float_from_text(self.edit_gauss_sigma.text(), 0.5) + ws.plane_gen_params.use_center_plane = self.radio_plane_center.isChecked() + ws.plane_gen_params.cross_section_distance = self._float_from_text(self.edit_plane_dist.text(), 20.0) + ws.plane_gen_params.start_distance = self._float_from_text(self.edit_plane_start.text(), 5.0) + ws.plane_gen_params.end_distance = self._float_from_text(self.edit_plane_end.text(), 0.0) + ws.plane_gen_params.smoothing_window = self._int_from_text(self.edit_plane_smooth_win.text(), 15) + ws.plane_gen_params.smoothing_polyorder = self._int_from_text(self.edit_plane_smooth_poly.text(), 3) + ws.plane_gen_params.inter_time = self._int_from_text(self.edit_plane_inter_time.text(), 10) + ws.streamline_params.seed_ratio = min(max(self._float_from_text(self.edit_sl_ratio.text(), 0.02), 0.0001), 1.0) + ws.streamline_params.max_steps = min(max(self._int_from_text(self.edit_sl_maxsteps.text(), 2000), 1), 200000) + ws.streamline_params.terminal_speed = min(max(self._float_from_text(self.edit_sl_terminal.text(), 0.01), 0.0), 1e6) + ws.streamline_params.min_seeds = 50 + ws.derived_params.smoothing_iteration = max(self._int_from_text(self.edit_dm_smoothing.text(), 200), 0) + ws.derived_params.viscosity = max(self._float_from_text(self.edit_dm_viscosity.text(), 4.0), 0.0) + ws.derived_params.inward_distance = max(self._float_from_text(self.edit_dm_inward.text(), 0.6), 0.01) + ws.derived_params.parabolic_fitting = self.chk_dm_parabolic.isChecked() + ws.derived_params.no_slip_condition = self.chk_dm_noslip.isChecked() + ws.derived_params.rho = max(self._float_from_text(self.edit_dm_rho.text(), 1060.0), 1.0) + ws.derived_params.step_size = max(self._int_from_text(self.edit_dm_stepsize.text(), 5), 1) + ws.derived_params.tube_radius = max(self._float_from_text(self.edit_dm_tube.text(), 0.1), 0.0) + ws.derived_params.use_multithread = self.chk_dm_multithread.isChecked() + + def _sync_params_to_ui(self): + ws = self.workspace + self.chk_remove_small_cc.setChecked(ws.skeleton_params.remove_small_cc) + self.edit_min_cc_volume.setText(str(ws.skeleton_params.min_cc_volume_mm3)) + self.chk_closing.setChecked(ws.skeleton_params.do_closing) + self.chk_opening.setChecked(ws.skeleton_params.do_opening) + self.chk_gaussian.setChecked(ws.skeleton_params.gaussian_enabled) + self.edit_gauss_sigma.setText(str(ws.skeleton_params.gaussian_sigma)) + self.radio_plane_center.setChecked(ws.plane_gen_params.use_center_plane) + self.radio_plane_by_distance.setChecked(not ws.plane_gen_params.use_center_plane) + self.edit_plane_dist.setText(str(ws.plane_gen_params.cross_section_distance)) + self.edit_plane_start.setText(str(ws.plane_gen_params.start_distance)) + self.edit_plane_end.setText(str(ws.plane_gen_params.end_distance)) + self.edit_plane_smooth_win.setText(str(ws.plane_gen_params.smoothing_window)) + self.edit_plane_smooth_poly.setText(str(ws.plane_gen_params.smoothing_polyorder)) + self.edit_plane_inter_time.setText(str(ws.plane_gen_params.inter_time)) + self.edit_sl_ratio.setText(str(ws.streamline_params.seed_ratio)) + self.edit_sl_maxsteps.setText(str(ws.streamline_params.max_steps)) + self.edit_sl_terminal.setText(str(ws.streamline_params.terminal_speed)) + self.edit_dm_smoothing.setText(str(ws.derived_params.smoothing_iteration)) + self.edit_dm_viscosity.setText(str(ws.derived_params.viscosity)) + self.edit_dm_inward.setText(str(ws.derived_params.inward_distance)) + self.chk_dm_parabolic.setChecked(ws.derived_params.parabolic_fitting) + self.chk_dm_noslip.setChecked(ws.derived_params.no_slip_condition) + self.edit_dm_rho.setText(str(ws.derived_params.rho)) + self.edit_dm_stepsize.setText(str(ws.derived_params.step_size)) + self.edit_dm_tube.setText(str(ws.derived_params.tube_radius)) + self.chk_dm_multithread.setChecked(ws.derived_params.use_multithread) + + def _rebuild_plane_objects(self): + self._clear_plane_drag_widgets() + ws = self.workspace + ws.remove_objects_by_prefix("plane_") + for i in range(len(ws.planes)): + ws.add_object(name=f"Plane {i}", kind=ObjectKind.PLANE, + data_key=f"plane_{i}", visible=True, opacity=0.6, + color="yellow", line_width=2) + self.scene.invalidate_cache("plane_") + self.scene.sync_from_workspace() + + def _refresh_all(self): + self._refresh_browser() + self._refresh_timeline() + self._sync_params_to_ui() + self._refresh_selection_info() + self._refresh_scene() + + def _refresh_browser(self): + self.tree_objects.blockSignals(True) + self.tree_objects.clear() + groups = {} + for obj in self.workspace.scene_objects.values(): + if obj.data_key == "branch_surface": + continue + kn = obj.kind.value + if kn not in groups: + top = QtWidgets.QTreeWidgetItem([kn, "", ""]) + top.setFlags(top.flags() | QtCore.Qt.ItemIsUserCheckable) + top.setFlags(top.flags() & ~QtCore.Qt.ItemIsSelectable) + top.setCheckState(0, QtCore.Qt.Checked) + groups[kn] = top + self.tree_objects.addTopLevelItem(top) + it = QtWidgets.QTreeWidgetItem([obj.name, obj.kind.value, ""]) + it.setData(0, QtCore.Qt.UserRole, obj.uid) + it.setFlags(it.flags() | QtCore.Qt.ItemIsUserCheckable) + it.setCheckState(0, QtCore.Qt.Checked if obj.visible else QtCore.Qt.Unchecked) + groups[kn].addChild(it) + for kn, top in groups.items(): + vis_count = sum(1 for i in range(top.childCount()) if top.child(i).checkState(0) == QtCore.Qt.Checked) + total = top.childCount() + if vis_count == total: + top.setCheckState(0, QtCore.Qt.Checked) + elif vis_count == 0: + top.setCheckState(0, QtCore.Qt.Unchecked) + else: + top.setCheckState(0, QtCore.Qt.PartiallyChecked) + self.tree_objects.expandAll() + self.tree_objects.blockSignals(False) + + def _selected_uid(self): + items = self.tree_objects.selectedItems() + if not items: + return None + return items[0].data(0, QtCore.Qt.UserRole) + + def _on_browser_select(self): + uid = self._selected_uid() + if uid: + obj = self.workspace.scene_objects.get(uid) + if obj and obj.kind == ObjectKind.PLANE: + pidx = _parse_plane_index(obj.data_key) + self.workspace.selected_path_index = -1 + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + if pidx is not None: + self._selected_plane_index = int(pidx) + self.ortho_viewer.set_selected_plane(int(pidx)) + self._activate_plane_drag_widgets(int(pidx)) + self._set_path_info_text("") + self._log_selected_plane_metric(int(pidx)) + self.scene.highlight_plane(uid) + elif obj and obj.kind == ObjectKind.BRANCH: + pidx = _parse_path_index(obj.data_key) + self._selected_plane_index = -1 + self.scene.highlight_plane(None) + self._clear_plane_drag_widgets() + self.scene.highlight_path(uid) + self._set_plane_info_text("") + self.ortho_viewer.set_selected_plane(None) + if pidx is not None: + self.workspace.selected_path_index = int(pidx) + self.scene.show_forks_for_path(int(pidx)) + self._log_selected_path_info(int(pidx)) + else: + self.workspace.selected_path_index = -1 + self.scene.show_forks_for_path(-1) + else: + self._selected_plane_index = -1 + self.workspace.selected_path_index = -1 + self._clear_plane_drag_widgets() + self.scene.highlight_plane(None) + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + self.ortho_viewer.set_selected_plane(None) + self._refresh_selection_info() + else: + self._selected_plane_index = -1 + self.workspace.selected_path_index = -1 + self._clear_plane_drag_widgets() + self.scene.highlight_plane(None) + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + self.ortho_viewer.set_selected_plane(None) + self._refresh_selection_info() + + def _on_tree_item_changed(self, item, column): + uid = item.data(0, QtCore.Qt.UserRole) + if uid: + obj = self.workspace.scene_objects.get(uid) + if obj: + obj.visible = item.checkState(0) == QtCore.Qt.Checked + self.scene.apply_object_properties(obj) + else: + checked = item.checkState(0) != QtCore.Qt.Unchecked + self.tree_objects.blockSignals(True) + for i in range(item.childCount()): + child = item.child(i) + child.setCheckState(0, QtCore.Qt.Checked if checked else QtCore.Qt.Unchecked) + cuid = child.data(0, QtCore.Qt.UserRole) + if cuid: + obj = self.workspace.scene_objects.get(cuid) + if obj: + obj.visible = checked + self.scene.apply_object_properties(obj) + self.tree_objects.blockSignals(False) + self._refresh_scene() + + def _on_browser_ctx_menu(self, pos): + item = self.tree_objects.itemAt(pos) + if item is None: + return + uid = item.data(0, QtCore.Qt.UserRole) + menu = QtWidgets.QMenu(self) + if uid is None: + act_show = menu.addAction("Show All") + act_hide = menu.addAction("Hide All") + act_del_all = menu.addAction("Delete All") + action = menu.exec_(self.tree_objects.viewport().mapToGlobal(pos)) + if action == act_show: + self._set_group_vis(item, True) + elif action == act_hide: + self._set_group_vis(item, False) + elif action == act_del_all: + self.tree_objects.setCurrentItem(item) + self._on_delete_object() + else: + act_toggle = menu.addAction("Toggle Visibility") + act_del = menu.addAction("Delete") + obj = self.workspace.scene_objects.get(uid) + act_plane_sl = None + if obj and obj.kind == ObjectKind.PLANE: + act_plane_sl = menu.addAction("Streamlines from Plane") + action = menu.exec_(self.tree_objects.viewport().mapToGlobal(pos)) + if action == act_toggle: + if obj: + obj.visible = not obj.visible + self.scene.apply_object_properties(obj) + self._refresh_browser() + elif action == act_del: + self._on_delete_object() + elif act_plane_sl is not None and action == act_plane_sl: + pidx = _parse_plane_index(obj.data_key) + if pidx is not None: + self._trigger_plane_streamlines(pidx) + + def _trigger_plane_streamlines(self, plane_idx): + self._sync_params_to_ws() + self.pipeline.preprocess(self.workspace) + self.workspace.plane_streamline_plane_idx = plane_idx + self.scene.trigger_plane_streamlines(plane_idx) + self._refresh_browser() + self.scene.invalidate_cache() + self.scene.sync_from_workspace() + self._refresh_all() + self._log_selected_plane_metric(int(plane_idx)) + + def _set_group_vis(self, group_item, visible): + for i in range(group_item.childCount()): + uid = group_item.child(i).data(0, QtCore.Qt.UserRole) + if uid: + obj = self.workspace.scene_objects.get(uid) + if obj: + obj.visible = visible + self.scene.apply_object_properties(obj) + self._refresh_browser() + self._refresh_scene() + + def _on_delete_object(self): + items = self.tree_objects.selectedItems() + if not items: + return + item = items[0] + uid = item.data(0, QtCore.Qt.UserRole) + plane_indices_removed = [] + if uid: + obj = self.workspace.scene_objects.get(uid) + name = obj.name if obj else uid + if obj and obj.kind == ObjectKind.PLANE: + pidx = _parse_plane_index(obj.data_key) + if pidx is not None: + plane_indices_removed.append(pidx) + self.scene.remove_object(uid) + self.log(f"Deleted: {name}") + else: + count = item.childCount() + if count == 0: + return + kind_name = item.text(0) + uids = [] + for i in range(count): + cuid = item.child(i).data(0, QtCore.Qt.UserRole) + if cuid: + cobj = self.workspace.scene_objects.get(cuid) + if cobj and cobj.kind == ObjectKind.PLANE: + pidx = _parse_plane_index(cobj.data_key) + if pidx is not None: + plane_indices_removed.append(pidx) + uids.append(cuid) + for u in uids: + self.scene.remove_object(u) + self.log(f"Deleted section: {kind_name} ({len(uids)} objects)") + if plane_indices_removed: + self._clear_plane_drag_widgets() + for pidx in sorted(plane_indices_removed, reverse=True): + if 0 <= pidx < len(self.workspace.planes): + self.workspace.planes.pop(pidx) + self._rebuild_plane_objects() + self._selected_plane_index = -1 + self.workspace.selected_path_index = -1 + self.scene.highlight_plane(None) + self.scene.highlight_path(None) + self.scene.show_forks_for_path(-1) + self._refresh_browser() + self._refresh_selection_info() + + def _refresh_timeline(self): + T = max(1, self.workspace.time_count()) + self.slider_t.blockSignals(True) + self.slider_t.setMaximum(T - 1) + self.slider_t.setValue(self.workspace.current_t) + self.slider_t.blockSignals(False) + self.lab_t.setText(str(self.workspace.current_t)) + + def _on_t_changed(self, v): + self.workspace.current_t = int(v) + self.lab_t.setText(str(v)) + self.scene.update_time(int(v)) + self.ortho_viewer.refresh() + self._refresh_selection_info() + + def _on_prev_frame(self): + self.workspace.current_t = max(0, self.workspace.current_t - 1) + self._refresh_timeline() + self.scene.update_time(self.workspace.current_t) + self.ortho_viewer.refresh() + self._refresh_selection_info() + + def _on_next_frame(self): + T = self.workspace.time_count() + self.workspace.current_t = min(T - 1, self.workspace.current_t + 1) + self._refresh_timeline() + self.scene.update_time(self.workspace.current_t) + self.ortho_viewer.refresh() + self._refresh_selection_info() + + def _on_play(self): + self.scene.set_playback_active(True) + self._play_timer.start(self.spin_interval.value()) + + def _on_pause(self): + self._play_timer.stop() + self.scene.set_playback_active(False) + + def _on_play_tick(self): + T = self.workspace.time_count() + if T <= 1: + return + self.workspace.current_t = (self.workspace.current_t + 1) % T + self._refresh_timeline() + self.scene.update_time(self.workspace.current_t) + self.ortho_viewer.refresh() + self._refresh_selection_info() + + def _refresh_scene(self): + try: + self.scene.render_all() + except Exception as e: + self.log(f"VIEW ERROR: {type(e).__name__}: {e}") + + def _on_open_data(self): + path, _ = QtWidgets.QFileDialog.getOpenFileName(self, "Open Data", "", "H5 (*.h5 *.hdf5);;All (*)") + if not path: + return + try: + if self._edit_mode is not None: + self._exit_interactive_edit(False) + self._clear_plane_drag_widgets() + self.workspace.reset_all() + self.workspace.paths.segmask_path = path + self.workspace.paths.flow_path = path + self.pipeline.load_data(self.workspace, self.log) + self.scene.workspace = self.workspace + self.scene.reset_scene() + self._refresh_all() + self.ortho_viewer.update_slider_ranges() + except Exception as e: + self.log(f"LOAD ERROR: {type(e).__name__}: {e}") + self.log(traceback.format_exc()) + + def _on_close_workspace(self): + if self._edit_mode is not None: + self._exit_interactive_edit(False) + self._clear_plane_drag_widgets() + self.workspace.reset_all() + self._selected_plane_index = -1 + self.scene.reset_scene() + self.ortho_viewer.reset_state() + self._refresh_all() + self.log("Workspace cleared") + + def _run_single_step(self, step): + if not self.workspace.data_loaded: + self.log("No data loaded. Use File > Open Data.") + return + if self._edit_mode is not None: + if step == StepId.EDIT_SKELETON and self._edit_mode == "skeleton": + self._exit_interactive_edit(True) + return + if step == StepId.EDIT_GRAPH and self._edit_mode == "graph": + self._exit_interactive_edit(True) + return + self.log("Finish current interactive edit first. Press ESC to force exit.") + return + try: + self._sync_params_to_ws() + self._clear_plane_drag_widgets() + self.setEnabled(False) + QtWidgets.QApplication.processEvents() + if step == StepId.EDIT_SKELETON: + self._start_skeleton_interactive_edit() + return + if step == StepId.EDIT_GRAPH: + self._start_graph_interactive_edit() + return + if step == StepId.GENERATE_STREAMLINES: + self.pipeline.preprocess(self.workspace) + self.scene.trigger_streamlines() + self._refresh_browser() + self.scene.invalidate_cache() + self.scene.sync_from_workspace() + self._refresh_all() + return + if step == StepId.PLANE_STREAMLINES: + self._on_plane_streamlines_step() + return + t0 = time.time() + result = self.pipeline.run_step(self.workspace, step, self.log) + elapsed = time.time() - t0 + self.log(f"[{step.label}] {elapsed:.2f}s - {result.message}") + self.scene.invalidate_cache() + self.scene.sync_from_workspace() + self._refresh_all() + self.ortho_viewer.refresh() + except Exception as e: + self.log(f"STEP ERROR: {type(e).__name__}: {e}") + self.log(traceback.format_exc()) + finally: + self.setEnabled(True) + + def _run_all_pipeline(self): + if not self.workspace.data_loaded: + self.log("No data loaded. Use File > Open Data.") + return + if self._edit_mode is not None: + self.log("Finish current interactive edit first.") + return + self._sync_params_to_ws() + self._clear_plane_drag_widgets() + self.setEnabled(False) + QtWidgets.QApplication.processEvents() + all_steps = [ + StepId.GENERATE_SKELETON, + StepId.GENERATE_GRAPH, + StepId.GENERATE_PLANES, + StepId.COMPUTE_PLANE_METRICS, + StepId.COMPUTE_DERIVED_METRICS, + ] + try: + t_total = time.time() + for step in all_steps: + t0 = time.time() + self.log(f"[Run All] Running {step.label}...") + QtWidgets.QApplication.processEvents() + result = self.pipeline.run_step(self.workspace, step, self.log) + elapsed = time.time() - t0 + self.log(f"[{step.label}] {elapsed:.2f}s - {result.message}") + self.scene.invalidate_cache() + self.scene.sync_from_workspace() + self._refresh_all() + self.ortho_viewer.refresh() + total_elapsed = time.time() - t_total + self.log(f"[Run All] Completed in {total_elapsed:.2f}s") + except Exception as e: + self.log(f"RUN ALL ERROR: {type(e).__name__}: {e}") + self.log(traceback.format_exc()) + finally: + self.setEnabled(True) + + def _on_plane_streamlines_step(self): + ws = self.workspace + if len(ws.planes) == 0: + self.log("No planes available for plane streamlines.") + return + uid = self._selected_uid() + pidx = 0 + if uid: + obj = ws.scene_objects.get(uid) + if obj and obj.kind == ObjectKind.PLANE: + parsed = _parse_plane_index(obj.data_key) + if parsed is not None: + pidx = parsed + self._trigger_plane_streamlines(pidx) + + def _get_spacing_xyz_from_resolution(self): + r = np.asarray(self.workspace.resolution, dtype=float).reshape(-1) + if r.size >= 3: + return np.array([float(r[0]), float(r[1]), float(r[2])], dtype=float) + return np.array([1.0, 1.0, 1.0], dtype=float) + + def _edit_widget_radius(self): + spacing = self._get_spacing_xyz_from_resolution() + return max(0.1, float(np.mean(spacing)) * 0.6) + + def _graph_polydata(self, points, edges): + points = np.asarray(points, dtype=float).reshape(-1, 3) + poly = pv.PolyData(points) + edges = np.asarray(edges, dtype=int).reshape(-1, 2) if len(edges) else np.empty((0, 2), dtype=int) + if len(edges) > 0: + cells = np.empty((len(edges), 3), dtype=np.int64) + cells[:, 0] = 2 + cells[:, 1] = edges[:, 0] + cells[:, 2] = edges[:, 1] + poly.lines = cells.ravel() + return poly + + def _cleanup_edit_actors(self): + for actor in [self._edit_actor, self._edit_edge_actor, self._edit_sel_actor, self._edit_sel_edge_actor]: + if actor is not None: + try: + self.plotter.remove_actor(actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(actor) + except Exception: + pass + self._edit_actor = None + self._edit_edge_actor = None + self._edit_sel_actor = None + self._edit_sel_edge_actor = None + self._edit_poly = None + self._edit_edge_poly = None + self._edit_sel_poly = None + self._edit_sel_edge_poly = None + + def _remove_edit_widget(self): + try: + if hasattr(self.plotter, "clear_sphere_widgets"): + self.plotter.clear_sphere_widgets() + except Exception: + pass + try: + if hasattr(self.plotter, "remove_widget") and self._edit_widget is not None: + try: + self.plotter.remove_widget(self._edit_widget) + except Exception: + pass + except Exception: + pass + self._edit_widget = None + + def _update_edit_labels(self): + if self._edit_info_label is None or self._edit_status_label is None: + return + mode = self._edit_mode or "-" + npts = 0 if self._edit_points is None else len(self._edit_points) + nedges = 0 if self._edit_edges is None else len(self._edit_edges) + sel = "-" if self._edit_selected_idx is None else str(int(self._edit_selected_idx)) + esel = "-" if self._edit_selected_edge_idx is None else str(int(self._edit_selected_edge_idx)) + self._edit_info_label.setText(f"Mode: {mode} Points: {npts} Edges: {nedges} Selected Node: {sel} Selected Edge: {esel}") + edge_state = "ON" if self._edit_edge_mode else "OFF" + src = "-" if self._edit_edge_src_idx is None else str(int(self._edit_edge_src_idx)) + self._edit_status_label.setText(f"Edge Mode: {edge_state} Edge Src: {src} Keys: Delete/Backspace delete, E toggle edge mode") + if self._edit_btn_edge is not None: + self._edit_btn_edge.setText("Edge Mode: ON" if self._edit_edge_mode else "Edge Mode: OFF") + + def _update_edit_points_actor(self): + if self._edit_points is None or len(self._edit_points) == 0: + if self._edit_actor is not None: + try: + self.plotter.remove_actor(self._edit_actor) + except Exception: + pass + self._edit_actor = None + self._edit_poly = None + return + if self._edit_poly is None: + self._edit_poly = pv.PolyData(np.asarray(self._edit_points, dtype=float)) + else: + self._edit_poly.points = np.asarray(self._edit_points, dtype=float) + color = "red" if self._edit_mode == "skeleton" else "deepskyblue" + if self._edit_mode == "plane": + color = "yellow" + if self._edit_actor is None: + self._edit_actor = self.plotter.add_mesh(self._edit_poly, color=color, point_size=10, render_points_as_spheres=True, name="interactive_edit_points") + else: + try: + self._edit_actor.GetMapper().SetInputData(self._edit_poly) + except Exception: + try: + self.plotter.remove_actor(self._edit_actor) + except Exception: + pass + self._edit_actor = self.plotter.add_mesh(self._edit_poly, color=color, point_size=10, render_points_as_spheres=True, name="interactive_edit_points") + + def _update_edit_edges_actor(self): + if self._edit_edges is None or len(self._edit_edges) == 0 or self._edit_points is None or len(self._edit_points) == 0: + if self._edit_edge_actor is not None: + try: + self.plotter.remove_actor(self._edit_edge_actor) + except Exception: + pass + self._edit_edge_actor = None + self._edit_edge_poly = None + return + poly = self._graph_polydata(self._edit_points, self._edit_edges) + self._edit_edge_poly = poly + color = "green" if self._edit_mode == "graph" else "orange" + if self._edit_mode == "plane": + color = "yellow" + if self._edit_edge_actor is None: + self._edit_edge_actor = self.plotter.add_mesh(poly, color=color, line_width=3, name="interactive_edit_edges") + else: + try: + self._edit_edge_actor.GetMapper().SetInputData(poly) + except Exception: + try: + self.plotter.remove_actor(self._edit_edge_actor) + except Exception: + pass + self._edit_edge_actor = self.plotter.add_mesh(poly, color=color, line_width=3, name="interactive_edit_edges") + + def _set_selected_idx(self, idx): + self._edit_selected_idx = None if idx is None else int(idx) + if self._edit_points is None or len(self._edit_points) == 0: + self._edit_selected_idx = None + elif self._edit_selected_idx is not None and not (0 <= self._edit_selected_idx < len(self._edit_points)): + self._edit_selected_idx = None + if self._edit_selected_idx is None: + if self._edit_sel_actor is not None: + try: + self.plotter.remove_actor(self._edit_sel_actor) + except Exception: + pass + self._edit_sel_actor = None + self._edit_sel_poly = None + self._remove_edit_widget() + self._update_edit_labels() + try: + self.plotter.render() + except Exception: + pass + return + p = np.asarray(self._edit_points[self._edit_selected_idx], dtype=float).reshape(1, 3) + if self._edit_sel_poly is None: + self._edit_sel_poly = pv.PolyData(p) + else: + self._edit_sel_poly.points = p + if self._edit_sel_actor is None: + self._edit_sel_actor = self.plotter.add_mesh(self._edit_sel_poly, color="yellow", point_size=16, render_points_as_spheres=True, name="interactive_edit_selected") + else: + try: + self._edit_sel_actor.GetMapper().SetInputData(self._edit_sel_poly) + except Exception: + try: + self.plotter.remove_actor(self._edit_sel_actor) + except Exception: + pass + self._edit_sel_actor = self.plotter.add_mesh(self._edit_sel_poly, color="yellow", point_size=16, render_points_as_spheres=True, name="interactive_edit_selected") + self._create_or_move_edit_widget(p[0]) + self._update_edit_labels() + try: + self.plotter.render() + except Exception: + pass + + def _create_or_move_edit_widget(self, center): + if self._edit_selected_idx is None or self._edit_points is None or len(self._edit_points) == 0: + return + self._remove_edit_widget() + radius = self._edit_widget_radius() + def _cb(new_center): + if self._edit_selected_idx is None or self._edit_points is None or len(self._edit_points) == 0: + return + c = np.asarray(new_center, dtype=float).reshape(3) + self._edit_points[self._edit_selected_idx, :] = c + if self._edit_poly is not None: + try: + self._edit_poly.points = self._edit_points + except Exception: + pass + if self._edit_sel_poly is not None: + try: + self._edit_sel_poly.points = np.asarray([c], dtype=float) + except Exception: + pass + if self._edit_edge_poly is not None: + try: + self._edit_edge_poly.points = self._edit_points + except Exception: + self._update_edit_edges_actor() + try: + self.plotter.render() + except Exception: + pass + self._edit_widget = self.plotter.add_sphere_widget(callback=_cb, center=tuple(np.asarray(center, dtype=float).tolist()), radius=radius, color="orange") + + def _enable_interactive_key_events(self, enable): + try: + iren = self.plotter.iren.interactor + except Exception: + iren = None + if not enable: + if iren is not None and self._vtk_keypress_obs_id is not None: + try: + iren.RemoveObserver(self._vtk_keypress_obs_id) + except Exception: + pass + self._vtk_keypress_obs_id = None + return + if iren is None: + self.log("WARNING: No VTK interactor available; cannot enable key events.") + return + def _on_keypress(obj, ev): + key = "" + try: + key = iren.GetKeySym() + except Exception: + return + if key == "Escape": + self._force_exit_edit() + return + if self._edit_mode is None: + return + if key in ("Delete", "BackSpace"): + if self._edit_selected_edge_idx is not None: + self._delete_selected_edge() + else: + self._delete_selected_interactive_point() + return + if key in ("e", "E"): + self._toggle_edge_mode() + return + if self._vtk_keypress_obs_id is not None: + try: + iren.RemoveObserver(self._vtk_keypress_obs_id) + except Exception: + pass + self._vtk_keypress_obs_id = None + self._vtk_keypress_obs_id = iren.AddObserver("KeyPressEvent", _on_keypress) + + def _enable_interactive_point_picking(self, enable): + self._edit_pick_enabled = bool(enable) + try: + iren = self.plotter.iren.interactor + except Exception: + iren = None + if not enable: + if iren is not None and self._vtk_left_click_obs_id is not None: + try: + iren.RemoveObserver(self._vtk_left_click_obs_id) + except Exception: + pass + self._vtk_left_click_obs_id = None + self._vtk_point_picker = None + return + if iren is None: + self.log("WARNING: No VTK interactor available; cannot enable picking.") + return + if self._vtk_point_picker is None: + self._vtk_point_picker = pv._vtk.vtkPointPicker() + self._vtk_point_picker.SetTolerance(0.02) + def _on_left_click(obj, ev): + if self._edit_mode is None: + try: + iren.GetInteractorStyle().OnLeftButtonDown() + except Exception: + pass + return + if self._edit_actor is None or self._edit_points is None or len(self._edit_points) == 0: + try: + iren.GetInteractorStyle().OnLeftButtonDown() + except Exception: + pass + return + try: + x, y = iren.GetEventPosition() + except Exception: + x, y = None, None + if x is None: + try: + iren.GetInteractorStyle().OnLeftButtonDown() + except Exception: + pass + return + try: + self._vtk_point_picker.InitializePickList() + self._vtk_point_picker.AddPickList(self._edit_actor) + if self._edit_edge_actor is not None: + self._vtk_point_picker.AddPickList(self._edit_edge_actor) + self._vtk_point_picker.PickFromListOn() + except Exception: + pass + try: + ren = self.plotter.renderer + ok = self._vtk_point_picker.Pick(float(x), float(y), 0.0, ren) + except Exception: + ok = 0 + if not ok: + try: + iren.GetInteractorStyle().OnLeftButtonDown() + except Exception: + pass + return + try: + p = np.asarray(self._vtk_point_picker.GetPickPosition(), dtype=float).reshape(3) + except Exception: + try: + iren.GetInteractorStyle().OnLeftButtonDown() + except Exception: + pass + return + pts = np.asarray(self._edit_points, dtype=float) + d2 = np.sum((pts - p.reshape(1, 3)) ** 2, axis=1) + node_idx = int(np.argmin(d2)) + node_dist = float(np.sqrt(d2[node_idx])) + edge_idx, edge_dist = self._find_closest_edge(p) + if self._edit_edge_mode and self._edit_mode == "graph": + if self._edit_edge_src_idx is None: + self._edit_edge_src_idx = node_idx + self._set_selected_idx(node_idx) + self._edit_selected_edge_idx = None + self._clear_edge_selection_actor() + self._update_edit_labels() + else: + src = self._edit_edge_src_idx + self._edit_edge_src_idx = None + if src != node_idx: + self._toggle_edge(src, node_idx) + self._update_edit_labels() + elif self._edit_mode == "graph" and edge_idx is not None and edge_dist < node_dist * 0.7: + self._set_selected_edge_idx(edge_idx) + else: + self._edit_selected_edge_idx = None + self._clear_edge_selection_actor() + self._set_selected_idx(node_idx) + try: + self.plotter.render() + except Exception: + pass + if self._vtk_left_click_obs_id is not None: + try: + iren.RemoveObserver(self._vtk_left_click_obs_id) + except Exception: + pass + self._vtk_left_click_obs_id = None + try: + self._vtk_left_click_obs_id = iren.AddObserver("LeftButtonPressEvent", _on_left_click) + except Exception as e: + self.log(f"WARNING: failed to add VTK observer for picking: {e}") + self._vtk_left_click_obs_id = None + + def _toggle_edge_mode(self): + if self._edit_mode != "graph": + return + self._edit_edge_mode = not self._edit_edge_mode + self._edit_edge_src_idx = None + self._update_edit_labels() + + def _toggle_edge(self, i, j): + if self._edit_edges is None: + self._edit_edges = np.empty((0, 2), dtype=int) + edges = np.asarray(self._edit_edges, dtype=int).reshape(-1, 2) + found = -1 + for k, (a, b) in enumerate(edges): + if (int(a) == i and int(b) == j) or (int(a) == j and int(b) == i): + found = k + break + if found >= 0: + self._edit_edges = np.delete(edges, found, axis=0) + self.log(f"Removed edge ({i}, {j})") + else: + self._edit_edges = np.vstack([edges, [i, j]]) if len(edges) > 0 else np.array([[i, j]], dtype=int) + self.log(f"Added edge ({i}, {j})") + self._edit_selected_edge_idx = None + self._clear_edge_selection_actor() + self._update_edit_edges_actor() + self._update_edit_labels() + try: + self.plotter.render() + except Exception: + pass + + def _clear_edge_selection_actor(self): + if self._edit_sel_edge_actor is not None: + try: + self.plotter.remove_actor(self._edit_sel_edge_actor) + except Exception: + pass + self._edit_sel_edge_actor = None + self._edit_sel_edge_poly = None + + def _set_selected_edge_idx(self, idx): + self._edit_selected_edge_idx = None if idx is None else int(idx) + if self._edit_edges is None or len(self._edit_edges) == 0: + self._edit_selected_edge_idx = None + elif self._edit_selected_edge_idx is not None and not (0 <= self._edit_selected_edge_idx < len(self._edit_edges)): + self._edit_selected_edge_idx = None + if self._edit_selected_edge_idx is None: + self._clear_edge_selection_actor() + self._update_edit_labels() + return + edge = self._edit_edges[self._edit_selected_edge_idx] + pts = self._edit_points[edge] + poly = pv.PolyData(pts) + poly.lines = np.array([2, 0, 1], dtype=np.int64) + self._edit_sel_edge_poly = poly + if self._edit_sel_edge_actor is None: + self._edit_sel_edge_actor = self.plotter.add_mesh(poly, color="yellow", line_width=6, name="interactive_edit_selected_edge") + else: + try: + self._edit_sel_edge_actor.GetMapper().SetInputData(poly) + except Exception: + try: + self.plotter.remove_actor(self._edit_sel_edge_actor) + except Exception: + pass + self._edit_sel_edge_actor = self.plotter.add_mesh(poly, color="yellow", line_width=6, name="interactive_edit_selected_edge") + self._set_selected_idx(None) + self._update_edit_labels() + try: + self.plotter.render() + except Exception: + pass + + def _delete_selected_edge(self): + if self._edit_selected_edge_idx is None or self._edit_edges is None or len(self._edit_edges) == 0: + return + idx = int(self._edit_selected_edge_idx) + self.log(f"Deleted edge: {idx} ({self._edit_edges[idx].tolist()})") + self._edit_edges = np.delete(self._edit_edges, idx, axis=0) + self._edit_selected_edge_idx = None + self._clear_edge_selection_actor() + self._update_edit_edges_actor() + self._update_edit_labels() + try: + self.plotter.render() + except Exception: + pass + + def _find_closest_edge(self, pick_pos): + if self._edit_edges is None or len(self._edit_edges) == 0 or self._edit_points is None: + return None, float("inf") + pts = np.asarray(self._edit_points, dtype=float) + p = np.asarray(pick_pos, dtype=float).reshape(3) + best_idx = None + best_dist = float("inf") + for k, (a, b) in enumerate(self._edit_edges): + a_pt = pts[int(a)] + b_pt = pts[int(b)] + ab = b_pt - a_pt + ab_len2 = np.dot(ab, ab) + if ab_len2 < 1e-24: + d = np.linalg.norm(p - a_pt) + else: + t = np.clip(np.dot(p - a_pt, ab) / ab_len2, 0.0, 1.0) + proj = a_pt + t * ab + d = np.linalg.norm(p - proj) + if d < best_dist: + best_dist = d + best_idx = k + return best_idx, best_dist + + def _delete_selected_interactive_point(self): + if self._edit_selected_idx is None or self._edit_points is None or len(self._edit_points) == 0: + return + idx = int(self._edit_selected_idx) + self._edit_points = np.delete(np.asarray(self._edit_points, dtype=float), idx, axis=0) + if self._edit_edges is not None and len(self._edit_edges) > 0: + keep_idx = [i for i in range(len(self._edit_points) + 1) if i != idx] + remap = {old: new for new, old in enumerate(keep_idx)} + new_edges = [] + for a, b in np.asarray(self._edit_edges, dtype=int): + a = int(a) + b = int(b) + if a in remap and b in remap: + new_edges.append([remap[a], remap[b]]) + self._edit_edges = np.asarray(new_edges, dtype=int).reshape(-1, 2) if new_edges else np.empty((0, 2), dtype=int) + self._update_edit_points_actor() + self._update_edit_edges_actor() + if len(self._edit_points) == 0: + self._set_selected_idx(None) + else: + self._set_selected_idx(min(idx, len(self._edit_points) - 1)) + self._update_edit_labels() + self.log(f"Deleted point: {idx}") + + def _add_interactive_point(self): + return + + def _show_interactive_overlay(self): + self._edit_overlay_dialog = None + self._edit_info_label = None + self._edit_status_label = None + self._edit_btn_edge = None + mode = self._edit_mode or "-" + if mode == "skeleton": + hint = "Edit Skeleton: drag sphere to move | Delete/Backspace to remove | ESC to cancel | click 'Edit Skeleton' again to apply" + elif mode == "graph": + hint = "Edit Graph: drag sphere to move node | Delete/Backspace to remove node/edge | E toggle edge mode | Click edge to select | ESC to cancel | click 'Edit Graph' again to apply" + else: + hint = f"Edit {mode}: ESC to cancel" + try: + self.statusBar().showMessage(hint) + except Exception: + pass + self.log(hint) + + def _close_interactive_overlay(self): + self._edit_overlay_dialog = None + self._edit_info_label = None + self._edit_status_label = None + self._edit_btn_edge = None + try: + self.statusBar().clearMessage() + except Exception: + pass + + def _enter_interactive_edit(self, mode, points, edges=None): + self.scene.invalidate_cache() + self.scene.sync_from_workspace() + self._cleanup_edit_actors() + self._remove_edit_widget() + self._close_interactive_overlay() + self._edit_mode = mode + self._edit_points = np.asarray(points, dtype=float).reshape(-1, 3).copy() + if edges is None: + self._edit_edges = np.empty((0, 2), dtype=int) + else: + arr = np.asarray(edges, dtype=int) + self._edit_edges = arr.reshape(-1, 2).copy() if len(arr) else np.empty((0, 2), dtype=int) + self._edit_original_points = self._edit_points.copy() + self._edit_original_edges = self._edit_edges.copy() + self._edit_selected_idx = None + self._edit_edge_mode = False + self._edit_edge_src_idx = None + self._edit_selected_edge_idx = None + self._update_edit_points_actor() + self._update_edit_edges_actor() + self._enable_interactive_key_events(True) + self._enable_interactive_point_picking(True) + self._show_interactive_overlay() + if len(self._edit_points) > 0: + self._set_selected_idx(0) + else: + self._set_selected_idx(None) + self.log(f"Interactive edit started: {mode}") + try: + self.plotter.render() + except Exception: + pass + + def _exit_interactive_edit(self, apply_changes): + mode = self._edit_mode + if mode is None: + return + try: + self._enable_interactive_key_events(False) + self._enable_interactive_point_picking(False) + self._remove_edit_widget() + if apply_changes: + if mode == "skeleton": + ed = SkeletonEditor(self.workspace) + ed.replace_points(self._edit_points) + self.workspace.remove_object_by_data_key("skeleton_points") + self.workspace.add_object(name="skeleton_points", kind=ObjectKind.SKELETON, data_key="skeleton_points", visible=True, opacity=1.0, color="red", point_size=8) + self.workspace.pipeline.mark_done(StepId.EDIT_SKELETON, skipped=False) + self.log(f"Skeleton edited: {len(self.workspace.skeleton_points)} points") + elif mode == "graph": + self.workspace.graph.points = np.asarray(self._edit_points, dtype=float).reshape(-1, 3) + self.workspace.graph.edges = np.asarray(self._edit_edges, dtype=int).reshape(-1, 2) if len(self._edit_edges) else np.empty((0, 2), dtype=int) + self.workspace.remove_object_by_data_key("graph_lines") + self.workspace.add_object(name="graph_lines", kind=ObjectKind.GRAPH, data_key="graph_lines", visible=True, opacity=1.0, color="blue", line_width=2) + self.workspace.pipeline.mark_done(StepId.EDIT_GRAPH, skipped=False) + self.log(f"Graph edited: {len(self.workspace.graph.points)} nodes, {len(self.workspace.graph.edges)} edges") + else: + self.log(f"Interactive edit cancelled: {mode}") + finally: + self._cleanup_edit_actors() + self._close_interactive_overlay() + self._edit_mode = None + self._edit_points = None + self._edit_edges = None + self._edit_selected_idx = None + self._edit_edge_mode = False + self._edit_edge_src_idx = None + self._edit_selected_edge_idx = None + self._edit_original_points = None + self._edit_original_edges = None + self.scene.invalidate_cache() + self.scene.sync_from_workspace() + self._refresh_all() + try: + self.plotter.render() + except Exception: + pass + + def _force_exit_edit(self): + if self._edit_mode is None: + return + self.log("ESC: force exit interactive edit") + try: + self._exit_interactive_edit(False) + except Exception as e: + self.log(f"Force exit cleanup error: {type(e).__name__}: {e}") + finally: + self._edit_mode = None + self._edit_points = None + self._edit_edges = None + self._edit_selected_idx = None + self._edit_edge_mode = False + self._edit_edge_src_idx = None + self._edit_selected_edge_idx = None + self._edit_original_points = None + self._edit_original_edges = None + self._edit_overlay_dialog = None + self._edit_info_label = None + self._edit_status_label = None + self._edit_btn_edge = None + try: + self.statusBar().clearMessage() + except Exception: + pass + try: + self._cleanup_edit_actors() + except Exception: + pass + try: + self._remove_edit_widget() + except Exception: + pass + try: + self._enable_interactive_key_events(False) + except Exception: + pass + try: + self._enable_interactive_point_picking(False) + except Exception: + pass + try: + self.setEnabled(True) + except Exception: + pass + try: + self.plotter.render() + except Exception: + pass + + def _start_skeleton_interactive_edit(self): + if self.workspace.skeleton_points is None or len(self.workspace.skeleton_points) == 0: + self.log("Edit Skeleton: no skeleton points.") + return + self._enter_interactive_edit("skeleton", self.workspace.skeleton_points, edges=None) + + def _start_graph_interactive_edit(self): + if self.workspace.graph is None or len(self.workspace.graph.points) == 0: + self.log("Edit Graph: no graph data.") + return + self._enter_interactive_edit("graph", self.workspace.graph.points, self.workspace.graph.edges) + + def closeEvent(self, event): + try: + if self._edit_mode is not None: + self._exit_interactive_edit(False) + self._clear_plane_drag_widgets() + except Exception: + pass + event.accept() + + +def main(): + app = QtWidgets.QApplication(sys.argv) + w = MainWindow() + w.show() + sys.exit(app.exec_()) + + +if __name__ == "__main__": + main() diff --git a/autoflow/ui/editors.py b/autoflow/ui/editors.py new file mode 100755 index 0000000..9c8252b --- /dev/null +++ b/autoflow/ui/editors.py @@ -0,0 +1,109 @@ +import numpy as np + +from ..core.models import GraphData, PlaneData + + +class SkeletonEditor: + def __init__(self, workspace): + self.workspace = workspace + + def remove_points_by_index(self, indices): + if self.workspace.skeleton_points is None: + return + pts = np.asarray(self.workspace.skeleton_points) + mask = np.ones(len(pts), dtype=bool) + mask[np.asarray(indices, dtype=int)] = False + self.workspace.skeleton_points = pts[mask] + + def append_points(self, points): + pts = np.asarray(points, dtype=float).reshape(-1, 3) + if self.workspace.skeleton_points is None or len(self.workspace.skeleton_points) == 0: + self.workspace.skeleton_points = pts + else: + self.workspace.skeleton_points = np.vstack([self.workspace.skeleton_points, pts]) + + def replace_points(self, points): + self.workspace.skeleton_points = np.asarray(points, dtype=float).reshape(-1, 3) + + +class GraphEditor: + def __init__(self, workspace): + self.workspace = workspace + + def remove_edges_by_index(self, indices): + edges = np.asarray(self.workspace.graph.edges, dtype=int) + if len(edges) == 0: + return + mask = np.ones(len(edges), dtype=bool) + mask[np.asarray(indices, dtype=int)] = False + self.workspace.graph = GraphData( + points=self.workspace.graph.points.copy(), edges=edges[mask]) + + def append_edges(self, edges): + e = np.asarray(edges, dtype=int).reshape(-1, 2) + if len(self.workspace.graph.edges) == 0: + new_edges = e + else: + new_edges = np.vstack([self.workspace.graph.edges, e]) + self.workspace.graph = GraphData( + points=self.workspace.graph.points.copy(), edges=new_edges) + + def remove_nodes_by_index(self, indices): + points = np.asarray(self.workspace.graph.points, dtype=float) + edges = np.asarray(self.workspace.graph.edges, dtype=int) + rm = set(int(i) for i in indices) + keep_idx = [i for i in range(len(points)) if i not in rm] + remap = {old: new for new, old in enumerate(keep_idx)} + new_points = points[keep_idx] + new_edges = [] + for a, b in edges: + if int(a) in remap and int(b) in remap: + new_edges.append([remap[int(a)], remap[int(b)]]) + self.workspace.graph = GraphData( + points=new_points, + edges=np.asarray(new_edges, dtype=int).reshape(-1, 2) if new_edges else np.empty((0, 2), dtype=int), + ) + + +class PlaneEditor: + def __init__(self, workspace): + self.workspace = workspace + + def add_plane(self, center, normal, label=1, path_index=0, distance=0.0): + n = np.asarray(normal, dtype=float).reshape(3) + n = n / (np.linalg.norm(n) + 1e-12) + self.workspace.planes.append(PlaneData( + center=np.asarray(center, dtype=float).reshape(3), + normal=n, + label=int(label), + path_index=int(path_index), + distance=float(distance), + )) + + def remove_planes_by_index(self, indices): + rm = set(int(i) for i in indices) + self.workspace.planes = [p for i, p in enumerate(self.workspace.planes) if i not in rm] + + def update_plane(self, index, center=None, normal=None, label=None): + p = self.workspace.planes[int(index)] + if center is not None: + p.center = np.asarray(center, dtype=float).reshape(3) + if normal is not None: + n = np.asarray(normal, dtype=float).reshape(3) + p.normal = n / (np.linalg.norm(n) + 1e-12) + if label is not None: + p.label = int(label) + + def replace_planes(self, planes): + out = [] + for p in planes: + n = np.asarray(p["normal"], dtype=float).reshape(3) + n = n / (np.linalg.norm(n) + 1e-12) + out.append(PlaneData( + center=np.asarray(p["center"], dtype=float).reshape(3), + normal=n, + label=int(p.get("label", 1)), + path_index=int(p.get("path_index", 0)), + distance=float(p.get("distance", 0.0)), + )) + self.workspace.planes = out diff --git a/autoflow/ui/launcher.py b/autoflow/ui/launcher.py new file mode 100755 index 0000000..73839b1 --- /dev/null +++ b/autoflow/ui/launcher.py @@ -0,0 +1,29 @@ +import sys + + +_GUI_IMPORTS = {"PyQt5", "pyvistaqt", "matplotlib", "pyvista", "vtk"} + + +def launch_gui() -> None: + try: + from .app import main as app_main + except ModuleNotFoundError as exc: + module_name = exc.name or "" + base_name = module_name.split(".", 1)[0] + if base_name in _GUI_IMPORTS: + raise SystemExit( + "GUI dependencies are not installed. Run `pip install \".[gui]\"` in the repo root first." + ) from exc + raise + app_main() + + +def main() -> None: + try: + launch_gui() + except KeyboardInterrupt: + sys.exit(130) + + +if __name__ == "__main__": + main() diff --git a/autoflow/ui/ortho_viewer.py b/autoflow/ui/ortho_viewer.py new file mode 100755 index 0000000..440a0a2 --- /dev/null +++ b/autoflow/ui/ortho_viewer.py @@ -0,0 +1,555 @@ +import numpy as np +from PyQt5 import QtWidgets, QtCore +from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas +from matplotlib.figure import Figure +from scipy.ndimage import map_coordinates + + +class OrthoViewer(QtWidgets.QWidget): + def __init__(self, workspace, parent=None): + super().__init__(parent) + self.workspace = workspace + self._selected_plane_idx = None + self._scalar_cbar = None + self._cache = {} + self._build_ui() + + def _cached(self, group, key, builder, max_items=24): + bucket = self._cache.setdefault(group, {}) + if key in bucket: + return bucket[key] + value = builder() + if len(bucket) >= max_items: + bucket.clear() + bucket[key] = value + return value + + def _build_ui(self): + layout = QtWidgets.QVBoxLayout(self) + layout.setContentsMargins(2, 2, 2, 2) + layout.setSpacing(2) + + ctrl = QtWidgets.QHBoxLayout() + self.combo_content = QtWidgets.QComboBox() + self.combo_content.addItems([ + "Flow X (cm/s)", "Flow Y (cm/s)", "Flow Z (cm/s)", + "Magnitude", "PC-MRA", "Speed (cm/s)", + "WSS (Pa)", "TKE (J/m³)" + ]) + self.combo_content.setCurrentIndex(4) + self.combo_content.currentIndexChanged.connect(self._on_content_changed) + ctrl.addWidget(QtWidgets.QLabel("Content:")) + ctrl.addWidget(self.combo_content) + ctrl.addStretch() + layout.addLayout(ctrl) + + slider_layout = QtWidgets.QHBoxLayout() + self.slider_x = QtWidgets.QSlider(QtCore.Qt.Horizontal) + self.slider_y = QtWidgets.QSlider(QtCore.Qt.Horizontal) + self.slider_z = QtWidgets.QSlider(QtCore.Qt.Horizontal) + self.label_x = QtWidgets.QLabel("X:0") + self.label_y = QtWidgets.QLabel("Y:0") + self.label_z = QtWidgets.QLabel("Z:0") + for lbl, sl in [(self.label_x, self.slider_x), (self.label_y, self.slider_y), (self.label_z, self.slider_z)]: + sl.setRange(0, 0) + sl.valueChanged.connect(self._on_slider_changed) + slider_layout.addWidget(lbl) + slider_layout.addWidget(sl) + layout.addLayout(slider_layout) + + self.label_value = QtWidgets.QLabel("Voxel: - Value: -") + self.label_plane_metric = QtWidgets.QLabel("Plane metrics: -") + self.label_value.setWordWrap(True) + self.label_plane_metric.setWordWrap(True) + layout.addWidget(self.label_value) + layout.addWidget(self.label_plane_metric) + + self.fig = Figure(figsize=(6.2, 6.6), dpi=80, facecolor="black") + self.canvas = FigureCanvas(self.fig) + self.canvas.setMinimumSize(300, 300) + self.ax_ax = self.fig.add_subplot(2, 2, 1) + self.ax_cor = self.fig.add_subplot(2, 2, 2) + self.ax_sag = self.fig.add_subplot(2, 2, 3) + self.ax_plane = self.fig.add_subplot(2, 2, 4) + for ax in [self.ax_ax, self.ax_cor, self.ax_sag, self.ax_plane]: + ax.set_facecolor("black") + ax.tick_params(colors="white", labelsize=6) + ax.set_xticks([]) + ax.set_yticks([]) + self.fig.subplots_adjust(left=0.03, right=0.96, top=0.96, bottom=0.03, wspace=0.14, hspace=0.24) + layout.addWidget(self.canvas, 1) + + self.canvas.mpl_connect("scroll_event", self._on_scroll) + self.canvas.mpl_connect("button_press_event", self._on_click) + + def _remove_colorbar(self): + if self._scalar_cbar is not None: + try: + self._scalar_cbar.remove() + except Exception: + pass + self._scalar_cbar = None + + def _on_scroll(self, event): + if event.inaxes is None: + return + ax = event.inaxes + factor = 0.8 if event.button == "up" else 1.25 + xlim = ax.get_xlim() + ylim = ax.get_ylim() + xdata = event.xdata if event.xdata is not None else (xlim[0] + xlim[1]) / 2 + ydata = event.ydata if event.ydata is not None else (ylim[0] + ylim[1]) / 2 + new_w = (xlim[1] - xlim[0]) * factor + new_h = (ylim[1] - ylim[0]) * factor + ax.set_xlim(xdata - new_w / 2, xdata + new_w / 2) + ax.set_ylim(ydata - new_h / 2, ydata + new_h / 2) + self.canvas.draw_idle() + + def _on_click(self, event): + if event.inaxes is None or event.xdata is None or event.ydata is None: + return + shape = self._get_volume_shape() + if shape is None: + return + x = int(np.clip(np.round(event.xdata), 0, shape[0] - 1)) + y = int(np.clip(np.round(event.ydata), 0, max(shape[1], shape[2]) - 1)) + cx, cy, cz = self.slider_x.value(), self.slider_y.value(), self.slider_z.value() + if event.inaxes == self.ax_ax: + self._set_cursor(x, int(np.clip(np.round(event.ydata), 0, shape[1] - 1)), cz) + elif event.inaxes == self.ax_cor: + self._set_cursor(x, cy, int(np.clip(np.round(event.ydata), 0, shape[2] - 1))) + elif event.inaxes == self.ax_sag: + self._set_cursor(cx, int(np.clip(np.round(event.xdata), 0, shape[1] - 1)), int(np.clip(np.round(event.ydata), 0, shape[2] - 1))) + + def _set_cursor(self, x, y, z): + self.slider_x.blockSignals(True) + self.slider_y.blockSignals(True) + self.slider_z.blockSignals(True) + self.slider_x.setValue(int(x)) + self.slider_y.setValue(int(y)) + self.slider_z.setValue(int(z)) + self.slider_x.blockSignals(False) + self.slider_y.blockSignals(False) + self.slider_z.blockSignals(False) + self._update_labels() + self.workspace.ortho_cursor = np.array([int(x), int(y), int(z)], dtype=int) + self.refresh() + + def update_slider_ranges(self): + shape = self._get_volume_shape() + if shape is None: + return + self.slider_x.blockSignals(True) + self.slider_y.blockSignals(True) + self.slider_z.blockSignals(True) + self.slider_x.setRange(0, max(0, shape[0] - 1)) + self.slider_y.setRange(0, max(0, shape[1] - 1)) + self.slider_z.setRange(0, max(0, shape[2] - 1)) + self.slider_x.setValue(min(shape[0] // 2, self.slider_x.maximum())) + self.slider_y.setValue(min(shape[1] // 2, self.slider_y.maximum())) + self.slider_z.setValue(min(shape[2] // 2, self.slider_z.maximum())) + self.slider_x.blockSignals(False) + self.slider_y.blockSignals(False) + self.slider_z.blockSignals(False) + self._update_labels() + self.workspace.ortho_cursor = np.array([self.slider_x.value(), self.slider_y.value(), self.slider_z.value()], dtype=int) + self.refresh() + + def _get_volume_shape(self): + ws = self.workspace + if ws.flow_raw is not None and ws.flow_raw.ndim == 5: + return ws.flow_raw.shape[:3] + if ws.mag_raw is not None and ws.mag_raw.ndim == 4: + return ws.mag_raw.shape[:3] + if ws.segmask_3d is not None: + return ws.segmask_3d.shape[:3] + return None + + def _update_labels(self): + self.label_x.setText(f"X:{self.slider_x.value()}") + self.label_y.setText(f"Y:{self.slider_y.value()}") + self.label_z.setText(f"Z:{self.slider_z.value()}") + + def _on_slider_changed(self, _): + self._update_labels() + self.workspace.ortho_cursor = np.array([self.slider_x.value(), self.slider_y.value(), self.slider_z.value()], dtype=int) + self.refresh() + + def _on_content_changed(self, _): + self.refresh() + + def set_selected_plane(self, idx): + self._selected_plane_idx = idx + self._move_to_plane_center(idx) + self.refresh() + + def _move_to_plane_center(self, idx): + ws = self.workspace + if idx is None or idx >= len(ws.planes): + return + plane = ws.planes[idx] + res = self._get_resolution() + center_vox = np.asarray(plane.center, dtype=float) / (res + 1e-12) + shape = self._get_volume_shape() + if shape is None: + return + self._set_cursor( + int(np.clip(np.round(center_vox[0]), 0, shape[0] - 1)), + int(np.clip(np.round(center_vox[1]), 0, shape[1] - 1)), + int(np.clip(np.round(center_vox[2]), 0, shape[2] - 1)), + ) + + def _get_resolution(self): + ws = self.workspace + if ws.resolution is not None and len(ws.resolution) >= 3: + r = np.asarray(ws.resolution, dtype=float).reshape(-1)[:3] + return np.where(r > 0, r, 1.0) + return np.array([1.0, 1.0, 1.0]) + + def _scene_style(self, data_key, default_cmap, default_clim=None): + for obj in self.workspace.scene_objects.values(): + if obj.data_key == data_key: + return obj.cmap or default_cmap, obj.clim if obj.clim else default_clim + return default_cmap, default_clim + + def _get_wss_volume(self, t): + ws = self.workspace + if ws.derived.wss_volume is not None: + cmap, clim = self._scene_style("wss_surface_live", "jet", None) + tidx = min(max(0, int(t)), ws.derived.wss_volume.shape[3] - 1) + return np.asarray(ws.derived.wss_volume[..., tidx], dtype=float), "WSS (Pa)", {"cmap": cmap, "clim": clim} + if not ws.derived.wss_surfaces: + return None, "WSS (no data)", {"cmap": "jet", "clim": None} + tidx = min(max(0, t), len(ws.derived.wss_surfaces) - 1) + surf = ws.derived.wss_surfaces[tidx] + if surf is None or "wss" not in surf.point_data: + return None, "WSS (no data)", {"cmap": "jet", "clim": None} + shape = self._get_volume_shape() + if shape is None: + return None, "WSS (no data)", {"cmap": "jet", "clim": None} + res = self._get_resolution() + key = ( + id(surf), + tuple(int(x) for x in shape), + tuple(np.round(res, 6).tolist()), + ) + def _build(): + vol = np.zeros(shape, dtype=float) + pts = np.asarray(surf.points, dtype=float) + vals = np.asarray(surf.point_data["wss"], dtype=float) + vox = np.rint(pts / (res.reshape(1, 3) + 1e-12)).astype(int) + for k in range(3): + vox[:, k] = np.clip(vox[:, k], 0, shape[k] - 1) + flat = np.ravel_multi_index((vox[:, 0], vox[:, 1], vox[:, 2]), shape) + tgt = vol.reshape(-1) + np.maximum.at(tgt, flat, vals) + return vol + vol = self._cached("wss_volume", key, _build) + cmap, clim = self._scene_style("wss_surface_live", "jet", None) + return vol, "WSS (Pa)", {"cmap": cmap, "clim": clim} + + def _get_tke_volume(self, t): + ws = self.workspace + if ws.derived.tke_array is not None: + arr = np.asarray(ws.derived.tke_array, dtype=float) + tidx = min(max(0, int(t)), arr.shape[3] - 1) if arr.ndim == 4 else 0 + mask_id = id(ws.segmask_binary) if ws.segmask_binary is not None else -1 + key = (id(ws.derived.tke_array), mask_id, int(tidx)) + def _build(): + if arr.ndim == 4: + vol = arr[..., tidx] + else: + vol = arr + if ws.segmask_binary is not None: + if ws.segmask_binary.ndim == 4: + mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] + else: + mask_t = ws.segmask_binary + vol = np.asarray(vol, dtype=float) * np.asarray(mask_t, dtype=float) + return np.asarray(vol, dtype=float) + vol = self._cached("tke_volume", key, _build) + cmap, clim = self._scene_style("tke_volume", "hot", None) + return vol, "TKE (J/m³)", {"cmap": cmap, "clim": clim} + if ws.derived.tke_volume is None: + return None, "TKE (no data)", {"cmap": "hot", "clim": None} + shape = self._get_volume_shape() + if shape is None: + return None, "TKE (no data)", {"cmap": "hot", "clim": None} + tke_mesh = ws.derived.tke_volume + if "TKE" not in tke_mesh.point_data and "TKE" not in tke_mesh.cell_data: + return None, "TKE (no data)", {"cmap": "hot", "clim": None} + res = self._get_resolution() + key = ( + id(tke_mesh), + tuple(int(x) for x in shape), + tuple(np.round(res, 6).tolist()), + ) + def _build(): + vol = np.zeros(shape, dtype=float) + if "TKE" in tke_mesh.cell_data: + pts = tke_mesh.cell_centers().points + vals = np.asarray(tke_mesh.cell_data["TKE"], dtype=float) + else: + pts = tke_mesh.points + vals = np.asarray(tke_mesh.point_data["TKE"], dtype=float) + vox = np.rint(pts / (res.reshape(1, 3) + 1e-12)).astype(int) + for k in range(3): + vox[:, k] = np.clip(vox[:, k], 0, shape[k] - 1) + flat = np.ravel_multi_index((vox[:, 0], vox[:, 1], vox[:, 2]), shape) + tgt = vol.reshape(-1) + np.maximum.at(tgt, flat, vals) + return vol + vol = self._cached("tke_mesh_volume", key, _build) + cmap, clim = self._scene_style("tke_volume", "hot", None) + return vol, "TKE (J/m³)", {"cmap": cmap, "clim": clim} + def _get_scalar_slice(self, t): + ws = self.workspace + content_idx = self.combo_content.currentIndex() + if content_idx == 0 and ws.flow_raw is not None: + vol = np.asarray(ws.flow_raw[..., t, 0], dtype=float) + vmax = max(abs(np.nanmin(vol)), abs(np.nanmax(vol)), 1e-6) + return vol, "Flow X (cm/s)", {"cmap": "RdBu_r", "clim": (-vmax, vmax)} + if content_idx == 1 and ws.flow_raw is not None: + vol = np.asarray(ws.flow_raw[..., t, 1], dtype=float) + vmax = max(abs(np.nanmin(vol)), abs(np.nanmax(vol)), 1e-6) + return vol, "Flow Y (cm/s)", {"cmap": "RdBu_r", "clim": (-vmax, vmax)} + if content_idx == 2 and ws.flow_raw is not None: + vol = np.asarray(ws.flow_raw[..., t, 2], dtype=float) + vmax = max(abs(np.nanmin(vol)), abs(np.nanmax(vol)), 1e-6) + return vol, "Flow Z (cm/s)", {"cmap": "RdBu_r", "clim": (-vmax, vmax)} + if content_idx == 3 and ws.mag_raw is not None: + vol = np.asarray(ws.mag_raw[..., t], dtype=float) + return vol, "Magnitude", {"cmap": "gray", "clim": (float(np.nanmin(vol)), float(np.nanmax(vol)))} + if content_idx == 4 and ws.mag_raw is not None and ws.flow_raw is not None: + key = (id(ws.mag_raw), id(ws.flow_raw), int(t)) + def _build(): + speed = np.sqrt(np.sum(ws.flow_raw[..., t, :] ** 2, axis=-1)) + return np.asarray(ws.mag_raw[..., t], dtype=float) * np.asarray(speed, dtype=float) + vol = self._cached("scalar_volume", ("pcmra",) + key, _build) + return vol, "PC-MRA", {"cmap": "gray", "clim": (float(np.nanmin(vol)), float(np.nanmax(vol)))} + if content_idx == 5 and ws.flow_raw is not None: + key = (id(ws.flow_raw), int(t)) + vol = self._cached( + "scalar_volume", + ("speed",) + key, + lambda: np.sqrt(np.sum(ws.flow_raw[..., t, :] ** 2, axis=-1)), + ) + return np.asarray(vol, dtype=float), "Speed (cm/s)", {"cmap": "turbo", "clim": (0.0, float(np.nanmax(vol)) if np.nanmax(vol) > 0 else 1.0)} + if content_idx == 6: + return self._get_wss_volume(t) + if content_idx == 7: + return self._get_tke_volume(t) + return None, "", {"cmap": "gray", "clim": None} + + def _get_mask_3d(self): + ws = self.workspace + if ws.segmask_3d is not None: + return ws.segmask_3d + if ws.segmask_binary is not None and ws.segmask_binary.ndim == 4: + key = (id(ws.segmask_binary), tuple(int(x) for x in ws.segmask_binary.shape)) + return self._cached("mask_3d", key, lambda: np.any(ws.segmask_binary, axis=3)) + return None + + def _update_value_label(self, vol, title): + if vol is None: + self.label_value.setText("Voxel: - Value: -") + return + x, y, z = self.slider_x.value(), self.slider_y.value(), self.slider_z.value() + try: + val = float(vol[x, y, z]) + self.label_value.setText(f"Voxel: ({x}, {y}, {z}) {title}: {val:.6g}") + except Exception: + self.label_value.setText(f"Voxel: ({x}, {y}, {z}) {title}: -") + + def _update_plane_metric_label(self): + ws = self.workspace + if self._selected_plane_idx is None or self._selected_plane_idx >= len(ws.planes): + self.label_plane_metric.setText("Plane metrics: -") + return + plane = ws.planes[self._selected_plane_idx] + metrics = plane.metrics or {} + t = int(ws.current_t) + fr = metrics.get("flowrate_mL_s", []) + ar = metrics.get("area_mm2", []) + mv = metrics.get("meanv_cm_s_t", []) + cur_fr = float(fr[t]) if t < len(fr) else 0.0 + cur_ar = float(ar[t]) if t < len(ar) else 0.0 + cur_mv = float(mv[t]) if t < len(mv) else metrics.get("meanv_cm_s", 0.0) + path_direction = metrics.get("path_direction", "") + path_ic = metrics.get("path_ic", None) + txt = f"Plane {self._selected_plane_idx}" + if path_direction: + txt += f" [{path_direction}]" + txt += f" t={t} Flow Rate={cur_fr:.4g} mL/s Area={cur_ar:.4g} mm² Mean Velocity={cur_mv:.4g} cm/s Peak Velocity={metrics.get('peakv_cm_s', 0.0):.4g} cm/s" + if path_ic is not None: + txt += f" Path IC={float(path_ic):.3f}" + self.label_plane_metric.setText(txt) + + def refresh(self): + ws = self.workspace + t = int(ws.current_t) + shape = self._get_volume_shape() + if shape is None: + self._remove_colorbar() + self.canvas.draw_idle() + return + + cx, cy, cz = self.slider_x.value(), self.slider_y.value(), self.slider_z.value() + vol, title, style = self._get_scalar_slice(t) + mask_3d = self._get_mask_3d() + res = self._get_resolution() + + for ax in [self.ax_ax, self.ax_cor, self.ax_sag]: + ax.clear() + ax.set_facecolor("black") + ax.set_xticks([]) + ax.set_yticks([]) + + self._remove_colorbar() + im = None + if vol is not None: + cmap = style.get("cmap", "gray") + clim = style.get("clim", None) + if clim is None: + clim = (float(np.nanmin(vol)), float(np.nanmax(vol))) + axial = vol[:, :, cz] + im = self.ax_ax.imshow(axial.T, origin="lower", cmap=cmap, vmin=clim[0], vmax=clim[1], aspect=float(res[1] / res[0])) + self.ax_ax.axhline(cy, color="lime", linewidth=0.5, alpha=0.5) + self.ax_ax.axvline(cx, color="lime", linewidth=0.5, alpha=0.5) + self.ax_ax.plot(cx, cy, "r+", markersize=8, markeredgewidth=1.5) + self.ax_ax.set_title(f"Axial Z={cz}", color="white", fontsize=8) + if mask_3d is not None and cz < mask_3d.shape[2]: + try: + self.ax_ax.contour(mask_3d[:, :, cz].astype(float).T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") + except Exception: + pass + + coronal = vol[:, cy, :] + self.ax_cor.imshow(coronal.T, origin="lower", cmap=cmap, vmin=clim[0], vmax=clim[1], aspect=float(res[2] / res[0])) + self.ax_cor.axhline(cz, color="lime", linewidth=0.5, alpha=0.5) + self.ax_cor.axvline(cx, color="lime", linewidth=0.5, alpha=0.5) + self.ax_cor.plot(cx, cz, "r+", markersize=8, markeredgewidth=1.5) + self.ax_cor.set_title(f"Coronal Y={cy}", color="white", fontsize=8) + if mask_3d is not None and cy < mask_3d.shape[1]: + try: + self.ax_cor.contour(mask_3d[:, cy, :].astype(float).T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") + except Exception: + pass + + sagittal = vol[cx, :, :] + self.ax_sag.imshow(sagittal.T, origin="lower", cmap=cmap, vmin=clim[0], vmax=clim[1], aspect=float(res[2] / res[1])) + self.ax_sag.axhline(cz, color="lime", linewidth=0.5, alpha=0.5) + self.ax_sag.axvline(cy, color="lime", linewidth=0.5, alpha=0.5) + self.ax_sag.plot(cy, cz, "r+", markersize=8, markeredgewidth=1.5) + self.ax_sag.set_title(f"Sagittal X={cx}", color="white", fontsize=8) + if mask_3d is not None and cx < mask_3d.shape[0]: + try: + self.ax_sag.contour(mask_3d[cx, :, :].astype(float).T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") + except Exception: + pass + + if self.combo_content.currentIndex() in (6, 7): + self._scalar_cbar = self.fig.colorbar(im, ax=[self.ax_ax, self.ax_cor, self.ax_sag], fraction=0.025, pad=0.01) + self._scalar_cbar.ax.tick_params(labelsize=6, colors="white") + self._scalar_cbar.set_label(title, color="white", fontsize=7) + try: + self._scalar_cbar.outline.set_edgecolor("white") + except Exception: + pass + + self._draw_plane_flow(t) + self._update_value_label(vol, title) + self._update_plane_metric_label() + self.fig.subplots_adjust(left=0.03, right=0.96, top=0.96, bottom=0.03, wspace=0.14, hspace=0.24) + self.canvas.draw_idle() + + def _resample_oblique(self, volume_3d, center_vox, normal, half_size=30): + normal = np.asarray(normal, dtype=float) + normal = normal / (np.linalg.norm(normal) + 1e-12) + up_hint = np.array([0.0, 1.0, 0.0]) if abs(normal[2]) > max(abs(normal[0]), abs(normal[1])) else np.array([0.0, 0.0, 1.0]) + u = np.cross(normal, up_hint) + u = u / (np.linalg.norm(u) + 1e-12) + v = np.cross(normal, u) + v = v / (np.linalg.norm(v) + 1e-12) + ii = np.arange(-half_size, half_size + 1, dtype=float) + jj = np.arange(-half_size, half_size + 1, dtype=float) + gi, gj = np.meshgrid(ii, jj, indexing="ij") + coords = center_vox.reshape(1, 1, 3) + gi[..., None] * u.reshape(1, 1, 3) + gj[..., None] * v.reshape(1, 1, 3) + sampled = map_coordinates(volume_3d, [coords[..., 0].ravel(), coords[..., 1].ravel(), coords[..., 2].ravel()], order=1, mode="constant", cval=0.0) + return sampled.reshape(len(ii), len(jj)) + + def _draw_plane_flow(self, t): + self.ax_plane.clear() + self.ax_plane.set_facecolor("black") + self.ax_plane.set_xticks([]) + self.ax_plane.set_yticks([]) + ws = self.workspace + if self._selected_plane_idx is None or self._selected_plane_idx >= len(ws.planes): + self.ax_plane.set_title("Plane Through-Plane Velocity (select a plane)", color="white", fontsize=8) + return + if ws.flow_raw is None: + self.ax_plane.set_title("Plane Through-Plane Velocity (no flow data)", color="white", fontsize=8) + return + plane = ws.planes[self._selected_plane_idx] + res = self._get_resolution() + center_vox = np.asarray(plane.center, dtype=float) / (res + 1e-12) + normal = np.asarray(plane.normal, dtype=float) + normal = normal / (np.linalg.norm(normal) + 1e-12) + flow_t = ws.flow_raw[..., t, :] + shape = self._get_volume_shape() + half_size = max(10, min(shape) // 2) + plane_key = ( + int(self._selected_plane_idx), + int(t), + tuple(np.round(center_vox, 4).tolist()), + tuple(np.round(normal, 6).tolist()), + int(half_size), + id(ws.flow_raw), + ) + def _build_plane_flow(): + proj = flow_t[..., 0] * normal[0] + flow_t[..., 1] * normal[1] + flow_t[..., 2] * normal[2] + return self._resample_oblique(proj, center_vox, normal, half_size=half_size) + sl = self._cached("plane_flow", plane_key, _build_plane_flow) + vmax = max(abs(np.nanmin(sl)), abs(np.nanmax(sl)), 1e-6) + self.ax_plane.imshow(sl.T, origin="lower", cmap="RdBu_r", vmin=-vmax, vmax=vmax, aspect=1.0) + self.ax_plane.plot(half_size, half_size, "r+", markersize=10, markeredgewidth=2) + mask_3d = self._get_mask_3d() + if mask_3d is not None: + mask_key = ( + int(self._selected_plane_idx), + tuple(np.round(center_vox, 4).tolist()), + tuple(np.round(normal, 6).tolist()), + int(half_size), + id(mask_3d), + ) + m_sl = self._cached( + "plane_mask", + mask_key, + lambda: self._resample_oblique(mask_3d.astype(float), center_vox, normal, half_size=half_size), + ) + try: + self.ax_plane.contour(m_sl.T, levels=[0.5], colors="cyan", linewidths=0.5, origin="lower") + except Exception: + pass + metrics = plane.metrics or {} + txt = f"Plane {self._selected_plane_idx} Through-Plane Velocity [{-vmax:.2f}, {vmax:.2f}] cm/s" + if metrics: + fr = metrics.get("flowrate_mL_s", []) + ar = metrics.get("area_mm2", []) + flow_txt = float(fr[t]) if t < len(fr) else 0.0 + area_txt = float(ar[t]) if t < len(ar) else 0.0 + txt += f"\nFlow Rate={flow_txt:.4g} mL/s Area={area_txt:.4g} mm²" + self.ax_plane.set_title(txt, color="white", fontsize=7) + + def reset_state(self): + self._selected_plane_idx = None + self._remove_colorbar() + self._cache.clear() + self.label_value.setText("Voxel: - Value: -") + self.label_plane_metric.setText("Plane metrics: -") + for ax in [self.ax_ax, self.ax_cor, self.ax_sag, self.ax_plane]: + ax.clear() + ax.set_facecolor("black") + ax.set_xticks([]) + ax.set_yticks([]) + self.canvas.draw_idle() diff --git a/autoflow/ui/viewer.py b/autoflow/ui/viewer.py new file mode 100755 index 0000000..2b3e036 --- /dev/null +++ b/autoflow/ui/viewer.py @@ -0,0 +1,866 @@ +import numpy as np +import pyvista as pv + +from ..core.models import ObjectKind +from ..algorithms import ( + build_multilabel_surface_t, + build_surface_from_mask3d, + graph_to_polydata, + generate_seed_points, + generate_streamlines_at_t, + generate_streamlines_from_plane_at_t, + create_uniform_grid, +) + + +def _parse_indexed_data_key(data_key, prefix): + token = f"{prefix}_" + if not isinstance(data_key, str) or not data_key.startswith(token): + return None + suffix = data_key[len(token):] + if suffix.isdigit(): + return int(suffix) + return None + + +def _path_polydata(path, origin): + pts = np.asarray(path, dtype=float) + if len(pts) == 0: + return None + poly = pv.PolyData(pts + np.asarray(origin, dtype=float).reshape(1, 3)) + if len(pts) >= 2: + cells = np.empty((len(pts) - 1, 3), dtype=np.int64) + cells[:, 0] = 2 + cells[:, 1] = np.arange(len(pts) - 1) + cells[:, 2] = np.arange(1, len(pts)) + poly.lines = cells.ravel() + return poly + + +class SceneController: + def __init__(self, plotter, workspace, logger): + self.plotter = plotter + self.workspace = workspace + self.logger = logger + self._axes_shown = True + self._mesh_cache = {} + self._tracked_actors = {} + self._saved_camera = None + self._playback_active = False + self._highlight_plane_uid = None + self._highlight_plane_actor = None + self._highlight_path_uid = None + self._highlight_path_actor = None + self._context_path_actors = [] + self._highlight_fork_actor = None + self._plane_pick_obs_id = None + self._path_pick_obs_id = None + self._plane_pick_callback = None + self._path_pick_callback = None + self._shared_pick_obs_id = None + + def initialize(self): + self.plotter.set_background("white") + self.plotter.add_axes(line_width=2) + self.plotter.reset_camera() + + def reset_scene(self): + try: + self.plotter.clear() + except Exception: + try: + self.plotter.renderer.RemoveAllViewProps() + except Exception: + pass + for obj in self.workspace.scene_objects.values(): + obj.actor = None + obj.label_actor = None + self._tracked_actors.clear() + self._mesh_cache.clear() + self._remove_plane_highlight() + self._remove_path_highlight() + self.initialize() + + def invalidate_cache(self, prefix=None): + if prefix is None: + self._mesh_cache.clear() + else: + self._mesh_cache = {k: v for k, v in self._mesh_cache.items() if not k[0].startswith(prefix)} + + def set_background(self, color): + self.plotter.set_background(color) + self.render_all() + + def toggle_axes(self): + self._axes_shown = not self._axes_shown + self.reset_scene() + if not self._axes_shown: + try: + self.plotter.hide_axes() + except Exception: + pass + self.render_all() + + def reset_camera(self): + try: + self.plotter.reset_camera() + self.plotter.render() + except Exception: + pass + + def save_camera(self): + try: + self._saved_camera = self.plotter.camera_position + except Exception: + self._saved_camera = None + + def restore_camera(self): + if self._saved_camera is not None: + try: + self.plotter.camera_position = self._saved_camera + except Exception: + pass + + def set_playback_active(self, active): + self._playback_active = active + if active: + self.save_camera() + + def sync_from_workspace(self): + current_uids = set(self.workspace.scene_objects.keys()) + stale = set(self._tracked_actors.keys()) - current_uids + for uid in stale: + actor = self._tracked_actors.pop(uid, None) + if actor is not None: + try: + self.plotter.remove_actor(actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(actor) + except Exception: + pass + self.render_all() + + def remove_object(self, uid): + obj = self.workspace.scene_objects.get(uid) + if obj is not None: + self._remove_actor(obj) + del self.workspace.scene_objects[uid] + actor = self._tracked_actors.pop(uid, None) + if actor is not None: + try: + self.plotter.remove_actor(actor) + except Exception: + pass + if self._highlight_plane_uid == uid: + self._remove_plane_highlight() + if self._highlight_path_uid == uid: + self._remove_path_highlight() + + def render_all(self): + for obj in self.workspace.scene_objects.values(): + self._render_object(obj) + try: + self.plotter.render() + except Exception: + pass + + def update_time(self, t): + self.workspace.current_t = int(t) + cam_before = None + if self._playback_active: + try: + cam_before = self.plotter.camera_position + except Exception: + cam_before = None + for obj in self.workspace.scene_objects.values(): + if obj.dynamic: + self.readd_object(obj) + if self._playback_active and cam_before is not None: + try: + self.plotter.camera_position = cam_before + except Exception: + pass + try: + self.plotter.render() + except Exception: + pass + + def rebuild_dynamic(self): + for obj in self.workspace.scene_objects.values(): + if obj.dynamic: + self.readd_object(obj) + + def readd_object(self, obj): + self._remove_actor(obj) + self._render_object(obj) + + def apply_object_properties(self, obj): + if obj.actor is None: + self._render_object(obj) + return + try: + obj.actor.SetVisibility(1 if obj.visible else 0) + except Exception: + pass + try: + prop = obj.actor.GetProperty() + prop.SetOpacity(float(obj.opacity)) + prop.SetLineWidth(float(obj.line_width)) + prop.SetPointSize(float(obj.point_size)) + except Exception: + pass + if obj.visible: + self.readd_object(obj) + return + try: + self.plotter.render() + except Exception: + pass + + def highlight_plane(self, uid): + self._remove_plane_highlight() + self._highlight_plane_uid = uid + if uid is None: + try: + self.plotter.render() + except Exception: + pass + return + obj = self.workspace.scene_objects.get(uid) + if obj is None or obj.kind != ObjectKind.PLANE: + self._highlight_plane_uid = None + return + data = self._build_dataset(obj.data_key) + if data is None: + return + try: + self._highlight_plane_actor = self.plotter.add_mesh( + data, color="magenta", opacity=0.9, line_width=4, + style="wireframe", name="__plane_highlight__") + self._promote_overlay_actor(self._highlight_plane_actor) + except Exception: + self._highlight_plane_actor = None + try: + self.plotter.render() + except Exception: + pass + + def highlight_path(self, uid): + self._remove_path_highlight() + self._highlight_path_uid = uid + if uid is None: + try: + self.plotter.render() + except Exception: + pass + return + obj = self.workspace.scene_objects.get(uid) + if obj is None or obj.kind != ObjectKind.BRANCH: + self._highlight_path_uid = None + return + data = self._build_dataset(obj.data_key) + if data is None: + return + try: + self._highlight_path_actor = self.plotter.add_mesh( + data, color="magenta", opacity=1.0, line_width=8, + render_lines_as_tubes=True, + name="__path_highlight__") + self._promote_overlay_actor(self._highlight_path_actor) + except Exception: + self._highlight_path_actor = None + try: + self.plotter.render() + except Exception: + pass + + def show_forks_for_path(self, path_idx): + self._clear_fork_and_context_actors() + if int(path_idx) < 0: + try: + self.plotter.render() + except Exception: + pass + return + org = np.asarray(self.workspace.origin, dtype=float).reshape(3) + pts = [] + incoming_ids = set() + outgoing_ids = set() + if 0 <= int(path_idx) < len(self.workspace.path_info): + info = self.workspace.path_info[int(path_idx)] + incoming_ids.update(int(x) for x in info.get("incoming_path_ids", [])) + outgoing_ids.update(int(x) for x in info.get("outgoing_path_ids", [])) + for fork in self.workspace.forks: + if int(path_idx) in fork.get("left", []) or int(path_idx) in fork.get("right", []): + pts.append(np.asarray(fork.get("crosspoint", [0.0, 0.0, 0.0]), dtype=float) + org) + incoming_ids.update(int(x) for x in fork.get("left", []) if int(x) != int(path_idx)) + outgoing_ids.update(int(x) for x in fork.get("right", []) if int(x) != int(path_idx)) + incoming_ids.discard(int(path_idx)) + outgoing_ids.discard(int(path_idx)) + for pid, color in [(sorted(incoming_ids), "deepskyblue"), (sorted(outgoing_ids), "orange")]: + for idx in pid: + if not (0 <= int(idx) < len(self.workspace.centerline_paths_smooth)): + continue + poly = _path_polydata(self.workspace.centerline_paths_smooth[int(idx)], org) + if poly is None: + continue + try: + actor = self.plotter.add_mesh( + poly, color=color, opacity=1.0, line_width=8, + render_lines_as_tubes=True, + name=f"__path_context_{color}_{int(idx)}__") + self._promote_overlay_actor(actor) + self._context_path_actors.append(actor) + except Exception: + pass + if pts: + try: + poly = pv.PolyData(np.asarray(pts, dtype=float).reshape(-1, 3)) + self._highlight_fork_actor = self.plotter.add_mesh( + poly, color="magenta", point_size=22, render_points_as_spheres=True, + name="__fork_highlight__") + self._promote_overlay_actor(self._highlight_fork_actor) + except Exception: + self._highlight_fork_actor = None + try: + self.plotter.render() + except Exception: + pass + + def _remove_plane_highlight(self): + if self._highlight_plane_actor is not None: + try: + self.plotter.remove_actor(self._highlight_plane_actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(self._highlight_plane_actor) + except Exception: + pass + self._highlight_plane_actor = None + self._highlight_plane_uid = None + + def _remove_path_highlight(self): + if self._highlight_path_actor is not None: + try: + self.plotter.remove_actor(self._highlight_path_actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(self._highlight_path_actor) + except Exception: + pass + self._highlight_path_actor = None + self._highlight_path_uid = None + self._clear_fork_and_context_actors() + + def _clear_fork_and_context_actors(self): + if self._highlight_fork_actor is not None: + try: + self.plotter.remove_actor(self._highlight_fork_actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(self._highlight_fork_actor) + except Exception: + pass + self._highlight_fork_actor = None + for actor in list(self._context_path_actors): + if actor is not None: + try: + self.plotter.remove_actor(actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(actor) + except Exception: + pass + self._context_path_actors = [] + try: + self.plotter.remove_actor("__fork_highlight__") + except Exception: + pass + try: + renderer = self.plotter.renderer + actors_to_remove = [] + it = renderer.GetActors() + it.InitTraversal() + for _ in range(it.GetNumberOfItems()): + a = it.GetNextItem() + if a is not None: + try: + name = a.GetObjectName() if hasattr(a, "GetObjectName") else "" + if name and ("__path_context_" in name or "__fork_highlight__" in name): + actors_to_remove.append(a) + except Exception: + pass + for a in actors_to_remove: + try: + renderer.RemoveActor(a) + except Exception: + pass + except Exception: + pass + + + def _promote_overlay_actor(self, actor): + if actor is None: + return + try: + actor.PickableOff() + except Exception: + pass + try: + prop = actor.GetProperty() + prop.SetLighting(False) + except Exception: + pass + + def refresh_plane_labels(self): + pass + + def remove_all_plane_labels(self): + pass + + def _remove_actor(self, obj): + if obj.actor is not None: + try: + self.plotter.remove_actor(obj.actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(obj.actor) + except Exception: + pass + if getattr(obj, "label_actor", None) is not None: + try: + self.plotter.remove_actor(obj.label_actor) + except Exception: + try: + self.plotter.renderer.RemoveActor(obj.label_actor) + except Exception: + pass + self._tracked_actors.pop(obj.uid, None) + obj.actor = None + obj.label_actor = None + + def _render_object(self, obj): + if not obj.visible: + if obj.actor is not None: + try: + obj.actor.SetVisibility(0) + except Exception: + pass + return + data = self._build_dataset(obj.data_key) + if data is None: + self._remove_actor(obj) + return + if obj.actor is not None: + self._remove_actor(obj) + kwargs = self._mesh_kwargs(obj, data) + try: + if obj.tube_radius > 0 and hasattr(data, "tube") and obj.kind.value in ("Graph", "Branch", "Flow", "Metric", "Skeleton"): + data_show = data.tube(radius=float(obj.tube_radius)) + else: + data_show = data + obj.actor = self.plotter.add_mesh(data_show, name=obj.uid, **kwargs) + self._tracked_actors[obj.uid] = obj.actor + self._apply_basic_properties_only(obj) + except Exception as e: + self.logger(f"Render failed: {obj.name}: {type(e).__name__}: {e}") + + def _apply_basic_properties_only(self, obj): + try: + obj.actor.SetVisibility(1 if obj.visible else 0) + except Exception: + pass + try: + prop = obj.actor.GetProperty() + prop.SetOpacity(float(obj.opacity)) + prop.SetLineWidth(float(obj.line_width)) + prop.SetPointSize(float(obj.point_size)) + except Exception: + pass + + def _mesh_kwargs(self, obj, data): + kw = {"opacity": float(obj.opacity), "show_scalar_bar": bool(obj.show_scalar_bar)} + use_scalars = False + if obj.scalars: + if hasattr(data, "point_data") and obj.scalars in data.point_data: + use_scalars = True + if hasattr(data, "cell_data") and obj.scalars in data.cell_data: + use_scalars = True + if use_scalars: + kw["scalars"] = obj.scalars + kw["cmap"] = obj.cmap + if obj.clim: + kw["clim"] = obj.clim + if obj.scalar_bar_title: + kw["scalar_bar_args"] = { + "title": obj.scalar_bar_title, + "vertical": True, + "title_font_size": 14, + "label_font_size": 12, + "n_labels": 5, + "fmt": "%.3g", + } + else: + kw["color"] = obj.color + if obj.kind.value in ("Skeleton", "Aux"): + kw["render_points_as_spheres"] = True + kw["point_size"] = obj.point_size + if obj.kind.value in ("Graph", "Branch", "Flow", "Aux"): + kw["line_width"] = obj.line_width + kw["render_lines_as_tubes"] = True + if obj.kind == ObjectKind.PLANE: + kw["show_edges"] = True + kw["edge_color"] = "black" + kw["line_width"] = max(float(obj.line_width), 2.0) + return kw + + def _build_dataset(self, data_key): + ws = self.workspace + t = ws.current_t + sp = ws.resolution + org = ws.origin + + if data_key == "segmask_raw_surface": + if ws.segmask_raw is None: + return None + return self._cached(data_key, t, lambda: build_multilabel_surface_t(ws.segmask_raw, t, sp, org)) + + if data_key == "segmask_pre_surface": + if ws.segmask_labels is None: + return None + return self._cached(data_key, t, lambda: build_multilabel_surface_t(ws.segmask_labels, t, sp, org)) + + if data_key == "segmask_3d_surface": + if ws.segmask_3d is None: + return None + return self._cached(data_key, 0, lambda: build_surface_from_mask3d(ws.segmask_3d, sp, org, smooth_iter=1000)) + + if data_key == "skeleton_points": + if ws.skeleton_points is None or len(ws.skeleton_points) == 0: + return None + return pv.PolyData(np.asarray(ws.skeleton_points, dtype=float) + np.asarray(org, dtype=float).reshape(1, 3)) + + if data_key == "skeleton_mask_surface": + if ws.skeleton_mask is None: + return None + return self._cached(data_key, 0, lambda: build_surface_from_mask3d(ws.skeleton_mask, sp, org, smooth_iter=1000)) + + if data_key == "graph_lines": + if ws.graph is None or len(ws.graph.points) == 0: + return None + return graph_to_polydata(np.asarray(ws.graph.points) + np.asarray(org).reshape(1, 3), ws.graph.edges) + + if data_key == "streamlines_live": + return self._get_streamline_mesh(t) + + if data_key == "plane_streamlines_live": + return self._get_plane_streamline_mesh(t) + + if data_key == "wss_surface_live": + if not ws.derived.wss_surfaces: + return None + return ws.derived.wss_surfaces[min(max(0, t), len(ws.derived.wss_surfaces) - 1)] + + if data_key == "tke_volume": + if ws.derived.tke_array is not None: + def _build_tke_t(): + arr = np.asarray(ws.derived.tke_array, dtype=np.float32) + if arr.ndim == 4: + vol_t = arr[..., min(max(0, int(t)), arr.shape[3] - 1)] + else: + vol_t = arr + if ws.segmask_binary is not None: + if ws.segmask_binary.ndim == 4: + mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] + else: + mask_t = ws.segmask_binary + elif ws.segmask_3d is not None: + mask_t = ws.segmask_3d + else: + mask_t = np.ones(vol_t.shape, dtype=bool) + vol_t = vol_t * np.asarray(mask_t, dtype=np.float32) + + tke_grid = create_uniform_grid(vol_t, sp, origin=org, name="TKE") + mask_grid = create_uniform_grid(np.asarray(mask_t, dtype=np.float32), sp, origin=org, name="mask") + mask_mesh = mask_grid.threshold(0.1, scalars="mask") + if mask_mesh is None or mask_mesh.n_cells == 0: + return None + return mask_mesh.sample(tke_grid) + return self._cached(data_key, t, _build_tke_t) + return ws.derived.tke_volume + + if data_key == "derived_streamlines_live": + if not ws.derived.streamlines: + return None + return ws.derived.streamlines[min(max(0, t), len(ws.derived.streamlines) - 1)] + + idx = _parse_indexed_data_key(data_key, "smooth_path") + if idx is not None: + if idx >= len(ws.centerline_paths_smooth): + return None + path = np.asarray(ws.centerline_paths_smooth[idx], dtype=float) + if len(path) == 0: + return None + return _path_polydata(path, org) + + idx = _parse_indexed_data_key(data_key, "path_arrow") + if idx is not None: + if idx >= len(ws.centerline_paths_smooth): + return None + path = np.asarray(ws.centerline_paths_smooth[idx], dtype=float) + if len(path) < 2: + return None + org_r = np.asarray(org, dtype=float).reshape(3) + seglens = np.linalg.norm(np.diff(path, axis=0), axis=1) + total = float(np.sum(seglens)) + if total < 1e-6: + return None + overall = path[-1] - path[0] + n = np.linalg.norm(overall) + if n < 1e-12: + return None + overall = overall / n + mid = 0.5 * (path[0] + path[-1]) + arrow_len = max(2.0, total * 0.45) + shaft_r = max(0.25, arrow_len * 0.05) + tip_r = max(0.6, arrow_len * 0.12) + tip_l = max(2.0, arrow_len * 0.25) + start = mid + org_r - overall * (arrow_len * 0.5) + return pv.Arrow( + start=start, + direction=overall * arrow_len, + shaft_radius=shaft_r, + tip_radius=tip_r, + tip_length=tip_l, + ) + + idx = _parse_indexed_data_key(data_key, "path") + if idx is not None: + if idx >= len(ws.centerline_paths): + return None + path = np.asarray(ws.centerline_paths[idx], dtype=float) + if len(path) == 0: + return None + return _path_polydata(path, org) + + if data_key == "fork_markers": + pts = [np.asarray(f.get("crosspoint", [0.0, 0.0, 0.0]), dtype=float) + np.asarray(org, dtype=float).reshape(3) for f in ws.forks] + if not pts: + return None + return pv.PolyData(np.asarray(pts, dtype=float).reshape(-1, 3)) + + idx = _parse_indexed_data_key(data_key, "plane") + if idx is not None: + if idx >= len(ws.planes): + return None + p = ws.planes[idx] + return pv.Plane(center=np.asarray(p.center) + np.asarray(org), direction=np.asarray(p.normal), i_size=25, j_size=25) + + return None + + def _cached(self, data_key, t, builder): + key = (data_key, t) + if key in self._mesh_cache: + return self._mesh_cache[key] + mesh = builder() + if mesh is not None: + self._mesh_cache[key] = mesh + return mesh + + def _get_streamline_mesh(self, t): + ws = self.workspace + if not ws.streamline_active: + return None + if t in ws.streamline_cache: + return ws.streamline_cache[t] + if ws.flow_raw is None or ws.segmask_binary is None: + return None + p = ws.streamline_params + mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] + sl = generate_streamlines_at_t( + ws.flow_raw, t, ws.streamline_seeds, ws.resolution, ws.origin, + mask_3d=mask_t, + max_steps=p.max_steps, + terminal_speed=p.terminal_speed, + seed_ratio=p.seed_ratio, + min_seeds=p.min_seeds, + rng_seed=p.rng_seed, + ) + ws.streamline_cache[t] = sl + return sl + + def _get_plane_streamline_mesh(self, t): + ws = self.workspace + if not ws.plane_streamline_active: + return None + if t in ws.plane_streamline_cache: + return ws.plane_streamline_cache[t] + if ws.flow_raw is None or ws.segmask_binary is None: + return None + pidx = ws.plane_streamline_plane_idx + if pidx < 0 or pidx >= len(ws.planes): + return None + plane = ws.planes[pidx] + p = ws.streamline_params + mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] + sl = generate_streamlines_from_plane_at_t( + ws.flow_raw, t, plane, ws.resolution, ws.origin, + mask_3d=mask_t, + max_steps=p.max_steps, + terminal_speed=p.terminal_speed, + seed_ratio=p.seed_ratio, + min_seeds=p.min_seeds, + rng_seed=p.rng_seed, + branch_labels_3d=ws.branch_labels, + ) + ws.plane_streamline_cache[t] = sl + return sl + + def trigger_streamlines(self): + ws = self.workspace + if ws.flow_raw is None or ws.segmask_3d is None: + self.logger("Cannot generate streamlines: need flow + segmask_3d") + return + ws.streamline_seeds = generate_seed_points( + ws.segmask_3d, + ws.resolution, + ws.origin, + ratio=ws.streamline_params.seed_ratio, + rng_seed=ws.streamline_params.rng_seed, + min_seeds=ws.streamline_params.min_seeds, + ) + ws.streamline_cache.clear() + ws.streamline_active = True + p = ws.streamline_params + self.logger(f"Streamlines enabled: seed_ratio={p.seed_ratio} max_steps={p.max_steps} min_seeds={p.min_seeds} terminal_speed={p.terminal_speed} rng_seed={p.rng_seed}") + ws.remove_object_by_data_key("streamlines_live") + ws.add_object(name="streamlines", kind=ObjectKind.FLOW, + data_key="streamlines_live", visible=True, opacity=1.0, + scalars="Velocity", cmap="turbo", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") + self.sync_from_workspace() + + def trigger_plane_streamlines(self, plane_idx): + ws = self.workspace + if ws.flow_raw is None or ws.segmask_3d is None: + self.logger("Cannot generate plane streamlines: need flow + segmask_3d") + return + if plane_idx < 0 or plane_idx >= len(ws.planes): + self.logger(f"Invalid plane index: {plane_idx}") + return + ws.plane_streamline_cache.clear() + ws.plane_streamline_active = True + ws.plane_streamline_plane_idx = plane_idx + p = ws.streamline_params + self.logger(f"Plane streamlines enabled from plane {plane_idx}: seed_ratio={p.seed_ratio} min_seeds={p.min_seeds} max_steps={p.max_steps} terminal_speed={p.terminal_speed} rng_seed={p.rng_seed}") + ws.remove_object_by_data_key("plane_streamlines_live") + ws.add_object(name="plane_streamlines", kind=ObjectKind.FLOW, + data_key="plane_streamlines_live", visible=True, opacity=1.0, + scalars="Velocity", cmap="turbo", dynamic=True, + show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") + self.sync_from_workspace() + + def clear_streamlines(self): + self.workspace.clear_streamlines() + self.invalidate_cache("streamlines") + self.sync_from_workspace() + self.logger("Streamlines cleared") + + def clear_plane_streamlines(self): + self.workspace.clear_plane_streamlines() + self.invalidate_cache("plane_streamlines") + self.sync_from_workspace() + self.logger("Plane streamlines cleared") + + def find_plane_uid_at_position(self, picked_point): + ws = self.workspace + if picked_point is None: + return None, None + picked = np.asarray(picked_point, dtype=float).reshape(3) + best_uid, best_idx, best_dist = None, None, float("inf") + org = np.asarray(ws.origin, dtype=float).reshape(3) + for uid, obj in ws.scene_objects.items(): + if obj.kind != ObjectKind.PLANE: + continue + pidx = _parse_indexed_data_key(obj.data_key, "plane") + if pidx is None: + continue + if pidx >= len(ws.planes): + continue + center = np.asarray(ws.planes[pidx].center, dtype=float) + org + d = float(np.linalg.norm(picked - center)) + if d < best_dist: + best_uid, best_idx, best_dist = uid, pidx, d + return (best_uid, best_idx) if best_dist <= 30.0 else (None, None) + + def find_path_uid_at_position(self, picked_point): + ws = self.workspace + if picked_point is None: + return None, None + picked = np.asarray(picked_point, dtype=float).reshape(3) + best_uid, best_idx, best_dist = None, None, float("inf") + org = np.asarray(ws.origin, dtype=float).reshape(3) + for uid, obj in ws.scene_objects.items(): + if obj.kind != ObjectKind.BRANCH: + continue + if not obj.data_key.startswith("smooth_path_"): + continue + try: + pidx = int(obj.data_key.split("_")[2]) + except Exception: + continue + if pidx >= len(ws.centerline_paths_smooth): + continue + path = np.asarray(ws.centerline_paths_smooth[pidx], dtype=float) + org.reshape(1, 3) + if len(path) == 0: + continue + d = float(np.min(np.linalg.norm(path - picked.reshape(1, 3), axis=1))) + if d < best_dist: + best_uid, best_idx, best_dist = uid, pidx, d + return (best_uid, best_idx) if best_dist <= 15.0 else (None, None) + + def _ensure_shared_right_click_picking(self): + if self._shared_pick_obs_id is not None: + return + try: + iren = self.plotter.iren.interactor + except Exception: + return + picker = pv._vtk.vtkCellPicker() + picker.SetTolerance(0.005) + + def _on_right_click(obj, ev): + try: + x, y = iren.GetEventPosition() + except Exception: + return + ren = self.plotter.renderer + ok = picker.Pick(float(x), float(y), 0.0, ren) + pos = picker.GetPickPosition() if ok else None + plane_uid, plane_idx = self.find_plane_uid_at_position(pos) if pos is not None else (None, None) + if plane_uid is not None and plane_idx is not None: + if self._plane_pick_callback is not None: + self._plane_pick_callback(plane_uid, plane_idx) + return + path_uid, path_idx = self.find_path_uid_at_position(pos) if pos is not None else (None, None) + if path_uid is not None and path_idx is not None: + if self._path_pick_callback is not None: + self._path_pick_callback(path_uid, path_idx) + return + if self._plane_pick_callback is not None: + self._plane_pick_callback(None, None) + if self._path_pick_callback is not None: + self._path_pick_callback(None, None) + + self._shared_pick_obs_id = iren.AddObserver("RightButtonPressEvent", _on_right_click) + self._plane_pick_obs_id = self._shared_pick_obs_id + self._path_pick_obs_id = self._shared_pick_obs_id + + def enable_plane_picking(self, callback): + self._plane_pick_callback = callback + self._ensure_shared_right_click_picking() + + def enable_path_picking(self, callback): + self._path_pick_callback = callback + self._ensure_shared_right_click_picking() diff --git a/autoflow/utils.py b/autoflow/utils.py index d2f5990..8fc23f5 100755 --- a/autoflow/utils.py +++ b/autoflow/utils.py @@ -1,1383 +1,35 @@ -import sys -import os -import glob -import json -import traceback -import copy - -import numpy as np -import pyvista as pv -import imageio.v2 as imageio -from PIL import Image - -from .models import Workspace, StepId, PlaneData -from .pipeline import PipelineEngine -from .algorithms import ( - load_metrics_as_table, - compute_derived_metrics, - create_uniform_grid, - generate_seed_points, - generate_streamlines_at_t, +"""Compatibility re-exports for legacy utility imports.""" + +from .plane_io import load_plane_positions, project_planes_to_workspace, resolve_reuse_plane_file, save_plane_positions +from .processing import build_base_workspace, collect_h5_files, process_single, run_batch +from .rendering import ( + CAMERA_PRESETS, + WINDOW_SIZE, + extract_frame, + render_plane_rotation_video, + render_streamlines_video, + render_tke_video, + render_wss_video, ) - -WINDOW_SIZE = (1600, 1200) -_OFFSCREEN_BOOTSTRAPPED = False - - -def _offscreen_mode(): - mode = str(os.environ.get("AUTOFLOW_OFFSCREEN_MODE", "local")).strip().lower() - if mode in {"display", "x11", "onscreen"}: - return "display" - if mode in {"local", "headless", "xvfb"}: - return "local" - - display = str(os.environ.get("DISPLAY", "")).strip().lower() - if not display: - return "local" - if os.environ.get("SSH_CONNECTION") or os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"): - if display.startswith("localhost:") or display.startswith("localhost/unix:") or display.startswith("127.0.0.1:"): - return "local" - return "display" - - -def load_metrics_from_output(out_dir): - metrics_path = os.path.join(out_dir, "plane_metrics.json") - qc_path = os.path.join(out_dir, "plane_qc.json") - if not os.path.isfile(metrics_path): - return None, None, None - qc_p = qc_path if os.path.isfile(qc_path) else None - table_rows, raw_metrics, qc_data = load_metrics_as_table(metrics_path, qc_p) - return table_rows, raw_metrics, qc_data - - -def print_metrics_summary(table_rows): - if not table_rows: - print(" No metrics to summarize.") - return - print(f" {'Plane':>6} {'Path':>5} {'Net Flow(mL/beat)':>18} " - f"{'Peak Velocity(cm/s)':>20} {'Mean Velocity(cm/s)':>20} " - f"{'Reflux':>7} {'IC':>6}") - print(f" {'-'*6} {'-'*5} {'-'*18} {'-'*20} {'-'*20} " - f"{'-'*7} {'-'*6}") - for row in table_rows: - pidx = row.get("plane_index", "?") - path = row.get("path_index", "?") - nf = row.get("netflow_mL_beat", 0.0) - pv_ = row.get("peakv_cm_s", 0.0) - if "meanv_signed_cm_s" in row: - mv = row["meanv_signed_cm_s"] - else: - mv = row.get("meanv_cm_s", 0.0) - refl = row.get("reflux_fraction", 0.0) - ic = row.get("path_ic", 1.0) - print(f" {pidx:>6} {path:>5} {nf:>18.4f} " - f"{pv_:>20.3f} {mv:>20.3f} {refl:>7.3f} {ic:>6.3f}") - - -def _normalize(v): - arr = np.asarray(v, dtype=float).reshape(3) - n = np.linalg.norm(arr) - if n <= 1e-12: - return np.array([1.0, 0.0, 0.0], dtype=float) - return arr / n - - -def _path_cumdist(path): - pts = np.asarray(path, dtype=float).reshape(-1, 3) - if len(pts) <= 1: - return np.zeros(len(pts), dtype=float) - return np.concatenate([[0.0], np.cumsum(np.linalg.norm(np.diff(pts, axis=0), axis=1))]) - - -def _path_polydata(path_world): - pts = np.asarray(path_world, dtype=float).reshape(-1, 3) - if len(pts) == 0: - return None - poly = pv.PolyData(pts) - if len(pts) >= 2: - cells = np.empty((len(pts) - 1, 3), dtype=np.int64) - cells[:, 0] = 2 - cells[:, 1] = np.arange(len(pts) - 1) - cells[:, 2] = np.arange(1, len(pts)) - poly.lines = cells.ravel() - return poly - - -def _path_label_anchor(path_world, all_paths_world=None): - pts = np.asarray(path_world, dtype=float).reshape(-1, 3) - if len(pts) == 0: - return None - mid_idx = len(pts) // 2 - anchor = pts[mid_idx].copy() - if len(pts) >= 2: - i0 = max(0, mid_idx - 1) - i1 = min(len(pts) - 1, mid_idx + 1) - tangent = pts[i1] - pts[i0] - tangent = _normalize(tangent) - ref = np.array([0.0, 0.0, 1.0], dtype=float) - if abs(np.dot(tangent, ref)) > 0.9: - ref = np.array([0.0, 1.0, 0.0], dtype=float) - normal = np.cross(tangent, ref) - normal = _normalize(normal) - if all_paths_world: - all_pts = [] - for p in all_paths_world: - arr = np.asarray(p, dtype=float).reshape(-1, 3) - if len(arr) > 0: - all_pts.append(arr) - if all_pts: - all_pts = np.concatenate(all_pts, axis=0) - bmin = all_pts.min(axis=0) - bmax = all_pts.max(axis=0) - diag = np.linalg.norm(bmax - bmin) - offset = max(2.0, 0.02 * diag) - anchor = anchor + normal * offset - return anchor - - -def _plane_mesh(center_world, normal, size): - return pv.Plane( - center=np.asarray(center_world, dtype=float).reshape(3), - direction=_normalize(normal), - i_size=float(size), - j_size=float(size), - i_resolution=1, - j_resolution=1, - ) - - -def _scalar_bar_args(title, bar_cfg=None): - cfg = { - "title": title, - "vertical": True, - "position_x": 0.86, - "position_y": 0.1, - "height": 0.8, - "width": 0.08, - "title_font_size": 18, - "label_font_size": 14, - "n_labels": 5, - "fmt": "%.3g", - } - if bar_cfg: - cfg.update(bar_cfg) - return cfg - - -def _ensure_offscreen(): - global _OFFSCREEN_BOOTSTRAPPED - if _OFFSCREEN_BOOTSTRAPPED: - return - _OFFSCREEN_BOOTSTRAPPED = True - os.environ["PYVISTA_OFF_SCREEN"] = "true" - os.environ.pop("DISPLAY", None) - try: - if hasattr(pv, "start_xvfb"): - pv.start_xvfb() - except Exception: - pass - - -def _make_plotter(window_size=WINDOW_SIZE): - _ensure_offscreen() - p = pv.Plotter(off_screen=True, window_size=window_size) - p.set_background("white") - return p - - -def _write_video(frames, out_path, fps=24): - if not frames: - return None - out_path = os.path.splitext(out_path)[0] + ".mp4" - os.makedirs(os.path.dirname(out_path), exist_ok=True) - try: - with imageio.get_writer(out_path, fps=fps, codec="libx264", macro_block_size=None) as writer: - for frame in frames: - writer.append_data(np.asarray(frame)) - return out_path - except Exception: - gif_path = os.path.splitext(out_path)[0] + ".gif" - imageio.mimsave(gif_path, [np.asarray(frame) for frame in frames], duration=1.0 / max(int(fps), 1)) - return gif_path - - -def _surface_center_radius(poly): - if poly is None or poly.n_points == 0: - return np.zeros(3, dtype=float), 100.0 - b = np.array(poly.bounds, dtype=float).reshape(3, 2) - center = b.mean(axis=1) - extent = np.maximum(b[:, 1] - b[:, 0], 1.0) - radius = float(max(np.linalg.norm(extent) * 1.2, 50.0)) - return center, radius - - -CAMERA_PRESETS = { - "iso": (35.0, 25.0), - "iso_back": (215.0, 25.0), - "right": (0.0, 0.0), - "left": (180.0, 0.0), - "anterior": (270.0, 0.0), - "posterior": (90.0, 0.0), - "superior": (0.0, 89.9), - "inferior": (0.0, -89.9), -} - - -def _resolve_view(view): - if view is None: - return CAMERA_PRESETS["iso"] - if isinstance(view, str): - if view not in CAMERA_PRESETS: - raise ValueError(f"unknown camera preset: {view}, options: {list(CAMERA_PRESETS)}") - return CAMERA_PRESETS[view] - az, el = view - return float(az), float(el) - - -def _camera_from_scene(poly, azimuth_deg=35.0, elevation_deg=25.0, distance_scale=1.0): - center, radius = _surface_center_radius(poly) - radius = radius * float(distance_scale) - az = np.deg2rad(float(azimuth_deg)) - el = np.deg2rad(float(elevation_deg)) - pos = center + np.array([ - radius * np.cos(el) * np.cos(az), - radius * np.cos(el) * np.sin(az), - radius * np.sin(el), - ], dtype=float) - if abs(elevation_deg) > 80.0: - up = (0.0, 1.0, 0.0) - else: - up = (0.0, 0.0, 1.0) - return [tuple(pos.tolist()), tuple(center.tolist()), up] - - -def _orbit_camera(poly, azimuth_deg, elevation_deg=25.0, distance_scale=1.0): - return _camera_from_scene(poly, azimuth_deg, elevation_deg, distance_scale) - - -def _camera_from_view(poly, view, distance_scale=1.0): - az, el = _resolve_view(view) - return _camera_from_scene(poly, az, el, distance_scale) - -def _time_and_azimuth(frame_idx, rotation_frames, n_time, time_repeat=1): - rotation_frames = int(max(rotation_frames, 1)) - n_time = int(max(n_time, 1)) - time_repeat = int(max(time_repeat, 1)) - - t = (frame_idx // time_repeat) % n_time - az = 360.0 * (frame_idx % rotation_frames) / rotation_frames - return t, az -def _build_union_surface(ws, smoothing_iteration=200): - if ws.segmask_binary is not None: - mask3d = np.any(np.asarray(ws.segmask_binary, dtype=bool), axis=3) - else: - mask3d = np.asarray(ws.segmask_3d, dtype=bool) - mesh = create_uniform_grid(mask3d, ws.resolution, origin=ws.origin) - mesh = mesh.threshold(0.1) - if mesh is None or mesh.n_cells == 0: - return None, None - surf = mesh.extract_surface() - if surf is not None and surf.n_points > 0 and int(smoothing_iteration) > 0: - surf = surf.smooth(n_iter=int(smoothing_iteration)) - return mesh, surf - - -def _plane_size_from_surface(surf): - if surf is None or surf.n_points == 0: - return 25.0 - b = np.array(surf.bounds, dtype=float).reshape(3, 2) - extent = b[:, 1] - b[:, 0] - return float(max(12.0, 0.12 * np.max(extent))) - - -def _make_plane_payload(ws, source_path=""): - origin = np.asarray(ws.origin, dtype=float).reshape(3) - payload = { - "source": source_path, - "origin": origin.tolist(), - "resolution": np.asarray(ws.resolution, dtype=float).reshape(3).tolist(), - "planes": [], - } - for i, plane in enumerate(ws.planes): - center_local = np.asarray(plane.center, dtype=float).reshape(3) - payload["planes"].append({ - "plane_index": int(i), - "center": center_local.tolist(), - "center_world": (center_local + origin).tolist(), - "normal": _normalize(plane.normal).tolist(), - "label": int(plane.label), - "path_index": int(plane.path_index), - "distance": float(plane.distance), - }) - return payload - - -def save_plane_positions(ws, out_path, source_path=""): - payload = _make_plane_payload(ws, source_path=source_path) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - return out_path - - -def load_plane_positions(path): - with open(path, "r", encoding="utf-8") as f: - payload = json.load(f) - if isinstance(payload, dict) and "planes" in payload: - return payload["planes"] - if isinstance(payload, list): - return payload - raise ValueError(f"Invalid plane position file: {path}") - - -def _nearest_path_info(center_world, paths_world): - best_dist = np.inf - best_path_idx = -1 - best_point_idx = -1 - best_distance = 0.0 - for path_idx, path in enumerate(paths_world): - pts = np.asarray(path, dtype=float).reshape(-1, 3) - if len(pts) == 0: - continue - d = np.linalg.norm(pts - center_world.reshape(1, 3), axis=1) - point_idx = int(np.argmin(d)) - dist = float(d[point_idx]) - if dist < best_dist: - cum = _path_cumdist(pts) - best_dist = dist - best_path_idx = int(path_idx) - best_point_idx = point_idx - best_distance = float(cum[point_idx]) if len(cum) > point_idx else 0.0 - return best_path_idx, best_point_idx, best_dist, best_distance - - -def _path_tangent(path_world, point_idx): - pts = np.asarray(path_world, dtype=float).reshape(-1, 3) - if len(pts) == 0: - return np.array([1.0, 0.0, 0.0], dtype=float) - i0 = max(0, int(point_idx) - 1) - i1 = min(len(pts) - 1, int(point_idx) + 1) - if i1 == i0: - i1 = min(len(pts) - 1, i0 + 1) - tangent = pts[i1] - pts[i0] - return _normalize(tangent) - - -def project_planes_to_workspace(plane_items, ws): - origin = np.asarray(ws.origin, dtype=float).reshape(3) - paths_local = ws.centerline_paths_smooth if len(ws.centerline_paths_smooth) > 0 else ws.centerline_paths - paths_world = [np.asarray(path, dtype=float).reshape(-1, 3) + origin.reshape(1, 3) for path in paths_local] - planes = [] - for i, item in enumerate(plane_items): - if "center_world" in item: - center_world = np.asarray(item["center_world"], dtype=float).reshape(3) - elif "center" in item: - center_world = np.asarray(item["center"], dtype=float).reshape(3) - else: - continue - normal = _normalize(item.get("normal", [1.0, 0.0, 0.0])) - path_index = int(item.get("path_index", -1)) - distance = float(item.get("distance", 0.0)) - if paths_world: - nearest_path_idx, nearest_point_idx, _, nearest_distance = _nearest_path_info(center_world, paths_world) - if nearest_path_idx >= 0: - path_index = int(nearest_path_idx) - distance = float(nearest_distance) - if np.linalg.norm(normal) <= 1e-12: - normal = _path_tangent(paths_world[path_index], nearest_point_idx) - if path_index < 0: - path_index = 0 - planes.append(PlaneData( - center=center_world - origin, - normal=_normalize(normal), - label=int(path_index) + 1, - path_index=int(path_index), - distance=float(distance), - )) - return planes - - -def resolve_reuse_plane_file(reuse_spec, case_name): - if not reuse_spec: - return "" - if os.path.isfile(reuse_spec): - return reuse_spec - if os.path.isdir(reuse_spec): - candidates = [ - os.path.join(reuse_spec, case_name, "plane_positions.json"), - os.path.join(reuse_spec, case_name, "planes.json"), - os.path.join(reuse_spec, "plane_positions.json"), - os.path.join(reuse_spec, "planes.json"), - ] - for candidate in candidates: - if os.path.isfile(candidate): - return candidate - return reuse_spec - - -def render_plane_rotation_video( - ws, - out_dir, - fps=24, - n_frames=180, - smoothing_iteration=200, - elevation_deg=0.0, - distance_scale=1.0, - add_plane_idx=False, - add_path_idx=False, -): - _, surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) - if surf is None or surf.n_points == 0: - return None - - plane_size = _plane_size_from_surface(surf) - origin = np.asarray(ws.origin, dtype=float).reshape(3) - p = _make_plotter() - p.add_mesh(surf, opacity=0.18, color="white") - - paths_world = [] - for path in ws.centerline_paths_smooth: - path_world = np.asarray(path, dtype=float) + origin.reshape(1, 3) - paths_world.append(path_world) - poly = _path_polydata(path_world) - if poly is not None and poly.n_points > 0: - p.add_mesh(poly, color="deepskyblue", line_width=5, render_lines_as_tubes=True) - - centers = [] - plane_labels = [] - for i, plane in enumerate(ws.planes): - center_world = np.asarray(plane.center, dtype=float).reshape(3) + origin - pm = _plane_mesh(center_world, plane.normal, plane_size) - p.add_mesh(pm, color="yellow", opacity=0.75, show_edges=True, edge_color="black", line_width=2) - centers.append(center_world) - plane_labels.append(f"Plane {i}") - - if add_plane_idx and len(centers) > 0: - p.add_point_labels( - np.asarray(centers, dtype=float), - plane_labels, - font_size=28, - bold=True, - text_color="black", - fill_shape=True, - shape="rounded_rect", - shape_color="yellow", - shape_opacity=0.85, - margin=5, - always_visible=True, - ) - - if add_path_idx and len(paths_world) > 0: - path_label_points = [] - path_label_texts = [] - offsets = [ - np.array([0, 0, 0]), - np.array([3, 0, 0]), - np.array([-3, 0, 0]), - np.array([0, 3, 0]), - np.array([0, -3, 0]), - ] - frac_choices = [0.25, 0.5, 0.75, 0.35, 0.65] - - for idx, path_world in enumerate(paths_world): - if path_world is None or len(path_world) == 0: - continue - n = len(path_world) - frac = frac_choices[idx % len(frac_choices)] - k = min(max(int(frac * (n - 1)), 0), n - 1) - anchor = np.asarray(path_world[k], dtype=float) + offsets[idx % len(offsets)] - path_label_points.append(anchor) - path_label_texts.append(f"Branch {idx}") - - if len(path_label_points) > 0: - p.add_point_labels( - np.asarray(path_label_points, dtype=float), - path_label_texts, - font_size=20, - bold=True, - text_color="black", - fill_shape=True, - shape="rounded_rect", - shape_color="deepskyblue", - shape_opacity=0.85, - margin=2, - always_visible=True, - ) - - frames = [] - for frame_idx in range(int(max(n_frames, 1))): - az = 360.0 * frame_idx / max(n_frames, 1) - p.camera_position = _orbit_camera(surf, az, elevation_deg, distance_scale) - p.add_text( - f"Rotating {frame_idx + 1}/{int(max(n_frames, 1))}", - position="upper_left", - font_size=14, - color="black", - name="frame_text", - ) - p.render() - frames.append(np.asarray(p.screenshot(return_img=True))) - try: - p.remove_actor("frame_text") - except Exception: - pass - p.close() - return _write_video(frames, os.path.join(out_dir, "planes_rotate.mp4"), fps=fps) - - -def render_wss_video( - ws, - out_dir, - fps=24, - smoothing_iteration=200, - view="iso", - distance_scale=1.0, - wss_clim=None, - wss_bar_cfg=None, - rotate=False, - rotation_frames=None, - elevation_deg=None, - time_repeat=1 -): - if not ws.derived.wss_surfaces: - return None - - _, context_surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) - if context_surf is None or context_surf.n_points == 0: - return None - - wss_max = 0.0 - for surf in ws.derived.wss_surfaces: - if surf is not None and surf.n_points > 0 and "wss" in surf.point_data: - vals = np.asarray(surf.point_data["wss"], dtype=float) - if vals.size: - wss_max = max(wss_max, float(np.nanmax(vals))) - wss_max = max(wss_max, 1e-6) - clim = wss_clim if wss_clim is not None else (0.0, wss_max) - - az0, el0 = _resolve_view(view) - if elevation_deg is None: - elevation_deg = el0 - - n_time = int(max(ws.time_count(), 1)) - if rotate: - base_frames = n_time * int(max(time_repeat, 1)) - if rotation_frames is not None: - total_frames = max(int(rotation_frames), base_frames) - else: - total_frames = base_frames - else: - total_frames = n_time * int(max(time_repeat, 1)) - - p = _make_plotter() - frames = [] - - for frame_idx in range(total_frames): - if rotate: - t, az = _time_and_azimuth( - frame_idx, - rotation_frames=rotation_frames if rotation_frames is not None else total_frames, - n_time=n_time, - time_repeat=time_repeat, - ) - camera_position = _orbit_camera(context_surf, az, elevation_deg, distance_scale) - else: - t = min(frame_idx, n_time - 1) - camera_position = _camera_from_view(context_surf, view, distance_scale) - - p.clear() - p.set_background("white") - p.add_mesh(context_surf, opacity=0.08, color="white") - - surf = ws.derived.wss_surfaces[min(max(0, t), len(ws.derived.wss_surfaces) - 1)] - if surf is not None and surf.n_points > 0 and "wss" in surf.point_data: - p.add_mesh( - surf, - scalars="wss", - cmap="jet", - clim=clim, - show_scalar_bar=True, - scalar_bar_args=_scalar_bar_args("WSS (Pa)", wss_bar_cfg), - ) - - if rotate: - txt = f"t={t} | rot {frame_idx + 1}/{total_frames}" - else: - txt = f"t={t}" - - p.add_text(txt, position="upper_left", font_size=14, color="black") - p.camera_position = camera_position - p.render() - frames.append(np.asarray(p.screenshot(return_img=True))) - - p.close() - suffix = "rotate" if rotate else "video" - return _write_video(frames, os.path.join(out_dir, f"wss_{suffix}.mp4"), fps=fps) - -def _streamline_speed_max(ws): - if ws.flow_raw is None: - return 1e-6 - speed = np.linalg.norm(np.asarray(ws.flow_raw, dtype=float) / 100.0, axis=-1) - if ws.segmask_binary is not None and np.any(ws.segmask_binary): - vals = speed[np.asarray(ws.segmask_binary, dtype=bool)] - if vals.size: - return max(float(np.nanmax(vals)), 1e-6) - return max(float(np.nanmax(speed)), 1e-6) - - -def _ensure_streamline_scalars(sl): - if sl is None: - return sl - if "Velocity" in sl.point_data or "Velocity" in sl.cell_data: - return sl - if "vector" in sl.point_data: - sl.point_data["Velocity"] = np.linalg.norm(np.asarray(sl.point_data["vector"], dtype=float), axis=1) - return sl - if "vector" in sl.cell_data: - sl.cell_data["Velocity"] = np.linalg.norm(np.asarray(sl.cell_data["vector"], dtype=float), axis=1) - return sl - return sl - - -def render_streamlines_video( - ws, - out_dir, - fps=24, - smoothing_iteration=200, - view="iso", - distance_scale=1.0, - streamline_clim=None, - streamline_bar_cfg=None, - rotate=False, - rotation_frames=None, - elevation_deg=None, - time_repeat=1 -): - if ws.flow_raw is None or ws.segmask_binary is None or ws.segmask_3d is None: - return None - - mesh, surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) - if mesh is None or surf is None or surf.n_points == 0: - return None - - seeds = generate_seed_points( - ws.segmask_3d, - ws.resolution, - ws.origin, - ratio=ws.streamline_params.seed_ratio, - rng_seed=ws.streamline_params.rng_seed, - min_seeds=50, - ) - - v_max = _streamline_speed_max(ws) - clim = streamline_clim if streamline_clim is not None else (0.0, v_max) - - az0, el0 = _resolve_view(view) - if elevation_deg is None: - elevation_deg = el0 - - n_time = int(max(ws.time_count(), 1)) - if rotate: - base_frames = n_time * int(max(time_repeat, 1)) - if rotation_frames is not None: - total_frames = max(int(rotation_frames), base_frames) - else: - total_frames = base_frames - else: - total_frames = n_time * int(max(time_repeat, 1)) - - p = _make_plotter() - frames = [] - - for frame_idx in range(total_frames): - if rotate: - t, az = _time_and_azimuth( - frame_idx, - rotation_frames=rotation_frames if rotation_frames is not None else total_frames, - n_time=n_time, - time_repeat=time_repeat, - ) - camera_position = _orbit_camera(surf, az, elevation_deg, distance_scale) - else: - t = min(frame_idx, n_time - 1) - camera_position = _camera_from_view(surf, view, distance_scale) - - mask_t = np.asarray( - ws.segmask_binary[..., min(max(0, t), ws.segmask_binary.shape[3] - 1)], - dtype=bool, - ) - - sl = generate_streamlines_at_t( - ws.flow_raw, - t, - seeds, - ws.resolution, - ws.origin, - mask_3d=mask_t, - max_steps=ws.streamline_params.max_steps, - terminal_speed=ws.streamline_params.terminal_speed, - seed_ratio=ws.streamline_params.seed_ratio, - min_seeds=50, - rng_seed=ws.streamline_params.rng_seed, - ) - sl = _ensure_streamline_scalars(sl) - - p.clear() - p.set_background("white") - p.add_mesh(surf, opacity=0.18, color="lightgray") - - if sl is not None and sl.n_points > 0: - p.add_mesh( - sl, - scalars="Velocity", - cmap="turbo", - clim=clim, - show_scalar_bar=True, - scalar_bar_args=_scalar_bar_args("Velocity (m/s)", streamline_bar_cfg), - render_lines_as_tubes=True, - line_width=3, - ) - - if rotate: - txt = f"t={t} | rot {frame_idx + 1}/{total_frames}" - else: - txt = f"t={t}" - - p.add_text(txt, position="upper_left", font_size=14, color="black") - p.camera_position = camera_position - p.render() - frames.append(np.asarray(p.screenshot(return_img=True))) - - p.close() - suffix = "rotate" if rotate else "video" - return _write_video(frames, os.path.join(out_dir, f"streamlines_{suffix}.mp4"), fps=fps) - -def _tke_max(ws): - if ws.derived.tke_array is not None: - return max(float(np.nanmax(np.asarray(ws.derived.tke_array, dtype=float))), 1e-6) - tke_mesh = ws.derived.tke_volume - if tke_mesh is None: - return 1e-6 - if "TKE" in tke_mesh.point_data: - return max(float(np.nanmax(np.asarray(tke_mesh.point_data["TKE"], dtype=float))), 1e-6) - if "TKE" in tke_mesh.cell_data: - return max(float(np.nanmax(np.asarray(tke_mesh.cell_data["TKE"], dtype=float))), 1e-6) - return 1e-6 - -def render_tke_video( - ws, - out_dir, - fps=24, - smoothing_iteration=200, - view="iso", - distance_scale=1.0, - tke_clim=None, - tke_bar_cfg=None, - rotate=False, - rotation_frames=None, - elevation_deg=None, - time_repeat=1 -): - if ws.derived.tke_array is None and ws.derived.tke_volume is None: - return None - - _, surf = _build_union_surface(ws, smoothing_iteration=smoothing_iteration) - if surf is None or surf.n_points == 0: - return None - - tke_max = _tke_max(ws) - clim = tke_clim if tke_clim is not None else (0.0, tke_max) - - az0, el0 = _resolve_view(view) - if elevation_deg is None: - elevation_deg = el0 - - n_time = int(max(ws.time_count(), 1)) - if rotate: - base_frames = n_time * int(max(time_repeat, 1)) - if rotation_frames is not None: - total_frames = max(int(rotation_frames), base_frames) - else: - total_frames = base_frames - else: - total_frames = n_time * int(max(time_repeat, 1)) - - p = _make_plotter() - frames = [] - - for frame_idx in range(total_frames): - if rotate: - t, az = _time_and_azimuth( - frame_idx, - rotation_frames=rotation_frames if rotation_frames is not None else total_frames, - n_time=n_time, - time_repeat=time_repeat, - ) - camera_position = _orbit_camera(surf, az, elevation_deg, distance_scale) - else: - t = min(frame_idx, n_time - 1) - camera_position = _camera_from_view(surf, view, distance_scale) - - p.clear() - p.set_background("white") - p.add_mesh(surf, opacity=0.08, color="white") - - if ws.derived.tke_array is not None: - arr = np.asarray(ws.derived.tke_array, dtype=np.float32) - if arr.ndim == 4: - vol_t = arr[..., min(max(0, t), arr.shape[3] - 1)] - else: - vol_t = arr - tke_mesh = create_uniform_grid(vol_t, ws.resolution, origin=ws.origin, name="TKE") - mesh_union = create_uniform_grid( - np.max(ws.segmask_binary > 0, axis=-1), - ws.resolution, - origin=ws.origin, - ) - mesh_union = mesh_union.threshold(0.1) - tke_mesh = mesh_union.sample(tke_mesh) - p.add_mesh( - tke_mesh, - scalars="TKE", - cmap="hot", - clim=clim, - show_scalar_bar=True, - scalar_bar_args=_scalar_bar_args("TKE (J/m³)", tke_bar_cfg), - ) - else: - p.add_mesh( - ws.derived.tke_volume, - scalars="TKE", - cmap="hot", - clim=clim, - show_scalar_bar=True, - scalar_bar_args=_scalar_bar_args("TKE (J/m³)", tke_bar_cfg), - ) - - if rotate: - txt = f"t={t} | rot {frame_idx + 1}/{total_frames}" - else: - txt = f"t={t}" - - p.add_text(txt, position="upper_left", font_size=14, color="black") - p.camera_position = camera_position - p.render() - frames.append(np.asarray(p.screenshot(return_img=True))) - - p.close() - suffix = "rotate" if rotate else "video" - return _write_video(frames, os.path.join(out_dir, f"tke_{suffix}.mp4"), fps=fps) -def _format_path_group(v): - if v is None: - return "?" - if isinstance(v, dict): - if "paths" in v: - v = v["paths"] - elif "path_indices" in v: - v = v["path_indices"] - if isinstance(v, (int, np.integer)): - return f"Branch{int(v)}" - if isinstance(v, str): - return v - try: - vals = list(v) - except Exception: - return str(v) - if not vals: - return "-" - out = [] - for x in vals: - if isinstance(x, (int, np.integer)): - out.append(f"Branch{int(x)}") - else: - out.append(str(x)) - return "+".join(out) - - -def _fork_side_text(fork): - if not isinstance(fork, dict): - return "left=?", "right=?" - - left = ( - fork.get("left") - or fork.get("left_paths") - or fork.get("left_path_indices") - or fork.get("in_paths") - or fork.get("in_path_indices") - ) - right = ( - fork.get("right") - or fork.get("right_paths") - or fork.get("right_path_indices") - or fork.get("out_paths") - or fork.get("out_path_indices") - ) - - return f"{_format_path_group(left)}", f"{_format_path_group(right)}" - - -def print_qc_summary(qc_data, forks=None): - if not qc_data: - print(" No QC results to summarize.") - return - - items = [] - if isinstance(qc_data, list): - items = qc_data - elif isinstance(qc_data, dict): - if isinstance(qc_data.get("forks"), list): - items = qc_data["forks"] - elif isinstance(qc_data.get("fork_qc"), list): - items = qc_data["fork_qc"] - else: - for k, v in qc_data.items(): - if isinstance(v, dict): - row = dict(v) - row.setdefault("fork_index", k) - items.append(row) - - rows = [] - for i, item in enumerate(items): - if not isinstance(item, dict): - continue - fork_idx = item.get("fork_index", item.get("fork_id", item.get("fork", i))) - ic = item.get("internal_consistency", item.get("ic", item.get("path_ic", np.nan))) - - fork_obj = None - if isinstance(forks, list): - try: - if isinstance(fork_idx, (int, np.integer)) and 0 <= int(fork_idx) < len(forks): - fork_obj = forks[int(fork_idx)] - elif i < len(forks): - fork_obj = forks[i] - except Exception: - pass - - left_txt, right_txt = _fork_side_text(fork_obj) - rows.append((fork_idx, ic, left_txt, right_txt)) - - if not rows: - print(" No fork-level QC results found.") - print(json.dumps(qc_data, ensure_ascii=False, indent=2)) - return - - print(f" {'Fork':>6} {'Internal Consistency':>24} {'Left':>24} {'Right':>24}") - print(f" {'-'*6} {'-'*24} {'-'*24} {'-'*24}") - for fork_idx, ic, left_txt, right_txt in rows: - ic_str = f"{float(ic):.6f}" if np.isfinite(ic) else "nan" - print(f" {str(fork_idx):>6} {ic_str:>24} {left_txt:>24} {right_txt:>24}") - - -def process_single( - h5_path, - out_dir, - workspace=None, - skip_derived=False, - skip_plane_metrics=False, - use_multithread=False, - reuse_planes_path="", - fps=24, - plane_rotation_frames=180, - rotate_dynamic_video=False, - dynamic_rotation_frames=180, - dynamic_rotation_elevation_deg=None, - make_plane_video=True, - make_wss_video=True, - make_streamlines_video=True, - make_tke_video=True, - camera_view="iso", - camera_distance_scale=1.0, - add_plane_idx=False, - add_path_idx=False, - wss_clim=None, - wss_bar_cfg=None, - tke_clim=None, - tke_bar_cfg=None, - streamline_clim=None, - streamline_bar_cfg=None, - dynamic_time_repeat=1, -): - print(f"\n{'=' * 60}") - print(f"Processing: {h5_path}") - print(f"Output dir: {out_dir}") - print(f"{'=' * 60}") - - os.makedirs(out_dir, exist_ok=True) - ws = copy.deepcopy(workspace) if workspace is not None else Workspace() - ws.paths.segmask_path = h5_path - ws.paths.flow_path = h5_path - ws.paths.output_dir = out_dir - ws.derived_params.use_multithread = use_multithread - engine = PipelineEngine() - logger = lambda msg: None - import time as _time - _t_total_start = _time.time() - - print("[1/7] Loading data...") - engine.load_data(ws, logger) - - print("[2/7] Generate Skeleton...") - r = engine.run_step(ws, StepId.GENERATE_SKELETON, logger) - print(f" -> {r.message}") - - print("[3/7] Generate Graph (+ branches/forks)...") - r = engine.run_step(ws, StepId.GENERATE_GRAPH, logger) - print(f" -> {r.message}") - - print("[4/7] Generate Planes...") - r = engine.run_step(ws, StepId.GENERATE_PLANES, logger) - print(f" -> {r.message}") - - if reuse_planes_path: - print(f"[5/7] Reuse Plane Positions: {reuse_planes_path}") - plane_items = load_plane_positions(reuse_planes_path) - ws.planes = project_planes_to_workspace(plane_items, ws) - planes_json = engine._save_planes_json(ws) - print(f" -> Reused {len(ws.planes)} planes saved={planes_json}") - else: - print("[5/7] Use generated planes") - - if skip_plane_metrics: - print("[6/7] Skipped plane metrics") - else: - print("[6/7] Calculate & Save Metrics...") - _, _, metric_msg = engine._compute_plane_metrics_internal( - ws, - save=True, - use_multithread=use_multithread, - ) - print(f" -> {metric_msg}") - try: - engine._save_planes_json(ws) - except Exception: - pass - - pixelwise_result = {} - if not skip_derived: - print("[7/7] Compute Derived Metrics (WSS/TKE)...") - dp = ws.derived_params - engine.preprocess(ws) - loaded_tke = ws.derived.tke_array - - result = compute_derived_metrics( - flow=ws.flow_raw * ws.segmask_binary[..., None], - mask4d=ws.segmask_binary, - spacing=ws.resolution, - origin=ws.origin, - smoothing_iteration=dp.smoothing_iteration, - viscosity=dp.viscosity, - inward_distance=dp.inward_distance, - parabolic_fitting=dp.parabolic_fitting, - no_slip_condition=dp.no_slip_condition, - step_size=dp.step_size, - tube_radius=dp.tube_radius, - rho=dp.rho, - save_pixelwise=True, - tke_array=loaded_tke, - ) - ws.derived.wss_surfaces = result["wss_surfaces"] - ws.derived.wss_volume = result.get("wss_volume") - ws.derived.tke_volume = result["tke_volume"] - ws.derived.tke_array = result.get("tke_array") - ws.derived.pixelwise_export = result.get("pixelwise_export", {}) - pixelwise_result = ws.derived.pixelwise_export - pixel_path = os.path.join(out_dir, "derived_metrics_pixelwise.npz") - if pixelwise_result: - np.savez_compressed(pixel_path, **pixelwise_result) - print(f" -> Saved pixelwise: {pixel_path}") - ws.pipeline.mark_done(StepId.COMPUTE_DERIVED_METRICS) - print(f" -> Derived: Nt={len(ws.derived.wss_surfaces)}") - else: - print("[7/7] Skipped derived metrics (WSS/TKE)") - total_time_sec = _time.time() - _t_total_start - print(f" => Total pipeline took {total_time_sec:.2f}s") - - plane_positions_path = save_plane_positions(ws, os.path.join(out_dir, "plane_positions.json"), source_path=h5_path) - print(f"Plane positions saved: {plane_positions_path}") - - video_paths = {} - if make_plane_video: - try: - video_paths["planes"] = render_plane_rotation_video( - ws, - out_dir, - fps=fps, - n_frames=plane_rotation_frames, - smoothing_iteration=ws.derived_params.smoothing_iteration, - distance_scale=camera_distance_scale, - add_plane_idx=add_plane_idx, - add_path_idx=add_path_idx, - ) - if video_paths["planes"]: - print(f"Plane video saved: {video_paths['planes']}") - except Exception: - print("[WARN] Plane video failed") - print(traceback.format_exc()) - video_paths["planes"] = "" - - if make_streamlines_video: - try: - video_paths["streamlines"] = render_streamlines_video( - ws, - out_dir, - fps=fps, - smoothing_iteration=ws.derived_params.smoothing_iteration, - view=camera_view, - distance_scale=camera_distance_scale, - streamline_clim=streamline_clim, - streamline_bar_cfg=streamline_bar_cfg, - rotate=rotate_dynamic_video, - rotation_frames=dynamic_rotation_frames, - elevation_deg=dynamic_rotation_elevation_deg, - time_repeat=dynamic_time_repeat, - ) - if video_paths["streamlines"]: - print(f"Streamlines video saved: {video_paths['streamlines']}") - except Exception: - print("[WARN] Streamlines video failed") - print(traceback.format_exc()) - video_paths["streamlines"] = "" - - if not skip_derived and make_wss_video: - try: - video_paths["wss"] = render_wss_video( - ws, - out_dir, - fps=fps, - smoothing_iteration=ws.derived_params.smoothing_iteration, - view=camera_view, - distance_scale=camera_distance_scale, - wss_clim=wss_clim, - wss_bar_cfg=wss_bar_cfg, - rotate=rotate_dynamic_video, - rotation_frames=dynamic_rotation_frames, - elevation_deg=dynamic_rotation_elevation_deg, - time_repeat=dynamic_time_repeat, - ) - if video_paths["wss"]: - print(f"WSS video saved: {video_paths['wss']}") - except Exception: - print("[WARN] WSS video failed") - print(traceback.format_exc()) - video_paths["wss"] = "" - - if not skip_derived and make_tke_video: - try: - video_paths["tke"] = render_tke_video( - ws, - out_dir, - fps=fps, - smoothing_iteration=ws.derived_params.smoothing_iteration, - view=camera_view, - distance_scale=camera_distance_scale, - tke_clim=tke_clim, - tke_bar_cfg=tke_bar_cfg, - rotate=rotate_dynamic_video, - rotation_frames=dynamic_rotation_frames, - elevation_deg=dynamic_rotation_elevation_deg, - time_repeat=dynamic_time_repeat, - ) - if video_paths["tke"]: - print(f"TKE video saved: {video_paths['tke']}") - except Exception: - print("[WARN] TKE video failed") - print(traceback.format_exc()) - video_paths["tke"] = "" - table_rows, raw_metrics, qc_data = (None, None, None) - if not skip_plane_metrics: - table_rows, raw_metrics, qc_data = load_metrics_from_output(out_dir) - - if table_rows: - print("\n === Plane Metrics Summary ===") - print_metrics_summary(table_rows) - - if qc_data: - print("\n === Fork QC Summary ===") - print_qc_summary(qc_data, ws.forks) - - summary = { - "input": h5_path, - "output_dir": out_dir, - "resolution": ws.resolution.tolist(), - "origin": np.asarray(ws.origin, dtype=float).reshape(3).tolist(), - "rr": ws.rr, - "total_time_sec": float(total_time_sec), - "n_planes": len(ws.planes), - "n_skeleton_pts": len(ws.skeleton_points) if ws.skeleton_points is not None else 0, - "n_graph_nodes": len(ws.graph.points), - "n_graph_edges": len(ws.graph.edges), - "n_paths": len(ws.centerline_paths_smooth), - "n_forks": len(ws.forks), - "path_info": ws.path_info, - "forks": ws.forks, - "plane_metrics": ws.derived.plane_metrics, - "plane_qc": ws.derived.plane_qc, - "plane_positions_file": plane_positions_path, - "reused_planes_file": reuse_planes_path, - "videos": video_paths, - "pixelwise_export": {k: list(np.asarray(v).shape) for k, v in pixelwise_result.items()} if pixelwise_result else {}, - } - summary_path = os.path.join(out_dir, "summary.json") - with open(summary_path, "w", encoding="utf-8") as f: - json.dump(summary, f, ensure_ascii=False, indent=2) - print(f"\nSummary saved: {summary_path}") - return summary - - -def collect_h5_files(inputs): - files = [] - for inp in inputs: - if os.path.isfile(inp) and inp.lower().endswith((".h5", ".hdf5")): - files.append(inp) - elif os.path.isdir(inp): - files.extend(sorted(glob.glob(os.path.join(inp, "**", "*.h5"), recursive=True))) - files.extend(sorted(glob.glob(os.path.join(inp, "**", "*.hdf5"), recursive=True))) - return sorted(dict.fromkeys(files)) - - -def build_base_workspace(): - ws = Workspace() - ws.plane_gen_params.use_center_plane = globals().get("USE_CENTER_PLANE", True) - ws.plane_gen_params.cross_section_distance = globals().get("CROSS_SECTION_DIST", 5.0) - ws.plane_gen_params.start_distance = globals().get("START_DIST", 5.0) - ws.plane_gen_params.end_distance = globals().get("END_DIST", 0.0) - ws.skeleton_params.remove_small_cc = globals().get("REMOVE_SMALL_CC", True) - ws.skeleton_params.min_cc_volume_mm3 = globals().get("MIN_CC_VOLUME", 50.0) - ws.streamline_params.max_steps = 2000 - ws.streamline_params.min_seeds = 50 - ws.streamline_params.seed_ratio = globals().get("SEED_RATIO", 0.02) - ws.streamline_params.tube_radius = globals().get("TUBE_RADIUS", 0.05) - return ws - - -def run_batch(): - inputs = globals().get("INPUT", None) - if inputs is None: - raise ValueError("INPUT is not defined.") - dynamic_time_repeat = globals().get("DYNAMIC_TIME_REPEAT", 1) - output_dir = globals().get("OUTPUT_DIR", "./batch_output") - skip_derived = globals().get("SKIP_DERIVED", False) - use_multithread = globals().get("USE_MULTITHREAD", True) - reuse_planes = globals().get("REUSE_PLANES", "") - fps = globals().get("FPS", 12) - plane_rotation_frames = globals().get("PLANE_ROTATION_FRAMES", 180) - make_plane_video = globals().get("MAKE_PLANE_VIDEO", True) - make_wss_video = globals().get("MAKE_WSS_VIDEO", True) - make_streamlines_video = globals().get("MAKE_STREAMLINES_VIDEO", True) - make_tke_video = globals().get("MAKE_TKE_VIDEO", True) - camera_view = globals().get("CAMERA_VIEW", "posterior") - camera_distance_scale = globals().get("CAMERA_DISTANCE_SCALE", 1.5) - skip_plane_metrics = globals().get("SKIP_PLANE_METRICS", False) - rotate_dynamic_video = globals().get("ROTATE_DYNAMIC_VIDEO", False) - dynamic_rotation_frames = globals().get("DYNAMIC_ROTATION_FRAMES", 180) - dynamic_rotation_elevation_deg = globals().get("DYNAMIC_ROTATION_ELEVATION_DEG", None) - add_plane_idx = globals().get("ADD_PLANE_IDX", False) - add_path_idx = globals().get("ADD_PATH_IDX", False) - - wss_clim = globals().get("WSS_CLIM", (0, 5)) - wss_bar_cfg = globals().get( - "WSS_BAR_CFG", - {"position_x": 0.75, "position_y": 0.2, "height": 0.22, "width": 0.05, "title_font_size": 40, "label_font_size": 32}, - ) - tke_clim = globals().get("TKE_CLIM", (0, 2)) - tke_bar_cfg = globals().get( - "TKE_BAR_CFG", - {"position_x": 0.75, "position_y": 0.2, "height": 0.22, "width": 0.05, "title_font_size": 40, "label_font_size": 32}, - ) - streamline_clim = globals().get("STREAMLINE_CLIM", (0, 0.6)) - streamline_bar_cfg = globals().get( - "STREAMLINE_BAR_CFG", - {"position_x": 0.75, "position_y": 0.2, "height": 0.22, "width": 0.05, "title_font_size": 40, "label_font_size": 32}, - ) - - h5_files = collect_h5_files(inputs) - if not h5_files: - print("No H5 files found.") - return [], "" - - print(f"Found {len(h5_files)} file(s) to process.") - base_ws = build_base_workspace() - results = [] - case_out = "" - - for path in h5_files: - name = os.path.splitext(os.path.basename(path))[0] - case_out = os.path.join(output_dir, name) - reuse_file = resolve_reuse_plane_file(reuse_planes, name) - - if reuse_planes and not os.path.isfile(reuse_file): - results.append({"file": path, "status": "error", "error": f"reuse plane file not found: {reuse_planes}"}) - print(f"\n[ERROR] Reuse plane file not found: {reuse_planes}") - continue - - try: - summary = process_single( - path, - case_out, - workspace=base_ws, - skip_derived=skip_derived, - use_multithread=use_multithread, - reuse_planes_path=reuse_file, - fps=fps, - plane_rotation_frames=plane_rotation_frames, - make_plane_video=make_plane_video, - make_wss_video=make_wss_video, - make_streamlines_video=make_streamlines_video, - make_tke_video=make_tke_video, - camera_view=camera_view, - camera_distance_scale=camera_distance_scale, - add_plane_idx=add_plane_idx, - add_path_idx=add_path_idx, - wss_clim=wss_clim, - wss_bar_cfg=wss_bar_cfg, - tke_clim=tke_clim, - tke_bar_cfg=tke_bar_cfg, - streamline_clim=streamline_clim, - streamline_bar_cfg=streamline_bar_cfg, - skip_plane_metrics=skip_plane_metrics, - rotate_dynamic_video=rotate_dynamic_video, - dynamic_rotation_frames=dynamic_rotation_frames, - dynamic_rotation_elevation_deg=dynamic_rotation_elevation_deg, - dynamic_time_repeat=dynamic_time_repeat, - ) - results.append({"file": path, "status": "ok", "summary": summary}) - except Exception: - print(f"\n[ERROR] Failed: {path}") - print(traceback.format_exc()) - results.append({"file": path, "status": "error", "error": traceback.format_exc()}) - - os.makedirs(output_dir, exist_ok=True) - batch_report = os.path.join(output_dir, "batch_report.json") - with open(batch_report, "w", encoding="utf-8") as f: - json.dump(results, f, ensure_ascii=False, indent=2) - - n_ok = sum(1 for r in results if r["status"] == "ok") - - times_sec = [] - for r in results: - if r.get("status") == "ok": - t = r.get("summary", {}).get("total_time_sec", None) - if t is not None: - times_sec.append(float(t)) - - if times_sec: - times_sec = np.asarray(times_sec, dtype=float) - mean_sec = times_sec.mean() - std_sec = times_sec.std(ddof=1) if len(times_sec) > 1 else 0.0 - time_text = f"{mean_sec:.2f} ± {std_sec:.2f} s" - print(f"\nCase time: {time_text}") - - with open(os.path.join(output_dir, "time_summary.txt"), "w", encoding="utf-8") as f: - f.write(f"n = {len(times_sec)}\n") - f.write(f"mean_sec = {mean_sec:.6f}\n") - f.write(f"std_sec = {std_sec:.6f}\n") - f.write(f"formatted = {time_text}\n") - print(f"\nDone: {n_ok}/{len(results)} succeeded. Report: {batch_report}") - return results, case_out - - -def extract_frame(mp4_path, frame_index, out_png): - reader = imageio.get_reader(mp4_path, format="ffmpeg") - frame = reader.get_data(frame_index) - reader.close() - Image.fromarray(np.asarray(frame)).save(out_png, format="PNG", compress_level=0) - print(f"Saved frame {frame_index} -> {out_png}") +from .reporting import load_metrics_from_output, print_metrics_summary, print_qc_summary + +__all__ = [ + "CAMERA_PRESETS", + "WINDOW_SIZE", + "build_base_workspace", + "collect_h5_files", + "extract_frame", + "load_metrics_from_output", + "load_plane_positions", + "print_metrics_summary", + "print_qc_summary", + "process_single", + "project_planes_to_workspace", + "render_plane_rotation_video", + "render_streamlines_video", + "render_tke_video", + "render_wss_video", + "resolve_reuse_plane_file", + "run_batch", + "save_plane_positions", +] diff --git a/autoflow/viewer.py b/autoflow/viewer.py index cfe1e4e..205af33 100755 --- a/autoflow/viewer.py +++ b/autoflow/viewer.py @@ -1,866 +1,11 @@ -import numpy as np -import pyvista as pv +"""Compatibility re-exports for UI viewer components.""" -from .models import ObjectKind -from .algorithms import ( - build_multilabel_surface_t, - build_surface_from_mask3d, - graph_to_polydata, - generate_seed_points, - generate_streamlines_at_t, - generate_streamlines_from_plane_at_t, - create_uniform_grid, -) +__all__ = ["SceneController"] -def _parse_indexed_data_key(data_key, prefix): - token = f"{prefix}_" - if not isinstance(data_key, str) or not data_key.startswith(token): - return None - suffix = data_key[len(token):] - if suffix.isdigit(): - return int(suffix) - return None +def __getattr__(name): + if name == "SceneController": + from .ui.viewer import SceneController - -def _path_polydata(path, origin): - pts = np.asarray(path, dtype=float) - if len(pts) == 0: - return None - poly = pv.PolyData(pts + np.asarray(origin, dtype=float).reshape(1, 3)) - if len(pts) >= 2: - cells = np.empty((len(pts) - 1, 3), dtype=np.int64) - cells[:, 0] = 2 - cells[:, 1] = np.arange(len(pts) - 1) - cells[:, 2] = np.arange(1, len(pts)) - poly.lines = cells.ravel() - return poly - - -class SceneController: - def __init__(self, plotter, workspace, logger): - self.plotter = plotter - self.workspace = workspace - self.logger = logger - self._axes_shown = True - self._mesh_cache = {} - self._tracked_actors = {} - self._saved_camera = None - self._playback_active = False - self._highlight_plane_uid = None - self._highlight_plane_actor = None - self._highlight_path_uid = None - self._highlight_path_actor = None - self._context_path_actors = [] - self._highlight_fork_actor = None - self._plane_pick_obs_id = None - self._path_pick_obs_id = None - self._plane_pick_callback = None - self._path_pick_callback = None - self._shared_pick_obs_id = None - - def initialize(self): - self.plotter.set_background("white") - self.plotter.add_axes(line_width=2) - self.plotter.reset_camera() - - def reset_scene(self): - try: - self.plotter.clear() - except Exception: - try: - self.plotter.renderer.RemoveAllViewProps() - except Exception: - pass - for obj in self.workspace.scene_objects.values(): - obj.actor = None - obj.label_actor = None - self._tracked_actors.clear() - self._mesh_cache.clear() - self._remove_plane_highlight() - self._remove_path_highlight() - self.initialize() - - def invalidate_cache(self, prefix=None): - if prefix is None: - self._mesh_cache.clear() - else: - self._mesh_cache = {k: v for k, v in self._mesh_cache.items() if not k[0].startswith(prefix)} - - def set_background(self, color): - self.plotter.set_background(color) - self.render_all() - - def toggle_axes(self): - self._axes_shown = not self._axes_shown - self.reset_scene() - if not self._axes_shown: - try: - self.plotter.hide_axes() - except Exception: - pass - self.render_all() - - def reset_camera(self): - try: - self.plotter.reset_camera() - self.plotter.render() - except Exception: - pass - - def save_camera(self): - try: - self._saved_camera = self.plotter.camera_position - except Exception: - self._saved_camera = None - - def restore_camera(self): - if self._saved_camera is not None: - try: - self.plotter.camera_position = self._saved_camera - except Exception: - pass - - def set_playback_active(self, active): - self._playback_active = active - if active: - self.save_camera() - - def sync_from_workspace(self): - current_uids = set(self.workspace.scene_objects.keys()) - stale = set(self._tracked_actors.keys()) - current_uids - for uid in stale: - actor = self._tracked_actors.pop(uid, None) - if actor is not None: - try: - self.plotter.remove_actor(actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(actor) - except Exception: - pass - self.render_all() - - def remove_object(self, uid): - obj = self.workspace.scene_objects.get(uid) - if obj is not None: - self._remove_actor(obj) - del self.workspace.scene_objects[uid] - actor = self._tracked_actors.pop(uid, None) - if actor is not None: - try: - self.plotter.remove_actor(actor) - except Exception: - pass - if self._highlight_plane_uid == uid: - self._remove_plane_highlight() - if self._highlight_path_uid == uid: - self._remove_path_highlight() - - def render_all(self): - for obj in self.workspace.scene_objects.values(): - self._render_object(obj) - try: - self.plotter.render() - except Exception: - pass - - def update_time(self, t): - self.workspace.current_t = int(t) - cam_before = None - if self._playback_active: - try: - cam_before = self.plotter.camera_position - except Exception: - cam_before = None - for obj in self.workspace.scene_objects.values(): - if obj.dynamic: - self.readd_object(obj) - if self._playback_active and cam_before is not None: - try: - self.plotter.camera_position = cam_before - except Exception: - pass - try: - self.plotter.render() - except Exception: - pass - - def rebuild_dynamic(self): - for obj in self.workspace.scene_objects.values(): - if obj.dynamic: - self.readd_object(obj) - - def readd_object(self, obj): - self._remove_actor(obj) - self._render_object(obj) - - def apply_object_properties(self, obj): - if obj.actor is None: - self._render_object(obj) - return - try: - obj.actor.SetVisibility(1 if obj.visible else 0) - except Exception: - pass - try: - prop = obj.actor.GetProperty() - prop.SetOpacity(float(obj.opacity)) - prop.SetLineWidth(float(obj.line_width)) - prop.SetPointSize(float(obj.point_size)) - except Exception: - pass - if obj.visible: - self.readd_object(obj) - return - try: - self.plotter.render() - except Exception: - pass - - def highlight_plane(self, uid): - self._remove_plane_highlight() - self._highlight_plane_uid = uid - if uid is None: - try: - self.plotter.render() - except Exception: - pass - return - obj = self.workspace.scene_objects.get(uid) - if obj is None or obj.kind != ObjectKind.PLANE: - self._highlight_plane_uid = None - return - data = self._build_dataset(obj.data_key) - if data is None: - return - try: - self._highlight_plane_actor = self.plotter.add_mesh( - data, color="magenta", opacity=0.9, line_width=4, - style="wireframe", name="__plane_highlight__") - self._promote_overlay_actor(self._highlight_plane_actor) - except Exception: - self._highlight_plane_actor = None - try: - self.plotter.render() - except Exception: - pass - - def highlight_path(self, uid): - self._remove_path_highlight() - self._highlight_path_uid = uid - if uid is None: - try: - self.plotter.render() - except Exception: - pass - return - obj = self.workspace.scene_objects.get(uid) - if obj is None or obj.kind != ObjectKind.BRANCH: - self._highlight_path_uid = None - return - data = self._build_dataset(obj.data_key) - if data is None: - return - try: - self._highlight_path_actor = self.plotter.add_mesh( - data, color="magenta", opacity=1.0, line_width=8, - render_lines_as_tubes=True, - name="__path_highlight__") - self._promote_overlay_actor(self._highlight_path_actor) - except Exception: - self._highlight_path_actor = None - try: - self.plotter.render() - except Exception: - pass - - def show_forks_for_path(self, path_idx): - self._clear_fork_and_context_actors() - if int(path_idx) < 0: - try: - self.plotter.render() - except Exception: - pass - return - org = np.asarray(self.workspace.origin, dtype=float).reshape(3) - pts = [] - incoming_ids = set() - outgoing_ids = set() - if 0 <= int(path_idx) < len(self.workspace.path_info): - info = self.workspace.path_info[int(path_idx)] - incoming_ids.update(int(x) for x in info.get("incoming_path_ids", [])) - outgoing_ids.update(int(x) for x in info.get("outgoing_path_ids", [])) - for fork in self.workspace.forks: - if int(path_idx) in fork.get("left", []) or int(path_idx) in fork.get("right", []): - pts.append(np.asarray(fork.get("crosspoint", [0.0, 0.0, 0.0]), dtype=float) + org) - incoming_ids.update(int(x) for x in fork.get("left", []) if int(x) != int(path_idx)) - outgoing_ids.update(int(x) for x in fork.get("right", []) if int(x) != int(path_idx)) - incoming_ids.discard(int(path_idx)) - outgoing_ids.discard(int(path_idx)) - for pid, color in [(sorted(incoming_ids), "deepskyblue"), (sorted(outgoing_ids), "orange")]: - for idx in pid: - if not (0 <= int(idx) < len(self.workspace.centerline_paths_smooth)): - continue - poly = _path_polydata(self.workspace.centerline_paths_smooth[int(idx)], org) - if poly is None: - continue - try: - actor = self.plotter.add_mesh( - poly, color=color, opacity=1.0, line_width=8, - render_lines_as_tubes=True, - name=f"__path_context_{color}_{int(idx)}__") - self._promote_overlay_actor(actor) - self._context_path_actors.append(actor) - except Exception: - pass - if pts: - try: - poly = pv.PolyData(np.asarray(pts, dtype=float).reshape(-1, 3)) - self._highlight_fork_actor = self.plotter.add_mesh( - poly, color="magenta", point_size=22, render_points_as_spheres=True, - name="__fork_highlight__") - self._promote_overlay_actor(self._highlight_fork_actor) - except Exception: - self._highlight_fork_actor = None - try: - self.plotter.render() - except Exception: - pass - - def _remove_plane_highlight(self): - if self._highlight_plane_actor is not None: - try: - self.plotter.remove_actor(self._highlight_plane_actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(self._highlight_plane_actor) - except Exception: - pass - self._highlight_plane_actor = None - self._highlight_plane_uid = None - - def _remove_path_highlight(self): - if self._highlight_path_actor is not None: - try: - self.plotter.remove_actor(self._highlight_path_actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(self._highlight_path_actor) - except Exception: - pass - self._highlight_path_actor = None - self._highlight_path_uid = None - self._clear_fork_and_context_actors() - - def _clear_fork_and_context_actors(self): - if self._highlight_fork_actor is not None: - try: - self.plotter.remove_actor(self._highlight_fork_actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(self._highlight_fork_actor) - except Exception: - pass - self._highlight_fork_actor = None - for actor in list(self._context_path_actors): - if actor is not None: - try: - self.plotter.remove_actor(actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(actor) - except Exception: - pass - self._context_path_actors = [] - try: - self.plotter.remove_actor("__fork_highlight__") - except Exception: - pass - try: - renderer = self.plotter.renderer - actors_to_remove = [] - it = renderer.GetActors() - it.InitTraversal() - for _ in range(it.GetNumberOfItems()): - a = it.GetNextItem() - if a is not None: - try: - name = a.GetObjectName() if hasattr(a, "GetObjectName") else "" - if name and ("__path_context_" in name or "__fork_highlight__" in name): - actors_to_remove.append(a) - except Exception: - pass - for a in actors_to_remove: - try: - renderer.RemoveActor(a) - except Exception: - pass - except Exception: - pass - - - def _promote_overlay_actor(self, actor): - if actor is None: - return - try: - actor.PickableOff() - except Exception: - pass - try: - prop = actor.GetProperty() - prop.SetLighting(False) - except Exception: - pass - - def refresh_plane_labels(self): - pass - - def remove_all_plane_labels(self): - pass - - def _remove_actor(self, obj): - if obj.actor is not None: - try: - self.plotter.remove_actor(obj.actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(obj.actor) - except Exception: - pass - if getattr(obj, "label_actor", None) is not None: - try: - self.plotter.remove_actor(obj.label_actor) - except Exception: - try: - self.plotter.renderer.RemoveActor(obj.label_actor) - except Exception: - pass - self._tracked_actors.pop(obj.uid, None) - obj.actor = None - obj.label_actor = None - - def _render_object(self, obj): - if not obj.visible: - if obj.actor is not None: - try: - obj.actor.SetVisibility(0) - except Exception: - pass - return - data = self._build_dataset(obj.data_key) - if data is None: - self._remove_actor(obj) - return - if obj.actor is not None: - self._remove_actor(obj) - kwargs = self._mesh_kwargs(obj, data) - try: - if obj.tube_radius > 0 and hasattr(data, "tube") and obj.kind.value in ("Graph", "Branch", "Flow", "Metric", "Skeleton"): - data_show = data.tube(radius=float(obj.tube_radius)) - else: - data_show = data - obj.actor = self.plotter.add_mesh(data_show, name=obj.uid, **kwargs) - self._tracked_actors[obj.uid] = obj.actor - self._apply_basic_properties_only(obj) - except Exception as e: - self.logger(f"Render failed: {obj.name}: {type(e).__name__}: {e}") - - def _apply_basic_properties_only(self, obj): - try: - obj.actor.SetVisibility(1 if obj.visible else 0) - except Exception: - pass - try: - prop = obj.actor.GetProperty() - prop.SetOpacity(float(obj.opacity)) - prop.SetLineWidth(float(obj.line_width)) - prop.SetPointSize(float(obj.point_size)) - except Exception: - pass - - def _mesh_kwargs(self, obj, data): - kw = {"opacity": float(obj.opacity), "show_scalar_bar": bool(obj.show_scalar_bar)} - use_scalars = False - if obj.scalars: - if hasattr(data, "point_data") and obj.scalars in data.point_data: - use_scalars = True - if hasattr(data, "cell_data") and obj.scalars in data.cell_data: - use_scalars = True - if use_scalars: - kw["scalars"] = obj.scalars - kw["cmap"] = obj.cmap - if obj.clim: - kw["clim"] = obj.clim - if obj.scalar_bar_title: - kw["scalar_bar_args"] = { - "title": obj.scalar_bar_title, - "vertical": True, - "title_font_size": 14, - "label_font_size": 12, - "n_labels": 5, - "fmt": "%.3g", - } - else: - kw["color"] = obj.color - if obj.kind.value in ("Skeleton", "Aux"): - kw["render_points_as_spheres"] = True - kw["point_size"] = obj.point_size - if obj.kind.value in ("Graph", "Branch", "Flow", "Aux"): - kw["line_width"] = obj.line_width - kw["render_lines_as_tubes"] = True - if obj.kind == ObjectKind.PLANE: - kw["show_edges"] = True - kw["edge_color"] = "black" - kw["line_width"] = max(float(obj.line_width), 2.0) - return kw - - def _build_dataset(self, data_key): - ws = self.workspace - t = ws.current_t - sp = ws.resolution - org = ws.origin - - if data_key == "segmask_raw_surface": - if ws.segmask_raw is None: - return None - return self._cached(data_key, t, lambda: build_multilabel_surface_t(ws.segmask_raw, t, sp, org)) - - if data_key == "segmask_pre_surface": - if ws.segmask_labels is None: - return None - return self._cached(data_key, t, lambda: build_multilabel_surface_t(ws.segmask_labels, t, sp, org)) - - if data_key == "segmask_3d_surface": - if ws.segmask_3d is None: - return None - return self._cached(data_key, 0, lambda: build_surface_from_mask3d(ws.segmask_3d, sp, org, smooth_iter=1000)) - - if data_key == "skeleton_points": - if ws.skeleton_points is None or len(ws.skeleton_points) == 0: - return None - return pv.PolyData(np.asarray(ws.skeleton_points, dtype=float) + np.asarray(org, dtype=float).reshape(1, 3)) - - if data_key == "skeleton_mask_surface": - if ws.skeleton_mask is None: - return None - return self._cached(data_key, 0, lambda: build_surface_from_mask3d(ws.skeleton_mask, sp, org, smooth_iter=1000)) - - if data_key == "graph_lines": - if ws.graph is None or len(ws.graph.points) == 0: - return None - return graph_to_polydata(np.asarray(ws.graph.points) + np.asarray(org).reshape(1, 3), ws.graph.edges) - - if data_key == "streamlines_live": - return self._get_streamline_mesh(t) - - if data_key == "plane_streamlines_live": - return self._get_plane_streamline_mesh(t) - - if data_key == "wss_surface_live": - if not ws.derived.wss_surfaces: - return None - return ws.derived.wss_surfaces[min(max(0, t), len(ws.derived.wss_surfaces) - 1)] - - if data_key == "tke_volume": - if ws.derived.tke_array is not None: - def _build_tke_t(): - arr = np.asarray(ws.derived.tke_array, dtype=np.float32) - if arr.ndim == 4: - vol_t = arr[..., min(max(0, int(t)), arr.shape[3] - 1)] - else: - vol_t = arr - if ws.segmask_binary is not None: - if ws.segmask_binary.ndim == 4: - mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] - else: - mask_t = ws.segmask_binary - elif ws.segmask_3d is not None: - mask_t = ws.segmask_3d - else: - mask_t = np.ones(vol_t.shape, dtype=bool) - vol_t = vol_t * np.asarray(mask_t, dtype=np.float32) - - tke_grid = create_uniform_grid(vol_t, sp, origin=org, name="TKE") - mask_grid = create_uniform_grid(np.asarray(mask_t, dtype=np.float32), sp, origin=org, name="mask") - mask_mesh = mask_grid.threshold(0.1, scalars="mask") - if mask_mesh is None or mask_mesh.n_cells == 0: - return None - return mask_mesh.sample(tke_grid) - return self._cached(data_key, t, _build_tke_t) - return ws.derived.tke_volume - - if data_key == "derived_streamlines_live": - if not ws.derived.streamlines: - return None - return ws.derived.streamlines[min(max(0, t), len(ws.derived.streamlines) - 1)] - - idx = _parse_indexed_data_key(data_key, "smooth_path") - if idx is not None: - if idx >= len(ws.centerline_paths_smooth): - return None - path = np.asarray(ws.centerline_paths_smooth[idx], dtype=float) - if len(path) == 0: - return None - return _path_polydata(path, org) - - idx = _parse_indexed_data_key(data_key, "path_arrow") - if idx is not None: - if idx >= len(ws.centerline_paths_smooth): - return None - path = np.asarray(ws.centerline_paths_smooth[idx], dtype=float) - if len(path) < 2: - return None - org_r = np.asarray(org, dtype=float).reshape(3) - seglens = np.linalg.norm(np.diff(path, axis=0), axis=1) - total = float(np.sum(seglens)) - if total < 1e-6: - return None - overall = path[-1] - path[0] - n = np.linalg.norm(overall) - if n < 1e-12: - return None - overall = overall / n - mid = 0.5 * (path[0] + path[-1]) - arrow_len = max(2.0, total * 0.45) - shaft_r = max(0.25, arrow_len * 0.05) - tip_r = max(0.6, arrow_len * 0.12) - tip_l = max(2.0, arrow_len * 0.25) - start = mid + org_r - overall * (arrow_len * 0.5) - return pv.Arrow( - start=start, - direction=overall * arrow_len, - shaft_radius=shaft_r, - tip_radius=tip_r, - tip_length=tip_l, - ) - - idx = _parse_indexed_data_key(data_key, "path") - if idx is not None: - if idx >= len(ws.centerline_paths): - return None - path = np.asarray(ws.centerline_paths[idx], dtype=float) - if len(path) == 0: - return None - return _path_polydata(path, org) - - if data_key == "fork_markers": - pts = [np.asarray(f.get("crosspoint", [0.0, 0.0, 0.0]), dtype=float) + np.asarray(org, dtype=float).reshape(3) for f in ws.forks] - if not pts: - return None - return pv.PolyData(np.asarray(pts, dtype=float).reshape(-1, 3)) - - idx = _parse_indexed_data_key(data_key, "plane") - if idx is not None: - if idx >= len(ws.planes): - return None - p = ws.planes[idx] - return pv.Plane(center=np.asarray(p.center) + np.asarray(org), direction=np.asarray(p.normal), i_size=25, j_size=25) - - return None - - def _cached(self, data_key, t, builder): - key = (data_key, t) - if key in self._mesh_cache: - return self._mesh_cache[key] - mesh = builder() - if mesh is not None: - self._mesh_cache[key] = mesh - return mesh - - def _get_streamline_mesh(self, t): - ws = self.workspace - if not ws.streamline_active: - return None - if t in ws.streamline_cache: - return ws.streamline_cache[t] - if ws.flow_raw is None or ws.segmask_binary is None: - return None - p = ws.streamline_params - mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] - sl = generate_streamlines_at_t( - ws.flow_raw, t, ws.streamline_seeds, ws.resolution, ws.origin, - mask_3d=mask_t, - max_steps=p.max_steps, - terminal_speed=p.terminal_speed, - seed_ratio=p.seed_ratio, - min_seeds=p.min_seeds, - rng_seed=p.rng_seed, - ) - ws.streamline_cache[t] = sl - return sl - - def _get_plane_streamline_mesh(self, t): - ws = self.workspace - if not ws.plane_streamline_active: - return None - if t in ws.plane_streamline_cache: - return ws.plane_streamline_cache[t] - if ws.flow_raw is None or ws.segmask_binary is None: - return None - pidx = ws.plane_streamline_plane_idx - if pidx < 0 or pidx >= len(ws.planes): - return None - plane = ws.planes[pidx] - p = ws.streamline_params - mask_t = ws.segmask_binary[..., min(max(0, int(t)), ws.segmask_binary.shape[3] - 1)] - sl = generate_streamlines_from_plane_at_t( - ws.flow_raw, t, plane, ws.resolution, ws.origin, - mask_3d=mask_t, - max_steps=p.max_steps, - terminal_speed=p.terminal_speed, - seed_ratio=p.seed_ratio, - min_seeds=p.min_seeds, - rng_seed=p.rng_seed, - branch_labels_3d=ws.branch_labels, - ) - ws.plane_streamline_cache[t] = sl - return sl - - def trigger_streamlines(self): - ws = self.workspace - if ws.flow_raw is None or ws.segmask_3d is None: - self.logger("Cannot generate streamlines: need flow + segmask_3d") - return - ws.streamline_seeds = generate_seed_points( - ws.segmask_3d, - ws.resolution, - ws.origin, - ratio=ws.streamline_params.seed_ratio, - rng_seed=ws.streamline_params.rng_seed, - min_seeds=ws.streamline_params.min_seeds, - ) - ws.streamline_cache.clear() - ws.streamline_active = True - p = ws.streamline_params - self.logger(f"Streamlines enabled: seed_ratio={p.seed_ratio} max_steps={p.max_steps} min_seeds={p.min_seeds} terminal_speed={p.terminal_speed} rng_seed={p.rng_seed}") - ws.remove_object_by_data_key("streamlines_live") - ws.add_object(name="streamlines", kind=ObjectKind.FLOW, - data_key="streamlines_live", visible=True, opacity=1.0, - scalars="Velocity", cmap="turbo", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") - self.sync_from_workspace() - - def trigger_plane_streamlines(self, plane_idx): - ws = self.workspace - if ws.flow_raw is None or ws.segmask_3d is None: - self.logger("Cannot generate plane streamlines: need flow + segmask_3d") - return - if plane_idx < 0 or plane_idx >= len(ws.planes): - self.logger(f"Invalid plane index: {plane_idx}") - return - ws.plane_streamline_cache.clear() - ws.plane_streamline_active = True - ws.plane_streamline_plane_idx = plane_idx - p = ws.streamline_params - self.logger(f"Plane streamlines enabled from plane {plane_idx}: seed_ratio={p.seed_ratio} min_seeds={p.min_seeds} max_steps={p.max_steps} terminal_speed={p.terminal_speed} rng_seed={p.rng_seed}") - ws.remove_object_by_data_key("plane_streamlines_live") - ws.add_object(name="plane_streamlines", kind=ObjectKind.FLOW, - data_key="plane_streamlines_live", visible=True, opacity=1.0, - scalars="Velocity", cmap="turbo", dynamic=True, - show_scalar_bar=True, scalar_bar_title="Velocity (m/s)") - self.sync_from_workspace() - - def clear_streamlines(self): - self.workspace.clear_streamlines() - self.invalidate_cache("streamlines") - self.sync_from_workspace() - self.logger("Streamlines cleared") - - def clear_plane_streamlines(self): - self.workspace.clear_plane_streamlines() - self.invalidate_cache("plane_streamlines") - self.sync_from_workspace() - self.logger("Plane streamlines cleared") - - def find_plane_uid_at_position(self, picked_point): - ws = self.workspace - if picked_point is None: - return None, None - picked = np.asarray(picked_point, dtype=float).reshape(3) - best_uid, best_idx, best_dist = None, None, float("inf") - org = np.asarray(ws.origin, dtype=float).reshape(3) - for uid, obj in ws.scene_objects.items(): - if obj.kind != ObjectKind.PLANE: - continue - pidx = _parse_indexed_data_key(obj.data_key, "plane") - if pidx is None: - continue - if pidx >= len(ws.planes): - continue - center = np.asarray(ws.planes[pidx].center, dtype=float) + org - d = float(np.linalg.norm(picked - center)) - if d < best_dist: - best_uid, best_idx, best_dist = uid, pidx, d - return (best_uid, best_idx) if best_dist <= 30.0 else (None, None) - - def find_path_uid_at_position(self, picked_point): - ws = self.workspace - if picked_point is None: - return None, None - picked = np.asarray(picked_point, dtype=float).reshape(3) - best_uid, best_idx, best_dist = None, None, float("inf") - org = np.asarray(ws.origin, dtype=float).reshape(3) - for uid, obj in ws.scene_objects.items(): - if obj.kind != ObjectKind.BRANCH: - continue - if not obj.data_key.startswith("smooth_path_"): - continue - try: - pidx = int(obj.data_key.split("_")[2]) - except Exception: - continue - if pidx >= len(ws.centerline_paths_smooth): - continue - path = np.asarray(ws.centerline_paths_smooth[pidx], dtype=float) + org.reshape(1, 3) - if len(path) == 0: - continue - d = float(np.min(np.linalg.norm(path - picked.reshape(1, 3), axis=1))) - if d < best_dist: - best_uid, best_idx, best_dist = uid, pidx, d - return (best_uid, best_idx) if best_dist <= 15.0 else (None, None) - - def _ensure_shared_right_click_picking(self): - if self._shared_pick_obs_id is not None: - return - try: - iren = self.plotter.iren.interactor - except Exception: - return - picker = pv._vtk.vtkCellPicker() - picker.SetTolerance(0.005) - - def _on_right_click(obj, ev): - try: - x, y = iren.GetEventPosition() - except Exception: - return - ren = self.plotter.renderer - ok = picker.Pick(float(x), float(y), 0.0, ren) - pos = picker.GetPickPosition() if ok else None - plane_uid, plane_idx = self.find_plane_uid_at_position(pos) if pos is not None else (None, None) - if plane_uid is not None and plane_idx is not None: - if self._plane_pick_callback is not None: - self._plane_pick_callback(plane_uid, plane_idx) - return - path_uid, path_idx = self.find_path_uid_at_position(pos) if pos is not None else (None, None) - if path_uid is not None and path_idx is not None: - if self._path_pick_callback is not None: - self._path_pick_callback(path_uid, path_idx) - return - if self._plane_pick_callback is not None: - self._plane_pick_callback(None, None) - if self._path_pick_callback is not None: - self._path_pick_callback(None, None) - - self._shared_pick_obs_id = iren.AddObserver("RightButtonPressEvent", _on_right_click) - self._plane_pick_obs_id = self._shared_pick_obs_id - self._path_pick_obs_id = self._shared_pick_obs_id - - def enable_plane_picking(self, callback): - self._plane_pick_callback = callback - self._ensure_shared_right_click_picking() - - def enable_path_picking(self, callback): - self._path_pick_callback = callback - self._ensure_shared_right_click_picking() + return SceneController + raise AttributeError(f"module 'autoflow.viewer' has no attribute {name!r}") diff --git a/tests/test_imports.py b/tests/test_imports.py index 4bf1f28..7cf19bf 100755 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -105,6 +105,52 @@ def test_base_imports_do_not_require_optional_gui_dependencies(): assert result.returncode == 0, result.stderr or result.stdout +def test_refactor_modules_and_utils_compatibility_import(): + import autoflow.plane_io as plane_io + import autoflow.processing as processing + import autoflow.rendering as rendering + import autoflow.reporting as reporting + import autoflow.utils as utils + + assert utils.process_single is processing.process_single + assert utils.resolve_reuse_plane_file is plane_io.resolve_reuse_plane_file + assert utils.print_qc_summary is reporting.print_qc_summary + assert utils.render_tke_video is rendering.render_tke_video + + +def test_core_and_ui_packages_reexport_legacy_entry_points(): + import autoflow.core as core + import autoflow.core.models as core_models + import autoflow.core.pipeline as core_pipeline + import autoflow.editors as editors + import autoflow.gui as gui + import autoflow.models as models + import autoflow.pipeline as pipeline + import autoflow.ui as ui + + assert models.Workspace is core_models.Workspace + assert pipeline.PipelineEngine is core_pipeline.PipelineEngine + assert editors.PlaneEditor is ui.PlaneEditor + assert gui.launch_gui is ui.launch_gui + assert core.PipelineEngine is core_pipeline.PipelineEngine + + +def test_gui_shims_remain_lazy_without_optional_dependencies(): + code = _block_optional_gui_imports( + """ + import autoflow.app as app + import autoflow.ortho_viewer as ortho_viewer + import autoflow.viewer as viewer + + print(app.__all__) + print(ortho_viewer.__all__) + print(viewer.__all__) + """ + ) + result = _run_python(code) + assert result.returncode == 0, result.stderr or result.stdout + + def test_launch_gui_reports_missing_optional_dependencies(): code = _block_optional_gui_imports( """