-
Notifications
You must be signed in to change notification settings - Fork 1
/
training_config.py
71 lines (56 loc) · 2.13 KB
/
training_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch,os,sys,datetime,time
from tqdm import tqdm
import torch.optim as optim
best_acc=0
def train(net,trainloader,optim,criterion,epoch,device):
print("Training")
net.train()
train_loss = 0
total = 0
total_correct = 0
iterator = tqdm(trainloader)
for inputs,targets in iterator:
inputs,targets = inputs.to(device), targets.to(device)
optim.zero_grad()
outputs,_ = net(inputs)
loss = criterion(outputs,targets)
loss.backward()
optim.step()
train_loss += loss.item()
_,predicted = torch.max(outputs.data,1)
total_correct += (predicted == targets).sum().item()
total += targets.size(0)
print("Epoch: [{}] loss: [{:.2f}] Accuracy [{:.2f}] ".format(epoch+1,train_loss/len(trainloader),
total_correct*100/total))
def test(net,testloader,optim,criterion,epoch,device,filename):
global best_acc
print("validation")
net.eval()
test_loss,total,total_correct = 0,0,0
iterator = tqdm(testloader)
for inputs, targets in iterator:
inputs, targets = inputs.to(device), targets.to(device)
outputs,_ = net(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
total_correct += (predicted == targets).sum().item()
# Save checkpoint when best model
acc = 100. * total_correct / total
print("\nValidation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %(epoch+1, test_loss, acc))
if acc > best_acc:
print('Saving Best model...\t\t\tTop1 = %.2f%%' %(acc))
state = {
'net':net,
'acc':acc,
'epoch':epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
save_point = './checkpoint/'
if not os.path.isdir(save_point):
os.mkdir(save_point)
torch.save(state, save_point+filename+'model.t7')
best_acc = acc
return best_acc