Skip to content

Commit

Permalink
Confusion matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Sep 17, 2024
1 parent b9aad2b commit 32d1707
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion icu_benchmarks/models/custom_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import logging

import torch
from typing import Callable
import numpy as np
from ignite.metrics import EpochMetric
from sklearn.metrics import balanced_accuracy_score, mean_absolute_error
from numpy import ndarray
from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix
from sklearn.calibration import calibration_curve
from scipy.spatial.distance import jensenshannon
from torchmetrics.classification import BinaryFairness
Expand Down Expand Up @@ -130,3 +133,18 @@ def feature_helper(self, trainer, step_prefix):
else:
feature_names = trainer.test_dataloaders.dataset.features
return feature_names

def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> torch.tensor:
y_pred = np.rint(y_pred).astype(int)
confusion = sk_confusion_matrix(y_true, y_pred)
if normalize:
confusion = confusion / confusion.sum()
confusion_tensor = torch.tensor(confusion)
# confusion = confusion.tolist()
confusion_dict = {}
for i in range(confusion.shape[0]):
for j in range(confusion.shape[1]):
confusion_dict[f"class_{i}_pred_{j}"] = confusion[i][j]
# logging.info(f"Confusion matrix: {confusion_dict}")
# dict = {"TP": confusion[0][0], "FP": confusion[0][1], "FN": confusion[1][0], "TN": confusion[1][1]}
return confusion_dict

0 comments on commit 32d1707

Please sign in to comment.