Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions arcos4py/tools/_detect_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,19 @@ def transportation_linking(
cost_threshold: float = 0,
**kwargs: Dict[str, Any],
) -> Tuple[np.ndarray, int]:
"""
Optimized pixel-wise transportation linking of clusters across frames.
"""Optimized pixel-wise transportation linking of clusters across frames.

Arguments:
cluster_labels (np.ndarray): The cluster labels for the current frame.
cluster_coordinates (np.ndarray): The cluster coordinates for the current frame.
memory_cluster_labels (np.ndarray): The cluster labels for previous frames.
memory_coordinates (np.ndarray): The cluster coordinates for previous frames.
memory_kdtree (KDTree): KDTree for the previous frame's clusters.
epsPrev (float): Frame-to-frame distance, used to connect clusters across frames.
max_cluster_label (int): The maximum label for clusters.
reg (float): Entropy regularization parameter for unbalanced OT algorithm.
reg_m (float): Marginal relaxation parameter for unbalanced OT.
cost_threshold (float): Cost threshold for assigning clusters.

Uses unbalanced OT to assign each current pixel to a previous pixel within epsPrev.
"""
Expand All @@ -249,14 +260,14 @@ def transportation_linking(

# Build cost matrix between each current pixel and each candidate memory pixel
curr_coords = cluster_coordinates
mem_coords = memory_coordinates[valid_mem_idx]
mem_coords = memory_coordinates[valid_mem_idx]
cost_matrix = _compute_filtered_distances(curr_coords, mem_coords)

# Uniform distributions
n_curr = curr_coords.shape[0]
n_mem = mem_coords.shape[0]
n_mem = mem_coords.shape[0]
a = np.ones(n_curr) / n_curr
b = np.ones(n_mem) / n_mem
b = np.ones(n_mem) / n_mem

# Solve unbalanced OT
ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, cost_matrix, reg, reg_m)
Expand Down