diff --git a/1.0.0 b/1.0.0 new file mode 100644 index 0000000..2d8d6fd --- /dev/null +++ b/1.0.0 @@ -0,0 +1,5 @@ +Collecting poetry-core + Using cached poetry_core-2.1.3-py3-none-any.whl.metadata (3.5 kB) +Using cached poetry_core-2.1.3-py3-none-any.whl (332 kB) +Installing collected packages: poetry-core +Successfully installed poetry-core-2.1.3 diff --git a/arcos4py/tools/_detect_events.py b/arcos4py/tools/_detect_events.py index 7ff6e32..049968a 100644 --- a/arcos4py/tools/_detect_events.py +++ b/arcos4py/tools/_detect_events.py @@ -197,12 +197,13 @@ def brute_force_linking( return cluster_labels, max_cluster_label + @njit(parallel=True) def _compute_filtered_distances(current_coords, memory_coords): n, m = len(current_coords), len(memory_coords) distances = np.empty((n, m)) for i in prange(n): - for j in prange(m): + for j in range(m): distances[i, j] = np.sum((current_coords[i] - memory_coords[j]) ** 2) return np.sqrt(distances) @@ -229,70 +230,53 @@ def transportation_linking( cost_threshold: float = 0, **kwargs: Dict[str, Any], ) -> Tuple[np.ndarray, int]: - """Optimized transportation linking of clusters across frames, using a pre-constructed sklearn KDTree. - - Args: - cluster_labels (np.ndarray): The cluster labels for the current frame. - cluster_coordinates (np.ndarray): The cluster coordinates for the current frame. - memory_cluster_labels (np.ndarray): The cluster labels for previous frames. - memory_coordinates (np.ndarray): The coordinates for previous frames. - memory_kdtree (KDTree): Pre-constructed sklearn KDTree for memory coordinates. - epsPrev (float): Frame-to-frame distance, used to connect clusters across frames. - max_cluster_label (int): The maximum label for clusters. - reg (float): Entropy regularization parameter for Sinkhorn algorithm. - reg_m (float): Marginal relaxation parameter for unbalanced OT. - cost_threshold (float): Threshold for filtering low-probability matches. - **kwargs: Additional keyword arguments. + """ + Optimized pixel-wise transportation linking of clusters across frames. - Returns: - Tuple[np.ndarray, int]: Updated cluster labels and the maximum cluster label. + Uses unbalanced OT to assign each current pixel to a previous pixel within epsPrev. """ - # Find neighbors within the maximum allowed distance (epsPrev) - indices = memory_kdtree.query_radius(cluster_coordinates, r=epsPrev) + # Find all memory indices within epsPrev of any current pixel + neighbors = memory_kdtree.query_radius(cluster_coordinates, r=epsPrev) - if all(len(ind) == 0 for ind in indices): + if all(len(ind) == 0 for ind in neighbors): max_cluster_label += 1 - return np.full_like(cluster_labels, max_cluster_label), max_cluster_label + return np.full(cluster_labels.shape, max_cluster_label, dtype=int), max_cluster_label - # Prepare indices of valid points - valid_mask = np.array([len(ind) > 0 for ind in indices]) - current_indices = np.arange(len(indices))[valid_mask] - memory_indices = np.array([ind[0] for ind in indices if len(ind) > 0]) - - if len(current_indices) == 0: + valid_mem_idx = np.unique(np.concatenate([ind for ind in neighbors if len(ind) > 0])) + if valid_mem_idx.size == 0: max_cluster_label += 1 - return np.full_like(cluster_labels, max_cluster_label), max_cluster_label - - # Compute distance matrix for valid pairs - filtered_distances = _compute_filtered_distances( - cluster_coordinates[current_indices], memory_coordinates[memory_indices] - ) + return np.full(cluster_labels.shape, max_cluster_label, dtype=int), max_cluster_label - # Uniform distribution on the valid points - a = np.ones(len(current_indices)) / len(current_indices) - b = np.ones(len(memory_indices)) / len(memory_indices) + # Build cost matrix between each current pixel and each candidate memory pixel + curr_coords = cluster_coordinates + mem_coords = memory_coordinates[valid_mem_idx] + cost_matrix = _compute_filtered_distances(curr_coords, mem_coords) - # Solve the unbalanced OT problem - ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, filtered_distances, reg, reg_m) + # Uniform distributions + n_curr = curr_coords.shape[0] + n_mem = mem_coords.shape[0] + a = np.ones(n_curr) / n_curr + b = np.ones(n_mem) / n_mem - # Propagate cluster id from previous frame - matches = np.argmax(ot_plan, axis=1) + # Solve unbalanced OT + ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, cost_matrix, reg, reg_m) - # Set matches to -1 if the cost is too high - matches[ot_plan[np.arange(len(matches)), matches] < cost_threshold] = -1 + # Determine best assignment for each current pixel + best_mem = np.argmax(ot_plan, axis=1) + probs = ot_plan[np.arange(n_curr), best_mem] + best_mem[probs < cost_threshold] = -1 - new_cluster_labels = _assign_labels( - matches, current_indices, memory_indices, memory_cluster_labels, cluster_labels.size - ) + new_cluster_labels = np.full(n_curr, -1, dtype=int) + for i, m in enumerate(best_mem): + if m != -1: + new_cluster_labels[i] = int(memory_cluster_labels[valid_mem_idx[m]]) - # Assign new labels to unmatched clusters if np.any(new_cluster_labels == -1): max_cluster_label += 1 new_cluster_labels[new_cluster_labels == -1] = max_cluster_label return new_cluster_labels, max_cluster_label - @dataclass class Memory: """Memory class for retaining coordinates and cluster IDs over a specified number of time points. @@ -1227,7 +1211,7 @@ def link(self, input_coordinates: np.ndarray) -> None: linked_cluster_ids = self._update_id(original_cluster_ids, coordinates) if self._remove_small_clusters: - final_cluster_ids = self._apply_remove_small_clusters(linked_cluster_ids, original_cluster_ids) + linked_cluster_ids = self._apply_remove_small_clusters(linked_cluster_ids, original_cluster_ids) # Apply stable merges and splits final_cluster_ids = self._apply_stable_merges_splits(linked_cluster_ids, original_cluster_ids) diff --git a/tests/testdata/pix/4_colliding_transportation.tif b/tests/testdata/pix/4_colliding_transportation.tif index 9f47a22..abf7511 100644 Binary files a/tests/testdata/pix/4_colliding_transportation.tif and b/tests/testdata/pix/4_colliding_transportation.tif differ