-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
86 lines (68 loc) · 2.91 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import tqdm
import torch
from Core.Functional import accuracy
from Core.Functional import AverageMeter
class BackboneTrainer():
"""Class to train and evaluate backbones."""
def train(self,
train_loader,
model,
criterion,
optimizer: torch.optim.Optimizer,
device: torch.device) :
"""Run one epoch of training.
:param train_loader: Data loader to train the model.
:param model: Model to be trained.
:param criterion: Loss criterion module.
:param optimizer: A torch optimizer object.
:param device: Device the model is on.
:return: average of top-1, top-5, and loss on current epoch.
"""
losses = AverageMeter("Loss", ":.3f")
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
model.train()
for i, (images, target) in tqdm.tqdm(
enumerate(train_loader), ascii=True, total=len(train_loader)
):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output, _ = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, top_k=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
top5.update(acc5.item(), images.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
return top1.avg, top5.avg, losses.avg
def validate(self,
val_loader,
model ,
criterion,
device: torch.device):
"""Run validation.
:param val_loader: Data loader to evaluate the model.
:param model: Model to be evaluated.
:param criterion: Loss criterion module.
:param device: Device the model is on.
:return: average of top-1, top-5, and loss on current epoch.
"""
losses = AverageMeter("Loss", ":.3f")
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
model.eval()
with torch.no_grad():
for i, (images, target) in tqdm.tqdm(
enumerate(val_loader), ascii=True, total=len(val_loader)
):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
output, _ = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, top_k=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
top5.update(acc5.item(), images.size(0))
return top1.avg, top5.avg, losses.avg