From fe62ee281f5860fe106d9dd9d9818553270119c4 Mon Sep 17 00:00:00 2001 From: youssefmecky96 Date: Fri, 8 Mar 2024 14:38:13 +0100 Subject: [PATCH] formatting --- icu_benchmarks/imputation/diffwave.py | 2 +- icu_benchmarks/models/custom_metrics.py | 30 +- icu_benchmarks/models/utils.py | 379 ++++++++++++++++- icu_benchmarks/models/wrappers.py | 540 ++---------------------- scripts/plotting/plotting.py | 192 +++++++++ 5 files changed, 615 insertions(+), 528 deletions(-) diff --git a/icu_benchmarks/imputation/diffwave.py b/icu_benchmarks/imputation/diffwave.py index 458c2c5e..437303ed 100644 --- a/icu_benchmarks/imputation/diffwave.py +++ b/icu_benchmarks/imputation/diffwave.py @@ -330,7 +330,7 @@ def forward(self, input_data): cond = self.cond_conv(cond) h += cond - out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) res = self.res_conv(out) assert x.shape == res.shape diff --git a/icu_benchmarks/models/custom_metrics.py b/icu_benchmarks/models/custom_metrics.py index fd88699f..0cff08c0 100644 --- a/icu_benchmarks/models/custom_metrics.py +++ b/icu_benchmarks/models/custom_metrics.py @@ -422,30 +422,30 @@ def norm_function(arr): x_original = dataloader.dataset.data["reals"].clone() dataloader.dataset.add_noise() - x_preturb = dataloader.dataset.data["reals"].clone() - y_pred_preturb = model.model.predict(dataloader) + x_perturb = dataloader.dataset.data["reals"].clone() + y_pred_perturb = model.model.predict(dataloader) Attention_weights = model.interpertations(dataloader) - att_preturb = Attention_weights["attention"] + att_perturb = Attention_weights["attention"] # Calculate the absolute difference - difference = torch.abs(y_pred_preturb - y_pred) + difference = torch.abs(y_pred_perturb - y_pred) # Find where the difference is less than or equal to a thershold close_indices = torch.nonzero(difference <= thershold).squeeze()[:, 0].to(device) RIS = relative_stability_objective( x_original.detach(), - x_preturb.detach(), + x_perturb.detach(), attribution, - att_preturb, + att_perturb, close_indices=close_indices, input=True, attention=True, ) ROS = relative_stability_objective( y_pred, - y_pred_preturb, + y_pred_perturb, attribution, - att_preturb, + att_perturb, close_indices=close_indices, input=False, attention=True, @@ -458,16 +458,16 @@ def norm_function(arr): with torch.no_grad(): noise = torch.randn_like(x["encoder_cont"]) * 0.01 x["encoder_cont"] += noise - y_pred_preturb = model(model.prep_data(x)).detach() + y_pred_perturb = model(model.prep_data(x)).detach() if explain_method == "Random": - att_preturb = np.random.normal(size=[64, 24, 53]) + att_perturb = np.random.normal(size=[64, 24, 53]) else: - att_preturb, features_attrs, timestep_attrs = model.explantation2(x, explain_method) + att_perturb, features_attrs, timestep_attrs = model.explantation2(x, explain_method) # # Calculate the absolute difference - difference = torch.abs(y_pred_preturb - y_pred) + difference = torch.abs(y_pred_perturb - y_pred) # Find where the difference is less than or equal to a thershold close_indices = torch.nonzero(difference <= thershold).squeeze()[:, 0].to(device) @@ -476,15 +476,15 @@ def norm_function(arr): x_original.detach(), x["encoder_cont"].detach(), attribution, - att_preturb, + att_perturb, close_indices=close_indices, input=True, ) ROS = relative_stability_objective( y_pred, - y_pred_preturb, + y_pred_perturb, attribution, - att_preturb, + att_perturb, close_indices=close_indices, input=False, ) diff --git a/icu_benchmarks/models/utils.py b/icu_benchmarks/models/utils.py index 6c944ae7..34a5de0b 100644 --- a/icu_benchmarks/models/utils.py +++ b/icu_benchmarks/models/utils.py @@ -8,13 +8,14 @@ import logging import numpy as np import torch - +from quantus.functions.similarity_func import correlation_spearman, cosine from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities import rank_zero_only from torch.nn import Module from torch.optim import Optimizer, Adam, SGD, RAdam from typing import Optional, Union from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR, MultiStepLR, ExponentialLR +import captum def save_config_file(log_dir): @@ -188,3 +189,379 @@ def version(self): @rank_zero_only def log_hyperparams(self, params): pass + + +def Faithfulness_Correlation( + model, + x, + attribution, + similarity_func=None, + nr_runs=100, + pertrub=None, + subset_size=3, + feature=False, + time_step=False, + feature_timestep=False, +): + """ + Calculates faithfulness scores for captum attributions + + Args: + - x:Batch input + -attribution: attribution generated by captum, + - similarity_func:function to determine similarity between sum of attributions and difference in prediction + - nr_runs: How many times to repeat the experiment, + - pertrub: What change to do to the input, + - subset_size: The size of the subset of featrues to alter , + - feature: Determines if to calcualte faithfulness of feature attributions, + - time_step: Determines if to calcualte faithfulness of timesteps attributions, + - feature_timestep: Determines if to calcualte faithfulness of featrues per timesteps attributions, + Returns: + score: similarity score between sum of attributions and difference in prediction averaged over nr_runs + + Implementation of faithfulness correlation by Bhatt et al., 2020. + + The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness + (or 'fidelity') with respect to the model behaviour. + + Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and + the average explanation attribution for only the subset of features are (linearly) correlated, taking the + average over multiple runs and test samples. The metric returns one float per input-attribution pair that + ranges between -1 and 1, where higher scores are better. + + For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline + or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified + test point and the average explanation attribution for only the subset of features is calculated. Results is + average over multiple runs and several test samples. + This code is adapted from the quantus libray to suit our use case + + References: + 1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model + explanations." IJCAI (2020): 3016-3022. + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for + responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + """ + + attribution = torch.tensor(attribution).to(model.device) + + # Other initializations + if similarity_func is None: + similarity_func = correlation_spearman + if pertrub is None: + pertrub = "baseline" + similarities = [] + + # Assuming this is a method to prepare your data + + y_pred = model(model.prep_data(x)).detach() # Keep on GPU + pred_deltas = [] + att_sums = [] + + for i_ix in range(nr_runs): + if time_step: + timesteps_idx = np.random.choice(24, subset_size, replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, timesteps_idx] + + elif feature: + feature_idx = np.random.choice(53, subset_size, replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, feature_idx] + elif feature_timestep: + timesteps_idx = np.random.choice(24, subset_size[0], replace=False) + feature_idx = np.random.choice(53, subset_size[1], replace=False) + patient_idx = np.random.choice(64, 1, replace=False) + a_ix = [patient_idx, timesteps_idx, feature_idx] + + # Apply perturbation + if pertrub == "Noise": + x = model.add_noise(x, a_ix, time_step, feature, feature_timestep) + elif pertrub == "baseline": + x = model.apply_baseline(x, a_ix, time_step, feature, feature_timestep) + + # Predict on perturbed input and calculate deltas + y_pred_perturb = (model(model.prep_data(x))).detach() # Keep on GPU + + if time_step: + if attribution.size() == torch.Size([24]): + att_sums.append((attribution[timesteps_idx]).sum()) + else: + att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :]).sum()) + elif feature: + if len(attribution) == 53: + att_sums.append((attribution[feature_idx]).sum()) + else: + att_sums.append((attribution[patient_idx, :, :][:, :, feature_idx]).sum()) + elif feature_timestep: + att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :][:, :, feature_idx]).sum()) + + pred_deltas.append((y_pred - y_pred_perturb)[patient_idx].item()) + # Convert to CPU for numpy operations + + pred_deltas_cpu = torch.tensor(pred_deltas).cpu().numpy() + att_sums_cpu = torch.tensor(att_sums).cpu().numpy() + + similarities.append(similarity_func(pred_deltas_cpu, att_sums_cpu)) + + score = np.nanmean(similarities) + return score + + +def Data_Randomization( + model, + x, + attribution, + explain_method, + random_model, + similarity_func=cosine, + dataloader=None, + method_name="", + **kwargs, +): + """ + + Args: + - x:Batch input + -attribution: attribution + - explain_method:function to generate explantations + - random_model: Reference to model trained on random labels + - similarity_func: Function to measure similiarity + - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , + - method_name: Name of the explantation + + Returns: + score: similarity score between attributions of model trained on random data and model trained on real data + + Implementation of the Random Logit Metric by Sixt et al., 2020. + + The Random Logit Metric computes the distance between the original explanation and a reference explanation of + a randomly chosen non-target class. + This code is adapted from the quantus libray to suit our use case + + References: + 1) Leon Sixt et al.: "When Explanations Lie: Why Many Modified BP + Attributions Fail." ICML (2020): 9046-9057. + 2)Hedström, Anna, et al. "Quantus: An explainable ai + toolkit for responsible evaluation of neural network explanations and beyond." + Journal of Machine Learning Research 24.34 (2023): 1-11. + + """ + + if explain_method == "Attention": + Attention_weights = random_model.interpertations(dataloader) + attribution = attribution.cpu().numpy() + min_val = np.min(attribution) + max_val = np.max(attribution) + + attribution = (attribution - min_val) / (max_val - min_val) + random_attr = Attention_weights["attention"].cpu().numpy() + min_val = np.min(random_attr) + max_val = np.max(random_attr) + random_attr = (random_attr - min_val) / (max_val - min_val) + score = similarity_func(random_attr, attribution) + elif explain_method == "Random": + score = similarity_func(np.random.normal(size=[64, 24, 53]).flatten(), attribution.flatten()) + else: + data, baselines = model.prep_data_captum(x) + + explantation = explain_method(random_model.forward_captum) + # Reformat attributions. + if explain_method is not captum.attr.Saliency: + attr = explantation.attribute(data, baselines=baselines, **kwargs) + else: + attr = explantation.attribute(data, **kwargs) + + # Process and store the calculated attributions + random_attr = ( + attr[1].cpu().detach().numpy() + if method_name in ["Lime", "FeatureAblation"] + else torch.stack(attr).cpu().detach().numpy() + ) + + attribution = attribution.flatten() + min_val = np.min(attribution) + max_val = np.max(attribution) + attribution = (attribution - min_val) / (max_val - min_val) + random_attr = random_attr.flatten() + min_val = np.min(random_attr) + max_val = np.max(random_attr) + random_attr = (random_attr - min_val) / (max_val - min_val) + + score = similarity_func(random_attr, attribution) + return score + + +def Relative_Stability( + model, + x, + attribution, + explain_method, + method_name, + dataloader=None, + threshold=0.5, + **kwargs, +): + """ + Args: + - x:Batch input + -attribution: attribution + - explain_method:function to generate explantations + - method_name: Name of the explantation + - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , + + + Returns: + RIS : relative distance between the explantation and the input + ROS: relative distance between the explantation and the output + + + References: + 1) `https://arxiv.org/pdf/2203.06877.pdf + 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for responsible evaluation + of neural network explanations and beyond." Journal of Machine Learning Research 24.34 (2023): 1-11. + + """ + + def relative_stability_objective(x, xs, e_x, e_xs, eps_min=0.0001, input=False, device="cuda") -> torch.Tensor: + """ + Computes relative input and output stabilities maximization objective + as defined here :ref:`https://arxiv.org/pdf/2203.06877.pdf` by the authors. + + Args: + + x: Input tensor + xs: perturbed tensor. + e_x: Explanations for x. + e_xs: Explanations for xs. + eps_min:Value to avoid division by zero if needed + input:Boolean to indicate if this is an input or an output + device: the device to keep the tensors on + + Returns: + + ris_obj: Tensor + RIS maximization objective. + """ + + # Function to convert inputs to tensors if they are numpy arrays + def to_tensor(input_array): + if isinstance(input_array, np.ndarray): + return torch.tensor(input_array).to(device) + return input_array.to(device) + + # Convert all inputs to tensors and move to GPU + x, xs, e_x, e_xs = map(to_tensor, [x, xs, e_x, e_xs]) + + if input: + num_dim = x.ndim + else: + num_dim = e_x.ndim + + if num_dim == 3: + + def norm_function(arr): + return torch.norm(arr, dim=(-1, -2)) + + elif num_dim == 2: + + def norm_function(arr): + return torch.norm(arr, dim=-1) + + else: + + def norm_function(arr): + return torch.norm(arr) + + nominator = (e_x - e_xs) / (e_x + (e_x == 0) * eps_min) + nominator = norm_function(nominator) + + if input: + denominator = x - xs + denominator /= x + (x == 0) * eps_min + denominator = norm_function(denominator) + denominator += (denominator == 0) * eps_min + else: + denominator = torch.squeeze(x) - torch.squeeze(xs) + denominator = torch.norm(denominator, dim=-1) + denominator += (denominator == 0) * eps_min + + return nominator / denominator + + attribution = torch.tensor(attribution).to(model.device) + if explain_method == "Attention": + y_pred = model.model.predict(dataloader) + x_original = dataloader.dataset.data["reals"].clone() + + dataloader.dataset.add_noise() + x_perturb = dataloader.dataset.data["reals"].clone() + y_pred_perturb = model.model.predict(dataloader) + Attention_weights = model.interpertations(dataloader) + att_perturb = Attention_weights["attention"] + # Calculate the absolute difference + difference = torch.abs(y_pred_perturb - y_pred) + + # Find where the difference is less than or equal to a threshold + close_indices = torch.nonzero(difference <= threshold).squeeze() + RIS = relative_stability_objective( + x_original[close_indices, :, :].detach(), + x_perturb[close_indices, :, :].detach(), + attribution, + att_perturb, + input=True, + ) + + ROS = relative_stability_objective( + y_pred[close_indices], + y_pred_perturb[close_indices], + attribution, + att_perturb, + input=False, + ) + + else: + y_pred = model(model.prep_data(x)).detach() + x_original = x["encoder_cont"].detach().clone() + + with torch.no_grad(): + noise = torch.randn_like(x["encoder_cont"]) * 0.01 + x["encoder_cont"] += noise + y_pred_perturb = model(model.prep_data(x)).detach() + if explain_method == "Random": + att_perturb = np.random.normal(size=[64, 24, 53]) + att_perturb = torch.tensor(att_perturb).to(model.device) + else: + data, baselines = model.prep_data_captum(x) + + explantation = explain_method(model.forward_captum) + # Reformat attributions. + if explain_method is not captum.attr.Saliency: + att_perturb = explantation.attribute(data, baselines=baselines, **kwargs) + else: + att_perturb = explantation.attribute(data, **kwargs) + + # Process and store the calculated attributions + att_perturb = ( + att_perturb[1].detach() if method_name in ["Lime", "FeatureAblation"] else torch.stack(att_perturb).detach() + ) + # Calculate the absolute difference + difference = torch.abs(y_pred_perturb - y_pred) + + # Find where the difference is less than or equal to a threshold + close_indices = torch.nonzero(difference <= threshold).squeeze() + RIS = relative_stability_objective( + x_original[close_indices, :, :].detach(), + x["encoder_cont"][close_indices, :, :].detach(), + attribution[close_indices, :, :], + att_perturb[close_indices, :, :], + input=True, + ) + ROS = relative_stability_objective( + y_pred[close_indices], + y_pred_perturb[close_indices], + attribution[close_indices, :, :], + att_perturb[close_indices, :, :], + input=False, + ) + + return np.max(RIS.cpu().numpy()).astype(np.float64), np.max(ROS.cpu().numpy()).astype(np.float64) diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index 79f13ef4..ae70406a 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -18,8 +18,7 @@ from pytorch_lightning import LightningModule from icu_benchmarks.models.constants import MLMetrics, DLMetrics from icu_benchmarks.contants import RunMode -import matplotlib.pyplot as plt -from quantus.functions.similarity_func import correlation_spearman, cosine +from icu_benchmarks.models.utils import Faithfulness_Correlation, Data_Randomization, Relative_Stability import captum from captum._utils.models.linear_model import SkLearnLasso @@ -522,118 +521,6 @@ def prep_data_captum(self, x): ) return data, baselines - def plot_attributions(self, features_attrs, timestep_attrs, method_name, log_dir): - """ - Plots the attribution values for features and timesteps. - - Args: - - features_attrs: Array of feature attribution values. - - timestep_attrs: Array of timestep attribution values. - - method_name: Name of the attribution method. - - log_dir: Directory to save the plots. - Returns: - Nothing - """ - - # Plot for feature attributions - x_values = np.arange(1, len(features_attrs) + 1) - plt.figure(figsize=(8, 6)) - plt.plot( - x_values, - features_attrs, - marker="o", - color="skyblue", - linestyle="-", - linewidth=2, - markersize=8, - ) - plt.xlabel("Feature") - plt.ylabel("{} Attribution".format(method_name)) - plt.title("{} Attribution Values".format(method_name)) - plt.xticks( - x_values, - [ - "height", - "weight", - "age", - "sex", - "time_idx", - "alb", - "alp", - "alt", - "ast", - "be", - "bicar", - "bili", - "bili_dir", - "bnd", - "bun", - "ca", - "cai", - "ck", - "ckmb", - "cl", - "crea", - "crp", - "dbp", - "fgn", - "fio2", - "glu", - "hgb", - "hr", - "inr_pt", - "k", - "lact", - "lymph", - "map", - "mch", - "mchc", - "mcv", - "methb", - "mg", - "na", - "neut", - "o2sat", - "pco2", - "ph", - "phos", - "plt", - "po2", - "ptt", - "resp", - "sbp", - "temp", - "tnt", - "urine", - "wbc", - ], - rotation=90, - ) - plt.tight_layout() - plt.savefig( - log_dir / "{}_attribution_features_plot.png".format(method_name), - bbox_inches="tight", - ) - - # Plot for timestep attributions - x_values = np.arange(1, len(timestep_attrs) + 1) - plt.figure(figsize=(8, 6)) - plt.plot( - x_values, - timestep_attrs, - marker="o", - color="skyblue", - linestyle="-", - linewidth=2, - markersize=8, - ) - plt.xlabel("Time Step") - plt.ylabel("{} Attribution".format(method_name)) - plt.title("{} Attribution Values".format(method_name)) - plt.xticks(x_values) - plt.tight_layout() - plt.savefig(log_dir / "{}_attribution_plot.png".format(method_name), bbox_inches="tight") - def explantation( self, dataloader, @@ -677,7 +564,8 @@ def explantation( timestep_attrs = Interpertations["attention"] features_attrs = Interpertations["static_variables"].tolist() features_attrs.extend(Interpertations["encoder_variables"].tolist()) - r_score = self.Data_Randomization( + r_score = Data_Randomization( + self, x=None, attribution=timestep_attrs, explain_method=method, @@ -685,7 +573,8 @@ def explantation( dataloader=dataloader, method_name=method_name, ) - st_i_score, st_o_score = self.Relative_Stability( + st_i_score, st_o_score = Relative_Stability( + self, x=None, attribution=timestep_attrs, explain_method=method, @@ -706,7 +595,8 @@ def explantation( if method_name == "Random": f_ts_v_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, all_attrs, pertrub="baseline", @@ -716,7 +606,8 @@ def explantation( ) ) f_ts_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, all_attrs, pertrub="baseline", @@ -726,7 +617,8 @@ def explantation( ) ) f_v_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, all_attrs, pertrub="baseline", @@ -737,7 +629,8 @@ def explantation( ) r_score.append( - self.Data_Randomization( + Data_Randomization( + self, x, attribution=all_attrs, explain_method=method, @@ -745,7 +638,8 @@ def explantation( method_name=method_name, ) ) - res1, res2 = self.Relative_Stability( + res1, res2 = Relative_Stability( + self, x, all_attrs, explain_method=method, @@ -757,7 +651,8 @@ def explantation( st_o_score.append(res2) else: f_ts_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, timestep_attrs, pertrub="baseline", @@ -767,7 +662,8 @@ def explantation( ) ) f_v_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, features_attrs, pertrub="baseline", @@ -829,7 +725,8 @@ def explantation( ) if XAI_metric: f_ts_v_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, stacked_attr, pertrub="baseline", @@ -840,7 +737,8 @@ def explantation( ) f_ts_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, stacked_attr, pertrub="baseline", @@ -850,7 +748,8 @@ def explantation( ) ) f_v_score.append( - self.Faithfulness_Correlation( + Faithfulness_Correlation( + self, x, stacked_attr, pertrub="baseline", @@ -860,7 +759,8 @@ def explantation( ) ) r_score.append( - self.Data_Randomization( + Data_Randomization( + self, x, attribution=stacked_attr, explain_method=method, @@ -869,7 +769,8 @@ def explantation( ) ) - res1, res2 = self.Relative_Stability( + res1, res2 = Relative_Stability( + self, x, stacked_attr, explain_method=method, @@ -901,14 +802,6 @@ def explantation( st_i_score = np.max(st_i_score) st_o_score = np.max(st_o_score) - """ if plot: - log_dir_plots = log_dir / "plots" - if not (log_dir_plots.exists()): - log_dir_plots.mkdir(parents=True) - # Plot attributions for features and timesteps - - self.plot_attributions(features_attrs, timestep_attrs, method_name, log_dir_plots) """ - # Return computed attributions and metrics return ( all_attrs, @@ -994,381 +887,6 @@ def apply_baseline(self, x, indices, time_step, feature, feature_timestep): x["encoder_cont"] *= mask return x - def Faithfulness_Correlation( - self, - x, - attribution, - similarity_func=None, - nr_runs=100, - pertrub=None, - subset_size=3, - feature=False, - time_step=False, - feature_timestep=False, - ): - """ - Calculates faithfulness scores for captum attributions - - Args: - - x:Batch input - -attribution: attribution generated by captum, - - similarity_func:function to determine similarity between sum of attributions and difference in prediction - - nr_runs: How many times to repeat the experiment, - - pertrub: What change to do to the input, - - subset_size: The size of the subset of featrues to alter , - - feature: Determines if to calcualte faithfulness of feature attributions, - - time_step: Determines if to calcualte faithfulness of timesteps attributions, - - feature_timestep: Determines if to calcualte faithfulness of featrues per timesteps attributions, - Returns: - score: similarity score between sum of attributions and difference in prediction averaged over nr_runs - - Implementation of faithfulness correlation by Bhatt et al., 2020. - - The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness - (or 'fidelity') with respect to the model behaviour. - - Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and - the average explanation attribution for only the subset of features are (linearly) correlated, taking the - average over multiple runs and test samples. The metric returns one float per input-attribution pair that - ranges between -1 and 1, where higher scores are better. - - For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline - or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified - test point and the average explanation attribution for only the subset of features is calculated. Results is - average over multiple runs and several test samples. - This code is adapted from the quantus libray to suit our use case - - References: - 1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model - explanations." IJCAI (2020): 3016-3022. - 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for - responsible evaluation of neural network explanations and beyond." - Journal of Machine Learning Research 24.34 (2023): 1-11. - """ - - attribution = torch.tensor(attribution).to(self.device) - - # Other initializations - if similarity_func is None: - similarity_func = correlation_spearman - if pertrub is None: - pertrub = "baseline" - similarities = [] - - # Assuming this is a method to prepare your data - - y_pred = self(self.prep_data(x)).detach() # Keep on GPU - pred_deltas = [] - att_sums = [] - - for i_ix in range(nr_runs): - if time_step: - timesteps_idx = np.random.choice(24, subset_size, replace=False) - patient_idx = np.random.choice(64, 1, replace=False) - a_ix = [patient_idx, timesteps_idx] - - elif feature: - feature_idx = np.random.choice(53, subset_size, replace=False) - patient_idx = np.random.choice(64, 1, replace=False) - a_ix = [patient_idx, feature_idx] - elif feature_timestep: - timesteps_idx = np.random.choice(24, subset_size[0], replace=False) - feature_idx = np.random.choice(53, subset_size[1], replace=False) - patient_idx = np.random.choice(64, 1, replace=False) - a_ix = [patient_idx, timesteps_idx, feature_idx] - - # Apply perturbation - if pertrub == "Noise": - x = self.add_noise(x, a_ix, time_step, feature, feature_timestep) - elif pertrub == "baseline": - x = self.apply_baseline(x, a_ix, time_step, feature, feature_timestep) - - # Predict on perturbed input and calculate deltas - y_pred_perturb = (self(self.prep_data(x))).detach() # Keep on GPU - - if time_step: - if attribution.size() == torch.Size([24]): - att_sums.append((attribution[timesteps_idx]).sum()) - else: - att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :]).sum()) - elif feature: - if len(attribution) == 53: - att_sums.append((attribution[feature_idx]).sum()) - else: - att_sums.append((attribution[patient_idx, :, :][:, :, feature_idx]).sum()) - elif feature_timestep: - att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :][:, :, feature_idx]).sum()) - - pred_deltas.append((y_pred - y_pred_perturb)[patient_idx].item()) - # Convert to CPU for numpy operations - - pred_deltas_cpu = torch.tensor(pred_deltas).cpu().numpy() - att_sums_cpu = torch.tensor(att_sums).cpu().numpy() - - similarities.append(similarity_func(pred_deltas_cpu, att_sums_cpu)) - - score = np.nanmean(similarities) - return score - - def Data_Randomization( - self, - x, - attribution, - explain_method, - random_model, - similarity_func=cosine, - dataloader=None, - method_name="", - **kwargs, - ): - """ - - Args: - - x:Batch input - -attribution: attribution - - explain_method:function to generate explantations - - random_model: Reference to model trained on random labels - - similarity_func: Function to measure similiarity - - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , - - method_name: Name of the explantation - - Returns: - score: similarity score between attributions of model trained on random data and model trained on real data - - Implementation of the Random Logit Metric by Sixt et al., 2020. - - The Random Logit Metric computes the distance between the original explanation and a reference explanation of - a randomly chosen non-target class. - This code is adapted from the quantus libray to suit our use case - - References: - 1) Leon Sixt et al.: "When Explanations Lie: Why Many Modified BP - Attributions Fail." ICML (2020): 9046-9057. - 2)Hedström, Anna, et al. "Quantus: An explainable ai - toolkit for responsible evaluation of neural network explanations and beyond." - Journal of Machine Learning Research 24.34 (2023): 1-11. - - """ - - if explain_method == "Attention": - Attention_weights = random_model.interpertations(dataloader) - attribution = attribution.cpu().numpy() - min_val = np.min(attribution) - max_val = np.max(attribution) - - attribution = (attribution - min_val) / (max_val - min_val) - random_attr = Attention_weights["attention"].cpu().numpy() - min_val = np.min(random_attr) - max_val = np.max(random_attr) - random_attr = (random_attr - min_val) / (max_val - min_val) - score = similarity_func(random_attr, attribution) - elif explain_method == "Random": - score = similarity_func(np.random.normal(size=[64, 24, 53]).flatten(), attribution.flatten()) - else: - data, baselines = self.prep_data_captum(x) - - explantation = explain_method(random_model.forward_captum) - # Reformat attributions. - if explain_method is not captum.attr.Saliency: - attr = explantation.attribute(data, baselines=baselines, **kwargs) - else: - attr = explantation.attribute(data, **kwargs) - - # Process and store the calculated attributions - random_attr = ( - attr[1].cpu().detach().numpy() - if method_name in ["Lime", "FeatureAblation"] - else torch.stack(attr).cpu().detach().numpy() - ) - - attribution = attribution.flatten() - min_val = np.min(attribution) - max_val = np.max(attribution) - attribution = (attribution - min_val) / (max_val - min_val) - random_attr = random_attr.flatten() - min_val = np.min(random_attr) - max_val = np.max(random_attr) - random_attr = (random_attr - min_val) / (max_val - min_val) - - score = similarity_func(random_attr, attribution) - return score - - def Relative_Stability( - self, - x, - attribution, - explain_method, - method_name, - dataloader=None, - thershold=0.5, - **kwargs, - ): - """ - Args: - - x:Batch input - -attribution: attribution - - explain_method:function to generate explantations - - method_name: Name of the explantation - - dataloader:In case of using Attention as the explain method need to pass the dataloader instead of the batch , - - - Returns: - RIS : relative distance between the explantation and the input - ROS: relative distance between the explantation and the output - - - References: - 1) `https://arxiv.org/pdf/2203.06877.pdf - 2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for responsible evaluation - of neural network explanations and beyond." Journal of Machine Learning Research 24.34 (2023): 1-11. - - """ - - def relative_stability_objective(x, xs, e_x, e_xs, eps_min=0.0001, input=False, device="cuda") -> torch.Tensor: - """ - Computes relative input and output stabilities maximization objective - as defined here :ref:`https://arxiv.org/pdf/2203.06877.pdf` by the authors. - - Args: - - x: Input tensor - xs: perturbed tensor. - e_x: Explanations for x. - e_xs: Explanations for xs. - eps_min:Value to avoid division by zero if needed - input:Boolean to indicate if this is an input or an output - device: the device to keep the tensors on - - Returns: - - ris_obj: Tensor - RIS maximization objective. - """ - - # Function to convert inputs to tensors if they are numpy arrays - def to_tensor(input_array): - if isinstance(input_array, np.ndarray): - return torch.tensor(input_array).to(device) - return input_array.to(device) - - # Convert all inputs to tensors and move to GPU - x, xs, e_x, e_xs = map(to_tensor, [x, xs, e_x, e_xs]) - - if input: - num_dim = x.ndim - else: - num_dim = e_x.ndim - - if num_dim == 3: - - def norm_function(arr): - return torch.norm(arr, dim=(-1, -2)) - - elif num_dim == 2: - - def norm_function(arr): - return torch.norm(arr, dim=-1) - - else: - - def norm_function(arr): - return torch.norm(arr) - - nominator = (e_x - e_xs) / (e_x + (e_x == 0) * eps_min) - nominator = norm_function(nominator) - - if input: - denominator = x - xs - denominator /= x + (x == 0) * eps_min - denominator = norm_function(denominator) - denominator += (denominator == 0) * eps_min - else: - denominator = torch.squeeze(x) - torch.squeeze(xs) - denominator = torch.norm(denominator, dim=-1) - denominator += (denominator == 0) * eps_min - - return nominator / denominator - - attribution = torch.tensor(attribution).to(self.device) - if explain_method == "Attention": - y_pred = self.model.predict(dataloader) - x_original = dataloader.dataset.data["reals"].clone() - - dataloader.dataset.add_noise() - x_preturb = dataloader.dataset.data["reals"].clone() - y_pred_preturb = self.model.predict(dataloader) - Attention_weights = self.interpertations(dataloader) - att_preturb = Attention_weights["attention"] - # Calculate the absolute difference - difference = torch.abs(y_pred_preturb - y_pred) - - # Find where the difference is less than or equal to a thershold - close_indices = torch.nonzero(difference <= thershold).squeeze() - RIS = relative_stability_objective( - x_original[close_indices, :, :].detach(), - x_preturb[close_indices, :, :].detach(), - attribution, - att_preturb, - input=True, - ) - - ROS = relative_stability_objective( - y_pred[close_indices], - y_pred_preturb[close_indices], - attribution, - att_preturb, - input=False, - ) - - else: - y_pred = self(self.prep_data(x)).detach() - x_original = x["encoder_cont"].detach().clone() - - with torch.no_grad(): - noise = torch.randn_like(x["encoder_cont"]) * 0.01 - x["encoder_cont"] += noise - y_pred_preturb = self(self.prep_data(x)).detach() - if explain_method == "Random": - att_preturb = np.random.normal(size=[64, 24, 53]) - att_preturb = torch.tensor(att_preturb).to(self.device) - else: - data, baselines = self.prep_data_captum(x) - - explantation = explain_method(self.forward_captum) - # Reformat attributions. - if explain_method is not captum.attr.Saliency: - att_preturb = explantation.attribute(data, baselines=baselines, **kwargs) - else: - att_preturb = explantation.attribute(data, **kwargs) - - # Process and store the calculated attributions - att_preturb = ( - att_preturb[1].detach() - if method_name in ["Lime", "FeatureAblation"] - else torch.stack(att_preturb).detach() - ) - # Calculate the absolute difference - difference = torch.abs(y_pred_preturb - y_pred) - - # Find where the difference is less than or equal to a thershold - close_indices = torch.nonzero(difference <= thershold).squeeze() - RIS = relative_stability_objective( - x_original[close_indices, :, :].detach(), - x["encoder_cont"][close_indices, :, :].detach(), - attribution[close_indices, :, :], - att_preturb[close_indices, :, :], - input=True, - ) - ROS = relative_stability_objective( - y_pred[close_indices], - y_pred_preturb[close_indices], - attribution[close_indices, :, :], - att_preturb[close_indices, :, :], - input=False, - ) - - return np.max(RIS.cpu().numpy()).astype(np.float64), np.max(ROS.cpu().numpy()).astype(np.float64) - @gin.configurable("MLWrapper") class MLWrapper(BaseModule, ABC): diff --git a/scripts/plotting/plotting.py b/scripts/plotting/plotting.py index 779eea26..881d4e1b 100644 --- a/scripts/plotting/plotting.py +++ b/scripts/plotting/plotting.py @@ -1,4 +1,5 @@ import matplotlib.pyplot as plt +import numpy as np class Plotter: @@ -49,3 +50,194 @@ def calibration_curve(self): plt.legend(loc="lower right") plt.savefig(self.save_dir / f"call_curve {self.specifier}.png") plt.clf() + + def plot_XAI_Metrics(accumulated_metrics, log_dir_plots): + groups = {} + for key in accumulated_metrics["avg"]: + if key in ["loss", "MAE"]: + continue + suffix = key.split("_")[-1] + if suffix not in groups: + groups[suffix] = [] + groups[suffix].append(key) + + # Define a dictionary for legend labels + legend_labels = { + "IG": "Integrated Gradient", + "G": "Gradient", + "R": "Random", + "FA": "Feature Ablation", + "Att": "Attention", + "VSN": "Variable Selection Network", + "L": "Lime", + } + colors = [ + "navy", + "skyblue", + "crimson", + "salmon", + "teal", + "orange", + "darkgreen", + "lightgreen", + ] + + # Plotting + num_groups = len(groups) + fig, axs = plt.subplots(num_groups, 1, figsize=(10, num_groups * 5)) + + # Custom handles for the legend + # handles = [plt.Rectangle((0, 0), 1, 1, color="none", + # label=f"{key}: {value}") for key, value in legend_labels.items()] + + for i, (suffix, keys) in enumerate(groups.items()): + ax = axs[i] if num_groups > 1 else axs + # Extract values and errors + avg_values = [accumulated_metrics["avg"][key] for key in keys] + ci_lower = [accumulated_metrics["CI_0.95"][key][0] for key in keys] + ci_upper = [accumulated_metrics["CI_0.95"][key][1] for key in keys] + ci_error = [np.abs([a - b, c - a]) for a, b, c in zip(avg_values, ci_lower, ci_upper)] + + # Sort by absolute values of avg_values + sorted_indices = np.argsort([np.abs(val) for val in avg_values])[::-1] # Indices to sort in descending order + sorted_keys = np.array(keys)[sorted_indices] + sorted_avg_values = np.array(avg_values)[sorted_indices] + sorted_ci_error = np.array(ci_error)[sorted_indices] + + # Plot bars + bars = ax.bar( + sorted_keys, + np.abs(sorted_avg_values), + yerr=np.array(sorted_ci_error).T, + capsize=5, + color=colors, + ) + + # Set titles and labels + title_suffix = sorted_keys[0].split("_")[1] + ax.set_title(f'Metric: "{title_suffix}"') + ax.set_ylabel("Values") + ax.axhline(0, color="grey", linewidth=0.8) + ax.grid(axis="y") + + # Set x-ticks + ax.set_xticks(sorted_keys) + ax.set_xticklabels([key.split("_")[0] for key in sorted_keys], rotation=45, ha="right") + # Create a custom legend for each subplot + custom_labels = [legend_labels[key.split("_")[0]] for key in sorted_keys] + ax.legend(bars, custom_labels, loc="upper right") + + plt.tight_layout() + plt.savefig(log_dir_plots / "metrics_plot.png", bbox_inches="tight") + + def plot_attributions(self, features_attrs, timestep_attrs, method_name, log_dir): + """ + Plots the attribution values for features and timesteps. + + Args: + - features_attrs: Array of feature attribution values. + - timestep_attrs: Array of timestep attribution values. + - method_name: Name of the attribution method. + - log_dir: Directory to save the plots. + Returns: + Nothing + """ + + # Plot for feature attributions + x_values = np.arange(1, len(features_attrs) + 1) + plt.figure(figsize=(8, 6)) + plt.plot( + x_values, + features_attrs, + marker="o", + color="skyblue", + linestyle="-", + linewidth=2, + markersize=8, + ) + plt.xlabel("Feature") + plt.ylabel("{} Attribution".format(method_name)) + plt.title("{} Attribution Values".format(method_name)) + plt.xticks( + x_values, + [ + "height", + "weight", + "age", + "sex", + "time_idx", + "alb", + "alp", + "alt", + "ast", + "be", + "bicar", + "bili", + "bili_dir", + "bnd", + "bun", + "ca", + "cai", + "ck", + "ckmb", + "cl", + "crea", + "crp", + "dbp", + "fgn", + "fio2", + "glu", + "hgb", + "hr", + "inr_pt", + "k", + "lact", + "lymph", + "map", + "mch", + "mchc", + "mcv", + "methb", + "mg", + "na", + "neut", + "o2sat", + "pco2", + "ph", + "phos", + "plt", + "po2", + "ptt", + "resp", + "sbp", + "temp", + "tnt", + "urine", + "wbc", + ], + rotation=90, + ) + plt.tight_layout() + plt.savefig( + log_dir / "{}_attribution_features_plot.png".format(method_name), + bbox_inches="tight", + ) + + # Plot for timestep attributions + x_values = np.arange(1, len(timestep_attrs) + 1) + plt.figure(figsize=(8, 6)) + plt.plot( + x_values, + timestep_attrs, + marker="o", + color="skyblue", + linestyle="-", + linewidth=2, + markersize=8, + ) + plt.xlabel("Time Step") + plt.ylabel("{} Attribution".format(method_name)) + plt.title("{} Attribution Values".format(method_name)) + plt.xticks(x_values) + plt.tight_layout() + plt.savefig(log_dir / "{}_attribution_plot.png".format(method_name), bbox_inches="tight")