diff --git a/.travis.yml b/.travis.yml index 370682401..744006762 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,6 +24,6 @@ script: - mkdir $T2T_TRAIN_DIR - t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR - t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR - - t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10' + - t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10,use_last_position_only=True' git: depth: 3 diff --git a/README.md b/README.md index 0e97770ba..9525e9bcb 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,7 @@ registrations. To add a new dataset, subclass [`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) and register it with `@registry.register_problem`. See -[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py) for an example. Also see the [data generators diff --git a/docs/new_problem.md b/docs/new_problem.md index ab5dd5e26..48976a61b 100644 --- a/docs/new_problem.md +++ b/docs/new_problem.md @@ -105,7 +105,7 @@ We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file -[`wmt.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py). +[`translate.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate.py). We will import `character_generator` and [`text_encoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py) to write: diff --git a/docs/walkthrough.md b/docs/walkthrough.md index 0e97770ba..9525e9bcb 100644 --- a/docs/walkthrough.md +++ b/docs/walkthrough.md @@ -286,7 +286,7 @@ registrations. To add a new dataset, subclass [`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) and register it with `@registry.register_problem`. See -[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py) for an example. Also see the [data generators diff --git a/setup.py b/setup.py index 5b6f4690e..88ed4a4ea 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.2.5', + version='1.2.6', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100755 new mode 100644 index b3016c994..eba408074 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -82,9 +82,9 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), "parsing_english_ptb8k": ( - lambda: wmt.parsing_token_generator( + lambda: translate.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13), - lambda: wmt.parsing_token_generator( + lambda: translate.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, False, 2**13)), "parsing_english_ptb16k": ( lambda: wsj_parsing.parsing_token_generator( diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder old mode 100755 new mode 100644 index ff143f5d4..c2bf97f94 --- a/tensor2tensor/bin/t2t-decoder +++ b/tensor2tensor/bin/t2t-decoder @@ -84,6 +84,7 @@ def main(_): decode_hp = decoding.decode_hparams(FLAGS.decode_hparams) decode_hp.add_hparam("shards", FLAGS.decode_shards) + decode_hp.add_hparam("shard_id", FLAGS.worker_id) if FLAGS.decode_interactive: decoding.decode_interactively(estimator, decode_hp) elif FLAGS.decode_from_file: diff --git a/tensor2tensor/bin/t2t-make-tf-configs b/tensor2tensor/bin/t2t-make-tf-configs old mode 100755 new mode 100644 diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer old mode 100755 new mode 100644 diff --git a/tensor2tensor/data_generators/README.md b/tensor2tensor/data_generators/README.md index 0e6d64dd2..04a90a778 100644 --- a/tensor2tensor/data_generators/README.md +++ b/tensor2tensor/data_generators/README.md @@ -23,7 +23,7 @@ All tasks produce TFRecord files of `tensorflow.Example` protocol buffers. To add a new problem, subclass [`Problem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py) and register it with `@registry.register_problem`. See -[`WMTEnDeTokens8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +[`TranslateEndeWmt8k`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py) for an example. `Problem`s support data generation, training, and decoding. @@ -37,7 +37,7 @@ for training/decoding, e.g. a vocabulary file. A particularly easy way to implement `Problem.generate_data` for your dataset is to create 2 Python generators, one for the training data and another for the dev data, and pass them to `generator_utils.generate_dataset_and_shuffle`. See -[`WMTEnDeTokens8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +[`TranslateEndeWmt8k.generate_data`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py) for an example of usage. The generators should yield dictionaries with string keys and values being lists @@ -66,5 +66,5 @@ Some examples: * [Algorithmic problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic.py) and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/algorithmic_test.py) -* [WMT problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py) +* [WMT En-De problems](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/translate_ende.py) and their [unit tests](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt_test.py) diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 1a65c628a..c7f364cf1 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -29,16 +29,16 @@ from tensor2tensor.data_generators import image from tensor2tensor.data_generators import imdb from tensor2tensor.data_generators import lm1b +from tensor2tensor.data_generators import multinli from tensor2tensor.data_generators import problem_hparams from tensor2tensor.data_generators import ptb from tensor2tensor.data_generators import snli -from tensor2tensor.data_generators import wiki -from tensor2tensor.data_generators import translate -from tensor2tensor.data_generators import translate_enfr -from tensor2tensor.data_generators import translate_ende from tensor2tensor.data_generators import translate_encs -from tensor2tensor.data_generators import translate_enzh +from tensor2tensor.data_generators import translate_ende +from tensor2tensor.data_generators import translate_enfr from tensor2tensor.data_generators import translate_enmk +from tensor2tensor.data_generators import translate_enzh +from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wsj_parsing diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index c0f6756a5..239d1af99 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -19,9 +19,9 @@ from __future__ import division from __future__ import print_function +import hashlib import os import tarfile -import hashlib # Dependency imports @@ -39,6 +39,7 @@ _DAILYMAIL_STORIES_DRIVE_URL = "https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs" + # Note: using See et al. (2017) as reference for data generation # For more info, use the links below @@ -47,13 +48,17 @@ _DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt" _TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt" + # End-of-sentence marker. EOS = text_encoder.EOS_ID + # Techniques for data prep from See et al. (2017) -dm_single_close_quote = u'\u2019' # unicode -dm_double_close_quote = u'\u201d' -END_TOKENS = [u'.', u'!', u'?', u'...', u"'", u"`", u'"', dm_single_close_quote, dm_double_close_quote, u")"] # acceptable ways to end a sentence +dm_single_close_quote = u"\u2019" # unicode +dm_double_close_quote = u"\u201d" +# Acceptable ways to end a sentence. +END_TOKENS = [u".", u"!", u"?", u"...", u"'", u"`", u"\"", + dm_single_close_quote, dm_double_close_quote, u")"] def _maybe_download_corpora(tmp_dir, is_training): @@ -61,9 +66,11 @@ def _maybe_download_corpora(tmp_dir, is_training): Args: tmp_dir: directory containing dataset. + is_training: whether we're in training mode or not. Returns: - list of all files generated and path to file containing train/dev/test split info. + List of all files generated and path to file containing + train/dev/test split info. """ cnn_filename = "cnn_stories.tgz" cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/") @@ -85,43 +92,52 @@ def _maybe_download_corpora(tmp_dir, is_training): all_files = cnn_files + dailymail_files if is_training: - urls_path = generator_utils.maybe_download(tmp_dir, "all_train.txt", _TRAIN_URLS) + urls_path = generator_utils.maybe_download( + tmp_dir, "all_train.txt", _TRAIN_URLS) else: - urls_path = generator_utils.maybe_download(tmp_dir, "all_val.txt", _DEV_URLS) + urls_path = generator_utils.maybe_download( + tmp_dir, "all_val.txt", _DEV_URLS) return all_files, urls_path + def example_splits(url_file, all_files): + """Generate splits of the data.""" def generate_hash(inp): - """Generate a sha1 hash to match the raw url to the filename extracted""" - h = hashlib.sha1() - h.update(inp) - return h.hexdigest() + """Generate a sha1 hash to match the raw url to the filename extracted.""" + h = hashlib.sha1() + h.update(inp) + return h.hexdigest() - all_files_map = {f.split("/")[-1]:f for f in all_files} + all_files_map = {f.split("/")[-1]: f for f in all_files} urls = [] for line in tf.gfile.Open(url_file): - urls.append(line.strip().encode('utf-8')) + urls.append(line.strip().encode("utf-8")) filelist = [] for url in urls: - url_hash = generate_hash(url) - filename = url_hash + ".story" - if filename not in all_files_map: - tf.logging.info("Missing file: %s" % url) - continue - filelist.append(all_files_map[filename]) + url_hash = generate_hash(url) + filename = url_hash + ".story" + if filename not in all_files_map: + tf.logging.info("Missing file: %s" % url) + continue + filelist.append(all_files_map[filename]) tf.logging.info("Found %d examples" % len(filelist)) return filelist + def example_generator(tmp_dir, is_training, sum_token): + """Generate examples.""" def fix_run_on_sents(line): - if u"@highlight" in line: return line - if line=="": return line - if line[-1] in END_TOKENS: return line + if u"@highlight" in line: + return line + if not line: + return line + if line[-1] in END_TOKENS: + return line return line + u"." all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training) @@ -133,28 +149,33 @@ def fix_run_on_sents(line): summary = [] reading_highlights = False for line in tf.gfile.Open(story_file, "rb"): - line = unicode(line.strip(), "utf-8") if six.PY2 else line.strip().decode("utf-8") + if six.PY2: + line = unicode(line.strip(), "utf-8") + else: + line = line.strip().decode("utf-8") line = fix_run_on_sents(line) - if line == "": - continue + if not line: + continue elif line.startswith(u"@highlight"): - if len(story) == 0: break # No article text - reading_highlights = True + if not story: + break # No article text. + reading_highlights = True elif reading_highlights: - summary.append(line) + summary.append(line) else: - story.append(line) + story.append(line) - if len(story) == 0 or len(summary) == 0: - continue + if (not story) or not summary: + continue yield " ".join(story) + story_summary_split_token + " ".join(summary) + def _story_summary_split(story): split_str = u" " split_str_len = len(split_str) split_pos = story.find(split_str) - return story[:split_pos], story[split_pos+split_str_len:] # story, summary + return story[:split_pos], story[split_pos+split_str_len:] # story, summary @registry.register_problem diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 984694e47..55ccf117e 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -263,6 +263,7 @@ def gunzip_file(gz_path, new_path): for line in gz_file: new_file.write(line) + def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, generator): """Inner implementation for vocab generators. @@ -301,10 +302,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, return vocab -def get_or_generate_vocab(data_dir, - tmp_dir, - vocab_filename, - vocab_size, +def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size, sources): """Generate a vocabulary from the datasets in sources.""" diff --git a/tensor2tensor/data_generators/ice_parsing.py b/tensor2tensor/data_generators/ice_parsing.py index 99586ef83..fdb53430a 100644 --- a/tensor2tensor/data_generators/ice_parsing.py +++ b/tensor2tensor/data_generators/ice_parsing.py @@ -32,7 +32,7 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators.translate import tabbed_generator +from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry @@ -51,7 +51,8 @@ def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, data_dir, tmp_dir, filename, 1, prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) pair_filepath = os.path.join(tmp_dir, filename) - return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) + return translate.tabbed_generator(pair_filepath, source_vocab, target_vocab, + EOS) def tabbed_parsing_character_generator(tmp_dir, train): @@ -59,7 +60,8 @@ def tabbed_parsing_character_generator(tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() filename = "parsing_{0}.pairs".format("train" if train else "dev") pair_filepath = os.path.join(tmp_dir, filename) - return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) + return translate.tabbed_generator(pair_filepath, character_vocab, + character_vocab, EOS) @registry.register_problem diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index df497019a..e9ae45f01 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -227,7 +227,7 @@ def feature_encoders(self, data_dir): # This vocab file must be present within the data directory. vocab_filename = os.path.join(data_dir, "charset_size134.txt") return { - "inputs": text_encoder.TextEncoder(), + "inputs": text_encoder.ImageEncoder(), "targets": text_encoder.SubwordTextEncoder(vocab_filename) } @@ -273,7 +273,7 @@ def class_labels(self): def feature_encoders(self, data_dir): del data_dir return { - "inputs": text_encoder.TextEncoder(), + "inputs": text_encoder.ImageEncoder(), "targets": text_encoder.ClassLabelEncoder(self.class_labels) } diff --git a/tensor2tensor/data_generators/multinli.py b/tensor2tensor/data_generators/multinli.py new file mode 100644 index 000000000..acd3a2c58 --- /dev/null +++ b/tensor2tensor/data_generators/multinli.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for MultiNLI (https://www.nyu.edu/projects/bowman/multinli/). +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os +import zipfile + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.utils import metrics +from tensor2tensor.utils import registry + +import tensorflow as tf + +EOS = text_encoder.EOS_ID + + +class MultinliProblem(problem.Problem): + """Base class for MultiNLI classification problems.""" + + _ZIP = 'multinli_1.0.zip' + _URL = 'https://www.nyu.edu/projects/bowman/multinli/' + _ZIP + _LABEL_DICT = {'contradiction': 0, + 'entailment': 1, + 'neutral': 2} + _LABELS = {'contradiction', 'entailment', 'neutral'} + + @property + def num_shards(self): + return 10 + + @property + def vocab_file(self): + if self._matched: + return 'multinli_matched.vocab' + else: + return 'multinli_mismatched.vocab' + + @property + def targeted_vocab_size(self): + return 2**14 + + @property + def _matched(self): + raise NotImplementedError() + + @property + def _train_file(self): + return 'multinli_1.0/multinli_1.0_train.jsonl' + + @property + def _dev_file(self): + if self._matched: + return 'multinli_1.0/multinli_1.0_dev_matched.jsonl' + else: + return 'multinli_1.0/multinli_1.0_dev_mismatched.jsonl' + + def _examples(self, data_dir, tmp_dir, train): + file_path = generator_utils.maybe_download(tmp_dir, self._ZIP, self._URL) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(tmp_dir) + zip_ref.close() + + data_file = self._train_file if train else self._dev_file + examples = [] + with tf.gfile.GFile(os.path.join(tmp_dir, data_file), mode='r') as f: + for line in f: + record = json.loads(line) + try: + label_str = record['gold_label'].encode('ascii') + if label_str != '-': + label = self._LABEL_DICT[label_str] + sentence1 = record['sentence1'].encode('ascii') + sentence2 = record['sentence2'].encode('ascii') + examples.append({'sentence1': sentence1, + 'sentence2': sentence2, + 'label': label}) + except UnicodeEncodeError: + pass + + return examples + + def _inputs_and_targets(self, encoder, examples): + for e in examples: + enc_s1 = encoder.encode(e['sentence1']) + enc_s2 = encoder.encode(e['sentence2']) + + yield { + 'inputs': enc_s1 + [EOS] + enc_s2 + [EOS], + 'targets': [e['label']] + } + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + train_paths = self.training_filepaths( + data_dir, self.num_shards, shuffled=False) + dev_paths = self.dev_filepaths(data_dir, 1, shuffled=False) + + train_examples = self._examples(data_dir, tmp_dir, train=True) + dev_examples = self._examples(data_dir, tmp_dir, train=False) + + encoder = generator_utils.get_or_generate_vocab_inner( + data_dir, self.vocab_file, self.targeted_vocab_size, + (e['sentence1'] + ' ' + e['sentence2'] + for e in train_examples + dev_examples) + ) + + generator_utils.generate_dataset_and_shuffle( + self._inputs_and_targets(encoder, train_examples), train_paths, + self._inputs_and_targets(encoder, dev_examples), dev_paths) + + def hparams(self, defaults, unused_model_hparams): + p = defaults + source_vocab_size = self._encoders['inputs'].vocab_size + p.input_modality = { + 'inputs': (registry.Modalities.SYMBOL, source_vocab_size) + } + p.target_modality = (registry.Modalities.CLASS_LABEL, 3) + p.input_space_id = problem.SpaceID.EN_TOK + p.target_space_id = problem.SpaceID.GENERIC + + def feature_encoders(self, data_dir): + vocab_filename = os.path.join(data_dir, self.vocab_file) + encoder = text_encoder.SubwordTextEncoder(vocab_filename) + return { + 'inputs': encoder, + 'targets': text_encoder.ClassLabelEncoder(self._LABELS), + } + + def example_reading_spec(self): + data_fields = { + 'inputs': tf.VarLenFeature(tf.int64), + 'targets': tf.FixedLenFeature([1], tf.int64), + } + data_items_to_decoders = None + return (data_fields, data_items_to_decoders) + + def eval_metrics(self): + return [metrics.Metrics.ACC] + + +@registry.register_problem +class MultinliMatched(MultinliProblem): + """MultiNLI with matched dev set.""" + + @property + def _matched(self): + return True + + +@registry.register_problem +class MultinliMismatched(MultinliProblem): + """MultiNLI with mismatched dev set.""" + + @property + def _matched(self): + return False diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index e46708859..657a5b18b 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -234,7 +234,7 @@ def test_filepaths(self, data_dir, num_shards, shuffled): return generator_utils.test_data_filenames(file_basename, data_dir, num_shards) - def filepattern(self, data_dir, mode): + def filepattern(self, data_dir, mode, shard=None): """Get filepattern for data files for mode. Matches mode to a suffix. @@ -246,12 +246,13 @@ def filepattern(self, data_dir, mode): Args: data_dir: str, data directory. mode: tf.estimator.ModeKeys or "test". + shard: int, if provided, will only read data from the specified shard. Returns: filepattern str """ path = os.path.join(data_dir, self.dataset_filename()) - + shard_str = "-%05d" % shard if shard is not None else "" if mode == tf.estimator.ModeKeys.TRAIN: suffix = "train" elif mode in [tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT]: @@ -260,7 +261,7 @@ def filepattern(self, data_dir, mode): assert mode == "test" suffix = "test" - return "%s-%s*" % (path, suffix) + return "%s-%s%s*" % (path, suffix, shard_str) def __init__(self, was_reversed=False, was_copy=False): """Create a Problem. @@ -328,7 +329,8 @@ def dataset(self, shuffle_files=None, hparams=None, preprocess=True, - dataset_split=None): + dataset_split=None, + shard=None): """Build a Dataset for this problem. Args: @@ -347,6 +349,7 @@ def dataset(self, Problem.preprocess_example. dataset_split: tf.estimator.ModeKeys + ["test"], which split to read data from (TRAIN:"-train", EVAL:"-dev", "test":"-test"). Defaults to mode. + shard: int, if provided, will only read data from the specified shard. Returns: Dataset containing dict. @@ -372,7 +375,7 @@ def dataset(self, } is_training = mode == tf.estimator.ModeKeys.TRAIN - data_filepattern = self.filepattern(data_dir, dataset_split) + data_filepattern = self.filepattern(data_dir, dataset_split, shard=shard) tf.logging.info("Reading data files from %s", data_filepattern) data_files = tf.contrib.slim.parallel_reader.get_data_files( data_filepattern) @@ -530,6 +533,11 @@ def _default_hparams(): # but decrease if your reader uses a lot of memory and increase if slow. max_expected_batch_size_per_shard=64, + # During inference for autoregressive problems, if the batch_size is 1, + # the inference will stop when the model predict a text_encoder.EOS_ID + # token. + stop_at_eos=int(False), + # Modalities used to map from input features to a space compatible with # chosen model architecture. One modality spec (which is a 2-tuple, # (modality_full_name, vocab_size)) per feature key. modality_full_name @@ -644,6 +652,7 @@ def feature_encoders(self, data_dir): def hparams(self, defaults, unused_model_hparams): p = defaults + p.stop_at_eos = int(True) if self.has_inputs: source_vocab_size = self._encoders["inputs"].vocab_size diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index 64eef14fe..1c720a6db 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -27,6 +27,7 @@ import collections from itertools import chain import re +import tempfile # Dependency imports @@ -464,9 +465,24 @@ def _tokens_to_subtoken_ids(self, tokens): """ ret = [] for token in tokens: - ret.extend( - self._escaped_token_to_subtoken_ids( - _escape_token(token, self._alphabet))) + ret.extend(self._token_to_subtoken_ids(token)) + return ret + + def _token_to_subtoken_ids(self, token): + """Converts token to a list of subtoken ids. + + Args: + token: a string. + Returns: + a list of integers in the range [0, vocab_size) + """ + cache_location = hash(token) % self._cache_size + cache_key, cache_value = self._cache[cache_location] + if cache_key == token: + return cache_value + ret = self._escaped_token_to_subtoken_ids( + _escape_token(token, self._alphabet)) + self._cache[cache_location] = (token, ret) return ret def _subtoken_ids_to_tokens(self, subtokens): @@ -480,7 +496,13 @@ def _subtoken_ids_to_tokens(self, subtokens): concatenated = "".join( [self._subtoken_id_to_subtoken_string(s) for s in subtokens]) split = concatenated.split("_") - return [_unescape_token(t + "_") for t in split if t] + ret = [] + for t in split: + if t: + unescaped = _unescape_token(t + "_") + if unescaped: + ret.append(unescaped) + return ret def _subtoken_id_to_subtoken_string(self, subtoken): """Converts a subtoken integer ID to a subtoken string.""" @@ -717,6 +739,9 @@ def _init_subtokens_from_list(self, subtoken_strings, reserved=0): s: i + reserved for i, s in enumerate(subtoken_strings) if s } + # Initialize the cache to empty. + self._cache_size = 2 ** 20 + self._cache = [(None, None)] * self._cache_size def _init_alphabet_from_tokens(self, tokens): """Initialize alphabet from an iterable of token or subtoken strings.""" @@ -755,3 +780,72 @@ def store_to_file(self, filename): with tf.gfile.Open(filename, "w") as f: for subtoken_string in self._all_subtoken_strings: f.write("'" + unicode_to_native(subtoken_string) + "'\n") + + +class ImageEncoder(object): + """Encoder class for saving and loading images.""" + + def __init__(self, num_reserved_ids=0, height=32, width=32, channels=3): + assert num_reserved_ids == 0 + self._height = height + self._width = width + self._channels = channels + + @property + def num_reserved_ids(self): + return 0 + + def encode(self, s): + """Transform a string with a filename into a list of RGB integers. + + Args: + s: path to the file with an image. + + Returns: + ids: list of integers + """ + # TODO(lukaszkaiser): implement this. + raise NotImplementedError + + def decode(self, ids): + """Transform a sequence of int ids into an image file. + + Args: + ids: list of integers to be converted. + + Returns: + Path to the temporary file where the image was saved. + + Raises: + ValueError: if the ids are not of the appropriate size. + """ + _, tmp_file_path = tempfile.mkstemp() + length = self._height * self._width * self._channels + if len(ids) != length: + raise ValueError("Length of ids (%d) must be height (%d) x width (%d) x " + "channels (%d); %d != %d.\n Ids: %s" + % (len(ids), self._height, self._width, self._channels, + len(ids), length, " ".join([str(i) for i in ids]))) + with tf.Graph().as_default(): + raw = tf.constant(ids, dtype=tf.uint8) + img = tf.reshape(raw, [self._height, self._width, self._channels]) + png = tf.image.encode_png(img) + op = tf.write_file(tmp_file_path, png) + with tf.Session() as sess: + sess.run(op) + return tmp_file_path + + def decode_list(self, ids): + """Transform a sequence of int ids into an image file. + + Args: + ids: list of integers to be converted. + + Returns: + Singleton list: path to the temporary file where the image was saved. + """ + return [self.decode(ids)] + + @property + def vocab_size(self): + return 256 diff --git a/tensor2tensor/data_generators/translate.py b/tensor2tensor/data_generators/translate.py index 1de25bc47..95f5844c1 100644 --- a/tensor2tensor/data_generators/translate.py +++ b/tensor2tensor/data_generators/translate.py @@ -26,9 +26,6 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import problem -from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import wsj_parsing -from tensor2tensor.utils import registry import tensorflow as tf @@ -67,7 +64,6 @@ def character_generator(source_path, target_path, character_vocab, eos=None): target_path: path to the file with target sentences. character_vocab: a TextEncoder to encode the characters. eos: integer to append at the end of each sequence (default: None). - Yields: A dictionary {"inputs": source-line, "targets": target-line} where the lines are integer lists converted from characters in the file lines. @@ -97,7 +93,6 @@ def tabbed_generator(source_path, source_vocab, target_vocab, eos=None): source_vocab: a SubwordTextEncoder to encode the source string. target_vocab: a SubwordTextEncoder to encode the target string. eos: integer to append at the end of each sequence (default: None). - Yields: A dictionary {"inputs": source-line, "targets": target-line} where the lines are integer lists converted from characters in the file lines. @@ -126,7 +121,6 @@ def token_generator(source_path, target_path, token_vocab, eos=None): target_path: path to the file with target sentences. token_vocab: text_encoder.TextEncoder object. eos: integer to append at the end of each sequence (default: None). - Yields: A dictionary {"inputs": source-line, "targets": target-line} where the lines are integer lists converted from tokens in the file lines. @@ -160,7 +154,6 @@ def bi_vocabs_token_generator(source_path, source_token_vocab: text_encoder.TextEncoder object. target_token_vocab: text_encoder.TextEncoder object. eos: integer to append at the end of each sequence (default: None). - Yields: A dictionary {"inputs": source-line, "targets": target-line} where the lines are integer lists converted from tokens in the file lines. @@ -175,6 +168,7 @@ def bi_vocabs_token_generator(source_path, yield {"inputs": source_ints, "targets": target_ints} source, target = source_file.readline(), target_file.readline() + def _preprocess_sgm(line, is_sgm): """Preprocessing to strip tags in SGM files.""" if not is_sgm: @@ -192,7 +186,8 @@ def _preprocess_sgm(line, is_sgm): i = line.index(">") return line[i + 1:-6] # Strip first and last . -def _compile_data(tmp_dir, datasets, filename): + +def compile_data(tmp_dir, datasets, filename): """Concatenate all `datasets` and save to `filename`.""" filename = os.path.join(tmp_dir, filename) with tf.gfile.GFile(filename + ".lang1", mode="w") as lang1_resfile: @@ -229,8 +224,8 @@ def _compile_data(tmp_dir, datasets, filename): lang1_filename, lang2_filename = dataset[1] lang1_filepath = os.path.join(tmp_dir, lang1_filename) lang2_filepath = os.path.join(tmp_dir, lang2_filename) - is_sgm = (lang1_filename.endswith("sgm") and - lang2_filename.endswith("sgm")) + is_sgm = ( + lang1_filename.endswith("sgm") and lang2_filename.endswith("sgm")) if not (os.path.exists(lang1_filepath) and os.path.exists(lang2_filepath)): @@ -258,5 +253,3 @@ def _compile_data(tmp_dir, datasets, filename): line1, line2 = lang1_file.readline(), lang2_file.readline() return filename - - diff --git a/tensor2tensor/data_generators/translate_encs.py b/tensor2tensor/data_generators/translate_encs.py index 211d27413..ad0fe828d 100644 --- a/tensor2tensor/data_generators/translate_encs.py +++ b/tensor2tensor/data_generators/translate_encs.py @@ -19,16 +19,12 @@ from __future__ import division from __future__ import print_function -import os -import tarfile - # Dependency imports from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import translate from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry import tensorflow as tf @@ -39,11 +35,9 @@ EOS = text_encoder.EOS_ID _ENCS_TRAIN_DATASETS = [ - [ - ("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/" - "11234/1-1458/data-plaintext-format.tar"), - ("tsv", 3, 2, "data.plaintext-format/*train.gz") - ], + [("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/" + "11234/1-1458/data-plaintext-format.tar"), + ("tsv", 3, 2, "data.plaintext-format/*train.gz")], [ "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long ("training/news-commentary-v12.cs-en.en", @@ -82,20 +76,24 @@ def generator(self, data_dir, tmp_dir, train): datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS tag = "train" if train else "dev" vocab_datasets = [] - data_path = translate._compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_encs_tok_%s" % tag) # CzEng contains 100 gz files with tab-separated columns, so let's expect # it is the first dataset in datasets and use the newly created *.lang{1,2} # files for vocab construction. if datasets[0][0].endswith("data-plaintext-format.tar"): - vocab_datasets.append([datasets[0][0], ["wmt_encs_tok_%s.lang1" % tag, - "wmt_encs_tok_%s.lang2" % tag]]) + vocab_datasets.append([ + datasets[0][0], + ["wmt_encs_tok_%s.lang1" % tag, + "wmt_encs_tok_%s.lang2" % tag] + ]) datasets = datasets[1:] vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets] symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets) return translate.token_generator(data_path + ".lang1", data_path + ".lang2", - symbolizer_vocab, EOS) + symbolizer_vocab, EOS) @property def input_space_id(self): @@ -118,9 +116,10 @@ def generator(self, data_dir, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "wmt_encs_chr_%s" % tag) - return translate.character_generator(data_path + ".lang1", data_path + ".lang2", - character_vocab, EOS) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_encs_chr_%s" % tag) + return translate.character_generator( + data_path + ".lang1", data_path + ".lang2", character_vocab, EOS) @property def input_space_id(self): @@ -129,5 +128,3 @@ def input_space_id(self): @property def target_space_id(self): return problem.SpaceID.CS_CHR - - diff --git a/tensor2tensor/data_generators/translate_ende.py b/tensor2tensor/data_generators/translate_ende.py index 01fe77b85..7358e9b7e 100644 --- a/tensor2tensor/data_generators/translate_ende.py +++ b/tensor2tensor/data_generators/translate_ende.py @@ -25,10 +25,9 @@ # Dependency imports from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import translate from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry import tensorflow as tf @@ -103,8 +102,8 @@ def generator(self, data_dir, tmp_dir, train): with tf.gfile.GFile(token_path, mode="a") as f: f.write("UNK\n") # Add UNK to the vocab. token_vocab = text_encoder.TokenTextEncoder(token_path, replace_oov="UNK") - return translate.token_generator(train_path + ".en", train_path + ".de", token_vocab, - EOS) + return translate.token_generator(train_path + ".en", train_path + ".de", + token_vocab, EOS) @property def input_space_id(self): @@ -115,7 +114,6 @@ def target_space_id(self): return problem.SpaceID.DE_BPE_TOK - @registry.register_problem class TranslateEndeWmt8k(translate.TranslateProblem): """Problem spec for WMT En-De translation.""" @@ -130,12 +128,14 @@ def vocab_name(self): def generator(self, data_dir, tmp_dir, train): symbolizer_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, _ENDE_TRAIN_DATASETS) + data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, + _ENDE_TRAIN_DATASETS) datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "wmt_ende_tok_%s" % tag) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_ende_tok_%s" % tag) return translate.token_generator(data_path + ".lang1", data_path + ".lang2", - symbolizer_vocab, EOS) + symbolizer_vocab, EOS) @property def input_space_id(self): @@ -170,9 +170,10 @@ def generator(self, _, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "wmt_ende_chr_%s" % tag) - return translate.character_generator(data_path + ".lang1", data_path + ".lang2", - character_vocab, EOS) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_ende_chr_%s" % tag) + return translate.character_generator( + data_path + ".lang1", data_path + ".lang2", character_vocab, EOS) @property def input_space_id(self): @@ -181,4 +182,3 @@ def input_space_id(self): @property def target_space_id(self): return problem.SpaceID.DE_CHR - diff --git a/tensor2tensor/data_generators/translate_enfr.py b/tensor2tensor/data_generators/translate_enfr.py index 01e4e8f82..152d3d963 100644 --- a/tensor2tensor/data_generators/translate_enfr.py +++ b/tensor2tensor/data_generators/translate_enfr.py @@ -19,16 +19,12 @@ from __future__ import division from __future__ import print_function -import os -import tarfile - # Dependency imports from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import translate from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry import tensorflow as tf @@ -41,41 +37,45 @@ _ENFR_TRAIN_DATASETS = [ [ "https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz", - ("baseline-1M-enfr/baseline-1M_train.en", "baseline-1M-enfr/baseline-1M_train.fr") + ("baseline-1M-enfr/baseline-1M_train.en", + "baseline-1M-enfr/baseline-1M_train.fr") ], -# [ -# "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz", -# ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr") -# ], -# [ -# "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz", -# ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr") -# ], -# [ -# "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz", -# ("training/news-commentary-v9.fr-en.en", -# "training/news-commentary-v9.fr-en.fr") -# ], -# [ -# "http://www.statmt.org/wmt10/training-giga-fren.tar", -# ("giga-fren.release2.fixed.en.gz", "giga-fren.release2.fixed.fr.gz") -# ], -# [ -# "http://www.statmt.org/wmt13/training-parallel-un.tgz", -# ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr") -# ], + # [ + # "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz", + # ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr") + # ], + # [ + # "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz", + # ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr") + # ], + # [ + # "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz", + # ("training/news-commentary-v9.fr-en.en", + # "training/news-commentary-v9.fr-en.fr") + # ], + # [ + # "http://www.statmt.org/wmt10/training-giga-fren.tar", + # ("giga-fren.release2.fixed.en.gz", + # "giga-fren.release2.fixed.fr.gz") + # ], + # [ + # "http://www.statmt.org/wmt13/training-parallel-un.tgz", + # ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr") + # ], ] _ENFR_TEST_DATASETS = [ [ "https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz", - ("baseline-1M-enfr/baseline-1M_valid.en", "baseline-1M-enfr/baseline-1M_valid.fr") + ("baseline-1M-enfr/baseline-1M_valid.en", + "baseline-1M-enfr/baseline-1M_valid.fr") ], -# [ -# "http://data.statmt.org/wmt17/translation-task/dev.tgz", -# ("dev/newstest2013.en", "dev/newstest2013.fr") -# ], + # [ + # "http://data.statmt.org/wmt17/translation-task/dev.tgz", + # ("dev/newstest2013.en", "dev/newstest2013.fr") + # ], ] + @registry.register_problem class TranslateEnfrWmt8k(translate.TranslateProblem): """Problem spec for WMT En-Fr translation.""" @@ -90,12 +90,14 @@ def vocab_name(self): def generator(self, data_dir, tmp_dir, train): symbolizer_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, _ENFR_TRAIN_DATASETS) + data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, + _ENFR_TRAIN_DATASETS) datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "wmt_enfr_tok_%s" % tag) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_enfr_tok_%s" % tag) return translate.token_generator(data_path + ".lang1", data_path + ".lang2", - symbolizer_vocab, EOS) + symbolizer_vocab, EOS) @property def input_space_id(self): @@ -130,9 +132,10 @@ def generator(self, data_dir, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "wmt_enfr_chr_%s" % tag) - return translate.character_generator(data_path + ".lang1", data_path + ".lang2", - character_vocab, EOS) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_enfr_chr_%s" % tag) + return translate.character_generator( + data_path + ".lang1", data_path + ".lang2", character_vocab, EOS) @property def input_space_id(self): @@ -141,6 +144,3 @@ def input_space_id(self): @property def target_space_id(self): return problem.SpaceID.FR_CHR - - - diff --git a/tensor2tensor/data_generators/translate_enmk.py b/tensor2tensor/data_generators/translate_enmk.py index f6c934121..aa1bac8b1 100644 --- a/tensor2tensor/data_generators/translate_enmk.py +++ b/tensor2tensor/data_generators/translate_enmk.py @@ -19,16 +19,12 @@ from __future__ import division from __future__ import print_function -import os -import tarfile - # Dependency imports from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import translate from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry import tensorflow as tf @@ -53,6 +49,7 @@ ("dev.mk", "dev.en") ]] + @registry.register_problem class TranslateEnmkSetimes32k(translate.TranslateProblem): """Problem spec for SETimes Mk-En translation.""" @@ -73,12 +70,13 @@ def generator(self, data_dir, tmp_dir, train): data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, source_datasets + target_datasets) tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "setimes_mken_tok_%s" % tag) + data_path = translate.compile_data(tmp_dir, datasets, + "setimes_mken_tok_%s" % tag) # We generate English->X data by convention, to train reverse translation # just add the "_rev" suffix to the problem name, e.g., like this. # --problems=translate_enmk_setimes32k_rev return translate.token_generator(data_path + ".lang2", data_path + ".lang1", - symbolizer_vocab, EOS) + symbolizer_vocab, EOS) @property def input_space_id(self): @@ -87,5 +85,3 @@ def input_space_id(self): @property def target_space_id(self): return problem.SpaceID.EN_TOK - - diff --git a/tensor2tensor/data_generators/translate_enzh.py b/tensor2tensor/data_generators/translate_enzh.py index d1b7f7c20..7c77a05fc 100644 --- a/tensor2tensor/data_generators/translate_enzh.py +++ b/tensor2tensor/data_generators/translate_enzh.py @@ -20,15 +20,13 @@ from __future__ import print_function import os -import tarfile # Dependency imports from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import translate from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import translate from tensor2tensor.utils import registry import tensorflow as tf @@ -48,6 +46,7 @@ ("dev/newsdev2017-zhen-src.zh.sgm", "dev/newsdev2017-zhen-ref.en.sgm") ]] + @registry.register_problem class TranslateEnzhWmt8k(translate.TranslateProblem): """Problem spec for WMT Zh-En translation.""" @@ -79,12 +78,14 @@ def generator(self, data_dir, tmp_dir, train): data_dir, tmp_dir, self.target_vocab_name, self.targeted_vocab_size, target_datasets) tag = "train" if train else "dev" - data_path = translate._compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag) + data_path = translate.compile_data(tmp_dir, datasets, + "wmt_zhen_tok_%s" % tag) # We generate English->X data by convention, to train reverse translation # just add the "_rev" suffix to the problem name, e.g., like this. # --problems=translate_enzh_wmt8k_rev - return translate.bi_vocabs_token_generator(data_path + ".lang2", data_path + ".lang1", - source_vocab, target_vocab, EOS) + return translate.bi_vocabs_token_generator(data_path + ".lang2", + data_path + ".lang1", + source_vocab, target_vocab, EOS) @property def input_space_id(self): @@ -103,5 +104,3 @@ def feature_encoders(self, data_dir): "inputs": source_token, "targets": target_token, } - - diff --git a/tensor2tensor/data_generators/translate_test.py b/tensor2tensor/data_generators/translate_test.py index f082c1a85..e357e11fc 100644 --- a/tensor2tensor/data_generators/translate_test.py +++ b/tensor2tensor/data_generators/translate_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""WMT generators test.""" +"""Translate generators test.""" from __future__ import absolute_import from __future__ import division @@ -32,7 +32,7 @@ import tensorflow as tf -class WMTTest(tf.test.TestCase): +class TranslateTest(tf.test.TestCase): def testCharacterGenerator(self): # Generate a trivial source and target file. @@ -62,24 +62,16 @@ def testCharacterGenerator(self): # First check that the results match the encoded original strings; # this is a comparison of integer arrays. self.assertEqual(len(results_src), 2) - self.assertEqual(results_src[0], - character_vocab.encode("source1")) - self.assertEqual(results_src[1], - character_vocab.encode("source2")) - self.assertEqual(results_tgt[0], - character_vocab.encode("target1")) - self.assertEqual(results_tgt[1], - character_vocab.encode("target2")) + self.assertEqual(results_src[0], character_vocab.encode("source1")) + self.assertEqual(results_src[1], character_vocab.encode("source2")) + self.assertEqual(results_tgt[0], character_vocab.encode("target1")) + self.assertEqual(results_tgt[1], character_vocab.encode("target2")) # Then decode the results and compare with the original strings; # this is a comparison of strings - self.assertEqual(character_vocab.decode(results_src[0]), - "source1") - self.assertEqual(character_vocab.decode(results_src[1]), - "source2") - self.assertEqual(character_vocab.decode(results_tgt[0]), - "target1") - self.assertEqual(character_vocab.decode(results_tgt[1]), - "target2") + self.assertEqual(character_vocab.decode(results_src[0]), "source1") + self.assertEqual(character_vocab.decode(results_src[1]), "source2") + self.assertEqual(character_vocab.decode(results_tgt[0]), "target1") + self.assertEqual(character_vocab.decode(results_tgt[1]), "target2") # Clean up. os.remove(tmp_file_path + ".src") diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 792241632..2178e6fe5 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -466,7 +466,7 @@ def attention_bias_batch( coordinates of the batches batch_coordinates_k (tf.Tensor): int32 of shape [length_k, 1] containing the coordinates of the batches. If None, do self attention (q and k identical) - condition_fn (fct): A predicat function defining which type of mask build + condition_fn (fct): A function defining which type of mask build Returns: tf.Tensor: float32 mask of shape [length_q, length_k] containing either 0 or @@ -501,7 +501,7 @@ def to_float(bc): attention_bias_future = functools.partial( attention_bias_batch, # Elems can attend to themself (otherwise would use bias_batch + 1.0) - # No tf.abs to concider the order + # No tf.abs to consider the order # tf.maximum and tf.minimum to threshold the values condition_fn=lambda bias: tf.maximum(0.0, tf.minimum(1.0, bias)), ) @@ -1059,7 +1059,7 @@ def dot_product_attention_relative(q, def masked_local_attention_1d( q, k, v, block_length=128, name=None): - """Attention to the source position and a neigborhood to the left of it. + """Attention to the source position and a neighborhood to the left of it. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions @@ -2267,7 +2267,7 @@ def length_not_null(x, batch_coordinate): bias_batch = attention_bias_coordinates(batch_coordinate) def add_or_set_if(prev_bias, new_bias, condition): - """Add the bias together while concidering the None case.""" + """Add the bias together while considering the None case.""" if not condition: return prev_bias elif prev_bias is None: @@ -2776,7 +2776,7 @@ def get_gates_head(x, add_first=False): # Each head get its own dispatcher gates = lsh.get_gates(single_x) nb_buckets = gates.get_shape().as_list()[-1] - # Reshape to [batch, length, depth] but should concider sequence + # Reshape to [batch, length, depth] but should consider sequence # padding in that case (also dispatch the padding) gates = tf.reshape(gates, [batch_size, length, nb_buckets]) list_gates.append(gates) @@ -2958,12 +2958,13 @@ def pad_and_reshape(x): @expert_utils.add_var_scope() def multihead_self_attention_reduced( - x, factor, reduction_type, multihead_params): + x, factor, nonlinearity, reduction_type, multihead_params): """Reduce the length dimension by compressing with conv. Args: x (tf.Tensor): float32 of shape [batch, length, depth] factor (int): compression factor for the memory sequence + nonlinearity (str): Add some non-linearity after the memory block reduction_type (str): type of compression multihead_params (dict): parameters for multihead attention @@ -2971,13 +2972,13 @@ def multihead_self_attention_reduced( (tf.Tensor): float32 of shape [batch, length, depth] Raises: - ValueError: If reduction_type invalid + ValueError: If reduction_type or nonlinearity is invalid """ depth = x.get_shape().as_list()[-1] # Could try to have some overlapp between the blocks but that would # create conv artifacts, would make it difficult to not attend to the future - # withing one group and the padding should be handled specially. + # within one group and the padding should be handled specially. # Reduce the memory dimension if reduction_type == "attention": @@ -2988,6 +2989,11 @@ def multihead_self_attention_reduced( else: raise ValueError("Unknown reduction type {}".format(reduction_type)) + if nonlinearity == "silu": + memory_x *= tf.nn.sigmoid(memory_x) + elif nonlinearity != "none": + raise ValueError("Unknown non linearity {}".format(nonlinearity)) + memory_x = tf.concat( # Add the first elem to make it attendable by everyone (otherwise the # first block cannot attend to anything) diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index 8e76c8051..a29aa93b1 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -85,6 +85,7 @@ def bottom_simple(self, x, name, reuse): return ret def bottom(self, x): + self._bottom_was_called = True if self._model_hparams.shared_embedding_and_softmax_weights: return self.bottom_simple(x, "shared", reuse=None) else: @@ -92,7 +93,11 @@ def bottom(self, x): def targets_bottom(self, x): if self._model_hparams.shared_embedding_and_softmax_weights: - return self.bottom_simple(x, "shared", reuse=True) + try: + return self.bottom_simple(x, "shared", reuse=True) + except ValueError: + # perhaps there were no inputs, and this is a new variable. + return self.bottom_simple(x, "shared", reuse=None) else: return self.bottom_simple(x, "target_emb", reuse=None) @@ -172,7 +177,11 @@ def top(self, body_output, _): dim = body_output.get_shape().as_list()[-1] // 3 out = tf.reshape(body_output, [shape[0], shape[1], shape[2], self._channels, dim]) - return tf.layers.dense(out, self.top_dimensionality) + res = tf.layers.dense(out, self.top_dimensionality) + if not tf.get_variable_scope().reuse: + res_argmax = tf.cast(tf.argmax(res, axis=-1), tf.uint8) + tf.summary.image("result", res_argmax, max_outputs=1) + return res def loss(self, top_out, targets, weights_fn=common_layers.weights_all): # Call the default implementation, but weight 1.0 on 0s by default. diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index 1eb988c4c..62ed6c6a5 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -346,14 +346,16 @@ def _recompute_grad(fn, args): """See recompute_grad.""" cached_vs = [] + cached_arg_scope = [] def grad_fn(inputs, variables, outputs, output_grads): """Recompute outputs for gradient computation.""" del outputs # Recompute outputs with tf.control_dependencies(output_grads): - with tf.variable_scope(cached_vs[0], reuse=True): - outputs = fn(*inputs) + with tf.contrib.framework.arg_scope(cached_arg_scope[0]): + with tf.variable_scope(cached_vs[0], reuse=True): + outputs = fn(*inputs) if not (isinstance(outputs, list) or isinstance(outputs, tuple)): outputs = [outputs] @@ -366,6 +368,11 @@ def grad_fn(inputs, variables, outputs, output_grads): @common_layers.fn_with_custom_grad(grad_fn) def fn_with_recompute(*args): cached_vs.append(tf.get_variable_scope()) + # TODO(rsepassi): Rm conditional in TF 1.4 + if hasattr(tf.contrib.framework, "current_arg_scope"): + cached_arg_scope.append(tf.contrib.framework.current_arg_scope()) + else: + cached_arg_scope.append({}) return fn(*args) return fn_with_recompute(*args) diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 85c7c9d49..48720cd5d 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -277,6 +277,7 @@ def print_shape(x, suffix, debug=False): preprocess(x), factor=hparams.attention_red_factor, reduction_type=hparams.attention_reduction_type, + nonlinearity=hparams.attention_nonlinearity, multihead_params=dict( total_key_depth= hparams.attention_key_channels or hparams.hidden_size, @@ -368,7 +369,7 @@ def attention_lm_moe_prepare_decoder(targets, hparams): """ targets_pad_mask = common_attention.embedding_to_padding(targets) with tf.name_scope("pad_remover"): - # Because of the shift_right, the token will be concidered as + # Because of the shift_right, the token will be considered as # padding. In practice, it doesn't really matter, due to the triangular # mask, this token should never be attended. pad_remover = expert_utils.PadRemover(targets_pad_mask) @@ -509,6 +510,9 @@ def attention_lm_moe_base(): hparams.add_hparam("attention_red_factor", 3) hparams.add_hparam("attention_block_length", 128) hparams.add_hparam("attention_reduction_type", "conv") + # Non linearity for the attention reduction. Either "none", or "silu" ( + # Sigmoid Linear-Unit described in https://arxiv.org/abs/1710.05941) + hparams.add_hparam("attention_nonlinearity", "none") # If attention_exp_factor is set, each input to local_expert_attention (of # dimensionality hidden size) is projected into attention_exp_factor smaller # inputs, each of dimensionality attention_exp_inputdim. (otherwise @@ -599,6 +603,20 @@ def attention_lm_16k(): return hparams +@registry.register_hparams +def attention_lm_12k(): + hparams = attention_lm_hybrid_v2() + hparams.batch_size = 12000 + return hparams + + +@registry.register_hparams +def attention_lm_11k(): + hparams = attention_lm_hybrid_v2() + hparams.batch_size = 11500 + return hparams + + @registry.register_hparams def attention_lm_ae_extended(): """Experiment with the exp_factor params.""" diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index baa85829c..9a090e40f 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -30,12 +30,15 @@ from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers +from tensor2tensor.utils import beam_search from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model import tensorflow as tf +from tensorflow.python.util import nest + @registry.register_model class Transformer(t2t_model.T2TModel): @@ -159,6 +162,58 @@ def _greedy_infer( logits: Not returned losses: Not returned + Raises: + ValueError: If last_position_only if False + NotImplementedError: If there are multiple data shards. + """ + decoded_ids = self._fast_decode(features, decode_length, last_position_only) + return decoded_ids, None, None + + def _beam_decode(self, features, decode_length, beam_size, top_beams, + last_position_only, alpha): + """Beam search decoding. + + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + last_position_only: MUST be true for fast decoding! + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for slonger translations. + + Returns: + samples: an integer `Tensor`. Top samples from the beam search + """ + return self._fast_decode( + features, decode_length, last_position_only, beam_size, top_beams, + alpha) + + def _fast_decode( + self, + features, + decode_length, + last_position_only=True, + beam_size=1, + top_beams=1, + alpha=1.0): + """Fast decoding. + + Implements both greedy and beam search decoding, uses beam search iff + beam_size > 1, otherwise beam search related arguments are ignored. + + Args: + features: a map of string to model features. + decode_length: an integer. How many additional timesteps to decode. + last_position_only: MUST be true for fast decoding! + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for slonger translations. + + Returns: + samples: an integer `Tensor`. Top samples from the beam search + Raises: ValueError: If last_position_only if False NotImplementedError: If there are multiple data shards. @@ -192,6 +247,8 @@ def _greedy_infer( with tf.variable_scope("body"): encoder_output, encoder_decoder_attention_bias = dp( self.encode, inputs, features["target_space_id"], hparams) + encoder_output = encoder_output[0] + encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] if hparams.pos == "timing": timing_signal = common_attention.get_timing_signal_1d( @@ -236,6 +293,7 @@ def preprocess_targets(targets, i): def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" + ids = ids[:, -1:] targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = preprocess_targets(targets, i) @@ -245,8 +303,8 @@ def symbols_to_logits_fn(ids, i, cache): body_outputs = dp( self.decode, targets, - encoder_output[0], - encoder_decoder_attention_bias[0], + cache["encoder_output"], + cache["encoder_decoder_attention_bias"], bias, hparams, cache) @@ -254,13 +312,7 @@ def symbols_to_logits_fn(ids, i, cache): with tf.variable_scope(target_modality.name): logits = target_modality.top_sharded(body_outputs, None, dp)[0] - return tf.squeeze(logits, axis=[1, 2, 3]) - - def inner_loop(i, next_id, decoded_ids, cache): - logits = symbols_to_logits_fn(next_id, i, cache) - next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) - decoded_ids = tf.concat([decoded_ids, next_id], axis=1) - return i+1, next_id, decoded_ids, cache + return tf.squeeze(logits, axis=[1, 2, 3]), cache key_channels = hparams.attention_key_channels or hparams.hidden_size value_channels = hparams.attention_value_channels or hparams.hidden_size @@ -272,24 +324,53 @@ def inner_loop(i, next_id, decoded_ids, cache): "v": tf.zeros([batch_size, 0, value_channels]), } for layer in range(num_layers) } - decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) - next_id = tf.zeros([batch_size, 1], dtype=tf.int64) - _, _, decoded_ids, _ = tf.while_loop( - # TODO(llion): Early stopping. - lambda i, *_: tf.less(i, decode_length), - inner_loop, - [tf.constant(0), next_id, decoded_ids, cache], - shape_invariants=[ - tf.TensorShape([]), - tf.TensorShape([None, None]), - tf.TensorShape([None, None]), - {"layer_%d" % layer: { - "k": tf.TensorShape([None, None, key_channels]), - "v": tf.TensorShape([None, None, value_channels]), - } for layer in range(num_layers)} - ]) - return decoded_ids, None, None + # Set 2nd dim to None since it's not invariant in the tf.while_loop + # Note: Tensor.set_shape() does not work here since it merges shape info. + # TODO(llion); Find a more robust solution. + # pylint: disable=protected-access + for layer in cache: + cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels]) + cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels]) + # pylint: enable=protected-access + cache["encoder_output"] = encoder_output + cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias + + if beam_size > 1: # Beam Search + target_modality = ( + self._hparams.problems[self._problem_idx].target_modality) + vocab_size = target_modality.top_dimensionality + initial_ids = tf.zeros([batch_size], dtype=tf.int32) + decoded_ids, _ = beam_search.beam_search( + symbols_to_logits_fn, initial_ids, beam_size, decode_length, + vocab_size, alpha, states=cache) + + if top_beams == 1: + decoded_ids = decoded_ids[:, 0, 1:] + else: + decoded_ids = decoded_ids[:, :top_beams, 1:] + else: # Greedy + def inner_loop(i, next_id, decoded_ids, cache): + logits, cache = symbols_to_logits_fn(next_id, i, cache) + next_id = tf.expand_dims(tf.argmax(logits, axis=-1), axis=1) + decoded_ids = tf.concat([decoded_ids, next_id], axis=1) + return i+1, next_id, decoded_ids, cache + + decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64) + next_id = tf.zeros([batch_size, 1], dtype=tf.int64) + _, _, decoded_ids, _ = tf.while_loop( + # TODO(llion): Early stopping. + lambda i, *_: tf.less(i, decode_length), + inner_loop, + [tf.constant(0), next_id, decoded_ids, cache], + shape_invariants=[ + tf.TensorShape([]), + tf.TensorShape([None, None]), + tf.TensorShape([None, None]), + nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), + ]) + + return decoded_ids @registry.register_model @@ -913,13 +994,26 @@ def transformer_parameter_attention_b(): @registry.register_hparams -def transformer_prepend(): - hparams = transformer_base() +def transformer_prepend_v2(): + hparams = transformer_base_v2() + hparams.prepend_mode = "prepend_inputs_masked_attention" + hparams.max_length = 0 + return hparams + + +@registry.register_hparams +def transformer_prepend_v1(): + hparams = transformer_base_v1() hparams.prepend_mode = "prepend_inputs_masked_attention" hparams.max_length = 0 return hparams +@registry.register_hparams +def transformer_prepend(): + return transformer_prepend_v2() + + @registry.register_ranged_hparams("transformer_base") def transformer_base_range(rhp): """Small range of hyperparameters.""" diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index e77138eaf..74f563fbb 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -112,5 +112,51 @@ def testGreedyVsFast(self): self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) self.assertAllClose(greedy_res, fast_res) + def testBeamVsFast(self): + model, features = self.getModel(transformer.transformer_small()) + + decode_length = 2 + + out_logits, _ = model.model_fn(features) + out_logits = tf.squeeze(out_logits[0], axis=[2, 3]) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), + labels=tf.reshape(features["targets"], [-1])) + loss = tf.reduce_mean(loss) + apply_grad = tf.train.AdamOptimizer(0.001).minimize(loss) + + with self.test_session(): + tf.global_variables_initializer().run() + for _ in range(100): + apply_grad.run() + + model, _ = self.getModel(transformer.transformer_small(), + mode=tf.estimator.ModeKeys.PREDICT) + + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + beam_result = model._beam_decode_slow( + features, + decode_length, + beam_size=4, + top_beams=1, + last_position_only=True, + alpha=1.0) + + fast_result = model._beam_decode( + features, + decode_length, + beam_size=4, + top_beams=1, + last_position_only=True, + alpha=1.0) + + with self.test_session(): + beam_res = beam_result.eval() + fast_res = fast_result.eval() + + self.assertEqual(fast_res.shape, (BATCH_SIZE, INPUT_LENGTH + decode_length)) + self.assertAllClose(beam_res, fast_res) + + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/tpu/tpu_trainer_lib.py b/tensor2tensor/tpu/tpu_trainer_lib.py index c514da2ad..dca9f4de9 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib.py +++ b/tensor2tensor/tpu/tpu_trainer_lib.py @@ -15,7 +15,8 @@ """Library for training on TPU. See tpu_trainer.py. -Currently only supports training and evaluation for text-to-text problems. +Currently only supports training and evaluation for text-to-text and text +autoregressive problems. """ from __future__ import absolute_import @@ -91,31 +92,42 @@ def _valid_size(example): dataset = dataset.shuffle(100) # TODO(rsepassi): In eval mode, should not repeat dataset = dataset.repeat(None) - dataset = data_reader.padded_batch(dataset, - batching_scheme["batch_sizes"][0], + dataset = data_reader.padded_batch(dataset, batch_size, batching_scheme["padded_shapes"]) if not is_training: dataset = dataset.map( lambda f: pad_batch(f, batch_size), num_parallel_calls=num_threads) - dataset.prefetch(1) + def shape_def(example): + """Set the right shapes for the features.""" + inputs = example["inputs"] + targets = example["targets"] - train_features = dataset.make_one_shot_iterator().get_next() + # Ensure inputs and targets are proper rank. + while len(inputs.get_shape()) < 4: + inputs = tf.expand_dims(inputs, axis=-1) + while len(targets.get_shape()) < 4: + targets = tf.expand_dims(targets, axis=-1) - inputs = train_features["inputs"] - targets = train_features["targets"] + example["inputs"] = inputs + example["targets"] = targets - # Ensure inputs and targets are proper rank. - while len(inputs.get_shape()) != 4: - inputs = tf.expand_dims(inputs, axis=-1) - while len(targets.get_shape()) != 4: - targets = tf.expand_dims(targets, axis=-1) + # Ensure batch size is set on all features + for _, t in example.iteritems(): + shape = t.get_shape().as_list() + shape[0] = batch_size + t.set_shape(t.get_shape().merge_with(shape)) + # Assert shapes are fully known + t.get_shape().assert_is_fully_defined() - train_features["inputs"] = inputs - train_features["targets"] = targets + return example + + dataset = dataset.map(shape_def, num_parallel_calls=num_threads) + dataset = dataset.prefetch(1) + features = dataset.make_one_shot_iterator().get_next() - return train_features, targets + return features, features["targets"] return input_fn @@ -147,20 +159,26 @@ def model_fn(features, labels, mode, params, config): problem_hp = hparams.problems[0] orig_features = features - # Instantiate model and retrieve modalities + # Instantiate model and retrieve modalities. Note that autoregressive models + # have no input modality. model_class = registry.model(model)(hparams, mode, problem_hp) - input_modality = problem_hp.input_modality["inputs"] + input_modality = problem_hp.input_modality.get("inputs") target_modality = problem_hp.target_modality + # Transform features + transformed_features = {} + if input_modality is not None: + transformed_features["inputs"] = input_modality.bottom(features["inputs"]) + transformed_features["targets"] = target_modality.targets_bottom( + features["targets"]) + transformed_features["problem_choice"] = tf.constant(0) + transformed_features["input_space_id"] = tf.constant( + problem_hp.input_space_id) + transformed_features["target_space_id"] = tf.constant( + problem_hp.target_space_id) + # Model construction - features = { - "inputs": input_modality.bottom(features["inputs"]), - "targets": target_modality.targets_bottom(features["targets"]), - "problem_choice": tf.constant(0), - "input_space_id": tf.constant(problem_hp.input_space_id), - "target_space_id": tf.constant(problem_hp.target_space_id) - } - outputs = model_class.model_fn_body(features) + outputs = model_class.model_fn_body(transformed_features) logits = target_modality.top(outputs, labels) # Ensure the length is known statically diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index 1dd2f87b1..c08416fb8 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -30,7 +30,45 @@ INF = 1. * 1e7 -def expand_to_beam_size(tensor, beam_size): +def _get_shape(tensor): + """Returns static shape if available and dynamic shape otherwise.""" + static = tensor.shape.as_list() + dynamic = tf.unstack(tf.shape(tensor)) + return [s[1] if s[0] is None else s[0] for s in zip(static, dynamic)] + + +def _merge_beam_dim(tensor): + """Reshapes first two dimensions in to single dimension. + + Args: + tensor: Tensor to reshape of shape [A, B, ...] + + Returns: + Reshaped tensor of shape [A*B, ...] + """ + shape = _get_shape(tensor) + shape[0] *= shape[1] # batch -> batch * beam_size + shape.pop(1) # Remove beam dim + return tf.reshape(tensor, shape) + + +def _unmerge_beam_dim(tensor, batch_size, beam_size): + """Reshapes first dimension back to [batch_size, beam_size]. + + Args: + tensor: Tensor to reshape of shape [batch_size*beam_size, ...] + batch_size: Tensor, original batch size. + beam_size: int, original beam size. + + Returns: + Reshaped tensor of shape [batch_size, beam_size, ...] + """ + shape = _get_shape(tensor) + new_shape = [batch_size] + [beam_size] + shape[1:] + return tf.reshape(tensor, new_shape) + + +def _expand_to_beam_size(tensor, beam_size): """Tiles a given tensor by beam_size. Args: @@ -191,11 +229,11 @@ def beam_search(symbols_to_logits_fn, alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) # Expand each batch and state to beam_size - alive_seq = expand_to_beam_size(initial_ids, beam_size) + alive_seq = _expand_to_beam_size(initial_ids, beam_size) alive_seq = tf.expand_dims(alive_seq, axis=2) # (batch_size, beam_size, 1) if states: states = nest.map_structure( - lambda state: expand_to_beam_size(state, beam_size), states) + lambda state: _expand_to_beam_size(state, beam_size), states) else: states = {} @@ -302,12 +340,10 @@ def grow_topk(i, alive_seq, alive_log_probs, states): # (batch_size * beam_size, decoded_length) if states: - flat_states = nest.map_structure( - lambda state: tf.reshape(state, [batch_size * beam_size, -1]), states) - flat_logits, flat_states = symbols_to_logits_fn(flat_ids, flat_states) + flat_states = nest.map_structure(_merge_beam_dim, states) + flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states) states = nest.map_structure( - lambda state: tf.reshape(state, [batch_size, beam_size, -1]), - flat_states) + lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states) else: flat_logits = symbols_to_logits_fn(flat_ids) logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) @@ -478,8 +514,7 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores.get_shape(), finished_flags.get_shape(), nest.map_structure( - lambda tensor: tf.TensorShape([None] * tensor.shape.ndims), - states), + lambda tensor: tf.TensorShape(tensor.shape), states), ], parallel_iterations=1, back_prop=False) diff --git a/tensor2tensor/utils/beam_search_test.py b/tensor2tensor/utils/beam_search_test.py index fc15eb3bc..379411e99 100644 --- a/tensor2tensor/utils/beam_search_test.py +++ b/tensor2tensor/utils/beam_search_test.py @@ -289,7 +289,7 @@ def testStates(self): expected_states = tf.constant([[[0.]], [[1.]]]) - def symbols_to_logits(ids, states): + def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 # We have to assert the values of state inline here since we can't fetch # them out of the loop! @@ -303,6 +303,7 @@ def symbols_to_logits(ids, states): states = { "state": tf.zeros((batch_size, 1)), } + states["state"]._shape = tf.TensorShape((None, 1)) final_ids, _ = beam_search.beam_search( symbols_to_logits, @@ -336,7 +337,7 @@ def testStateBeamTwo(self): # at each position, which is the one thats getting 3 added to it each step. expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]]) - def symbols_to_logits(ids, states): + def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 # We have to assert the values of state inline here since we can't fetch @@ -351,6 +352,7 @@ def symbols_to_logits(ids, states): states = { "state": tf.zeros((batch_size, 1)), } + states["state"]._shape = tf.TensorShape((None, 1)) final_ids, _ = beam_search.beam_search( symbols_to_logits, diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index 83f66b985..9ec147e3d 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -71,7 +71,8 @@ def input_pipeline(problem, mode, hparams, batching_scheme, - dataset_split=None): + dataset_split=None, + shard=None): """Input pipeline, returns a dictionary of batched and padded tensors. Args: @@ -88,6 +89,7 @@ def input_pipeline(problem, "max_length": an integer. We drop sequences which are longer. dataset_split: tf.estimator.ModeKeys + ["test"], which split of the dataset to use. Defaults to mode. + shard: int, if provided, will only read data from the specified shard. Returns: dict @@ -102,7 +104,8 @@ def input_pipeline(problem, num_threads=num_threads, output_buffer_size=capacity, hparams=hparams, - dataset_split=dataset_split) + dataset_split=dataset_split, + shard=shard) dataset = dataset.map(cast_int64_to_int32, num_threads=num_threads) dataset = dataset.filter( functools.partial( diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 5dac0dd5f..8aa3c0b71 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -69,7 +69,8 @@ def log_decode_results(inputs, model_dir=None, identity_output=False): """Log inference results.""" - if "image" in problem_name and save_images: + is_image = "image" in problem_name + if is_image and save_images: save_path = os.path.join(model_dir, "%s_prediction_%d.jpg" % (problem_name, prediction_idx)) show_and_save_image(inputs / 255., save_path) @@ -77,7 +78,7 @@ def log_decode_results(inputs, if identity_output: decoded_inputs = " ".join(map(str, inputs.flatten())) else: - decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs.flatten())) + decoded_inputs = inputs_vocab.decode(_save_until_eos(inputs, is_image)) tf.logging.info("Inference results INPUT: %s" % decoded_inputs) @@ -87,11 +88,9 @@ def log_decode_results(inputs, if targets is not None: decoded_targets = " ".join(map(str, targets.flatten())) else: - decoded_outputs = "".join( - map(str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) + decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image)) if targets is not None: - decoded_targets = "".join( - map(str, targets_vocab.decode(_save_until_eos(targets.flatten())))) + decoded_targets = targets_vocab.decode(_save_until_eos(targets, is_image)) tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) if targets is not None: @@ -107,6 +106,8 @@ def decode_from_dataset(estimator, tf.logging.info("Performing local inference from dataset for %s.", str(problem_names)) hparams = estimator.params + # We assume that worker_id corresponds to shard number. + shard = decode_hp.shard_id if decode_hp.shards > 1 else None for problem_idx, problem_name in enumerate(problem_names): # Build the inference input function @@ -117,14 +118,19 @@ def decode_from_dataset(estimator, num_datashards=devices.data_parallelism().n, fixed_problem=problem_idx, batch_size=decode_hp.batch_size, - dataset_split=dataset_split) + dataset_split=dataset_split, + shard=shard) # Get the predictions as an iterable predictions = estimator.predict(infer_input_fn) # Prepare output file writers if decode_to_file passed if decode_to_file: - output_filepath = _decode_filename(decode_to_file, problem_name, + if decode_hp.shards > 1: + decode_filename = decode_to_file + ("%.2d" % decode_hp.shard_id) + else: + decode_filename = decode_to_file + output_filepath = _decode_filename(decode_filename, problem_name, decode_hp) parts = output_filepath.split(".") parts[-1] = "targets" @@ -134,7 +140,11 @@ def decode_from_dataset(estimator, target_file = tf.gfile.Open(target_filepath, "w") problem_hparams = hparams.problems[problem_idx] - inputs_vocab = problem_hparams.vocabulary.get("inputs", None) + # Inputs vocabulary is set to targets if there are no inputs in the problem, + # e.g., for language models where the inputs are just a prefix of targets. + has_input = "inputs" in problem_hparams.vocabulary + inputs_vocab_key = "inputs" if has_input else "targets" + inputs_vocab = problem_hparams.vocabulary[inputs_vocab_key] targets_vocab = problem_hparams.vocabulary["targets"] for num_predictions, prediction in enumerate(predictions): num_predictions += 1 @@ -200,7 +210,11 @@ def decode_from_file(estimator, filename, decode_hp, decode_to_file=None): hparams = estimator.params problem_id = decode_hp.problem_idx - inputs_vocab = hparams.problems[problem_id].vocabulary["inputs"] + # Inputs vocabulary is set to targets if there are no inputs in the problem, + # e.g., for language models where the inputs are just a prefix of targets. + has_input = "inputs" in hparams.problems[problem_id].vocabulary + inputs_vocab_key = "inputs" if has_input else "targets" + inputs_vocab = hparams.problems[problem_id].vocabulary[inputs_vocab_key] targets_vocab = hparams.problems[problem_id].vocabulary["targets"] problem_name = FLAGS.problems.split("-")[problem_id] tf.logging.info("Performing decoding from a file.") @@ -246,7 +260,7 @@ def input_fn(): else: output_filename = filename if decode_hp.shards > 1: - base_filename = output_filename + ("%.2d" % FLAGS.worker_id) + base_filename = output_filename + ("%.2d" % decode_hp.shard_id) else: base_filename = output_filename decode_filename = _decode_filename(base_filename, problem_name, decode_hp) @@ -303,6 +317,7 @@ def input_fn(): result_iter = estimator.predict(input_fn) for result in result_iter: problem_idx = result["problem_choice"] + is_image = False # TODO(lukaszkaiser): find out from problem id / class. targets_vocab = hparams.problems[problem_idx].vocabulary["targets"] if decode_hp.return_beams: @@ -312,7 +327,7 @@ def input_fn(): scores = np.split(result["scores"], decode_hp.beam_size, axis=0) for k, beam in enumerate(beams): tf.logging.info("BEAM %d:" % k) - beam_string = targets_vocab.decode(_save_until_eos(beam.flatten())) + beam_string = targets_vocab.decode(_save_until_eos(beam, is_image)) if scores is not None: tf.logging.info("%s\tScore:%f" % (beam_string, scores[k])) else: @@ -322,7 +337,7 @@ def input_fn(): tf.logging.info(" ".join(map(str, result["outputs"].flatten()))) else: tf.logging.info( - targets_vocab.decode(_save_until_eos(result["outputs"].flatten()))) + targets_vocab.decode(_save_until_eos(result["outputs"], is_image))) def _decode_batch_input_fn(problem_id, num_decode_batches, sorted_inputs, @@ -509,8 +524,11 @@ def _get_sorted_inputs(filename, num_shards=1, delimiter="\n"): return sorted_inputs, sorted_keys -def _save_until_eos(hyp): +def _save_until_eos(hyp, is_image): """Strips everything after the first token, which is normally 1.""" + hyp = hyp.flatten() + if is_image: + return hyp try: index = list(hyp).index(text_encoder.EOS_ID) return hyp[0:index] diff --git a/tensor2tensor/utils/devices.py b/tensor2tensor/utils/devices.py index d532b6d5f..9fa322985 100644 --- a/tensor2tensor/utils/devices.py +++ b/tensor2tensor/utils/devices.py @@ -109,8 +109,11 @@ def _replica_device_setter(worker_device): ps_tasks=FLAGS.ps_replicas, ps_device=FLAGS.ps_job + "/GPU:0" if FLAGS.ps_gpu > 0 else FLAGS.ps_job) - if FLAGS.schedule == "train_and_evaluate": + if FLAGS.schedule in ["train_and_evaluate", "continuous_train_and_eval"]: assert not FLAGS.sync + tf.logging.warn( + "Schedule=%s. Assuming that training is running on a single machine.", + FLAGS.schedule) datashard_devices = ["gpu:%d" % d for d in _gpu_order(FLAGS.worker_gpu)] if FLAGS.locally_shard_to_cpu or FLAGS.worker_gpu < 1: datashard_devices += ["cpu:0"] diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index f4a3098ad..fc4a72405 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -36,7 +36,8 @@ def build_input_fn(mode, worker_replicas=None, worker_id=None, batch_size=None, - dataset_split=None): + dataset_split=None, + shard=None): """Provides input to the graph, either from disk or via a placeholder. This function produces an input function that will feed data into @@ -62,6 +63,7 @@ def build_input_fn(mode, batch_size: int, if provided, will use a fixed batch size. dataset_split: tf.estimator.ModeKeys + ["test"], which split of the dataset to use. Defaults to mode. + shard: int, if provided, will only read data from the specified shard. Returns: A function that returns a dictionary of features and the target labels. @@ -99,6 +101,7 @@ def input_fn(): mode, batch_size=batch_size, dataset_split=dataset_split, + shard=shard, name="problem_%d" % problem_idx) problem_batches.append(feature_map) @@ -204,6 +207,7 @@ def features_for_problem(problem_instance, mode, batch_size=None, dataset_split=None, + shard=None, name="problem_inputs"): """Feature map for Problem.""" with tf.name_scope(name): @@ -228,7 +232,8 @@ def features_for_problem(problem_instance, mode, hparams, batching_scheme, - dataset_split=dataset_split) + dataset_split=dataset_split, + shard=shard) # Ensure inputs and targets are proper rank. if problem_instance.has_inputs: diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index c54b38f3f..85f339511 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -26,6 +26,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensor2tensor.data_generators import text_encoder from tensor2tensor.layers import common_layers from tensor2tensor.utils import beam_search from tensor2tensor.utils import expert_utils as eu @@ -216,6 +217,8 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, last_position_only, alpha): """Beam search decoding. + Models should ideally implement a more efficient version of this function. + Args: features: an map of string to `Tensor` decode_length: an integer. How many additional timesteps to decode. @@ -228,7 +231,27 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, Returns: samples: an integer `Tensor`. Top samples from the beam search """ + return self._beam_decode_slow(features, decode_length, beam_size, top_beams, + last_position_only, alpha) + + def _beam_decode_slow(self, features, decode_length, beam_size, top_beams, + last_position_only, alpha): + """Slow version of Beam search decoding. + + Quadratic time in decode_length. + Args: + features: an map of string to `Tensor` + decode_length: an integer. How many additional timesteps to decode. + beam_size: number of beams. + top_beams: an integer. How many of the beams to return. + last_position_only: a boolean, speed-up by computing last position only. + alpha: Float that controls the length penalty. larger the alpha, stronger + the preference for slonger translations. + + Returns: + samples: an integer `Tensor`. Top samples from the beam search + """ batch_size = tf.shape(features["inputs"])[0] batch_size = tf.Print(batch_size, [batch_size], "beam_decode batch_size=") @@ -259,15 +282,16 @@ def symbols_to_logits_fn(ids): initial_ids = tf.zeros([batch_size], dtype=tf.int32) - inputs_old = features["inputs"] - features["inputs"] = tf.expand_dims(features["inputs"], 1) - if len(features["inputs"].shape) < 5: - features["inputs"] = tf.expand_dims(features["inputs"], 4) - # Expand the inputs in to the beam size. - features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1]) - s = tf.shape(features["inputs"]) - features["inputs"] = tf.reshape(features["inputs"], - [s[0] * s[1], s[2], s[3], s[4]]) + if self.has_input: + inputs_old = features["inputs"] + features["inputs"] = tf.expand_dims(features["inputs"], 1) + if len(features["inputs"].shape) < 5: + features["inputs"] = tf.expand_dims(features["inputs"], 4) + # Expand the inputs in to the beam size. + features["inputs"] = tf.tile(features["inputs"], [1, beam_size, 1, 1, 1]) + s = tf.shape(features["inputs"]) + features["inputs"] = tf.reshape(features["inputs"], + [s[0] * s[1], s[2], s[3], s[4]]) target_modality = self._hparams.problems[self._problem_idx].target_modality vocab_size = target_modality.top_dimensionality @@ -280,7 +304,8 @@ def symbols_to_logits_fn(ids): alpha) # Set inputs back to the unexpanded inputs to not to confuse the Estimator! - features["inputs"] = inputs_old + if self.has_input: + features["inputs"] = inputs_old # Return `top_beams` decodings (also remove initial id from the beam search) return_scores = False # TODO(lukaszkaiser): make it work multi-problem. @@ -365,8 +390,9 @@ def infer_step(recent_output, recent_logits, unused_loss): # Create an initial output tensor. This will be passed # to the infer_step, which adds one timestep at every iteration. if "partial_targets" in features: - initial_output = tf.to_int64(tf.expand_dims( - tf.expand_dims(features["partial_targets"], 2), 3)) + initial_output = tf.to_int64(features["partial_targets"]) + while len(initial_output.get_shape().as_list()) < 4: + initial_output = tf.expand_dims(initial_output, 2) batch_size = tf.shape(initial_output)[0] else: batch_size = tf.shape(features["inputs"])[0] @@ -387,8 +413,38 @@ def infer_step(recent_output, recent_logits, unused_loss): logits.set_shape([None, None, None, None, None]) loss = 0.0 + def while_exit_cond(result, logits, loss): # pylint: disable=unused-argument + """Exit the loop either if reach decode_length or EOS.""" + length = tf.shape(result)[1] + + not_overflow = length < decode_length + + if self._problem_hparams.stop_at_eos: + def fn_not_eos(): + return tf.not_equal( # Check if the last predicted element is a EOS + tf.squeeze(result[:, -1, :, :]), + text_encoder.EOS_ID + ) + + not_eos = tf.cond( + # We only check for early stoping if there is at least 1 element ( + # otherwise not_eos will crash) + tf.not_equal(length, 0), + fn_not_eos, + lambda: True, + ) + + return tf.cond( + tf.equal(batch_size, 1), + # If batch_size == 1, we check EOS for early stoping + lambda: tf.logical_and(not_overflow, not_eos), + # Else, just wait for max length + lambda: not_overflow + ) + return not_overflow + result, logits, loss = tf.while_loop( - lambda result, logits, loss: tf.shape(result)[1] < decode_length, + while_exit_cond, infer_step, [result, logits, loss], shape_invariants=[ tf.TensorShape([None, None, None, None]), diff --git a/tensor2tensor/visualization/TransformerVisualization.ipynb b/tensor2tensor/visualization/TransformerVisualization.ipynb index ae3c5809a..ce70bde89 100644 --- a/tensor2tensor/visualization/TransformerVisualization.ipynb +++ b/tensor2tensor/visualization/TransformerVisualization.ipynb @@ -30,7 +30,8 @@ "import numpy as np\n", "\n", "from tensor2tensor.utils import trainer_utils as utils\n", - "from tensor2tensor.visualization import attention" + "from tensor2tensor.visualization import attention\n", + "from tensor2tensor.utils import decoding" ] }, { @@ -84,7 +85,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "/home/llion/t2t_train/wmt_ende_tokens_32k/transformer-transformer_base_single_gpu\n" + "/usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu\n" ] } ], @@ -104,7 +105,9 @@ "FLAGS.problems = PROBLEM\n", "FLAGS.hparams_set = HPARAMS\n", "FLAGS.data_dir = DATA_DIR\n", - "FLAGS.model = MODEL" + "FLAGS.model = MODEL\n", + "\n", + "FLAGS.schedule = 'train_and_evaluate'" ] }, { @@ -120,24 +123,33 @@ "output_type": "stream", "text": [ "INFO:tensorflow:datashard_devices: ['gpu:0']\n", - "INFO:tensorflow:caching_devices: None\n" + "INFO:tensorflow:caching_devices: None\n", + "INFO:tensorflow:batching_scheme = {'min_length': 0, 'window_size': 720, 'shuffle_queue_size': 270, 'boundaries': [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 24, 26, 28, 30, 33, 36, 39, 42, 46, 50, 55, 60, 66, 72, 79, 86, 94, 103, 113, 124, 136, 149, 163, 179, 196, 215, 236], 'max_length': 1000000000, 'batch_sizes': [240, 180, 180, 180, 144, 144, 144, 120, 120, 120, 90, 90, 90, 90, 80, 72, 72, 60, 60, 48, 48, 48, 40, 40, 36, 30, 30, 24, 24, 20, 20, 18, 18, 16, 15, 12, 12, 10, 10, 9, 8, 8]}\n", + "INFO:tensorflow:Updated batching_scheme = {'min_length': 0, 'window_size': 720, 'shuffle_queue_size': 270, 'boundaries': [], 'max_length': 1000000000, 'batch_sizes': [1]}\n", + "INFO:tensorflow:Reading data files from /usr/local/google/home/llion/t2t_data/translate_ende_wmt32k-dev*\n" ] } ], "source": [ - "hparams = utils.create_hparams(HPARAMS, PROBLEM, DATA_DIR)\n", + "hparams = utils.create_hparams(FLAGS.hparams_set, FLAGS.data_dir)\n", "\n", "# SET EXTRA HYPER PARAMS HERE!\n", - "# e.g.\n", - "# hparams.batch_size = 1024\n", + "#hparams.null_slot = True\n", + "\n", + "utils.add_problem_hparams(hparams, PROBLEM)\n", "\n", "num_datashards = utils.devices.data_parallelism().n\n", "\n", + "mode = tf.estimator.ModeKeys.EVAL\n", + "\n", "input_fn = utils.input_fn_builder.build_input_fn(\n", - " mode=tf.estimator.ModeKeys.EVAL,\n", - " hparams=hparams,\n", - " data_dir=DATA_DIR,\n", - " num_datashards=num_datashards)\n", + " mode=mode,\n", + " hparams=hparams,\n", + " data_dir=DATA_DIR,\n", + " num_datashards=num_datashards,\n", + " worker_replicas=FLAGS.worker_replicas,\n", + " worker_id=FLAGS.worker_id,\n", + " batch_size=1)\n", "\n", "inputs, target = input_fn()\n", "features = inputs\n", @@ -199,8 +211,15 @@ } ], "source": [ - "spec = utils.model_builder.model_fn(MODEL, features, tf.estimator.ModeKeys.EVAL, hparams, problem_names=[PROBLEM])\n", - "predictions_dict = spec.predictions" + "model_fn=utils.model_builder.build_model_fn(\n", + " MODEL,\n", + " problem_names=[PROBLEM],\n", + " train_steps=FLAGS.train_steps,\n", + " worker_id=FLAGS.worker_id,\n", + " worker_replicas=FLAGS.worker_replicas,\n", + " eval_run_autoregressive=FLAGS.eval_run_autoregressive,\n", + " decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams))\n", + "est_spec = model_fn(features, target, mode, hparams)" ] }, { @@ -224,8 +243,7 @@ ], "source": [ "with tf.variable_scope(tf.get_variable_scope(), reuse=True):\n", - " spec = utils.model_builder.model_fn(MODEL, features, tf.estimator.ModeKeys.PREDICT, hparams, problem_names=[PROBLEM])\n", - " beam_out = spec.predictions['outputs']" + " beam_out = model_fn(features, target, tf.contrib.learn.ModeKeys.INFER, hparams)" ] }, { @@ -246,10 +264,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO:tensorflow:Restoring parameters from /home/llion/t2t_train/wmt_ende_tokens_32k/transformer-transformer_base_single_gpu/model.ckpt-250000\n", + "INFO:tensorflow:Restoring parameters from /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt-1\n", "INFO:tensorflow:Starting standard services.\n", - "INFO:tensorflow:Saving checkpoint to path /home/llion/t2t_train/wmt_ende_tokens_32k/transformer-transformer_base_single_gpu/model.ckpt\n", - "INFO:tensorflow:Starting queue runners.\n" + "INFO:tensorflow:Starting queue runners.\n", + "INFO:tensorflow:Saving checkpoint to path /usr/local/google/home/llion/t2t_train/translate_ende_wmt32k/transformer-transformer_base_single_gpu/model.ckpt\n" ] }, { @@ -337,7 +355,7 @@ } ], "source": [ - "inp, out, logits = sess.run([inputs['inputs'], target, predictions_dict['predictions']])\n", + "inp, out, logits = sess.run([inputs['inputs'], target, est_spec.predictions['predictions']])\n", "\n", "print(\"Input: \", decode(inp[0]))\n", "print(\"Gold: \", decode(out[0]))\n", @@ -381,7 +399,7 @@ ], "source": [ "inp_ids = encode(eng)\n", - "beam_decode = sess.run(beam_out, {\n", + "beam_decode = sess.run(beam_out.predictions['outputs'], {\n", " inputs['inputs']: np.expand_dims(np.expand_dims(inp_ids, axis=2), axis=3),\n", "})\n", "trans = decode(beam_decode[0])\n",