From daec1218e92cefcc0b67d17d8e55c72d5faec1cc Mon Sep 17 00:00:00 2001 From: jon Date: Thu, 13 Jul 2023 15:23:58 +0100 Subject: [PATCH] refactor reporter training --- elk/training/train.py | 65 ++++++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index 570d49ad..3eb018bc 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -17,10 +17,16 @@ from ..training.supervised import train_supervised from ..utils.typing import assert_type from .ccs_reporter import CcsConfig, CcsReporter -from .common import FitterConfig +from .common import FitterConfig, Reporter from .eigen_reporter import EigenFitter, EigenFitterConfig +@dataclass +class ReporterTrainResult: + reporter: CcsReporter | Reporter + train_loss: float | None + + @dataclass class Elicit(Run): """Full specification of a reporter training run.""" @@ -69,22 +75,11 @@ def make_eval(self, model, eval_dataset): disable_cache=self.disable_cache, ) - def apply_to_layer( - self, - layer: int, - devices: list[str], - world_size: int, - probe_per_prompt: bool, - ) -> dict[str, pd.DataFrame]: - """Train a single reporter on a single layer.""" - - self.make_reproducible(seed=self.net.seed + layer) - device = self.get_device(devices, world_size) - + # Create a separate function to handle the reporter training. + def train_reporter(self, device, layer, out_dir) -> ReporterTrainResult: train_dict = self.prepare_data(device, layer, "train") - val_dict = self.prepare_data(device, layer, "val") - (first_train_h, train_gt, _), *rest = train_dict.values() + (first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove? (_, v, k, d) = first_train_h.shape if not all(other_h.shape[-1] == d for other_h, _, _ in rest): raise ValueError("All datasets must have the same hidden state size") @@ -96,16 +91,12 @@ def apply_to_layer( if not all(other_h.shape[-2] == k for other_h, _, _ in rest): raise ValueError("All datasets must have the same number of classes") - reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) train_loss = None - if isinstance(self.net, CcsConfig): assert len(train_dict) == 1, "CCS only supports single-task training" - reporter = CcsReporter(self.net, d, device=device, num_variants=v) train_loss = reporter.fit(first_train_h) - (_, v, k, _) = first_train_h.shape reporter.platt_scale( to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(), rearrange(first_train_h, "n v k d -> (n v k) d"), @@ -137,20 +128,50 @@ def apply_to_layer( raise ValueError(f"Unknown reporter config type: {type(self.net)}") # Save reporter checkpoint to disk - torch.save(reporter, reporter_dir / f"layer_{layer}.pt") + torch.save(reporter, out_dir / f"layer_{layer}.pt") - # Fit supervised logistic regression model + return ReporterTrainResult(reporter, train_loss) + + def train_lr_model(self, train_dict, device, layer, out_dir): if self.supervised != "none": lr_models = train_supervised( train_dict, device=device, mode=self.supervised, ) - with open(lr_dir / f"layer_{layer}.pt", "wb") as file: + with open(out_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_models, file) else: lr_models = [] + return lr_models + + def apply_to_layer( + self, + layer: int, + devices: list[str], + world_size: int, + probe_per_prompt: bool, + ) -> dict[str, pd.DataFrame]: + """Train a single reporter on a single layer.""" + + self.make_reproducible(seed=self.net.seed + layer) + device = self.get_device(devices, world_size) + + train_dict = self.prepare_data(device, layer, "train") + val_dict = self.prepare_data(device, layer, "val") + + (first_train_h, train_gt, _), *rest = train_dict.values() + (_, v, k, d) = first_train_h.shape + + reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir)) + + reporter_train_result = self.train_reporter(device, layer, reporter_dir) + reporter = reporter_train_result.reporter + train_loss = reporter_train_result.train_loss + + lr_models = self.train_lr_model(train_dict, device, layer, lr_dir) + row_bufs = defaultdict(list) for ds_name in val_dict: val_h, val_gt, val_lm_preds = val_dict[ds_name]