Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
youssefmecky96 committed Dec 27, 2023
1 parent a9ec05e commit 1534c13
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 49 deletions.
35 changes: 7 additions & 28 deletions icu_benchmarks/models/custom_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def apply_baseline(x, indices, time_step, feature_timestep):
# Assuming 'attribution' is already a GPU tensor
if not torch.is_tensor(attribution):
attribution = torch.tensor(attribution).to(device)

# Other initializations
if similarity_func is None:
similarity_func = correlation_spearman
Expand Down Expand Up @@ -287,16 +286,15 @@ def apply_baseline(x, indices, time_step, feature_timestep):
if attribution.size() == torch.Size([24]):
att_sums.append((attribution[timesteps_idx]).sum())
else:
att_sums.append((attribution[patient_idx, :, :][:, timesteps_idx, :]).sum())
att_sums.append((attribution[patient_idx, :][:, timesteps_idx]).sum())
elif feature:

if len(attribution) == 53:
att_sums.append((attribution[feature_idx]).sum())
else:

att_sums.append((attribution[patient_idx, :, :][:, :, feature_idx]).sum())
att_sums.append((attribution[patient_idx, :][:, feature_idx]).sum())
elif feature_timestep:

att_sums.append((attribution[patient_idx, :, :]
[:, timesteps_idx, :][:, :, feature_idx]).sum())

Expand All @@ -323,7 +321,7 @@ def __init__(
)

def update(self, x,
attribution, model, explain_method, method_name, dataloader=None, thershold=0.5, device='cuda', **kwargs
attribution, model, explain_method, dataloader=None, thershold=0.5, device='cuda', **kwargs
):
"""
Args:
Expand Down Expand Up @@ -441,19 +439,9 @@ def norm_function(arr): return torch.norm(arr)
att_preturb = np.random.normal(size=[64, 24, 53])

else:
att_preturb, features_attrs, timestep_attrs = model.explantation2(x, explain_method)

data, baselines = model.prep_data_captum(x)

explantation = explain_method(model.forward_captum)
# Reformat attributions.
if explain_method is not captum.attr.Saliency:
att_preturb = explantation.attribute(data, baselines=baselines, **kwargs)
else:
att_preturb = explantation.attribute(data, **kwargs)

# Process and store the calculated attributions
att_preturb = att_preturb[1].detach() if method_name in [
'Lime', 'FeatureAblation'] else torch.stack(att_preturb).detach()
#
# Calculate the absolute difference
difference = torch.abs(y_pred_preturb - y_pred)

Expand Down Expand Up @@ -484,7 +472,7 @@ def __init__(

def update(
self, x,
attribution, model, explain_method, random_model, similarity_func=cosine, dataloader=None, method_name="", **kwargs
attribution, model, explain_method, random_model, similarity_func=cosine, dataloader=None, **kwargs
):
"""
Expand Down Expand Up @@ -531,16 +519,7 @@ def update(
data, baselines = model.prep_data_captum(x)
y_pred = model(data).detach()

explantation = explain_method(random_model.forward_captum)
# Reformat attributions.
if explain_method is not captum.attr.Saliency:
attr = explantation.attribute(data, baselines=baselines, **kwargs)
else:
attr = explantation.attribute(data, **kwargs)

# Process and store the calculated attributions
random_attr = attr[1].cpu().detach().numpy() if method_name in [
'Lime', 'FeatureAblation'] else torch.stack(attr).cpu().detach().numpy()
random_attr, features_attrs, timestep_attrs = model.explantation2(x, explain_method)

attribution = attribution.flatten()
min_val = np.min(attribution)
Expand Down
76 changes: 55 additions & 21 deletions icu_benchmarks/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,15 +518,16 @@ def step_fn(self, element, step_prefix=""):
if self.explain is not None:

for method in self.explain:
print(method)
if method == "Attention":
all_attrs, features_attrs, timestep_attrs = self.explantation2(dic, method)
all_attrs, features_attrs, timestep_attrs = self.explantation2(dic, method, self.test_loader)
else:
all_attrs, features_attrs, timestep_attrs = self.explantation2(dic, method)
for key, value in self.metrics[step_prefix].items():
if key == "Faithfulness_timesteps":
value.update(
dic,
all_attrs,
timestep_attrs,
self,
correlation_pearson,
100,
Expand All @@ -540,7 +541,7 @@ def step_fn(self, element, step_prefix=""):
elif key == "Faithfulness_features":
value.update(
dic,
all_attrs,
features_attrs,
self,
correlation_pearson,
100,
Expand Down Expand Up @@ -573,7 +574,6 @@ def step_fn(self, element, step_prefix=""):
dic,
timestep_attrs,
self,
self.methods[method],
method,
self.test_loader,
0.5,
Expand All @@ -584,7 +584,6 @@ def step_fn(self, element, step_prefix=""):
dic,
all_attrs,
self,
self.methods[method],
method,
None,
0.5,
Expand All @@ -596,27 +595,29 @@ def step_fn(self, element, step_prefix=""):
dic,
timestep_attrs,
self,
self.methods[method],
method,
self.random_model,
cosine,
None,
method

)
else:

value.update(
dic,
all_attrs,
self,
self.methods[method],
method,
self.random_model,
cosine,
self.test_loader,
method

)

for key, value in self.metrics[step_prefix].items():

if (key == "Faithfulness_timesteps" or key == "Faithfulness_features" or key == "Faithfulness_feature_timestep"
or key == "Stability" or key == "Randomization"):
continue
if isinstance(value, torchmetrics.Metric):
if key == "Binary_Fairness":
feature_names = self.metrics[step_prefix][key].feature_helper(self.trainer, step_prefix)
Expand All @@ -626,9 +627,6 @@ def step_fn(self, element, step_prefix=""):
data,
feature_names,
)
elif (key == "Faithfulness_timesteps" or key == "Faithfulness_features" or key == "Faithfulness_feature_timestep"
or key == "Stability" or key == "Randomization"):
continue

else:
value.update(transformed_output[0], transformed_output[1].int())
Expand Down Expand Up @@ -780,23 +778,59 @@ def explantation2(self, dic, method, dataloader=None, log_dir=".", plot=False, *

# Calculate attributions using the selected method
if method is not Saliency:
attr = explanation.attribute(data, baselines=baselines, **kwargs)
if method_name == "IG":
attr = explanation.attribute(data, baselines=baselines, n_steps=50, **kwargs)
elif method_name == "L" or method_name == "FA":
# for Lime and feature ablation we need to define what is a feature we define each variable per timestep as a feature
shapes = [
torch.Size([64, 24, 0]),
torch.Size([64, 24, 53]),
torch.Size([64, 24]),
torch.Size([64]),
torch.Size([64, 1, 0]),
torch.Size([64, 1, 53]),
torch.Size([64, 1]),
torch.Size([64]),
torch.Size([64, 1]),
torch.Size([64, 1]),
torch.Size([64, 2])
]

# Create a default mask for non-targeted tensors
def create_default_mask(shape):
if len(shape) == 3:
return torch.zeros(shape[0], shape[1], max(1, shape[2]), dtype=torch.int32)
elif len(shape) == 2:
return torch.zeros(shape[0], max(1, shape[1]), dtype=torch.int32)
else: # len(shape) == 1
return torch.zeros(shape[0], dtype=torch.int32)

# Create a feature mask for the second tensor that includes both features and timesteps
num_timesteps = shapes[1][1]
num_features = shapes[1][2]
feature_mask_second = torch.arange(num_timesteps * num_features).reshape(num_timesteps, num_features)
feature_mask_second = feature_mask_second.unsqueeze(0).repeat(shapes[1][0], 1, 1)

# Create a tuple of masks
feature_masks = tuple([create_default_mask(shape) if i !=
1 else feature_mask_second for i, shape in enumerate(shapes)])

attr = explanation.attribute(data, baselines=baselines, feature_mask=feature_masks, **kwargs)
else:
attr = explanation.attribute(data, **kwargs)
self.eval()
# Process and store the calculated attributions
stacked_attr = attr[1].cpu().detach().numpy() if method_name in [
'Lime', 'FeatureAblation'] else torch.stack(attr).cpu().detach().numpy()

# aggregate over batch
attr = np.mean(stacked_attr, axis=0)
all_attrs.append(attr)
# attr = np.mean(stacked_attr, axis=0)

# aggregate over all batches
all_attrs = np.array(all_attrs).mean(axis=(0))
# all_attrs = np.array(all_attrs).mean(axis=(0))
# aggregate over all timesteps
features_attrs = all_attrs.mean(axis=(0))
features_attrs = stacked_attr.mean(axis=(1))
# aggregate over all features
timestep_attrs = all_attrs.mean(axis=(1))
timestep_attrs = stacked_attr.mean(axis=(2))

if plot:
log_dir_plots = log_dir / 'plots'
Expand All @@ -807,7 +841,7 @@ def explantation2(self, dic, method, dataloader=None, log_dir=".", plot=False, *
self.plot_attributions(features_attrs, timestep_attrs, method_name, log_dir_plots)

# Return computed attributions and metrics
return all_attrs, features_attrs, timestep_attrs
return stacked_attr, features_attrs, timestep_attrs

def explantation(self, dataloader, method, log_dir=".", plot=False, XAI_metric=False, random_model=None, test_dataset=None, **kwargs):
"""
Expand Down

0 comments on commit 1534c13

Please sign in to comment.