From f3319c1e00de02b0b604b7ee51d75b58b23b8d9b Mon Sep 17 00:00:00 2001 From: jon Date: Fri, 13 Oct 2023 16:35:04 +0000 Subject: [PATCH] fix merge --- elk/metrics/accuracy.py | 5 ++++- elk/metrics/eval.py | 15 ++++++++------- pyproject.toml | 4 ++-- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/elk/metrics/accuracy.py b/elk/metrics/accuracy.py index 33b946321..2f9685f8b 100644 --- a/elk/metrics/accuracy.py +++ b/elk/metrics/accuracy.py @@ -14,11 +14,14 @@ class AccuracyResult: """Lower bound of the confidence interval.""" upper: float """Upper bound of the confidence interval.""" + cal_thresh: float | None + """The threshold used to compute the calibrated accuracy.""" def accuracy_ci( y_true: Tensor, y_pred: Tensor, + cal_thresh: float | None = None, *, num_samples: int = 1000, level: float = 0.95, @@ -79,4 +82,4 @@ def accuracy_ci( # Compute the point estimate. Call flatten to ensure that we get a single number # computed across cluster boundaries even if the inputs were clustered. estimate = y_true.flatten().eq(y_pred.flatten()).float().mean().item() - return AccuracyResult(estimate, lower, upper) + return AccuracyResult(estimate, lower, upper, cal_thresh) diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index 0cdfb032d..c77dd09e9 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -33,7 +33,7 @@ class EvalResult: cal_thresh: float | None """The threshold used to compute the calibrated accuracy.""" - def to_dict(self, prefix: str = "") -> dict[str, float]: + def to_dict(self, prefix: str = "") -> dict[str, float | None]: """Convert the result to a dictionary.""" acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()} cal_acc_dict = ( @@ -102,9 +102,9 @@ def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult: AccuracyResult: A dictionary containing the accuracy and confidence interval. """ - cal_thresh = pos_probs.float().quantile(y_true.float().mean()) + cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item() cal_preds = pos_probs.gt(cal_thresh).to(torch.int) - cal_acc = accuracy_ci(y_true, cal_preds) + cal_acc = accuracy_ci(y_true, cal_preds, cal_thresh) return cal_acc @@ -196,10 +196,11 @@ def calc_eval_results( acc = calc_accuracies(y_logits=y_logits, y_true=y_true) pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0]) - cal_acc = ( + cal_acc, cal_thresh = ( calc_calibrated_accuracies(y_true=y_true, pos_probs=pos_probs) if num_classes == 2 - else None + else None, + None, ) cal_err = ( calc_calibrated_errors(y_true=y_true, pos_probs=pos_probs) @@ -210,11 +211,11 @@ def calc_eval_results( auroc = calc_auroc( y_logits=y_logits, y_true=y_true, - prompt_ensembling=prompt_ensembling, + ensembling=prompt_ensembling, num_classes=num_classes, ) - return EvalResult(acc, cal_acc, cal_err, auroc) + return EvalResult(acc, cal_acc, cal_err, auroc, cal_thresh) def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: diff --git a/pyproject.toml b/pyproject.toml index f3f16504a..b0a078cde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ # We upstreamed bugfixes for Literal types in 0.1.1 "simple-parsing>=0.1.1", # Version 1.11 introduced Fully Sharded Data Parallel, which we plan to use soon - "torch>=1.11.0", + "torch==2.0", # Doesn't really matter but versions < 4.0 are very very old (pre-2016) "tqdm>=4.0.0", # 4.0 introduced the breaking change of using return_dict=True by default @@ -37,7 +37,7 @@ dependencies = [ # For visualization of results "plotly==5.14.1", "kaleido==0.2.1", - "rich==13.3.5" + "rich" ] version = "0.1.1"