diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index 1caaf9c6..bf6ab4f9 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -272,15 +272,15 @@ def train_common( if explain: explanintations = model.explantation_captum( - test_loader, log_dir, IntegratedGradients, target=(0, 1) + test_loader, log_dir, IntegratedGradients, target=0 ) print("IG", explanintations) interperations = model.interpertations(test_loader, log_dir) print("attention", interperations) if XAI_metric: - batch = iter(test_loader).next() - + batch = next(iter(test_loader)) + model = model.cpu() for key, value in batch[0].items(): batch[0][key] = batch[0][key].to(model.device) @@ -316,7 +316,7 @@ def train_common( # Reformat attributions. attr, delta = explantation.attribute( data, - target=(0, 1), + target=0, return_convergence_delta=True, baselines=baselines, n_steps=20,