Skip to content

Commit

Permalink
attention ROS RIS
Browse files Browse the repository at this point in the history
  • Loading branch information
youssefmecky96 committed Dec 5, 2023
1 parent 08c3c31 commit 4893f3e
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 35 deletions.
2 changes: 1 addition & 1 deletion configs/prediction_models/DeepARpytorch.gin
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ optimizer/hyperparameter.lr = (1e-5, 1e-3)
# Model params

model/hyperparameter.class_to_tune = @DeepARpytorch
model/hyperparameter.hidden = (40, 120, "log-uniform", 2)
model/hyperparameter.hidden = (4, 64, "log-uniform", 2)
model/hyperparameter.rnn_layers=(1, 3)
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.cell_type='LSTM'
Expand Down
2 changes: 1 addition & 1 deletion configs/prediction_models/RNNpytorch.gin
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ optimizer/hyperparameter.lr = (1e-5, 3e-4)
# Model params

model/hyperparameter.class_to_tune = @RNNpytorch
model/hyperparameter.hidden = (32, 256, "log-uniform", 2)
model/hyperparameter.hidden = (2, 64, "log-uniform", 2)
model/hyperparameter.rnn_layers=(1,3)
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.cell_type='LSTM'
Expand Down
6 changes: 5 additions & 1 deletion icu_benchmarks/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd
import gin
import numpy as np
from torch import Tensor, cat, from_numpy, float32, min, max
from torch import Tensor, cat, from_numpy, float32, min, max, randn_like
from torch.utils.data import Dataset
import logging
from typing import Dict, Tuple, Union
Expand Down Expand Up @@ -531,3 +531,7 @@ def randomize_labels(self, num_classes=None, min=None, max=None):
else:
random_target = np.random.randint(num_classes, size=len(self.data["target"][0]))
self.data["target"][0] = Tensor(random_target)

def add_noise(self, num_classes=None, min=None, max=None):
noise = randn_like(self.data["reals"])*0.1
self.data["reals"] += noise
2 changes: 1 addition & 1 deletion icu_benchmarks/models/dl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def __init__(
**kwargs,
):
super().__init__(optimizer=optimizer, pytorch_forecasting=True, *args, **kwargs)

self.dataset = dataset
self.model = TemporalFusionTransformer.from_dataset(
dataset=dataset,
hidden_size=hidden,
Expand Down
4 changes: 2 additions & 2 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,10 @@ def create_default_mask(shape):
print("{} Attributions Faithfulness Timesteps ".format(key), ts_score)
XAI_dict["{}_Faith Timesteps".format(key)] = ts_score
print("{}_ROS ".format(
key), st_o_score)
key), st_o_score, type(st_o_score))
XAI_dict["{}_ROS".format(key)] = st_o_score
print("{}_RIS ".format(
key), st_i_score)
key), st_i_score, type(st_i_score))
XAI_dict["{}_RIS".format(key)] = st_i_score
if key == "Att":
print("Variable selection weights faithfulness featrues ".format(key), v_score)
Expand Down
55 changes: 33 additions & 22 deletions icu_benchmarks/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ def explantation(self, dataloader, method, log_dir=".", plot=False, XAI_metric=F
timestep_attrs = Interpertations["attention"]
features_attrs = Interpertations["static_variables"].tolist()
features_attrs.extend(Interpertations["encoder_variables"].tolist())
r_score = self.Data_Randomization(x=None, attribution=timestep_attrs,
explain_method=method, random_model=random_model, dataloader=dataloader, method_name=method_name)
""" r_score = self.Data_Randomization(x=None, attribution=timestep_attrs,
explain_method=method, random_model=random_model, dataloader=dataloader, method_name=method_name) """
st_i_score, st_o_score = self.Relative_Stability(x=None,
attribution=timestep_attrs, explain_method=method, method_name=method_name, dataloader=dataloader, **kwargs
)
Expand Down Expand Up @@ -649,8 +649,7 @@ def explantation(self, dataloader, method, log_dir=".", plot=False, XAI_metric=F
# Faithfulness score for attribtuons of timesteps averaged over features
f_ts_score = np.mean(f_ts_score)
f_v_score = np.mean(f_v_score)
# min_val = np.min(r_score)
# max_val = np.max(r_score)

if method_name != "Attention":
# r_score = (r_score - min_val) / (max_val - min_val)
r_score = np.mean(r_score)
Expand Down Expand Up @@ -944,7 +943,7 @@ def Data_Randomization(
min_val = np.min(random_attr)
max_val = np.max(random_attr)
random_attr = (random_attr - min_val) / (max_val - min_val)
score = similarity_func(random_attr, attribution)*53
score = similarity_func(random_attr, attribution)
elif explain_method == "Random":
score = similarity_func(np.random.normal(size=[64, 24, 53]).flatten(), attribution.flatten())
else:
Expand Down Expand Up @@ -989,7 +988,7 @@ def relative_stability_objective(
x: np.ndarray, xs: np.ndarray, e_x: np.ndarray, e_xs: np.ndarray, eps_min=0.0001, input=False
) -> np.ndarray:
"""
Computes relative input stabilities maximization objective
Computes relative input and output stabilities maximization objective
as defined here :ref:`https://arxiv.org/pdf/2203.06877.pdf` by the authors.
Parameters
Expand Down Expand Up @@ -1019,9 +1018,7 @@ def norm_function(arr): return np.linalg.norm(arr, axis=(-1, -2)) # noqa
elif num_dim == 2:
def norm_function(arr): return np.linalg.norm(arr, axis=-1)
else:
raise ValueError(
"Relative Input Stability only supports 4D, 3D and 2D inputs (batch dimension inclusive)."
)
def norm_function(arr): return np.linalg.norm(arr)

# fmt: off

Expand All @@ -1043,16 +1040,30 @@ def norm_function(arr): return np.linalg.norm(arr, axis=-1)

return nominator / denominator

y_pred = self(self.prep_data(x)).detach()

if explain_method == "Attention":
with torch.no_grad():
noise = torch.randn_like(self.dataset["encoder_cont"])*0.1
self.dataset += noise

y_pred = self.model.predict(dataloader)
x_original = dataloader.dataset.data["reals"].clone()

dataloader.dataset.add_noise()
x_preturb = dataloader.dataset.data["reals"].clone()
""" dataloader = self.dataset.to_dataloader(
train=False,
batch_size=64,
num_workers=8,
pin_memory=False,
drop_last=True,
shuffle=False,
) """

y_pred_preturb = self.model.predict(dataloader)
Attention_weights = self.interpertations(dataloader)
att_preturb = Attention_weights["attention"].cpu().numpy()

RIS = relative_stability_objective(
x_original.detach().cpu().numpy(), x_preturb.detach().cpu().numpy(), attribution.cpu().numpy(), att_preturb, input=True
)
ROS = relative_stability_objective(
y_pred.cpu().numpy(), y_pred_preturb.cpu().numpy(), attribution.cpu().numpy(), att_preturb, input=False
)
else:
y_pred = self(self.prep_data(x)).detach()
x_original = x["encoder_cont"].detach().clone()
Expand All @@ -1078,12 +1089,12 @@ def norm_function(arr): return np.linalg.norm(arr, axis=-1)
# Process and store the calculated attributions
att_preturb = att_preturb[1].cpu().detach().numpy() if method_name in [
'Lime', 'FeatureAblation'] else torch.stack(att_preturb).cpu().detach().numpy()
RIS = relative_stability_objective(
x_original.detach().cpu().numpy(), x["encoder_cont"].detach().cpu().numpy(), attribution, att_preturb, input=True
)
ROS = relative_stability_objective(
y_pred.cpu().numpy(), y_pred_preturb.cpu().numpy(), attribution, att_preturb, input=False
)
RIS = relative_stability_objective(
x_original.detach().cpu().numpy(), x["encoder_cont"].detach().cpu().numpy(), attribution, att_preturb, input=True
)
ROS = relative_stability_objective(
y_pred.cpu().numpy(), y_pred_preturb.cpu().numpy(), attribution, att_preturb, input=False
)

return np.max(RIS), np.max(ROS)

Expand Down
25 changes: 18 additions & 7 deletions icu_benchmarks/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def plot_XAI_Metrics(accumulated_metrics, log_dir_plots):
'VSN': 'Variable Selection Network',
'L': 'Lime'
}
colors = ['navy', 'skyblue', 'crimson', 'salmon', 'teal', 'orange', 'darkgreen', 'lightgreen']

# Plotting
num_groups = len(groups)
Expand All @@ -303,23 +304,33 @@ def plot_XAI_Metrics(accumulated_metrics, log_dir_plots):
for i, (suffix, keys) in enumerate(groups.items()):

ax = axs[i] if num_groups > 1 else axs
# Extract values and errors
avg_values = [accumulated_metrics['avg'][key] for key in keys]
ci_lower = [accumulated_metrics['CI_0.95'][key][0] for key in keys]
ci_upper = [accumulated_metrics['CI_0.95'][key][1] for key in keys]
ci_error = [np.abs([a - b, c - a]) for a, b, c in zip(avg_values, ci_lower, ci_upper)]

bars = ax.bar(keys, np.abs(avg_values), yerr=np.array(ci_error).T, capsize=5)
# Modify the title to use the second suffix
title_suffix = keys[0].split('_')[1]
# Sort by absolute values of avg_values
sorted_indices = np.argsort([np.abs(val) for val in avg_values])[::-1] # Indices to sort in descending order
sorted_keys = np.array(keys)[sorted_indices]
sorted_avg_values = np.array(avg_values)[sorted_indices]
sorted_ci_error = np.array(ci_error)[sorted_indices]

# Plot bars
bars = ax.bar(sorted_keys, np.abs(sorted_avg_values), yerr=np.array(sorted_ci_error).T, capsize=5, color=colors)

# Set titles and labels
title_suffix = sorted_keys[0].split('_')[1]
ax.set_title(f'Metric: "{title_suffix}"')
ax.set_ylabel('Values')
ax.axhline(0, color='grey', linewidth=0.8)
ax.grid(axis='y')
# Modify x-axis labels to show only the prefix
ax.set_xticks(keys)
ax.set_xticklabels([key.split('_')[0] for key in keys], rotation=45, ha="right")

# Set x-ticks
ax.set_xticks(sorted_keys)
ax.set_xticklabels([key.split('_')[0] for key in sorted_keys], rotation=45, ha="right")
# Create a custom legend for each subplot
custom_labels = [legend_labels[key.split('_')[0]] for key in keys]
custom_labels = [legend_labels[key.split('_')[0]] for key in sorted_keys]
ax.legend(bars, custom_labels, loc='upper right')

plt.tight_layout()
Expand Down

0 comments on commit 4893f3e

Please sign in to comment.