From 80dba8903f053314510ae4042e0171ec970e2f83 Mon Sep 17 00:00:00 2001 From: Niki Parmar Date: Tue, 18 Jul 2017 21:52:38 -0700 Subject: [PATCH 1/7] Add celeba dataset, add to problems PiperOrigin-RevId: 162444538 --- tensor2tensor/bin/t2t-datagen | 3 ++ tensor2tensor/data_generators/image.py | 38 +++++++++++++++++++ .../data_generators/problem_hparams.py | 13 +++++++ tensor2tensor/utils/data_reader.py | 4 ++ 4 files changed, 58 insertions(+) diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index b0fd816a2..f0aa26ceb 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -120,6 +120,9 @@ _SUPPORTED_PROBLEM_GENERATORS = { "image_mscoco_characters_test": ( lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000), lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)), + "image_celeba_tune": ( + lambda: image.celeba_generator(FLAGS.tmp_dir, 162770), + lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)), "image_mscoco_tokens_8k_test": ( lambda: image.mscoco_generator( FLAGS.tmp_dir, diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 0cba1800b..79bb51f3c 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -347,3 +347,41 @@ def hparams(self, defaults, model_hparams): p.target_modality = (registry.Modalities.SYMBOL, vocab_size) p.input_space_id = problem.SpaceID.DIGIT_0 p.target_space_id = problem.SpaceID.DIGIT_1 + + +# Filename for CELEBA data. +_CELEBA_NAME = "img_align_celeba" + + +def _get_celeba(directory): + """Download and extract CELEBA to directory unless it is there.""" + path = os.path.join(directory, _CELEBA_NAME) + if not tf.gfile.Exists(path): + # We expect that this file has been downloaded from: + # https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM + # and placed in `directory`. + zipfile.ZipFile(path+".zip", "r").extractall(directory) + + +def celeba_generator(tmp_dir, how_many, start_from=0): + """Image generator for CELEBA dataset. + + Args: + tmp_dir: path to temporary storage directory. + how_many: how many images and labels to generate. + start_from: from which image to start. + + Yields: + A dictionary representing the images with the following fields: + * image/encoded: the string encoding the image as JPEG, + * image/format: the string "jpeg" representing image format, + """ + _get_celeba(tmp_dir) + image_files = tf.gfile.Glob(tmp_dir + "/*.jpg") + for filename in image_files[start_from:start_from+how_many]: + with tf.gfile.Open(filename, "r") as f: + encoded_image_data = f.read() + yield { + "image/encoded": [encoded_image_data], + "image/format": ["jpeg"], + } diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 5922ab59a..3347fe4f6 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -596,6 +596,18 @@ def img2img_imagenet(unused_model_hparams): return p +def image_celeba(unused_model_hparams): + """Image CelebA dataset.""" + p = default_problem_hparams() + p.input_modality = {"inputs": ("image:identity_no_pad", None)} + p.target_modality = ("image:identity_no_pad", None) + p.batch_size_multiplier = 256 + p.max_expected_batch_size_per_shard = 4 + p.input_space_id = 1 + p.target_space_id = 1 + return p + + # Dictionary of named hyperparameter settings for various problems. # This is only accessed through the problem_hparams function below. PROBLEM_HPARAMS_MAP = { @@ -620,6 +632,7 @@ def img2img_imagenet(unused_model_hparams): "image_cifar10_test": image_cifar10, "image_mnist_tune": image_mnist, "image_mnist_test": image_mnist, + "image_celeba_tune": image_celeba, "image_mscoco_characters_tune": image_mscoco_characters, "image_mscoco_characters_test": image_mscoco_characters, "image_mscoco_tokens_8k_test": lambda p: image_mscoco_tokens(p, 2**13), diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index cb84b9e3e..cd8e6c2d3 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -161,6 +161,10 @@ def preprocess(img): inputs = examples["inputs"] examples["inputs"] = resize(inputs, 16) examples["targets"] = resize(inputs, 64) + elif "image_celeba" in data_file_pattern: + inputs = examples["inputs"] + examples["inputs"] = resize(inputs, 8) + examples["targets"] = resize(inputs, 32) elif "audio" in data_file_pattern: # Reshape audio to proper shape From 78acdb4f3b0908bbdf32fea8b98eee5b65641ef9 Mon Sep 17 00:00:00 2001 From: T2T Team Date: Wed, 19 Jul 2017 15:13:13 -0700 Subject: [PATCH 2/7] Fix a bug in text_encoder. "self._UNESCAPE_REGEX -> _UNESCAPE_REGEX" PiperOrigin-RevId: 162542600 --- tensor2tensor/data_generators/text_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index e0ac1901e..8be22ce0b 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -530,4 +530,4 @@ def match(m): # Convert '\u' to '_' and '\\' to '\' return u"_" if m.group(0) == u"\\u" else u"\\" # Cut off the trailing underscore and apply the regex substitution - return self._UNESCAPE_REGEX.sub(match, escaped_token[:-1]) + return _UNESCAPE_REGEX.sub(match, escaped_token[:-1]) From 293b5f6ef63a7a6f5ae546e050967cf79c74b4d2 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 19 Jul 2017 16:36:48 -0700 Subject: [PATCH 3/7] Add genetics dataset - data generation only PiperOrigin-RevId: 162553677 --- tensor2tensor/bin/t2t-datagen | 5 +- tensor2tensor/data_generators/all_problems.py | 1 + tensor2tensor/data_generators/genetics.py | 212 ++++++++++++++++++ tensor2tensor/data_generators/problem.py | 6 + 4 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 tensor2tensor/data_generators/genetics.py diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index f0aa26ceb..1ba354695 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -28,6 +28,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import random import tempfile @@ -320,7 +321,9 @@ def generate_data_for_problem(problem): def generate_data_for_registered_problem(problem_name): problem = registry.problem(problem_name) - problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir, FLAGS.num_shards) + problem.generate_data(os.path.expanduser(FLAGS.data_dir), + os.path.expanduser(FLAGS.tmp_dir), + FLAGS.num_shards) if __name__ == "__main__": diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 364c252a7..0a2503bd2 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -21,6 +21,7 @@ from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio +from tensor2tensor.data_generators import genetics from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb diff --git a/tensor2tensor/data_generators/genetics.py b/tensor2tensor/data_generators/genetics.py new file mode 100644 index 000000000..255e0caf9 --- /dev/null +++ b/tensor2tensor/data_generators/genetics.py @@ -0,0 +1,212 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Genetics problems. + +Inputs are bases ACTG (with indices assigned in that order). + +Requires the h5py library. + +File format expected: + * h5 file + * h5 datasets should include {train, valid, test}_{in, na, out}, which will + map to inputs, targets mask, and targets for the train, dev, and test + datasets. + * Each record in *_in is a bool 2-D numpy array with one-hot encoded base + pairs with shape [num_input_timesteps, 4]. The base order is ACTG. + * Each record in *_na is a bool 1-D numpy array with shape + [num_output_timesteps]. + * Each record in *_out is a float 2-D numpy array with shape + [num_output_timesteps, num_predictions]. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import multiprocessing as mp +import os + +# Dependency imports + +import h5py +import numpy as np + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import registry + +_bases = list("ACTG") +BASE_TO_ID = dict(zip(_bases, range(len(_bases)))) +ID_TO_BASE = dict(zip(range(len(_bases)), _bases)) +UNK_ID = len(_bases) + + +# TODO(rsepassi): +# * DataEncoder for genetic bases +# * GeneticModality and problem hparams +# * Training preprocessing + + +class GeneticsProblem(problem.Problem): + + @property + def download_url(self): + raise NotImplementedError() + + @property + def h5_file(self): + raise NotImplementedError() + + def generate_data(self, data_dir, tmp_dir, num_shards=None): + if num_shards is None: + num_shards = 100 + + # Download source data + h5_filepath = generator_utils.maybe_download(tmp_dir, self.h5_file, + self.download_url) + with h5py.File(h5_filepath, "r") as h5_file: + num_train_examples = h5_file["train_in"].len() + num_dev_examples = h5_file["valid_in"].len() + num_test_examples = h5_file["test_in"].len() + + # Collect all_filepaths to later shuffle + all_filepaths = [] + # Collect created shard processes to start and join + processes = [] + + datasets = [(self.training_filepaths, num_shards, "train", + num_train_examples), (self.dev_filepaths, 1, "valid", + num_dev_examples), + (self.test_filepaths, 1, "test", num_test_examples)] + for fname_fn, nshards, key_prefix, num_examples in datasets: + outfiles = fname_fn(data_dir, nshards, shuffled=False) + all_filepaths.extend(outfiles) + for start_idx, end_idx, outfile in generate_shard_args( + outfiles, num_examples): + p = mp.Process( + target=generate_dataset, + args=(h5_filepath, key_prefix, [outfile], start_idx, end_idx)) + processes.append(p) + + # Start and wait for processes + assert len(processes) == num_shards + 2 # 1 per training shard + dev + test + for p in processes: + p.start() + for p in processes: + p.join() + + # Shuffle + generator_utils.shuffle_dataset(all_filepaths) + + +@registry.register_problem("genetics_cage10") +class GeneticsCAGE10(GeneticsProblem): + + @property + def download_url(self): + return "https://storage.googleapis.com/262k_binned/cage10_l262k_w128.h5" + + @property + def h5_file(self): + return "cage10.h5" + + +@registry.register_problem("genetics_gm12878") +class GeneticsGM12878(GeneticsProblem): + + @property + def download_url(self): + return "https://storage.googleapis.com/262k_binned/gm12878_l262k_w128.h5" + + @property + def h5_file(self): + return "gm12878.h5" + + +def generate_shard_args(outfiles, num_examples): + """Generate start and end indices per outfile.""" + num_shards = len(outfiles) + num_examples_per_shard = num_examples // num_shards + start_idxs = [i * num_examples_per_shard for i in xrange(num_shards)] + end_idxs = list(start_idxs) + end_idxs.pop(0) + end_idxs.append(num_examples) + return zip(start_idxs, end_idxs, outfiles) + + +def generate_dataset(h5_filepath, + key_prefix, + out_filepaths, + start_idx=None, + end_idx=None): + print("PID: %d, Key: %s, (Start, End): (%s, %s)" % (os.getpid(), key_prefix, + start_idx, end_idx)) + generator_utils.generate_files( + dataset_generator(h5_filepath, key_prefix, start_idx, end_idx), + out_filepaths) + + +def dataset_generator(filepath, dataset, start_idx=None, end_idx=None): + with h5py.File(filepath, "r") as h5_file: + # Get input keys from h5_file + src_keys = [s % dataset for s in ["%s_in", "%s_na", "%s_out"]] + src_values = [h5_file[k] for k in src_keys] + inp_data, mask_data, out_data = src_values + assert len(set([v.len() for v in src_values])) == 1 + + if start_idx is None: + start_idx = 0 + if end_idx is None: + end_idx = inp_data.len() + + for i in xrange(start_idx, end_idx): + if i % 100 == 0: + print("Generating example %d for %s" % (i, dataset)) + inputs, mask, outputs = inp_data[i], mask_data[i], out_data[i] + yield to_example_dict(inputs, mask, outputs) + + +def to_example_dict(inputs, mask, outputs): + """Convert single h5 record to an example dict.""" + # Inputs + input_ids = [] + last_idx = -1 + for row in np.argwhere(inputs): + idx, base_id = row + idx, base_id = int(idx), int(base_id) + assert idx > last_idx # if not, means 2 True values in 1 row + # Some rows are all False. Those rows are mapped to UNK_ID. + while idx != last_idx + 1: + input_ids.append(UNK_ID + text_encoder.NUM_RESERVED_TOKENS) + last_idx += 1 + input_ids.append(base_id + text_encoder.NUM_RESERVED_TOKENS) + last_idx = idx + assert len(inputs) == len(input_ids) + input_ids.append(text_encoder.EOS_ID) + + # Targets: mask and output + targets_mask = [float(v) for v in mask] + # The output is (n, m); store targets_shape so that it can be reshaped + # properly on the other end. + targets = [float(v) for v in outputs.flatten()] + targets_shape = [int(dim) for dim in outputs.shape] + assert mask.shape[0] == outputs.shape[0] + + example_keys = ["inputs", "targets_mask", "targets", "targets_shape"] + ex_dict = dict( + zip(example_keys, [input_ids, targets_mask, targets, targets_shape])) + return ex_dict diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 1182ed7d1..e93039b71 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -146,6 +146,12 @@ def dev_filepaths(self, data_dir, num_shards, shuffled): file_basename += utils.UNSHUFFLED_SUFFIX return utils.dev_data_filenames(file_basename, data_dir, num_shards) + def test_filepaths(self, data_dir, num_shards, shuffled): + file_basename = self.dataset_filename() + if not shuffled: + file_basename += utils.UNSHUFFLED_SUFFIX + return utils.test_data_filenames(file_basename, data_dir, num_shards) + def __init__(self, was_reversed=False, was_copy=False): """Create a Problem. From 84445cc6eaabc338285b6a96135c78b0e1e4b26c Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Wed, 19 Jul 2017 19:31:47 -0700 Subject: [PATCH 4/7] Change the format of generated vocab files to include languages, put them in data_dir, add --generate_data option. PiperOrigin-RevId: 162569315 --- README.md | 17 ++- tensor2tensor/bin/t2t-datagen | 129 ++++++------------ tensor2tensor/bin/t2t-trainer | 24 +++- tensor2tensor/data_generators/audio.py | 6 +- .../data_generators/generator_utils.py | 14 +- tensor2tensor/data_generators/image.py | 6 +- tensor2tensor/data_generators/inspect.py | 2 +- .../data_generators/problem_hparams.py | 16 +-- tensor2tensor/data_generators/wmt.py | 73 +++++----- tensor2tensor/data_generators/wsj_parsing.py | 11 +- tensor2tensor/models/transformer.py | 2 +- tensor2tensor/utils/trainer_utils.py | 8 +- 12 files changed, 150 insertions(+), 158 deletions(-) diff --git a/README.md b/README.md index 059fbe429..0564a9c99 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,21 @@ issues](https://github.com/tensorflow/tensor2tensor/issues). And chat with us and other users on [Gitter](https://gitter.im/tensor2tensor/Lobby). +Here is a one-command version that installs tensor2tensor, downloads the data, +trains an English-German translation model, and lets you use it interactively: +``` +pip install tensor2tensor && t2t-trainer \ + --generate_data \ + --data_dir=~/t2t_data \ + --problems=wmt_ende_tokens_32k \ + --model=transformer \ + --hparams_set=transformer_base_single_gpu \ + --output_dir=~/t2t_train/base \ + --decode_interactive +``` + +See the [Walkthrough](#walkthrough) below for more details on each step. + ### Contents * [Walkthrough](#walkthrough) @@ -72,8 +87,6 @@ t2t-datagen \ --num_shards=100 \ --problem=$PROBLEM -cp $TMP_DIR/tokens.vocab.* $DATA_DIR - # Train # * If you run out of memory, add --hparams='batch_size=2048' or even 1024. t2t-trainer \ diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 1ba354695..af5b47f8c 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -80,24 +80,30 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), "ice_parsing_tokens": ( - lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir, - True, "ice", 2**13, 2**8), - lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir, - False, "ice", 2**13, 2**8)), + lambda: wmt.tabbed_parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8), + lambda: wmt.tabbed_parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)), "ice_parsing_characters": ( - lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True), - lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)), + lambda: wmt.tabbed_parsing_character_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True), + lambda: wmt.tabbed_parsing_character_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False)), "wmt_parsing_tokens_8k": ( - lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, True, 2**13), - lambda: wmt.parsing_token_generator(FLAGS.tmp_dir, False, 2**13)), + lambda: wmt.parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13), + lambda: wmt.parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)), "wsj_parsing_tokens_16k": ( - lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True, - 2**14, 2**9), - lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False, - 2**14, 2**9)), + lambda: wsj_parsing.parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True, 2**14, 2**9), + lambda: wsj_parsing.parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False, 2**14, 2**9)), "wmt_ende_bpe32k": ( - lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True), - lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)), + lambda: wmt.ende_bpe_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True), + lambda: wmt.ende_bpe_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False)), "lm1b_32k": ( lambda: lm1b.generator(FLAGS.tmp_dir, True), lambda: lm1b.generator(FLAGS.tmp_dir, False) @@ -119,101 +125,50 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000), lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)), "image_mscoco_characters_test": ( - lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000), - lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)), + lambda: image.mscoco_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True, 80000), + lambda: image.mscoco_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False, 40000)), "image_celeba_tune": ( lambda: image.celeba_generator(FLAGS.tmp_dir, 162770), lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)), "image_mscoco_tokens_8k_test": ( lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 80000, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13), + FLAGS.data_dir, FLAGS.tmp_dir, True, 80000, + vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13), lambda: image.mscoco_generator( - FLAGS.tmp_dir, - False, - 40000, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13)), + FLAGS.data_dir, FLAGS.tmp_dir, False, 40000, + vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)), "image_mscoco_tokens_32k_test": ( lambda: image.mscoco_generator( - FLAGS.tmp_dir, - True, - 80000, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15), + FLAGS.data_dir, FLAGS.tmp_dir, True, 80000, + vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15), lambda: image.mscoco_generator( - FLAGS.tmp_dir, - False, - 40000, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15)), + FLAGS.data_dir, FLAGS.tmp_dir, False, 40000, + vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), "snli_32k": ( lambda: snli.snli_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: snli.snli_token_generator(FLAGS.tmp_dir, False, 2**15), ), - "audio_timit_characters_tune": ( - lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1374), - lambda: audio.timit_generator(FLAGS.tmp_dir, True, 344, 1374)), "audio_timit_characters_test": ( - lambda: audio.timit_generator(FLAGS.tmp_dir, True, 1718), - lambda: audio.timit_generator(FLAGS.tmp_dir, False, 626)), - "audio_timit_tokens_8k_tune": ( lambda: audio.timit_generator( - FLAGS.tmp_dir, - True, - 1374, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13), + FLAGS.data_dir, FLAGS.tmp_dir, True, 1718), lambda: audio.timit_generator( - FLAGS.tmp_dir, - True, - 344, - 1374, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13)), + FLAGS.data_dir, FLAGS.tmp_dir, False, 626)), "audio_timit_tokens_8k_test": ( lambda: audio.timit_generator( - FLAGS.tmp_dir, - True, - 1718, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13), - lambda: audio.timit_generator( - FLAGS.tmp_dir, - False, - 626, - vocab_filename="tokens.vocab.%d" % 2**13, - vocab_size=2**13)), - "audio_timit_tokens_32k_tune": ( - lambda: audio.timit_generator( - FLAGS.tmp_dir, - True, - 1374, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15), + FLAGS.data_dir, FLAGS.tmp_dir, True, 1718, + vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13), lambda: audio.timit_generator( - FLAGS.tmp_dir, - True, - 344, - 1374, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15)), + FLAGS.data_dir, FLAGS.tmp_dir, False, 626, + vocab_filename="vocab.endefr.%d" % 2**13, vocab_size=2**13)), "audio_timit_tokens_32k_test": ( lambda: audio.timit_generator( - FLAGS.tmp_dir, - True, - 1718, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15), + FLAGS.data_dir, FLAGS.tmp_dir, True, 1718, + vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15), lambda: audio.timit_generator( - FLAGS.tmp_dir, - False, - 626, - vocab_filename="tokens.vocab.%d" % 2**15, - vocab_size=2**15)), + FLAGS.data_dir, FLAGS.tmp_dir, False, 626, + vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), "lmptb_10k": ( lambda: ptb.train_generator( FLAGS.tmp_dir, diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 8a801e70e..a37767258 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -31,7 +31,8 @@ from __future__ import print_function # Dependency imports -from tensor2tensor.utils import trainer_utils as utils +from tensor2tensor.utils import registry +from tensor2tensor.utils import trainer_utils from tensor2tensor.utils import usr_dir import tensorflow as tf @@ -45,14 +46,29 @@ flags.DEFINE_string("t2t_usr_dir", "", "The imported files should contain registrations, " "e.g. @registry.register_model calls, that will then be " "available to the t2t-trainer.") +flags.DEFINE_string("tmp_dir", "/tmp/t2t_datagen", + "Temporary storage directory.") +flags.DEFINE_bool("generate_data", False, "Generate data before training?") def main(_): tf.logging.set_verbosity(tf.logging.INFO) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) - utils.log_registry() - utils.validate_flags() - utils.run( + trainer_utils.log_registry() + trainer_utils.validate_flags() + tf.gfile.MakeDirs(FLAGS.output_dir) + + # Generate data if requested. + if FLAGS.generate_data: + tf.gfile.MakeDirs(FLAGS.data_dir) + tf.gfile.MakeDirs(FLAGS.tmp_dir) + for problem_name in FLAGS.problems.split("-"): + tf.logging.info("Generating data for %s" % problem_name) + problem = registry.problem(problem_name) + problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir) + + # Run the trainer. + trainer_utils.run( data_dir=FLAGS.data_dir, model=FLAGS.model, output_dir=FLAGS.output_dir, diff --git a/tensor2tensor/data_generators/audio.py b/tensor2tensor/data_generators/audio.py index 81cfde008..4f8c096a5 100644 --- a/tensor2tensor/data_generators/audio.py +++ b/tensor2tensor/data_generators/audio.py @@ -97,7 +97,8 @@ def _get_text_data(filepath): return " ".join(words) -def timit_generator(tmp_dir, +def timit_generator(data_dir, + tmp_dir, training, how_many, start_from=0, @@ -107,6 +108,7 @@ def timit_generator(tmp_dir, """Data generator for TIMIT transcription problem. Args: + data_dir: path to the data directory. tmp_dir: path to temporary storage directory. training: a Boolean; if true, we use the train set, otherwise the test set. how_many: how many inputs and labels to generate. @@ -128,7 +130,7 @@ def timit_generator(tmp_dir, eos_list = [1] if eos_list is None else eos_list if vocab_filename is not None: vocab_symbolizer = generator_utils.get_or_generate_vocab( - tmp_dir, vocab_filename, vocab_size) + data_dir, tmp_dir, vocab_filename, vocab_size) _get_timit(tmp_dir) datasets = (_TIMIT_TRAIN_DATASETS if training else _TIMIT_TEST_DATASETS) i = 0 diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index b34a87138..5c0c94bce 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -244,16 +244,13 @@ def gunzip_file(gz_path, new_path): "http://www.statmt.org/wmt13/training-parallel-un.tgz", ["un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr"] ], - [ - "https://github.com/stefan-it/nmt-mk-en/raw/master/data/setimes.mk-en.train.tgz", # pylint: disable=line-too-long - ["train.mk", "train.en"] - ], ] -def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): +def get_or_generate_vocab(data_dir, tmp_dir, + vocab_filename, vocab_size, sources=None): """Generate a vocabulary from the datasets in sources (_DATA_FILE_URLS).""" - vocab_filepath = os.path.join(tmp_dir, vocab_filename) + vocab_filepath = os.path.join(data_dir, vocab_filename) if tf.gfile.Exists(vocab_filepath): tf.logging.info("Found vocab file: %s", vocab_filepath) vocab = text_encoder.SubwordTextEncoder(vocab_filepath) @@ -304,7 +301,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): return vocab -def get_or_generate_tabbed_vocab(tmp_dir, source_filename, +def get_or_generate_tabbed_vocab(data_dir, tmp_dir, source_filename, index, vocab_filename, vocab_size): r"""Generate a vocabulary from a tabbed source file. @@ -313,6 +310,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, The index parameter specifies 0 for the source or 1 for the target. Args: + data_dir: path to the data directory. tmp_dir: path to the temporary directory. source_filename: the name of the tab-separated source file. index: index. @@ -322,7 +320,7 @@ def get_or_generate_tabbed_vocab(tmp_dir, source_filename, Returns: The vocabulary. """ - vocab_filepath = os.path.join(tmp_dir, vocab_filename) + vocab_filepath = os.path.join(data_dir, vocab_filename) if os.path.exists(vocab_filepath): vocab = text_encoder.SubwordTextEncoder(vocab_filepath) return vocab diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 79bb51f3c..e3567d78f 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -230,7 +230,8 @@ def _get_mscoco(directory): zipfile.ZipFile(path, "r").extractall(directory) -def mscoco_generator(tmp_dir, +def mscoco_generator(data_dir, + tmp_dir, training, how_many, start_from=0, @@ -240,6 +241,7 @@ def mscoco_generator(tmp_dir, """Image generator for MSCOCO captioning problem with token-wise captions. Args: + data_dir: path to the data directory. tmp_dir: path to temporary storage directory. training: a Boolean; if true, we use the train set, otherwise the test set. how_many: how many images and labels to generate. @@ -261,7 +263,7 @@ def mscoco_generator(tmp_dir, eos_list = [1] if eos_list is None else eos_list if vocab_filename is not None: vocab_symbolizer = generator_utils.get_or_generate_vocab( - tmp_dir, vocab_filename, vocab_size) + data_dir, tmp_dir, vocab_filename, vocab_size) _get_mscoco(tmp_dir) caption_filepath = (_MSCOCO_TRAIN_CAPTION_FILE if training else _MSCOCO_EVAL_CAPTION_FILE) diff --git a/tensor2tensor/data_generators/inspect.py b/tensor2tensor/data_generators/inspect.py index fba3c6492..dad0c1c83 100644 --- a/tensor2tensor/data_generators/inspect.py +++ b/tensor2tensor/data_generators/inspect.py @@ -17,7 +17,7 @@ python data_generators/inspect.py \ --logtostderr \ --print_targets \ - --subword_text_encoder_filename=$DATA_DIR/tokens.vocab.8192 \ + --subword_text_encoder_filename=$DATA_DIR/vocab.endefr.8192 \ --input_filename=$DATA_DIR/wmt_ende_tokens_8k-train-00000-of-00100 """ diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 3347fe4f6..8e6d032d5 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -249,7 +249,7 @@ def audio_timit_tokens(model_hparams, wrong_vocab_size): p = default_problem_hparams() # This vocab file must be present within the data directory. vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % wrong_vocab_size) + "vocab.endefr.%d" % wrong_vocab_size) subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) p.input_modality = { "inputs": (registry.Modalities.AUDIO, None), @@ -298,7 +298,7 @@ def audio_wsj_tokens(model_hparams, wrong_vocab_size): p = default_problem_hparams() # This vocab file must be present within the data directory. vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % wrong_vocab_size) + "vocab.endefr.%d" % wrong_vocab_size) subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) p.input_modality = { "inputs": (registry.Modalities.AUDIO, None), @@ -412,7 +412,7 @@ def wmt_parsing_tokens(model_hparams, wrong_vocab_size): p = default_problem_hparams() # This vocab file must be present within the data directory. vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % wrong_vocab_size) + "vocab.endefr.%d" % wrong_vocab_size) subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) p.input_modality = { "inputs": (registry.Modalities.SYMBOL, subtokenizer.vocab_size) @@ -449,10 +449,10 @@ def wsj_parsing_tokens(model_hparams, # This vocab file must be present within the data directory. source_vocab_filename = os.path.join( model_hparams.data_dir, - prefix + "_source.tokens.vocab.%d" % wrong_source_vocab_size) + prefix + "_source.vocab.%d" % wrong_source_vocab_size) target_vocab_filename = os.path.join( model_hparams.data_dir, - prefix + "_target.tokens.vocab.%d" % wrong_target_vocab_size) + prefix + "_target.vocab.%d" % wrong_target_vocab_size) source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) p.input_modality = { @@ -485,10 +485,10 @@ def ice_parsing_tokens(model_hparams, wrong_source_vocab_size): # This vocab file must be present within the data directory. source_vocab_filename = os.path.join( model_hparams.data_dir, - "ice_source.tokens.vocab.%d" % wrong_source_vocab_size) + "ice_source.vocab.%d" % wrong_source_vocab_size) target_vocab_filename = os.path.join( model_hparams.data_dir, - "ice_target.tokens.vocab.256") + "ice_target.vocab.256") source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) p.input_modality = { @@ -573,7 +573,7 @@ def image_mscoco_tokens(model_hparams, vocab_count): p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)} # This vocab file must be present within the data directory. vocab_filename = os.path.join(model_hparams.data_dir, - "tokens.vocab.%d" % vocab_count) + "vocab.endefr.%d" % vocab_count) subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) p.target_modality = (registry.Modalities.SYMBOL, subtokenizer.vocab_size) p.vocabulary = { diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index 2e1f1e8af..4d134caf1 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -43,7 +43,8 @@ def _default_token_feature_encoders(data_dir, target_vocab_size): - vocab_filename = os.path.join(data_dir, "tokens.vocab.%d" % target_vocab_size) + vocab_filename = os.path.join(data_dir, + "vocab.endefr.%d" % target_vocab_size) subtokenizer = text_encoder.SubwordTextEncoder(vocab_filename) return { "inputs": subtokenizer, @@ -71,7 +72,7 @@ def targeted_vocab_size(self): @property def train_generator(self): - """Generator; takes tmp_dir, is_training, possibly targeted_vocab_size.""" + """Generator; takes data_dir, tmp_dir, is_training, targeted_vocab_size.""" raise NotImplementedError() @property @@ -101,9 +102,11 @@ def generate_data(self, data_dir, tmp_dir, num_shards=None): self.dev_filepaths(data_dir, 1, shuffled=False)) else: generator_utils.generate_dataset_and_shuffle( - self.train_generator(tmp_dir, True, self.targeted_vocab_size), + self.train_generator(data_dir, tmp_dir, True, + self.targeted_vocab_size), self.training_filepaths(data_dir, num_shards, shuffled=False), - self.dev_generator(tmp_dir, False, self.targeted_vocab_size), + self.dev_generator(data_dir, tmp_dir, False, + self.targeted_vocab_size), self.dev_filepaths(data_dir, 1, shuffled=False)) def feature_encoders(self, data_dir): @@ -351,12 +354,14 @@ def _get_wmt_ende_dataset(directory, filename): return train_path -def ende_bpe_token_generator(tmp_dir, train): +def ende_bpe_token_generator(data_dir, tmp_dir, train): """Instance of token generator for the WMT en->de task, training set.""" dataset_path = ("train.tok.clean.bpe.32000" if train else "newstest2013.tok.bpe.32000") train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path) - token_path = os.path.join(tmp_dir, "vocab.bpe.32000") + token_tmp_path = os.path.join(tmp_dir, "vocab.bpe.32000") + token_path = os.path.join(data_dir, "vocab.bpe.32000") + tf.gfile.Copy(token_tmp_path, token_path, overwrite=True) token_vocab = text_encoder.TokenTextEncoder(vocab_filename=token_path) return token_generator(train_path + ".en", train_path + ".de", token_vocab, EOS) @@ -402,9 +407,9 @@ def _compile_data(tmp_dir, datasets, filename): return filename -def ende_wordpiece_token_generator(tmp_dir, train, vocab_size): +def ende_wordpiece_token_generator(data_dir, tmp_dir, train, vocab_size): symbolizer_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size) + data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size) datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_ende_tok_%s" % tag) @@ -471,26 +476,26 @@ def target_space_id(self): return problem.SpaceID.DE_CHR -def zhen_wordpiece_token_bigenerator(tmp_dir, train, source_vocab_size, - target_vocab_size): +def zhen_wordpiece_token_bigenerator(data_dir, tmp_dir, train, + source_vocab_size, target_vocab_size): """Wordpiece generator for the WMT'17 zh-en dataset.""" datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS - source_datasets = [[item[0], [item[1][0]]] for item in datasets] - target_datasets = [[item[0], [item[1][1]]] for item in datasets] + source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS] + target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS] source_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.zh.%d" % source_vocab_size, source_vocab_size, - source_datasets) + data_dir, tmp_dir, "vocab.zh.%d" % source_vocab_size, + source_vocab_size, source_datasets) target_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.en.%d" % target_vocab_size, target_vocab_size, - target_datasets) + data_dir, tmp_dir, "vocab.en.%d" % target_vocab_size, + target_vocab_size, target_datasets) tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag) return bi_vocabs_token_generator(data_path + ".lang1", data_path + ".lang2", source_vocab, target_vocab, EOS) -def zhen_wordpiece_token_generator(tmp_dir, train, vocab_size): - return zhen_wordpiece_token_bigenerator(tmp_dir, train, +def zhen_wordpiece_token_generator(data_dir, tmp_dir, train, vocab_size): + return zhen_wordpiece_token_bigenerator(data_dir, tmp_dir, train, vocab_size, vocab_size) @@ -517,9 +522,9 @@ def target_space_id(self): def feature_encoders(self, data_dir): vocab_size = self.targeted_vocab_size source_vocab_filename = os.path.join(data_dir, - "tokens.vocab.zh.%d" % vocab_size) + "vocab.zh.%d" % vocab_size) target_vocab_filename = os.path.join(data_dir, - "tokens.vocab.en.%d" % vocab_size) + "vocab.en.%d" % vocab_size) source_token = text_encoder.SubwordTextEncoder(source_vocab_filename) target_token = text_encoder.SubwordTextEncoder(target_vocab_filename) return { @@ -536,10 +541,10 @@ def targeted_vocab_size(self): return 2**15 # 32768 -def enfr_wordpiece_token_generator(tmp_dir, train, vocab_size): +def enfr_wordpiece_token_generator(data_dir, tmp_dir, train, vocab_size): """Instance of token generator for the WMT en->fr task.""" symbolizer_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size) + data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size) datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag) @@ -607,13 +612,13 @@ def target_space_id(self): return problem.SpaceID.FR_CHR -def mken_wordpiece_token_generator(tmp_dir, train, vocab_size): +def mken_wordpiece_token_generator(data_dir, tmp_dir, train, vocab_size): """Wordpiece generator for the SETimes Mk-En dataset.""" datasets = _MKEN_TRAIN_DATASETS if train else _MKEN_TEST_DATASETS - source_datasets = [[item[0], [item[1][0]]] for item in datasets] - target_datasets = [[item[0], [item[1][1]]] for item in datasets] + source_datasets = [[item[0], [item[1][0]]] for item in _MKEN_TRAIN_DATASETS] + target_datasets = [[item[0], [item[1][1]]] for item in _MKEN_TRAIN_DATASETS] symbolizer_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size, + data_dir, tmp_dir, "vocab.mken.%d" % vocab_size, vocab_size, source_datasets + target_datasets) tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "setimes_mken_tok_%s" % tag) @@ -650,15 +655,15 @@ def parsing_character_generator(tmp_dir, train): return character_generator(text_filepath, tags_filepath, character_vocab, EOS) -def tabbed_parsing_token_generator(tmp_dir, train, prefix, source_vocab_size, - target_vocab_size): +def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, + source_vocab_size, target_vocab_size): """Generate source and target data from a single file.""" source_vocab = generator_utils.get_or_generate_tabbed_vocab( - tmp_dir, "parsing_train.pairs", 0, - prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size) + data_dir, tmp_dir, "parsing_train.pairs", 0, + prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size) target_vocab = generator_utils.get_or_generate_tabbed_vocab( - tmp_dir, "parsing_train.pairs", 1, - prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) + data_dir, tmp_dir, "parsing_train.pairs", 1, + prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size) filename = "parsing_%s" % ("train" if train else "dev") pair_filepath = os.path.join(tmp_dir, filename + ".pairs") return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) @@ -672,9 +677,9 @@ def tabbed_parsing_character_generator(tmp_dir, train): return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) -def parsing_token_generator(tmp_dir, train, vocab_size): +def parsing_token_generator(data_dir, tmp_dir, train, vocab_size): symbolizer_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "tokens.vocab.%d" % vocab_size, vocab_size) + data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size) filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev") tree_filepath = os.path.join(tmp_dir, filename) return wsj_parsing.token_generator(tree_filepath, symbolizer_vocab, diff --git a/tensor2tensor/data_generators/wsj_parsing.py b/tensor2tensor/data_generators/wsj_parsing.py index 7734db646..200754e16 100644 --- a/tensor2tensor/data_generators/wsj_parsing.py +++ b/tensor2tensor/data_generators/wsj_parsing.py @@ -86,7 +86,7 @@ def token_generator(tree_path, source_token_vocab, target_token_vocab, tree_line = tree_file.readline() -def parsing_token_generator(tmp_dir, train, source_vocab_size, +def parsing_token_generator(data_dir, tmp_dir, train, source_vocab_size, target_vocab_size): """Generator for parsing as a sequence-to-sequence task that uses tokens. @@ -94,8 +94,9 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size, trees in wsj format. Args: - tmp_dir: path to the file with source sentences. - train: path to the file with target sentences. + data_dir: path to the data directory. + tmp_dir: path to temporary storage directory. + train: whether we're training or not. source_vocab_size: source vocab size. target_vocab_size: target vocab size. @@ -103,10 +104,10 @@ def parsing_token_generator(tmp_dir, train, source_vocab_size, A generator to a dictionary of inputs and outputs. """ source_symbolizer_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "wsj_source.tokens.vocab.%d" % source_vocab_size, + data_dir, tmp_dir, "wsj_source.vocab.%d" % source_vocab_size, source_vocab_size) target_symbolizer_vocab = generator_utils.get_or_generate_vocab( - tmp_dir, "wsj_target.tokens.vocab.%d" % target_vocab_size, + data_dir, tmp_dir, "wsj_target.vocab.%d" % target_vocab_size, target_vocab_size) filename = "%s_%s.trees" % (FLAGS.parsing_path, "train" if train else "dev") tree_filepath = os.path.join(tmp_dir, filename) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index b24f7fa50..c693d1ca3 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -324,7 +324,7 @@ def transformer_big_single_gpu(): def transformer_base_single_gpu(): """HParams for transformer base model for single gpu.""" hparams = transformer_base() - hparams.batch_size = 8192 + hparams.batch_size = 2048 hparams.learning_rate_warmup_steps = 16000 hparams.batching_mantissa_bits = 2 return hparams diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index f7d3010a9..9b0e10fcb 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -126,10 +126,10 @@ def _save_until_eos(hyp): """Strips everything after the first token, which is normally 1.""" try: - index = list(hyp).index(text_encoder.EOS_TOKEN) + index = list(hyp).index(text_encoder.EOS_ID) return hyp[0:index] except ValueError: - # No EOS_TOKEN: return the array as-is. + # No EOS_ID: return the array as-is. return hyp @@ -745,7 +745,7 @@ def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs, for inputs in sorted_inputs[b * FLAGS.decode_batch_size: (b + 1) * FLAGS.decode_batch_size]: input_ids = vocabulary.encode(inputs) - input_ids.append(text_encoder.EOS_TOKEN) + input_ids.append(text_encoder.EOS_ID) batch_inputs.append(input_ids) if len(input_ids) > batch_length: batch_length = len(input_ids) @@ -838,7 +838,7 @@ def _interactive_input_fn(hparams): if input_type == "text": input_ids = vocabulary.encode(input_string) if has_input: - input_ids.append(text_encoder.EOS_TOKEN) + input_ids.append(text_encoder.EOS_ID) x = [num_samples, decode_length, len(input_ids)] + input_ids assert len(x) < const_array_size x += [0] * (const_array_size - len(x)) From 60dd5e0c333b4631db392745ddcdab23b95f4da0 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 19 Jul 2017 19:35:47 -0700 Subject: [PATCH 5/7] Add tests for genetics problems PiperOrigin-RevId: 162569505 --- .../data_generators/genetics_test.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tensor2tensor/data_generators/genetics_test.py diff --git a/tensor2tensor/data_generators/genetics_test.py b/tensor2tensor/data_generators/genetics_test.py new file mode 100644 index 000000000..70b4fe495 --- /dev/null +++ b/tensor2tensor/data_generators/genetics_test.py @@ -0,0 +1,65 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Genetics problems.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +import numpy as np + +from tensor2tensor.data_generators import genetics + +import tensorflow as tf + + +class GeneticsTest(tf.test.TestCase): + + def _oneHotBases(self, bases): + one_hots = [] + for base_id in bases: + one_hot = [False] * 4 + if base_id < 4: + one_hot[base_id] = True + one_hots.append(one_hot) + return np.array(one_hots) + + def testRecordToExample(self): + inputs = self._oneHotBases([0, 1, 3, 4, 1, 0]) + mask = np.array([True, False, True]) + outputs = np.array([[1.0, 2.0, 3.0], [5.0, 1.0, 0.2], [5.1, 2.3, 2.3]]) + ex_dict = genetics.to_example_dict(inputs, mask, outputs) + + self.assertAllEqual([2, 3, 5, 6, 3, 2, 1], ex_dict["inputs"]) + self.assertAllEqual([1.0, 0.0, 1.0], ex_dict["targets_mask"]) + self.assertAllEqual([1.0, 2.0, 3.0, 5.0, 1.0, 0.2, 5.1, 2.3, 2.3], + ex_dict["targets"]) + self.assertAllEqual([3, 3], ex_dict["targets_shape"]) + + def testGenerateShardArgs(self): + num_examples = 37 + num_shards = 4 + outfiles = [str(i) for i in range(num_shards)] + shard_args = genetics.generate_shard_args(outfiles, num_examples) + + starts, ends, fnames = zip(*shard_args) + self.assertAllEqual([0, 9, 18, 27], starts) + self.assertAllEqual([9, 18, 27, 37], ends) + self.assertAllEqual(fnames, outfiles) + + +if __name__ == "__main__": + tf.test.main() From a7339cdf81d1dc134a6116e2ca1413731eb5eddd Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 19 Jul 2017 19:36:10 -0700 Subject: [PATCH 6/7] v1.1.1 PiperOrigin-RevId: 162569525 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d8fd19cf4..9da5293b9 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.1.0', + version='1.1.1', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', From 2fd79ec8b708101956b03890ac8d760b309e2683 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 19 Jul 2017 20:01:18 -0700 Subject: [PATCH 7/7] Update readme and make genetics module optional PiperOrigin-RevId: 162570620 --- README.md | 6 ++++-- tensor2tensor/data_generators/all_problems.py | 10 +++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0564a9c99..c0e34e0fe 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,10 @@ send along a pull request to add your dataset or model. See [our contribution doc](CONTRIBUTING.md) for details and our [open issues](https://github.com/tensorflow/tensor2tensor/issues). -And chat with us and other users on -[Gitter](https://gitter.im/tensor2tensor/Lobby). +You can chat with us and other users on +[Gitter](https://gitter.im/tensor2tensor/Lobby) and please join our +[Google Group](https://groups.google.com/forum/#!forum/tensor2tensor) to keep up +with T2T announcements. Here is a one-command version that installs tensor2tensor, downloads the data, trains an English-German translation model, and lets you use it interactively: diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 0a2503bd2..93a8a06a2 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -21,7 +21,6 @@ from tensor2tensor.data_generators import algorithmic from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio -from tensor2tensor.data_generators import genetics from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb @@ -29,4 +28,13 @@ from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing + +# Problem modules that require optional dependencies +# pylint: disable=g-import-not-at-top +try: + # Requires h5py + from tensor2tensor.data_generators import genetics +except ImportError: + pass +# pylint: enable=g-import-not-at-top # pylint: enable=unused-import