Skip to content

Commit 6da2aba

Browse files
committed
fix conf matrix eval
1 parent 57bb453 commit 6da2aba

3 files changed

Lines changed: 93 additions & 43 deletions

File tree

4.91 KB
Loading

evaluation/visualization/model_comparison.py

Lines changed: 82 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,27 @@
2222
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
2323
from evaluator import evaluate_model
2424

25+
def group_features_by_compound(model_name, data_root):
26+
"""Groups features by compound name."""
27+
feature_dir = os.path.join(data_root, "bbbc021_features", model_name)
28+
if not os.path.exists(feature_dir):
29+
raise FileNotFoundError(f"Features directory not found: {feature_dir}")
30+
31+
compound_features = defaultdict(list)
32+
for file in os.listdir(feature_dir):
33+
if not file.endswith(".pkl") or "DMSO" in file:
34+
continue
35+
filepath = os.path.join(feature_dir, file)
36+
try:
37+
with open(filepath, 'rb') as f:
38+
(compound_info, feat) = pickle.load(f)
39+
compound, _, moa = compound_info
40+
if moa != "null":
41+
compound_features[compound].append({'feature': feat.numpy() if hasattr(feat, 'numpy') else feat, 'moa': moa})
42+
except Exception as e:
43+
print(f"Warning: Could not load {filepath}: {e}")
44+
return compound_features
45+
2546
def load_model_features_and_moas(model_name, data_root="/scratch/cv-course2025/group8"):
2647
"""
2748
Load features and MoAs for a single model.
@@ -120,6 +141,7 @@ def plot_accuracy_comparison(model_names, data_root="/scratch/cv-course2025/grou
120141
def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group8", output_dir="/scratch/cv-course2025/group8/plots", distance_measure="cosine"):
121142
"""
122143
Create confusion matrices for each model showing predicted vs actual MoA.
144+
This uses a leave-one-compound-out cross-validation approach.
123145
124146
Args:
125147
model_names (list): List of model names to compare
@@ -132,45 +154,77 @@ def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group
132154
n_cols = 2
133155
n_rows = math.ceil(n_models / n_cols)
134156
fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows))
135-
axes = axes.flatten() # Flatten in case of single row
157+
158+
# Handle case where we have only one subplot
159+
if n_models == 1:
160+
axes = [axes]
161+
else:
162+
axes = axes.flatten()
163+
164+
# Get a superset of all MoAs across all models for consistent plotting
165+
all_moas = set()
166+
all_compound_features = {}
167+
for model_name in model_names:
168+
try:
169+
_, moas = load_model_features_and_moas(model_name, data_root)
170+
all_moas.update(moas)
171+
all_compound_features[model_name] = group_features_by_compound(model_name, data_root)
172+
except (FileNotFoundError, ValueError) as e:
173+
print(f"Could not load data for {model_name}: {e}")
174+
unique_moas = sorted(list(all_moas))
136175

137176
for i, model_name in enumerate(model_names):
177+
ax = axes[i]
138178
try:
139179
print(f"Creating confusion matrix for {model_name}...")
140180

141-
# Load features and MoAs
142-
features, moas = load_model_features_and_moas(model_name, data_root)
143-
144-
# Get unique MoAs
145-
unique_moas = sorted(set(moas))
146-
147-
# Perform 1-NN classification
148-
if distance_measure == "cosine":
149-
# Normalize features for cosine similarity
150-
features_norm = features / np.linalg.norm(features, axis=1, keepdims=True)
151-
nbrs = NearestNeighbors(n_neighbors=2, metric='cosine')
152-
nbrs.fit(features_norm)
153-
distances, indices = nbrs.kneighbors(features_norm)
154-
else:
155-
nbrs = NearestNeighbors(n_neighbors=2, metric=distance_measure)
156-
nbrs.fit(features)
157-
distances, indices = nbrs.kneighbors(features)
158-
159-
# Get predictions (nearest neighbor that's not itself)
181+
compound_features = all_compound_features.get(model_name)
182+
if not compound_features:
183+
raise ValueError("No compound features loaded.")
184+
185+
compounds = list(compound_features.keys())
160186
predicted_moas = []
161187
actual_moas = []
162-
163-
for j, (dist, idx) in enumerate(zip(distances, indices)):
164-
# Skip self (first neighbor)
165-
nearest_idx = idx[1]
166-
predicted_moas.append(moas[nearest_idx])
167-
actual_moas.append(moas[j])
168-
188+
189+
for compound_to_test in compounds:
190+
# Gallery: all other compounds
191+
gallery_features = []
192+
gallery_moas = []
193+
for c in compounds:
194+
if c != compound_to_test:
195+
for sample in compound_features[c]:
196+
gallery_features.append(sample['feature'])
197+
gallery_moas.append(sample['moa'])
198+
199+
if not gallery_features:
200+
continue
201+
202+
gallery_features = np.array(gallery_features)
203+
gallery_moas = np.array(gallery_moas)
204+
205+
# Query: features of the compound to test
206+
query_features = [sample['feature'] for sample in compound_features[compound_to_test]]
207+
query_moas = [sample['moa'] for sample in compound_features[compound_to_test]]
208+
209+
# Fit NearestNeighbors on the gallery
210+
if distance_measure == "cosine":
211+
gallery_features = gallery_features / np.linalg.norm(gallery_features, axis=1, keepdims=True)
212+
query_features = np.array(query_features) / np.linalg.norm(query_features, axis=1, keepdims=True)
213+
214+
nbrs = NearestNeighbors(n_neighbors=1, metric=distance_measure)
215+
nbrs.fit(gallery_features)
216+
217+
# Find nearest neighbor for each query sample
218+
_, indices = nbrs.kneighbors(query_features)
219+
220+
for j, idx_list in enumerate(indices):
221+
predicted_moas.append(gallery_moas[idx_list[0]])
222+
actual_moas.append(query_moas[j])
223+
169224
# Create confusion matrix
170225
cm = confusion_matrix(actual_moas, predicted_moas, labels=unique_moas)
171226

172227
# Plot
173-
ax = axes[i]
174228
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
175229
xticklabels=unique_moas, yticklabels=unique_moas,
176230
ax=ax, cbar=i == n_models - 1) # Only show colorbar for last plot
@@ -179,13 +233,11 @@ def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group
179233
ax.set_xlabel('Predicted MoA')
180234
ax.set_ylabel('Actual MoA')
181235

182-
# Rotate labels for better readability
183236
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
184237
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
185238

186239
except Exception as e:
187240
print(f"Error creating confusion matrix for {model_name}: {e}")
188-
ax = axes[i]
189241
ax.text(0.5, 0.5, f'Error: {str(e)}', transform=ax.transAxes,
190242
ha='center', va='center', fontsize=12)
191243
ax.set_title(f'{model_name} - Error')

notebooks/evaluate_models.ipynb

Lines changed: 11 additions & 13 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)