forked from pengzhiliang/MAE-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_mae_vis.py
228 lines (177 loc) · 8.5 KB
/
run_mae_vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# -*- coding: utf-8 -*-
# @Time : 2021/11/18 22:40
# @Author : zhao pengfei
# @Email : [email protected]
# @File : run_mae_vis.py
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from timm.models import create_model
import utils
import modeling_pretrain
from datasets import DataAugmentationForMAE
from torchvision.transforms import ToPILImage
from einops import rearrange
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def get_args():
parser = argparse.ArgumentParser('MAE visualization reconstruction script', add_help=False)
parser.add_argument('img_path', type=str, help='input image path')
parser.add_argument('img_type', type=str, help='original/attacked')
parser.add_argument('save_path', type=str, help='save image path')
parser.add_argument('model_path', type=str, help='checkpoint path of model')
parser.add_argument('--input_size', default=224, type=int,
help='images input size for backbone')
parser.add_argument('--device', default='cuda:0',
help='device to use for training / testing')
parser.add_argument('--imagenet_default_mean_and_std', default=True, action='store_true')
parser.add_argument('--mask_ratio', default=0.75, type=float,
help='ratio of the visual tokens/patches need be masked')
# Model parameters
parser.add_argument('--model', default='pretrain_mae_base_patch16_224', type=str, metavar='MODEL',
help='Name of model to vis')
parser.add_argument('--drop_path', type=float, default=0.0, metavar='PCT',
help='Drop path rate (default: 0.1)')
return parser.parse_args()
def get_model(args):
print(f"Creating model: {args.model}")
model = create_model(
args.model,
pretrained=False,
drop_path_rate=args.drop_path,
drop_block_rate=None,
)
return model
def set_seed(seed=42):
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def compute_pixelwise_accuracy(original, reconstructed, threshold=0.05):
abs_diff = torch.abs(original - reconstructed)
correct_pixels = (abs_diff < threshold).float().sum()
total_pixels = torch.numel(original)
accuracy = correct_pixels / total_pixels
return accuracy.item()
def calculate_patchwise_mse(img, rec_img, bool_masked_pos, patch_size):
"""Calculate Mean Squared Error for each patch."""
mse_losses = []
masked_losses = [] # Store MSE for only masked patches
patches_per_side = img.shape[2] // patch_size[0]
for idx in range(patches_per_side * patches_per_side):
row = idx // patches_per_side
col = idx % patches_per_side
i, j = row * patch_size[0], col * patch_size[1]
patch_ori = img[:, :, i:i+patch_size[0], j:j+patch_size[1]]
patch_rec = rec_img[:, :, i:i+patch_size[0], j:j+patch_size[1]]
mse = ((patch_ori - patch_rec) ** 2).mean().item()
if bool_masked_pos[0, idx]: # Only calculate MSE for masked patches
mse_losses.append(mse)
masked_losses.append(mse)
else:
mse_losses.append(float('nan')) # Placeholder for non-masked patches with NaN
avg_masked_mse = sum(masked_losses) / len(masked_losses) if masked_losses else None
return mse_losses, avg_masked_mse
def plot_mse_per_patch(mse_losses, save_path, title, y_range):
"""Plot Mean Squared Error for each patch."""
plt.figure(figsize=(10, 6))
plt.plot(mse_losses, marker='o')
plt.title(title)
plt.xlabel('Patch Index')
plt.ylabel('MSE')
plt.grid(True)
plt.ylim(y_range) # Manually setting y-axis range
plt.savefig(save_path, bbox_inches='tight', dpi=300)
def assessment(img, rec_img, bool_masked_pos, img_type, patch_size):
print(f"Accuracy of rec_{img_type}_vs_ori_{img_type}: ", compute_pixelwise_accuracy(img, rec_img))
mse_losses, avg_masked_mse = calculate_patchwise_mse(img, rec_img, bool_masked_pos, patch_size)
if avg_masked_mse is not None:
print(f"Average MSE loss on masked patches (rec_{img_type}_vs_ori_{img_type}):", avg_masked_mse)
else:
print("No masked patches detected.")
print(f"Average MSE loss (rec_{img_type}_vs_ori_{img_type}):", np.mean(mse_losses))
# calculate y-range for the MSE plots so they share the same y-axis range:
y_range = (min(mse_losses), max(mse_losses))
plot_mse_per_patch(mse_losses, save_path=f'out/rec_{img_type}_vs_ori_{img_type}.png', title=f'Mean Squared Error for Masked Patch: rec_{img_type}_img vs ori_{img_type}_img', y_range=y_range)
def main(args):
print(args)
# set_seed(42)
device = torch.device(args.device)
cudnn.benchmark = True
img_type = args.img_type
model = get_model(args)
patch_size = model.encoder.patch_embed.patch_size
print("Patch size = %s" % str(patch_size))
args.window_size = (args.input_size // patch_size[0], args.input_size // patch_size[1])
args.patch_size = patch_size
model.to(device)
checkpoint = torch.load(args.model_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()
with open(args.img_path, 'rb') as f:
img = Image.open(f)
img.convert('RGB')
if img_type == 'original':
img = img.resize((224, 224)) # Resize the image
print("img path:", args.img_path)
transforms = DataAugmentationForMAE(args)
img, bool_masked_pos = transforms(img)
bool_masked_pos = torch.from_numpy(bool_masked_pos)
with torch.no_grad():
img = img[None, :]
bool_masked_pos = bool_masked_pos[None, :]
img = img.to(device, non_blocking=True)
bool_masked_pos = bool_masked_pos.to(device, non_blocking=True).flatten(1).to(torch.bool)
outputs = model(img, bool_masked_pos)
#save original img
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, :, None, None]
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, :, None, None]
ori_img = img * std + mean # in [0, 1]
img = ToPILImage()(ori_img[0, :])
img.save(f"{args.save_path}/ori_{img_type}_img.jpg")
img_squeeze = rearrange(ori_img, 'b c (h p1) (w p2) -> b (h w) (p1 p2) c', p1=patch_size[0], p2=patch_size[0])
img_norm = (img_squeeze - img_squeeze.mean(dim=-2, keepdim=True)) / (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6)
img_patch = rearrange(img_norm, 'b n p c -> b n (p c)')
img_patch[bool_masked_pos] = outputs
#make mask
mask = torch.ones_like(img_patch)
mask[bool_masked_pos] = 0
mask = rearrange(mask, 'b n (p c) -> b n p c', c=3)
mask = rearrange(mask, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', p1=patch_size[0], p2=patch_size[1], h=14, w=14)
#save reconstruction img
rec_img = rearrange(img_patch, 'b n (p c) -> b n p c', c=3)
# Notice: To visualize the reconstruction image, we add the predict and the original mean and var of each patch. Issue #40
rec_img = rec_img * (img_squeeze.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6) + img_squeeze.mean(dim=-2, keepdim=True)
rec_img = rearrange(rec_img, 'b (h w) (p1 p2) c -> b c (h p1) (w p2)', p1=patch_size[0], p2=patch_size[1], h=14, w=14)
img = ToPILImage()(rec_img[0, :].clip(0,0.996))
img.save(f"{args.save_path}/rec_{img_type}_img.jpg")
#save random mask img
img_mask = rec_img * mask
img = ToPILImage()(img_mask[0, :])
img.save(f"{args.save_path}/mask_{img_type}_img.jpg")
assessment(ori_img, rec_img, bool_masked_pos, img_type, patch_size)
print("----------------------------------------------------------")
if __name__ == '__main__':
opts = get_args()
set_seed(42)
main(opts)