-
Notifications
You must be signed in to change notification settings - Fork 0
/
distil.py
82 lines (71 loc) · 2.87 KB
/
distil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Distil a model into a vessel adapter"""
import os
import hydra
from omegaconf import DictConfig, OmegaConf
import transformers
import pytorch_lightning as pl
from claficle.data.oscar import OSCARDataModule
from claficle.utils.general import run_script_preamble
from claficle.models.vessel import Vessel
@hydra.main(version_base=None, config_path="../conf", config_name="distil")
def main(cfg: DictConfig):
# setting seed and initializing model
model: Vessel
model, cfg = run_script_preamble(cfg)
# additional post initialization (activating adapter, freezing gpt2)
model.post_init(seed=cfg.seed)
# we are only doing vessel distillation in english
lang = "en"
# data
oscar = OSCARDataModule(config=cfg.data, lang=lang, seed=cfg.seed)
tokenizer = transformers.AutoTokenizer.from_pretrained(cfg.model.causalLM_variant)
oscar.prepare_data()
oscar.set_tokenizer(tokenizer)
oscar.setup("distillation")
# trainer
log_save_dir = os.path.join(
cfg.trainer.log_dir, cfg.model.name, f"seed_{cfg.seed}", "distillation", lang
)
os.makedirs(log_save_dir, exist_ok=True)
script_host = "slurm" if "SLURM_JOB_ID" in os.environ else "local"
logger = pl.loggers.WandbLogger(
save_dir=log_save_dir,
entity="giulio-uva",
project="claficle",
job_type="distillation",
mode="disabled" if cfg.disable_wandb else "online",
group=script_host,
config=OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True),
log_model=False, # don't log or upload artifacts
)
checkpoint_callback = pl.callbacks.ModelCheckpoint( # save best checkpoints
dirpath=cfg.model.checkpoint_dir,
filename=f"adapter_distillation_-v{cfg.seed}",
monitor=f"{cfg.trainer.train_mode}/val/perplexity",
mode="min",
auto_insert_metric_name=False,
save_on_train_epoch_end=False,
)
early_stopping_callback = pl.callbacks.EarlyStopping(
monitor=f"{cfg.trainer.train_mode}/val/perplexity", patience=4, mode="min"
)
lr_monitor_callback = pl.callbacks.LearningRateMonitor(logging_interval="step")
trainer = pl.Trainer(
max_epochs=1,
logger=logger,
enable_progress_bar=cfg.trainer.progress_bar,
accelerator=cfg.trainer.accelerator,
devices=cfg.trainer.devices,
gradient_clip_algorithm="norm",
gradient_clip_val=cfg.trainer.clip_grad_norm,
accumulate_grad_batches=cfg.trainer.accumulate_grad_batches,
val_check_interval=cfg.trainer.val_check_interval,
log_every_n_steps=cfg.trainer.log_every_n_steps, # log every batch
callbacks=[checkpoint_callback, early_stopping_callback, lr_monitor_callback],
precision=16,
deterministic=True,
)
model.train_mode = cfg.trainer.train_mode
trainer.fit(model, oscar)
if __name__ == "__main__":
main()