Skip to content

Commit

Permalink
refactor reporter training
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 13, 2023
1 parent 78549f5 commit daec121
Showing 1 changed file with 43 additions and 22 deletions.
65 changes: 43 additions & 22 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@
from ..training.supervised import train_supervised
from ..utils.typing import assert_type
from .ccs_reporter import CcsConfig, CcsReporter
from .common import FitterConfig
from .common import FitterConfig, Reporter
from .eigen_reporter import EigenFitter, EigenFitterConfig


@dataclass
class ReporterTrainResult:
reporter: CcsReporter | Reporter
train_loss: float | None


@dataclass
class Elicit(Run):
"""Full specification of a reporter training run."""
Expand Down Expand Up @@ -69,22 +75,11 @@ def make_eval(self, model, eval_dataset):
disable_cache=self.disable_cache,
)

def apply_to_layer(
self,
layer: int,
devices: list[str],
world_size: int,
probe_per_prompt: bool,
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

self.make_reproducible(seed=self.net.seed + layer)
device = self.get_device(devices, world_size)

# Create a separate function to handle the reporter training.
def train_reporter(self, device, layer, out_dir) -> ReporterTrainResult:
train_dict = self.prepare_data(device, layer, "train")
val_dict = self.prepare_data(device, layer, "val")

(first_train_h, train_gt, _), *rest = train_dict.values()
(first_train_h, train_gt, _), *rest = train_dict.values() # TODO can remove?
(_, v, k, d) = first_train_h.shape
if not all(other_h.shape[-1] == d for other_h, _, _ in rest):
raise ValueError("All datasets must have the same hidden state size")
Expand All @@ -96,16 +91,12 @@ def apply_to_layer(
if not all(other_h.shape[-2] == k for other_h, _, _ in rest):
raise ValueError("All datasets must have the same number of classes")

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
train_loss = None

if isinstance(self.net, CcsConfig):
assert len(train_dict) == 1, "CCS only supports single-task training"

reporter = CcsReporter(self.net, d, device=device, num_variants=v)
train_loss = reporter.fit(first_train_h)

(_, v, k, _) = first_train_h.shape
reporter.platt_scale(
to_one_hot(repeat(train_gt, "n -> (n v)", v=v), k).flatten(),
rearrange(first_train_h, "n v k d -> (n v k) d"),
Expand Down Expand Up @@ -137,20 +128,50 @@ def apply_to_layer(
raise ValueError(f"Unknown reporter config type: {type(self.net)}")

# Save reporter checkpoint to disk
torch.save(reporter, reporter_dir / f"layer_{layer}.pt")
torch.save(reporter, out_dir / f"layer_{layer}.pt")

# Fit supervised logistic regression model
return ReporterTrainResult(reporter, train_loss)

def train_lr_model(self, train_dict, device, layer, out_dir):
if self.supervised != "none":
lr_models = train_supervised(
train_dict,
device=device,
mode=self.supervised,
)
with open(lr_dir / f"layer_{layer}.pt", "wb") as file:
with open(out_dir / f"layer_{layer}.pt", "wb") as file:
torch.save(lr_models, file)
else:
lr_models = []

return lr_models

def apply_to_layer(
self,
layer: int,
devices: list[str],
world_size: int,
probe_per_prompt: bool,
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

self.make_reproducible(seed=self.net.seed + layer)
device = self.get_device(devices, world_size)

train_dict = self.prepare_data(device, layer, "train")
val_dict = self.prepare_data(device, layer, "val")

(first_train_h, train_gt, _), *rest = train_dict.values()
(_, v, k, d) = first_train_h.shape

reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))

reporter_train_result = self.train_reporter(device, layer, reporter_dir)
reporter = reporter_train_result.reporter
train_loss = reporter_train_result.train_loss

lr_models = self.train_lr_model(train_dict, device, layer, lr_dir)

row_bufs = defaultdict(list)
for ds_name in val_dict:
val_h, val_gt, val_lm_preds = val_dict[ds_name]
Expand Down

0 comments on commit daec121

Please sign in to comment.