forked from githubharald/SimpleHTR
-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataloader_iam.py
134 lines (107 loc) · 4.68 KB
/
dataloader_iam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import pickle
import random
from collections import namedtuple
from typing import Tuple
import cv2
import lmdb
import numpy as np
from path import Path
Sample = namedtuple('Sample', 'gt_text, file_path')
Batch = namedtuple('Batch', 'imgs, gt_texts, batch_size')
class DataLoaderIAM:
"""
Loads data which corresponds to IAM format,
see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database
"""
def __init__(self,
data_dir: Path,
batch_size: int,
data_split: float = 0.95,
fast: bool = True) -> None:
"""Loader for dataset."""
assert data_dir.exists()
self.fast = fast
if fast:
self.env = lmdb.open(str(data_dir / 'lmdb'), readonly=True)
self.data_augmentation = False
self.curr_idx = 0
self.batch_size = batch_size
self.samples = []
f = open(data_dir / 'gt/words.txt')
chars = set()
bad_samples_reference = ['a01-117-05-02', 'r06-022-03-05'] # known broken images in IAM dataset
for line in f:
# ignore empty and comment lines
line = line.strip()
if not line or line[0] == '#':
continue
line_split = line.split(' ')
assert len(line_split) >= 9
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png
file_name_split = line_split[0].split('-')
file_name_subdir1 = file_name_split[0]
file_name_subdir2 = f'{file_name_split[0]}-{file_name_split[1]}'
file_base_name = line_split[0] + '.png'
file_name = data_dir / 'img' / file_name_subdir1 / file_name_subdir2 / file_base_name
if line_split[0] in bad_samples_reference:
print('Ignoring known broken image:', file_name)
continue
# GT text are columns starting at 9
gt_text = ' '.join(line_split[8:])
chars = chars.union(set(list(gt_text)))
# put sample into list
self.samples.append(Sample(gt_text, file_name))
# split into training and validation set: 95% - 5%
split_idx = int(data_split * len(self.samples))
self.train_samples = self.samples[:split_idx]
self.validation_samples = self.samples[split_idx:]
# put words into lists
self.train_words = [x.gt_text for x in self.train_samples]
self.validation_words = [x.gt_text for x in self.validation_samples]
# start with train set
self.train_set()
# list of all chars in dataset
self.char_list = sorted(list(chars))
def train_set(self) -> None:
"""Switch to randomly chosen subset of training set."""
self.data_augmentation = True
self.curr_idx = 0
random.shuffle(self.train_samples)
self.samples = self.train_samples
self.curr_set = 'train'
def validation_set(self) -> None:
"""Switch to validation set."""
self.data_augmentation = False
self.curr_idx = 0
self.samples = self.validation_samples
self.curr_set = 'val'
def get_iterator_info(self) -> Tuple[int, int]:
"""Current batch index and overall number of batches."""
if self.curr_set == 'train':
num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches
else:
num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller
curr_batch = self.curr_idx // self.batch_size + 1
return curr_batch, num_batches
def has_next(self) -> bool:
"""Is there a next element?"""
if self.curr_set == 'train':
return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches
else:
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller
def _get_img(self, i: int) -> np.ndarray:
if self.fast:
with self.env.begin() as txn:
basename = Path(self.samples[i].file_path).basename()
data = txn.get(basename.encode("ascii"))
img = pickle.loads(data)
else:
img = cv2.imread(self.samples[i].file_path, cv2.IMREAD_GRAYSCALE)
return img
def get_next(self) -> Batch:
"""Get next element."""
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples)))
imgs = [self._get_img(i) for i in batch_range]
gt_texts = [self.samples[i].gt_text for i in batch_range]
self.curr_idx += self.batch_size
return Batch(imgs, gt_texts, len(imgs))