diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index 4e6ec59f3..a74dca262 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -8,8 +8,11 @@ from ..files import elk_reporter_dir from ..metrics import evaluate_preds -from ..run import Run +from ..run import LayerApplied, LayerOutput, Run from ..utils import Color +from ..utils.types import PromptEnsembling + +PROMPT_ENSEMBLING = "prompt_ensembling" @dataclass(kw_only=True) @@ -31,7 +34,7 @@ def execute(self, highlight_color: Color = "cyan"): @torch.inference_mode() def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> LayerApplied: """Evaluate a single reporter on a single layer.""" device = self.get_device(devices, world_size) val_output = self.prepare_data(device, layer, "val") @@ -42,16 +45,23 @@ def apply_to_layer( reporter = torch.load(reporter_path, map_location=device) row_bufs = defaultdict(list) + + layer_outputs: list[LayerOutput] = [] + for ds_name, (val_h, val_gt, val_lm_preds) in val_output.items(): meta = {"dataset": ds_name, "layer": layer} val_credences = reporter(val_h) - for mode in ("none", "partial", "full"): + + layer_outputs.append(LayerOutput(val_gt, val_credences, meta)) + for prompt_ensembling in PromptEnsembling.all(): row_bufs["eval"].append( { **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), + PROMPT_ENSEMBLING: prompt_ensembling.value, + **evaluate_preds( + val_gt, val_credences, prompt_ensembling + ).to_dict(), } ) @@ -59,8 +69,10 @@ def apply_to_layer( row_bufs["lm_eval"].append( { **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + PROMPT_ENSEMBLING: prompt_ensembling.value, + **evaluate_preds( + val_gt, val_lm_preds, prompt_ensembling + ).to_dict(), } ) @@ -75,11 +87,14 @@ def apply_to_layer( model.eval() row_bufs["lr_eval"].append( { - "ensembling": mode, + PROMPT_ENSEMBLING: prompt_ensembling.value, "inlp_iter": i, **meta, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds( + val_gt, model(val_h), prompt_ensembling + ).to_dict(), } ) - - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return LayerApplied( + layer_outputs, {k: pd.DataFrame(v) for k, v in row_bufs.items()} + ) diff --git a/elk/metrics/accuracy.py b/elk/metrics/accuracy.py index 33b946321..2f9685f8b 100644 --- a/elk/metrics/accuracy.py +++ b/elk/metrics/accuracy.py @@ -14,11 +14,14 @@ class AccuracyResult: """Lower bound of the confidence interval.""" upper: float """Upper bound of the confidence interval.""" + cal_thresh: float | None + """The threshold used to compute the calibrated accuracy.""" def accuracy_ci( y_true: Tensor, y_pred: Tensor, + cal_thresh: float | None = None, *, num_samples: int = 1000, level: float = 0.95, @@ -79,4 +82,4 @@ def accuracy_ci( # Compute the point estimate. Call flatten to ensure that we get a single number # computed across cluster boundaries even if the inputs were clustered. estimate = y_true.flatten().eq(y_pred.flatten()).float().mean().item() - return AccuracyResult(estimate, lower, upper) + return AccuracyResult(estimate, lower, upper, cal_thresh) diff --git a/elk/metrics/eval.py b/elk/metrics/eval.py index d0e2bf7a5..c77dd09e9 100644 --- a/elk/metrics/eval.py +++ b/elk/metrics/eval.py @@ -1,15 +1,22 @@ from dataclasses import asdict, dataclass -from typing import Literal import torch from einops import repeat from torch import Tensor +from ..utils.types import PromptEnsembling from .accuracy import AccuracyResult, accuracy_ci from .calibration import CalibrationError, CalibrationEstimate from .roc_auc import RocAucResult, roc_auc_ci +@dataclass +class LayerOutput: + val_gt: Tensor + val_credences: Tensor + meta: dict + + @dataclass(frozen=True) class EvalResult: """The result of evaluating a classifier.""" @@ -26,7 +33,7 @@ class EvalResult: cal_thresh: float | None """The threshold used to compute the calibrated accuracy.""" - def to_dict(self, prefix: str = "") -> dict[str, float]: + def to_dict(self, prefix: str = "") -> dict[str, float | None]: """Convert the result to a dictionary.""" acc_dict = {f"{prefix}acc_{k}": v for k, v in asdict(self.accuracy).items()} cal_acc_dict = ( @@ -49,67 +56,164 @@ def to_dict(self, prefix: str = "") -> dict[str, float]: } +def calc_auroc( + y_logits: Tensor, + y_true: Tensor, + ensembling: PromptEnsembling, + num_classes: int, +) -> RocAucResult: + """ + Calculate the AUROC + + Args: + y_true: Ground truth tensor of shape (n,). + y_logits: Predicted class tensor of shape (n, num_variants, num_classes). + prompt_ensembling: The prompt_ensembling mode. + num_classes: The number of classes. + + Returns: + RocAucResult: A dictionary containing the AUROC and confidence interval. + """ + if ensembling == PromptEnsembling.NONE: + auroc = roc_auc_ci( + to_one_hot(y_true, num_classes).long().flatten(1), y_logits.flatten(1) + ) + elif ensembling in (PromptEnsembling.PARTIAL, PromptEnsembling.FULL): + # Pool together the negative and positive class logits + if num_classes == 2: + auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) + else: + auroc = roc_auc_ci(to_one_hot(y_true, num_classes).long(), y_logits) + else: + raise ValueError(f"Unknown mode: {ensembling}") + + return auroc + + +def calc_calibrated_accuracies(y_true, pos_probs) -> AccuracyResult: + """ + Calculate the calibrated accuracies + + Args: + y_true: Ground truth tensor of shape (n,). + pos_probs: Predicted class tensor of shape (n, num_variants, num_classes). + + Returns: + AccuracyResult: A dictionary containing the accuracy and confidence interval. + """ + + cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item() + cal_preds = pos_probs.gt(cal_thresh).to(torch.int) + cal_acc = accuracy_ci(y_true, cal_preds, cal_thresh) + return cal_acc + + +def calc_calibrated_errors(y_true, pos_probs) -> CalibrationEstimate: + """ + Calculate the expected calibration error. + + Args: + y_true: Ground truth tensor of shape (n,). + y_logits: Predicted class tensor of shape (n, num_variants, num_classes). + + Returns: + CalibrationEstimate: + """ + + cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) + cal_err = cal.compute() + return cal_err + + +def calc_accuracies(y_logits, y_true) -> AccuracyResult: + """ + Calculate the accuracy + + Args: + y_true: Ground truth tensor of shape (n,). + y_logits: Predicted class tensor of shape (n, num_variants, num_classes). + + Returns: + AccuracyResult: A dictionary containing the accuracy and confidence interval. + """ + y_pred = y_logits.argmax(dim=-1) + return accuracy_ci(y_true, y_pred) + + def evaluate_preds( y_true: Tensor, y_logits: Tensor, - ensembling: Literal["none", "partial", "full"] = "none", + prompt_ensembling: PromptEnsembling = PromptEnsembling.NONE, ) -> EvalResult: """ Evaluate the performance of a classification model. Args: - y_true: Ground truth tensor of shape (N,). - y_logits: Predicted class tensor of shape (N, variants, n_classes). + y_true: Ground truth tensor of shape (n,). + y_logits: Predicted class tensor of shape (n, num_variants, num_classes). + prompt_ensembling: The prompt_ensembling mode. Returns: dict: A dictionary containing the accuracy, AUROC, and ECE. """ - (n, v, c) = y_logits.shape - assert y_true.shape == (n,) + y_logits, y_true, num_classes = prepare(y_logits, y_true, prompt_ensembling) + return calc_eval_results(y_true, y_logits, prompt_ensembling, num_classes) + + +def prepare(y_logits: Tensor, y_true: Tensor, prompt_ensembling: PromptEnsembling): + """ + Prepare the logits and ground truth for evaluation + """ + (n, num_variants, num_classes) = y_logits.shape + assert y_true.shape == (n,), f"y_true.shape: {y_true.shape} is not equal to n: {n}" - if ensembling == "full": + if prompt_ensembling == PromptEnsembling.FULL: y_logits = y_logits.mean(dim=1) else: - y_true = repeat(y_true, "n -> n v", v=v) + y_true = repeat(y_true, "n -> n v", v=num_variants) - THRESHOLD = 0.5 - if ensembling == "none": - y_pred = y_logits[..., 1].gt(THRESHOLD).to(torch.int) - else: - y_pred = y_logits.argmax(dim=-1) + return y_logits, y_true, num_classes - acc = accuracy_ci(y_true, y_pred) - if ensembling == "none": - auroc = roc_auc_ci(to_one_hot(y_true, c).long().flatten(1), y_logits.flatten(1)) - elif ensembling in ("partial", "full"): - # Pool together the negative and positive class logits - if c == 2: - auroc = roc_auc_ci(y_true, y_logits[..., 1] - y_logits[..., 0]) - else: - auroc = roc_auc_ci(to_one_hot(y_true, c).long(), y_logits) - else: - raise ValueError(f"Unknown mode: {ensembling}") +def calc_eval_results( + y_true: Tensor, + y_logits: Tensor, + prompt_ensembling: PromptEnsembling, + num_classes: int, +) -> EvalResult: + """ + Calculate the evaluation results - cal_acc = None - cal_err = None - cal_thresh = None + Args: + y_true: Ground truth tensor of shape (n,). + y_logits: Predicted class tensor of shape (n, num_variants, num_classes). + prompt_ensembling: The prompt_ensembling mode. - if c == 2: - pooled_logits = ( - y_logits[..., 1] - if ensembling == "none" - else y_logits[..., 1] - y_logits[..., 0] - ) - pos_probs = torch.sigmoid(pooled_logits) + Returns: + EvalResult: The result of evaluating a classifier containing the accuracy, + calibrated accuracies, calibrated errors, and AUROC. + """ + acc = calc_accuracies(y_logits=y_logits, y_true=y_true) - # Calibrated accuracy - cal_thresh = pos_probs.float().quantile(y_true.float().mean()).item() - cal_preds = pos_probs.gt(cal_thresh).to(torch.int) - cal_acc = accuracy_ci(y_true, cal_preds) + pos_probs = torch.sigmoid(y_logits[..., 1] - y_logits[..., 0]) + cal_acc, cal_thresh = ( + calc_calibrated_accuracies(y_true=y_true, pos_probs=pos_probs) + if num_classes == 2 + else None, + None, + ) + cal_err = ( + calc_calibrated_errors(y_true=y_true, pos_probs=pos_probs) + if num_classes == 2 + else None + ) - cal = CalibrationError().update(y_true.flatten(), pos_probs.flatten()) - cal_err = cal.compute() + auroc = calc_auroc( + y_logits=y_logits, + y_true=y_true, + ensembling=prompt_ensembling, + num_classes=num_classes, + ) return EvalResult(acc, cal_acc, cal_err, auroc, cal_thresh) @@ -127,3 +231,49 @@ def to_one_hot(labels: Tensor, n_classes: int) -> Tensor: """ one_hot_labels = labels.new_zeros(*labels.shape, n_classes) return one_hot_labels.scatter_(-1, labels.unsqueeze(-1).long(), 1) + + +def layer_ensembling( + layer_outputs: list[LayerOutput], prompt_ensembling: PromptEnsembling +) -> EvalResult: + """ + Return EvalResult after prompt_ensembling + the probe output of the middle to last layers + + Args: + layer_outputs: A list of LayerOutput containing the ground truth and + predicted class tensor of shape (n, num_variants, num_classes). + prompt_ensembling: The prompt_ensembling mode. + + Returns: + EvalResult: The result of evaluating a classifier containing the accuracy, + calibrated accuracies, calibrated errors, and AUROC. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + y_logits_collection = [] + + num_classes = 2 + y_true = layer_outputs[0].val_gt.to(device) + + for layer_output in layer_outputs: + # all y_trues are identical, so just get the first + y_logits = layer_output.val_credences.to(device) + y_logits, y_true, num_classes = prepare( + y_logits=y_logits, + y_true=layer_outputs[0].val_gt.to(device), + prompt_ensembling=prompt_ensembling, + ) + y_logits_collection.append(y_logits) + + # get logits and ground_truth from middle to last layer + middle_index = len(layer_outputs) // 2 + y_logits_stacked = torch.stack(y_logits_collection[middle_index:]) + # layer prompt_ensembling of the stacked logits + y_logits_stacked_mean = torch.mean(y_logits_stacked, dim=0) + + return calc_eval_results( + y_true=y_true, + y_logits=y_logits_stacked_mean, + prompt_ensembling=prompt_ensembling, + num_classes=num_classes, + ) diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index fa183e5af..04fdb0393 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from pathlib import Path +from typing import Iterable import pandas as pd import plotly.express as px @@ -9,6 +10,8 @@ from rich.console import Console from rich.table import Table +from elk.utils.types import PromptEnsembling + @dataclass class SweepByDsMultiplot: @@ -19,16 +22,16 @@ class SweepByDsMultiplot: def render( self, sweep: "SweepVisualization", - with_transfer=False, - ensembles=["full", "partial", "none"], - write=False, + with_transfer: bool = False, + ensemblings: Iterable[PromptEnsembling] = PromptEnsembling.all(), + write: bool = False, ) -> go.Figure: """Render the multiplot visualization. Args: sweep: The SweepVisualization instance containing the data. with_transfer: Flag indicating whether to include transfer eval data. - ensembles: Filter for which ensembing options to include. + ensemblings: Filter for which ensembing options to include. write: Flag indicating whether to write the visualization to disk. Returns: @@ -49,13 +52,15 @@ def render( x_title="Layer", y_title="AUROC", ) - color_map = dict(zip(ensembles, qualitative.Plotly)) + color_map = dict(zip(ensemblings, qualitative.Plotly)) - for ensemble in ensembles: - ensemble_data: pd.DataFrame = df[df["ensembling"] == ensemble] + for prompt_ensembling in ensemblings: + ensemble_data: pd.DataFrame = df[ + df["prompt_ensembling"] == prompt_ensembling.value + ] if with_transfer: # TODO write tests ensemble_data = ensemble_data.groupby( - ["eval_dataset", "layer", "ensembling"], as_index=False + ["eval_dataset", "layer", "prompt_ensembling"], as_index=False ).agg({"auroc_estimate": "mean"}) else: ensemble_data = ensemble_data[ @@ -77,11 +82,11 @@ def render( x=dataset_data["layer"], y=dataset_data["auroc_estimate"], mode="lines", - name=ensemble, + name=prompt_ensembling.value, showlegend=False if dataset_name != unique_datasets[0] else True, - line=dict(color=color_map[ensemble]), + line=dict(color=color_map[prompt_ensembling]), ), row=row, col=col, @@ -93,7 +98,7 @@ def render( fig.update_layout( legend=dict( - title="Ensembling", + title="prompt_ensembling", ), title=f"AUROC Trend: {self.model_name}", ) @@ -115,7 +120,7 @@ class TransferEvalHeatmap: layer: int score_type: str = "auroc_estimate" - ensembling: str = "full" + prompt_ensembling: PromptEnsembling = PromptEnsembling.FULL def render(self, df: pd.DataFrame) -> go.Figure: """Render the heatmap visualization. @@ -245,7 +250,7 @@ def render_and_save( sweep: "SweepVisualization", dataset_names: list[str] | None = None, score_type="auroc_estimate", - ensembling="full", + prompt_ensembling=PromptEnsembling.FULL, ) -> None: """Render and save the visualization for the model. @@ -253,7 +258,7 @@ def render_and_save( sweep: The SweepVisualization instance. dataset_names: List of dataset names to include in the visualization. score_type: The type of score to display. - ensembling: The ensembling option to consider. + prompt_ensembling: The prompt_ensembling option to consider. """ df = self.df model_name = self.model_name @@ -262,9 +267,12 @@ def render_and_save( model_path.mkdir(parents=True, exist_ok=True) if self.is_transfer: for layer in range(layer_min, layer_max + 1): - filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)] + filtered = df[ + (df["layer"] == layer) + & (df["prompt_ensembling"] == prompt_ensembling.value) + ] fig = TransferEvalHeatmap( - layer, score_type=score_type, ensembling=ensembling + layer, score_type=score_type, prompt_ensembling=prompt_ensembling ).render(filtered) fig.write_image(file=model_path / f"{layer}.png") fig = TransferEvalTrend(dataset_names).render(df) @@ -382,7 +390,7 @@ def render_table( Returns: The generated score table as a pandas DataFrame. """ - df = self.df[self.df["ensembling"] == "partial"] + df = self.df[self.df["prompt_ensembling"] == PromptEnsembling.PARTIAL.value] # For each model, we use the layer whose mean AUROC is the highest best_layers, model_dfs = [], [] diff --git a/elk/run.py b/elk/run.py index fb8903ccf..55c449540 100644 --- a/elk/run.py +++ b/elk/run.py @@ -21,6 +21,7 @@ from .extraction import Extract, extract from .extraction.dataset_name import DatasetDictWithName from .files import elk_reporter_dir, memorably_named_dir +from .metrics.eval import LayerOutput, layer_ensembling from .utils import ( Color, assert_type, @@ -29,6 +30,56 @@ select_split, select_usable_devices, ) +from .utils.types import PromptEnsembling + +PROMPT_ENSEMBLING = "prompt_ensembling" + + +@dataclass(frozen=True) +class LayerApplied: + layer_outputs: list[LayerOutput] + """The output of the reporter on the layer, should contain credences and ground + truth labels.""" + df_dict: dict[str, pd.DataFrame] + """The evaluation results for the layer.""" + + +def calculate_layer_outputs(layer_outputs: list[LayerOutput], out_path: Path): + """ + Calculate the layer ensembling results for each dataset + and prompt ensembling and save them to a CSV file. + + Args: + layer_outputs: The layer outputs to calculate the results for. + out_path: The path to save the results to. + """ + grouped_layer_outputs = {} + for layer_output in layer_outputs: + dataset_name = layer_output.meta["dataset"] + if dataset_name in grouped_layer_outputs: + grouped_layer_outputs[dataset_name].append(layer_output) + else: + grouped_layer_outputs[dataset_name] = [layer_output] + + dfs = [] + for dataset_name, layer_outputs in grouped_layer_outputs.items(): + for prompt_ensembling in PromptEnsembling.all(): + res = layer_ensembling( + layer_outputs=layer_outputs, + prompt_ensembling=prompt_ensembling, + ) + df = pd.DataFrame( + { + "dataset": dataset_name, + PROMPT_ENSEMBLING: prompt_ensembling.value, + **res.to_dict(), + }, + index=[0], + ).round(4) + dfs.append(df) + + df_concat = pd.concat(dfs) + df_concat.to_csv(out_path, index=False) @dataclass @@ -98,7 +149,7 @@ def execute( devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) num_devices = len(devices) - func: Callable[[int], dict[str, pd.DataFrame]] = partial( + func: Callable[[int], LayerApplied] = partial( self.apply_to_layer, devices=devices, world_size=num_devices ) self.apply_to_layers(func=func, num_devices=num_devices) @@ -106,7 +157,7 @@ def execute( @abstractmethod def apply_to_layer( self, layer: int, devices: list[str], world_size: int - ) -> dict[str, pd.DataFrame]: + ) -> LayerApplied: """Train or eval a reporter on a single layer.""" def make_reproducible(self, seed: int): @@ -155,7 +206,7 @@ def concatenate(self, layers): def apply_to_layers( self, - func: Callable[[int], dict[str, pd.DataFrame]], + func: Callable[[int], LayerApplied], num_devices: int, ): """Apply a function to each layer of the datasets in parallel @@ -178,15 +229,21 @@ def apply_to_layers( with ctx.Pool(num_devices) as pool: mapper = pool.imap_unordered if num_devices > 1 else map df_buffers = defaultdict(list) - + layer_outputs: list[LayerOutput] = [] try: - for df_dict in tqdm(mapper(func, layers), total=len(layers)): - for k, v in df_dict.items(): + for res in tqdm(mapper(func, layers), total=len(layers)): + layer_outputs.extend(res.layer_outputs) + for k, v in res.df_dict.items(): # type: ignore df_buffers[k].append(v) finally: # Make sure the CSVs are written even if we crash or get interrupted for name, dfs in df_buffers.items(): - df = pd.concat(dfs).sort_values(by=["layer", "ensembling"]) + df = pd.concat(dfs).sort_values(by=["layer", PROMPT_ENSEMBLING]) df.round(4).to_csv(self.out_dir / f"{name}.csv", index=False) if self.debug: save_debug_log(self.datasets, self.out_dir) + + calculate_layer_outputs( + layer_outputs=layer_outputs, + out_path=self.out_dir / "layer_ensembling.csv", + ) diff --git a/elk/training/train.py b/elk/training/train.py index 3d00b54c1..882babc35 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -13,8 +13,10 @@ from ..extraction import Extract from ..metrics import evaluate_preds, to_one_hot -from ..run import Run +from ..metrics.eval import LayerOutput +from ..run import LayerApplied, Run from ..training.supervised import train_supervised +from ..utils.types import PromptEnsembling from ..utils.typing import assert_type from .ccs_reporter import CcsConfig, CcsReporter from .common import FitterConfig @@ -63,7 +65,7 @@ def apply_to_layer( layer: int, devices: list[str], world_size: int, - ) -> dict[str, pd.DataFrame]: + ) -> LayerApplied: """Train a single reporter on a single layer.""" self.make_reproducible(seed=self.net.seed + layer) @@ -136,19 +138,30 @@ def apply_to_layer( lr_models = [] row_bufs = defaultdict(list) + layer_output = [] for ds_name in val_dict: val_h, val_gt, val_lm_preds = val_dict[ds_name] train_h, train_gt, train_lm_preds = train_dict[ds_name] meta = {"dataset": ds_name, "layer": layer} val_credences = reporter(val_h) + layer_output.append( + LayerOutput( + val_gt=val_gt.detach(), + val_credences=val_credences.detach(), + meta=meta, + ) + ) + train_credences = reporter(train_h) - for mode in ("none", "partial", "full"): + for prompt_ensembling in PromptEnsembling.all(): row_bufs["eval"].append( { **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_credences, mode).to_dict(), + "prompt_ensembling": prompt_ensembling.value, + **evaluate_preds( + val_gt, val_credences, prompt_ensembling + ).to_dict(), "train_loss": train_loss, } ) @@ -156,8 +169,10 @@ def apply_to_layer( row_bufs["train_eval"].append( { **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_credences, mode).to_dict(), + "prompt_ensembling": prompt_ensembling.value, + **evaluate_preds( + train_gt, train_credences, prompt_ensembling + ).to_dict(), "train_loss": train_loss, } ) @@ -166,8 +181,10 @@ def apply_to_layer( row_bufs["lm_eval"].append( { **meta, - "ensembling": mode, - **evaluate_preds(val_gt, val_lm_preds, mode).to_dict(), + "prompt_ensembling": prompt_ensembling.value, + **evaluate_preds( + val_gt, val_lm_preds, prompt_ensembling + ).to_dict(), } ) @@ -175,8 +192,10 @@ def apply_to_layer( row_bufs["train_lm_eval"].append( { **meta, - "ensembling": mode, - **evaluate_preds(train_gt, train_lm_preds, mode).to_dict(), + "prompt_ensembling": prompt_ensembling.value, + **evaluate_preds( + train_gt, train_lm_preds, prompt_ensembling + ).to_dict(), } ) @@ -184,10 +203,14 @@ def apply_to_layer( row_bufs["lr_eval"].append( { **meta, - "ensembling": mode, + "prompt_ensembling": prompt_ensembling.value, "inlp_iter": i, - **evaluate_preds(val_gt, model(val_h), mode).to_dict(), + **evaluate_preds( + val_gt, model(val_h), prompt_ensembling + ).to_dict(), } ) - return {k: pd.DataFrame(v) for k, v in row_bufs.items()} + return LayerApplied( + layer_output, {k: pd.DataFrame(v) for k, v in row_bufs.items()} + ) diff --git a/elk/utils/types.py b/elk/utils/types.py new file mode 100644 index 000000000..eadeb81ad --- /dev/null +++ b/elk/utils/types.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class PromptEnsembling(Enum): + FULL = "full" + PARTIAL = "partial" + NONE = "none" + + @staticmethod + def all() -> tuple["PromptEnsembling"]: + return tuple(PromptEnsembling) diff --git a/pyproject.toml b/pyproject.toml index f3f16504a..b0a078cde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ # We upstreamed bugfixes for Literal types in 0.1.1 "simple-parsing>=0.1.1", # Version 1.11 introduced Fully Sharded Data Parallel, which we plan to use soon - "torch>=1.11.0", + "torch==2.0", # Doesn't really matter but versions < 4.0 are very very old (pre-2016) "tqdm>=4.0.0", # 4.0 introduced the breaking change of using return_dict=True by default @@ -37,7 +37,7 @@ dependencies = [ # For visualization of results "plotly==5.14.1", "kaleido==0.2.1", - "rich==13.3.5" + "rich" ] version = "0.1.1" diff --git a/tests/test_smoke_elicit.py b/tests/test_smoke_elicit.py index bac0f3989..4ae8b22b0 100644 --- a/tests/test_smoke_elicit.py +++ b/tests/test_smoke_elicit.py @@ -31,6 +31,7 @@ def test_smoke_elicit_run_tiny_gpt2_ccs(tmp_path: Path): "lr_models", "reporters", "eval.csv", + "layer_ensembling.csv", ] for file in expected_files: assert file in created_file_names @@ -62,6 +63,7 @@ def test_smoke_elicit_run_tiny_gpt2_eigen(tmp_path: Path): "lr_models", "reporters", "eval.csv", + "layer_ensembling.csv", ] for file in expected_files: assert file in created_file_names diff --git a/tests/test_smoke_eval.py b/tests/test_smoke_eval.py index 4efd7112d..7f0bad7ea 100644 --- a/tests/test_smoke_eval.py +++ b/tests/test_smoke_eval.py @@ -11,6 +11,7 @@ "cfg.yaml", "fingerprints.yaml", "eval.csv", + "layer_ensembling.csv", ]