-
Notifications
You must be signed in to change notification settings - Fork 61
/
ctc_utils.py
147 lines (108 loc) · 4.13 KB
/
ctc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import cv2
def convert_inputs_to_ctc_format(target_text):
SPACE_TOKEN = '-'
SPACE_INDEX = 4
FIRST_INDEX = 0
original = ' '.join(target_text.strip().lower().split(' ')).replace('.', '').replace('?', '').replace(',', '').replace("'", '').replace('!', '').replace('-', '')
print(original)
targets = original.replace(' ', ' ')
targets = targets.split(' ')
# Adding blank label
targets = np.hstack([SPACE_TOKEN if x == '' else list(x) for x in targets])
# Transform char into index
targets = np.asarray([SPACE_INDEX if x == SPACE_TOKEN else ord(x) - FIRST_INDEX
for x in targets])
# Creating sparse representation to feed the placeholder
train_targets = sparse_tuple_from([targets])
return train_targets, original
def sparse_tuple_from(sequences, dtype=np.int32):
indices = []
values = []
for n, seq in enumerate(sequences):
indices.extend(zip([n] * len(seq), range(len(seq))))
values.extend(seq)
indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=dtype)
shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)
return indices, values, shape
def sparse_tensor_to_strs(sparse_tensor):
indices= sparse_tensor[0][0]
values = sparse_tensor[0][1]
dense_shape = sparse_tensor[0][2]
strs = [ [] for i in range(dense_shape[0]) ]
string = []
ptr = 0
b = 0
for idx in range(len(indices)):
if indices[idx][0] != b:
strs[b] = string
string = []
b = indices[idx][0]
string.append(values[ptr])
ptr = ptr + 1
strs[b] = string
return strs
def pad_sequences(sequences, maxlen=None, dtype=np.float32,
padding='post', truncating='post', value=0.):
lengths = np.asarray([len(s) for s in sequences], dtype=np.int64)
nb_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
# take the sample shape from the first non empty sequence
# checking for consistency in the main loop below.
sample_shape = tuple()
for s in sequences:
if len(s) > 0:
sample_shape = np.asarray(s).shape[1:]
break
x = (np.ones((nb_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if len(s) == 0:
continue # empty list was found
if truncating == 'pre':
trunc = s[-maxlen:]
elif truncating == 'post':
trunc = s[:maxlen]
else:
raise ValueError('Truncating type "%s" not understood' % truncating)
# check `trunc` has expected shape
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError('Shape of sample %s of sequence at position %s is different from expected shape %s' %
(trunc.shape[1:], idx, sample_shape))
if padding == 'post':
x[idx, :len(trunc)] = trunc
elif padding == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError('Padding type "%s" not understood' % padding)
return x, lengths
def word_separator():
return '\t'
def levenshtein(a,b):
"Computes the Levenshtein distance between a and b."
n, m = len(a), len(b)
if n > m:
a,b = b,a
n,m = m,n
current = range(n+1)
for i in range(1,m+1):
previous, current = current, [i]+[0]*n
for j in range(1,n+1):
add, delete = previous[j]+1, current[j-1]+1
change = previous[j-1]
if a[j-1] != b[i-1]:
change = change + 1
current[j] = min(add, delete, change)
return current[n]
def edit_distance(a,b,EOS=-1,PAD=-1):
_a = [s for s in a if s != EOS and s != PAD]
_b = [s for s in b if s != EOS and s != PAD]
return levenshtein(_a,_b)
def normalize(image):
return (255. - image)/255.
def resize(image, height):
width = int(float(height * image.shape[1]) / image.shape[0])
sample_img = cv2.resize(image, (width, height))
return sample_img