Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions 1.0.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Collecting poetry-core
Using cached poetry_core-2.1.3-py3-none-any.whl.metadata (3.5 kB)
Using cached poetry_core-2.1.3-py3-none-any.whl (332 kB)
Installing collected packages: poetry-core
Successfully installed poetry-core-2.1.3
80 changes: 32 additions & 48 deletions arcos4py/tools/_detect_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,13 @@ def brute_force_linking(
return cluster_labels, max_cluster_label



@njit(parallel=True)
def _compute_filtered_distances(current_coords, memory_coords):
n, m = len(current_coords), len(memory_coords)
distances = np.empty((n, m))
for i in prange(n):
for j in prange(m):
for j in range(m):
distances[i, j] = np.sum((current_coords[i] - memory_coords[j]) ** 2)
return np.sqrt(distances)

Expand All @@ -229,70 +230,53 @@ def transportation_linking(
cost_threshold: float = 0,
**kwargs: Dict[str, Any],
) -> Tuple[np.ndarray, int]:
"""Optimized transportation linking of clusters across frames, using a pre-constructed sklearn KDTree.

Args:
cluster_labels (np.ndarray): The cluster labels for the current frame.
cluster_coordinates (np.ndarray): The cluster coordinates for the current frame.
memory_cluster_labels (np.ndarray): The cluster labels for previous frames.
memory_coordinates (np.ndarray): The coordinates for previous frames.
memory_kdtree (KDTree): Pre-constructed sklearn KDTree for memory coordinates.
epsPrev (float): Frame-to-frame distance, used to connect clusters across frames.
max_cluster_label (int): The maximum label for clusters.
reg (float): Entropy regularization parameter for Sinkhorn algorithm.
reg_m (float): Marginal relaxation parameter for unbalanced OT.
cost_threshold (float): Threshold for filtering low-probability matches.
**kwargs: Additional keyword arguments.
"""
Optimized pixel-wise transportation linking of clusters across frames.

Returns:
Tuple[np.ndarray, int]: Updated cluster labels and the maximum cluster label.
Uses unbalanced OT to assign each current pixel to a previous pixel within epsPrev.
"""
# Find neighbors within the maximum allowed distance (epsPrev)
indices = memory_kdtree.query_radius(cluster_coordinates, r=epsPrev)
# Find all memory indices within epsPrev of any current pixel
neighbors = memory_kdtree.query_radius(cluster_coordinates, r=epsPrev)

if all(len(ind) == 0 for ind in indices):
if all(len(ind) == 0 for ind in neighbors):
max_cluster_label += 1
return np.full_like(cluster_labels, max_cluster_label), max_cluster_label
return np.full(cluster_labels.shape, max_cluster_label, dtype=int), max_cluster_label

# Prepare indices of valid points
valid_mask = np.array([len(ind) > 0 for ind in indices])
current_indices = np.arange(len(indices))[valid_mask]
memory_indices = np.array([ind[0] for ind in indices if len(ind) > 0])

if len(current_indices) == 0:
valid_mem_idx = np.unique(np.concatenate([ind for ind in neighbors if len(ind) > 0]))
if valid_mem_idx.size == 0:
max_cluster_label += 1
return np.full_like(cluster_labels, max_cluster_label), max_cluster_label

# Compute distance matrix for valid pairs
filtered_distances = _compute_filtered_distances(
cluster_coordinates[current_indices], memory_coordinates[memory_indices]
)
return np.full(cluster_labels.shape, max_cluster_label, dtype=int), max_cluster_label

# Uniform distribution on the valid points
a = np.ones(len(current_indices)) / len(current_indices)
b = np.ones(len(memory_indices)) / len(memory_indices)
# Build cost matrix between each current pixel and each candidate memory pixel
curr_coords = cluster_coordinates
mem_coords = memory_coordinates[valid_mem_idx]
cost_matrix = _compute_filtered_distances(curr_coords, mem_coords)

# Solve the unbalanced OT problem
ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, filtered_distances, reg, reg_m)
# Uniform distributions
n_curr = curr_coords.shape[0]
n_mem = mem_coords.shape[0]
a = np.ones(n_curr) / n_curr
b = np.ones(n_mem) / n_mem

# Propagate cluster id from previous frame
matches = np.argmax(ot_plan, axis=1)
# Solve unbalanced OT
ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, cost_matrix, reg, reg_m)

# Set matches to -1 if the cost is too high
matches[ot_plan[np.arange(len(matches)), matches] < cost_threshold] = -1
# Determine best assignment for each current pixel
best_mem = np.argmax(ot_plan, axis=1)
probs = ot_plan[np.arange(n_curr), best_mem]
best_mem[probs < cost_threshold] = -1

new_cluster_labels = _assign_labels(
matches, current_indices, memory_indices, memory_cluster_labels, cluster_labels.size
)
new_cluster_labels = np.full(n_curr, -1, dtype=int)
for i, m in enumerate(best_mem):
if m != -1:
new_cluster_labels[i] = int(memory_cluster_labels[valid_mem_idx[m]])

# Assign new labels to unmatched clusters
if np.any(new_cluster_labels == -1):
max_cluster_label += 1
new_cluster_labels[new_cluster_labels == -1] = max_cluster_label

return new_cluster_labels, max_cluster_label


@dataclass
class Memory:
"""Memory class for retaining coordinates and cluster IDs over a specified number of time points.
Expand Down Expand Up @@ -1227,7 +1211,7 @@ def link(self, input_coordinates: np.ndarray) -> None:
linked_cluster_ids = self._update_id(original_cluster_ids, coordinates)

if self._remove_small_clusters:
final_cluster_ids = self._apply_remove_small_clusters(linked_cluster_ids, original_cluster_ids)
linked_cluster_ids = self._apply_remove_small_clusters(linked_cluster_ids, original_cluster_ids)

# Apply stable merges and splits
final_cluster_ids = self._apply_stable_merges_splits(linked_cluster_ids, original_cluster_ids)
Expand Down
Binary file modified tests/testdata/pix/4_colliding_transportation.tif
Binary file not shown.
Loading