forked from jack-willturner/deep-compression
-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
293 lines (230 loc) · 8.93 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
from __future__ import print_function
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
# 自定义
import config
device = config.DEVICE
global error_history
error_history = []
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_cifar_loaders(data_loc=config.DATA, batch_size=128, cutout=True, n_holes=1, length=16):
"""加载 cifar10数据集
Args:
data_loc (str): cifar10数据集的位置。
cutout (bool): 如果为真,将对每张训练集图像都随机制造 n_holes个黑色正方形区域。
n_holes (int): 黑色正方形区域的数量,只在 cutout为真时生效。
"""
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])
if cutout:
transform_train.transforms.append(Cutout(n_holes=n_holes, length=length))
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])
train_set = torchvision.datasets.CIFAR10(root=data_loc, train=True, download=False, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
test_set = torchvision.datasets.CIFAR10(root=data_loc, train=False, download=False, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1)
return train_loader, test_loader
def load_model(model, sd, old_format=False):
"""加载预训练模型(权重)。
Args:
sd (str): 要加载的 checkpoint文件名,sd即 state dict。
old_format (bool): 用于加载一般格式的 checkpoint文件。
Returns:
model: 加载过checkpoint文件的模型。
sd: 所加载checkpoint文件的内容。
"""
sd = torch.load('checkpoints/%s.t7' % sd, map_location='cpu')
new_sd = model.state_dict()
if 'state_dict' in sd.keys():
old_sd = sd['state_dict']
else:
old_sd = sd['net']
# 将所加载的模型的权重复制到新模型中
if old_format:
# this means the sd we are trying to load does not have masks and/or is named incorrectly
keys_without_masks = [k for k in new_sd.keys() if 'mask' not in k]
for old_k, new_k in zip(old_sd.keys(), keys_without_masks):
new_sd[new_k] = old_sd[old_k]
else:
new_names = [v for v in new_sd]
old_names = [v for v in old_sd]
for i, j in enumerate(new_names):
if not 'mask' in j:
new_sd[j] = old_sd[old_names[i]]
try:
model.load_state_dict(new_sd)
except:
new_sd = model.state_dict()
old_sd = sd['state_dict']
k_new = [k for k in new_sd.keys() if 'mask' not in k]
k_new = [k for k in k_new if 'num_batches_tracked' not in k]
for o, n in zip(old_sd.keys(), k_new):
new_sd[n] = old_sd[o]
model.load_state_dict(new_sd)
return model, sd
def get_error(output, labels, topk=(1,)):
"""Computes the error@k for the specified values of k"""
max_k = max(topk)
batch_size = labels.size(0)
_, pred = output.topk(max_k, 1, True, True)
pred = pred.t()
correct = pred.eq(labels.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(100.0 - correct_k.mul_(100.0 / batch_size))
return res
def get_no_params(net):
"""统计所有卷积层中非零权重的数量。"""
params = net
total = 0
for p in params:
num = torch.sum(params[p] != 0)
if 'conv' in p:
total += num
return total
def train(model, train_loader, criterion, optimizer):
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.train()
for i, (images, labels) in enumerate(train_loader):
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
# 计算 top1和 top5误差
err1, err5 = get_error(output.detach(), labels, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(err1.item(), images.size(0))
top5.update(err5.item(), images.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
def validate(model, epoch, val_loader, criterion, checkpoint=None):
"""
Args:
checkpoint (str): 保存checkpoint文件的名称;如果没有,则不保存;默认不保存。
"""
global error_history
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.eval()
for i, (images, labels) in enumerate(val_loader):
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
err1, err5 = get_error(output.detach(), labels, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(err1.item(), images.size(0))
top5.update(err5.item(), images.size(0))
error_history.append(top1.avg)
if checkpoint:
state = {
'net': model.state_dict(),
'masks': [w for name, w in model.named_parameters() if 'mask' in name],
'epoch': epoch,
'error_history': error_history,
}
torch.save(state, config.CP + '/%s.t7' % checkpoint)
def finetune(model, train_loader, criterion, optimizer, steps=100):
model.train()
data_iter = iter(train_loader)
for i in range(steps):
try:
images, labels = data_iter.next()
except StopIteration:
data_iter = iter(train_loader)
images, labels = data_iter.next()
images, labels = images.to(device), labels.to(device)
output = model(images)
loss = criterion(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def expand_model(model, layers=torch.Tensor()):
for layer in model.children():
if len(list(layer.children())) > 0:
layers = expand_model(layer, layers) # 递归,直到最底层
else:
if isinstance(layer, nn.Conv2d) and 'mask' not in layer._get_name():
layers = torch.cat((layers.view(-1), layer.weight.view(-1)))
return layers
def calculate_threshold(model, rate):
"""计算权重阈值。
Args:
rate (float): 0~100, 剪枝的比例。
Returns:
float: 阈值,小于该值的权重在剪枝时将被去除。
"""
empty = torch.Tensor() # 创建一个空 Tensor,目标权重将被加到里面
if torch.cuda.is_available():
empty = empty.cuda()
pre_abs = expand_model(model, empty) # 获取所有未被裁剪的权重,得到一个一维 Tensor
weights = torch.abs(pre_abs) # 取绝对值
return np.percentile(weights.detach().cpu().numpy(), rate) # 取所有权重的分位数(由百分数 rate决定)作为阈值
def sparsify(model, prune_rate=50.):
"""按给定比例进行剪枝。
Args:
prune_rate (float): 剪枝的比例。
Returns:
剪枝后的模型。
"""
threshold = calculate_threshold(model, prune_rate)
try:
model.__prune__(threshold)
except:
model.module.__prune__(threshold)
return model
class Cutout(object):
"""在一张图像上,随机制造一定数量的黑色正方形区域。
Args:
n_holes (int): 正方形区域的数量。
length (int): 正方形的边长。
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor类型的图像,尺寸为 (C, H, W)。
Returns:
Tensor: 带有 n_holes个正方形区域的图像,Tensor类型
"""
h = img.size(1)
w = img.size(2)
mask = np.ones((h, w), np.float32)
for n in range(self.n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img = img * mask
return img