Skip to content

Commit

Permalink
added pytorch forecasting wrapper to make it easier
Browse files Browse the repository at this point in the history
  • Loading branch information
youssefmecky96 committed Oct 31, 2023
1 parent 85216b7 commit 119dbcf
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 100 deletions.
98 changes: 8 additions & 90 deletions icu_benchmarks/models/dl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
PositionalEncoding,
)
import matplotlib.pyplot as plt
from icu_benchmarks.models.wrappers import DLPredictionWrapper
from torch import Tensor, FloatTensor, zeros_like, ones_like, randn_like, from_numpy
from icu_benchmarks.models.wrappers import DLPredictionWrapper, DLPredictionPytorchForecastingWrapper
from torch import Tensor, FloatTensor, zeros_like, ones_like, randn_like, is_tensor
from pytorch_forecasting import TemporalFusionTransformer, RecurrentNetwork, DeepAR
from pytorch_forecasting.metrics import QuantileLoss
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -464,7 +464,7 @@ def forward(self, x: Dict[str, Tensor]) -> Tensor:


@gin.configurable
class TFTpytorch(DLPredictionWrapper):
class TFTpytorch(DLPredictionPytorchForecastingWrapper):
"""
Implementation of https://arxiv.org/abs/1912.09363 from pytorch forecasting
"""
Expand Down Expand Up @@ -534,6 +534,7 @@ def actual_vs_predictions_plot(self, dataloader):
return predictions_vs_actuals

def interpertations(self, dataloader, log_dir):

raw_predictions = self.model.predict(dataloader, return_x=True, mode="raw")
interpretation = self.model.interpret_output(
raw_predictions.output, reduction="sum"
Expand All @@ -542,6 +543,8 @@ def interpertations(self, dataloader, log_dir):
for key, fig in figs.items():
fig.savefig(log_dir / f"interpretation_{key}.png", bbox_inches="tight")

self.model = self.model.to(self.device)

return interpretation

def predict_dependency(self, dataloader, variable, log_dir):
Expand Down Expand Up @@ -688,94 +691,9 @@ def explantation_captum(self, test_loader, log_dir, method, target):
plt.savefig(log_dir / "attribution_plot.png", bbox_inches="tight")
return means

def faithfulness_correlation(self, test_loader, attribution, nr_runs=100, pertrub=None, subset_size=4):
"""
Implementation of faithfulness correlation by Bhatt et al., 2020.
The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness
(or 'fidelity') with respect to the model behaviour.
Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and
the average explanation attribution for only the subset of features are (linearly) correlated, taking the
average over multiple runs and test samples. The metric returns one float per input-attribution pair that
ranges between -1 and 1, where higher scores are better.
For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline
or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified
test point and the average explanation attribution for only the subset of features is calculated. Results is
average over multiple runs and several test samples.
This code is adapted from the quantus libray to suit our use case
References:
1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model
explanations." IJCAI (2020): 3016-3022.
2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for responsible evaluation of neural network explanations and beyond." Journal of Machine Learning Research 24.34 (2023): 1-11.
"""
if pertrub == None:
pertrub = "baseline"
similarities = []
for batch in test_loader:

for key, value in batch[0].items():

batch[0][key] = batch[0][key].to(self.device)
x = batch[0]
data = (
x["encoder_cat"],
x["encoder_cont"],
x["encoder_target"],
x["encoder_lengths"],
x["decoder_cat"],
x["decoder_cont"],
x["decoder_target"],
x["decoder_lengths"],
x["decoder_time_idx"],
x["groups"],
x["target_scale"],
)

y_pred = self(data).detach().cpu().numpy()
pred_deltas = []
att_sums = []
for i_ix in range(nr_runs):
# Randomly mask by subset size.
a_ix = np.random.choice(x["encoder_cont"].shape[1], subset_size, replace=False)

# Move a_ix_tensor to the same device as mask

if pertrub == "Noise":
# add normal noise to input
noise = randn_like(x["encoder_cont"])

x["encoder_cont"][:, a_ix, :] += noise[:, a_ix, :]
elif pertrub == "baseline":
# Create a mask tensor with zeros at specified time steps and ones everywhere else
# pytorch bug need to change to cpu for next step and then revert
mask = ones_like(x["encoder_cont"]).cpu()

mask[:, a_ix, :] = 0
mask = mask.to(x["encoder_cont"].device)

x["encoder_cont"] = x["encoder_cont"] * mask

# Predict on perturbed input x.
y_pred_perturb = self(data).detach().cpu().numpy()
pred_deltas.append((y_pred - y_pred_perturb).mean(axis=(0, 2)))

# Sum attributions of the random subset.
att_sums.append(np.sum(attribution[a_ix]))
print(np.asarray(pred_deltas).shape)
print(np.asarray(att_sums).shape)
correlation_matrix = np.corrcoef(pred_deltas, att_sums, rowvar=False)

# Get the correlation coefficient from the correlation matrix
pearson_correlation = correlation_matrix[0, 1]
similarities.append(pearson_correlation)
return np.mean(similarities)


@gin.configurable
class RNNpytorch(DLPredictionWrapper):
class RNNpytorch(DLPredictionPytorchForecastingWrapper):
"""
Implementation of RNN from pytorch forecasting
"""
Expand Down Expand Up @@ -835,7 +753,7 @@ def forward(


@gin.configurable
class DeepARpytorch(DLPredictionWrapper):
class DeepARpytorch(DLPredictionPytorchForecastingWrapper):
"""
Implementation of RNN from pytorch forecasting
"""
Expand Down
19 changes: 9 additions & 10 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,18 @@ def train_common(
return 0

if explain:
explanintations = model.explantation_captum(
attributions_IG = model.explantation_captum(
test_loader, log_dir, IntegratedGradients, target=(0, 1)
)

print("IG", explanintations)
interperations = model.interpertations(test_loader, log_dir)
print("attention", interperations)
if XAI_metric:
explanintations = model.explantation_captum(
test_loader, log_dir, IntegratedGradients, target=(0, 1)
)
scores = model.faithfulness_correlation(test_loader, explanintations)
print(scores)
# print("IG", attributions_IG)
# Attention_weights = model.interpertations(test_loader, log_dir)
# print("attention", Attention_weights)
if XAI_metric:
scores = model.faithfulness_correlation(test_loader, attributions_IG)
print('Attributions faithfulness correlation', scores)
# scores = model.faithfulness_correlation(test_loader, Attention_weights["attention"])
print('Attention faithfulness correlation', scores)

model.set_weight("balanced", train_dataset)
test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0][
Expand Down
214 changes: 214 additions & 0 deletions icu_benchmarks/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,220 @@ def step_fn(self, element, step_prefix=""):
return loss


@gin.configurable("DLPredictionPytorchForecastingWrapper")
class DLPredictionPytorchForecastingWrapper(DLPredictionWrapper):
"""Interface for Deep Learning models."""

_supported_run_modes = [RunMode.classification, RunMode.regression]

def __init__(
self,
loss=CrossEntropyLoss(),
optimizer=torch.optim.Adam,
run_mode: RunMode = RunMode.classification,
input_shape=None,
lr: float = 0.002,
momentum: float = 0.9,
lr_scheduler: Optional[str] = None,
lr_factor: float = 0.99,
lr_steps: Optional[List[int]] = None,
epochs: int = 100,
input_size: Tensor = None,
initialization_method: str = "normal",
pytorch_forecasting: bool = False,
**kwargs,
):
super().__init__(
loss=loss,
optimizer=optimizer,
run_mode=run_mode,
input_shape=input_shape,
lr=lr,
momentum=momentum,
lr_scheduler=lr_scheduler,
lr_factor=lr_factor,
lr_steps=lr_steps,
epochs=epochs,
input_size=input_size,
initialization_method=initialization_method,
kwargs=kwargs,
)

def step_fn(self, element, step_prefix=""):
"""Perform a step in the DL prediction model training loop.
Args:
element (object):
step_prefix (str): Step type, by default: test, train, val.
"""

dic, labels = element[0], element[1][0]

if isinstance(labels, list):
labels = labels[-1]
data = (
dic["encoder_cat"],
dic["encoder_cont"],
dic["encoder_target"],
dic["encoder_lengths"],
dic["decoder_cat"],
dic["decoder_cont"],
dic["decoder_target"],
dic["decoder_lengths"],
dic["decoder_time_idx"],
dic["groups"],
dic["target_scale"],
)

mask = torch.ones_like(labels).bool()

out = self(data)

# If aux_loss is present, it is returned as a tuple
if len(out) == 2 and isinstance(out, tuple):
out, aux_loss = out
else:
aux_loss = 0
# Get prediction and target

prediction = (
torch.masked_select(out, mask.unsqueeze(-1))
.reshape(-1, out.shape[-1])
.to(self.device)
)

target = torch.masked_select(labels, mask).to(self.device)

if prediction.shape[-1] > 1 and self.run_mode == RunMode.classification:
# Classification task
loss = (
self.loss(
prediction, target.long(), weight=self.loss_weights.to(self.device)
)
+ aux_loss
)
# Returns torch.long because negative log likelihood loss
elif self.run_mode == RunMode.regression:
# Regression task
loss = self.loss(prediction[:, 0], target.float()) + aux_loss
else:
raise ValueError(
f"Run mode {self.run_mode} not yet supported. Please implement it."
)
transformed_output = self.output_transform((prediction, target))

for key, value in self.metrics[step_prefix].items():
if isinstance(value, torchmetrics.Metric):
if key == "Binary_Fairness":
feature_names = key.feature_helper(self.trainer)
value.update(
transformed_output[0],
transformed_output[1],
data,
feature_names,
)
else:
value.update(transformed_output[0], transformed_output[1])
else:
value.update(transformed_output)
self.log(
f"{step_prefix}/loss", loss, on_step=False, on_epoch=True, sync_dist=True
)
return loss

def faithfulness_correlation(self, test_loader, attribution, nr_runs=100, pertrub=None, subset_size=4):
"""
Implementation of faithfulness correlation by Bhatt et al., 2020.
The Faithfulness Correlation metric intend to capture an explanation's relative faithfulness
(or 'fidelity') with respect to the model behaviour.
Faithfulness correlation scores shows to what extent the predicted logits of each modified test point and
the average explanation attribution for only the subset of features are (linearly) correlated, taking the
average over multiple runs and test samples. The metric returns one float per input-attribution pair that
ranges between -1 and 1, where higher scores are better.
For each test sample, |S| features are randomly selected and replace them with baseline values (zero baseline
or average of set). Thereafter, Pearson’s correlation coefficient between the predicted logits of each modified
test point and the average explanation attribution for only the subset of features is calculated. Results is
average over multiple runs and several test samples.
This code is adapted from the quantus libray to suit our use case
References:
1) Umang Bhatt et al.: "Evaluating and aggregating feature-based model
explanations." IJCAI (2020): 3016-3022.
2)Hedström, Anna, et al. "Quantus: An explainable ai toolkit for responsible evaluation of neural network explanations and beyond." Journal of Machine Learning Research 24.34 (2023): 1-11.
"""

if torch.is_tensor(attribution):
# Convert the tensor to a NumPy array
example_numpy_array = attribution.cpu().detach().numpy()
if pertrub == None:
pertrub = "baseline"
similarities = []
for batch in test_loader:

for key, value in batch[0].items():

batch[0][key] = batch[0][key].to(self.device)
x = batch[0]
data = (
x["encoder_cat"],
x["encoder_cont"],
x["encoder_target"],
x["encoder_lengths"],
x["decoder_cat"],
x["decoder_cont"],
x["decoder_target"],
x["decoder_lengths"],
x["decoder_time_idx"],
x["groups"],
x["target_scale"],
)

y_pred = self(data).detach().cpu().numpy()
pred_deltas = []
att_sums = []
for i_ix in range(nr_runs):
# Randomly mask by subset size.
a_ix = np.random.choice(x["encoder_cont"].shape[1], subset_size, replace=False)

# Move a_ix_tensor to the same device as mask

if pertrub == "Noise":
# add normal noise to input
noise = torch.randn_like(x["encoder_cont"])

x["encoder_cont"][:, a_ix, :] += noise[:, a_ix, :]
elif pertrub == "baseline":
# Create a mask tensor with zeros at specified time steps and ones everywhere else
# pytorch bug need to change to cpu for next step and then revert
mask = torch.ones_like(x["encoder_cont"]).cpu()

mask[:, a_ix, :] = 0
mask = mask.to(x["encoder_cont"].device)

x["encoder_cont"] = x["encoder_cont"] * mask

# Predict on perturbed input x.
y_pred_perturb = self(data).detach().cpu().numpy()
print(y_pred - y_pred_perturb)
break
pred_deltas.append((y_pred - y_pred_perturb).mean(axis=(0, 2)))

# Sum attributions of the random subset.

att_sums.append(np.sum(attribution[a_ix]))
print(pred_deltas, att_sums)
correlation_matrix = np.corrcoef(pred_deltas, att_sums, rowvar=False)

# Get the correlation coefficient from the correlation matrix
pearson_correlation = correlation_matrix[0, 1]
similarities.append(pearson_correlation)
print(similarities)
return np.nanmean(similarities)


@gin.configurable("MLWrapper")
class MLWrapper(BaseModule, ABC):
"""Interface for prediction with traditional Scikit-learn-like Machine Learning models."""
Expand Down

0 comments on commit 119dbcf

Please sign in to comment.