-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
132 lines (107 loc) · 3.5 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import pickle
import queue
import threading
import zipfile
import cv2
import numpy as np
import tqdm_utils
def image_center_crop(img):
h, w = img.shape[0], img.shape[1]
pad_left = 0
pad_right = 0
pad_top = 0
pad_bottom = 0
if h > w:
diff = h - w
pad_top = diff - diff // 2
pad_bottom = diff // 2
else:
diff = w - h
pad_left = diff - diff // 2
pad_right = diff // 2
return img[pad_top : h - pad_bottom, pad_left : w - pad_right, :] # noqa: E203
def decode_image_from_buf(buf):
img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def crop_and_preprocess(img, input_shape, preprocess_for_model):
img = image_center_crop(img) # take center crop
img = cv2.resize(img, input_shape) # resize for our model
img = img.astype("float32") # prepare for normalization
img = preprocess_for_model(img) # preprocess for model
return img
def apply_model(
zip_fn,
model,
preprocess_for_model,
extensions=(".jpg",),
input_shape=(224, 224),
batch_size=32,
):
# queue for cropped images
q = queue.Queue(maxsize=batch_size * 10)
# when read thread put all images in queue
read_thread_completed = threading.Event()
# time for read thread to die
kill_read_thread = threading.Event()
def reading_thread(zip_fn):
zf = zipfile.ZipFile(zip_fn)
for fn in tqdm_utils.tqdm_notebook_failsafe(zf.namelist()):
if kill_read_thread.is_set():
break
if os.path.splitext(fn)[-1] in extensions:
buf = zf.read(fn) # read raw bytes from zip for fn
img = decode_image_from_buf(buf) # decode raw bytes
img = crop_and_preprocess(img, input_shape, preprocess_for_model)
while True:
try:
q.put((os.path.split(fn)[-1], img), timeout=1) # put in queue
except queue.Full:
if kill_read_thread.is_set():
break
continue
break
read_thread_completed.set() # read all images
# start reading thread
t = threading.Thread(target=reading_thread, args=(zip_fn,))
t.daemon = True
t.start()
img_fns = []
img_embeddings = []
batch_imgs = []
def process_batch(batch_imgs):
batch_imgs = np.stack(batch_imgs, axis=0)
batch_embeddings = model.predict(batch_imgs)
img_embeddings.append(batch_embeddings)
try:
while True:
try:
fn, img = q.get(timeout=1)
except queue.Empty:
if read_thread_completed.is_set():
break
continue
img_fns.append(fn)
batch_imgs.append(img)
if len(batch_imgs) == batch_size:
process_batch(batch_imgs)
batch_imgs = []
q.task_done()
# process last batch
if len(batch_imgs):
process_batch(batch_imgs)
finally:
kill_read_thread.set()
t.join()
q.join()
img_embeddings = np.vstack(img_embeddings)
return img_embeddings, img_fns
def save_pickle(obj, fn):
with open(fn, "wb") as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
def read_pickle(fn):
with open(fn, "rb") as f:
return pickle.load(f)