-
Notifications
You must be signed in to change notification settings - Fork 14
/
helper.py
39 lines (27 loc) · 1.35 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
def load_data(path):
input_file = os.path.join(path)
with open(input_file, "r", encoding='utf-8', errors='ignore') as f:
data = f.read()
return data
def extract_vocab(data):
special_words = ['<pad>', '<unk>', '<s>', '<\s>']
set_words = set([word for line in data.split('\n') for word in line.split()])
int_to_vocab = {word_i: word for word_i, word in enumerate(special_words + list(set_words))}
vocab_to_int = {word: word_i for word_i, word in int_to_vocab.items()}
return int_to_vocab, vocab_to_int
def pad_id_sequences(source_ids, source_vocab_to_int, target_ids, target_vocab_to_int, sequence_length):
new_source_ids = [list(reversed(sentence + [source_vocab_to_int['<pad>']] * (sequence_length - len(sentence)))) \
for sentence in source_ids]
new_target_ids = [sentence + [target_vocab_to_int['<pad>']] * (sequence_length - len(sentence)) \
for sentence in target_ids]
return new_source_ids, new_target_ids
def batch_data(source, target, batch_size):
"""
Batch source and target together
"""
for batch_i in range(0, len(source)//batch_size):
start_i = batch_i * batch_size
source_batch = source[start_i:start_i + batch_size]
target_batch = target[start_i:start_i + batch_size]
yield source_batch, target_batch