-
Notifications
You must be signed in to change notification settings - Fork 2
/
trainer.py
111 lines (91 loc) · 3.45 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
so nothing in this file really has anything to do with GPT specifically.
References:
0) Karpathy's code from https://github.com/karpathy/minGPT
"""
import time
from collections import defaultdict
import torch
from dataset import DatasetOf24Game
from model import GPT
from torch.utils.data.dataloader import DataLoader
from utils import ConfigNode
class Trainer:
@staticmethod
def get_default_config():
C = ConfigNode()
# device to train on
C.device = 'auto'
# optimizer parameters
C.max_iters = None
C.batch_size = 64
C.learning_rate = 3e-4
C.betas = (0.9, 0.95)
C.weight_decay = 0.1 # only applied on matmul weights
C.grad_norm_clip = 1.0
return C
def __init__(self, config: ConfigNode, model: GPT, train_dataset: DatasetOf24Game):
self.config = config
self.model = model
self.optimizer = None
self.train_dataset = train_dataset
self.callbacks = defaultdict(list)
# determine the device we'll train on
if config.device == 'auto':
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
else:
self.device = config.device
self.model = self.model.to(self.device)
print("running on device", self.device)
# variables that will be assigned to trainer class later for logging and etc
self.iter_num = 0
self.iter_time = 0.0
self.iter_dt = 0.0
def add_callback(self, onevent: str, callback):
self.callbacks[onevent].append(callback)
def set_callback(self, onevent: str, callback):
self.callbacks[onevent] = [callback]
def trigger_callbacks(self, onevent: str):
for callback in self.callbacks.get(onevent, []):
callback(self)
def run(self) -> None:
model, config = self.model, self.config
# setup the optimizer
self.optimizer = model.configure_optimizers(config)
# setup the dataloader
train_loader = DataLoader(
self.train_dataset,
sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)),
shuffle=False,
pin_memory=True,
batch_size=config.batch_size,
)
model.train()
self.iter_num = 0
self.iter_time = time.time()
data_iter = iter(train_loader)
while True:
# fetch the next batch (x, y) and re-init iterator if needed
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
batch = next(data_iter)
batch = [t.to(self.device) for t in batch]
x, y = batch
# forward the model
logits, self.loss = model(x, y)
# backprop and update the parameters
model.zero_grad(set_to_none=True)
self.loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
self.optimizer.step()
self.trigger_callbacks('on_batch_end')
self.iter_num += 1
tnow = time.time()
self.iter_dt = tnow - self.iter_time
self.iter_time = tnow
# termination conditions
if config.max_iters is not None and self.iter_num >= config.max_iters:
break