Skip to content

Commit

Permalink
changes for quantus
Browse files Browse the repository at this point in the history
  • Loading branch information
youssefmecky96 committed Oct 25, 2023
1 parent d42bd73 commit 46eea23
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,15 @@ def train_common(

if explain:
explanintations = model.explantation_captum(
test_loader, log_dir, IntegratedGradients, target=(0, 1)
test_loader, log_dir, IntegratedGradients, target=0
)

print("IG", explanintations)
interperations = model.interpertations(test_loader, log_dir)
print("attention", interperations)
if XAI_metric:
batch = iter(test_loader).next()

batch = next(iter(test_loader))
model = model.cpu()
for key, value in batch[0].items():
batch[0][key] = batch[0][key].to(model.device)

Expand Down Expand Up @@ -316,7 +316,7 @@ def train_common(
# Reformat attributions.
attr, delta = explantation.attribute(
data,
target=(0, 1),
target=0,
return_convergence_delta=True,
baselines=baselines,
n_steps=20,
Expand Down

0 comments on commit 46eea23

Please sign in to comment.