-
Notifications
You must be signed in to change notification settings - Fork 2
/
01_hbp_single_layer_mnist.py
139 lines (110 loc) · 3.29 KB
/
01_hbp_single_layer_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from bpexts.hbp.crossentropy import HBPCrossEntropyLoss
from bpexts.hbp.linear import HBPLinear
from bpexts.optim.cg_newton import CGNewton
from bpexts.utils import set_seeds
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
set_seeds(0)
batch_size = 500
# download directory
data_dir = "~/tmp/MNIST"
# training set loader
train_set = torchvision.datasets.MNIST(
root=data_dir, train=True, transform=transforms.ToTensor(), download=True
)
train_loader = torch.utils.data.DataLoader(
dataset=train_set, batch_size=batch_size, shuffle=True
)
# layer parameters
in_features = 784
out_features = 10
bias = True
# linear layer
model = HBPLinear(in_features=in_features, out_features=out_features, bias=bias)
# load to device
model.to(device)
print(model)
loss_func = HBPCrossEntropyLoss()
# learning rate
lr = 0.1
# regularization
alpha = 0.02
# convergence criteria for CG
cg_maxiter = 50
cg_atol = 0.0
cg_tol = 0.1
# construct the optimizer
optimizer = CGNewton(
model.parameters(),
lr=lr,
alpha=alpha,
cg_atol=cg_atol,
cg_tol=cg_tol,
cg_maxiter=cg_maxiter,
)
# use the GGN
modify_2nd_order_terms = "zero"
# train for two epochs
num_epochs = 2
# log some metrics
train_epoch = []
batch_loss = []
batch_acc = []
samples = 0
samples_per_epoch = 60000.0
for epoch in range(num_epochs):
iters = len(train_loader)
for i, (images, labels) in enumerate(train_loader):
# reshape and load to device
images = images.reshape(-1, in_features).to(device)
labels = labels.to(device)
# 1) forward pass
outputs = model(images)
loss = loss_func(outputs, labels)
# set gradients to zero
optimizer.zero_grad()
# Hessian backpropagation and backward pass
# 2) compute gradients
loss.backward()
# 3) batch average of Hessian of loss w.r.t. model output
output_hessian = loss_func.batch_summed_hessian(loss, outputs)
# 4) propagate Hessian back through the graph
model.backward_hessian(
output_hessian, modify_2nd_order_terms=modify_2nd_order_terms
)
# 5) second-order optimization step
optimizer.step()
# compute statistics
total = labels.size(0)
_, predicted = torch.max(outputs, 1)
correct = (predicted == labels).sum().item()
accuracy = correct / total
# update lists
samples += total
train_epoch.append(samples / samples_per_epoch)
batch_loss.append(loss.item())
batch_acc.append(accuracy)
# print every 5 iterations
if i % 5 == 0:
print(
"Epoch [{}/{}], Iter. [{}/{}], Loss: {:.4f}, Acc.: {:.4f}".format(
epoch + 1, num_epochs, i + 1, iters, loss.item(), accuracy
)
)
plt.subplots(121, figsize=(7, 3))
# plot batch loss
plt.subplot(121)
plt.plot(train_epoch, batch_loss, color="darkorange")
plt.xlabel("epoch")
plt.ylabel("batch loss")
# plot batch accuracy
plt.subplot(122)
plt.plot(train_epoch, batch_acc, color="darkblue")
plt.xlabel("epoch")
plt.ylabel("batch accuracy")
# save plot
plt.tight_layout()
plt.savefig("01_hbp_single_layer_mnist_metrics.png")