Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #156 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.0.14
  • Loading branch information
lukaszkaiser authored Jul 14, 2017
2 parents 43bfb9f + c8b7000 commit 0c66117
Show file tree
Hide file tree
Showing 82 changed files with 1,272 additions and 745 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def transformer_my_very_own_hparams_set():

```python
# In ~/usr/t2t_usr/__init__.py
import my_registrations
from . import my_registrations
```

```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.0.13',
version='1.0.14',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017 Google Inc.
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
144 changes: 55 additions & 89 deletions tensor2tensor/bin/t2t-datagen
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright 2017 Google Inc.
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,9 @@ takes 2 arguments - input_directory and mode (one of "train" or "dev") - and
yields for each training example a dictionary mapping string feature names to
lists of {string, int, float}. The generator will be run once for each mode.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import random
import tempfile
Expand All @@ -34,6 +37,7 @@ import numpy as np

from tensor2tensor.data_generators import algorithmic
from tensor2tensor.data_generators import algorithmic_math
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
from tensor2tensor.data_generators import audio
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image
Expand All @@ -43,6 +47,7 @@ from tensor2tensor.data_generators import snli
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing
from tensor2tensor.utils import registry

import tensorflow as tf

Expand All @@ -62,12 +67,6 @@ flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
# Mapping from problems that we can generate data for to their generators.
# pylint: disable=g-long-lambda
_SUPPORTED_PROBLEM_GENERATORS = {
"algorithmic_identity_binary40": (
lambda: algorithmic.identity_generator(2, 40, 100000),
lambda: algorithmic.identity_generator(2, 400, 10000)),
"algorithmic_identity_decimal40": (
lambda: algorithmic.identity_generator(10, 40, 100000),
lambda: algorithmic.identity_generator(10, 400, 10000)),
"algorithmic_shift_decimal40": (
lambda: algorithmic.shift_generator(20, 10, 40, 100000),
lambda: algorithmic.shift_generator(20, 10, 80, 10000)),
Expand Down Expand Up @@ -104,9 +103,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
"ice_parsing_tokens": (
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
True, "ice", 2**13, 2**8),
True, "ice", 2**13, 2**8),
lambda: wmt.tabbed_parsing_token_generator(FLAGS.tmp_dir,
False, "ice", 2**13, 2**8)),
False, "ice", 2**13, 2**8)),
"ice_parsing_characters": (
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, True),
lambda: wmt.tabbed_parsing_character_generator(FLAGS.tmp_dir, False)),
Expand All @@ -118,11 +117,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
2**14, 2**9),
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
2**14, 2**9)),
"wsj_parsing_tokens_32k": (
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, True,
2**15, 2**9),
lambda: wsj_parsing.parsing_token_generator(FLAGS.tmp_dir, False,
2**15, 2**9)),
"wmt_enfr_characters": (
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, True),
lambda: wmt.enfr_character_generator(FLAGS.tmp_dir, False)),
Expand All @@ -140,14 +134,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
"wmt_ende_bpe32k": (
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, True),
lambda: wmt.ende_bpe_token_generator(FLAGS.tmp_dir, False)),
"wmt_ende_tokens_8k": (
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**13),
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**13)
),
"wmt_ende_tokens_32k": (
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, True, 2**15),
lambda: wmt.ende_wordpiece_token_generator(FLAGS.tmp_dir, False, 2**15)
),
"wmt_zhen_tokens_32k": (
lambda: wmt.zhen_wordpiece_token_generator(FLAGS.tmp_dir, True,
2**15, 2**15),
Expand All @@ -174,26 +160,9 @@ _SUPPORTED_PROBLEM_GENERATORS = {
"image_cifar10_test": (
lambda: image.cifar10_generator(FLAGS.tmp_dir, True, 50000),
lambda: image.cifar10_generator(FLAGS.tmp_dir, False, 10000)),
"image_mscoco_characters_tune": (
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 70000),
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 10000, 70000)),
"image_mscoco_characters_test": (
lambda: image.mscoco_generator(FLAGS.tmp_dir, True, 80000),
lambda: image.mscoco_generator(FLAGS.tmp_dir, False, 40000)),
"image_mscoco_tokens_8k_tune": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
70000,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13),
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
10000,
70000,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13)),
"image_mscoco_tokens_8k_test": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
Expand All @@ -207,20 +176,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
40000,
vocab_filename="tokens.vocab.%d" % 2**13,
vocab_size=2**13)),
"image_mscoco_tokens_32k_tune": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
70000,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15),
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
True,
10000,
70000,
vocab_filename="tokens.vocab.%d" % 2**15,
vocab_size=2**15)),
"image_mscoco_tokens_32k_test": (
lambda: image.mscoco_generator(
FLAGS.tmp_dir,
Expand Down Expand Up @@ -308,8 +263,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {

# pylint: enable=g-long-lambda

UNSHUFFLED_SUFFIX = "-unshuffled"


def set_random_seed():
"""Set the random seed from flag everywhere."""
Expand All @@ -322,13 +275,15 @@ def main(_):
tf.logging.set_verbosity(tf.logging.INFO)

# Calculate the list of problems to generate.
problems = list(sorted(_SUPPORTED_PROBLEM_GENERATORS))
problems = sorted(
list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())
if FLAGS.problem and FLAGS.problem[-1] == "*":
problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
elif FLAGS.problem:
problems = [p for p in problems if p == FLAGS.problem]
else:
problems = []

# Remove TIMIT if paths are not given.
if not FLAGS.timit_paths:
problems = [p for p in problems if "timit" not in p]
Expand All @@ -340,7 +295,8 @@ def main(_):
problems = [p for p in problems if "ende_bpe" not in p]

if not problems:
problems_str = "\n * ".join(sorted(_SUPPORTED_PROBLEM_GENERATORS))
problems_str = "\n * ".join(
sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems()))
error_msg = ("You must specify one of the supported problems to "
"generate data for:\n * " + problems_str + "\n")
error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with "
Expand All @@ -357,40 +313,50 @@ def main(_):
for problem in problems:
set_random_seed()

training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

if isinstance(dev_gen, int):
# The dev set and test sets are generated as extra shards using the
# training generator. The integer specifies the number of training
# shards. FLAGS.num_shards is ignored.
num_training_shards = dev_gen
tf.logging.info("Generating data for %s.", problem)
all_output_files = generator_utils.combined_data_filenames(
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_training_shards)
generator_utils.generate_files(
training_gen(), all_output_files, FLAGS.max_cases)
if problem in _SUPPORTED_PROBLEM_GENERATORS:
generate_data_for_problem(problem)
else:
# usual case - train data and dev data are generated using separate
# generators.
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, FLAGS.num_shards)
generator_utils.generate_files(
training_gen(), train_output_files, FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
dev_shards = 10 if "coco" in problem else 1
dev_output_files = generator_utils.dev_data_filenames(
problem + UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
generator_utils.generate_files(dev_gen(), dev_output_files)
all_output_files = train_output_files + dev_output_files
generate_data_for_registered_problem(problem)


def generate_data_for_problem(problem):
"""Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

if isinstance(dev_gen, int):
# The dev set and test sets are generated as extra shards using the
# training generator. The integer specifies the number of training
# shards. FLAGS.num_shards is ignored.
num_training_shards = dev_gen
tf.logging.info("Generating data for %s.", problem)
all_output_files = generator_utils.combined_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
num_training_shards)
generator_utils.generate_files(training_gen(), all_output_files,
FLAGS.max_cases)
else:
# usual case - train data and dev data are generated using separate
# generators.
tf.logging.info("Generating training data for %s.", problem)
train_output_files = generator_utils.train_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
FLAGS.num_shards)
generator_utils.generate_files(training_gen(), train_output_files,
FLAGS.max_cases)
tf.logging.info("Generating development data for %s.", problem)
dev_shards = 10 if "coco" in problem else 1
dev_output_files = generator_utils.dev_data_filenames(
problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards)
generator_utils.generate_files(dev_gen(), dev_output_files)
all_output_files = train_output_files + dev_output_files

tf.logging.info("Shuffling data...")
generator_utils.shuffle_dataset(all_output_files)


tf.logging.info("Shuffling data...")
for fname in all_output_files:
records = generator_utils.read_records(fname)
random.shuffle(records)
out_fname = fname.replace(UNSHUFFLED_SUFFIX, "")
generator_utils.write_records(records, out_fname)
tf.gfile.Remove(fname)
def generate_data_for_registered_problem(problem_name):
problem = registry.problem(problem_name)
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)


if __name__ == "__main__":
Expand Down
27 changes: 14 additions & 13 deletions tensor2tensor/bin/t2t-make-tf-configs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright 2017 Google Inc.
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,13 +17,13 @@
Usage:
`t2t-make-tf-configs --workers="server1:1234" --ps="server3:2134,server4:2334"`
`t2t-make-tf-configs --masters="server1:1234" --ps="server3:2134,server4:2334"`
Outputs 1 line per job to stdout, first the workers, then the parameter servers.
Outputs 1 line per job to stdout, first the masters, then the parameter servers.
Each line has the TF_CONFIG, then a tab, then the command line flags for that
job.
If there is a single worker, workers will have the `--sync` flag.
If there is a single master, it will have the `--sync` flag.
"""
from __future__ import absolute_import
from __future__ import division
Expand All @@ -38,31 +38,32 @@ import tensorflow as tf
flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("workers", "", "Comma-separated list of worker addresses")
flags.DEFINE_string("masters", "", "Comma-separated list of master addresses")
flags.DEFINE_string("ps", "", "Comma-separated list of ps addresses")


def main(_):
if not (FLAGS.workers and FLAGS.ps):
raise ValueError("Must provide --workers and --ps")
if not (FLAGS.masters and FLAGS.ps):
raise ValueError("Must provide --masters and --ps")

workers = FLAGS.workers.split(",")
masters = FLAGS.masters.split(",")
ps = FLAGS.ps.split(",")

cluster = {"ps": ps, "worker": workers}
cluster = {"ps": ps, "master": masters}

for task_type, jobs in (("worker", workers), ("ps", ps)):
for task_type, jobs in (("master", masters), ("ps", ps)):
for idx, job in enumerate(jobs):
if task_type == "worker":
if task_type == "master":
cmd_line_flags = " ".join([
"--master=grpc://%s" % job,
"--ps_replicas=%d" % len(ps),
"--worker_replicas=%d" % len(workers),
"--worker_replicas=%d" % len(masters),
"--worker_gpu=1",
"--worker_id=%d" % idx,
"--worker_job='/job:master'",
"--ps_gpu=1",
"--schedule=train",
"--sync" if len(workers) == 1 else "",
"--sync" if len(masters) == 1 else "",
])
else:
cmd_line_flags = " ".join([
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t-trainer
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# Copyright 2017 Google Inc.
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017 Google Inc.
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 0c66117

Please sign in to comment.