diff --git a/icu_benchmarks/models/custom_metrics.py b/icu_benchmarks/models/custom_metrics.py index ddb5d37e..75648466 100644 --- a/icu_benchmarks/models/custom_metrics.py +++ b/icu_benchmarks/models/custom_metrics.py @@ -1,8 +1,11 @@ +import logging + import torch from typing import Callable import numpy as np from ignite.metrics import EpochMetric -from sklearn.metrics import balanced_accuracy_score, mean_absolute_error +from numpy import ndarray +from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix from sklearn.calibration import calibration_curve from scipy.spatial.distance import jensenshannon from torchmetrics.classification import BinaryFairness @@ -130,3 +133,18 @@ def feature_helper(self, trainer, step_prefix): else: feature_names = trainer.test_dataloaders.dataset.features return feature_names + +def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> torch.tensor: + y_pred = np.rint(y_pred).astype(int) + confusion = sk_confusion_matrix(y_true, y_pred) + if normalize: + confusion = confusion / confusion.sum() + confusion_tensor = torch.tensor(confusion) + # confusion = confusion.tolist() + confusion_dict = {} + for i in range(confusion.shape[0]): + for j in range(confusion.shape[1]): + confusion_dict[f"class_{i}_pred_{j}"] = confusion[i][j] + # logging.info(f"Confusion matrix: {confusion_dict}") + # dict = {"TP": confusion[0][0], "FP": confusion[0][1], "FN": confusion[1][0], "TN": confusion[1][1]} + return confusion_dict