Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensembling over layers #259

Open
wants to merge 79 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
a493b85
experiment with layer ensembling
lauritowal Apr 29, 2023
af5def6
add draft example for ensembling datasets
lauritowal Apr 30, 2023
04a2a82
add comment
lauritowal Apr 30, 2023
f433885
Merge branch 'main' into ensembling_layer
lauritofzi Apr 30, 2023
cda7de7
add eval in comments
lauritowal Apr 30, 2023
c9f2558
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Apr 30, 2023
0ceaa3a
add different root
lauritofzi May 1, 2023
86fb1c8
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritofzi May 1, 2023
47a3f60
Merge branch 'main' into ensembling_layer
lauritowal May 24, 2023
0bd274f
add empty list of vals
lauritowal May 27, 2023
04f0b4c
Merge branch 'main' into ensembling_layer
lauritowal Jun 16, 2023
994af9b
add first version of layer ensembling to eval
lauritowal Jun 17, 2023
6ca1916
add vals to train
lauritowal Jun 19, 2023
b0d0f83
refactoring & cleanup of eval and layer ensembling
lauritowal Jun 19, 2023
241a03a
add annotations
lauritowal Jun 19, 2023
e8d042a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2023
a4ace25
rename vals to layer_outputs
lauritowal Jun 19, 2023
b025c71
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jun 19, 2023
2156ad8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2023
e391da6
fir formatting
lauritowal Jun 19, 2023
449971f
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jun 19, 2023
528367d
make layer ensembling work on multiple gpus
lauritowal Jun 21, 2023
d4df517
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2023
2661ea1
make sure we use the same device
lauritowal Jun 21, 2023
043aa7a
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jun 21, 2023
21cccb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 21, 2023
2495c3a
calc layer ensembling for all prompt ensembling modes
lauritowal Jun 21, 2023
69af43c
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jun 21, 2023
908308b
implement ensembling enum
derpyplops Jun 23, 2023
fc980d7
Fix ensembling value writing error
derpyplops Jun 23, 2023
5aa30a9
accidentally a print
derpyplops Jun 23, 2023
d5b8584
slightly refactor layer stuff and fix tests
derpyplops Jun 23, 2023
6380814
try fixing type hints
derpyplops Jun 23, 2023
98d19b7
tidy up output
derpyplops Jun 23, 2023
e6914e1
accidentally a char
derpyplops Jun 23, 2023
29b1cb8
rename to PromptEnsembling
lauritowal Jun 24, 2023
421590c
Merge branch 'main' into ensembling_layer
lauritowal Jul 9, 2023
bed615a
add annotations and types
lauritowal Jul 9, 2023
bf49e99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2023
1f5d8be
clearer naming: prompt_ensembling
lauritowal Jul 9, 2023
03a37d2
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jul 12, 2023
ec37716
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
1936624
better name for ensembling
lauritowal Jul 12, 2023
b243932
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jul 12, 2023
484788e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
8c34797
Merge branch 'main' into ensembling_layer
lauritowal Jul 12, 2023
ea5e9e8
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jul 12, 2023
c0545aa
Merge branch 'main' into ensembling_layer
lauritowal Jul 13, 2023
b6de957
remove pseudo auroc
lauritowal Jul 13, 2023
cf32b0c
rename to prompt_ensembling
lauritowal Jul 13, 2023
769676a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2023
e6c9d4c
precomit fixes
lauritowal Jul 13, 2023
8093294
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jul 13, 2023
6028152
fix num_classes
lauritowal Jul 18, 2023
6d7d99a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2023
4148857
fix bug where y_true has a dimension of two
lauritowal Jul 18, 2023
964f03d
cleanup
lauritowal Jul 18, 2023
f7ed262
Merge branch 'ensembling_layer' of https://github.com/EleutherAI/elk …
lauritowal Jul 18, 2023
06dad69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2023
0d2545b
add y_true_initial
lauritowal Jul 18, 2023
4a717ce
merge
lauritowal Jul 18, 2023
5952b4b
fix test error
lauritowal Jul 18, 2023
7efe38f
Merge branch 'main' into ensembling_layer
lauritowal Jul 22, 2023
049cd63
replace mode with prompt_ensembling.value
lauritowal Jul 22, 2023
c8236dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2023
56d1796
remove .value for lm_eval
lauritowal Jul 23, 2023
4d9c781
add LayerApplied
derpyplops Jul 27, 2023
8961e95
fix run.py part
derpyplops Jul 27, 2023
bd06cd3
multidataset layer ensembling
derpyplops Jul 27, 2023
f8882c6
little refactoring
derpyplops Jul 27, 2023
23183bc
fix tests
derpyplops Jul 27, 2023
d091f9d
add annotation + cleanup
lauritowal Jul 31, 2023
776c186
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2023
9629ba5
Merge pull request #282 from EleutherAI/fix-ensembling-jon
derpyplops Jul 31, 2023
45b527f
Merge branch 'main' into ensembling_layer
derpyplops Oct 13, 2023
64e762a
[pre-commit.ci] auto fixes from pre-commit.com hooks
derpyplops Oct 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..metrics import evaluate_preds
from ..run import Run
from ..utils import Color
from ..utils.types import PromptEnsembling


@dataclass(kw_only=True)
Expand All @@ -31,7 +32,7 @@ def execute(self, highlight_color: Color = "cyan"):
@torch.inference_mode()
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], list[dict]]:
lauritowal marked this conversation as resolved.
Show resolved Hide resolved
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")
Expand All @@ -42,16 +43,23 @@ def apply_to_layer(
reporter = torch.load(reporter_path, map_location=device)

row_bufs = defaultdict(list)

layer_outputs = []
for ds_name, (val_h, val_gt, _) in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_credences = reporter(val_h)
for mode in ("none", "partial", "full"):
layer_outputs.append(
{**meta, "val_gt": val_gt, "val_credences": val_credences}
)
for prompt_ensembling in PromptEnsembling.all():
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
"prompt_ensembling": prompt_ensembling.value,
**evaluate_preds(
val_gt, val_credences, prompt_ensembling
).to_dict(),
}
)

Expand All @@ -66,11 +74,13 @@ def apply_to_layer(
model.eval()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"prompt_ensembling": prompt_ensembling.value,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**evaluate_preds(
val_gt, model(val_h), prompt_ensembling
).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
return ({k: pd.DataFrame(v) for k, v in row_bufs.items()}, layer_outputs)
197 changes: 167 additions & 30 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import asdict, dataclass
from typing import Literal

import torch
from einops import repeat
from torch import Tensor

from ..utils.types import PromptEnsembling
from .accuracy import AccuracyResult, accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .roc_auc import RocAucResult, roc_auc_ci
Expand Down Expand Up @@ -41,59 +41,196 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}


def calc_auroc(
y_logits: Tensor,
y_true: Tensor,
prompt_ensembling: PromptEnsembling,
num_classes: int,
) -> RocAucResult:
"""
Calculate the AUROC

Args:
y_true: Ground truth tensor of shape (n,).
y_logits: Predicted class tensor of shape (n, num_variants, num_classes).
prompt_ensembling: The prompt_ensembling mode.
num_classes: The number of classes.

Returns:
RocAucResult: A dictionary containing the AUROC and confidence interval.
"""
if prompt_ensembling == PromptEnsembling.NONE:
auroc = roc_auc_ci(
to_one_hot(y_true, num_classes).long().flatten(1), y_logits.flatten(1)
)
elif prompt_ensembling in (PromptEnsembling.PARTIAL, PromptEnsembling.FULL):
# Pool together the negative and positive class logits
if num_classes == 2:
auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0])
else:
derpyplops marked this conversation as resolved.
Show resolved Hide resolved
auroc = roc_auc_ci(to_one_hot(y_true, num_classes).long(), y_logits)
else:
raise ValueError(f"Unknown mode: {prompt_ensembling}")

return auroc


def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult:
"""
Calculate the calibrated accuracies

Args:
y_true: Ground truth tensor of shape (n,).
pos_probs: Predicted class tensor of shape (n, num_variants, num_classes).

Returns:
AccuracyResult: A dictionary containing the accuracy and confidence interval.
"""

cal_thresh = pos_probs.float().quantile(y_true.float().mean())
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = accuracy_ci(y_true, cal_preds)
return cal_acc


def calc_calibrated_errors(y_true, pos_probs) -> CalibrationEstimate:
"""
Calculate the expected calibration error.

Args:
y_true: Ground truth tensor of shape (n,).
y_logits: Predicted class tensor of shape (n, num_variants, num_classes).

Returns:
CalibrationEstimate:
"""

cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten())
cal_err = cal.compute()
return cal_err


def calc_accuracies(y_logits, y_true) -> AccuracyResult:
"""
Calculate the accuracy

Args:
y_true: Ground truth tensor of shape (n,).
y_logits: Predicted class tensor of shape (n, num_variants, num_classes).

Returns:
AccuracyResult: A dictionary containing the accuracy and confidence interval.
"""
y_pred = y_logits.argmax(dim=-1)
return accuracy_ci(y_true, y_pred)


def evaluate_preds(
y_true: Tensor,
y_logits: Tensor,
ensembling: Literal["none", "partial", "full"] = "none",
prompt_ensembling: PromptEnsembling = PromptEnsembling.NONE,
) -> EvalResult:
"""
Evaluate the performance of a classification model.

Args:
y_true: Ground truth tensor of shape (N,).
y_logits: Predicted class tensor of shape (N, variants, n_classes).
y_true: Ground truth tensor of shape (n,).
y_logits: Predicted class tensor of shape (n, num_variants, num_classes).
prompt_ensembling: The prompt_ensembling mode.

Returns:
dict: A dictionary containing the accuracy, AUROC, and ECE.
"""
(n, v, c) = y_logits.shape
(n, num_variants, num_classes) = y_logits.shape
assert y_true.shape == (n,)

if ensembling == "full":
if prompt_ensembling == PromptEnsembling.FULL:
y_logits = y_logits.mean(dim=1)
else:
y_true = repeat(y_true, "n -> n v", v=v)

y_pred = y_logits.argmax(dim=-1)
if ensembling == "none":
auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1))
elif ensembling in ("partial", "full"):
# Pool together the negative and positive class logits
if c == 2:
auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0])
else:
auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits)
else:
raise ValueError(f"Unknown mode: {ensembling}")
y_true = repeat(y_true, "n -> n v", v=num_variants)
return calc_eval_results(y_true, y_logits, prompt_ensembling, num_classes)

acc = accuracy_ci(y_true, y_pred)
cal_acc = None
cal_err = None

if c == 2:
pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0])
def calc_eval_results(
y_true: Tensor,
y_logits: Tensor,
prompt_ensembling: PromptEnsembling,
num_classes: int,
) -> EvalResult:
"""
Calculate the evaluation results

# Calibrated accuracy
cal_thresh = pos_probs.float().quantile(y_true.float().mean())
cal_preds = pos_probs.gt(cal_thresh).to(torch.int)
cal_acc = accuracy_ci(y_true, cal_preds)
Args:
y_true: Ground truth tensor of shape (n,).
y_logits: Predicted class tensor of shape (n, num_variants, num_classes).
prompt_ensembling: The prompt_ensembling mode.

cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten())
cal_err = cal.compute()
Returns:
EvalResult: The result of evaluating a classifier containing the accuracy,
calibrated accuracies, calibrated errors, and AUROC.
"""
acc = calc_accuracies(y_logits=y_logits, y_true=y_true)

pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0])
cal_acc = (
calc_calibrated_accuracies(y_true=y_true, pos_probs=pos_probs)
if num_classes == 2
else None
)
cal_err = (
calc_calibrated_errors(y_true=y_true, pos_probs=pos_probs)
if num_classes == 2
else None
)

auroc = calc_auroc(
y_logits=y_logits,
y_true=y_true,
prompt_ensembling=prompt_ensembling,
num_classes=num_classes,
)

return EvalResult(acc, cal_acc, cal_err, auroc)


def layer_ensembling(
layer_outputs: list, prompt_ensembling: PromptEnsembling
) -> EvalResult:
"""
Return EvalResult after prompt_ensembling the probe output of the middle to last layers

Args:
layer_outputs: A list of dictionaries containing the ground truth and
predicted class tensor of shape (n, num_variants, num_classes).
prompt_ensembling: The prompt_ensembling mode.

Returns:
EvalResult: The result of evaluating a classifier containing the accuracy,
calibrated accuracies, calibrated errors, and AUROC.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
y_logits_means = []
y_true = layer_outputs[0][0]["val_gt"].to(device)
lauritowal marked this conversation as resolved.
Show resolved Hide resolved

for layer_output in layer_outputs:
y_logits = layer_output[0]["val_credences"].to(device)
y_logits_means.append(y_logits.mean(dim=1)) # full prompt_ensembling

num_classes = layer_outputs[0][0]["val_credences"].shape[2]
# get logits and ground_truth from middle to last layer
middle_index = len(layer_outputs) // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in some ways I think we should allow the layers over which we ensemble to be configurable. E.g. sometimes the last layers perform worse.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it makes sense to make it configurable. However, I'm curious, how would you decide which layers to pick?

y_logits_stacked = torch.stack(y_logits_means[middle_index:])
# layer prompt_ensembling of the stacked logits
y_logits_stacked_mean = torch.mean(y_logits_stacked, dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the ensembling is done by taking the mean over layers, rather than concatenating. This isn't super clear from comments/docstrings, and hard to tell from reading the code because the shapes aren't commented.


return calc_eval_results(
y_true=y_true,
y_logits=y_logits_stacked_mean,
prompt_ensembling=prompt_ensembling,
num_classes=num_classes,
)


def to_one_hot(labels: Tensor, n_classes: int) -> Tensor:
"""
Convert a tensor of class labels to a one-hot representation.
Expand Down
Loading