forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_utils.py
614 lines (502 loc) · 22.5 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities used for data preparation."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import json
import os
from absl import logging
import tensorflow as tf
special_symbols = {
"<unk>": 0,
"<s>": 1,
"</s>": 2,
"<cls>": 3,
"<sep>": 4,
"<pad>": 5,
"<mask>": 6,
"<eod>": 7,
"<eop>": 8,
}
VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
def file_based_input_fn_builder(input_file, name_to_features, batch_size,
is_training):
"""Creates an `input_fn` closure."""
logging.info("Input tfrecord file %s", input_file)
def _decode_record(record, name_to_features):
"""Decodes a record to a TensorFlow example."""
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in list(example.keys()):
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def input_fn():
"""Returns dataset for training/evaluation."""
num_threads = 8
if isinstance(input_file, str):
d = tf.data.TFRecordDataset(input_file)
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
if is_training:
d = d.shuffle(2048)
d = d.repeat()
else:
cycle_length = min(num_threads, len(input_file))
d = tf.data.Dataset.from_tensor_slices(input_file)
# file level shuffle
d = d.shuffle(len(input_file)).repeat()
d = d.apply(
tf.data.experimental.parallel_interleave(
tf.data.TFRecordDataset,
sloppy=is_training,
cycle_length=cycle_length))
if is_training:
# sample level shuffle
d = d.shuffle(buffer_size=2048)
# TODO(b/138223458): Hard-code drop_remainder=True to get around the bug
# that under TPU strategy, setting drop_remainder=False in
# tf.data.Dataset.batch() while data_size can be divided by global
# batch_size will trigger dynamic_dimension related TPU compilation error.
d = d.apply(
tf.data.experimental.map_and_batch(
lambda record: _decode_record(record, name_to_features),
batch_size=batch_size,
num_parallel_batches=num_threads,
drop_remainder=True))
# When `input_file` is a path to a single file or a list
# containing a single path, disable auto sharding so that
# same input file is sent to all workers.
if isinstance(input_file, str) or len(input_file) == 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard = False
d = d.with_options(options)
d = d.prefetch(tf.data.experimental.AUTOTUNE)
return d
return input_fn
def create_classification_dataset(file_path, seq_length, batch_size,
is_training):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([], tf.int64),
"is_real_example": tf.io.FixedLenFeature([], tf.int64),
}
input_fn = file_based_input_fn_builder(file_path, name_to_features,
batch_size, is_training)
dataset = input_fn()
return dataset
def create_squad_dataset(file_path, seq_length, batch_size, is_training):
"""Creates input dataset from (tf)records files for pretraining."""
name_to_features = {
"unique_ids": tf.io.FixedLenFeature([], tf.int64),
"input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
"segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
"cls_index": tf.io.FixedLenFeature([], tf.int64),
"p_mask": tf.io.FixedLenFeature([seq_length], tf.float32)
}
if is_training:
name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32)
input_fn = file_based_input_fn_builder(file_path, name_to_features,
batch_size, is_training)
dataset = input_fn()
return dataset
def _get_input_iterator(input_fn, strategy):
"""Returns distributed dataset iterator."""
# When training with TPU pods, datasets needs to be cloned across
# workers. Since Dataset instance cannot be cloned in eager mode, we instead
# pass callable that returns a dataset.
input_data = input_fn()
if callable(input_data):
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_data))
else:
iterator = iter(strategy.experimental_distribute_dataset(input_data))
return iterator
def get_classification_input_data(batch_size, seq_len, strategy, is_training,
file_path):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
def _dataset_fn(ctx=None):
del ctx
train_dataset = create_classification_dataset(
file_path=file_path,
seq_length=seq_len,
batch_size=batch_size,
is_training=is_training)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
file_path):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
if is_training:
input_glob = os.path.join(
file_path,
"spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len))
global_input_paths = tf.io.gfile.glob(input_glob)
else:
global_input_paths = file_path
def _dataset_fn(ctx=None):
del ctx
train_dataset = create_squad_dataset(
file_path=global_input_paths,
seq_length=seq_len,
batch_size=batch_size,
is_training=is_training)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
def create_pretrain_dataset(file_names,
bsz_per_core,
seq_len,
reuse_len,
perm_size,
num_predict=None,
input_pipeline_context=None):
"""Creates pretrain dataset."""
def parser(record):
"""Function used to parse tfrecord."""
record_spec = {
"input": tf.io.FixedLenFeature([seq_len], tf.int64),
"target": tf.io.FixedLenFeature([seq_len], tf.int64),
"seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
"label": tf.io.FixedLenFeature([1], tf.int64),
"is_masked": tf.io.FixedLenFeature([seq_len], tf.int64),
}
# retrieve serialized example
example = tf.io.parse_single_example(
serialized=record, features=record_spec)
inputs = example.pop("input")
target = example.pop("target")
is_masked = tf.cast(example.pop("is_masked"), tf.bool)
non_reuse_len = seq_len - reuse_len
# perm_size should not be larger than reuse_len or non_reuse_len otherwise
# there will be data leaks.
assert perm_size <= reuse_len and perm_size <= non_reuse_len
# Creates permutation mask and target mask for the first reuse_len tokens.
# The tokens in this part are reused from the last sequence.
perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
inputs[:reuse_len], target[:reuse_len], is_masked[:reuse_len],
perm_size, reuse_len)
# Creates permutation mask and target mask for the rest of tokens in
# current example, which are concatentation of two new segments.
perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
inputs[reuse_len:], target[reuse_len:], is_masked[reuse_len:],
perm_size, non_reuse_len)
perm_mask_0 = tf.concat(
[perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
axis=1)
perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
target = tf.concat([target_0, target_1], axis=0)
target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
input_k = tf.concat([input_k_0, input_k_1], axis=0)
input_q = tf.concat([input_q_0, input_q_1], axis=0)
if num_predict is not None:
indices = tf.range(seq_len, dtype=tf.int64)
bool_target_mask = tf.cast(target_mask, tf.bool)
indices = tf.boolean_mask(indices, bool_target_mask)
##### extra padding due to CLS/SEP introduced after prepro
actual_num_predict = tf.shape(indices)[0]
pad_len = num_predict - actual_num_predict
##### target_mapping
target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
target_mapping = tf.concat([target_mapping, paddings], axis=0)
example["target_mapping"] = tf.reshape(target_mapping,
[num_predict, seq_len])
##### target
target = tf.boolean_mask(target, bool_target_mask)
paddings = tf.zeros([pad_len], dtype=target.dtype)
target = tf.concat([target, paddings], axis=0)
example["target"] = tf.reshape(target, [num_predict])
##### target mask
target_mask = tf.concat([
tf.ones([actual_num_predict], dtype=tf.float32),
tf.zeros([pad_len], dtype=tf.float32)
],
axis=0)
example["target_mask"] = tf.reshape(target_mask, [num_predict])
else:
example["target"] = tf.reshape(target, [seq_len])
example["target_mask"] = tf.reshape(target_mask, [seq_len])
# reshape back to fixed shape
example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
example["input_k"] = tf.reshape(input_k, [seq_len])
example["input_q"] = tf.reshape(input_q, [seq_len])
for key in list(example.keys()):
val = example[key]
if tf.keras.backend.is_sparse(val):
val = tf.sparse.to_dense(val)
if val.dtype == tf.int64:
val = tf.cast(val, tf.int32)
example[key] = val
for k, v in example.items():
logging.info("%s: %s", k, v)
return example
dataset = parse_files_to_dataset(
parser=parser,
file_paths=file_names,
bsz_per_core=bsz_per_core,
input_pipeline_context=input_pipeline_context)
return dataset
def format_filename(prefix,
bsz_per_host,
seq_len,
bi_data,
suffix,
mask_alpha=5,
mask_beta=1,
reuse_len=None,
uncased=False,
fixed_num_predict=None):
"""Generates input file name pattern."""
if reuse_len is None:
reuse_len_str = ""
else:
reuse_len_str = "reuse-{}.".format(reuse_len)
if not uncased:
uncased_str = ""
else:
uncased_str = "uncased."
if bi_data:
bi_data_str = "bi"
else:
bi_data_str = "uni"
if fixed_num_predict is not None:
fnp_str = "fnp-{}.".format(fixed_num_predict)
else:
fnp_str = ""
file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
mask_alpha, mask_beta, fnp_str, suffix)
return file_name
def get_pretrain_input_data(batch_size,
seq_len,
strategy,
file_path,
reuse_len,
perm_size,
mask_alpha,
mask_beta,
num_predict,
bi_data,
uncased,
num_hosts=1):
"""Returns input dataset from input file string."""
# When using TPU pods, we need to clone dataset across
# workers and need to pass in function that returns the dataset rather
# than passing dataset instance itself.
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
split = "train"
record_glob_base = format_filename(
prefix="record_info-{}-*".format(split),
bsz_per_host=int(batch_size / num_hosts),
seq_len=seq_len,
bi_data=bi_data,
suffix="json",
mask_alpha=mask_alpha,
mask_beta=mask_beta,
reuse_len=reuse_len,
uncased=uncased,
fixed_num_predict=num_predict)
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
"Batch size must be divisible by number of replicas : {}".format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
record_info = {"num_batch": 0, "filenames": []}
tfrecord_dirs = file_path.split(",")
logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
for idx, record_dir in enumerate(tfrecord_dirs):
record_glob = os.path.join(record_dir, record_glob_base)
logging.info("[%d] Record glob: %s", idx, record_glob)
record_paths = sorted(tf.io.gfile.glob(record_glob))
logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
cur_record_info = {"num_batch": 0, "filenames": []}
for record_info_path in record_paths:
with tf.io.gfile.GFile(record_info_path, "r") as fp:
info = json.load(fp)
cur_record_info["num_batch"] += info["num_batch"]
cur_record_info["filenames"] += info["filenames"]
# overwrite directory for `cur_record_info`
new_filenames = []
for filename in cur_record_info["filenames"]:
basename = os.path.basename(filename)
new_filename = os.path.join(record_dir, basename)
new_filenames.append(new_filename)
cur_record_info["filenames"] = new_filenames
logging.info("[Dir %d] Number of chosen batches: %s", idx,
cur_record_info["num_batch"])
logging.info("[Dir %d] Number of chosen files: %s", idx,
len(cur_record_info["filenames"]))
logging.info(cur_record_info["filenames"])
# add `cur_record_info` to global `record_info`
record_info["num_batch"] += cur_record_info["num_batch"]
record_info["filenames"] += cur_record_info["filenames"]
logging.info("Total number of batches: %d", record_info["num_batch"])
logging.info("Total number of files: %d", len(record_info["filenames"]))
logging.info(record_info["filenames"])
def _dataset_fn(ctx=None):
"""Function that can create a pretrain dataset."""
train_dataset = create_pretrain_dataset(
file_names=record_info["filenames"],
bsz_per_core=batch_size,
seq_len=seq_len,
reuse_len=reuse_len,
perm_size=perm_size,
num_predict=num_predict,
input_pipeline_context=ctx)
return train_dataset
return _dataset_fn if use_dataset_fn else _dataset_fn()
def parse_files_to_dataset(parser,
file_paths,
bsz_per_core,
input_pipeline_context=None):
"""Creates the dataset given file paths."""
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
# Note: we cannot perform sample-level shuffle here because this will violate
# the consecutive requirement of data stream.
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
input_pipeline_context.input_pipeline_id)
# file-level shuffle
if len(file_paths) > 1:
dataset = dataset.shuffle(len(file_paths))
dataset = tf.data.TFRecordDataset(dataset)
# (zihang): since we are doing online preprocessing, the parsed result of
# the same input at each time will be different. Thus, cache processed data
# is not helpful. It will use a lot of memory and lead to contrainer OOM.
# So, change to cache non-parsed raw data instead.
dataset = dataset.cache().map(parser).repeat()
dataset = dataset.batch(bsz_per_core, drop_remainder=True)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
"""Samples a permutation of the factorization order.
Creates perm_mask and target_mask accordingly.
Args:
inputs: int64 Tensor in shape [seq_len], input ids.
targets: int64 Tensor in shape [seq_len], target ids.
is_masked: bool Tensor in shape [seq_len]. True means being selected for
partial prediction.
perm_size: the length of longest permutation. Could be set to be reuse_len.
Should not be larger than reuse_len or there will be data leaks.
seq_len: int, sequence length.
Returns:
perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
If perm_mask[i][j] == 1, it means the ith token (in original order) cannot
attend to the jth token
(in original order). This case will happen only when the ith token's
permutated position <= the jth token's permutated position,
and the jth token is masked or is func token. If perm_mask[i][j] == 0, it
means the ith token (in original order) can attend to the jth token
(in original order). Note that non-masked tokens can be attended by all
other tokens, which is different from the description in original paper.
new_targets: int64 Tensor in shape [seq_len], target token ids to be
predicted in XLNet.
In XLNet, target doesn't need to be shifted one position.
target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
target_mask[i] == 1,
the ith token needs to be predicted and mask will be used as input. This
token will count for loss.
If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This
token will not count for loss.
inputs_k: int64 Tensor in shape [seq_len], input ids.
inputs_q: float32 Tensor in shape [seq_len], the same as target_mask.
"""
# Generate permutation indices
index = tf.range(seq_len, dtype=tf.int64)
index = tf.transpose(tf.reshape(index, [-1, perm_size]))
index = tf.random.shuffle(index)
index = tf.reshape(tf.transpose(index), [-1])
# `perm_mask` and `target_mask`
# non-functional tokens
non_func_tokens = tf.logical_not(
tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))
non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
masked_or_func_tokens = tf.logical_not(non_mask_tokens)
# Set the permutation indices of non-masked (& non-funcional) tokens to the
# smallest index (-1):
# (1) they can be seen by all other positions
# (2) they cannot see masked positions, so there won"t be information leak
smallest_index = -tf.ones([seq_len], dtype=tf.int64)
rev_index = tf.where(non_mask_tokens, smallest_index, index)
# Create `target_mask`: non-funcional and masked tokens
# 1: use mask as input and have loss
# 0: use token (or [SEP], [CLS]) as input and do not have loss
target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
target_mask = tf.cast(target_tokens, tf.float32)
# Create `perm_mask`
# `target_tokens` cannot see themselves
self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
# 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
# 0: can attend if i > j or j is non-masked
perm_mask = tf.logical_and(self_rev_index[:, None] <= rev_index[None, :],
masked_or_func_tokens)
perm_mask = tf.cast(perm_mask, tf.float32)
# new target: [next token] for LM and [curr token] (self) for PLM
new_targets = tf.concat([inputs[0:1], targets[:-1]], axis=0)
# construct inputs_k
inputs_k = inputs
# construct inputs_q
inputs_q = target_mask
return perm_mask, new_targets, target_mask, inputs_k, inputs_q