Skip to content

Commit

Permalink
save logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexTMallen committed Nov 10, 2023
1 parent 2deefc3 commit 100635c
Show file tree
Hide file tree
Showing 11 changed files with 298 additions and 95 deletions.
6 changes: 5 additions & 1 deletion ccs/debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def save_debug_log(datasets: list[DatasetDictWithName], out_dir: Path) -> None:
else:
train_split, val_split = select_train_val_splits(ds)

text_questions = ds[val_split][0]["text_questions"]
if len(ds[val_split]) == 0:
logging.warning(f"Val split '{val_split}' is empty!")
continue

text_questions = ds[val_split][0]["texts"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]

Expand Down
51 changes: 42 additions & 9 deletions ccs/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simple_parsing.helpers import field

from ..files import ccs_reporter_dir
from ..metrics import evaluate_preds
from ..metrics import evaluate_preds, get_logprobs
from ..run import Run
from ..utils import Color

Expand All @@ -31,7 +31,7 @@ def execute(self, highlight_color: Color = "cyan"):
@torch.inference_mode()
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
) -> dict[str, pd.DataFrame]:
) -> tuple[dict[str, pd.DataFrame], dict]:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")
Expand All @@ -41,28 +41,51 @@ def apply_to_layer(
reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = torch.load(reporter_path, map_location=device)

out_logprobs = defaultdict(dict)
row_bufs = defaultdict(list)
for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items():
for ds_name, val_data in val_output.items():
meta = {"dataset": ds_name, "layer": layer}
if self.save_logprobs:
out_logprobs[ds_name] = dict(
row_ids=val_data.row_ids.cpu(),
variant_ids=val_data.variant_ids,
texts=val_data.texts,
labels=val_data.labels.cpu(),
lm=dict(),
lr=dict(),
reporter=dict(),
)

val_credences = reporter(val_h)
val_credences = reporter(val_data.hiddens)
for mode in ("none", "partial", "full"):
row_bufs["eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_credences, mode).to_dict(),
**evaluate_preds(
val_data.labels, val_credences, mode
).to_dict(),
}
)
if self.save_logprobs:
out_logprobs[ds_name]["reporter"][mode] = (
get_logprobs(val_credences, mode).detach().cpu()
)

if val_lm_preds is not None:
if val_data.lm_preds is not None:
row_bufs["lm_eval"].append(
{
**meta,
"ensembling": mode,
**evaluate_preds(val_gt, val_lm_preds, mode).to_dict(),
**evaluate_preds(
val_data.labels, val_data.lm_preds, mode
).to_dict(),
}
)
if self.save_logprobs:
out_logprobs[ds_name]["lm"][mode] = get_logprobs(
val_data.lm_preds, mode
).cpu()

lr_dir = experiment_dir / "lr_models"
if not self.skip_supervised and lr_dir.exists():
Expand All @@ -71,15 +94,25 @@ def apply_to_layer(
if not isinstance(lr_models, list): # backward compatibility
lr_models = [lr_models]

if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode] = dict()

for i, model in enumerate(lr_models):
model.eval()
val_credences = model(val_data.hiddens)
if self.save_logprobs:
out_logprobs[ds_name]["lr"][mode][i] = get_logprobs(
val_credences, mode
).cpu()
row_bufs["lr_eval"].append(
{
"ensembling": mode,
"inlp_iter": i,
**meta,
**evaluate_preds(val_gt, model(val_h), mode).to_dict(),
**evaluate_preds(
val_data.labels, val_credences, mode
).to_dict(),
}
)

return {k: pd.DataFrame(v) for k, v in row_bufs.items()}
return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs
24 changes: 18 additions & 6 deletions ccs/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class Extract(Serializable):
"""The number of prompt templates to use for each example. If -1, all available
templates are used."""

balance: bool = True
"""Whether to balance the number of examples per class."""

layers: tuple[int, ...] = ()
"""Indices of layers to extract hidden states from. We ignore the embedding,
have only the output of the transformer layers."""
Expand Down Expand Up @@ -189,6 +192,7 @@ def extract_hiddens(
num_shots=cfg.num_shots,
split_type=split_type,
template_path=cfg.template_path,
balance=cfg.balance,
rank=rank,
world_size=world_size,
seed=cfg.seed,
Expand Down Expand Up @@ -229,7 +233,7 @@ def extract_hiddens(
)
for layer_idx in layer_indices
}
lm_logits = torch.empty(
lm_log_odds = torch.empty(
num_variants,
num_choices,
device=device,
Expand Down Expand Up @@ -289,7 +293,12 @@ def extract_hiddens(

# Compute the log probability of the answer tokens if available
if has_lm_preds:
lm_logits[i, j] = -assert_type(Tensor, outputs.loss)
logprob = -assert_type(Tensor, outputs.loss)
# Convert logprob to logodds to be consistent with reporters
# Because we went through logprobs, logodds corresponding to
# probs near 1 will be somewhat imprecise
# log(p/(1-p)) = log(p) - log(1-p) = logp - log(1 - exp(logp))
lm_log_odds[i, j] = logprob - torch.log1p(-logprob.exp())

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
Expand Down Expand Up @@ -323,14 +332,16 @@ def extract_hiddens(
continue

out_record: dict[str, Any] = dict(
row_id=example["row_id"],
label=example["label"],
variant_ids=example["template_names"],
text_questions=text_questions,
texts=text_questions,
**hidden_dict,
)
if has_lm_preds:
out_record["model_logits"] = lm_logits.log_softmax(dim=-1)
out_record["lm_log_odds"] = lm_log_odds.log_softmax(dim=-1)

assert out_record["variant_ids"] == sorted(out_record["variant_ids"])
num_yielded += 1
yield out_record

Expand Down Expand Up @@ -375,12 +386,13 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]:
for layer in layer_indices
}
other_cols = {
"row_id": Value(dtype="int64"),
"variant_ids": Sequence(
Value(dtype="string"),
length=num_variants,
),
"label": Value(dtype="int64"),
"text_questions": Sequence(
"texts": Sequence(
Sequence(
Value(dtype="string"),
),
Expand All @@ -390,7 +402,7 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]:

# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg, not cfg.use_encoder_states):
other_cols["model_logits"] = Array2D(
other_cols["lm_log_odds"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)
Expand Down
17 changes: 13 additions & 4 deletions ccs/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def load_prompts(
seed: int = 42,
split_type: Literal["train", "val"] = "train",
template_path: str | None = None,
balance: bool = True,
rank: int = 0,
world_size: int = 1,
) -> Iterator[dict]:
Expand All @@ -45,7 +46,14 @@ def load_prompts(
ds_dict = assert_type(dict, load_dataset(ds_name, config_name or None))
split_name = select_split(ds_dict, split_type)

ds = assert_type(Dataset, ds_dict[split_name].shuffle(seed=seed))
ds = assert_type(Dataset, ds_dict[split_name])

if "row_id" not in ds.column_names:
ds = ds.add_column("row_id", range(len(ds))) # type: ignore
else:
print("Found `row_id` column, using it as the example id")
ds = ds.shuffle(seed=seed)

if world_size > 1:
ds = ds.shard(world_size, rank)

Expand Down Expand Up @@ -89,14 +97,14 @@ def load_prompts(
else:
fewshot_iter = None

if label_column in ds.features:
if label_column in ds.features and balance:
ds = BalancedSampler(
ds.to_iterable_dataset(),
set(label_choices),
label_col=label_column,
)
else:
if rank == 0:
if rank == 0 and balance:
print("No label column found, not balancing")
ds = ds.to_iterable_dataset()

Expand All @@ -123,7 +131,7 @@ def _convert_to_prompts(
) -> dict[str, Any]:
"""Prompt-generating function to pass to `IterableDataset.map`."""
prompts = []
templates = list(prompter.templates.values())
templates = sorted(list(prompter.templates.values()), key=lambda t: t.name)

def qa_cat(q: str, a: str) -> str:
# if the jinja template already adds whitespace, don't add more
Expand Down Expand Up @@ -182,6 +190,7 @@ def qa_cat(q: str, a: str) -> str:
# If they're not, we need to convert them with index(). label_choices is guaranteed
# to be sorted (see above).
return dict(
row_id=example["row_id"],
label=label_choices.index(label),
prompts=prompts,
template_names=[template.name for template in templates],
Expand Down
3 changes: 2 additions & 1 deletion ccs/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .accuracy import accuracy_ci
from .calibration import CalibrationError, CalibrationEstimate
from .eval import EvalResult, evaluate_preds, to_one_hot
from .eval import EvalResult, evaluate_preds, get_logprobs, to_one_hot
from .roc_auc import RocAucResult, roc_auc, roc_auc_ci

__all__ = [
Expand All @@ -9,6 +9,7 @@
"CalibrationEstimate",
"EvalResult",
"evaluate_preds",
"get_logprobs",
"roc_auc",
"roc_auc_ci",
"to_one_hot",
Expand Down
24 changes: 24 additions & 0 deletions ccs/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Literal

import torch
import torch.nn.functional as F
from einops import repeat
from torch import Tensor

Expand Down Expand Up @@ -49,6 +50,29 @@ def to_dict(self, prefix: str = "") -> dict[str, Any]:
}


def get_logprobs(
y_logits: Tensor, ensembling: Literal["none", "partial", "full"] = "none"
) -> Tensor:
"""
Get the class probabilities from a tensor of logits.
Args:
y_logits: Predicted log-odds of the positive class, tensor of shape (n, v, c).
Returns:
Tensor of logprobs: If ensemble is "none" or "partial", tensor of shape (n, v).
If ensemble is "full", tensor of shape (n,).
"""
assert y_logits.shape[-1] == 2, "Logits must be binary."
if ensembling == "full":
y_logits = y_logits.mean(dim=1)

y_logits = (
y_logits[..., 1]
if ensembling == "none"
else y_logits[..., 1] - y_logits[..., 0]
)
return F.logsigmoid(y_logits)


def evaluate_preds(
y_true: Tensor,
y_logits: Tensor,
Expand Down
2 changes: 1 addition & 1 deletion ccs/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_fixed_answer_choices_list(self):
else:
return None

def apply(self, example, truncate=True, highlight_variables=False):
def apply(self, example, truncate=False, highlight_variables=False):
"""
Creates a prompt by applying this template to an example
Expand Down
Loading

0 comments on commit 100635c

Please sign in to comment.