-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
56 lines (48 loc) · 1.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import pathlib
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from resnet import resnet152
from torch import nn
def main():
# Create a pytorch dataset
data_dir = pathlib.Path('./data/tiny-imagenet-200')
image_count = len(list(data_dir.glob('**/*.JPEG')))
CLASS_NAMES = np.array([item.name for item in (data_dir / 'train').glob('*')])
print('Discovered {} images'.format(image_count))
# Create the training data generator
batch_size = 32
im_height = 64
im_width = 64
num_epochs = 1
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0, 0, 0), tuple(np.sqrt((255, 255, 255)))),
])
train_set = torchvision.datasets.ImageFolder(data_dir / 'train', data_transforms)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
shuffle=True, num_workers=4, pin_memory=True)
# Create a simple model
# model = ResNet(len(CLASS_NAMES), im_height, im_width)
model = resnet152(pretrained=False)
optim = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for i in range(num_epochs):
train_total, train_correct = 0,0
for idx, (inputs, targets) in enumerate(train_loader):
optim.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optim.step()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
print("\r", end='')
print(f'training {100 * idx / len(train_loader):.2f}%: {train_correct / train_total:.3f}', end='')
torch.save({
'net': model.state_dict(),
}, 'latest.pt')
if __name__ == '__main__':
main()