Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #396 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.2.7
  • Loading branch information
lukaszkaiser authored Nov 3, 2017
2 parents 9e7d03f + f564d6c commit 097ea5f
Show file tree
Hide file tree
Showing 24 changed files with 1,430 additions and 469 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.2.6',
version='1.2.7',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
13 changes: 8 additions & 5 deletions tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

from collections import defaultdict
import gzip
import io
import os
import random
import stat
import tarfile

# Dependency imports
Expand Down Expand Up @@ -190,8 +190,8 @@ def maybe_download(directory, filename, url):
print()
tf.gfile.Rename(inprogress_filepath, filepath)
statinfo = os.stat(filepath)
tf.logging.info("Successfully downloaded %s, %s bytes." % (filename,
statinfo.st_size))
tf.logging.info("Successfully downloaded %s, %s bytes." %
(filename, statinfo.st_size))
else:
tf.logging.info("Not downloading, file already found: %s" % filepath)
return filepath
Expand Down Expand Up @@ -243,7 +243,7 @@ def maybe_download_from_drive(directory, filename, url):
print()
statinfo = os.stat(filepath)
tf.logging.info("Successfully downloaded %s, %s bytes." % (filename,
statinfo.st_size))
statinfo.st_size))
return filepath


Expand All @@ -258,8 +258,11 @@ def gunzip_file(gz_path, new_path):
tf.logging.info("File %s already exists, skipping unpacking" % new_path)
return
tf.logging.info("Unpacking %s to %s" % (gz_path, new_path))
# We may be unpacking into a newly created directory, add write mode.
mode = stat.S_IRWXU or stat.S_IXGRP or stat.S_IRGRP or stat.S_IROTH
os.chmod(os.path.dirname(new_path), mode)
with gzip.open(gz_path, "rb") as gz_file:
with io.open(new_path, "wb") as new_file:
with tf.gfile.GFile(new_path, mode="wb") as new_file:
for line in gz_file:
new_file.write(line)

Expand Down
56 changes: 56 additions & 0 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import json
import os
import random
import struct
import tarfile
import zipfile

Expand Down Expand Up @@ -925,3 +926,58 @@ class ImageMsCocoTokens32k(ImageMsCocoTokens8k):
@property
def targeted_vocab_size(self):
return 2**15 # 32768


@registry.register_problem
class OcrTest(Image2TextProblem):
"""OCR test problem."""

@property
def is_small(self):
return True

@property
def is_character_level(self):
return True

@property
def target_space_id(self):
return problem.SpaceID.EN_CHR

@property
def train_shards(self):
return 1

@property
def dev_shards(self):
return 1

def preprocess_example(self, example, mode, _):
# Resize from usual size ~1350x60 to 90x4 in this test.
img = example["inputs"]
example["inputs"] = tf.to_int64(
tf.image.resize_images(img, [90, 4], tf.image.ResizeMethod.AREA))
return example

def generator(self, data_dir, tmp_dir, is_training):
# In this test problem, we assume that the data is in tmp_dir/ocr/ in
# files names 0.png, 0.txt, 1.png, 1.txt and so on until num_examples.
num_examples = 2
ocr_dir = os.path.join(tmp_dir, "ocr/")
tf.logging.info("Looking for OCR data in %s." % ocr_dir)
for i in xrange(num_examples):
image_filepath = os.path.join(ocr_dir, "%d.png" % i)
text_filepath = os.path.join(ocr_dir, "%d.txt" % i)
with tf.gfile.Open(text_filepath, "rb") as f:
label = f.read()
with tf.gfile.Open(image_filepath, "rb") as f:
encoded_image_data = f.read()
# In PNG files width and height are stored in these bytes.
width, height = struct.unpack(">ii", encoded_image_data[16:24])
yield {
"image/encoded": [encoded_image_data],
"image/format": ["png"],
"image/class/label": label.strip(),
"image/height": [height],
"image/width": [width]
}
106 changes: 72 additions & 34 deletions tensor2tensor/data_generators/translate_enfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,50 +34,54 @@
# End-of-sentence marker.
EOS = text_encoder.EOS_ID

_ENFR_TRAIN_DATASETS = [
_ENFR_TRAIN_SMALL_DATA = [
[
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz",
("baseline-1M-enfr/baseline-1M_train.en",
"baseline-1M-enfr/baseline-1M_train.fr")
],
# [
# "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
# ("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
# ],
# [
# "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
# ("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
# ],
# [
# "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
# ("training/news-commentary-v9.fr-en.en",
# "training/news-commentary-v9.fr-en.fr")
# ],
# [
# "http://www.statmt.org/wmt10/training-giga-fren.tar",
# ("giga-fren.release2.fixed.en.gz",
# "giga-fren.release2.fixed.fr.gz")
# ],
# [
# "http://www.statmt.org/wmt13/training-parallel-un.tgz",
# ("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
# ],
]
_ENFR_TEST_DATASETS = [
_ENFR_TEST_SMALL_DATA = [
[
"https://s3.amazonaws.com/opennmt-trainingdata/baseline-1M-enfr.tgz",
("baseline-1M-enfr/baseline-1M_valid.en",
"baseline-1M-enfr/baseline-1M_valid.fr")
],
# [
# "http://data.statmt.org/wmt17/translation-task/dev.tgz",
# ("dev/newstest2013.en", "dev/newstest2013.fr")
# ],
]
_ENFR_TRAIN_LARGE_DATA = [
[
"http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
("commoncrawl.fr-en.en", "commoncrawl.fr-en.fr")
],
[
"http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
("training/europarl-v7.fr-en.en", "training/europarl-v7.fr-en.fr")
],
[
"http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
("training/news-commentary-v9.fr-en.en",
"training/news-commentary-v9.fr-en.fr")
],
[
"http://www.statmt.org/wmt10/training-giga-fren.tar",
("giga-fren.release2.fixed.en.gz",
"giga-fren.release2.fixed.fr.gz")
],
[
"http://www.statmt.org/wmt13/training-parallel-un.tgz",
("un/undoc.2000.fr-en.en", "un/undoc.2000.fr-en.fr")
],
]
_ENFR_TEST_LARGE_DATA = [
[
"http://data.statmt.org/wmt17/translation-task/dev.tgz",
("dev/newstest2013.en", "dev/newstest2013.fr")
],
]


@registry.register_problem
class TranslateEnfrWmt8k(translate.TranslateProblem):
class TranslateEnfrWmtSmall8k(translate.TranslateProblem):
"""Problem spec for WMT En-Fr translation."""

@property
Expand All @@ -88,11 +92,18 @@ def targeted_vocab_size(self):
def vocab_name(self):
return "vocab.enfr"

@property
def use_small_dataset(self):
return True

def generator(self, data_dir, tmp_dir, train):
symbolizer_vocab = generator_utils.get_or_generate_vocab(
data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size,
_ENFR_TRAIN_DATASETS)
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
_ENFR_TRAIN_SMALL_DATA)
if self.use_small_dataset:
datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
else:
datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
tag = "train" if train else "dev"
data_path = translate.compile_data(tmp_dir, datasets,
"wmt_enfr_tok_%s" % tag)
Expand All @@ -109,15 +120,31 @@ def target_space_id(self):


@registry.register_problem
class TranslateEnfrWmt32k(TranslateEnfrWmt8k):
class TranslateEnfrWmtSmall32k(TranslateEnfrWmtSmall8k):

@property
def targeted_vocab_size(self):
return 2**15 # 32768


@registry.register_problem
class TranslateEnfrWmtCharacters(translate.TranslateProblem):
class TranslateEnfrWmt8k(TranslateEnfrWmtSmall8k):

@property
def use_small_dataset(self):
return False


@registry.register_problem
class TranslateEnfrWmt32k(TranslateEnfrWmtSmall32k):

@property
def use_small_dataset(self):
return False


@registry.register_problem
class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem):
"""Problem spec for WMT En-Fr translation."""

@property
Expand All @@ -130,7 +157,10 @@ def vocab_name(self):

def generator(self, data_dir, tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS
if self.use_small_dataset:
datasets = _ENFR_TRAIN_SMALL_DATA if train else _ENFR_TEST_SMALL_DATA
else:
datasets = _ENFR_TRAIN_LARGE_DATA if train else _ENFR_TEST_LARGE_DATA
tag = "train" if train else "dev"
data_path = translate.compile_data(tmp_dir, datasets,
"wmt_enfr_chr_%s" % tag)
Expand All @@ -144,3 +174,11 @@ def input_space_id(self):
@property
def target_space_id(self):
return problem.SpaceID.FR_CHR


@registry.register_problem
class TranslateEnfrWmtCharacters(TranslateEnfrWmtSmallCharacters):

@property
def use_small_dataset(self):
return False
8 changes: 6 additions & 2 deletions tensor2tensor/data_generators/translate_enzh.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@

# End-of-sentence marker.
EOS = text_encoder.EOS_ID

# End-of-sentence marker.
EOS = text_encoder.EOS_ID

# This is far from being the real WMT17 task - only toyset here
# you need to register to get UN data and CWT data
# also by convention this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
# you need to register to get UN data and CWT data. Also, by convention,
# this is EN to ZH - use translate_enzh_wmt8k_rev for ZH to EN task
_ENZH_TRAIN_DATASETS = [[("http://data.statmt.org/wmt17/translation-task/"
"training-parallel-nc-v12.tgz"),
("training/news-commentary-v12.zh-en.en",
Expand Down
9 changes: 7 additions & 2 deletions tensor2tensor/layers/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2958,15 +2958,20 @@ def pad_and_reshape(x):

@expert_utils.add_var_scope()
def multihead_self_attention_reduced(
x, factor, nonlinearity, reduction_type, multihead_params):
x,
factor,
multihead_params,
nonlinearity="none",
reduction_type="conv",
):
"""Reduce the length dimension by compressing with conv.
Args:
x (tf.Tensor): float32 of shape [batch, length, depth]
factor (int): compression factor for the memory sequence
multihead_params (dict): parameters for multihead attention
nonlinearity (str): Add some non-linearity after the memory block
reduction_type (str): type of compression
multihead_params (dict): parameters for multihead attention
Returns:
(tf.Tensor): float32 of shape [batch, length, depth]
Expand Down
3 changes: 3 additions & 0 deletions tensor2tensor/layers/common_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,15 @@ def basic_params1():
# If set to True, drop sequences longer than max_length during eval.
# This affects the validity of the evaluation metrics.
eval_drop_long_sequences=int(False),
# TODO(lukaszkaiser): these parameters should probably be set elsewhere.
# in SymbolModality, share the output embeddings and the softmax
# variables.
# You can also share the input embeddings with the output embeddings
# by using a problem_hparams that uses the same modality object for
# the input_modality and target_modality.
shared_embedding_and_softmax_weights=int(False),
# In SymbolModality, skip the top layer, assume we're providing logits.
symbol_modality_skip_top=int(False),
# For each feature for which you want to override the default input
# modality, add an entry to this semicolon-separated string. Entries are
# formatted "feature_name:modality_type:modality_name", e.g.
Expand Down
Loading

0 comments on commit 097ea5f

Please sign in to comment.