-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
executable file
·206 lines (175 loc) · 8.06 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import os
import sys
import robustbench as rb
import data
#from autopgd_train import apgd_train
import utils
from model_zoo.fast_models import PreActResNet18
from model_zoo.wide_resnet import WideResNet
import other_utils
import autoattack
from autopgd_train import apgd_train
# eps_dict = {'cifar10': {'Linf': 8. / 255., 'L2': .5, 'L1': 12.},
# 'imagenet': {'Linf': 4. / 255., 'L2': 2., 'L1': 255.}}
eps_dict = {'cifar10': {'Linf': 8. / 255., 'L2': 0.5, 'L1': 12.},
'imagenet': {'Linf': 4. / 255., 'L2': 2., 'L1': 255.}}
def eval_single_norm(model, x, y, norm='Linf', eps=8. / 255., bs=1000,
log_path=None, verbose=True):
adversary = autoattack.AutoAttack(model, norm=norm, eps=eps,
log_path=log_path)
adversary.attacks_to_run = ['apgd-ce', 'apgd-t']
#adversary.apgd.n_restarts = 1
with torch.no_grad():
x_adv = adversary.run_standard_evaluation(x, y, bs=bs)
#if verbose
acc = rb.utils.clean_accuracy(model, x_adv, y, device='cuda')
other_utils.check_imgs(x_adv, x, norm)
print('robust accuracy: {:.1%}'.format(acc))
return x_adv
def eval_norms(model, x, y, l_norms, l_epss, bs=1000, log_path=None, n_cls=10):
l_x_adv = []
acc_dets = []
logger = other_utils.Logger(log_path)
for norm, eps in zip(l_norms, l_epss):
x_adv_curr = eval_single_norm(model, x, y, norm=norm, eps=eps, bs=bs,
log_path=log_path, verbose=False)
l_x_adv.append(x_adv_curr.cpu())
acc, output = utils.get_accuracy_and_logits(model, x, y, batch_size=bs,
n_classes=n_cls)
pred = output.to(y.device).max(1)[1] == y
logger.log('')
logger.log('clean accuracy: {:.1%}'.format(pred.float().mean()))
print('clean accuracy: {:.1%}'.format(acc))
acc_dets.append(('clean', acc + 0.))
for norm, eps, x_adv in zip(l_norms, l_epss, l_x_adv):
acc, output = utils.get_accuracy_and_logits(model, x_adv, y,
batch_size=bs, n_classes=n_cls)
other_utils.check_imgs(x_adv, x.cpu(), norm)
pred_curr = output.to(y.device).max(1)[1] == y
logger.log('robust accuracy {}: {:.1%}'.format(norm, pred_curr.float().mean()))
print('robust accuracy: {:.1%}'.format(acc))
pred *= pred_curr
acc_dets.append((norm, acc + 0.))
logger.log('robust accuracy {}: {:.1%}'.format('+'.join(l_norms),
pred.float().mean()))
acc_dets.append(('union', pred.float().mean()))
return l_x_adv, acc_dets
def eval_norms_fast(model, x, y, l_norms, l_epss, n_iter=100, n_cls=10):
acc_dict = {}
assert not model.training
bs = x.shape[0]
acc, output = utils.get_accuracy_and_logits(model, x, y, batch_size=bs,
n_classes=n_cls)
pred = output.to(y.device).max(1)[1] == y
acc_dict['clean'] = acc + 0.
for norm, eps in zip(l_norms, l_epss):
_, _, _, _, x_adv = apgd_train(model, x, y, norm=norm, eps=eps,
n_iter=n_iter, is_train=False)
acc_dict[norm], output = utils.get_accuracy_and_logits(model,
x_adv, y, batch_size=bs, n_classes=n_cls)
pred *= output.to(y.device).max(1)[1] == y
acc_dict['union'] = pred.float().mean()
return acc_dict
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='Wong2020Fast')
parser.add_argument('--n_ex', type=int, default=10000, help='number of examples to evaluate on')
parser.add_argument('--batch_size_eval', type=int, default=100, help='batch size for evaluation')
#parser.add_argument('--batch_size', type=int, default=128, help='batch size for training')
parser.add_argument('--data_dir', type=str, default='/share/datasets/cifar10', help='where to store downloaded datasets')
parser.add_argument('--model_dir', type=str, default='./models', help='where to store downloaded models')
parser.add_argument('--save_dir', type=str, default='./trained_models')
parser.add_argument('--load_model', type=str, default=None)
parser.add_argument('--l_norms', type=str, default='Linf L2 L1')
parser.add_argument('--l_eps', type=str)
parser.add_argument('--dataset', type=str, default='cifar10')
parser.add_argument('--only_clean', action='store_true')
parser.add_argument('--wide', action='store_true')
args = parser.parse_args()
return args
if __name__ == '__main__':
'''
x_test, y_test = data.load_cifar10(1000, data_dir='/home/scratch/datasets/CIFAR10')
model = utils.load_pretrained_models('pretr_L2') #args.model_name
ckpt = torch.load('./trained_models/model_2021-04-21 19:57:33.710832 cifar10 lr=0.05000 piecewise-ft ep=3 attack=apgd fts=pretr_L1 seed=0 at=Linf L1 eps=default iter=10/ep_3.pth')
model.load_state_dict(ckpt)
model.eval()
#eval_single_norm(model, x_test, y_test, norm='L2', eps=.5, bs=256)
eval_norms(model, x_test, y_test, l_norms=['L2', 'L1'], l_epss=[.5, 12.], bs=256)
'''
args = parse_args()
# load data
if args.dataset == 'cifar10':
x_test, y_test = data.load_cifar10(args.n_ex, data_dir=args.data_dir,
device='cpu')
#x_test, y_test = x_test.cpu(), y_test.cpu()
if args.dataset == 'imagenet':
x_test, y_test = data.load_imagenet(args.n_ex, device='cpu')
if os.path.isfile(args.model_name):
# pretr_model = args.model_name.split('fts=')[1].split(' ')[0]
pretr_model = 'qetqet'
args.save_dir, ckpt_name = os.path.split(args.model_name) #os.path.join(args.model_name.split('/')[:-1])
ckpt = torch.load(args.model_name)
else:
pretr_model = args.model_name
args.save_dir = '{}/{}'.format(args.save_dir, args.model_name)
other_utils.makedir(args.save_dir)
ckpt_name = 'pretrained'
not_pretr = os.path.isfile(args.model_name)
log_path = '{}/log_eval_{}.txt'.format(args.save_dir, ckpt_name)
# load model
if pretr_model == 'rand':
model = PreActResNet18(10, activation=args.act).cuda()
#model.eval()
elif pretr_model.startswith('RB'):
# model = rb.utils.load_model(pretr_model.split('_')[1], model_dir=args.model_dir,
# dataset=args.dataset, threat_model=pretr_model.split('_')[2])
model = rb.utils.load_model('_'.join(args.model_name.split('_')[1:]), model_dir=args.model_dir,
dataset=args.dataset, threat_model='Linf')
if args.load_model:
ckpt = torch.load(args.load_model)
model.cuda()
model.eval()
print('{} ({}) loaded'.format(*args.model_name))
elif pretr_model.startswith('pretr'):
model = utils.load_pretrained_models(pretr_model)
print('pretrained model loaded')
else:
if args.wide:
model = WideResNet().cuda()
else:
model = PreActResNet18(10, activation='softplus1').cuda()
# if not_pretr:
# model.load_state_dict(ckpt)
# model.eval()
# clean acc
acc = rb.utils.clean_accuracy(model, x_test, y_test,
device='cuda')
print('clean accuracy {:.1%}'.format(acc))
if args.only_clean:
sys.exit()
# set norms and eps
args.l_norms = args.l_norms.split(' ')
if args.l_eps is None:
args.l_eps = [eps_dict[args.dataset][c] for c in args.l_norms]
else:
# args.l_eps = [float(c) for c in args.l_eps.split(' ')]
norm, size = args.l_eps.split('_')[0], float(args.l_eps.split('_')[1])
eps_dict[args.dataset][norm] = size
args.l_eps = [eps_dict[args.dataset][c] for c in args.l_norms]
# run attacks
if args.dataset == 'cifar10':
l_x_adv, _ = eval_norms(model, x_test, y_test, l_norms=args.l_norms,
l_epss=args.l_eps, bs=args.batch_size_eval, log_path=log_path)
else:
l_x_adv, _ = eval_norms(model, x_test, y_test,
l_norms=args.l_norms, l_epss=args.l_eps,
bs=args.batch_size_eval, log_path=log_path, n_cls=1000)
# saving
for norm, eps, v in zip(args.l_norms, args.l_eps, l_x_adv):
torch.save(v, '{}/eval_{}_{}_1_{}_eps_{:.5f}.pth'.format(
args.save_dir, ckpt_name, norm, args.n_ex, eps))