Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
youssefmecky96 committed Mar 8, 2024
1 parent 843c7e8 commit fe62ee2
Show file tree
Hide file tree
Showing 5 changed files with 615 additions and 528 deletions.
2 changes: 1 addition & 1 deletion icu_benchmarks/imputation/diffwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def forward(self, input_data):
cond = self.cond_conv(cond)
h += cond

out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])

res = self.res_conv(out)
assert x.shape == res.shape
Expand Down
30 changes: 15 additions & 15 deletions icu_benchmarks/models/custom_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,30 +422,30 @@ def norm_function(arr):
x_original = dataloader.dataset.data["reals"].clone()

dataloader.dataset.add_noise()
x_preturb = dataloader.dataset.data["reals"].clone()
y_pred_preturb = model.model.predict(dataloader)
x_perturb = dataloader.dataset.data["reals"].clone()
y_pred_perturb = model.model.predict(dataloader)
Attention_weights = model.interpertations(dataloader)
att_preturb = Attention_weights["attention"]
att_perturb = Attention_weights["attention"]
# Calculate the absolute difference
difference = torch.abs(y_pred_preturb - y_pred)
difference = torch.abs(y_pred_perturb - y_pred)

# Find where the difference is less than or equal to a thershold
close_indices = torch.nonzero(difference <= thershold).squeeze()[:, 0].to(device)

RIS = relative_stability_objective(
x_original.detach(),
x_preturb.detach(),
x_perturb.detach(),
attribution,
att_preturb,
att_perturb,
close_indices=close_indices,
input=True,
attention=True,
)
ROS = relative_stability_objective(
y_pred,
y_pred_preturb,
y_pred_perturb,
attribution,
att_preturb,
att_perturb,
close_indices=close_indices,
input=False,
attention=True,
Expand All @@ -458,16 +458,16 @@ def norm_function(arr):
with torch.no_grad():
noise = torch.randn_like(x["encoder_cont"]) * 0.01
x["encoder_cont"] += noise
y_pred_preturb = model(model.prep_data(x)).detach()
y_pred_perturb = model(model.prep_data(x)).detach()
if explain_method == "Random":
att_preturb = np.random.normal(size=[64, 24, 53])
att_perturb = np.random.normal(size=[64, 24, 53])

else:
att_preturb, features_attrs, timestep_attrs = model.explantation2(x, explain_method)
att_perturb, features_attrs, timestep_attrs = model.explantation2(x, explain_method)

#
# Calculate the absolute difference
difference = torch.abs(y_pred_preturb - y_pred)
difference = torch.abs(y_pred_perturb - y_pred)

# Find where the difference is less than or equal to a thershold
close_indices = torch.nonzero(difference <= thershold).squeeze()[:, 0].to(device)
Expand All @@ -476,15 +476,15 @@ def norm_function(arr):
x_original.detach(),
x["encoder_cont"].detach(),
attribution,
att_preturb,
att_perturb,
close_indices=close_indices,
input=True,
)
ROS = relative_stability_objective(
y_pred,
y_pred_preturb,
y_pred_perturb,
attribution,
att_preturb,
att_perturb,
close_indices=close_indices,
input=False,
)
Expand Down
Loading

0 comments on commit fe62ee2

Please sign in to comment.