diff --git a/arcos4py/tools/_detect_events.py b/arcos4py/tools/_detect_events.py index 049968a..14f25de 100644 --- a/arcos4py/tools/_detect_events.py +++ b/arcos4py/tools/_detect_events.py @@ -230,8 +230,19 @@ def transportation_linking( cost_threshold: float = 0, **kwargs: Dict[str, Any], ) -> Tuple[np.ndarray, int]: - """ - Optimized pixel-wise transportation linking of clusters across frames. + """Optimized pixel-wise transportation linking of clusters across frames. + + Arguments: + cluster_labels (np.ndarray): The cluster labels for the current frame. + cluster_coordinates (np.ndarray): The cluster coordinates for the current frame. + memory_cluster_labels (np.ndarray): The cluster labels for previous frames. + memory_coordinates (np.ndarray): The cluster coordinates for previous frames. + memory_kdtree (KDTree): KDTree for the previous frame's clusters. + epsPrev (float): Frame-to-frame distance, used to connect clusters across frames. + max_cluster_label (int): The maximum label for clusters. + reg (float): Entropy regularization parameter for unbalanced OT algorithm. + reg_m (float): Marginal relaxation parameter for unbalanced OT. + cost_threshold (float): Cost threshold for assigning clusters. Uses unbalanced OT to assign each current pixel to a previous pixel within epsPrev. """ @@ -249,14 +260,14 @@ def transportation_linking( # Build cost matrix between each current pixel and each candidate memory pixel curr_coords = cluster_coordinates - mem_coords = memory_coordinates[valid_mem_idx] + mem_coords = memory_coordinates[valid_mem_idx] cost_matrix = _compute_filtered_distances(curr_coords, mem_coords) # Uniform distributions n_curr = curr_coords.shape[0] - n_mem = mem_coords.shape[0] + n_mem = mem_coords.shape[0] a = np.ones(n_curr) / n_curr - b = np.ones(n_mem) / n_mem + b = np.ones(n_mem) / n_mem # Solve unbalanced OT ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, cost_matrix, reg, reg_m)