Skip to content

Commit fa0dada

Browse files
committed
Optimized SSD performance
1 parent a294a42 commit fa0dada

2 files changed

Lines changed: 54 additions & 33 deletions

File tree

benchmarks/results/latency.png

3.58 KB
Loading

src/pyversity/strategies/ssd.py

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:
113113
std = float(np.std(relevance_scores))
114114
relevance_scores = (relevance_scores - mean) / std if std > 0.0 else (relevance_scores - mean)
115115

116-
num_items, _ = feature_matrix.shape
116+
num_items, n_dims = feature_matrix.shape
117117

118118
# Initialize selection state
119119
selected_mask = np.zeros(num_items, dtype=bool)
@@ -123,31 +123,52 @@ def _prepare_vectors(matrix: np.ndarray) -> np.ndarray:
123123
# Current residuals under the sliding window
124124
residual_matrix = feature_matrix.astype(np.float32, copy=True)
125125

126-
# Sliding window storage
127-
basis_vectors: list[np.ndarray] = []
128-
projection_coefficients_per_basis: list[np.ndarray] = []
126+
# Incrementally maintained squared norms: residual_sq_norms[i] = ||residual_matrix[i]||^2
127+
residual_sq_norms: np.ndarray = np.einsum("ij,ij->i", residual_matrix, residual_matrix)
128+
129+
# Pre-allocated circular buffer
130+
basis_matrix = np.zeros((window_size, n_dims), dtype=np.float32)
131+
coeff_matrix = np.zeros((window_size, num_items), dtype=np.float32)
132+
window_count = 0
133+
window_head = 0
134+
135+
# Pre-allocated buffer for rank-1 updates
136+
_outer_buf = np.empty((num_items, n_dims), dtype=np.float32)
129137

130138
def _push_basis_vector(basis_vector: np.ndarray) -> None:
131139
"""Add a new basis vector to the sliding window and update residuals/projections."""
132-
if len(basis_vectors) == window_size:
133-
# Remove oldest basis and restore its contribution to residuals
134-
oldest_basis = basis_vectors.pop(0)
135-
oldest_coefficients = projection_coefficients_per_basis.pop(0)
136-
mask_unselected = ~selected_mask
137-
if np.any(mask_unselected):
138-
residual_matrix[mask_unselected] += oldest_coefficients[mask_unselected, None] * oldest_basis
139-
140-
denominator = float(basis_vector @ basis_vector) + EPS32
141-
basis_vectors.append(basis_vector.astype(np.float32, copy=False))
142-
143-
mask_unselected = ~selected_mask
144-
coefficients = np.zeros(num_items, dtype=np.float32)
145-
if np.any(mask_unselected):
146-
projections = (residual_matrix[mask_unselected] @ basis_vector) / denominator
147-
coefficients[mask_unselected] = projections
148-
residual_matrix[mask_unselected] -= projections[:, None] * basis_vector
149-
150-
projection_coefficients_per_basis.append(coefficients)
140+
nonlocal window_count, window_head
141+
142+
if window_count == window_size:
143+
# Evict oldest: restore its contribution to residuals (full-array op).
144+
# Zero out selected items so their residuals stay untouched.
145+
oldest_slot = window_head
146+
coeff_matrix[oldest_slot][selected_mask] = 0.0
147+
old_coeffs = coeff_matrix[oldest_slot]
148+
old_basis = basis_matrix[oldest_slot]
149+
old_basis_sq = float(old_basis @ old_basis)
150+
# r_new = r + c * b → ||r_new||^2 = ||r||^2 + 2c(r·b) + c^2||b||^2
151+
dots_evict = residual_matrix @ old_basis
152+
residual_sq_norms[:] += old_coeffs * (2.0 * dots_evict + old_coeffs * old_basis_sq)
153+
np.outer(old_coeffs, old_basis, out=_outer_buf)
154+
np.add(residual_matrix, _outer_buf, out=residual_matrix)
155+
else:
156+
window_count += 1
157+
158+
basis_sq = float(basis_vector @ basis_vector)
159+
denominator = basis_sq + EPS32
160+
basis_matrix[window_head] = basis_vector
161+
dots = residual_matrix @ basis_vector
162+
coefficients = dots / denominator
163+
coefficients[selected_mask] = 0.0
164+
coeff_matrix[window_head] = coefficients
165+
# r_new = r - c * b → ||r_new||^2 = ||r||^2 - 2c(r·b) + c^2||b||^2
166+
# = ||r||^2 - c(2·dot - c·basis_sq)
167+
residual_sq_norms[:] -= coefficients * (2.0 * dots - coefficients * basis_sq)
168+
np.maximum(residual_sq_norms, 0.0, out=residual_sq_norms)
169+
np.outer(coefficients, basis_vector, out=_outer_buf)
170+
np.subtract(residual_matrix, _outer_buf, out=residual_matrix)
171+
window_head = (window_head + 1) % window_size
151172

152173
# Seed with recent context (oldest → newest) if provided
153174
seeded_bases = 0
@@ -156,7 +177,9 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
156177
context = context[-window_size:] # keep only the latest `window_size` items
157178
for context_vector in context:
158179
residual_context = context_vector.copy()
159-
for basis in basis_vectors:
180+
for slot_offset in range(window_count):
181+
slot_idx = (window_head - window_count + slot_offset) % window_size
182+
basis = basis_matrix[slot_idx]
160183
denominator_b = float(basis @ basis) + EPS32
161184
residual_context -= float(residual_context @ basis) / denominator_b * basis
162185
_push_basis_vector(residual_context)
@@ -165,7 +188,7 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
165188
# Decide what to select first
166189
if seeded_bases > 0:
167190
# Use combined scores with diversity from seeded context
168-
residual_norms = np.linalg.norm(residual_matrix, axis=1)
191+
residual_norms = np.sqrt(residual_sq_norms)
169192
combined_scores = theta * relevance_scores + (1.0 - theta) * gamma * residual_norms
170193
combined_scores[selected_mask] = -np.inf
171194
first_index = int(np.argmax(combined_scores))
@@ -186,14 +209,12 @@ def _push_basis_vector(basis_vector: np.ndarray) -> None:
186209

187210
# Main loop
188211
for step in range(1, top_k):
189-
# Find best candidate among unselected items
190-
available_indices = np.where(~selected_mask)[0]
191-
# Residual norms measure novelty relative to the last `window` selections/context
192-
residual_norms = np.linalg.norm(residual_matrix[available_indices], axis=1)
193-
combined_scores = theta * relevance_scores[available_indices] + (1.0 - theta) * gamma * residual_norms
194-
local_best = int(np.argmax(combined_scores))
195-
best_index = int(available_indices[local_best])
196-
best_score = float(combined_scores[local_best])
212+
# Compute scores using incrementally maintained squared norms
213+
residual_norms = np.sqrt(residual_sq_norms)
214+
combined_scores = theta * relevance_scores + (1.0 - theta) * gamma * residual_norms
215+
combined_scores[selected_mask] = -np.inf
216+
best_index = int(np.argmax(combined_scores))
217+
best_score = float(combined_scores[best_index])
197218

198219
# Update selection state
199220
selected_mask[best_index] = True

0 commit comments

Comments
 (0)