Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to save probabilities #289

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simple_parsing.helpers import field

from ..files import elk_reporter_dir
from ..metrics import evaluate_preds
from ..metrics import evaluate_preds, get_logprobs
from ..run import Run
from ..utils import Color

Expand All @@ -31,7 +31,7 @@ def execute(self, highlight_color: Color = "cyan"):
@torch.inference_mode()
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], dict]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could do something similar to here: https://github.com/EleutherAI/elk/pull/259/files#diff-d13b83b80dc8fe2ae73e22669dd7a1a3167a1ae731d341fa96f03a766d877933R37
🟢

Instead of having tuple[dict[str, pd.DataFrame], dict]

But we can also leave it for now, and once we merge our pull-request it will be changed anyway

"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")
Expand All @@ -41,45 +41,77 @@ def apply_to_layer(
reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)

out_logprobs = defaultdict(dict)
row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items():
for ds_name, val_data in val_output.items():
meta = {"dataset": ds_name, "layer": layer}

val_credences = reporter(val_h)
if self.save_logprobs:
out_logprobs[ds_name]["texts"] = val_data.text_questions
out_logprobs[ds_name]["labels"] = val_data.labels.cpu()
out_logprobs[ds_name]["reporter"] = dict()
out_logprobs[ds_name]["lr"] = dict()
out_logprobs[ds_name]["lm"] = dict()

val_credences = reporter(val_data.hiddens)
for mode in ("none", "partial", "full"):
if self.save_logprobs:
out_logprobs[ds_name]["reporter"][mode] = get_logprobs(
val_credences, mode
).cpu()
out_logprobs[ds_name]["lm"][mode] = (
get_logprobs(val_data.lm_preds, mode).cpu()
if val_data.lm_preds is not None
else None
)

row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
**evaluate_preds(
val_data.labels, val_credences, mode
).to_dict(),
}
)

if val_lm_preds is not None:
if val_data.lm_preds is not None:
row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
**evaluate_preds(
val_data.labels, val_data.lm_preds, mode
).to_dict(),
}
)

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
lr_models = torch.load(f, map_location=device)
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

for i, model in enumerate(lr_models):
model.eval()
val_lr_credences = model(val_data.hiddens)
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = get_logprobs(
val_lr_credences, mode
).cpu()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**evaluate_preds(
val_data.labels, val_lr_credences, mode
).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs
3 changes: 2 additions & 1 deletion elk/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .accuracy import accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .eval import EvalResult, evaluate_preds, to_one_hot
from .eval import EvalResult, evaluate_preds, get_logprobs, to_one_hot
from .roc_auc import RocAucResult, roc_auc, roc_auc_ci

__all__ = [
Expand All @@ -9,6 +9,7 @@
"CalibrationEstimate",
"EvalResult",
"evaluate_preds",
"get_logprobs",
"roc_auc",
"roc_auc_ci",
"to_one_hot",
Expand Down
27 changes: 27 additions & 0 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Literal

import torch
import torch.nn.functional as F
from einops import repeat
from torch import Tensor

Expand Down Expand Up @@ -41,6 +42,32 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}


def get_logprobs(
y_logits: Tensor, ensembling: Literal["none", "partial", "full"] = "none"
) -> Tensor:
"""
Get the class probabilities from a tensor of logits.

Args:
y_logits: Predicted class tensor of shape (N, variants, 2).

Returns:
Tensor: If ensemble is "none", a tensor of shape (N, variants, 2).
If ensemble is "partial", a tensor of shape (N, n_variants).
If ensemble is "full", a tensor of shape (N,).
"""
assert y_logits.shape[-1] == 2, "Save probs only supported for binary labels"
if ensembling == "none":
return F.logsigmoid(y_logits)
elif ensembling == "partial":
return F.logsigmoid(y_logits[:, :, 1] - y_logits[:, :, 0])
elif ensembling == "full":
y_logits = y_logits.mean(dim=1)
return F.logsigmoid(y_logits[:, 1] - y_logits[:, 0])
else:
raise ValueError(f"Unknown mode: {ensembling}")


def evaluate_preds(
y_true: Tensor,
y_logits: Tensor,
Expand Down
57 changes: 49 additions & 8 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
)


@dataclass
class LayerData:
hiddens: Tensor
labels: Tensor
lm_preds: Tensor | None
text_questions: list[list[tuple[str, str]]] # (n, v, 2)


@dataclass
class Run(ABC, Serializable):
data: Extract
Expand All @@ -52,6 +60,13 @@ class Run(ABC, Serializable):
num_gpus: int = -1
out_dir: Path | None = None
disable_cache: bool = field(default=False, to_dict=False)
save_logprobs: bool = field(default=False, to_dict=False)
""" saves logprobs.pt containing {<dsname>: {"texts": [n, v, 2], "labels": [n,]
"lm": {"none": [n, v, 2], "partial": [n, v], "full": [n,]},
"reporter": {<layer>: {"none": [n, v, 2], "partial": [n, v], "full": [n,]}},
"lr": {<layer>: {<inlp_iter>: {"none": ..., "partial": ..., "full": ...}}}
}}}
"""

def execute(
self,
Expand Down Expand Up @@ -98,20 +113,19 @@ def execute(

devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem)
num_devices = len(devices)
func: Callable[[int], dict[str, pd.DataFrame]] = partial(
func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]] = partial(
self.apply_to_layer, devices=devices, world_size=num_devices
)
self.apply_to_layers(func=func, num_devices=num_devices)

@abstractmethod
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], dict]:
"""Train or eval a reporter on a single layer."""

def make_reproducible(self, seed: int):
"""Make the run reproducible by setting the random seed."""

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -125,7 +139,7 @@ def get_device(self, devices, world_size: int) -> str:

def prepare_data(
self, device: str, layer: int, split_type: Literal["train", "val"]
) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]:
) -> dict[str, LayerData]:
"""Prepare data for the specified layer and split type."""
out = {}

Expand All @@ -140,9 +154,13 @@ def prepare_data(

with split.formatted_as("torch", device=device):
has_preds = "model_logits" in split.features
lm_preds = split["model_logits"] if has_preds else None
lm_preds = (
assert_type(Tensor, split["model_logits"]) if has_preds else None
)

text_questions = split["text_questions"]

out[ds_name] = (hiddens, labels.to(hiddens.device), lm_preds)
out[ds_name] = LayerData(hiddens, labels, lm_preds, text_questions)

return out

Expand All @@ -155,7 +173,7 @@ def concatenate(self, layers):

def apply_to_layers(
self,
func: Callable[[int], dict[str, pd.DataFrame]],
func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]],
num_devices: int,
):
"""Apply a function to each layer of the datasets in parallel
Expand All @@ -178,15 +196,38 @@ def apply_to_layers(
with ctx.Pool(num_devices) as pool:
mapper = pool.imap_unordered if num_devices > 1 else map
df_buffers = defaultdict(list)
logprobs_dicts = defaultdict(dict)

try:
for df_dict in tqdm(mapper(func, layers), total=len(layers)):
for layer, (df_dict, logprobs_dict) in tqdm(
zip(layers, mapper(func, layers)), total=len(layers)
):
for k, v in df_dict.items():
df_buffers[k].append(v)
for k, v in logprobs_dict.items():
logprobs_dicts[k][layer] = logprobs_dict[k]
finally:
# Make sure the CSVs are written even if we crash or get interrupted
for name, dfs in df_buffers.items():
df = pd.concat(dfs).sort_values(by=["layer", "ensembling"])
df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False)
if self.debug:
save_debug_log(self.datasets, self.out_dir)
if self.save_logprobs:
save_dict = defaultdict(dict)
for ds_name, logprobs_dict in logprobs_dicts.items():
save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"]
save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][
"labels"
]
save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"]
save_dict[ds_name]["reporter"] = dict()
save_dict[ds_name]["lr"] = dict()
for layer, logprobs_dict_by_mode in logprobs_dict.items():
save_dict[ds_name]["reporter"][
layer
] = logprobs_dict_by_mode["reporter"]
save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[
"lr"
]
torch.save(save_dict, self.out_dir / "logprobs.pt")
3 changes: 2 additions & 1 deletion elk/training/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fit(
x: Tensor,
y: Tensor,
*,
l2_penalty: float = 0.0,
l2_penalty: float = 0.001,
max_iter: int = 10_000,
) -> float:
"""Fits the model to the input data using L-BFGS with L2 regularization.
Expand Down Expand Up @@ -180,6 +180,7 @@ def fit_cv(

# Refit with the best penalty
best_penalty = l2_penalties[best_idx]
print(f"Best L2 penalty: {best_penalty}")
self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter)
return RegularizationPath(l2_penalties, mean_losses.tolist())

Expand Down
11 changes: 6 additions & 5 deletions elk/training/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@
from einops import rearrange, repeat

from ..metrics import to_one_hot
from ..run import LayerData
from .classifier import Classifier


def train_supervised(
data: dict[str, tuple], device: str, mode: str
data: dict[str, LayerData], device: str, mode: str
) -> list[Classifier]:
Xs, train_labels = [], []

for train_h, labels, _ in data.values():
(_, v, k, _) = train_h.shape
train_h = rearrange(train_h, "n v k d -> (n v k) d")
for train_data in data.values():
(_, v, k, _) = train_data.hiddens.shape
train_h = rearrange(train_data.hiddens, "n v k d -> (n v k) d")

labels = repeat(labels, "n -> (n v)", v=v)
labels = repeat(train_data.labels, "n -> (n v)", v=v)
labels = to_one_hot(labels, k).flatten()

Xs.append(train_h)
Expand Down
Loading
Loading