forked from wiibrew/pytorch-yolo2
-
Notifications
You must be signed in to change notification settings - Fork 8
/
image.py
125 lines (100 loc) · 3.46 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/python
# encoding: utf-8
import random
import os
from PIL import Image
import numpy as np
def scale_image_channel(im, c, v):
cs = list(im.split())
cs[c] = cs[c].point(lambda i: i * v)
out = Image.merge(im.mode, tuple(cs))
return out
def distort_image(im, hue, sat, val):
im = im.convert('HSV')
cs = list(im.split())
cs[1] = cs[1].point(lambda i: i * sat)
cs[2] = cs[2].point(lambda i: i * val)
def change_hue(x):
x += hue*255
if x > 255:
x -= 255
if x < 0:
x += 255
return x
cs[0] = cs[0].point(change_hue)
im = Image.merge(im.mode, tuple(cs))
im = im.convert('RGB')
#constrain_image(im)
return im
def rand_scale(s):
scale = random.uniform(1, s)
if(random.randint(1,10000)%2):
return scale
return 1./scale
def random_distort_image(im, hue, saturation, exposure):
dhue = random.uniform(-hue, hue)
dsat = rand_scale(saturation)
dexp = rand_scale(exposure)
res = distort_image(im, dhue, dsat, dexp)
return res
def data_augmentation(img, shape, jitter, hue, saturation, exposure):
oh = img.height
ow = img.width
dw =int(ow*jitter)
dh =int(oh*jitter)
pleft = random.randint(-dw, dw)
pright = random.randint(-dw, dw)
ptop = random.randint(-dh, dh)
pbot = random.randint(-dh, dh)
swidth = ow - pleft - pright
sheight = oh - ptop - pbot
sx = float(swidth) / ow
sy = float(sheight) / oh
flip = random.randint(1,10000)%2
cropped = img.crop( (pleft, ptop, pleft + swidth - 1, ptop + sheight - 1))
dx = (float(pleft)/ow)/sx
dy = (float(ptop) /oh)/sy
sized = cropped.resize(shape)
if flip:
sized = sized.transpose(Image.FLIP_LEFT_RIGHT)
img = random_distort_image(sized, hue, saturation, exposure)
return img, flip, dx,dy,sx,sy
def fill_truth_detection(labpath, w, h, flip, dx, dy, sx, sy):
max_boxes = 50
label = np.zeros((max_boxes,5))
if os.path.getsize(labpath):
bs = np.loadtxt(labpath)
if bs is None:
return label
bs = np.reshape(bs, (-1, 5))
cc = 0
for i in range(bs.shape[0]):
x1 = bs[i][1] - bs[i][3]/2
y1 = bs[i][2] - bs[i][4]/2
x2 = bs[i][1] + bs[i][3]/2
y2 = bs[i][2] + bs[i][4]/2
x1 = min(0.999, max(0, x1 * sx - dx))
y1 = min(0.999, max(0, y1 * sy - dy))
x2 = min(0.999, max(0, x2 * sx - dx))
y2 = min(0.999, max(0, y2 * sy - dy))
bs[i][1] = (x1 + x2)/2
bs[i][2] = (y1 + y2)/2
bs[i][3] = (x2 - x1)
bs[i][4] = (y2 - y1)
if flip:
bs[i][1] = 0.999 - bs[i][1]
if bs[i][3] < 0.001 or bs[i][4] < 0.001:
continue
label[cc] = bs[i]
cc += 1
if cc >= 50:
break
label = np.reshape(label, (-1))
return label
def load_data_detection(imgpath, shape, jitter, hue, saturation, exposure):
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
## data augmentation
img = Image.open(imgpath).convert('RGB')
img,flip,dx,dy,sx,sy = data_augmentation(img, shape, jitter, hue, saturation, exposure)
label = fill_truth_detection(labpath, img.width, img.height, flip, dx, dy, 1./sx, 1./sy)
return img,label