Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Separate CLI t2t_decoder
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 166920562
  • Loading branch information
Ryan Sepassi committed Aug 29, 2017
1 parent a3be70a commit f715f85
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 34 deletions.
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,26 @@ You can chat with us and other users on
with T2T announcements.

Here is a one-command version that installs tensor2tensor, downloads the data,
trains an English-German translation model, and lets you use it interactively:
trains an English-German translation model, and evaluates it:
```
pip install tensor2tensor && t2t-trainer \
--generate_data \
--data_dir=~/t2t_data \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base \
--output_dir=~/t2t_train/base
```

You can decode from the model interactively:

```
t2t-decoder \
--data_dir=~/t2t_data \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base
--decode_interactive
```

Expand Down Expand Up @@ -106,14 +117,12 @@ echo "Goodbye world" >> $DECODE_FILE
BEAM_SIZE=4
ALPHA=0.6
t2t-trainer \
t2t-decoder \
--data_dir=$DATA_DIR \
--problems=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--train_steps=0 \
--eval_steps=0 \
--decode_beam_size=$BEAM_SIZE \
--decode_alpha=$ALPHA \
--decode_from_file=$DECODE_FILE
Expand Down
15 changes: 13 additions & 2 deletions docs/walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,26 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

Here is a one-command version that installs tensor2tensor, downloads the data,
trains an English-German translation model, and lets you use it interactively:
trains an English-German translation model, and evaluates it:
```
pip install tensor2tensor && t2t-trainer \
--generate_data \
--data_dir=~/t2t_data \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base \
--output_dir=~/t2t_train/base
```

You can decode from the model interactively:

```
t2t-decoder \
--data_dir=~/t2t_data \
--problems=translate_ende_wmt32k \
--model=transformer \
--hparams_set=transformer_base_single_gpu \
--output_dir=~/t2t_train/base
--decode_interactive
```

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
scripts=[
'tensor2tensor/bin/t2t-trainer',
'tensor2tensor/bin/t2t-datagen',
'tensor2tensor/bin/t2t-decoder',
'tensor2tensor/bin/t2t-make-tf-configs',
],
install_requires=[
Expand Down
90 changes: 90 additions & 0 deletions tensor2tensor/bin/t2t-decoder
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Decode from trained T2T models.
This binary performs inference using the Estimator API.
Example usage to decode from dataset:
t2t-decoder \
--data_dir ~/data \
--problems=algorithmic_identity_binary40 \
--model=transformer
--hparams_set=transformer_base
Set FLAGS.decode_interactive or FLAGS.decode_from_file for alternative decode
sources.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from tensor2tensor.utils import decoding
from tensor2tensor.utils import trainer_utils
from tensor2tensor.utils import usr_dir

import tensorflow as tf

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("t2t_usr_dir", "",
"Path to a Python module that will be imported. The "
"__init__.py file should include the necessary imports. "
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-decoder.")


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
trainer_utils.log_registry()
trainer_utils.validate_flags()
data_dir = os.path.expanduser(FLAGS.data_dir)
output_dir = os.path.expanduser(FLAGS.output_dir)

hparams = trainer_utils.create_hparams(
FLAGS.hparams_set, FLAGS.problems, data_dir, passed_hparams=FLAGS.hparams)
estimator, _ = trainer_utils.create_experiment_components(
hparams=hparams,
output_dir=output_dir,
data_dir=data_dir,
model_name=FLAGS.model)

if FLAGS.decode_interactive:
decoding.decode_interactively(estimator)
elif FLAGS.decode_from_file:
decoding.decode_from_file(estimator, FLAGS.decode_from_file)
else:
decoding.decode_from_dataset(
estimator,
FLAGS.problems.split("-"),
return_beams=FLAGS.decode_return_beams,
beam_size=FLAGS.decode_beam_size,
max_predictions=FLAGS.decode_num_samples,
decode_to_file=FLAGS.decode_to_file,
save_images=FLAGS.decode_save_images,
identity_output=FLAGS.identity_output)


if __name__ == "__main__":
tf.app.run()
32 changes: 5 additions & 27 deletions tensor2tensor/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from tensor2tensor.data_generators import problem_hparams
from tensor2tensor.models import models # pylint: disable=unused-import
from tensor2tensor.utils import data_reader
from tensor2tensor.utils import decoding
from tensor2tensor.utils import devices
from tensor2tensor.utils import input_fn_builder
from tensor2tensor.utils import metrics
Expand Down Expand Up @@ -101,16 +100,13 @@
flags.DEFINE_string("ps_job", "/job:ps", "name of ps job")
flags.DEFINE_integer("ps_replicas", 0, "How many ps replicas.")

# Decode flags
# Set one of {decode_from_dataset, decode_interactive, decode_from_file} to
# decode.
flags.DEFINE_bool("decode_from_dataset", False, "Decode from dataset on disk.")
flags.DEFINE_bool("decode_use_last_position_only", False,
"In inference, use last position only for speedup.")
# Decoding flags
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
flags.DEFINE_bool("decode_use_last_position_only", False,
"In inference, use last position only for speedup.")
flags.DEFINE_bool("decode_save_images", False, "Save inference input images.")
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_string("decode_to_file", None, "Path to inference output file")
flags.DEFINE_integer("decode_shards", 1, "How many shards to decode.")
flags.DEFINE_integer("decode_problem_id", 0, "Which problem to decode.")
Expand All @@ -128,7 +124,7 @@
"Maximum number of ids in input. Or <= 0 for no max.")
flags.DEFINE_bool("identity_output", False, "To print the output as identity")
flags.DEFINE_integer("decode_num_samples", -1,
"Number of samples to decode. Currently used in"
"Number of samples to decode. Currently used in "
"decode_from_dataset. Use -1 for all.")


Expand Down Expand Up @@ -303,7 +299,6 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule):
if exp.train_steps > 0 or exp.eval_steps > 0:
tf.logging.info("Performing local training and evaluation.")
exp.train_and_evaluate()
decode(exp.estimator)
else:
# Perform distributed training/evaluation.
learn_runner.run(
Expand Down Expand Up @@ -350,20 +345,3 @@ def session_config():

def get_data_filepatterns(data_dir, mode):
return data_reader.get_data_filepatterns(FLAGS.problems, data_dir, mode)


def decode(estimator):
if FLAGS.decode_interactive:
decoding.decode_interactively(estimator)
elif FLAGS.decode_from_file is not None and FLAGS.decode_from_file is not "":
decoding.decode_from_file(estimator, FLAGS.decode_from_file)
elif FLAGS.decode_from_dataset:
decoding.decode_from_dataset(
estimator,
FLAGS.problems.split("-"),
return_beams=FLAGS.decode_return_beams,
beam_size=FLAGS.decode_beam_size,
max_predictions=FLAGS.decode_num_samples,
decode_to_file=FLAGS.decode_to_file,
save_images=FLAGS.decode_save_images,
identity_output=FLAGS.identity_output)

0 comments on commit f715f85

Please sign in to comment.