Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Oct 13, 2023
1 parent a4874e1 commit f3319c1
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
5 changes: 4 additions & 1 deletion elk/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ class AccuracyResult:
"""Lower bound of the confidence interval."""
upper: float
"""Upper bound of the confidence interval."""
cal_thresh: float | None
"""The threshold used to compute the calibrated accuracy."""


def accuracy_ci(
y_true: Tensor,
y_pred: Tensor,
cal_thresh: float | None = None,
*,
num_samples: int = 1000,
level: float = 0.95,
Expand Down Expand Up @@ -79,4 +82,4 @@ def accuracy_ci(
# Compute the point estimate. Call flatten to ensure that we get a single number
# computed across cluster boundaries even if the inputs were clustered.
estimate = y_true.flatten().eq(y_pred.flatten()).float().mean().item()
return AccuracyResult(estimate, lower, upper)
return AccuracyResult(estimate, lower, upper, cal_thresh)
15 changes: 8 additions & 7 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class EvalResult:
cal_thresh: float | None
"""The threshold used to compute the calibrated accuracy."""

def to_dict(self, prefix: str = "") -> dict[str, float]:
def to_dict(self, prefix: str = "") -> dict[str, float | None]:
"""Convert the result to a dictionary."""
acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()}
cal_acc_dict = (
Expand Down Expand Up @@ -102,9 +102,9 @@ def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult:
AccuracyResult: A dictionary containing the accuracy and confidence interval.
"""

cal_thresh = pos_probs.float().quantile(y_true.float().mean())
cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item()
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = accuracy_ci(y_true, cal_preds)
cal_acc = accuracy_ci(y_true, cal_preds, cal_thresh)
return cal_acc


Expand Down Expand Up @@ -196,10 +196,11 @@ def calc_eval_results(
acc = calc_accuracies(y_logits=y_logits, y_true=y_true)

pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0])
cal_acc = (
cal_acc, cal_thresh = (
calc_calibrated_accuracies(y_true=y_true, pos_probs=pos_probs)
if num_classes == 2
else None
else None,
None,
)
cal_err = (
calc_calibrated_errors(y_true=y_true, pos_probs=pos_probs)
Expand All @@ -210,11 +211,11 @@ def calc_eval_results(
auroc = calc_auroc(
y_logits=y_logits,
y_true=y_true,
prompt_ensembling=prompt_ensembling,
ensembling=prompt_ensembling,
num_classes=num_classes,
)

return EvalResult(acc, cal_acc, cal_err, auroc)
return EvalResult(acc, cal_acc, cal_err, auroc, cal_thresh)


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies = [
# We upstreamed bugfixes for Literal types in 0.1.1
"simple-parsing>=0.1.1",
# Version 1.11 introduced Fully Sharded Data Parallel, which we plan to use soon
"torch>=1.11.0",
"torch==2.0",
# Doesn't really matter but versions < 4.0 are very very old (pre-2016)
"tqdm>=4.0.0",
# 4.0 introduced the breaking change of using return_dict=True by default
Expand All @@ -37,7 +37,7 @@ dependencies = [
# For visualization of results
"plotly==5.14.1",
"kaleido==0.2.1",
"rich==13.3.5"
"rich"

]
version = "0.1.1"
Expand Down

0 comments on commit f3319c1

Please sign in to comment.