forked from MKFMIKU/RAW2RGBNet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
90 lines (64 loc) · 2.72 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from PIL import Image
import torch.utils.data as data
from os import listdir
from os.path import join
import random
import numpy as np
import torch
def is_image_file(filename):
filename_lower = filename.lower()
return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.tif'])
def get_patch(*args, patch_size):
if patch_size == 0:
return args
ih, iw = args[0].shape[:2]
ix = random.randrange(0, iw - patch_size + 1)
iy = random.randrange(0, ih - patch_size + 1)
ret = [*[a[iy:iy + patch_size, ix:ix + patch_size, :] for a in args]]
return ret
def augment(*args, hflip=True, rot=False):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
def np2Tensor(*args, rgb_range=1.):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
tensor.mul_(rgb_range / 255)
return tensor
return [_np2Tensor(a) for a in args]
class RAW2RGBData(data.Dataset):
def __init__(self, dataset_dir, patch_size=0, test=False):
super(RAW2RGBData, self).__init__()
self.patch_size = patch_size
self.test = test
data_dir = join(dataset_dir, "RAW")
label_dir = join(dataset_dir, "RGB")
data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)]
label_filenames = [join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)]
label_filenames.sort()
data_filenames.sort()
# data_filenames = data_filenames[:1200]
# label_filenames = label_filenames[:1200]
data_filenames = data_filenames[::200] if test else list(set(data_filenames) - set(data_filenames[::200]))
label_filenames = label_filenames[::200] if test else list(set(label_filenames) - set(label_filenames[::200]))
label_filenames.sort()
data_filenames.sort()
self.data_filenames = data_filenames
self.label_filenames = label_filenames
def __getitem__(self, index):
data = np.asarray(Image.open(self.data_filenames[index]))
label = np.asarray(Image.open(self.label_filenames[index]))
data, label = get_patch(data, label, patch_size=self.patch_size)
if not self.test:
data, label = augment(data, label)
data, label = np2Tensor(data, label)
return data, label
def __len__(self):
return len(self.data_filenames)