Skip to content

Commit

Permalink
Merge branch 'main' into ensembling_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Oct 13, 2023
2 parents 9629ba5 + 670eaec commit 45b527f
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 59 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ Our code is based on [PyTorch](http://pytorch.org)
and [Huggingface Transformers](https://huggingface.co/docs/transformers/index). We test the code on Python 3.10 and
3.11.

First install the package with `pip install -e .` in the root directory, or `pip install -e .[dev]` if you'd like to
contribute to the project (see **Development** section below). This should install all the necessary dependencies.
First install the package with `pip install -e .` in the root directory, or `pip install eleuther-elk` to install from PyPi. Use `pip install -e .[dev]` if you'd like to contribute to the project (see **Development** section below). This should install all the necessary dependencies.

To fit reporters for the HuggingFace model `model` and dataset `dataset`, just run:

Expand Down
1 change: 1 addition & 0 deletions comparison-sweeps
Submodule comparison-sweeps added at f4ed88
20 changes: 9 additions & 11 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ class Extract(Serializable):
templates are used."""

layers: tuple[int, ...] = ()
"""Indices of layers to extract hidden states from. We follow the HF convention, so
0 is the embedding, and 1 is the output of the first transformer layer."""
"""Indices of layers to extract hidden states from. We ignore the embedding,
have only the output of the transformer layers."""

layer_stride: InitVar[int] = 1
"""Shortcut for `layers = (0,) + tuple(range(1, num_layers + 1, stride))`."""
"""Shortcut for `tuple(range(1, num_layers, stride))`."""

seed: int = 42
"""Seed to use for prompt randomization. Defaults to 42."""
Expand Down Expand Up @@ -134,9 +134,8 @@ def __post_init__(self, layer_stride: int):
config = assert_type(
PretrainedConfig, AutoConfig.from_pretrained(self.model)
)
# Note that we always include 0 which is the embedding layer
layer_range = range(1, config.num_hidden_layers + 1, layer_stride)
self.layers = (0,) + tuple(layer_range)
layer_range = range(1, config.num_hidden_layers, layer_stride)
self.layers = tuple(layer_range)

def explode(self) -> list["Extract"]:
"""Explode this config into a list of configs, one for each layer."""
Expand Down Expand Up @@ -195,8 +194,7 @@ def extract_hiddens(
seed=cfg.seed,
)

# Add one to the number of layers to account for the embedding layer
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1))
layer_indices = cfg.layers or tuple(range(1, model.config.num_hidden_layers))

global_max_examples = cfg.max_examples[0 if split_type == "train" else 1]

Expand Down Expand Up @@ -331,7 +329,7 @@ def extract_hiddens(
**hidden_dict,
)
if has_lm_preds:
out_record["model_logits"] = lm_logits
out_record["model_logits"] = lm_logits.log_softmax(dim=-1)

num_yielded += 1
yield out_record
Expand Down Expand Up @@ -368,13 +366,13 @@ def hidden_features(cfg: Extract) -> tuple[DatasetInfo, Features]:
if num_dropped:
print(f"Dropping {num_dropped} non-multiple choice templates")

layer_indices = cfg.layers or tuple(range(1, model_cfg.num_hidden_layers))
layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
shape=(num_variants, num_classes, model_cfg.hidden_size),
)
# Add 1 to include the embedding layer
for layer in cfg.layers or range(model_cfg.num_hidden_layers + 1)
for layer in layer_indices
}
other_cols = {
"variant_ids": Sequence(
Expand Down
21 changes: 13 additions & 8 deletions elk/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class EvalResult:
roc_auc: RocAucResult
"""Area under the ROC curve. For multi-class classification, each class is treated
as a one-vs-rest binary classification problem."""
cal_thresh: float | None
"""The threshold used to compute the calibrated accuracy."""

def to_dict(self, prefix: str = "") -> dict[str, float]:
"""Convert the result to a dictionary."""
Expand All @@ -45,13 +47,19 @@ def to_dict(self, prefix: str = "") -> dict[str, float]:
else {}
)
auroc_dict = {f"{prefix}auroc_{k}": v for k, v in asdict(self.roc_auc).items()}
return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict}
return {
**auroc_dict,
**cal_acc_dict,
**acc_dict,
**cal_dict,
f"{prefix}cal_thresh": self.cal_thresh,

Check failure on line 55 in elk/metrics/eval.py

View workflow job for this annotation

GitHub Actions / run-tests (3.10, macos-latest)

Expression of type "dict[str, Any | float | None]" cannot be assigned to return type "dict[str, float]"   Type "float | None" cannot be assigned to type "float"     Type "None" cannot be assigned to type "float" (reportGeneralTypeIssues)
}


def calc_auroc(
y_logits: Tensor,
y_true: Tensor,
prompt_ensembling: PromptEnsembling,
ensembling: PromptEnsembling,
num_classes: int,
) -> RocAucResult:
"""
Expand All @@ -66,22 +74,21 @@ def calc_auroc(
Returns:
RocAucResult: A dictionary containing the AUROC and confidence interval.
"""
if prompt_ensembling == PromptEnsembling.NONE:
if ensembling == PromptEnsembling.NONE:
auroc = roc_auc_ci(
to_one_hot(y_true, num_classes).long().flatten(1), y_logits.flatten(1)
)
elif prompt_ensembling in (PromptEnsembling.PARTIAL, PromptEnsembling.FULL):
elif ensembling in (PromptEnsembling.PARTIAL, PromptEnsembling.FULL):
# Pool together the negative and positive class logits
if num_classes == 2:
auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0])
else:
auroc = roc_auc_ci(to_one_hot(y_true, num_classes).long(), y_logits)
else:
raise ValueError(f"Unknown mode: {prompt_ensembling}")
raise ValueError(f"Unknown mode: {ensembling}")

return auroc


def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult:
"""
Calculate the calibrated accuracies
Expand All @@ -99,7 +106,6 @@ def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult:
cal_acc = accuracy_ci(y_true, cal_preds)
return cal_acc


def calc_calibrated_errors(y_true, pos_probs) -> CalibrationEstimate:
"""
Calculate the expected calibration error.
Expand All @@ -116,7 +122,6 @@ def calc_calibrated_errors(y_true, pos_probs) -> CalibrationEstimate:
cal_err = cal.compute()
return cal_err


def calc_accuracies(y_logits, y_true) -> AccuracyResult:
"""
Calculate the accuracy
Expand Down
16 changes: 8 additions & 8 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,26 +221,26 @@ def collect(cls, model_path: Path) -> "ModelVisualization":
model_name = model_path.name
is_transfer = False

def get_train_dirs(model_path):
def get_eval_dirs(model_path):
# toplevel is either repo/dataset or dataset
for toplevel in model_path.iterdir():
if (toplevel / "eval.csv").exists():
yield toplevel
else:
for train_dir in toplevel.iterdir():
yield train_dir
for eval_dir in toplevel.iterdir():
yield eval_dir

for train_dir in get_train_dirs(model_path):
for train_dir in get_eval_dirs(model_path):
eval_df = cls._read_eval_csv(train_dir, train_dir.name, train_dir.name)
df = pd.concat([df, eval_df], ignore_index=True)
transfer_dir = train_dir / "transfer"
if transfer_dir.exists():
is_transfer = True
for eval_ds_dir in transfer_dir.iterdir():
eval_df = cls._read_eval_csv(
eval_ds_dir, eval_ds_dir.name, train_dir.name
for tfr_ds_dir in get_eval_dirs(transfer_dir):
tfr_df = cls._read_eval_csv(
tfr_ds_dir, tfr_ds_dir.name, train_dir.name
)
df = pd.concat([df, eval_df], ignore_index=True)
df = pd.concat([df, tfr_df], ignore_index=True)

df["model_name"] = model_name
return cls(df, model_name, is_transfer)
Expand Down
36 changes: 36 additions & 0 deletions elk/training/burns_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from torch import Tensor, nn


class BurnsNorm(nn.Module):
"""Burns et al. style normalization. Minimal changes from the original code."""

def __init__(self, scale: bool = True):
super().__init__()
self.scale: bool = scale

def forward(self, x: Tensor) -> Tensor:
"""Normalizes per prompt template
Args:
x: input of dimension (n, v, c, d) or (n, v, d)
Returns:
x_normalized: normalized output
"""
num_elements = x.shape[0]
x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x

if not self.scale:
return x_normalized
else:
std = torch.linalg.norm(x_normalized, dim=0) / x_normalized.shape[0] ** 0.5
assert std.dim() == x.dim() - 1

# Compute the dimensions over which
# we want to compute the mean standard deviation
# exclude the first dimension v,
# which is the template dimension
dims = tuple(range(1, std.dim()))

avg_norm = std.mean(dim=dims, keepdim=True)

return x_normalized / avg_norm
46 changes: 32 additions & 14 deletions elk/training/ccs_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import torch.nn as nn
from concept_erasure import LeaceFitter
from torch import Tensor
from typing_extensions import override

from ..parsing import parse_loss
from ..utils.typing import assert_type
from .burns_norm import BurnsNorm
from .common import FitterConfig
from .losses import LOSSES
from .platt_scaling import PlattMixin
Expand Down Expand Up @@ -41,6 +43,7 @@ class CcsConfig(FitterConfig):
function 1.0*consistency_squared + 0.5*prompt_var.
"""
loss_dict: dict[str, float] = field(default_factory=dict, init=False)
norm: Literal["leace", "burns"] = "leace"
num_layers: int = 1
"""The number of layers in the MLP."""
pre_ln: bool = False
Expand Down Expand Up @@ -86,6 +89,7 @@ def __init__(
num_variants: int = 1,
):
super().__init__()

self.config = cfg
self.in_features = in_features
self.num_variants = num_variants
Expand All @@ -105,6 +109,7 @@ def __init__(
device=device,
),
)

if cfg.pre_ln:
self.probe.insert(0, nn.LayerNorm(in_features, elementwise_affine=False))

Expand All @@ -125,6 +130,15 @@ def __init__(
)
)

@override
def parameters(self, recurse=True):
parameters = super(CcsReporter, self).parameters(recurse=recurse)
for param in parameters:
# exclude the platt scaling parameters
# kind of a hack for now, we should find probably a cleaner way
if param is not self.scale and param is not self.bias:
yield param

def reset_parameters(self):
"""Reset the parameters of the probe.
Expand Down Expand Up @@ -161,9 +175,9 @@ def reset_parameters(self):
def forward(self, x: Tensor) -> Tensor:
"""Return the credence assigned to the hidden state `x`."""
assert self.norm is not None, "Must call fit() before forward()"

raw_scores = self.probe(self.norm(x)).squeeze(-1)
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
return platt_scaled_scores

def loss(self, logit0: Tensor, logit1: Tensor) -> Tensor:
"""Return the loss of the reporter on the contrast pair (x0, x1).
Expand Down Expand Up @@ -193,18 +207,21 @@ def fit(self, hiddens: Tensor) -> float:
n, v, d = x_neg.shape
prompt_ids = torch.eye(v, device=x_neg.device).expand(n, -1, -1)

fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
)
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser
if self.config.norm == "burns":
self.norm = BurnsNorm()
else:
fitter = LeaceFitter(d, 2 * v, dtype=x_neg.dtype, device=x_neg.device)
fitter.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
)
fitter.update(
x=x_pos,
# Independent indicator for each (template, pseudo-label) pair
z=torch.cat([prompt_ids, torch.zeros_like(prompt_ids)], dim=-1),
)
self.norm = fitter.eraser

x_neg, x_pos = self.norm(x_neg), self.norm(x_pos)

Expand Down Expand Up @@ -236,6 +253,7 @@ def fit(self, hiddens: Tensor) -> float:
raise RuntimeError("Got NaN/infinite loss during training")

self.load_state_dict(best_state)

return best_loss

def train_loop_adam(self, x_neg: Tensor, x_pos: Tensor) -> float:
Expand Down
1 change: 0 additions & 1 deletion elk/training/eigen_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def update(self, hiddens: Tensor) -> None:
if self.config.erase_prompts:
# Independent indicator for each (template, pseudo-label) pair
indicators = torch.eye(k * v, device=hiddens.device).expand(n, -1, -1)
self.leace.update(x=hiddens, z=indicators)
else:
# Only use indicators for each pseudo-label
indicators = torch.eye(k, device=hiddens.device).expand(n, v, -1, -1)
Expand Down
13 changes: 4 additions & 9 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import InitVar, dataclass, replace
from dataclasses import InitVar, dataclass, field, replace

import numpy as np
import torch
from datasets import get_dataset_config_info
from transformers import AutoConfig

from ..evaluation import Eval
from ..extraction import Extract
from ..files import memorably_named_dir, sweeps_dir
from ..plotting.visualize import visualize_sweep
from ..training.eigen_reporter import EigenFitterConfig
Expand All @@ -22,7 +21,8 @@ def assert_models_exist(model_names):

def assert_datasets_exist(dataset_names):
for dataset_name in dataset_names:
get_dataset_config_info(dataset_name)
ds_name, _, config_name = dataset_name.partition(":")
get_dataset_config_info(ds_name, config_name=config_name)


@dataclass
Expand Down Expand Up @@ -52,12 +52,7 @@ class Sweep:
name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
run_template: Elicit = Elicit(
data=Extract(
model="<placeholder>",
datasets=("<placeholder>",),
)
)
run_template: Elicit = field(default_factory=Elicit.default)

def __post_init__(self, add_pooled: bool):
if not self.datasets:
Expand Down
Loading

0 comments on commit 45b527f

Please sign in to comment.