diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 4e6ec59f3..253d30aa5 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -7,7 +7,7 @@ from simple_parsing.helpers import field from ..files import elk_reporter_dir -from ..metrics import evaluate_preds +from ..metrics import evaluate_preds, get_logprobs from ..run import Run from ..utils import Color @@ -31,7 +31,7 @@ def execute(self, highlight_color: Color = "cyan"): @torch.inference_mode() def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") @@ -41,31 +41,56 @@ def apply_to_layer( reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt" reporter = torch.load(reporter_path, map_location=device) + out_logprobs = defaultdict(dict) row_bufs = defaultdict(list) - for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items(): + for ds_name, val_data in val_output.items(): meta = {"dataset": ds_name, "layer": layer} - val_credences = reporter(val_h) + if self.save_logprobs: + out_logprobs[ds_name]["texts"] = val_data.text_questions + out_logprobs[ds_name]["labels"] = val_data.labels.cpu() + out_logprobs[ds_name]["reporter"] = dict() + out_logprobs[ds_name]["lr"] = dict() + out_logprobs[ds_name]["lm"] = dict() + + val_credences = reporter(val_data.hiddens) for mode in ("none", "partial", "full"): + if self.save_logprobs: + out_logprobs[ds_name]["reporter"][mode] = get_logprobs( + val_credences, mode + ).cpu() + out_logprobs[ds_name]["lm"][mode] = ( + get_logprobs(val_data.lm_preds, mode).cpu() + if val_data.lm_preds is not None + else None + ) + row_bufs["eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), + **evaluate_preds( + val_data.labels, val_credences, mode + ).to_dict(), } ) - if val_lm_preds is not None: + if val_data.lm_preds is not None: row_bufs["lm_eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + **evaluate_preds( + val_data.labels, val_data.lm_preds, mode + ).to_dict(), } ) lr_dir = experiment_dir / "lr_models" if not self.skip_supervised and lr_dir.exists(): + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode] = dict() + with open(lr_dir / f"layer_{layer}.pt", "rb") as f: lr_models = torch.load(f, map_location=device) if not isinstance(lr_models, list): # backward compatibility @@ -73,13 +98,20 @@ def apply_to_layer( for i, model in enumerate(lr_models): model.eval() + val_lr_credences = model(val_data.hiddens) + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode][i] = get_logprobs( + val_lr_credences, mode + ).cpu() row_bufs["lr_eval"].append( { "ensembling": mode, "inlp_iter": i, **meta, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds( + val_data.labels, val_lr_credences, mode + ).to_dict(), } ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs diff --git a/elk/metrics/__init__.py b/elk/metrics/__init__.py index 7fb214501..b07f67ff2 100644 --- a/elk/metrics/__init__.py +++ b/elk/metrics/__init__.py @@ -1,6 +1,6 @@ from .accuracy import accuracy_ci from .calibration import CalibrationError, CalibrationEstimate -from .eval import EvalResult, evaluate_preds, to_one_hot +from .eval import EvalResult, evaluate_preds, get_logprobs, to_one_hot from .roc_auc import RocAucResult, roc_auc, roc_auc_ci __all__ = [ @@ -9,6 +9,7 @@ "CalibrationEstimate", "EvalResult", "evaluate_preds", + "get_logprobs", "roc_auc", "roc_auc_ci", "to_one_hot", diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index 653beae55..319b95fdb 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -2,6 +2,7 @@ from typing import Literal import torch +import torch.nn.functional as F from einops import repeat from torch import Tensor @@ -41,6 +42,32 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: return {**auroc_dict, **cal_acc_dict, **acc_dict, **cal_dict} +def get_logprobs( + y_logits: Tensor, ensembling: Literal["none", "partial", "full"] = "none" +) -> Tensor: + """ + Get the class probabilities from a tensor of logits. + + Args: + y_logits: Predicted class tensor of shape (N, variants, 2). + + Returns: + Tensor: If ensemble is "none", a tensor of shape (N, variants, 2). + If ensemble is "partial", a tensor of shape (N, n_variants). + If ensemble is "full", a tensor of shape (N,). + """ + assert y_logits.shape[-1] == 2, "Save probs only supported for binary labels" + if ensembling == "none": + return F.logsigmoid(y_logits) + elif ensembling == "partial": + return F.logsigmoid(y_logits[:, :, 1] - y_logits[:, :, 0]) + elif ensembling == "full": + y_logits = y_logits.mean(dim=1) + return F.logsigmoid(y_logits[:, 1] - y_logits[:, 0]) + else: + raise ValueError(f"Unknown mode: {ensembling}") + + def evaluate_preds( y_true: Tensor, y_logits: Tensor, diff --git a/elk/run.py b/elk/run.py index fb8903ccf..0a621eb21 100644 --- a/elk/run.py +++ b/elk/run.py @@ -31,6 +31,14 @@ ) +@dataclass +class LayerData: + hiddens: Tensor + labels: Tensor + lm_preds: Tensor | None + text_questions: list[list[tuple[str, str]]] # (n, v, 2) + + @dataclass class Run(ABC, Serializable): data: Extract @@ -52,6 +60,13 @@ class Run(ABC, Serializable): num_gpus: int = -1 out_dir: Path | None = None disable_cache: bool = field(default=False, to_dict=False) + save_logprobs: bool = field(default=False, to_dict=False) + """ saves logprobs.pt containing {: {"texts": [n, v, 2], "labels": [n,] + "lm": {"none": [n, v, 2], "partial": [n, v], "full": [n,]}, + "reporter": {: {"none": [n, v, 2], "partial": [n, v], "full": [n,]}}, + "lr": {: {: {"none": ..., "partial": ..., "full": ...}}} + }}} + """ def execute( self, @@ -98,7 +113,7 @@ def execute( devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) - func: Callable[[int], dict[str, pd.DataFrame]] = partial( + func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]] = partial( self.apply_to_layer, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) @@ -106,12 +121,11 @@ def execute( @abstractmethod def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Train or eval a reporter on a single layer.""" def make_reproducible(self, seed: int): """Make the run reproducible by setting the random seed.""" - np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) @@ -125,7 +139,7 @@ def get_device(self, devices, world_size: int) -> str: def prepare_data( self, device: str, layer: int, split_type: Literal["train", "val"] - ) -> dict[str, tuple[Tensor, Tensor, Tensor | None]]: + ) -> dict[str, LayerData]: """Prepare data for the specified layer and split type.""" out = {} @@ -140,9 +154,13 @@ def prepare_data( with split.formatted_as("torch", device=device): has_preds = "model_logits" in split.features - lm_preds = split["model_logits"] if has_preds else None + lm_preds = ( + assert_type(Tensor, split["model_logits"]) if has_preds else None + ) + + text_questions = split["text_questions"] - out[ds_name] = (hiddens, labels.to(hiddens.device), lm_preds) + out[ds_name] = LayerData(hiddens, labels, lm_preds, text_questions) return out @@ -155,7 +173,7 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], dict[str, pd.DataFrame]], + func: Callable[[int], tuple[dict[str, pd.DataFrame], dict]], num_devices: int, ): """Apply a function to each layer of the datasets in parallel @@ -178,11 +196,16 @@ def apply_to_layers( with ctx.Pool(num_devices) as pool: mapper = pool.imap_unordered if num_devices > 1 else map df_buffers = defaultdict(list) + logprobs_dicts = defaultdict(dict) try: - for df_dict in tqdm(mapper(func, layers), total=len(layers)): + for layer, (df_dict, logprobs_dict) in tqdm( + zip(layers, mapper(func, layers)), total=len(layers) + ): for k, v in df_dict.items(): df_buffers[k].append(v) + for k, v in logprobs_dict.items(): + logprobs_dicts[k][layer] = logprobs_dict[k] finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): @@ -190,3 +213,21 @@ def apply_to_layers( df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) + if self.save_logprobs: + save_dict = defaultdict(dict) + for ds_name, logprobs_dict in logprobs_dicts.items(): + save_dict[ds_name]["texts"] = logprobs_dict[layers[0]]["texts"] + save_dict[ds_name]["labels"] = logprobs_dict[layers[0]][ + "labels" + ] + save_dict[ds_name]["lm"] = logprobs_dict[layers[0]]["lm"] + save_dict[ds_name]["reporter"] = dict() + save_dict[ds_name]["lr"] = dict() + for layer, logprobs_dict_by_mode in logprobs_dict.items(): + save_dict[ds_name]["reporter"][ + layer + ] = logprobs_dict_by_mode["reporter"] + save_dict[ds_name]["lr"][layer] = logprobs_dict_by_mode[ + "lr" + ] + torch.save(save_dict, self.out_dir / "logprobs.pt") diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 148da939f..0f4f25bb3 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -63,7 +63,7 @@ def fit( x: Tensor, y: Tensor, *, - l2_penalty: float = 0.0, + l2_penalty: float = 0.001, max_iter: int = 10_000, ) -> float: """Fits the model to the input data using L-BFGS with L2 regularization. @@ -180,6 +180,7 @@ def fit_cv( # Refit with the best penalty best_penalty = l2_penalties[best_idx] + print(f"Best L2 penalty: {best_penalty}") self.fit(x, y, l2_penalty=best_penalty, max_iter=max_iter) return RegularizationPath(l2_penalties, mean_losses.tolist()) diff --git a/elk/training/supervised.py b/elk/training/supervised.py index d2eef5f7f..e629a84e8 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -2,19 +2,20 @@ from einops import rearrange, repeat from ..metrics import to_one_hot +from ..run import LayerData from .classifier import Classifier def train_supervised( - data: dict[str, tuple], device: str, mode: str + data: dict[str, LayerData], device: str, mode: str ) -> list[Classifier]: Xs, train_labels = [], [] - for train_h, labels, _ in data.values(): - (_, v, k, _) = train_h.shape - train_h = rearrange(train_h, "n v k d -> (n v k) d") + for train_data in data.values(): + (_, v, k, _) = train_data.hiddens.shape + train_h = rearrange(train_data.hiddens, "n v k d -> (n v k) d") - labels = repeat(labels, "n -> (n v)", v=v) + labels = repeat(train_data.labels, "n -> (n v)", v=v) labels = to_one_hot(labels, k).flatten() Xs.append(train_h) diff --git a/elk/training/train.py b/elk/training/train.py index a7f0ef079..2d3072499 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -11,7 +11,7 @@ from simple_parsing import subgroups from simple_parsing.helpers.serialization import save -from ..metrics import evaluate_preds, to_one_hot +from ..metrics import evaluate_preds, get_logprobs, to_one_hot from ..run import Run from ..training.supervised import train_supervised from ..utils.typing import assert_type @@ -53,7 +53,7 @@ def apply_to_layer( layer: int, devices: list[str], world_size: int, - ) -> dict[str, pd.DataFrame]: + ) -> tuple[dict[str, pd.DataFrame], dict]: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.net.seed + layer) @@ -62,16 +62,16 @@ def apply_to_layer( train_dict = self.prepare_data(device, layer, "train") val_dict = self.prepare_data(device, layer, "val") - (first_train_h, train_gt, _), *rest = train_dict.values() - (_, v, k, d) = first_train_h.shape - if not all(other_h.shape[-1] == d for other_h, _, _ in rest): + first_train_data, *rest = train_dict.values() + (_, v, k, d) = first_train_data.hiddens.shape + if not all(other_data.hiddens.shape[-1] == d for other_data in rest): raise ValueError("All datasets must have the same hidden state size") # For a while we did support datasets with different numbers of classes, but # we reverted this once we switched to ConceptEraser. There are a few options # for re-enabling it in the future but they are somewhat complex and it's not # clear that it's worth it. - if not all(other_h.shape[-2] == k for other_h, _, _ in rest): + if not all(other_data.hiddens.shape[-2] == k for other_data in rest): raise ValueError("All datasets must have the same number of classes") reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) @@ -81,13 +81,15 @@ def apply_to_layer( assert len(train_dict) == 1, "CCS only supports single-task training" reporter = CcsReporter(self.net, d, device=device, num_variants=v) - train_loss = reporter.fit(first_train_h) + train_loss = reporter.fit(first_train_data.hiddens) if not self.net.norm == "burns": - (_, v, k, _) = first_train_h.shape + (_, v, k, _) = first_train_data.hiddens.shape reporter.platt_scale( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), - rearrange(first_train_h, "n v k d -> (n v k) d"), + to_one_hot( + repeat(first_train_data.labels, "n -> (n v)", v=v), k + ).flatten(), + rearrange(first_train_data.hiddens, "n v k d -> (n v k) d"), ) elif isinstance(self.net, EigenFitterConfig): @@ -96,16 +98,20 @@ def apply_to_layer( ) hidden_list, label_list = [], [] - for ds_name, (train_h, train_gt, _) in train_dict.items(): - (_, v, _, _) = train_h.shape + for ds_name, train_data in train_dict.items(): + (_, v, _, _) = train_data.hiddens.shape # Datasets can have different numbers of variants, so we need to # flatten them here before concatenating - hidden_list.append(rearrange(train_h, "n v k d -> (n v k) d")) + hidden_list.append( + rearrange(train_data.hiddens, "n v k d -> (n v k) d") + ) label_list.append( - to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten() + to_one_hot( + repeat(train_data.labels, "n -> (n v)", v=v), k + ).flatten() ) - fitter.update(train_h) + fitter.update(train_data.hiddens) reporter = fitter.fit_streaming() reporter.platt_scale( @@ -131,19 +137,36 @@ def apply_to_layer( lr_models = [] row_bufs = defaultdict(list) + out_logprobs = defaultdict(dict) for ds_name in val_dict: - val_h, val_gt, val_lm_preds = val_dict[ds_name] - train_h, train_gt, train_lm_preds = train_dict[ds_name] + val, train = val_dict[ds_name], train_dict[ds_name] meta = {"dataset": ds_name, "layer": layer} - val_credences = reporter(val_h) - train_credences = reporter(train_h) + if self.save_logprobs: + out_logprobs[ds_name]["texts"] = val.text_questions + out_logprobs[ds_name]["labels"] = val.labels.cpu() + out_logprobs[ds_name]["reporter"] = dict() + out_logprobs[ds_name]["lr"] = dict() + out_logprobs[ds_name]["lm"] = dict() + + val_credences = reporter(val.hiddens) + train_credences = reporter(train.hiddens) for mode in ("none", "partial", "full"): + if self.save_logprobs: + out_logprobs[ds_name]["reporter"][mode] = ( + get_logprobs(val_credences, mode).detach().cpu() + ) + out_logprobs[ds_name]["lm"][mode] = ( + get_logprobs(val.lm_preds, mode).detach().cpu() + if val.lm_preds is not None + else None + ) + row_bufs["eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), + **evaluate_preds(val.labels, val_credences, mode).to_dict(), "train_loss": train_loss, } ) @@ -152,37 +175,63 @@ def apply_to_layer( { **meta, "ensembling": mode, - **evaluate_preds(train_gt, train_credences, mode).to_dict(), + **evaluate_preds(train.labels, train_credences, mode).to_dict(), "train_loss": train_loss, } ) - if val_lm_preds is not None: + if val.lm_preds is not None: row_bufs["lm_eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + **evaluate_preds(val.labels, val.lm_preds, mode).to_dict(), } ) - if train_lm_preds is not None: + if train.lm_preds is not None: row_bufs["train_lm_eval"].append( { **meta, "ensembling": mode, - **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), - } - ) - - for i, model in enumerate(lr_models): - row_bufs["lr_eval"].append( - { - **meta, - "ensembling": mode, - "inlp_iter": i, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds( + train.labels, train.lm_preds, mode + ).to_dict(), } ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + if self.supervised != "none": + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode] = dict() + + for i, model in enumerate(lr_models): + model.eval() + val_lr_credences = model(val.hiddens) + train_lr_credences = model(train.hiddens) + if self.save_logprobs: + out_logprobs[ds_name]["lr"][mode][i] = ( + get_logprobs(val_lr_credences, mode).detach().cpu() + ) + + row_bufs["train_lr_eval"].append( + { + **meta, + "ensembling": mode, + "inlp_iter": i, + **evaluate_preds( + train.labels, train_lr_credences, mode + ).to_dict(), + } + ) + row_bufs["lr_eval"].append( + { + **meta, + "ensembling": mode, + "inlp_iter": i, + **evaluate_preds( + val.labels, val_lr_credences, mode + ).to_dict(), + } + ) + + return {k: pd.DataFrame(v) for k, v in row_bufs.items()}, out_logprobs