-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_sample_torch.py
62 lines (53 loc) · 2.11 KB
/
train_sample_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
This file will train a sample network on the tiny image-net data. It should be
your final goal to improve on the performance of this model by swapping out large
portions of the code. We provide this model in order to test the full pipeline,
and to validate your own code submission.
"""
import pathlib
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from model import Net
from torch import nn
def main():
# Create a pytorch dataset
data_dir = pathlib.Path('./data/tiny-imagenet-200')
image_count = len(list(data_dir.glob('**/*.JPEG')))
CLASS_NAMES = np.array([item.name for item in (data_dir / 'train').glob('*')])
print('Discovered {} images'.format(image_count))
# Create the training data generator
batch_size = 32
im_height = 64
im_width = 64
num_epochs = 1
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0, 0, 0), tuple(np.sqrt((255, 255, 255)))),
])
train_set = torchvision.datasets.ImageFolder(data_dir / 'train', data_transforms)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
shuffle=True, num_workers=4, pin_memory=True)
# Create a simple model
model = Net(len(CLASS_NAMES), im_height, im_width)
optim = torch.optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()
for i in range(num_epochs):
train_total, train_correct = 0,0
for idx, (inputs, targets) in enumerate(train_loader):
optim.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optim.step()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
print("\r", end='')
print(f'training {100 * idx / len(train_loader):.2f}%: {train_correct / train_total:.3f}', end='')
torch.save({
'net': model.state_dict(),
}, 'latest.pt')
if __name__ == '__main__':
main()