-
Notifications
You must be signed in to change notification settings - Fork 9
/
dataloader.py
28 lines (24 loc) · 922 Bytes
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from torch.utils import data
import torch
import os
from PIL import Image
class EvalDataset(data.Dataset):
def __init__(self, img_root, label_root):
lst_label = sorted(os.listdir(label_root))
lst_pred = sorted(os.listdir(img_root))
lst = []
for name in lst_label:
if name in lst_pred:
lst.append(name)
self.lst = lst
self.image_path = list(map(lambda x: os.path.join(img_root, x), lst))
self.label_path = list(map(lambda x: os.path.join(label_root, x), lst))
def __getitem__(self, item):
pred = Image.open(self.image_path[item]).convert('L')
gt = Image.open(self.label_path[item]).convert('L')
if pred.size != gt.size:
pred = pred.resize(gt.size, Image.BILINEAR)
img_name = self.lst[item]
return pred, gt, img_name
def __len__(self):
return len(self.image_path)