Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #136 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.0.13
  • Loading branch information
lukaszkaiser authored Jul 12, 2017
2 parents d827bb2 + 64defb7 commit 006ecb5
Show file tree
Hide file tree
Showing 28 changed files with 893 additions and 755 deletions.
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@

setup(
name='tensor2tensor',
version='1.0.12',
version='1.0.13',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
url='http://github.com/tensorflow/tensor2tensor',
license='Apache 2.0',
packages=find_packages(),
scripts=['tensor2tensor/bin/t2t-trainer', 'tensor2tensor/bin/t2t-datagen'],
scripts=[
'tensor2tensor/bin/t2t-trainer',
'tensor2tensor/bin/t2t-datagen',
'tensor2tensor/bin/t2t-make-tf-configs',
],
install_requires=[
'numpy',
'sympy',
Expand Down
44 changes: 35 additions & 9 deletions tensor2tensor/bin/t2t-datagen
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import ptb
from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing

Expand Down Expand Up @@ -138,6 +140,14 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
),
"lm1b_32k": (
lambda: lm1b.generator(FLAGS.tmp_dir, True),
lambda: lm1b.generator(FLAGS.tmp_dir, False)
),
"wiki_32k": (
lambda: wiki.generator(FLAGS.tmp_dir, True),
1000
),
"image_mnist_tune": (
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 55000),
lambda: image.mnist_generator(FLAGS.tmp_dir, True, 5000, 55000)),
Expand Down Expand Up @@ -335,17 +345,33 @@ def main(_):

training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.generate_files(
training_gen(), problem + UNSHUFFLED_SUFFIX + "-train",
FLAGS.data_dir, FLAGS.num_shards, FLAGS.max_cases)

tf.logging.info("Generating development data for %s.", problem)
dev_output_files = generator_utils.generate_files(
dev_gen(), problem + UNSHUFFLED_SUFFIX + "-dev", FLAGS.data_dir, 1)
if isinstance(dev_gen, int):
# The dev set and test sets are generated as extra shards using the
# training generator. The integer specifies the number of training
# shards. FLAGS.num_shards is ignored.
num_training_shards = dev_gen
tf.logging.info("Generating data for %s.", problem)
all_output_files = generator_utils.combined_data_filenames(
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_training_shards)
generator_utils.generate_files(
training_gen(), all_output_files, FLAGS.max_cases)
else:
# usual case - train data and dev data are generated using separate
# generators.
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, FLAGS.num_shards)
generator_utils.generate_files(
training_gen(), train_output_files, FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
dev_shards = 10 if "coco" in problem else 1
dev_output_files = generator_utils.dev_data_filenames(
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
generator_utils.generate_files(dev_gen(), dev_output_files)
all_output_files = train_output_files + dev_output_files

tf.logging.info("Shuffling data...")
for fname in train_output_files + dev_output_files:
for fname in all_output_files:
records = generator_utils.read_records(fname)
random.shuffle(records)
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/usr/bin/env python
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,7 +17,7 @@
Usage:
`make_tf_configs.py --workers="server1:1234" --ps="server3:2134,server4:2334"`
`t2t-make-tf-configs --workers="server1:1234" --ps="server3:2134,server4:2334"`
Outputs 1 line per job to stdout, first the workers, then the parameter servers.
Each line has the TF_CONFIG, then a tab, then the command line flags for that
Expand Down Expand Up @@ -74,7 +75,8 @@ def main(_):
"task": {
"type": task_type,
"index": idx
}
},
"environment": "cloud",
})
print("'%s'\t%s" % (tf_config, cmd_line_flags))

Expand Down
61 changes: 37 additions & 24 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

from collections import defaultdict
import gzip
import io
import os
Expand All @@ -30,7 +31,7 @@
import six.moves.urllib_request as urllib # Imports urllib on Python2, urllib.request on Python3

from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators.tokenizer import Tokenizer
from tensor2tensor.data_generators import tokenizer

import tensorflow as tf

Expand Down Expand Up @@ -84,10 +85,34 @@ def generate_files_distributed(generator,
return output_file


def _data_filenames(output_name, output_dir, num_shards):
return [os.path.join(
output_dir, "%s-%.5d-of-%.5d" % (output_name, shard, num_shards))
for shard in xrange(num_shards)]


def train_data_filenames(problem, output_dir, num_shards):
return _data_filenames(
problem + "-train", output_dir, num_shards)


def dev_data_filenames(problem, output_dir, num_shards):
return _data_filenames(problem + "-dev", output_dir, num_shards)


def test_data_filenames(problem, output_dir, num_shards):
return _data_filenames(problem + "-test", output_dir, num_shards)


def combined_data_filenames(problem, output_dir, num_training_shards):
return (
train_data_filenames(problem, output_dir, num_training_shards) +
dev_data_filenames(problem, output_dir, 1) +
test_data_filenames(problem, output_dir, 1))


def generate_files(generator,
output_name,
output_dir,
num_shards=1,
output_filenames,
max_cases=None):
"""Generate cases from a generator and save as TFRecord files.
Expand All @@ -96,27 +121,16 @@ def generate_files(generator,
Args:
generator: a generator yielding (string -> int/float/str list) dictionaries.
output_name: the file name prefix under which output will be saved.
output_dir: directory to save the output to.
num_shards: how many shards to use (defaults to 1).
output_filenames: List of output file paths.
max_cases: maximum number of cases to get from the generator;
if None (default), we use the generator until StopIteration is raised.
Returns:
List of output file paths.
"""
writers = []
output_files = []
for shard in xrange(num_shards):
output_filename = "%s-%.5d-of-%.5d" % (output_name, shard, num_shards)
output_file = os.path.join(output_dir, output_filename)
output_files.append(output_file)
writers.append(tf.python_io.TFRecordWriter(output_file))

num_shards = len(output_filenames)
writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
counter, shard = 0, 0
for case in generator:
if counter > 0 and counter % 100000 == 0:
tf.logging.info("Generating case %d for %s." % (counter, output_name))
tf.logging.info("Generating case %d." % counter)
counter += 1
if max_cases and counter > max_cases:
break
Expand All @@ -127,8 +141,6 @@ def generate_files(generator,
for writer in writers:
writer.close()

return output_files


def download_report_hook(count, block_size, total_size):
"""Report hook for download progress.
Expand Down Expand Up @@ -235,7 +247,7 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):

sources = sources or _DATA_FILE_URLS
tf.logging.info("Generating vocab from: %s", str(sources))
tokenizer = Tokenizer()
token_counts = defaultdict(int)
for source in sources:
url = source[0]
filename = os.path.basename(url)
Expand Down Expand Up @@ -269,10 +281,11 @@ def get_or_generate_vocab(tmp_dir, vocab_filename, vocab_size, sources=None):
break
line = line.strip()
file_byte_budget -= len(line)
_ = tokenizer.encode(text_encoder.native_to_unicode(line))
for tok in tokenizer.encode(text_encoder.native_to_unicode(line)):
token_counts[tok] += 1

vocab = text_encoder.SubwordTextEncoder.build_to_target_size(
vocab_size, tokenizer.token_counts, 1, 1e3)
vocab_size, token_counts, 1, 1e3)
vocab.store_to_file(vocab_filepath)
return vocab

Expand Down
7 changes: 4 additions & 3 deletions tensor2tensor/data_generators/generator_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def testGenerateFiles(self):
def test_generator():
yield {"inputs": [1], "target": [1]}

generator_utils.generate_files(test_generator(), tmp_file_name, tmp_dir)
self.assertTrue(tf.gfile.Exists(tmp_file_path + "-00000-of-00001"))
filenames = generator_utils.train_data_filenames(tmp_file_name, tmp_dir, 1)
generator_utils.generate_files(test_generator(), filenames)
self.assertTrue(tf.gfile.Exists(tmp_file_path + "-train-00000-of-00001"))

# Clean up.
os.remove(tmp_file_path + "-00000-of-00001")
os.remove(tmp_file_path + "-train-00000-of-00001")
os.remove(tmp_file_path)

def testMaybeDownload(self):
Expand Down
81 changes: 81 additions & 0 deletions tensor2tensor/data_generators/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Inspect a TFRecord file of tensorflow.Example and show tokenizations.
python data_generators/inspect.py \
--logtostderr \
--print_targets \
--subword_text_encoder_filename=$DATA_DIR/tokens.vocab.8192 \
--input_filename=$DATA_DIR/wmt_ende_tokens_8k-train-00000-of-00100
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports

from tensor2tensor.data_generators import text_encoder

import tensorflow as tf

tf.app.flags.DEFINE_string("subword_text_encoder_filename", "",
"SubwordTextEncoder vocabulary file")
tf.app.flags.DEFINE_string("input_filename", "", "input filename")
tf.app.flags.DEFINE_bool("print_inputs", False,
"Print decoded inputs to stdout")
tf.app.flags.DEFINE_bool("print_targets", False,
"Print decoded targets to stdout")

FLAGS = tf.app.flags.FLAGS


def main(_):
"""Convert a file to examples."""
if FLAGS.subword_text_encoder_filename:
encoder = text_encoder.SubwordTextEncoder(
FLAGS.subword_text_encoder_filename)
else:
encoder = None
reader = tf.python_io.tf_record_iterator(FLAGS.input_filename)
total_sequences = 0
total_input_tokens = 0
total_target_tokens = 0
max_input_length = 0
max_target_length = 0
for record in reader:
x = tf.train.Example()
x.ParseFromString(record)
inputs = [int(i) for i in x.features.feature["inputs"].int64_list.value]
targets = [int(i) for i in x.features.feature["targets"].int64_list.value]
if FLAGS.print_inputs:
print(encoder.decode(inputs) if encoder else inputs)
if FLAGS.print_targets:
print(encoder.decode(targets) if encoder else targets)
total_input_tokens += len(inputs)
total_target_tokens += len(targets)
total_sequences += 1
max_input_length = max(max_input_length, len(inputs))
max_target_length = max(max_target_length, len(targets))

tf.logging.info("total_sequences: %d", total_sequences)
tf.logging.info("total_input_tokens: %d", total_input_tokens)
tf.logging.info("total_target_tokens: %d", total_target_tokens)
tf.logging.info("max_input_length: %d", max_input_length)
tf.logging.info("max_target_length: %d", max_target_length)


if __name__ == "__main__":
tf.app.run()
Loading

0 comments on commit 006ecb5

Please sign in to comment.