diff --git a/setup.py b/setup.py index 821a88ee2..b70966986 100644 --- a/setup.py +++ b/setup.py @@ -5,14 +5,18 @@ setup( name='tensor2tensor', - version='1.0.12', + version='1.0.13', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', url='http://github.com/tensorflow/tensor2tensor', license='Apache 2.0', packages=find_packages(), - scripts=['tensor2tensor/bin/t2t-trainer', 'tensor2tensor/bin/t2t-datagen'], + scripts=[ + 'tensor2tensor/bin/t2t-trainer', + 'tensor2tensor/bin/t2t-datagen', + 'tensor2tensor/bin/t2t-make-tf-configs', + ], install_requires=[ 'numpy', 'sympy', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 4e7e4529a..0367fce94 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -37,8 +37,10 @@ from tensor2tensor.data_generators import algorithmic_math from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import image +from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb from tensor2tensor.data_generators import snli +from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing @@ -138,6 +140,14 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15), lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15) ), + "lm1b_32k": ( + lambda: lm1b.generator(FLAGS.tmp_dir, True), + lambda: lm1b.generator(FLAGS.tmp_dir, False) + ), + "wiki_32k": ( + lambda: wiki.generator(FLAGS.tmp_dir, True), + 1000 + ), "image_mnist_tune": ( lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000), lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)), @@ -335,17 +345,33 @@ def main(_): training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem] - tf.logging.info("Generating training data for %s.", problem) - train_output_files = generator_utils.generate_files( - training_gen(), problem + UNSHUFFLED_SUFFIX + "-train", - FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases) - - tf.logging.info("Generating development data for %s.", problem) - dev_output_files = generator_utils.generate_files( - dev_gen(), problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1) + if isinstance(dev_gen, int): + # The dev set and test sets are generated as extra shards using the + # training generator. The integer specifies the number of training + # shards. FLAGS.num_shards is ignored. + num_training_shards = dev_gen + tf.logging.info("Generating data for %s.", problem) + all_output_files = generator_utils.combined_data_filenames( + problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_training_shards) + generator_utils.generate_files( + training_gen(), all_output_files, FLAGS.max_cases) + else: + # usual case - train data and dev data are generated using separate + # generators. + tf.logging.info("Generating training data for %s.", problem) + train_output_files = generator_utils.train_data_filenames( + problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, FLAGS.num_shards) + generator_utils.generate_files( + training_gen(), train_output_files, FLAGS.max_cases) + tf.logging.info("Generating development data for %s.", problem) + dev_shards = 10 if "coco" in problem else 1 + dev_output_files = generator_utils.dev_data_filenames( + problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards) + generator_utils.generate_files(dev_gen(), dev_output_files) + all_output_files = train_output_files + dev_output_files tf.logging.info("Shuffling data...") - for fname in train_output_files + dev_output_files: + for fname in all_output_files: records = generator_utils.read_records(fname) random.shuffle(records) out_fname = fname.replace(UNSHUFFLED_SUFFIX, "") diff --git a/tensor2tensor/bin/make_tf_configs.py b/tensor2tensor/bin/t2t-make-tf-configs similarity index 93% rename from tensor2tensor/bin/make_tf_configs.py rename to tensor2tensor/bin/t2t-make-tf-configs index 005f638c0..ae87ffbd8 100644 --- a/tensor2tensor/bin/make_tf_configs.py +++ b/tensor2tensor/bin/t2t-make-tf-configs @@ -1,3 +1,4 @@ +#!/usr/bin/env python # Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,7 +17,7 @@ Usage: -`make_tf_configs.py --workers="server1:1234" --ps="server3:2134,server4:2334"` +`t2t-make-tf-configs --workers="server1:1234" --ps="server3:2134,server4:2334"` Outputs 1 line per job to stdout, first the workers, then the parameter servers. Each line has the TF_CONFIG, then a tab, then the command line flags for that @@ -74,7 +75,8 @@ def main(_): "task": { "type": task_type, "index": idx - } + }, + "environment": "cloud", }) print("'%s'\t%s" % (tf_config, cmd_line_flags)) diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index 6a3475456..a0dd7c101 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +from collections import defaultdict import gzip import io import os @@ -30,7 +31,7 @@ import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3 from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators.tokenizer import Tokenizer +from tensor2tensor.data_generators import tokenizer import tensorflow as tf @@ -84,10 +85,34 @@ def generate_files_distributed(generator, return output_file +def _data_filenames(output_name, output_dir, num_shards): + return [os.path.join( + output_dir, "%s-%.5d-of-%.5d" % (output_name, shard, num_shards)) + for shard in xrange(num_shards)] + + +def train_data_filenames(problem, output_dir, num_shards): + return _data_filenames( + problem + "-train", output_dir, num_shards) + + +def dev_data_filenames(problem, output_dir, num_shards): + return _data_filenames(problem + "-dev", output_dir, num_shards) + + +def test_data_filenames(problem, output_dir, num_shards): + return _data_filenames(problem + "-test", output_dir, num_shards) + + +def combined_data_filenames(problem, output_dir, num_training_shards): + return ( + train_data_filenames(problem, output_dir, num_training_shards) + + dev_data_filenames(problem, output_dir, 1) + + test_data_filenames(problem, output_dir, 1)) + + def generate_files(generator, - output_name, - output_dir, - num_shards=1, + output_filenames, max_cases=None): """Generate cases from a generator and save as TFRecord files. @@ -96,27 +121,16 @@ def generate_files(generator, Args: generator: a generator yielding (string -> int/float/str list) dictionaries. - output_name: the file name prefix under which output will be saved. - output_dir: directory to save the output to. - num_shards: how many shards to use (defaults to 1). + output_filenames: List of output file paths. max_cases: maximum number of cases to get from the generator; if None (default), we use the generator until StopIteration is raised. - - Returns: - List of output file paths. """ - writers = [] - output_files = [] - for shard in xrange(num_shards): - output_filename = "%s-%.5d-of-%.5d" % (output_name, shard, num_shards) - output_file = os.path.join(output_dir, output_filename) - output_files.append(output_file) - writers.append(tf.python_io.TFRecordWriter(output_file)) - + num_shards = len(output_filenames) + writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames] counter, shard = 0, 0 for case in generator: if counter > 0 and counter % 100000 == 0: - tf.logging.info("Generating case %d for %s." % (counter, output_name)) + tf.logging.info("Generating case %d." % counter) counter += 1 if max_cases and counter > max_cases: break @@ -127,8 +141,6 @@ def generate_files(generator, for writer in writers: writer.close() - return output_files - def download_report_hook(count, block_size, total_size): """Report hook for download progress. @@ -235,7 +247,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): sources = sources or _DATA_FILE_URLS tf.logging.info("Generating vocab from: %s", str(sources)) - tokenizer = Tokenizer() + token_counts = defaultdict(int) for source in sources: url = source[0] filename = os.path.basename(url) @@ -269,10 +281,11 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None): break line = line.strip() file_byte_budget -= len(line) - _ = tokenizer.encode(text_encoder.native_to_unicode(line)) + for tok in tokenizer.encode(text_encoder.native_to_unicode(line)): + token_counts[tok] += 1 vocab = text_encoder.SubwordTextEncoder.build_to_target_size( - vocab_size, tokenizer.token_counts, 1, 1e3) + vocab_size, token_counts, 1, 1e3) vocab.store_to_file(vocab_filepath) return vocab diff --git a/tensor2tensor/data_generators/generator_utils_test.py b/tensor2tensor/data_generators/generator_utils_test.py index 726763f7a..320d1a02d 100644 --- a/tensor2tensor/data_generators/generator_utils_test.py +++ b/tensor2tensor/data_generators/generator_utils_test.py @@ -41,11 +41,12 @@ def testGenerateFiles(self): def test_generator(): yield {"inputs": [1], "target": [1]} - generator_utils.generate_files(test_generator(), tmp_file_name, tmp_dir) - self.assertTrue(tf.gfile.Exists(tmp_file_path + "-00000-of-00001")) + filenames = generator_utils.train_data_filenames(tmp_file_name, tmp_dir, 1) + generator_utils.generate_files(test_generator(), filenames) + self.assertTrue(tf.gfile.Exists(tmp_file_path + "-train-00000-of-00001")) # Clean up. - os.remove(tmp_file_path + "-00000-of-00001") + os.remove(tmp_file_path + "-train-00000-of-00001") os.remove(tmp_file_path) def testMaybeDownload(self): diff --git a/tensor2tensor/data_generators/inspect.py b/tensor2tensor/data_generators/inspect.py new file mode 100644 index 000000000..a0da09150 --- /dev/null +++ b/tensor2tensor/data_generators/inspect.py @@ -0,0 +1,81 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Inspect a TFRecord file of tensorflow.Example and show tokenizations. + +python data_generators/inspect.py \ + --logtostderr \ + --print_targets \ + --subword_text_encoder_filename=$DATA_DIR/tokens.vocab.8192 \ + --input_filename=$DATA_DIR/wmt_ende_tokens_8k-train-00000-of-00100 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Dependency imports + +from tensor2tensor.data_generators import text_encoder + +import tensorflow as tf + +tf.app.flags.DEFINE_string("subword_text_encoder_filename", "", + "SubwordTextEncoder vocabulary file") +tf.app.flags.DEFINE_string("input_filename", "", "input filename") +tf.app.flags.DEFINE_bool("print_inputs", False, + "Print decoded inputs to stdout") +tf.app.flags.DEFINE_bool("print_targets", False, + "Print decoded targets to stdout") + +FLAGS = tf.app.flags.FLAGS + + +def main(_): + """Convert a file to examples.""" + if FLAGS.subword_text_encoder_filename: + encoder = text_encoder.SubwordTextEncoder( + FLAGS.subword_text_encoder_filename) + else: + encoder = None + reader = tf.python_io.tf_record_iterator(FLAGS.input_filename) + total_sequences = 0 + total_input_tokens = 0 + total_target_tokens = 0 + max_input_length = 0 + max_target_length = 0 + for record in reader: + x = tf.train.Example() + x.ParseFromString(record) + inputs = [int(i) for i in x.features.feature["inputs"].int64_list.value] + targets = [int(i) for i in x.features.feature["targets"].int64_list.value] + if FLAGS.print_inputs: + print(encoder.decode(inputs) if encoder else inputs) + if FLAGS.print_targets: + print(encoder.decode(targets) if encoder else targets) + total_input_tokens += len(inputs) + total_target_tokens += len(targets) + total_sequences += 1 + max_input_length = max(max_input_length, len(inputs)) + max_target_length = max(max_target_length, len(targets)) + + tf.logging.info("total_sequences: %d", total_sequences) + tf.logging.info("total_input_tokens: %d", total_input_tokens) + tf.logging.info("total_target_tokens: %d", total_target_tokens) + tf.logging.info("max_input_length: %d", max_input_length) + tf.logging.info("max_target_length: %d", max_target_length) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensor2tensor/data_generators/lm1b.py b/tensor2tensor/data_generators/lm1b.py new file mode 100644 index 000000000..66a3d52a0 --- /dev/null +++ b/tensor2tensor/data_generators/lm1b.py @@ -0,0 +1,161 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generators for LM1B data-set.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import os +import tarfile + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import tokenizer + +import tensorflow as tf + + +# End-of-sentence marker (should correspond to the position of EOS in the +# RESERVED_TOKENS list in text_encoder.py) +EOS = 1 + + +def _original_vocab(tmp_dir): + """Returns a set containing the original vocabulary. + + This is important for comparing with published results. + + Args: + tmp_dir: directory containing dataset. + + Returns: + a set of strings + """ + vocab_url = ("http://download.tensorflow.org/models/LM_LSTM_CNN/" + "vocab-2016-09-10.txt") + vocab_filename = os.path.basename(vocab_url) + vocab_filepath = os.path.join(tmp_dir, vocab_filename) + if not os.path.exists(vocab_filepath): + generator_utils.maybe_download(tmp_dir, vocab_filename, vocab_url) + return set( + [text_encoder.native_to_unicode(l.strip()) for l in + tf.gfile.Open(vocab_filepath)]) + + +def _replace_oov(original_vocab, line): + """Replace out-of-vocab words with "UNK". + + This maintains compatability with published results. + + Args: + original_vocab: a set of strings (The standard vocabulary for the dataset) + line: a unicode string - a space-delimited sequence of words. + + Returns: + a unicode string - a space-delimited sequence of words. + """ + return u" ".join( + [word if word in original_vocab else u"UNK" for word in line.split()]) + + +def _train_data_filenames(tmp_dir): + return [os.path.join( + tmp_dir, + "1-billion-word-language-modeling-benchmark-r13output", + "training-monolingual.tokenized.shuffled", + "news.en-%05d-of-00100" % i) for i in xrange(1, 100)] + + +def _dev_data_filename(tmp_dir): + return os.path.join( + tmp_dir, + "1-billion-word-language-modeling-benchmark-r13output", + "heldout-monolingual.tokenized.shuffled", + "news.en.heldout-00000-of-00050") + + +def _maybe_download_corpus(tmp_dir): + """Download and unpack the corpus. + + Args: + tmp_dir: directory containing dataset. + """ + corpus_url = ("http://www.statmt.org/lm-benchmark/" + "1-billion-word-language-modeling-benchmark-r13output.tar.gz") + corpus_filename = os.path.basename(corpus_url) + corpus_filepath = os.path.join(tmp_dir, corpus_filename) + if not os.path.exists(corpus_filepath): + generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url) + with tarfile.open(corpus_filepath, "r:gz") as corpus_tar: + corpus_tar.extractall(tmp_dir) + + +def _get_or_build_subword_text_encoder(tmp_dir): + """Builds a SubwordTextEncoder based on the corpus. + + Args: + tmp_dir: directory containing dataset. + Returns: + a SubwordTextEncoder. + """ + filepath = os.path.join(tmp_dir, "lm1b_32k.subword_text_encoder") + if tf.gfile.Exists(filepath): + return text_encoder.SubwordTextEncoder(filepath) + _maybe_download_corpus(tmp_dir) + original_vocab = _original_vocab(tmp_dir) + token_counts = defaultdict(int) + line_count = 0 + max_lines = 63000 + for line in tf.gfile.Open(_train_data_filenames(tmp_dir)[0]): + tokens = tokenizer.encode( + _replace_oov(original_vocab, text_encoder.native_to_unicode(line))) + for tok in tokens: + token_counts[tok] += 1 + line_count += 1 + if line_count >= max_lines: + break + ret = text_encoder.SubwordTextEncoder() + ret.build_from_token_counts(token_counts, min_count=5) + ret.store_to_file(filepath) + return ret + + +def generator(tmp_dir, train): + """Generator for lm1b sentences. + + Args: + tmp_dir: a string. + train: a boolean. + + Yields: + A dictionary {"inputs": [0], "targets": []} + """ + _maybe_download_corpus(tmp_dir) + original_vocab = _original_vocab(tmp_dir) + files = (_train_data_filenames(tmp_dir) if train + else [_dev_data_filename(tmp_dir)]) + encoder = _get_or_build_subword_text_encoder(tmp_dir) + for filepath in files: + tf.logging.info("filepath = %s", filepath) + for line in tf.gfile.Open(filepath): + tokens = encoder.encode( + _replace_oov(original_vocab, text_encoder.native_to_unicode(line))) + tokens.append(EOS) + yield {"inputs": [0], "targets": tokens} diff --git a/tensor2tensor/data_generators/lm_example.py b/tensor2tensor/data_generators/lm_example.py deleted file mode 100644 index d8a76baeb..000000000 --- a/tensor2tensor/data_generators/lm_example.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Convert language modeling data to tf.Example format. - -Uses SubwordTextEncoder. - -For each line, we generate a tf.Example, with "targets" equal to a sequence -of subwords (integers), ending in subword id 1 for end-of-sequence. We add -a dummy feature "inputs"=[0] for compatability with seq-to-seq models. - -If FLAGS.combine_to_length is nonzero, then we combine multiple sequences into -examples of a constant length, possibly with some padding at the end. - - -How to preprocess lm1b - billion word benchmark -TODO(noam): should these instructions be made into a script and moved elsewhere? - - -# Download data into $DATADIR/ -http://www.statmt.org/lm-benchmark/\ -1-billion-word-language-modeling-benchmark-r13output.tar.gz -http://download.tensorflow.org/models/LM_LSTM_CNN/vocab-2016-09-10.txt - -# unpack data -cd $DATADIR -tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz - -# replace oov words with UNK -$BINARYDIR/replace_oov \ ---vocab_file=$DATADIR/vocab-2016-09-10.txt \ ---in_filepattern=\ -$DATADIR/1-billion-word-language-modeling-benchmark-r13output/\ -heldout-monolingual.tokenized.shuffled/news.en.heldout-00000-of-00050 \ ---out_prefix=$DATADIR/dev-unk \ ---logtostderr - -wc $DATADIR/dev-unk-00000-of-00050 -# -> 6075 153583 826189 -# dev set tokens including EOS = 6075 + 153583 = 159658 - -$BINARYDIR/replace_oov \ ---vocab_file=$DATADIR/vocab-2016-09-10.txt \ ---in_filepattern=\ -$DATADIR/1-billion-word-language-modeling-benchmark-r13output/\ -training-monolingual.tokenized.shuffled/news.en-?????-of-00100 \ ---out_prefix=$DATADIR/train-unk \ ---logtostderr - -# build vocabularies -$BINARYDIR/\ -text_encoder_build_subword \ - --corpus_filepattern=$DATADIR/train-unk-* \ - --corpus_max_lines=17500 \ - --output_fn=$DATADIR/lm1b_16k.subword_text_encoder \ - --logtostderr - -$BINARYDIR/\ -text_encoder_build_subword \ - --corpus_filepattern=$DATADIR/train-unk-* \ - --corpus_max_lines=270000 \ - --output_fn=$DATADIR/lm1b_64k.subword_text_encoder \ - --logtostderr - -# generate training and dev data - -# 16k vocab - -$BINARYDIR/lm_example \ ---logtostderr \ ---vocab_file=$DATADIR/lm1b_16k.subword_text_encoder \ ---in_filepattern=$DATADIR/dev-unk* \ ---out_prefix=$DATADIR/lm1b_16k-dev - -# -> total subwords: 189068 -# perplexity exponent = 189068 / 159658 = 1.184206 - -mv $DATADIR/lm1b_16k-dev-00000-of-00050 $DATADIR/lm1b_16k-dev-00000-of-00001 - -$BINARYDIR/\ -text_encoder_inspect_subword \ ---logtostderr \ ---vocab_file=$DATADIR/lm1b_16k.subword_text_encoder \ ---in_file=$DATADIR/lm1b_16k-dev-00000-of-00001 | more - -$BINARYDIR/lm_example \ ---logtostderr \ ---vocab_file=$DATADIR/lm1b_16k.subword_text_encoder \ ---in_filepattern=$DATADIR/train-unk* \ ---out_prefix=$DATADIR/lm1b_16k-train - -# 64k vocab - -$BINARYDIR/lm_example \ ---logtostderr \ ---vocab_file=$DATADIR/lm1b_64k.subword_text_encoder \ ---in_filepattern=$DATADIR/dev-unk* \ ---out_prefix=$DATADIR/lm1b_64k-dev - -# -> total subwords: 170366 -# perplexity exponent = 170366 / 159658 = 1.067068 - -mv $DATADIR/lm1b_64k-dev-00000-of-00050 $DATADIR/lm1b_64k-dev-00000-of-00001 - -$BINARYDIR/\ -text_encoder_inspect_subword \ ---logtostderr \ ---vocab_file=$DATADIR/lm1b_64k.subword_text_encoder \ ---in_file=$DATADIR/lm1b_64k-dev-00000-of-00001 | more - -$BINARYDIR/lm_example \ ---logtostderr \ ---vocab_file=$DATADIR/lm1b_64k.subword_text_encoder \ ---in_filepattern=$DATADIR/train-unk* \ ---out_prefix=$DATADIR/lm1b_64k-train - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import text_encoder - -import tensorflow as tf - -tf.app.flags.DEFINE_string( - "vocab_file", "", "SubwordTextEncoder vocabulary file") - -tf.app.flags.DEFINE_integer( - "combine_to_length", 0, - "If positive, concatenate documents to form examples with length exactly" - " equal to this value. Documents are still suffixed with subword id=1. " - " Examples are padded with subword id=0.") - -tf.app.flags.DEFINE_string("in_filepattern", "", "input filename") - -tf.app.flags.DEFINE_string( - "out_prefix", "", "The output filename is equal to out_prefix plus " - "the last 15 characters of in_file. (e.g. -00001-of-00100)") - -FLAGS = tf.app.flags.FLAGS - - -def _make_example(ids, raw_num_bytes): - if FLAGS.combine_to_length > 0: - ids += [0] * (FLAGS.combine_to_length - len(ids)) - return generator_utils.to_example({ - "targets": ids, - "inputs": [0], - "raw_num_bytes": [raw_num_bytes] - }).SerializeToString() - - -def convert_file(in_file, encoder): - """Convert a file to examples.""" - total_bytes = 0 - total_subwords = 0 - total_documents = 0 - dropped_documents = 0 - - combined_subwords = [] - combined_num_bytes = 0 - - out_file = FLAGS.out_prefix + in_file[-15:] - writer = tf.python_io.TFRecordWriter(out_file) - out_file = FLAGS.out_prefix + in_file[-15:] - print ("in_file", in_file, "out_file", out_file) - for line in tf.gfile.Open(in_file): - total_documents += 1 - assert line[-1] == "\n" - num_bytes = len(line) - total_bytes += num_bytes - line = line[:-1] - subwords = encoder.encode(line) + [1] - total_subwords += len(subwords) - if FLAGS.combine_to_length: - if len(combined_subwords) + len(subwords) > FLAGS.combine_to_length: - writer.write(_make_example(combined_subwords, combined_num_bytes)) - combined_subwords = [] - combined_num_bytes = 0 - if len(subwords) <= FLAGS.combine_to_length: - combined_subwords.extend(subwords) - combined_num_bytes += num_bytes - else: - dropped_documents += 1 - else: - writer.write(_make_example(subwords, num_bytes)) - if combined_subwords: - writer.write(_make_example(combined_subwords, combined_num_bytes)) - writer.close() - - tf.logging.info("total bytes: %d", total_bytes) - tf.logging.info("total subwords: %d", total_subwords) - tf.logging.info("bytes per subword: %f", total_bytes / total_subwords) - tf.logging.info("total documents: %d", total_documents) - tf.logging.info("dropped documents: %d", dropped_documents) - - -def main(_): - """Convert a file to examples.""" - encoder = text_encoder.SubwordTextEncoder(FLAGS.vocab_file) - - in_files = tf.gfile.Glob(FLAGS.in_filepattern) - assert in_files, "No matching input files" - for in_file in in_files: - convert_file(in_file, encoder) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index 7ad0a57ad..203dba852 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -325,33 +325,35 @@ def audio_wsj_tokens(model_hparams, wrong_vocab_size): return p -def lm1b_16k(model_hparams): - """Billion-word language-modeling benchmark, 16k subtoken vocabulary.""" +def lm1b_32k(model_hparams): + """Billion-word language-modeling benchmark, 32k subword vocabulary.""" p = default_problem_hparams() - p.perplexity_exponent = 1.184206 + # ratio of dev tokens (including eos) to dev words (including eos) + # 176884 / 159658 = 1.107893 + p.perplexity_exponent = 1.107893 p.input_modality = {} - p.target_modality = (registry.Modalities.SYMBOL, 16384) + encoder = text_encoder.SubwordTextEncoder( + os.path.join(model_hparams.data_dir, "lm1b_32k.subword_text_encoder")) + p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size) p.vocabulary = { - "targets": - text_encoder.SubwordTextEncoder( - os.path.join(model_hparams.data_dir, - "lm1b_16k.subword_text_encoder")) + "targets": encoder } p.target_space_id = 3 return p -def lm1b_64k(model_hparams): - """Billion-word language-modeling benchmark, 64k subtoken vocabulary.""" +def wiki_32k(model_hparams): + """Wikipedia title to article. 32k subtoken vocabulary.""" p = default_problem_hparams() - p.perplexity_exponent = 1.067068 - p.input_modality = {} - p.target_modality = (registry.Modalities.SYMBOL, 65536) + encoder = text_encoder.SubwordTextEncoder( + os.path.join(model_hparams.data_dir, "wiki_32k.subword_text_encoder")) + p.input_modality = { + "inputs": (registry.Modalities.SYMBOL, encoder.vocab_size) + } + p.target_modality = (registry.Modalities.SYMBOL, encoder.vocab_size) p.vocabulary = { - "targets": - text_encoder.SubwordTextEncoder( - os.path.join(model_hparams.data_dir, - "lm1b_64k.subword_text_encoder")) + "inputs": encoder, + "targets": encoder } p.target_space_id = 3 return p @@ -700,8 +702,8 @@ def img2img_imagenet(unused_model_hparams): "audio_wsj_characters_test": audio_wsj_characters, "audio_wsj_tokens_8k_tune": lambda p: audio_wsj_tokens(p, 2**13), "audio_wsj_tokens_8k_test": lambda p: audio_wsj_tokens(p, 2**13), - "lm1b_16k": lm1b_16k, - "lm1b_64k": lm1b_64k, + "lm1b_32k": lm1b_32k, + "wiki_32k": wiki_32k, "lmptb_10k": lmptb_10k, "wmt_parsing_characters": wmt_parsing_characters, "wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13), diff --git a/tensor2tensor/data_generators/replace_oov.py b/tensor2tensor/data_generators/replace_oov.py deleted file mode 100644 index 7e2c8dc50..000000000 --- a/tensor2tensor/data_generators/replace_oov.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Data preprocessor for lm1b benchmark. - -Process the raw text file to replace out-of-vocab words with "". - -The input consists of a tokenized text file, where tokens are separated with -whitespace. - -Outputs a similar text file where the OOV words have been repalced with UNK. -The whitespace in the output may be different. - -This maintains compatibility with the benchmark, which does the same thing. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from six.moves import xrange # pylint: disable=redefined-builtin - -import tensorflow as tf - -tf.app.flags.DEFINE_string("vocab_file", "", - "text file containing one word per line") - -tf.app.flags.DEFINE_string("in_filepattern", "", "input filename") - -tf.app.flags.DEFINE_string( - "out_prefix", "", "The output filename is equal to out_prefix plus " - "the last 15 characters of in_file. (e.g. -00001-of-00100)") - -FLAGS = tf.app.flags.FLAGS - - -def replace_oov(vocab, in_file): - """Replace out-of-vocab words with .""" - out_file = FLAGS.out_prefix + in_file[-15:] - print ("in_file", in_file, "out_file", out_file) - with tf.gfile.Open(out_file, "w") as out: - for line in tf.gfile.Open(in_file): - words = line.split() - for i in xrange(len(words)): - if not vocab.get(words[i]): - words[i] = "UNK" - out_line = " ".join(words) + "\n" - out.write(out_line) - - -def main(_): - vocab = {} - with tf.gfile.Open(FLAGS.vocab_file) as vocab_file: - for line in vocab_file: - vocab[line.strip()] = True - - in_files = tf.gfile.Glob(FLAGS.in_filepattern) - assert in_files, "No matching input files" - for in_file in in_files: - replace_oov(vocab, in_file) - -if __name__ == "__main__": - tf.app.run() diff --git a/tensor2tensor/data_generators/snli.py b/tensor2tensor/data_generators/snli.py index 1d21d94ac..1d3acd356 100644 --- a/tensor2tensor/data_generators/snli.py +++ b/tensor2tensor/data_generators/snli.py @@ -25,6 +25,7 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import tokenizer import tensorflow as tf @@ -139,7 +140,7 @@ def _get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size): return gs example_file = os.path.join(tmp_dir, _EXAMPLES_FILE) gs = text_encoder.SubwordTextEncoder() - token_counts = text_encoder.SubwordTextEncoder.get_token_counts( + token_counts = tokenizer.corpus_token_counts( example_file, corpus_max_lines=1000000) gs = gs.build_to_target_size( vocab_size, token_counts, min_val=1, max_val=1e3) diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index 7934dca34..5d628fa4a 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -28,7 +28,8 @@ # Dependency imports import six -from six import PY2, unichr # pylint: disable=redefined-builtin +from six import PY2 +from six import unichr # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.data_generators import tokenizer @@ -36,7 +37,8 @@ # Conversion between Unicode and UTF-8, if required (on Python2) -native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s) +def native_to_unicode(s): + return s.decode("utf-8") if (PY2 and not isinstance(s, unicode)) else s unicode_to_native = (lambda s: s.encode("utf-8")) if PY2 else (lambda s: s) @@ -203,13 +205,11 @@ class SubwordTextEncoder(TextEncoder): """ - def __init__(self, filename=None, num_reserved_ids=2): + def __init__(self, filename=None): """Initialize and read from a file, if provided.""" - self._tokenizer = tokenizer.Tokenizer() if filename is not None: self._load_from_file(filename) - - super(SubwordTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + super(SubwordTextEncoder, self).__init__(num_reserved_ids=None) def encode(self, raw_text): """Converts a native string to a list of subtoken ids. @@ -219,7 +219,7 @@ def encode(self, raw_text): Returns: a list of integers in the range [0, vocab_size) """ - return self._tokens_to_subtokens(self._tokenizer.encode( + return self._tokens_to_subtokens(tokenizer.encode( native_to_unicode(raw_text))) def decode(self, subtokens): @@ -230,7 +230,7 @@ def decode(self, subtokens): Returns: a native string """ - return unicode_to_native(self._tokenizer.decode( + return unicode_to_native(tokenizer.decode( self._subtokens_to_tokens(subtokens))) @property @@ -260,19 +260,15 @@ def _subtokens_to_tokens(self, subtokens): a list of strings. """ concatenated = "".join( - [self.subtoken_to_subtoken_string(s) for s in subtokens]) + [self._subtoken_to_subtoken_string(s) for s in subtokens]) split = concatenated.split("_") return [self._unescape_token(t + "_") for t in split if t] - def subtoken_to_subtoken_string(self, subtoken): + def _subtoken_to_subtoken_string(self, subtoken): """Subtoken_String (string) corresponding to the given subtoken (id).""" if 0 <= subtoken < self.vocab_size: - subtoken_string = self._all_subtoken_strings[subtoken] - if subtoken_string: - return subtoken_string - if 0 <= subtoken < self._num_reserved_ids: - return u"%s_" % RESERVED_TOKENS[subtoken] - return u"ID%d_" % subtoken + return self._all_subtoken_strings[subtoken] + return u"" def _escaped_token_to_subtokens(self, escaped_token): """Converts an escaped token string to a list of subtokens. @@ -286,7 +282,7 @@ def _escaped_token_to_subtokens(self, escaped_token): pos = 0 lesc = len(escaped_token) while pos < lesc: - end = lesc + end = min(lesc, pos + self._max_subtoken_len) while end > pos: subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1) if subtoken != -1: @@ -348,13 +344,15 @@ def bisect(min_val, max_val): def build_from_token_counts(self, token_counts, min_count, - num_iterations=4): + num_iterations=4, + num_reserved_ids=2): """Train a SubwordTextEncoder based on a dictionary of word counts. Args: token_counts: a dictionary of Unicode strings to int. min_count: an integer - discard subtokens with lower counts. num_iterations: an integer. how many iterations of refinement. + num_reserved_ids: an integer. how many ids to reserve for special tokens. """ # first determine the alphabet to include all characters with count at # least min_count in the dataset. @@ -420,7 +418,7 @@ def build_from_token_counts(self, new_subtoken_strings.sort(reverse=True) # Now we have a candidate vocabulary old_alphabet = self._alphabet - self._init_from_list([u""] * self._num_reserved_ids + + self._init_from_list([u""] * num_reserved_ids + [p[1] for p in new_subtoken_strings]) assert old_alphabet == self._alphabet tf.logging.info("vocab_size = %d" % self.vocab_size) @@ -428,7 +426,7 @@ def build_from_token_counts(self, original = "This sentence was encoded by the SubwordTextEncoder." encoded = self.encode(original) print(encoded) - print([self.subtoken_to_subtoken_string(s) for s in encoded]) + print([self._subtoken_to_subtoken_string(s) for s in encoded]) decoded = self.decode(encoded) print(decoded) assert decoded == original @@ -443,6 +441,9 @@ def dump(self): def _init_from_list(self, subtoken_strings): """Initialize from a list of subtoken strings.""" self._all_subtoken_strings = subtoken_strings + # we remember the maximum length of any subtoken to avoid having to + # check arbitrarily long strings. + self._max_subtoken_len = max([len(s) for s in subtoken_strings]) self._subtoken_string_to_id = { s: i for i, s in enumerate(subtoken_strings) if s} self._alphabet = set([c for c in subtoken_strings if len(c) == 1]) @@ -472,10 +473,11 @@ def _escape_token(self, token): Returns: escaped_token: a unicode string """ - token = token.replace("\\", "\\\\").replace("_", "\\u") + "_" + assert isinstance(token, unicode) + token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u") + u"_" ret = u"" for c in token: - if c in self._alphabet: + if c in self._alphabet and c != u"\n": ret += c else: ret += u"\\%d;" % ord(c) @@ -496,12 +498,14 @@ def _unescape_token(self, escaped_token): c = escaped_token[pos] if c == "\\": pos += 1 + if pos >= len(escaped_token): + break c = escaped_token[pos] if c == u"u": ret += u"_" pos += 1 elif c == "\\": - ret += u"_" + ret += u"\\" pos += 1 else: semicolon_pos = escaped_token.find(u";", pos) @@ -516,19 +520,3 @@ def _unescape_token(self, escaped_token): ret += c pos += 1 return ret - - @classmethod - def get_token_counts(cls, text_filepattern, corpus_max_lines): - """Read the corpus and compute a dictionary of token counts.""" - tok = tokenizer.Tokenizer() - lines_read = 0 - filenames = tf.gfile.Glob(text_filepattern) - for text_filename in filenames: - with tf.gfile.Open(text_filename) as f: - for line in f: - # The tokenizer updates token_counts in encode() - tok.encode(native_to_unicode(line.strip())) - lines_read += 1 - if corpus_max_lines > 0 and lines_read > corpus_max_lines: - return tok.token_counts - return tok.token_counts diff --git a/tensor2tensor/data_generators/text_encoder_build_subword.py b/tensor2tensor/data_generators/text_encoder_build_subword.py index 659e9da14..df8aa73eb 100644 --- a/tensor2tensor/data_generators/text_encoder_build_subword.py +++ b/tensor2tensor/data_generators/text_encoder_build_subword.py @@ -21,16 +21,11 @@ Example usage: python data_generators/text_encoder_build_subword.py \ - --corpus_filepattern=$LM1B_DIR/train-unk-* \ - --corpus_max_lines=17500 \ - --output_fn=$DATA_DIR/lm1b16k.subword_text_encoder \ + --corpus_filepattern=$DATA_DIR/my_problem-train-* \ + --corpus_max_lines=12345 \ + --output_fn=$DATA_DIR/my_problem.subword_text_encoder \ --logtostderr -python data_generators/text_encoder_build_subword.py \ - --corpus_filepattern=$LM1B_DIR/train-unk-* \ - --corpus_max_lines=270000 \ - --output_fn=$DATA_DIR/lm1b64k.subword_text_encoder \ - --logtostderr """ from __future__ import absolute_import from __future__ import division @@ -39,6 +34,7 @@ # Dependency imports from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import tokenizer import tensorflow as tf @@ -50,6 +46,7 @@ tf.app.flags.DEFINE_integer('corpus_max_lines', 10000, 'How many lines of corpus to read') tf.app.flags.DEFINE_integer('num_iterations', 4, 'Number of iterations') +tf.app.flags.DEFINE_bool('split_on_newlines', True, 'Break corpus into lines.') FLAGS = tf.app.flags.FLAGS @@ -57,8 +54,9 @@ def main(unused_argv): gs = text_encoder.SubwordTextEncoder() if not FLAGS.corpus_filepattern: raise ValueError('Must provide --corpus_filepattern') - token_counts = text_encoder.SubwordTextEncoder.get_token_counts( - FLAGS.corpus_filepattern, FLAGS.corpus_max_lines) + token_counts = tokenizer.corpus_token_counts( + FLAGS.corpus_filepattern, FLAGS.corpus_max_lines, + split_on_newlines=FLAGS.split_on_newlines) gs.build_from_token_counts(token_counts, FLAGS.min_count, FLAGS.num_iterations) diff --git a/tensor2tensor/data_generators/text_encoder_inspect_subword.py b/tensor2tensor/data_generators/text_encoder_inspect_subword.py deleted file mode 100644 index 0ad9a2701..000000000 --- a/tensor2tensor/data_generators/text_encoder_inspect_subword.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2017 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Inspect a TFRecord file of tensorflow.Example and show tokenizations. - -python data_generators/text_encoder_inspect_subword.py \ - --logtostderr \ - --vocab_file=$DATA_DIR/tokens.vocab.8192 \ - --in_file=$DATA_DIR/wmt_ende_tokens_8k-train-00000-of-00100 -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from tensor2tensor.data_generators import text_encoder - -import tensorflow as tf - -tf.app.flags.DEFINE_string("vocab_file", "", - "SubwordTextEncoder vocabulary file") - -tf.app.flags.DEFINE_string("in_file", "", "input filename") - -FLAGS = tf.app.flags.FLAGS - - -def ShowSequence(subtokenizer, subtokens, label): - print("%s decoded = %s" % (label, subtokenizer.decode(subtokens))) - print("%s subtoken ids = %s" % (label, subtokens)) - print("%s subtoken strings = %s" % - (label, - [subtokenizer.subtoken_to_subtoken_string(s) for s in subtokens])) - print("") - - -def main(_): - """Convert a file to examples.""" - subtokenizer = text_encoder.SubwordTextEncoder(FLAGS.vocab_file) - reader = tf.python_io.tf_record_iterator(FLAGS.in_file) - for record in reader: - x = tf.train.Example() - x.ParseFromString(record) - inputs = [int(i) for i in x.features.feature["inputs"].int64_list.value] - targets = [int(i) for i in x.features.feature["targets"].int64_list.value] - ShowSequence(subtokenizer, inputs, "inputs") - ShowSequence(subtokenizer, targets, "targets") - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensor2tensor/data_generators/tokenizer.py b/tensor2tensor/data_generators/tokenizer.py index 8490ead19..df6ef6470 100644 --- a/tensor2tensor/data_generators/tokenizer.py +++ b/tensor2tensor/data_generators/tokenizer.py @@ -49,61 +49,101 @@ # Dependency imports +from six import PY2 from six import unichr # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf -class Tokenizer(object): - """Vocab for breaking words into Unicode wordpieces. + +# Conversion between Unicode and UTF-8, if required (on Python2) +_native_to_unicode = (lambda s: s.decode("utf-8")) if PY2 else (lambda s: s) + + +# This set contains all letter and number characters. +_ALPHANUMERIC_CHAR_SET = set( + unichr(i) for i in xrange(sys.maxunicode) + if (unicodedata.category(unichr(i)).startswith("L") or + unicodedata.category(unichr(i)).startswith("N"))) + + +def encode(text): + """Encode a unicode string as a list of tokens. + + Args: + text: a unicode string + Returns: + a list of tokens as Unicode strings + """ + if not text: + return [] + ret = [] + token_start = 0 + # Classify each character in the input string + is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text] + for pos in xrange(1, len(text)): + if is_alnum[pos] != is_alnum[pos - 1]: + token = text[token_start:pos] + if token != u" " or token_start == 0: + ret.append(token) + token_start = pos + final_token = text[token_start:] + ret.append(final_token) + return ret + + +def decode(tokens): + """Decode a list of tokens to a unicode string. + + Args: + tokens: a list of Unicode strings + Returns: + a unicode string + """ + ret = u"" + token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens] + for i, token in enumerate(tokens): + if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: + ret += u" " + ret += token + return ret + + +def corpus_token_counts(text_filepattern, corpus_max_lines, + split_on_newlines=True): + """Read the corpus and compute a dictionary of token counts. + + Args: + text_filepattern: a pattern matching one or more files + corpus_max_lines: an integer - maximum total lines to read. + split_on_newlines: a boolean. If true, then split files by lines and strip + leading and trailing whitespace from each line. + + Returns: + a dictionary from token to count. """ + def read_corpus(): + """Read the corpus.""" + docs = [] + lines_read = 0 + filenames = tf.gfile.Glob(text_filepattern) + for text_filename in filenames: + with tf.gfile.Open(text_filename) as f: + if not split_on_newlines: + docs.append("") + for line in f: + if split_on_newlines: + # The tokenizer updates token_counts in encode() + docs.append(line.strip()) + else: + docs[-1] += line + lines_read += 1 + if corpus_max_lines > 0 and lines_read > corpus_max_lines: + return docs + return docs + counts = defaultdict(int) + for doc in read_corpus(): + for tok in encode(_native_to_unicode(doc)): + counts[tok] += 1 + return counts - # This set contains all letter and number characters. - _ALPHANUMERIC_CHAR_SET = set( - unichr(i) for i in xrange(sys.maxunicode) - if (unicodedata.category(unichr(i)).startswith("L") or - unicodedata.category(unichr(i)).startswith("N"))) - - def __init__(self): - self.token_counts = defaultdict(int) - - def encode(self, text): - """Encode a unicode string as a list of tokens. - - Args: - text: a unicode string - Returns: - a list of tokens as Unicode strings - """ - if not text: - return [] - ret = [] - token_start = 0 - # Classify each character in the input string - is_alnum = [c in self._ALPHANUMERIC_CHAR_SET for c in text] - for pos in xrange(1, len(text)): - if is_alnum[pos] != is_alnum[pos - 1]: - token = text[token_start:pos] - if token != u" " or token_start == 0: - ret.append(token) - self.token_counts[token] += 1 - token_start = pos - final_token = text[token_start:] - ret.append(final_token) - self.token_counts[final_token] += 1 - return ret - - def decode(self, tokens): - """Decode a list of tokens to a unicode string. - - Args: - tokens: a list of Unicode strings - Returns: - a unicode string - """ - ret = u"" - token_is_alnum = [t[0] in self._ALPHANUMERIC_CHAR_SET for t in tokens] - for i, token in enumerate(tokens): - if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]: - ret += u" " - ret += token - return ret diff --git a/tensor2tensor/data_generators/tokenizer_test.py b/tensor2tensor/data_generators/tokenizer_test.py index a85e244ca..404a11396 100644 --- a/tensor2tensor/data_generators/tokenizer_test.py +++ b/tensor2tensor/data_generators/tokenizer_test.py @@ -33,31 +33,30 @@ class TokenizerTest(tf.test.TestCase): def testEncode(self): - t = tokenizer.Tokenizer() self.assertEqual( - t.encode(u"Dude - that's so cool."), + tokenizer.encode(u"Dude - that's so cool."), [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]) self.assertEqual( - t.encode(u"Łukasz est né en 1981."), + tokenizer.encode(u"Łukasz est né en 1981."), [u"Łukasz", u"est", u"né", u"en", u"1981", u"."]) self.assertEqual( - t.encode(u" Spaces at the ends "), + tokenizer.encode(u" Spaces at the ends "), [u" ", u"Spaces", u"at", u"the", u"ends", u" "]) - self.assertEqual(t.encode(u"802.11b"), [u"802", u".", u"11b"]) - self.assertEqual(t.encode(u"two. \nlines"), [u"two", u". \n", u"lines"]) + self.assertEqual(tokenizer.encode(u"802.11b"), [u"802", u".", u"11b"]) + self.assertEqual(tokenizer.encode(u"two. \nlines"), + [u"two", u". \n", u"lines"]) def testDecode(self): - t = tokenizer.Tokenizer() self.assertEqual( - t.decode([u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]), + tokenizer.decode( + [u"Dude", u" - ", u"that", u"'", u"s", u"so", u"cool", u"."]), u"Dude - that's so cool.") def testInvertibilityOnRandomStrings(self): - t = tokenizer.Tokenizer() random.seed(123) for _ in xrange(1000): s = u"".join([unichr(random.randint(0, 65535)) for _ in xrange(10)]) - self.assertEqual(s, t.decode(t.encode(s))) + self.assertEqual(s, tokenizer.decode(tokenizer.encode(s))) if __name__ == "__main__": diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py new file mode 100644 index 000000000..5ccbf14d9 --- /dev/null +++ b/tensor2tensor/data_generators/wiki.py @@ -0,0 +1,128 @@ +# Copyright 2017 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data generator for Wikipedia title to article dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import bz2 +from collections import defaultdict +import os + +# Dependency imports + +import six +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators import tokenizer + +import tensorflow as tf + + +# End-of-sentence marker (should correspond to the position of EOS in the +# RESERVED_TOKENS list in text_encoder.py) +EOS = 1 + + +def _maybe_download_corpus(tmp_dir): + """Download corpus if necessary. + + Args: + tmp_dir: directory containing dataset. + + Returns: + filepath of the downloaded corpus file. + """ + corpus_url = ("https://dumps.wikimedia.org/enwiki/20170620/" + "enwiki-20170620-pages-articles-multistream.xml.bz2") + corpus_filename = os.path.basename(corpus_url) + corpus_filepath = os.path.join(tmp_dir, corpus_filename) + if not os.path.exists(corpus_filepath): + generator_utils.maybe_download(tmp_dir, corpus_filename, corpus_url) + return corpus_filepath + + +def page_generator(tmp_dir, max_docs=None): + doc = u"" + count = 0 + corpus_filepath = _maybe_download_corpus(tmp_dir) + for line in bz2.BZ2File(corpus_filepath, "r"): + line = unicode(line, "utf-8") + if not doc and line != u" \n": + continue + doc += line + if line == u" \n": + yield doc + doc = u"" + count += 1 + if max_docs and count >= max_docs: + break + + +def _page_title(page): + start_pos = page.find(u"") + end_pos = page.find(u"") + assert start_pos != -1 + assert end_pos != -1 + start_pos += len(u"") + return page[start_pos:end_pos] + + +def _get_or_build_subword_text_encoder(tmp_dir): + """Builds a SubwordTextEncoder based on the corpus. + + Args: + tmp_dir: a string + + Returns: + a SubwordTextEncoder. + """ + filename = os.path.join(tmp_dir, "wiki_32k.subword_text_encoder") + if tf.gfile.Exists(filename): + return text_encoder.SubwordTextEncoder(filename) + token_counts = defaultdict(int) + for page in page_generator(tmp_dir, max_docs=1000): + tokens = tokenizer.encode(page) + tokens = set(tokens) + for tok in tokens: + token_counts[tok] += 1 + new_token_counts = defaultdict(int) + for token, count in six.iteritems(token_counts): + if count >= 3: + new_token_counts[token] = count + ret = text_encoder.SubwordTextEncoder() + ret.build_from_token_counts(new_token_counts, min_count=10) + ret.store_to_file(filename) + return ret + + +def generator(tmp_dir, train): + """Generator for lm1b sentences. + + Args: + tmp_dir: a string. + train: a boolean. + + Yields: + A dictionary {"inputs": [<subword ids>], "targets": [<subword ids>]} + """ + assert train + encoder = _get_or_build_subword_text_encoder(tmp_dir) + for page in page_generator(tmp_dir): + title = _page_title(page) + encoded = encoder.encode(page) + [EOS] + encoded_title = encoder.encode(title) + [EOS] + yield {"inputs": encoded_title, "targets": encoded} diff --git a/tensor2tensor/docs/distributed_training.md b/tensor2tensor/docs/distributed_training.md index e7ddd7294..f41197fc4 100644 --- a/tensor2tensor/docs/distributed_training.md +++ b/tensor2tensor/docs/distributed_training.md @@ -51,7 +51,7 @@ Parameter servers only need `--schedule=run_std_server`. ## Utility to produce `TF_CONFIG` and flags -[`bin/make_tf_configs.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/make_tf_configs.py)) +[`t2t-make-tf-configs`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-make-tf-configs)) generates the `TF_CONFIG` json strings and the above-mentioned command-line flags for the workers and parameter servers. diff --git a/tensor2tensor/models/common_hparams.py b/tensor2tensor/models/common_hparams.py index 41ca6f4b0..9bb3af4eb 100644 --- a/tensor2tensor/models/common_hparams.py +++ b/tensor2tensor/models/common_hparams.py @@ -65,6 +65,8 @@ def basic_params1(): sampling_method="argmax", # "argmax" or "random" problem_choice="adaptive", # "uniform", "adaptive", "distributed" multiply_embedding_mode="sqrt_depth", + norm_type="none", # "batch", layer", "noam", "none". + layer_norm_epsilon=1e-6, symbol_modality_num_shards=16, # setting the max length in a minibatch. 0 means default behavior, # max_length = hparams.batch_size * length_multiplier diff --git a/tensor2tensor/models/common_layers.py b/tensor2tensor/models/common_layers.py index 4c63ce8ba..15a712ef2 100644 --- a/tensor2tensor/models/common_layers.py +++ b/tensor2tensor/models/common_layers.py @@ -292,7 +292,8 @@ def conv_internal(conv_fn, inputs, filters, kernel_size, **kwargs): """Conditional conv_fn making kernel 1d or 2d depending on inputs shape.""" static_shape = inputs.get_shape() if not static_shape or len(static_shape) != 4: - raise ValueError("Inputs to conv must have statically known rank 4. Shape:" +str(static_shape)) + raise ValueError("Inputs to conv must have statically known rank 4. " + "Shape: " + str(static_shape)) # Add support for left padding. if "padding" in kwargs and kwargs["padding"] == "LEFT": dilation_rate = (1, 1) @@ -433,24 +434,48 @@ def noam_norm(x, name=None): tf.sqrt(tf.to_float(shape[-1]))) -def residual_function(hparams): +def get_norm(norm_type): + """Get the normalizer function.""" + if norm_type == "layer": + return lambda x, name, filters=None, epsilon=1e-6: layer_norm( # pylint: disable=g-long-lambda + x, filters=filters, epsilon=epsilon, name=name) + if norm_type == "batch": + return tf.layers.batch_normalization + if norm_type == "noam": + return noam_norm + if norm_type == "none": + return lambda x, name: x + raise ValueError("Parameter normalizer_fn must be one of: 'layer', 'batch'," + "'noam', 'none'.") + + +def residual_fn(x, y, norm_type, residual_dropout, + filters=None, + epsilon=1e-16, + name="residual"): """Returns a function for combining layer input and layer output. The returned function on x (layer input) and y (layer output) computes: - norm_function(x + t + norm_function(x + dropout(y)) Args: - hparams: model hyperparameters + x: tensor, input layer + y: tensor, output layer + norm_type: string, type of normalizer function + residual_dropout: integer, dropout value for residual connection + filters: integer, dimension for layer norm, optional + epsilon: integer, value of layer norm epsilon + name: string, name Returns: - a function from x=<layer input> and y=<layer output> to computed output + residual layer output with applied norm_fn. """ - - def residual_fn(x, y): - return hparams.norm_function(x + tf.nn.dropout( - y, 1.0 - hparams.residual_dropout)) - - return residual_fn + norm_fn = get_norm(norm_type) + res = x + tf.nn.dropout(y, 1.0 - residual_dropout) + if norm_type == "layer": + return norm_fn(res, name=name, filters=filters, epsilon=epsilon) + else: + return norm_fn(res, name=name) def conv_block_internal(conv_fn, @@ -1379,127 +1404,126 @@ def smoothing_cross_entropy(logits, labels, vocab_size, confidence): logits=logits, labels=soft_targets) return xentropy - normalizing - -def global_pool_1d(inputs, pooling_type='MAX', mask=None): - """ - Pools elements across the last dimension. Useful to a list of vectors into a - single vector to get a representation of a set. - - Args - inputs: A tensor of dimensions batch_size x sequence_length x input_dims - containing the sequences of input vectors. - pooling_type: the pooling type to use, MAX or AVR - mask: A tensor of dimensions batch_size x sequence_length containing a - mask for the inputs with 1's for existing elements, and 0's elsewhere. - Outputs - output: A tensor of dimensions batch_size x input_dims - dimension containing the sequences of transformed vectors. + +def global_pool_1d(inputs, pooling_type="MAX", mask=None): + """Pool elements across the last dimension. + + Useful to convert a list of vectors into a single vector so as + to get a representation of a set. + + Args: + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + pooling_type: the pooling type to use, MAX or AVR + mask: A tensor of dimensions batch_size x sequence_length containing a + mask for the inputs with 1's for existing elements, and 0's elsewhere. + + Returns: + output: A tensor of dimensions batch_size x input_dims + dimension containing the sequences of transformed vectors. """ - with tf.name_scope("global_pool", [inputs]): if mask is not None: mask = tf.expand_dims(mask, axis=2) inputs = tf.multiply(inputs, mask) - - if pooling_type == 'MAX': + + if pooling_type == "MAX": # A tf.pool can be used here, but reduce is cleaner output = tf.reduce_max(inputs, axis=1) - elif pooling_type == 'AVR': + elif pooling_type == "AVR": if mask is not None: - # Some elems are dummy elems so we can't just reduce the average + # Some elems are dummy elems so we can't just reduce the average. output = tf.reduce_sum(inputs, axis=1) num_elems = tf.reduce_sum(mask, axis=1, keep_dims=True) - output = tf.div(output, num_elems) - #N.B: this will cause a NaN if one batch contains no elements + output = tf.div(output, tf.maximum(num_elems, 1)) else: - output = tf.reduce_mean(inputs, axis=1) - + output = tf.reduce_mean(inputs, axis=1) + return output - - + + def linear_set_layer(layer_size, inputs, context=None, activation_fn=tf.nn.relu, dropout=0.0, name=None): - """ - Basic layer type for doing funky things with sets. + """Basic layer type for doing funky things with sets. + Applies a linear transformation to each element in the input set. If a context is supplied, it is concatenated with the inputs. e.g. One can use global_pool_1d to get a representation of the set which can then be used as the context for the next layer. - - Args - layer_size: Dimension to transform the input vectors to - inputs: A tensor of dimensions batch_size x sequence_length x input_dims - containing the sequences of input vectors. - context: A tensor of dimensions batch_size x context_dims - containing a global statistic about the set. - dropout: Dropout probability. - activation_fn: The activation function to use. - Outputs - output: A tensor of dimensions batch_size x sequence_length x output_dims - dimension containing the sequences of transformed vectors. - - TODO: Add bias add. + + TODO: Add bias add (or control the biases used). + + Args: + layer_size: Dimension to transform the input vectors to. + inputs: A tensor of dimensions batch_size x sequence_length x input_dims + containing the sequences of input vectors. + context: A tensor of dimensions batch_size x context_dims + containing a global statistic about the set. + activation_fn: The activation function to use. + dropout: Dropout probability. + name: name. + + Returns: + output: A tensor of dimensions batch_size x sequence_length x output_dims + dimension containing the sequences of transformed vectors. """ - with tf.variable_scope(name, "linear_set_layer", [inputs]): - # Apply 1D convolution to apply linear filter to each element along the 2nd - # dimension - #in_size = inputs.get_shape().as_list()[-1] + # Apply 1D convolution to apply linear filter to each element + # along the 2nd dimension. outputs = conv1d(inputs, layer_size, 1, activation=None, name="set_conv") - # Apply the context if it exists + # Apply the context if it exists. if context is not None: # Unfortunately tf doesn't support broadcasting via concat, but we can - # simply add the transformed context to get the same effect + # simply add the transformed context to get the same effect. context = tf.expand_dims(context, axis=1) - #context_size = context.get_shape().as_list()[-1] cont_tfm = conv1d(context, layer_size, 1, - activation=None, name="cont_conv") + activation=None, name="cont_conv") outputs += cont_tfm - + if activation_fn is not None: outputs = activation_fn(outputs) - + if dropout != 0.0: - output = tf.nn.dropout(output, 1.0 - dropout) - + outputs = tf.nn.dropout(outputs, 1.0 - dropout) + return outputs - - + + def ravanbakhsh_set_layer(layer_size, inputs, mask=None, activation_fn=tf.nn.tanh, dropout=0.0, name=None): - """ - Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500 + """Layer from Deep Sets paper: https://arxiv.org/abs/1611.04500 . + More parameter-efficient verstion of a linear-set-layer with context. - - - Args - layer_size: Dimension to transform the input vectors to. - inputs: A tensor of dimensions batch_size x sequence_length x vector - containing the sequences of input vectors. - mask: A tensor of dimensions batch_size x sequence_length containing a - mask for the inputs with 1's for existing elements, and 0's elsewhere. - activation_fn: The activation function to use. - Outputs - output: A tensor of dimensions batch_size x sequence_length x vector - dimension containing the sequences of transformed vectors. + + Args: + layer_size: Dimension to transform the input vectors to. + inputs: A tensor of dimensions batch_size x sequence_length x vector + containing the sequences of input vectors. + mask: A tensor of dimensions batch_size x sequence_length containing a + mask for the inputs with 1's for existing elements, and 0's elsewhere. + activation_fn: The activation function to use. + dropout: dropout. + name: name. + + Returns: + output: A tensor of dimensions batch_size x sequence_length x vector + dimension containing the sequences of transformed vectors. """ - with tf.variable_scope(name, "ravanbakhsh_set_layer", [inputs]): output = linear_set_layer( layer_size, inputs - tf.expand_dims(global_pool_1d(inputs, mask=mask), axis=1), activation_fn=activation_fn, + dropout=dropout, name=name) - - return output - + return output diff --git a/tensor2tensor/models/common_layers_test.py b/tensor2tensor/models/common_layers_test.py index 04d428884..a87776bfb 100644 --- a/tensor2tensor/models/common_layers_test.py +++ b/tensor2tensor/models/common_layers_test.py @@ -50,7 +50,7 @@ def testSaturatingSigmoid(self): self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0]) def testFlatten4D3D(self): - x = np.random.randint(1, 9, size=(3, 5, 2)) + x = np.random.random_integers(1, high=8, size=(3, 5, 2)) with self.test_session() as session: y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7)) session.run(tf.global_variables_initializer()) @@ -58,7 +58,7 @@ def testFlatten4D3D(self): self.assertEqual(res.shape, (3, 5 * 2, 7)) def testEmbedding(self): - x = np.random.randint(1, 9, size=(3, 5)) + x = np.random.random_integers(1, high=8, size=(3, 5)) with self.test_session() as session: y = common_layers.embedding(x, 10, 16) session.run(tf.global_variables_initializer()) @@ -81,7 +81,7 @@ def testConv(self): session.run(tf.global_variables_initializer()) res = session.run(y) self.assertEqual(res.shape, (5, 5, 1, 13)) - + def testConv1d(self): x = np.random.rand(5, 7, 11) with self.test_session() as session: @@ -301,66 +301,125 @@ def testDeconvStride2MultiStep(self): session.run(tf.global_variables_initializer()) actual = session.run(a) self.assertEqual(actual.shape, (5, 32, 1, 16)) - + + def testGetNormLayerFn(self): + norm_type = "layer" + with self.test_session() as session: + a = common_layers.get_norm(norm_type) + x1 = np.random.rand(5, 2, 1, 11) + x2 = a(tf.constant(x1, dtype=tf.float32), name="layer", filters=11) + session.run(tf.global_variables_initializer()) + actual = session.run(x2) + self.assertEqual(actual.shape, (5, 2, 1, 11)) + + def testGetNormNoamFn(self): + norm_type = "noam" + with self.test_session() as session: + a = common_layers.get_norm(norm_type) + x1 = np.random.rand(5, 2, 1, 11) + x2 = a(tf.constant(x1, dtype=tf.float32), name="noam") + session.run(tf.global_variables_initializer()) + actual = session.run(x2) + self.assertEqual(actual.shape, (5, 2, 1, 11)) + + def testGetNormBatchFn(self): + norm_type = "batch" + with self.test_session() as session: + a = common_layers.get_norm(norm_type) + x1 = np.random.rand(5, 2, 1, 11) + x2 = a(tf.constant(x1, dtype=tf.float32), name="batch") + session.run(tf.global_variables_initializer()) + actual = session.run(x2) + self.assertEqual(actual.shape, (5, 2, 1, 11)) + + def testGetNormNoneFn(self): + norm_type = "none" + with self.test_session() as session: + a = common_layers.get_norm(norm_type) + x1 = np.random.rand(5, 2, 1, 11) + x2 = a(tf.constant(x1, dtype=tf.float32), name="none") + session.run(tf.global_variables_initializer()) + actual = session.run(x2) + self.assertEqual(actual.shape, (5, 2, 1, 11)) + self.assertAllClose(actual, x1, atol=1e-03) + + def testResidualFn(self): + norm_type = "batch" + with self.test_session() as session: + x1 = np.random.rand(5, 2, 1, 11) + x2 = np.random.rand(5, 2, 1, 11) + x3 = common_layers.residual_fn( + tf.constant(x1, dtype=tf.float32), + tf.constant(x2, dtype=tf.float32), + norm_type, 0.1) + session.run(tf.global_variables_initializer()) + actual = session.run(x3) + self.assertEqual(actual.shape, (5, 2, 1, 11)) + + def testResidualFnWithLayerNorm(self): + norm_type = "layer" + with self.test_session() as session: + x1 = np.random.rand(5, 2, 1, 11) + x2 = np.random.rand(5, 2, 1, 11) + x3 = common_layers.residual_fn( + tf.constant(x1, dtype=tf.float32), + tf.constant(x2, dtype=tf.float32), + norm_type, 0.1, epsilon=0.1) + session.run(tf.global_variables_initializer()) + actual = session.run(x3) + self.assertEqual(actual.shape, (5, 2, 1, 11)) + def testGlobalPool1d(self): - shape = (5, 4) - x1 = np.random.rand(5,4,11) - #mask = np.random.randint(2, size=shape) - no_mask = np.ones((5,4)) - full_mask = np.zeros((5,4)) - + x1 = np.random.rand(5, 4, 11) + no_mask = np.ones((5, 4)) + full_mask = np.zeros((5, 4)) + with self.test_session() as session: x1_ = tf.Variable(x1, dtype=tf.float32) no_mask_ = tf.Variable(no_mask, dtype=tf.float32) full_mask_ = tf.Variable(full_mask, dtype=tf.float32) - + none_mask_max = common_layers.global_pool_1d(x1_) no_mask_max = common_layers.global_pool_1d(x1_, mask=no_mask_) result1 = tf.reduce_sum(none_mask_max - no_mask_max) - + full_mask_max = common_layers.global_pool_1d(x1_, mask=full_mask_) result2 = tf.reduce_sum(full_mask_max) - - none_mask_avr = common_layers.global_pool_1d(x1_, 'AVR') - no_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', no_mask_) + + none_mask_avr = common_layers.global_pool_1d(x1_, "AVR") + no_mask_avr = common_layers.global_pool_1d(x1_, "AVR", no_mask_) result3 = tf.reduce_sum(none_mask_avr - no_mask_avr) - - full_mask_avr = common_layers.global_pool_1d(x1_, 'AVR', full_mask_) + + full_mask_avr = common_layers.global_pool_1d(x1_, "AVR", full_mask_) result4 = tf.reduce_sum(full_mask_avr) - + session.run(tf.global_variables_initializer()) actual = session.run([result1, result2, result3, result4]) - # N.B: Last result will give a NaN. self.assertAllEqual(actual[:3], [0.0, 0.0, 0.0]) - def testLinearSetLayer(self): - x1 = np.random.rand(5,4,11) - cont = np.random.rand(5,13) + x1 = np.random.rand(5, 4, 11) + cont = np.random.rand(5, 13) with self.test_session() as session: x1_ = tf.Variable(x1, dtype=tf.float32) cont_ = tf.Variable(cont, dtype=tf.float32) - + simple_ff = common_layers.linear_set_layer(32, x1_) cont_ff = common_layers.linear_set_layer(32, x1_, context=cont_) - + session.run(tf.global_variables_initializer()) actual = session.run([simple_ff, cont_ff]) - self.assertEqual(actual[0].shape, (5,4,32)) - self.assertEqual(actual[1].shape, (5,4,32)) - + self.assertEqual(actual[0].shape, (5, 4, 32)) + self.assertEqual(actual[1].shape, (5, 4, 32)) + def testRavanbakhshSetLayer(self): - x1 = np.random.rand(5,4,11) - cont = np.random.rand(5,13) + x1 = np.random.rand(5, 4, 11) with self.test_session() as session: x1_ = tf.Variable(x1, dtype=tf.float32) - cont_ = tf.Variable(cont, dtype=tf.float32) - layer = common_layers.ravanbakhsh_set_layer(32, x1_) - session.run(tf.global_variables_initializer()) actual = session.run(layer) - self.assertEqual(actual.shape, (5,4,32)) + self.assertEqual(actual.shape, (5, 4, 32)) if __name__ == "__main__": diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index b8f0811e5..ae0e0da61 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -32,5 +32,6 @@ from tensor2tensor.models import neural_gpu from tensor2tensor.models import slicenet from tensor2tensor.models import transformer +from tensor2tensor.models import transformer_alternative from tensor2tensor.models import xception # pylint: enable=unused-import diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index b42d71cb3..26e7469c2 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -19,6 +19,8 @@ # Dependency imports +from six.moves import xrange # pylint: disable=redefined-builtin + from tensor2tensor.models import common_attention from tensor2tensor.models import common_hparams from tensor2tensor.models import common_layers @@ -27,7 +29,6 @@ from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model -from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 0b9efc2c3..77659e8ef 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -31,21 +31,6 @@ import tensorflow as tf -def get_norm(hparams): - """Get the normalizer function.""" - if hparams.normalizer_fn == "layer": - return lambda x, name: common_layers.layer_norm( # pylint: disable=g-long-lambda - x, hparams.hidden_size, name=name) - if hparams.normalizer_fn == "batch": - return tf.layers.batch_normalization - if hparams.normalizer_fn == "noam": - return common_layers.noam_norm - if hparams.normalizer_fn == "none": - return lambda x, name: x - raise ValueError("Parameter normalizer_fn must be one of: 'layer', 'batch'," - "'noam', 'none'.") - - def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] @@ -128,7 +113,7 @@ def multi_conv_res(x, padding, name, layers, hparams, hparams.separability - i for i in reversed(range(len(dilations_and_kernels2))) ] - norm_fn = get_norm(hparams) + norm_fn = common_layers.get_norm(hparams.norm_type) for layer in xrange(layers): with tf.variable_scope("layer_%d" % layer): y = common_layers.subseparable_conv_block( @@ -188,7 +173,7 @@ def similarity_cost(inputs_encoded, targets_encoded): def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" - norm_fn = get_norm(hparams) + norm_fn = common_layers.get_norm(hparams.norm_type) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) @@ -311,7 +296,7 @@ def slicenet_params1(): hparams.num_hidden_layers = 4 hparams.kernel_height = 3 hparams.kernel_width = 1 - hparams.add_hparam("normalizer_fn", "layer") # New ones are added like this. + hparams.norm_type = "layer" hparams.learning_rate_decay_scheme = "exp50k" hparams.learning_rate = 0.05 hparams.learning_rate_warmup_steps = 3000 @@ -322,7 +307,7 @@ def slicenet_params1(): hparams.optimizer_adam_epsilon = 1e-6 hparams.optimizer_adam_beta1 = 0.85 hparams.optimizer_adam_beta2 = 0.997 - hparams.add_hparam("large_kernel_size", 15) + hparams.add_hparam("large_kernel_size", 15) # New ones are added like this. hparams.add_hparam("separability", -2) # A dilation scheme, one of _DILATION_SCHEMES. hparams.add_hparam("dilation_scheme", "1.1.1.1") diff --git a/tensor2tensor/models/transformer_alternative.py b/tensor2tensor/models/transformer_alternative.py index 90fea6139..e50cba86f 100644 --- a/tensor2tensor/models/transformer_alternative.py +++ b/tensor2tensor/models/transformer_alternative.py @@ -12,27 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" - Alternative transformer network using different layer types to demonstrate - alternatives to self attention. +"""Alternative transformer network. - Code is mostly copied from original Transformer source (if that wasn't - already obvious). +Using different layer types to demonstrate alternatives to self attention. +Code is mostly copied from original Transformer source. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy - # Dependency imports from six.moves import xrange # pylint: disable=redefined-builtin from tensor2tensor.models import common_attention -from tensor2tensor.models import common_hparams from tensor2tensor.models import common_layers from tensor2tensor.models import transformer from tensor2tensor.utils import registry @@ -45,10 +41,7 @@ class TransformerAlt(t2t_model.T2TModel): def model_fn_body(self, features): - # - - # Remove dropout if not training - hparams = copy.copy(self._hparams) + hparams = self._hparams targets = features["targets"] inputs = features.get("inputs") target_space = features.get("target_space_id") @@ -56,16 +49,17 @@ def model_fn_body(self, features): inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) - (encoder_input, encoder_attention_bias, _) = (transformer.\ - transformer_prepare_encoder(inputs, target_space, hparams) ) - (decoder_input, decoder_self_attention_bias) = transformer.\ - transformer_prepare_decoder(targets, hparams) - + (encoder_input, encoder_attention_bias, + _) = transformer.transformer_prepare_encoder(inputs, target_space, hparams) + (decoder_input, + decoder_self_attention_bias) = transformer.transformer_prepare_decoder( + targets, hparams) + # We need masks of the form batch size x input sequences # Biases seem to be of the form batch_size x 1 x input sequences x vec dim - # Squeeze out dim one, and get the first element of each vector - encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:,:,0] - decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:,:,0] + # Squeeze out dim one, and get the first element of each vector. + encoder_mask = tf.squeeze(encoder_attention_bias, [1])[:, :, 0] + decoder_mask = tf.squeeze(decoder_self_attention_bias, [1])[:, :, 0] def residual_fn(x, y): return common_layers.layer_norm(x + tf.nn.dropout( @@ -79,47 +73,45 @@ def residual_fn(x, y): decoder_output = alt_transformer_decoder( decoder_input, encoder_output, residual_fn, decoder_mask, encoder_attention_bias, hparams) - + decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output - - - + + def composite_layer(inputs, mask, hparams): + """Composite layer.""" x = inputs - - # Applies ravanbakhsh on top of each other + + # Applies ravanbakhsh on top of each other. if hparams.composite_layer_type == "ravanbakhsh": for layer in xrange(hparams.layers_per_layer): with tf.variable_scope(".%d" % layer): x = common_layers.ravanbakhsh_set_layer( - hparams.hidden_size, - x, - mask=mask, - dropout=0.0) - - # Transforms elements to get a context, and then uses this in a final layer + hparams.hidden_size, + x, + mask=mask, + dropout=0.0) + + # Transforms elements to get a context, and then uses this in a final layer. elif hparams.composite_layer_type == "reembedding": - initial_elems = x - # Transform elements n times and then pool + # Transform elements n times and then pool. for layer in xrange(hparams.layers_per_layer): with tf.variable_scope(".%d" % layer): x = common_layers.linear_set_layer( - hparams.hidden_size, - x, - dropout=0.0) - context = common_layers.global_pool_1d(x, mask=mask) - - #Final layer - x = common_layers.linear_set_layer( hparams.hidden_size, x, - context=context, dropout=0.0) - + context = common_layers.global_pool_1d(x, mask=mask) + + # Final layer. + x = common_layers.linear_set_layer( + hparams.hidden_size, + x, + context=context, + dropout=0.0) + return x - def alt_transformer_encoder(encoder_input, @@ -127,17 +119,14 @@ def alt_transformer_encoder(encoder_input, mask, hparams, name="encoder"): - + """Alternative encoder.""" x = encoder_input - - # Summaries don't work in multi-problem setting yet. - summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 - + with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): x = residual_fn(x, composite_layer(x, mask, hparams)) - + return x @@ -148,34 +137,31 @@ def alt_transformer_decoder(decoder_input, encoder_decoder_attention_bias, hparams, name="decoder"): - + """Alternative decoder.""" x = decoder_input - + # Summaries don't work in multi-problem setting yet. summaries = "problems" not in hparams.values() or len(hparams.problems) == 1 with tf.variable_scope(name): for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): - + x_ = common_attention.multihead_attention( - x, - encoder_output, - encoder_decoder_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout, - summaries=summaries, - name="encdec_attention") + x, + encoder_output, + encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + summaries=summaries, + name="encdec_attention") x_ = residual_fn(x_, composite_layer(x_, mask, hparams)) x = residual_fn(x, x_) - - return x - - + return x @registry.register_hparams @@ -184,7 +170,5 @@ def transformer_alt(): hparams = transformer.transformer_base() hparams.batch_size = 64 hparams.add_hparam("layers_per_layer", 4) - #hparams.add_hparam("composite_layer_type", "ravanbakhsh") #ravanbakhsh or reembedding hparams.add_hparam("composite_layer_type", "reembedding") return hparams - diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index 52c1d1ba5..9535558a4 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -24,7 +24,6 @@ from tensor2tensor.data_generators import problem_hparams from tensor2tensor.models import transformer -from tensor2tensor.models import transformer_alternative import tensorflow as tf diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index 0022081ae..7386d3ea0 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -45,8 +45,9 @@ def test_generator(): for i in xrange(100): yield {"inputs": [i], "targets": [i], "floats": [i + 0.5]} - generator_utils.generate_files(test_generator(), tmp_file_name, tmp_dir) - self.assertTrue(tf.gfile.Exists(tmp_file_path + "-00000-of-00001")) + filenames = generator_utils.train_data_filenames(tmp_file_name, tmp_dir, 1) + generator_utils.generate_files(test_generator(), filenames) + self.assertTrue(tf.gfile.Exists(tmp_file_path + "-train-00000-of-00001")) examples_train = data_reader.examples_queue( [tmp_file_path + "*"], { @@ -82,7 +83,7 @@ def test_generator(): self.assertTrue(is_shuffled) # Clean up. - os.remove(tmp_file_path + "-00000-of-00001") + os.remove(tmp_file_path + "-train-00000-of-00001") os.remove(tmp_file_path) # TODO(rsepassi): fix and reenable test @@ -97,8 +98,9 @@ def test_generator(): for i in xrange(100): yield {"inputs": [i + 1 for _ in xrange(i + 1)], "targets": [i + 1]} - generator_utils.generate_files(test_generator(), tmp_file_name, tmp_dir) - self.assertTrue(tf.gfile.Exists(tmp_file_path + "-00000-of-00001")) + filenames = generator_utils.train_data_filenames(tmp_file_name, tmp_dir, 1) + generator_utils.generate_files(test_generator(), filenames) + self.assertTrue(tf.gfile.Exists(tmp_file_path + "-train-00000-of-00001")) examples_train = data_reader.examples_queue([tmp_file_path + "*"], { "inputs": tf.VarLenFeature(tf.int64), @@ -140,7 +142,7 @@ def test_generator(): # Clean up. coord.request_stop() coord.join() - os.remove(tmp_file_path + "-00000-of-00001") + os.remove(tmp_file_path + "-train-00000-of-00001") os.remove(tmp_file_path) diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index fd1c6885c..d621b6fbc 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -38,10 +38,12 @@ def setUpClass(cls): FLAGS.problems = "algorithmic_addition_binary40" TrainerUtilsTest.data_dir = tf.test.get_temp_dir() gen = algorithmic.identity_generator(2, 10, 300) - generator_utils.generate_files(gen, FLAGS.problems + "-train", - TrainerUtilsTest.data_dir, 1, 100) - generator_utils.generate_files(gen, FLAGS.problems + "-dev", - TrainerUtilsTest.data_dir, 1, 100) + train_filenames = generator_utils.train_data_filenames( + FLAGS.problems, TrainerUtilsTest.data_dir, 1) + dev_filenames = generator_utils.dev_data_filenames( + FLAGS.problems, TrainerUtilsTest.data_dir, 1) + generator_utils.generate_files(gen, train_filenames, 100) + generator_utils.generate_files(gen, dev_filenames, 100) def testModelsImported(self): models = registry.list_models()