-
Notifications
You must be signed in to change notification settings - Fork 1
/
data.py
executable file
·155 lines (132 loc) · 6.33 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Modifications Copyright 2017 Abigail See
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This file contains code to read the train/eval/test data from file and process it, and read the vocab data from file and process it"""
import glob
import json
import os
import random
import tensorflow as tf
# <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
SENTENCE_START = '<s>'
SENTENCE_END = '</s>'
PAD_TOKEN = '[PAD]' # This has a vocab id, which is used to pad the encoder input, decoder input and target sequence
UNKNOWN_TOKEN = '[UNK]' # This has a vocab id, which is used to represent out-of-vocabulary words
START_DECODING = '[START]' # This has a vocab id, which is used at the start of every decoder input sequence
STOP_DECODING = '[STOP]' # This has a vocab id, which is used at the end of untruncated target sequences
# Note: none of <s>, </s>, [PAD], [UNK], [START], [STOP] should appear in the vocab file.
FLAGS = tf.flags.FLAGS
class Vocab(object):
"""Vocabulary class for mapping between words and ids (integers)"""
def __init__(self, vocab_file, max_size):
"""Creates a vocab of up to max_size words, reading from the vocab_file. If max_size is 0, reads the entire vocab file.
Args:
vocab_file: path to the vocab file, which is assumed to contain "<word> <frequency>" on each line, sorted with most frequent word first. This code doesn't actually use the frequencies, though.
max_size: integer. The maximum size of the resulting Vocabulary."""
self._word_to_id = {}
self._id_to_word = {}
self._count = 0 # keeps track of total number of words in the Vocab
# [UNK], [PAD], [START] and [STOP] get the ids 0,1,2,3.
for w in [UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
self._word_to_id[w] = self._count
self._id_to_word[self._count] = w
self._count += 1
# Read the vocab file and add words up to max_size
with open(vocab_file, 'r', encoding='utf8') as vocab_f:
for line in vocab_f:
pieces = line.split()
if len(pieces) != 2:
continue
w = pieces[0]
if w in [SENTENCE_START, SENTENCE_END, UNKNOWN_TOKEN, PAD_TOKEN, START_DECODING, STOP_DECODING]:
raise Exception(
'<s>, </s>, [UNK], [PAD], [START] and [STOP] shouldn\'t be in the vocab file, but %s is' % w)
w = w.lower()
if w in self._word_to_id:
continue
# raise Exception('Duplicated word in vocabulary file: %s' % w)
self._word_to_id[w] = self._count
self._id_to_word[self._count] = w
self._count += 1
if max_size != 0 and self._count >= max_size:
print("max_size of vocab was specified as %i; we now have %i words. Stopping reading." % (
max_size, self._count))
break
print("Finished constructing vocabulary of %i total words. Last word added: %s" % (
self._count, self._id_to_word[self._count - 1]))
def word2id(self, word):
"""Returns the id (integer) of a word (string). Returns [UNK] id if word is OOV."""
if word not in self._word_to_id:
return self._word_to_id[UNKNOWN_TOKEN]
return self._word_to_id[word]
def id2word(self, word_id):
"""Returns the word (string) corresponding to an id (integer)."""
if word_id not in self._id_to_word:
raise ValueError('Id not found in vocab: %d' % word_id)
return self._id_to_word[word_id]
def size(self):
"""Returns the total size of the vocabulary"""
return self._count
def json_generator(data_path, single_pass):
data_counter = 0
while True:
filelist = glob.glob(data_path) # get the list of datafiles
assert filelist, ('Error: Empty filelist at %s' % data_path) # check filelist isn't empty
if single_pass:
filelist = sorted(filelist)
else:
random.shuffle(filelist)
for f in filelist:
reader = open(f, 'r', encoding='utf8')
for line in reader:
try:
obj = json.loads(line)
data_counter += 1
yield obj
except Exception as e:
print('解析json出错, %s' % e)
total_data = data_counter
data_counter = 0
if single_pass:
print("example_generator completed reading all datafiles. No more data.")
break
def get_config():
"""Returns config for tf.session"""
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
return config
class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0;37;40m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
BLINK = '\033[5;41;42m'
GREENBACK = '\033[0;40;42m'
REDBACK = '\033[0;42;101m'
def load_ckpt(saver, sess, ckpt_dir="train"):
try:
latest_filename = "checkpoint_best" if ckpt_dir == "eval" else None
ckpt_dir = os.path.join(FLAGS.log_root, ckpt_dir)
ckpt_state = tf.train.get_checkpoint_state(ckpt_dir, latest_filename=latest_filename)
tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
saver.restore(sess, ckpt_state.model_checkpoint_path)
return ckpt_state.model_checkpoint_path
except Exception as e:
print('The error is -->', e)
tf.logging.info("Failed to load checkpoint from %s.", ckpt_dir)