-
Notifications
You must be signed in to change notification settings - Fork 0
/
forward.py
159 lines (140 loc) · 6.64 KB
/
forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
'''
Check the correctness of torch/numpy conversion.
'''
import numpy as np
import matplotlib.pyplot as plt
import torch
from models.ae import FCAutoEncoder, FCAutoEncoder1Layer, MLPClassifer
import util
import pickle
# Some model parameters and experiment parameters
num_sc = 415
num_ctrl = 100
# Load the model
# model = FCAutoEncoder(num_sc, num_ctrl)
model = FCAutoEncoder1Layer(num_sc, num_ctrl)
ckpt = torch.load('checkpoint/ckpt.pth')
model.load_state_dict(ckpt['net'])
bin_op = util.BinOp(model)
print(model)
model.eval()
bin_op.binarization()
state_dict_np = {} # The dictionary to store the weights of BNNs in numpy array format.
for k, v in model.state_dict().items():
# print(k)
state_dict_np[k] = v.data.numpy()
# Store the weight dict
print('Store the weight...')
'''
with open('checkpoint/weight_np.pkl', 'wb') as f:
pickle.dump(state_dict_np, f)
# Two-layers
np.save('checkpoint/encoder.0.linear.weight.npy', state_dict_np['encoder.0.linear.weight'])
np.save('checkpoint/encoder.0.linear.bias.npy', state_dict_np['encoder.0.linear.bias'])
np.save('checkpoint/encoder.1.bn.running_mean.npy', state_dict_np['encoder.1.bn.running_mean'])
np.save('checkpoint/encoder.1.bn.running_var.npy', state_dict_np['encoder.1.bn.running_var'])
np.save('checkpoint/encoder.1.bn.weight.npy', state_dict_np['encoder.1.bn.weight'])
np.save('checkpoint/encoder.1.bn.bias.npy', state_dict_np['encoder.1.bn.bias'])
np.save('checkpoint/encoder.1.linear.weight.npy',state_dict_np['encoder.1.linear.weight'])
np.save('checkpoint/encoder.1.linear.bias.npy', state_dict_np['encoder.1.linear.bias'])
np.save('checkpoint/decoder.0.linear.weight.npy', state_dict_np['decoder.0.linear.weight'])
np.save('checkpoint/decoder.0.linear.bias.npy', state_dict_np['decoder.0.linear.bias'])
np.save('checkpoint/decoder.1.bn.running_mean.npy', state_dict_np['decoder.1.bn.running_mean'])
np.save('checkpoint/decoder.1.bn.running_var.npy', state_dict_np['decoder.1.bn.running_var'])
np.save('checkpoint/decoder.1.bn.weight.npy', state_dict_np['decoder.1.bn.weight'])
np.save('checkpoint/decoder.1.bn.bias.npy', state_dict_np['decoder.1.bn.bias'])
np.save('checkpoint/decoder.1.linear.weight.npy', state_dict_np['decoder.1.linear.weight'])
np.save('checkpoint/decoder.1.linear.bias.npy', state_dict_np['decoder.1.linear.bias'])
'''
# One-layer
thred_encoder = -state_dict_np['encoder.0.bn.bias'] * np.sqrt(state_dict_np['encoder.0.bn.running_var'] + 1e-5) + state_dict_np['encoder.0.bn.running_mean']
thred_encoder = np.floor(thred_encoder)
thred_decoder = -state_dict_np['decoder.0.bn.bias'] * np.sqrt(state_dict_np['decoder.0.bn.running_var'] + 1e-5) + state_dict_np['decoder.0.bn.running_mean']
thred_decoder = np.floor(thred_decoder)
state_dict_np['encoder.0.linear.weight'] = (state_dict_np['encoder.0.linear.weight'] + 1.0) / 2.0
state_dict_np['decoder.0.linear.weight'] = (state_dict_np['decoder.0.linear.weight'] + 1.0) / 2.0
np.save('checkpoint/encoder.linear.weight.npy', state_dict_np['encoder.0.linear.weight'])
np.save('checkpoint/encoder.bn.thred.npy', thred_encoder)
np.save('checkpoint/decoder.linear.weight.npy', state_dict_np['decoder.0.linear.weight'])
np.save('checkpoint/encoder.bn.thred.npy', thred_decoder)
# exit()
# for i in state_dict_np:
# print(i)
# exit()
# Do a forward in torch
# fake_input = torch.ones(1, 415)
fake_input = torch.randn(1, 415).sign()
output_torch = model(fake_input)
# print('Output from torch model', output_torch.sign())
fake_input_np = (fake_input.numpy() + 1.0) / 2.0 # fake_input_np in {0, 1}
# Do a forward in numpy
# encoder[0] linear
# input (1, 415), (215, 415), (215, )
'''
# Two-layer
x = np.matmul(fake_input_np, state_dict_np['encoder.0.linear.weight'].T) + state_dict_np['encoder.0.linear.bias']
# encoder[1] bn
x = (x - state_dict_np['encoder.1.bn.running_mean']) / np.sqrt(state_dict_np['encoder.1.bn.running_var'] + 1e-5) * state_dict_np['encoder.1.bn.weight'] + state_dict_np['encoder.1.bn.bias']
# => {-1, 1}
x = np.sign(x)
# encoder[1] linear
x = np.matmul(x, state_dict_np['encoder.1.linear.weight'].T) + state_dict_np['encoder.1.linear.bias']
# get the encoding bits in (45,)
x = np.sign(x)
# decoder[0] linear
x = np.matmul(x, state_dict_np['decoder.0.linear.weight'].T) + state_dict_np['decoder.0.linear.bias']
# decoder[1] bn
x = (x - state_dict_np['decoder.1.bn.running_mean']) / np.sqrt(state_dict_np['decoder.1.bn.running_var'] + 1e-5) * state_dict_np['decoder.1.bn.weight'] + state_dict_np['decoder.1.bn.bias']
x = np.sign(x)
# decoder[1] linear
x = np.matmul(x, state_dict_np['decoder.1.linear.weight'].T) + state_dict_np['decoder.1.linear.bias']
# check the correctness
# print(np.sum(np.abs(output_torch.data.numpy() - x))/np.prod(x.shape))
# This is the output of BNN
x = np.sign(x)
'''
######
# One-layer
print('Check the correctness')
# decoder linear
x = np.matmul(fake_input_np, state_dict_np['encoder.0.linear.weight'].T)
# encoder[1] bn
# caculate the threshold
# thred_encoder = -state_dict_np['encoder.0.bn.bias']/state_dict_np['encoder.0.bn.weight'] * np.sqrt(state_dict_np['encoder.0.bn.running_var'] + 1e-5) + state_dict_np['encoder.0.bn.running_mean']
x = (x >= thred_encoder).astype(float)
# x = (((x * 2.0 - 1.0)* np.sign(state_dict_np['encoder.0.bn.weight'])) + 1.0) / 2.0
# decoder linear
x = np.matmul(x, state_dict_np['decoder.0.linear.weight'].T)
# decoder bn
# caculate the threshold
# thred_decoder = -state_dict_np['decoder.0.bn.bias']/state_dict_np['decoder.0.bn.weight'] * np.sqrt(state_dict_np['decoder.0.bn.running_var'] + 1e-5) + state_dict_np['decoder.0.bn.running_mean']
# print(thred_decoder)
x = (x >= thred_decoder).astype(float)
# x = (((x * 2.0 - 1.0)* np.sign(state_dict_np['decoder.0.bn.weight'])) + 1.0) / 2.0
# print(output_torch.data.sign())
# check the correctness
print('Error', np.sum(np.abs((output_torch.data.sign().numpy()+1.0)/2.0 - x))/np.prod(x.shape))
# This is the output of BNN
# Plot:
# mlb = np.load('data/mlb_cell.npy')
# mlb = (np.abs(mlb).sum(axis=2) != 0).astype(float)
# sc_counts = np.zeros(num_sc)
# for row in mlb:
# for (eid, element) in enumerate(row):
# if element == 1:
# sc_counts[eid] += 1
# sc_weights = np.sum(state_dict_np['decoder.0.linear.weight'], axis=1)
# # print(sc_counts.shape)
# # print(sc_weights.shape)
# # print(thred_decoder.shape)
# # exit()
# x = np.arange(num_sc)
# # Do the normalization
# y_multi = [sc_counts/np.sum(sc_counts), sc_weights/np.sum(sc_weights), thred_decoder/np.sum(thred_decoder)]
# labels = ['Frequency', 'A weight', 'Threshold']
# fig, ax = plt.subplots()
# ax.bar(x-0.2, y_multi[0], width=0.2, label=labels[0])
# ax.bar(x, y_multi[1], width=0.2, label=labels[1])
# ax.bar(x+0.2, y_multi[2], width=0.2, label=labels[2])
# ax.legend()
# plt.savefig('figs/hist_all.pdf')