2222sys .path .append (os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
2323from evaluator import evaluate_model
2424
25+ def group_features_by_compound (model_name , data_root ):
26+ """Groups features by compound name."""
27+ feature_dir = os .path .join (data_root , "bbbc021_features" , model_name )
28+ if not os .path .exists (feature_dir ):
29+ raise FileNotFoundError (f"Features directory not found: { feature_dir } " )
30+
31+ compound_features = defaultdict (list )
32+ for file in os .listdir (feature_dir ):
33+ if not file .endswith (".pkl" ) or "DMSO" in file :
34+ continue
35+ filepath = os .path .join (feature_dir , file )
36+ try :
37+ with open (filepath , 'rb' ) as f :
38+ (compound_info , feat ) = pickle .load (f )
39+ compound , _ , moa = compound_info
40+ if moa != "null" :
41+ compound_features [compound ].append ({'feature' : feat .numpy () if hasattr (feat , 'numpy' ) else feat , 'moa' : moa })
42+ except Exception as e :
43+ print (f"Warning: Could not load { filepath } : { e } " )
44+ return compound_features
45+
2546def load_model_features_and_moas (model_name , data_root = "/scratch/cv-course2025/group8" ):
2647 """
2748 Load features and MoAs for a single model.
@@ -120,6 +141,7 @@ def plot_accuracy_comparison(model_names, data_root="/scratch/cv-course2025/grou
120141def plot_confusion_matrices (model_names , data_root = "/scratch/cv-course2025/group8" , output_dir = "/scratch/cv-course2025/group8/plots" , distance_measure = "cosine" ):
121142 """
122143 Create confusion matrices for each model showing predicted vs actual MoA.
144+ This uses a leave-one-compound-out cross-validation approach.
123145
124146 Args:
125147 model_names (list): List of model names to compare
@@ -132,45 +154,77 @@ def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group
132154 n_cols = 2
133155 n_rows = math .ceil (n_models / n_cols )
134156 fig , axes = plt .subplots (n_rows , n_cols , figsize = (6 * n_cols , 5 * n_rows ))
135- axes = axes .flatten () # Flatten in case of single row
157+
158+ # Handle case where we have only one subplot
159+ if n_models == 1 :
160+ axes = [axes ]
161+ else :
162+ axes = axes .flatten ()
163+
164+ # Get a superset of all MoAs across all models for consistent plotting
165+ all_moas = set ()
166+ all_compound_features = {}
167+ for model_name in model_names :
168+ try :
169+ _ , moas = load_model_features_and_moas (model_name , data_root )
170+ all_moas .update (moas )
171+ all_compound_features [model_name ] = group_features_by_compound (model_name , data_root )
172+ except (FileNotFoundError , ValueError ) as e :
173+ print (f"Could not load data for { model_name } : { e } " )
174+ unique_moas = sorted (list (all_moas ))
136175
137176 for i , model_name in enumerate (model_names ):
177+ ax = axes [i ]
138178 try :
139179 print (f"Creating confusion matrix for { model_name } ..." )
140180
141- # Load features and MoAs
142- features , moas = load_model_features_and_moas (model_name , data_root )
143-
144- # Get unique MoAs
145- unique_moas = sorted (set (moas ))
146-
147- # Perform 1-NN classification
148- if distance_measure == "cosine" :
149- # Normalize features for cosine similarity
150- features_norm = features / np .linalg .norm (features , axis = 1 , keepdims = True )
151- nbrs = NearestNeighbors (n_neighbors = 2 , metric = 'cosine' )
152- nbrs .fit (features_norm )
153- distances , indices = nbrs .kneighbors (features_norm )
154- else :
155- nbrs = NearestNeighbors (n_neighbors = 2 , metric = distance_measure )
156- nbrs .fit (features )
157- distances , indices = nbrs .kneighbors (features )
158-
159- # Get predictions (nearest neighbor that's not itself)
181+ compound_features = all_compound_features .get (model_name )
182+ if not compound_features :
183+ raise ValueError ("No compound features loaded." )
184+
185+ compounds = list (compound_features .keys ())
160186 predicted_moas = []
161187 actual_moas = []
162-
163- for j , (dist , idx ) in enumerate (zip (distances , indices )):
164- # Skip self (first neighbor)
165- nearest_idx = idx [1 ]
166- predicted_moas .append (moas [nearest_idx ])
167- actual_moas .append (moas [j ])
168-
188+
189+ for compound_to_test in compounds :
190+ # Gallery: all other compounds
191+ gallery_features = []
192+ gallery_moas = []
193+ for c in compounds :
194+ if c != compound_to_test :
195+ for sample in compound_features [c ]:
196+ gallery_features .append (sample ['feature' ])
197+ gallery_moas .append (sample ['moa' ])
198+
199+ if not gallery_features :
200+ continue
201+
202+ gallery_features = np .array (gallery_features )
203+ gallery_moas = np .array (gallery_moas )
204+
205+ # Query: features of the compound to test
206+ query_features = [sample ['feature' ] for sample in compound_features [compound_to_test ]]
207+ query_moas = [sample ['moa' ] for sample in compound_features [compound_to_test ]]
208+
209+ # Fit NearestNeighbors on the gallery
210+ if distance_measure == "cosine" :
211+ gallery_features = gallery_features / np .linalg .norm (gallery_features , axis = 1 , keepdims = True )
212+ query_features = np .array (query_features ) / np .linalg .norm (query_features , axis = 1 , keepdims = True )
213+
214+ nbrs = NearestNeighbors (n_neighbors = 1 , metric = distance_measure )
215+ nbrs .fit (gallery_features )
216+
217+ # Find nearest neighbor for each query sample
218+ _ , indices = nbrs .kneighbors (query_features )
219+
220+ for j , idx_list in enumerate (indices ):
221+ predicted_moas .append (gallery_moas [idx_list [0 ]])
222+ actual_moas .append (query_moas [j ])
223+
169224 # Create confusion matrix
170225 cm = confusion_matrix (actual_moas , predicted_moas , labels = unique_moas )
171226
172227 # Plot
173- ax = axes [i ]
174228 sns .heatmap (cm , annot = True , fmt = 'd' , cmap = 'Blues' ,
175229 xticklabels = unique_moas , yticklabels = unique_moas ,
176230 ax = ax , cbar = i == n_models - 1 ) # Only show colorbar for last plot
@@ -179,13 +233,11 @@ def plot_confusion_matrices(model_names, data_root="/scratch/cv-course2025/group
179233 ax .set_xlabel ('Predicted MoA' )
180234 ax .set_ylabel ('Actual MoA' )
181235
182- # Rotate labels for better readability
183236 ax .set_xticklabels (ax .get_xticklabels (), rotation = 45 , ha = 'right' )
184237 ax .set_yticklabels (ax .get_yticklabels (), rotation = 0 )
185238
186239 except Exception as e :
187240 print (f"Error creating confusion matrix for { model_name } : { e } " )
188- ax = axes [i ]
189241 ax .text (0.5 , 0.5 , f'Error: { str (e )} ' , transform = ax .transAxes ,
190242 ha = 'center' , va = 'center' , fontsize = 12 )
191243 ax .set_title (f'{ model_name } - Error' )
0 commit comments