Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #361 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.2.5
  • Loading branch information
lukaszkaiser authored Oct 16, 2017
2 parents 3a9c950 + fa9ad63 commit 3c5823f
Show file tree
Hide file tree
Showing 30 changed files with 1,992 additions and 541 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ script:
- mkdir $T2T_TRAIN_DIR
- t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR
- t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR
- t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10'
git:
depth: 3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.2.4',
version='1.2.5',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def story_generator(tmp_dir):
for path in paths:
for story_file in tf.gfile.Glob(path + "*"):
story = u""
for line in tf.gfile.Open(story_file, 'rb'):
for line in tf.gfile.Open(story_file, "rb"):
line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8")
story += line
yield story
Expand Down
4 changes: 3 additions & 1 deletion tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def generate():
for lang_file in source[1]:
tf.logging.info("Reading file: %s" % lang_file)
filepath = os.path.join(tmp_dir, lang_file)

# Extract from tar if needed.
if not tf.gfile.Exists(filepath):
read_type = "r:gz" if filename.endswith("tgz") else "r"
with tarfile.open(compressed_file, read_type) as corpus_tar:
Expand Down Expand Up @@ -411,7 +413,7 @@ def generate():
for line in source_file:
line = line.strip()
if line and "\t" in line:
parts = line.split("\t", maxsplit=1)
parts = line.split("\t", 1)
part = parts[index].strip()
yield part

Expand Down
59 changes: 47 additions & 12 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
import tensorflow as tf


def resize_by_area(img, size):
"""image resize function used by quite a few image problems."""
return tf.to_int64(
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))


class ImageProblem(problem.Problem):

def example_reading_spec(self, label_key=None):
Expand Down Expand Up @@ -93,16 +99,12 @@ class ImageCeleba(ImageProblem):

def preprocess_example(self, example, unused_mode, unused_hparams):

def resize(img, size):
return tf.to_int64(
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))

inputs = example["inputs"]
# Remove boundaries in CelebA images. Remove 40 pixels each side
# vertically and 20 pixels each side horizontally.
inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40)
example["inputs"] = resize(inputs, 8)
example["targets"] = resize(inputs, 32)
example["inputs"] = resize_by_area(inputs, 8)
example["targets"] = resize_by_area(inputs, 32)
return example

def hparams(self, defaults, unused_model_hparams):
Expand Down Expand Up @@ -388,14 +390,10 @@ def dataset_filename(self):

def preprocess_example(self, example, unused_mode, unused_hparams):

def resize(img, size):
return tf.to_int64(
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))

inputs = example["inputs"]
# For Img2Img resize input and output images as desired.
example["inputs"] = resize(inputs, 8)
example["targets"] = resize(inputs, 32)
example["inputs"] = resize_by_area(inputs, 8)
example["targets"] = resize_by_area(inputs, 32)
return example

def hparams(self, defaults, unused_model_hparams):
Expand Down Expand Up @@ -654,6 +652,43 @@ def preprocess_example(self, example, mode, unused_hparams):
return example


@registry.register_problem
class ImageCifar10Plain8(ImageCifar10):
"""CIFAR-10 rescaled to 8x8 for output: Conditional image generation."""

def dataset_filename(self):
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.

def preprocess_example(self, example, mode, unused_hparams):
example["inputs"] = resize_by_area(example["inputs"], 8)
return example


@registry.register_problem
class Img2imgCifar10(ImageCifar10):
"""CIFAR-10 rescaled to 8x8 for input and 32x32 for output."""

def dataset_filename(self):
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.

def preprocess_example(self, example, unused_mode, unused_hparams):

inputs = example["inputs"]
# For Img2Img resize input and output images as desired.
example["inputs"] = resize_by_area(inputs, 8)
example["targets"] = resize_by_area(inputs, 32)
return example

def hparams(self, defaults, unused_model_hparams):
p = defaults
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
p.target_modality = ("image:identity_no_pad", None)
p.batch_size_multiplier = 256
p.max_expected_batch_size_per_shard = 4
p.input_space_id = 1
p.target_space_id = 1


# URLs and filenames for MSCOCO data.
_MSCOCO_ROOT_URL = "http://msvocds.blob.core.windows.net/"
_MSCOCO_URLS = [
Expand Down
38 changes: 18 additions & 20 deletions tensor2tensor/data_generators/wmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from __future__ import division
from __future__ import print_function

import glob
import os
import stat
import tarfile

# Dependency imports
Expand Down Expand Up @@ -115,7 +113,7 @@ def tabbed_generator(source_path, source_vocab, target_vocab, eos=None):
with tf.gfile.GFile(source_path, mode="r") as source_file:
for line in source_file:
if line and "\t" in line:
parts = line.split("\t", maxsplit=1)
parts = line.split("\t", 1)
source, target = parts[0].strip(), parts[1].strip()
source_ints = source_vocab.encode(source) + eos_list
target_ints = target_vocab.encode(target) + eos_list
Expand Down Expand Up @@ -267,8 +265,9 @@ def bi_vocabs_token_generator(source_path,
# English-Czech datasets
_ENCS_TRAIN_DATASETS = [
[
"https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar",
('tsv', 3, 2, 'data.plaintext-format/*train.gz')
("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/"
"11234/1-1458/data-plaintext-format.tar"),
("tsv", 3, 2, "data.plaintext-format/*train.gz")
],
[
"http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long
Expand Down Expand Up @@ -375,25 +374,22 @@ def _compile_data(tmp_dir, datasets, filename):
url = dataset[0]
compressed_filename = os.path.basename(url)
compressed_filepath = os.path.join(tmp_dir, compressed_filename)

generator_utils.maybe_download(tmp_dir, compressed_filename, url)

if dataset[1][0] == 'tsv':
if dataset[1][0] == "tsv":
_, src_column, trg_column, glob_pattern = dataset[1]
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))
if not filenames:
mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz
# Capture *.tgz and *.tar.gz too.
mode = "r:gz" if compressed_filepath.endswith("gz") else "r"
with tarfile.open(compressed_filepath, mode) as corpus_tar:
corpus_tar.extractall(tmp_dir)
filenames = glob.glob(os.path.join(tmp_dir, glob_pattern))
filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern))
for tsv_filename in filenames:
if tsv_filename.endswith(".gz"):
new_filename = tsv_filename.strip(".gz")
try:
generator_utils.gunzip_file(tsv_filename, new_filename)
except PermissionError:
tsvdir = os.path.dirname(tsv_filename)
os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE)
generator_utils.gunzip_file(tsv_filename, new_filename)
generator_utils.gunzip_file(tsv_filename, new_filename)
tsv_filename = new_filename
with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file:
for line in tsv_file:
Expand Down Expand Up @@ -663,17 +659,19 @@ def vocab_name(self):
def generator(self, data_dir, tmp_dir, train):
datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS
tag = "train" if train else "dev"
data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag)
vocab_datasets = []
data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag)
# CzEng contains 100 gz files with tab-separated columns, so let's expect
# it is the first dataset in datasets and use the newly created *.lang{1,2} files instead.
# it is the first dataset in datasets and use the newly created *.lang{1,2}
# files for vocab construction.
if datasets[0][0].endswith("data-plaintext-format.tar"):
vocab_datasets.append([datasets[0][0],
["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]])
vocab_datasets.append([datasets[0][0], ["wmt_encs_tok_%s.lang1" % tag,
"wmt_encs_tok_%s.lang2" % tag]])
datasets = datasets[1:]
vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets]
symbolizer_vocab = generator_utils.get_or_generate_vocab(
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets)
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
vocab_datasets)
return token_generator(data_path + ".lang1", data_path + ".lang2",
symbolizer_vocab, EOS)

Expand Down
Loading

0 comments on commit 3c5823f

Please sign in to comment.