diff --git a/src/vlab4mic/analysis/_plots.py b/src/vlab4mic/analysis/_plots.py index 9a30a7e..02435b8 100644 --- a/src/vlab4mic/analysis/_plots.py +++ b/src/vlab4mic/analysis/_plots.py @@ -26,7 +26,6 @@ def sns_heatmap_pivots( f, axes = plt.subplots(nconditions, 2, figsize=figsize, squeeze = False) plot_num = 0 if cmaps_range == "same": - # min and max here correspond to SSIM hist_params = dict(vmin=0, vmax=1) elif cmaps_range == "each": hist_params = dict() diff --git a/src/vlab4mic/analysis/metrics.py b/src/vlab4mic/analysis/metrics.py index db1fd3c..f1119fb 100644 --- a/src/vlab4mic/analysis/metrics.py +++ b/src/vlab4mic/analysis/metrics.py @@ -5,8 +5,89 @@ from skimage.feature import peak_local_max from scipy.stats import pearsonr import cv2 +import copy + +def match_image_sizes( + reference_image = None, + reference_image_pixelsize_nm = None, + reference_image_mask = None, + simulated_image = None, + simulated_image_pixelsize_nm = None, + simulated_image_mask = None + ): + union_mask = None + if reference_image_pixelsize_nm and simulated_image_pixelsize_nm: + reference_interpolated, simulated_image_interpolated = resize_images_interpolation( + img1=reference_image, + img2=simulated_image, + px_size_im1=reference_image_pixelsize_nm, + px_size_im2=simulated_image_pixelsize_nm, + ) + if reference_image_mask is not None and simulated_image_mask is not None: + reference_mask_interpolated, simulated_image_mask_interpolated = resize_images_interpolation( + img1=reference_image_mask, + img2=simulated_image_mask, + px_size_im1=reference_image_pixelsize_nm, + px_size_im2=simulated_image_pixelsize_nm, + interpolation_order=0 # becauese they are masks + ) + union_mask = np.logical_or( + reference_mask_interpolated, + simulated_image_mask_interpolated) + #masks_interpolated["reference_mask"] = reference_mask_interpolated + #masks_interpolated["query_mask"] = simulated_image_mask_interpolated + #masks_interpolated["union_mask"] = union_mask + return reference_interpolated, simulated_image_interpolated, union_mask + + +def structural_similarity( + reference_image = None, + reference_image_pixelsize_nm = None, + reference_image_mask = None, + simulated_image = None, + simulated_image_pixelsize_nm = None, + simulated_image_mask = None +): + reference_interpolated, simulated_image_interpolated, union_mask = match_image_sizes( + reference_image = reference_image, + reference_image_pixelsize_nm = reference_image_pixelsize_nm, + reference_image_mask = reference_image_mask, + simulated_image = simulated_image, + simulated_image_pixelsize_nm = simulated_image_pixelsize_nm, + simulated_image_mask = simulated_image_mask + ) + similarity = ssim( + reference_interpolated[union_mask], + simulated_image_interpolated[union_mask], + data_range=simulated_image_interpolated[union_mask].max() - simulated_image_interpolated[union_mask].min() + ) + return similarity + + +def pearson_correlation( + reference_image = None, + reference_image_pixelsize_nm = None, + reference_image_mask = None, + simulated_image = None, + simulated_image_pixelsize_nm = None, + simulated_image_mask = None +): + reference_interpolated, simulated_image_interpolated, union_mask = match_image_sizes( + reference_image = reference_image, + reference_image_pixelsize_nm = reference_image_pixelsize_nm, + reference_image_mask = reference_image_mask, + simulated_image = simulated_image, + simulated_image_pixelsize_nm = simulated_image_pixelsize_nm, + simulated_image_mask = simulated_image_mask + ) + pearson_correlation, pval = pearsonr( + reference_interpolated[union_mask].flatten(), + simulated_image_interpolated[union_mask].flatten() + ) + return pearson_correlation + -def img_compare(ref, query, metric=["ssim",], force_match=False, zoom_in=0, ref_mask = None, query_mask = None, **kwargs): +def img_compare(ref, query, metric=["ssim",], force_match=False, zoom_in=0, ref_mask = None, query_mask = None, custom_metrics = None, **kwargs): """ Compare two images using specified similarity metrics. @@ -34,8 +115,14 @@ def img_compare(ref, query, metric=["ssim",], force_match=False, zoom_in=0, ref_ query : numpy.ndarray (Possibly resized) query image. """ + reference_image = ref.copy() + query_image = query.copy() + ref_pixelsize = None + query_pixelsize = None if force_match: if 'ref_pixelsize' in kwargs and 'modality_pixelsize' in kwargs: + ref_pixelsize = copy.copy(kwargs['ref_pixelsize']) + query_pixelsize = copy.copy(kwargs['modality_pixelsize']) ref, query = resize_images_interpolation( img1=ref, img2=query, @@ -66,7 +153,6 @@ def img_compare(ref, query, metric=["ssim",], force_match=False, zoom_in=0, ref_ zoom_in=zoom_in ) similarity_vector = [] - for method in metric: if method == "ssim": similarity = ssim(ref[union_mask], query[union_mask], data_range=query[union_mask].max() - query[union_mask].min()) @@ -74,6 +160,22 @@ def img_compare(ref, query, metric=["ssim",], force_match=False, zoom_in=0, ref_ elif method == "pearson": similarity, pval = pearsonr(ref[union_mask].flatten(), query[union_mask].flatten()) similarity_vector.append(similarity) + elif method in custom_metrics.keys(): + custom_measurement = custom_metrics[method]( + reference_image = reference_image, + reference_image_pixelsize_nm = ref_pixelsize, + simulated_image = query_image, + simulated_image_pixelsize_nm = query_pixelsize, + image_mask = union_mask, + resized_reference_image = ref, + resized_simulated_image = query, + **kwargs + ) + #custom_measurement = metric_calculator.run_metric() + if isinstance(custom_measurement, float): + similarity_vector.append(custom_measurement) + else: + similarity_vector.append(None) return similarity_vector, ref, query, masks_used diff --git a/src/vlab4mic/analysis/sweep.py b/src/vlab4mic/analysis/sweep.py index 1ad0d1b..949c185 100644 --- a/src/vlab4mic/analysis/sweep.py +++ b/src/vlab4mic/analysis/sweep.py @@ -503,51 +503,6 @@ def generate_global_reference_modality( return reference_output[modality]["ch0"], reference_parameters, reference_output_mask -def analyse_image_sweep( - img_outputs, img_params, reference, analysis_case_params=None -): - """ - Analyse a sweep of images against a reference image. - - Parameters - ---------- - img_outputs : dict - Dictionary of simulated image outputs. - img_params : dict - Dictionary of image parameters. - reference : numpy.ndarray - Reference image. - analysis_case_params : dict, optional - Additional parameters for analysis. - - Returns - ------- - measurement_vectors : list - List of measurement results for each image. - inputs : dict - Dictionary of input images and used references. - """ - measurement_vectors = [] - # ref_pixelsize = analysis_case_params["ref_pixelsize"] - inputs = dict() - for params_id in img_params.keys(): - inputs[params_id] = dict() - rep_number = 0 - mod_name = img_params[params_id][5] # 5th item corresponds to Modality - for img_r in img_outputs[params_id]: - im1 = img_r[0] - im_ref = reference[0] - rep_measurement, ref_used, qry_used_, masks_used = metrics.img_compare( - im_ref, im1, **analysis_case_params[mod_name] - ) - measurement_vectors.append( - [params_id, rep_number, rep_measurement] - ) - inputs[params_id][rep_number] = [qry_used, im1] - rep_number += 1 - return measurement_vectors, inputs - - def analyse_sweep_single_reference( img_outputs, img_outputs_masks, @@ -556,9 +511,8 @@ def analyse_sweep_single_reference( reference_image_mask, reference_params, zoom_in=0, - metrics_list: list = [ - "ssim", - ], + metrics: dict = None, + #custom_metrics = None, **kwargs, ): """ @@ -600,26 +554,39 @@ def analyse_sweep_single_reference( modality_pixelsize = img_params[params_id][6]["pixelsize"] for img_r, img_mask in zip(img_outputs[params_id], img_outputs_masks[params_id]): im1 = img_r[0] - im1_mask = img_mask - #print(f"query: {im1.shape},{im1_mask.shape}") - im_ref = reference_image - rep_measurement, ref_used, qry_used, masks_used = metrics.img_compare( - ref = im_ref, - ref_mask=reference_image_mask, - query=im1, - query_mask=im1_mask, - modality_pixelsize=modality_pixelsize, - ref_pixelsize=reference_params["ref_pixelsize"], - force_match=True, - zoom_in=zoom_in, - metric=metrics_list, - ) - r_vector = list([params_id, rep_number]) + list([*rep_measurement]) - measurement_vectors.append(r_vector) - # measurement_vectors = measurement_vectors + rep_measurement[0] - inputs[params_id][rep_number] = [qry_used, im1, ref_used, masks_used] + #im1_mask = img_mask + #im_ref = reference_image + similarity_vector = [] + metrics_names_list = [] + for metric_name, metric in metrics.items(): + #metrics_names_list.append(metric_name) + similarity_vector.append(metric( + reference_image = reference_image, + reference_image_pixelsize_nm = reference_params["ref_pixelsize"], + reference_image_mask = reference_image_mask, + simulated_image = im1, + simulated_image_pixelsize_nm = modality_pixelsize, + simulated_image_mask = img_mask, + ) + ) + #rep_measurement, ref_used, qry_used, masks_used = metrics.img_compare( + # ref = im_ref, + # ref_mask=reference_image_mask, + # query=im1, + # query_mask=im1_mask, + # modality_pixelsize=modality_pixelsize, + # ref_pixelsize=reference_params["ref_pixelsize"], + # force_match=True, + # zoom_in=zoom_in, + # metrics = metrics + #) + # each image as its ID, a replica number and the metrics associated to it + replicaID_repN_metrics = list([params_id, rep_number]) + list([*similarity_vector]) + measurement_vectors.append(replicaID_repN_metrics) + # for methods that require image resizing + #inputs[params_id][rep_number] = [qry_used, im1, ref_used, masks_used] rep_number += 1 - return measurement_vectors, inputs, metrics_list + return measurement_vectors, inputs, metrics_names_list def measurements_dataframe( @@ -683,6 +650,7 @@ def measurements_dataframe( nmetrics = len(metric_names) metrics_dictionary = dict() for metric_number in range(nmetrics): + # first two indices are ID and replica number, then one value per metric metricvector = measurement_array[:, 2 + metric_number] metrics_dictionary[metric_names[metric_number]] = np.array( metricvector, dtype=np.float64 diff --git a/src/vlab4mic/sweep_generator.py b/src/vlab4mic/sweep_generator.py index 184e541..ec15c84 100644 --- a/src/vlab4mic/sweep_generator.py +++ b/src/vlab4mic/sweep_generator.py @@ -13,6 +13,7 @@ import copy import tifffile as tiff from pandas.api.types import is_numeric_dtype +from .analysis.metrics import structural_similarity, pearson_correlation #output_dir = Path.home() / "vlab4mic_outputs" @@ -76,11 +77,13 @@ def __init__(self): self.parameter_settings = load_yaml(param_settings_file) self.analysis_parameters = {} self.analysis_parameters["zoom_in"] = 0 - self.analysis_parameters["metrics_list"] = [ - "ssim", - "pearson" - ] self.plot_parameters = {} + self.plot_parameters["ssim"] = {} + self.plot_parameters["ssim"]["heatmaps"] = {} + self.plot_parameters["ssim"]["lineplots"] = {} + self.plot_parameters["pearson"] = {} + self.plot_parameters["pearson"]["heatmaps"] = {} + self.plot_parameters["pearson"]["lineplots"] = {} self.plot_parameters["heatmaps"] = {} self.plot_parameters["heatmaps"]["category"] = "modality_name" self.plot_parameters["heatmaps"]["param1"] = None @@ -103,6 +106,11 @@ def __init__(self): self.param_settings = self.parameter_settings self.use_experiment_structure = False self.reference_parameters_unsorted = dict() + self.default_metrics = { + "ssim": structural_similarity, + "pearson": pearson_correlation + } + self.metrics = {} print("vLab4mic sweep generator initialised") def set_number_of_repetitions(self, repeats: int = 3): @@ -367,7 +375,10 @@ def load_reference_image( image_mask = tiff.imread(ref_image_mask_path) ref_image_mask = image_mask > 0 else: - image_mask = np.ones(shape=ref_image[0].shape) + if len(ref_image.shape) == 3: + image_mask = np.ones(shape=ref_image[0].shape) + else: + image_mask = np.ones(shape=ref_image.shape) ref_image_mask = image_mask > 0 if override: self.reference_image = ref_image @@ -459,7 +470,10 @@ def preview_reference_image(self, return_image=False, cmap="Grays_r"): if return_image: return self.reference_image else: - plt.imshow(self.reference_image[0], cmap=cmap) + if len(self.reference_image.shape) > 2: + plt.imshow(self.reference_image[0], cmap=cmap) + else: + plt.imshow(self.reference_image, cmap=cmap) print(self.reference_image_parameters) # set and change parameters @@ -734,8 +748,8 @@ def set_analysis_parameters( ------- None """ - if metrics_list is not None and type(metrics_list) == list: - self.analysis_parameters["metrics_list"] = metrics_list + #if metrics_list is not None and type(metrics_list) == list: + # self.analysis_parameters["metrics_list"] = metrics_list if zoom_in is not None: self.analysis_parameters["zoom_in"] = zoom_in @@ -771,7 +785,12 @@ def set_na_as_zero_in_plots(self, na_as_zero: bool = True): ------- None """ - self.plot_parameters["general"]["na_as_zero"] = na_as_zero + self.plot_parameters["general"]["na_as_zero"] = na_as_zero + + def use_default_metrics(self, metrics = ["ssim", "pearson"]): + if metrics is not None: + for metric_name in metrics: + self.metrics[metric_name] = self.default_metrics[metric_name] def run_analysis( self, @@ -816,6 +835,8 @@ def run_analysis( else: reference_image = self.reference_image reference_image_mask = self.reference_image_mask + if len(self.metrics.keys()) == 0: + self.use_default_metrics() print("Running analysis...") measurement_vectors, inputs, metric = ( sweep.analyse_sweep_single_reference( @@ -825,6 +846,7 @@ def run_analysis( reference_image=reference_image, reference_image_mask=reference_image_mask, reference_params=self.reference_image_parameters, + metrics=self.metrics, **self.analysis_parameters, ) ) @@ -836,14 +858,15 @@ def run_analysis( print("Analysis dataframe generated.") if plots: print("Generating analysis plots...") - for metric_name in self.analysis_parameters["metrics_list"]: - for plot_type in self.plot_parameters.keys(): + for metric_name in self.metrics.keys(): + for plot_type in ["heatmaps", "lineplots"]: self.generate_analysis_plots( plot_type=plot_type, return_figure=True, metric_name=metric_name, filter_dictionary=None, na_as_zero=self.plot_parameters["general"]["na_as_zero"], + **self.plot_parameters[metric_name][plot_type] ) print("Analysis plots generated.") if save: @@ -873,7 +896,7 @@ def gen_analysis_dataframe(self): mod_acq=self.acquisition_parameters, mod_names=self.modalities, mod_params=self.modality_parameters, - metric_names=self.analysis_parameters["metrics_list"], + metric_names=list(self.metrics.keys()), ) ) @@ -1006,7 +1029,7 @@ def _gen_heatmaps( The generated heatmap figure. """ if metric_name is None: - metric_name = self.analysis_parameters["metrics_list"][0] + metric_name = list(self.metrics.keys())[0] if category is None: category = "modality_name" if param1 is None: @@ -1111,7 +1134,7 @@ def _gen_lineplots( else: x_param = "labelling_efficiency" if metric_name is None: - metric_name = self.analysis_parameters["metrics_list"][0] + metric_name = list(self.metrics.keys())[0] if style is None and len(self.parameters_with_set_values) > 1: style = self.parameters_with_set_values[1] fig, axes = plt.subplots(figsize=figsize) @@ -1314,6 +1337,47 @@ def save_images(self, output_name=None, output_directory=None, floats_as=float): #name_ref = output_directory + "reference.tiff" tiff.imwrite(dir_name_ref, self.reference_image) + def add_custom_analysis_metrics(self, custom_metrics: list[callable] = None, heatmap_params={"cmaps_range": "each"}, lineplots_params=None, **kwargs): + """ + Add a custom analysis metric function to the sweep generator. + + Parameters + ---------- + :param metric_function: callable + A function that takes two numpy arrays (reference image and query image) + and returns a float representing the calculated metric. + :param metric_name: str + The name of the custom metric to be added. + + Returns + ------- + None + + Notes + ----- + The custom metric function should have the following signature: + def custom_metric(ref: np.ndarray, query: np.ndarray, union_mask, **kwargs) -> float: + # here union_mask is a binary mask that can be used to filter pixel to use for calculations + # Calculate and return the metric value + """ + current_metrics = list(self.metrics.keys()) + for m in range(len(custom_metrics)): + metric_name = custom_metrics[m].__name__ + if metric_name not in current_metrics: + #self.analysis_parameters["metrics_list"].append(metric_name) + self.plot_parameters[metric_name] = dict() + # change plot parameters to different method + if heatmap_params is not None: + self.plot_parameters[metric_name]["heatmaps"] = heatmap_params + else: + self.plot_parameters[metric_name]["heatmaps"] = {} + if lineplots_params is not None: + self.plot_parameters[metric_name]["lineplots"] = lineplots_params + else: + self.plot_parameters[metric_name]["lineplots"] = {} + self.metrics[metric_name] = custom_metrics[m] + + def run_parameter_sweep( structures: list[str] = None, @@ -1363,6 +1427,10 @@ def run_parameter_sweep( exp_time = None, # for plot generation na_as_zero = True, + custom_metrics: list[callable] = None, + default_metrics = ["ssim", "pearson"], + #custom_metric_name: str = None, + plot_parameters=None # Add more as needed for your sweep ): """ @@ -1509,6 +1577,17 @@ def run_parameter_sweep( reference_probe=reference_probe, **reference_parameters) sweep_gen.set_na_as_zero_in_plots(na_as_zero=na_as_zero) + sweep_gen.use_default_metrics(metrics=default_metrics) + if custom_metrics is not None: + sweep_gen.add_custom_analysis_metrics( + custom_metrics=custom_metrics + ) + if plot_parameters is not None: + for plot_type, parameters in plot_parameters.items(): + sweep_gen.set_plot_parameters( + plot_type=plot_type, + **parameters) + if run_analysis: sweep_gen.run_analysis( save=save_analysis_results, diff --git a/tests/test_sweeps.py b/tests/test_sweeps.py index 97132d2..c416821 100644 --- a/tests/test_sweeps.py +++ b/tests/test_sweeps.py @@ -19,7 +19,7 @@ def test_run_parameter_sweep(): output_name="vlab_script", return_generator=True, save_sweep_images=False, - save_analysis_results=True, + save_analysis_results=False, run_analysis=True ) @@ -33,4 +33,29 @@ def test_run_parameter_sweep(): sweep_gen_test.params_by_group[group_name][param_name] ) vsamples_unique_ids = len(sweep_gen_test.virtual_samples_parameters.keys()) - assert total_combinations == vsamples_unique_ids \ No newline at end of file + assert total_combinations == vsamples_unique_ids + +def test_custom_metric(): + + def mean_value(reference_image = None, + reference_image_pixelsize_nm = None, + simulated_image = None, + simulated_image_pixelsize_nm = None, + image_mask = None, + resized_reference_image = None, + resized_simulated_image = None, + *args,**kwargs): + return np.mean(simulated_image) + + + sweep_gen = sweep_generator.run_parameter_sweep( + sweep_repetitions=3, + # parameters for sweep + labelling_efficiency=(0, 1, 0.5), # values between 0 and 1 with step of 0.5 + return_generator=True, + analysis_plots=True, + save_sweep_images=False, # By default, the saving directory is set to the home path of the user + save_analysis_results=False, + run_analysis=True, + custom_metrics=[mean_value,], + ) \ No newline at end of file