diff --git a/setup.py b/setup.py index 94f44c137..8870809ae 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.3.1', + version='1.3.2', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/data_generators/translate_enfr.py b/tensor2tensor/data_generators/translate_enfr.py index b09fca90e..921834000 100644 --- a/tensor2tensor/data_generators/translate_enfr.py +++ b/tensor2tensor/data_generators/translate_enfr.py @@ -143,6 +143,14 @@ def use_small_dataset(self): return False +@registry.register_problem +class TranslateEnfrWmt32kPacked(TranslateEnfrWmt32k): + + @property + def packed_length(self): + return 256 + + @registry.register_problem class TranslateEnfrWmtSmallCharacters(translate.TranslateProblem): """Problem spec for WMT En-Fr translation.""" diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 23cf074af..304cb49be 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -1182,7 +1182,8 @@ def dot_product_attention(q, dropout_rate=0.0, image_shapes=None, name=None, - make_image_summary=True): + make_image_summary=True, + save_weights_to=None): """dot-product attention. Args: @@ -1195,17 +1196,22 @@ def dot_product_attention(q, see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. + save_weights_to: an optional dictionary to capture attention weights + for vizualization; the weights tensor will be appended there under + a string key created from the variable scope (including name). Returns: A Tensor. """ with tf.variable_scope( - name, default_name="dot_product_attention", values=[q, k, v]): + name, default_name="dot_product_attention", values=[q, k, v]) as scope: # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") + if save_weights_to is not None: + save_weights_to[scope.name] = weights # dropping out the attention links for each of the heads weights = tf.nn.dropout(weights, 1.0 - dropout_rate) if (not tf.get_variable_scope().reuse and @@ -2245,6 +2251,7 @@ def multihead_attention(query_antecedent, gap_size=0, num_memory_blocks=2, name=None, + save_weights_to=None, **kwargs): """Multihead scaled-dot-product attention with input/output transformations. @@ -2284,7 +2291,10 @@ def multihead_attention(query_antecedent, memory blocks. num_memory_blocks: Integer option to indicate how many memory blocks to look at. - name: an optional string + name: an optional string. + save_weights_to: an optional dictionary to capture attention weights + for vizualization; the weights tensor will be appended there under + a string key created from the variable scope (including name). **kwargs (dict): Parameters for the attention function Caching: @@ -2345,7 +2355,8 @@ def multihead_attention(query_antecedent, if isinstance(x, tuple): x, additional_returned_value = x # Unpack elif attention_type == "dot_product": - x = dot_product_attention(q, k, v, bias, dropout_rate, image_shapes) + x = dot_product_attention(q, k, v, bias, dropout_rate, image_shapes, + save_weights_to=save_weights_to) elif attention_type == "dot_product_relative": x = dot_product_attention_relative(q, k, v, bias, max_relative_position, dropout_rate, image_shapes) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index ffe5fcb52..8fd3edd21 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -45,6 +45,10 @@ class Transformer(t2t_model.T2TModel): """Attention net. See file docstring.""" + def __init__(self, *args, **kwargs): + super(Transformer, self).__init__(*args, **kwargs) + self.attention_weights = dict() # For vizualizing attention heads. + def encode(self, inputs, target_space, hparams, features=None): """Encode transformer inputs. @@ -73,7 +77,8 @@ def encode(self, inputs, target_space, hparams, features=None): encoder_output = transformer_encoder( encoder_input, self_attention_bias, - hparams, nonpadding=_features_to_nonpadding(features, "inputs")) + hparams, nonpadding=_features_to_nonpadding(features, "inputs"), + save_weights_to=self.attention_weights) return encoder_output, encoder_decoder_attention_bias @@ -114,7 +119,8 @@ def decode(self, encoder_decoder_attention_bias, hparams, cache=cache, - nonpadding=nonpadding) + nonpadding=nonpadding, + save_weights_to=self.attention_weights) if hparams.use_tpu and hparams.mode == tf.estimator.ModeKeys.TRAIN: # TPU does not react kindly to extra dimensions. @@ -507,7 +513,8 @@ def transformer_encoder(encoder_input, encoder_self_attention_bias, hparams, name="encoder", - nonpadding=None): + nonpadding=None, + save_weights_to=None): """A stack of transformer layers. Args: @@ -522,6 +529,9 @@ def transformer_encoder(encoder_input, encoder_self_attention_bias. The knowledge about padding is used for pad_remover(efficiency) and to mask out padding in convoltutional layers. + save_weights_to: an optional dictionary to capture attention weights + for vizualization; the weights tensor will be appended there under + a string key created from the variable scope (including name). Returns: y: a Tensors @@ -551,6 +561,7 @@ def transformer_encoder(encoder_input, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, + save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): @@ -571,7 +582,8 @@ def transformer_decoder(decoder_input, hparams, cache=None, name="decoder", - nonpadding=None): + nonpadding=None, + save_weights_to=None): """A stack of transformer layers. Args: @@ -590,6 +602,9 @@ def transformer_decoder(decoder_input, to mask out padding in convoltutional layers. We generally only need this mask for "packed" datasets, because for ordinary datasets, no padding is ever followed by nonpadding. + save_weights_to: an optional dictionary to capture attention weights + for vizualization; the weights tensor will be appended there under + a string key created from the variable scope (including name). Returns: y: a Tensors @@ -612,6 +627,7 @@ def transformer_decoder(decoder_input, hparams.num_heads, hparams.attention_dropout, attention_type=hparams.self_attention_type, + save_weights_to=save_weights_to, max_relative_position=hparams.max_relative_position, cache=layer_cache) x = common_layers.layer_postprocess(x, y, hparams) @@ -624,7 +640,8 @@ def transformer_decoder(decoder_input, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, - hparams.attention_dropout) + hparams.attention_dropout, + save_weights_to=save_weights_to) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( diff --git a/tensor2tensor/notebooks/hello_t2t.ipynb b/tensor2tensor/notebooks/hello_t2t.ipynb index fd08175c6..797b0b98b 100644 --- a/tensor2tensor/notebooks/hello_t2t.ipynb +++ b/tensor2tensor/notebooks/hello_t2t.ipynb @@ -55,7 +55,8 @@ }, "source": [ "# Install deps\n", - "!pip install -q \"tensor2tensor-dev==1.3.1.dev7\" tf-nightly" + "# We're using some new features from tensorflow so we install tf-nightly\n", + "!pip install -q tensor2tensor tf-nightly" ], "cell_type": "code", "execution_count": 0, @@ -77,8 +78,10 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import os\n", + "import collections\n", "\n", "from tensor2tensor import problems\n", + "from tensor2tensor.layers import common_layers\n", "from tensor2tensor.utils import t2t_model\n", "from tensor2tensor.utils import trainer_utils\n", "from tensor2tensor.utils import registry\n", @@ -109,17 +112,17 @@ }, { "metadata": { - "id": "gXL7_bVH49Kl", + "id": "0a69r1KDiZDe", "colab_type": "text" }, "source": [ - "# Translate from English to German with a pre-trained model" + "# Download MNIST and inspect it" ], "cell_type": "markdown" }, { "metadata": { - "id": "Q2CYCYjZTlZs", + "id": "RYDMO4zArgkz", "colab_type": "code", "colab": { "autoexec": { @@ -128,18 +131,18 @@ }, "output_extras": [ { - "item_id": 2 + "item_id": 1 } ], "base_uri": "https://localhost:8080/", - "height": 68 + "height": 1224 }, - "outputId": "b13d53a3-feba-4d74-fc1e-951bef99ecb0", + "outputId": "2edd5f47-1ebb-4d18-e57c-741c966afc10", "executionInfo": { "status": "ok", - "timestamp": 1512165746671, + "timestamp": 1512173990900, "user_tz": 480, - "elapsed": 2799, + "elapsed": 272, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -148,128 +151,162 @@ } }, "source": [ - "# Translation\n", - "ende_problem = registry.problem(\"translate_ende_wmt32k\")\n", - "\n", - "# Copy the vocab file locally\n", - "vocab_file = os.path.join(gs_data_dir, \"vocab.ende.32768\")\n", - "!gsutil cp {vocab_file} {data_dir}" + "# A Problem is a dataset together with some fixed pre-processing.\n", + "# It could be a translation dataset with a specific tokenization,\n", + "# or an image dataset with a specific resolution.\n", + "#\n", + "# There are many problems available in Tensor2Tensor\n", + "problems.available()" ], "cell_type": "code", "execution_count": 4, "outputs": [ { - "output_type": "stream", - "text": [ - "Copying gs://tensor2tensor-data/vocab.ende.32768...\n", - "/ [1 files][316.4 KiB/316.4 KiB] \n", - "Operation completed over 1 objects/316.4 KiB. \n" - ], - "name": "stdout" + "output_type": "execute_result", + "data": { + "text/plain": [ + "['algorithmic_addition_binary40',\n", + " 'algorithmic_addition_decimal40',\n", + " 'algorithmic_cipher_shift200',\n", + " 'algorithmic_cipher_shift5',\n", + " 'algorithmic_cipher_vigenere200',\n", + " 'algorithmic_cipher_vigenere5',\n", + " 'algorithmic_identity_binary40',\n", + " 'algorithmic_identity_decimal40',\n", + " 'algorithmic_multiplication_binary40',\n", + " 'algorithmic_multiplication_decimal40',\n", + " 'algorithmic_reverse_binary40',\n", + " 'algorithmic_reverse_binary40_test',\n", + " 'algorithmic_reverse_decimal40',\n", + " 'algorithmic_reverse_nlplike32k',\n", + " 'algorithmic_reverse_nlplike8k',\n", + " 'algorithmic_shift_decimal40',\n", + " 'audio_timit_characters_tune',\n", + " 'audio_timit_tokens8k_test',\n", + " 'audio_timit_tokens8k_tune',\n", + " 'image_celeba_tune',\n", + " 'image_cifar10',\n", + " 'image_cifar10_plain',\n", + " 'image_cifar10_plain8',\n", + " 'image_cifar10_tune',\n", + " 'image_fsns',\n", + " 'image_imagenet',\n", + " 'image_imagenet224',\n", + " 'image_imagenet32',\n", + " 'image_imagenet64',\n", + " 'image_mnist',\n", + " 'image_mnist_tune',\n", + " 'image_ms_coco_characters',\n", + " 'image_ms_coco_tokens32k',\n", + " 'image_ms_coco_tokens8k',\n", + " 'img2img_cifar10',\n", + " 'img2img_imagenet',\n", + " 'languagemodel_lm1b32k',\n", + " 'languagemodel_lm1b8k_packed',\n", + " 'languagemodel_lm1b_characters',\n", + " 'languagemodel_ptb10k',\n", + " 'languagemodel_ptb_characters',\n", + " 'languagemodel_wiki_full32k',\n", + " 'languagemodel_wiki_scramble128',\n", + " 'languagemodel_wiki_scramble1k50',\n", + " 'languagemodel_wiki_scramble8k50',\n", + " 'librispeech',\n", + " 'multinli_matched',\n", + " 'multinli_mismatched',\n", + " 'ocr_test',\n", + " 'parsing_english_ptb16k',\n", + " 'parsing_english_ptb8k',\n", + " 'parsing_icelandic16k',\n", + " 'programming_desc2code_cpp',\n", + " 'programming_desc2code_py',\n", + " 'sentiment_imdb',\n", + " 'summarize_cnn_dailymail32k',\n", + " 'translate_encs_wmt32k',\n", + " 'translate_encs_wmt_characters',\n", + " 'translate_ende_wmt32k',\n", + " 'translate_ende_wmt32k_packed',\n", + " 'translate_ende_wmt8k',\n", + " 'translate_ende_wmt_bpe32k',\n", + " 'translate_ende_wmt_characters',\n", + " 'translate_enfr_wmt32k',\n", + " 'translate_enfr_wmt8k',\n", + " 'translate_enfr_wmt_characters',\n", + " 'translate_enfr_wmt_small32k',\n", + " 'translate_enfr_wmt_small8k',\n", + " 'translate_enfr_wmt_small_characters',\n", + " 'translate_enmk_setimes32k',\n", + " 'translate_enzh_wmt8k']" + ] + }, + "metadata": { + "tags": [] + }, + "execution_count": 4 } ] }, { "metadata": { - "id": "EB4MP7_y_SuQ", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "source": [ - "encoders = ende_problem.feature_encoders(data_dir)\n", - "\n", - "def encode(input_str):\n", - " \"\"\"Input str to features dict, ready for inference\"\"\"\n", - " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n", - " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n", - " return {\"inputs\": batch_inputs}\n", - "\n", - "def decode(integers):\n", - " \"\"\"List of ints to str\"\"\"\n", - " integers = list(np.squeeze(integers))\n", - " if 1 in integers:\n", - " integers = integers[:integers.index(1)]\n", - " return encoders[\"inputs\"].decode(np.squeeze(integers))" - ], - "cell_type": "code", - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "g2aQW7Z6TOEu", + "id": "JKc2uSk6WX5e", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 + }, + "output_extras": [ + { + "item_id": 3 + } + ], + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "outputId": "0ea990ae-6715-4ada-d3a2-a5312faaaa39", + "executionInfo": { + "status": "ok", + "timestamp": 1512173992544, + "user_tz": 480, + "elapsed": 955, + "user": { + "displayName": "Ryan Sepassi", + "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", + "userId": "107877449274830904926" } } }, "source": [ - "# # Generate and view the data\n", - "# # This cell is commented out because data generation can take hours\n", - "\n", - "# ende_problem.generate_data(data_dir, tmp_dir)\n", - "# example = tfe.Iterator(ende_problem.dataset(Modes.TRAIN, data_dir)).next()\n", - "# inputs = [int(x) for x in example[\"inputs\"].numpy()] # Cast to ints.\n", - "# targets = [int(x) for x in example[\"targets\"].numpy()] # Cast to ints.\n", - "\n", - "\n", - "\n", - "# # Example inputs as int-tensor.\n", - "# print(\"Inputs, encoded:\")\n", - "# print(inputs)\n", - "# print(\"Inputs, decoded:\")\n", - "# # Example inputs as a sentence.\n", - "# print(decode(inputs))\n", - "# # Example targets as int-tensor.\n", - "# print(\"Targets, encoded:\")\n", - "# print(targets)\n", - "# # Example targets as a sentence.\n", - "# print(\"Targets, decoded:\")\n", - "# print(decode(targets))" + "# Fetch the MNIST problem\n", + "mnist_problem = problems.problem(\"image_mnist\")\n", + "# The generate_data method of a problem will download data and process it into\n", + "# a standard format ready for training and evaluation.\n", + "mnist_problem.generate_data(data_dir, tmp_dir)" ], "cell_type": "code", - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "9l6hDQbrRUYV", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-images-idx3-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-labels-idx1-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-images-idx3-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-labels-idx1-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-images-idx3-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-labels-idx1-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-images-idx3-ubyte.gz\n", + "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-labels-idx1-ubyte.gz\n", + "INFO:tensorflow:Skipping generator because outputs files exist\n", + "INFO:tensorflow:Skipping generator because outputs files exist\n", + "INFO:tensorflow:Skipping shuffle because output files exist\n" + ], + "name": "stdout" } - }, - "source": [ - "# Create hparams and the T2TModel object.\n", - "model_name = \"transformer\"\n", - "hparams_set = \"transformer_base\"\n", - "\n", - "hparams = trainer_utils.create_hparams(hparams_set, data_dir)\n", - "trainer_utils.add_problem_hparams(hparams, \"translate_ende_wmt32k\")\n", - "\n", - "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n", - "# Layer and so subsequent instantiations will have different variable scopes\n", - "# that will not match the checkpoint.\n", - "translate_model = registry.model(model_name)(hparams, Modes.PREDICT)" - ], - "cell_type": "code", - "execution_count": 0, - "outputs": [] + ] }, { "metadata": { - "id": "FEwNUVlMYOJi", + "id": "VW6HCRANFPYV", "colab_type": "code", "colab": { "autoexec": { @@ -278,18 +315,21 @@ }, "output_extras": [ { - "item_id": 1 + "item_id": 2 + }, + { + "item_id": 3 } ], "base_uri": "https://localhost:8080/", - "height": 34 + "height": 381 }, - "outputId": "fc15a59a-7ea7-4baa-85c1-2a94528eb157", + "outputId": "121d463f-adaf-4340-a5cb-12e931fd0fdb", "executionInfo": { "status": "ok", - "timestamp": 1512165760778, + "timestamp": 1512173993175, "user_tz": 480, - "elapsed": 12527, + "elapsed": 561, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -298,33 +338,52 @@ } }, "source": [ - "# Copy the pretrained checkpoint locally\n", - "ckpt_name = \"transformer_ende_test\"\n", - "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n", - "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n", - "ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))\n", - "ckpt_path" + "# Now let's see the training MNIST data as Tensors.\n", + "mnist_example = tfe.Iterator(mnist_problem.dataset(Modes.TRAIN, data_dir)).next()\n", + "image = mnist_example[\"inputs\"]\n", + "label = mnist_example[\"targets\"]\n", + "\n", + "plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n", + "print(\"Label: %d\" % label.numpy())" ], "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "outputs": [ { - "output_type": "execute_result", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Reading data files from /content/t2t/data/image_mnist-train*\n", + "Label: 6\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAE4hJREFUeJzt3X1MlfX/x/HXCWLC1KEklq27OZ1M\ncKvUic4bFC3amje1VERzc02XOm9Gxpyo5SaKaN61RFO3ZK3T+CdXLsjMcoo4aVMP/6D+YcwMQZnp\nRFM6vz9++7KQczhvjpyb6/h8bPzB5/qcz/V+72IvrnOuc53j8nq9XgEAOvVUpAsAACcgLAHAgLAE\nAAPCEgAMCEsAMCAsAcDCGwaSfP5cuHDB7zan/sRiT7HaFz055ydcfXXGFY73WbpcLp/jXq/X7zan\nisWepNjsi56cI1x9dRaH8cEuunHjRp07d04ul0urV6/WsGHDgl0KAKJeUGF55swZXblyRW63W5cv\nX9bq1avldru7uzYAiBpBXeCpqqpSdna2JGngwIG6deuW7ty5062FAUA0CerMsqmpSUOHDm37vW/f\nvmpsbFTPnj19zr9w4YLS09N9bgvDS6ZhF4s9SbHZFz05R6T7Cvo1y/8K1ERGRobfx8Xai9Gx2JMU\nm33Rk3NEwwWeoJ6Gp6amqqmpqe3369evq1+/fsEsBQCOEFRYjhkzRhUVFZKk2tpapaam+n0KDgCx\nIKin4a+99pqGDh2qWbNmyeVyad26dd1dFwBEFd6U3s1isScpNvuiJ+dw7GuWAPCkISwBwICwBAAD\nwhIADAhLADAgLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhL\nADAgLAHAgLAEAAPCEgAMCEsAMAjqq3CBWDV8+HDTvMrKSvOaf/zxh3ludna2eW5TU5N5Lh4fZ5YA\nYEBYAoABYQkABoQlABgQlgBgQFgCgAFhCQAGhCUAGBCWAGBAWAKAAbc7Av/x/vvvm+YlJyeb1+zK\n3Pnz55vnlpSUmOfi8XFmCQAGhCUAGBCWAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoABYQkABtzBg5j3\n3HPPmbfl5uaa1rxx44Z5/4WFhea5p06dMs9FeHFmCQAGQZ1ZVldXa9myZRo0aJAkafDgwV367wkA\nThP00/CRI0dq586d3VkLAEQtnoYDgEHQYXnp0iUtWrRIs2fP1smTJ7uzJgCIOi6v1+vt6oMaGhpU\nU1OjnJwc1dfXa968eaqsrFRCQoLP+R6PR+np6Y9dLABESlBh+ah3331Xn332mV544QXfO3G5fI57\nvV6/25wqFnuSnN2Xv7cO/fnnnxowYEC7MY/HY1rz33//Ne8/VG8dOn/+fIcxJx+nzoSrr87iMKin\n4YcPH9b+/fslSY2Njbpx44b69+8fXHUA4ABBXQ2fOHGi8vPz9fPPP+vBgwdav36936fgABALggrL\nnj17as+ePd1dCwBELW53RMwrKCgwb+vTp49pze3bt5v3z4lFbOB9lgBgQFgCgAFhCQAGhCUAGBCW\nAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoBBt3xEW8Cd8BFtjhdtfc2dO9c89+DBgz7H4+Li1Nra2m7s\n5s2bpjVfffVV8/6vXr1qnvu4ou04dRfHfkQbADxpCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICw\nBAADwhIADPjCMjjSe++9Z5771FP+zwke3fbVV1+Z1gznXTmIDpxZAoABYQkABoQlABgQlgBgQFgC\ngAFhCQAGhCUAGBCWAGBAWAKAAWEJAAbc7oiokpeXZ5r35ptvmte8d++ez/HExMQO26y3O+LJw5kl\nABgQlgBgQFgCgAFhCQAGhCUAGBCWAGBAWAKAAWEJAAaEJQAYEJYAYMDtjogqs2fPNs2Li4szr7l9\n+3af4/n5+fr888/bjZ0/f968Lp4spjPLuro6ZWdnq6ysTJJ07do1zZ07V7m5uVq2bJn++eefkBYJ\nAJEWMCzv3r2rDRs2KDMzs21s586dys3N1ddff62XXnpJ5eXlIS0SACItYFgmJCRo3759Sk1NbRur\nrq7WpEmTJElZWVmqqqoKXYUAEAUCvmYZHx+v+Pj201paWpSQkCBJSklJUWNjY2iqA4Ao8dgXeLxe\nb8A5Fy5cUHp6etCPd5pY7Elybl/5+fnmbZ3NdQqnHqdAIt1XUGGZlJSke/fuqUePHmpoaGj3FN2X\njIwMn+Ner1culyuYEqJWLPYkha+vH374wTQvJyfHvObWrVt9jufn56ukpKTd2EcffWReNxrx9/f4\n+/EnqPdZjh49WhUVFZKkyspKjR07NrjKAMAhAp5Zejwebd68WVevXlV8fLwqKipUUlKigoICud1u\nDRgwQNOmTQtHrQAQMQHDMj09XYcOHeowfvDgwZAUBADRiDt4EHJz5swxz508ebJpnr8vIfPl119/\n9Tmen5/fYVufPn1MazY3N5v3j9jAveEAYEBYAoABYQkABoQlABgQlgBgQFgCgAFhCQAGhCUAGBCW\nAGBAWAKAAbc7IuTGjRtnnvvoB0378/3335vXHD58uHlbaWmpac0dO3aY919cXGyei+jFmSUAGBCW\nAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoABYQkABoQlABgQlgBg4PJ6vd6Q78Tl8jnu9Xr9bnOqWOxJ\n6thXYmKi+bGXL182z3322WdN82pra81rDh061Oe4y+VSsH/+J0+eNM8dO3ZsUPsIxpPy9xfK/fjD\nmSUAGBCWAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoABYQkABoQlABjwhWUIyqxZs8xzrXfldIW/u3LC\n5dSpUxHdP8KPM0sAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADAgLAHAgNsdEZRR\no0ZFugSzn376yef4lClTOmybPHmyac2WlpbHrgvOwpklABiYwrKurk7Z2dkqKyuTJBUUFOjtt9/W\n3LlzNXfuXB0/fjyUNQJAxAV8Gn737l1t2LBBmZmZ7cZXrlyprKyskBUGANEk4JllQkKC9u3bp9TU\n1HDUAwBRyeX1er2Wibt27VKfPn2Ul5engoICNTY26sGDB0pJSVFhYaH69u3r97Eej0fp6endVjQA\nhFtQV8OnTp2q5ORkpaWlae/evdq9e7fWrl3rd35GRobPca/XK5fLFUwJUSsWe5I69lVaWmp+7Acf\nfBCKksw6uxpeWVnZbsx6NfzTTz8173/9+vXmuY/rSfn7C+V+/AnqanhmZqbS0tIkSRMnTlRdXV1w\nlQGAQwQVlkuXLlV9fb0kqbq6WoMGDerWogAg2gR8Gu7xeLR582ZdvXpV8fHxqqioUF5enpYvX67E\nxEQlJSWpqKgoHLUCQMQEDMv09HQdOnSow/gbb7wRkoIAIBpxuyPaef75503bZs6cGY5y/Prxxx/N\nc8+dO+dzfMqUKfr999/bjfm7GPmo/fv3m/eP2MDtjgBgQFgCgAFhCQAGhCUAGBCWAGBAWAKAAWEJ\nAAaEJQAYEJYAYEBYAoABtzuinc4+e/K/23r37h2S/T98+NA07+DBg+Y1Fy9e7Hfbo1+XYl33f5+6\nhScHZ5YAYEBYAoABYQkABoQlABgQlgBgQFgCgAFhCQAGhCUAGBCWAGDAHTxoJyUlJaht3eW7774z\nzbtz5455zddff9287fjx4+Z18WThzBIADAhLADAgLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8IS\nAAwISwAw4HZHhNz9+/fNc4cMGWKa9+2335rXrKmp8Tk+fvz4Dtu2bt1qXhdPFs4sAcCAsAQAA8IS\nAAwISwAwICwBwICwBAADwhIADAhLADAgLAHAgLAEAANud0TIWb+xUZLS0tJM85KSksxrfvLJJz7H\njx071mHb7du3zeviyWIKy+LiYtXU1Ojhw4dauHChMjIytGrVKrW2tqpfv37asmWLEhISQl0rAERM\nwLA8ffq0Ll68KLfbrebmZk2fPl2ZmZnKzc1VTk6Otm3bpvLycuXm5oajXgCIiICvWY4YMUI7duyQ\nJPXu3VstLS2qrq7WpEmTJElZWVmqqqoKbZUAEGEBwzIuLq7t9aHy8nKNGzdOLS0tbU+7U1JS1NjY\nGNoqASDCXF6v12uZePToUZWWlurAgQOaMmVK29nklStX9PHHH+ubb77x+1iPx6P09PTuqRgAIsB0\ngefEiRPas2ePvvzyS/Xq1UtJSUm6d++eevTooYaGBqWmpnb6+IyMDJ/jXq9XLper61VHMaf3tGvX\nLp/jS5Ys0e7du9t+X7x4sXnNrnxQr/VqeFf++WZnZ/scP3bsmCZOnNhu7JdffjGvG42c/vfnT7j6\n6uzcMeDT8Nu3b6u4uFilpaVKTk6WJI0ePVoVFRWSpMrKSo0dO7abSgWA6BTwzPLIkSNqbm7W8uXL\n28Y2bdqkNWvWyO12a8CAAZo2bVpIiwSASAsYljNnztTMmTM7jB88eDAkBQFANOIOHrRz5cqVoLZ1\nxtc/W3+M1xv1xRdfmNfs7HVIp79GifDh3nAAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIA\nDAhLADAgLAHAwPx5lo+1Ez8frRSLHycViz1JHfvaunWr+bErVqwwz920aZNpXlFRkXlNf19CFovH\nKhZ7khzyEW0AAMISAEwISwAwICwBwICwBAADwhIADAhLADAgLAHAgLAEAAPCEgAMuN2xm8ViT1Js\n9kVPzsHtjgDgEIQlABgQlgBgQFgCgAFhCQAGhCUAGBCWAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoAB\nYQkABoQlABgQlgBgQFgCgAFhCQAGhCUAGBCWAGBAWAKAAWEJAAbxlknFxcWqqanRw4cPtXDhQh07\ndky1tbVKTk6WJC1YsEATJkwIZZ0AEFEBw/L06dO6ePGi3G63mpubNX36dI0aNUorV65UVlZWOGoE\ngIgLGJYjRozQsGHDJEm9e/dWS0uLWltbQ14YAEQTl7ezbxV/hNvt1tmzZxUXF6fGxkY9ePBAKSkp\nKiwsVN++ff3vxM+Xo8fiF8LHYk9SbPZFT84Rrr46i0NzWB49elSlpaU6cOCAPB6PkpOTlZaWpr17\n9+qvv/7S2rVr/T7W4/EoPT2965UDQLTwGvz222/ed955x9vc3Nxh28WLF71z5szp9PGSfP50ts2p\nP7HYU6z2RU/O+QlXX50J+Nah27dvq7i4WKWlpW1Xv5cuXar6+npJUnV1tQYNGhRoGQBwtIAXeI4c\nOaLm5mYtX768bWzGjBlavny5EhMTlZSUpKKiopAWCQCR1qULPEHvhAs8jheLfdGTc4Srr87ikDt4\nAMCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADAgLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQA\nA8ISAAwISwAwICwBwICwBAADwhIADAhLADAgLAHAgLAEAIOwfBUuADgdZ5YAYEBYAoABYQkABoQl\nABgQlgBgQFgCgEF8JHa6ceNGnTt3Ti6XS6tXr9awYcMiUUa3qq6u1rJlyzRo0CBJ0uDBg1VYWBjh\nqoJXV1enDz/8UPPnz1deXp6uXbumVatWqbW1Vf369dOWLVuUkJAQ6TK75NGeCgoKVFtbq+TkZEnS\nggULNGHChMgW2UXFxcWqqanRw4cPtXDhQmVkZDj+OEkd+zp27FjEj1XYw/LMmTO6cuWK3G63Ll++\nrNWrV8vtdoe7jJAYOXKkdu7cGekyHtvdu3e1YcMGZWZmto3t3LlTubm5ysnJ0bZt21ReXq7c3NwI\nVtk1vnqSpJUrVyorKytCVT2e06dP6+LFi3K73Wpubtb06dOVmZnp6OMk+e5r1KhRET9WYX8aXlVV\npezsbEnSwIEDdevWLd25cyfcZaATCQkJ2rdvn1JTU9vGqqurNWnSJElSVlaWqqqqIlVeUHz15HQj\nRozQjh07JEm9e/dWS0uL44+T5Luv1tbWCFcVgbBsampSnz592n7v27evGhsbw11GSFy6dEmLFi3S\n7NmzdfLkyUiXE7T4+Hj16NGj3VhLS0vb07mUlBTHHTNfPUlSWVmZ5s2bpxUrVujmzZsRqCx4cXFx\nSkpKkiSVl5dr3Lhxjj9Oku++4uLiIn6sIvKa5X/Fyt2WL7/8spYsWaKcnBzV19dr3rx5qqysdOTr\nRYHEyjGbOnWqkpOTlZaWpr1792r37t1au3ZtpMvqsqNHj6q8vFwHDhzQlClT2sadfpz+25fH44n4\nsQr7mWVqaqqamprafr9+/br69esX7jK6Xf/+/fXWW2/J5XLpxRdf1DPPPKOGhoZIl9VtkpKSdO/e\nPUlSQ0NDTDydzczMVFpamiRp4sSJqquri3BFXXfixAnt2bNH+/btU69evWLmOD3aVzQcq7CH5Zgx\nY1RRUSFJqq2tVWpqqnr27BnuMrrd4cOHtX//fklSY2Ojbty4of79+0e4qu4zevTotuNWWVmpsWPH\nRriix7d06VLV19dL+v/XZP/3TganuH37toqLi1VaWtp2lTgWjpOvvqLhWEXkU4dKSkp09uxZuVwu\nrVu3TkOGDAl3Cd3uzp07ys/P199//60HDx5oyZIlGj9+fKTLCorH49HmzZt19epVxcfHq3///iop\nKVFBQYHu37+vAQMGqKioSE8//XSkSzXz1VNeXp727t2rxMREJSUlqaioSCkpKZEu1cztdmvXrl16\n5ZVX2sY2bdqkNWvWOPY4Sb77mjFjhsrKyiJ6rPiINgAw4A4eADAgLAHAgLAEAAPCEgAMCEsAMCAs\nAcCAsAQAA8ISAAz+D2GuR1qUzSXkAAAAAElFTkSuQmCC\n", "text/plain": [ - "u'/content/t2t/checkpoints/transformer_ende_test/model.ckpt-350855'" + "" ] }, "metadata": { "tags": [] - }, - "execution_count": 8 + } } ] }, { "metadata": { - "id": "3O-8E9d6TtuJ", + "id": "gXL7_bVH49Kl", + "colab_type": "text" + }, + "source": [ + "# Translate from English to German with a pre-trained model" + ], + "cell_type": "markdown" + }, + { + "metadata": { + "id": "EB4MP7_y_SuQ", "colab_type": "code", "colab": { "autoexec": { @@ -333,18 +392,18 @@ }, "output_extras": [ { - "item_id": 3 + "item_id": 2 } ], "base_uri": "https://localhost:8080/", - "height": 119 + "height": 68 }, - "outputId": "24231c95-99cb-421b-d961-5a21322be945", + "outputId": "db79aefe-d9a6-437b-aaf8-4174a1f3c643", "executionInfo": { "status": "ok", - "timestamp": 1512165773424, + "timestamp": 1512173998055, "user_tz": 480, - "elapsed": 12593, + "elapsed": 2988, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -353,32 +412,40 @@ } }, "source": [ - "# Restore and translate!\n", + "# Fetch the problem\n", + "ende_problem = problems.problem(\"translate_ende_wmt32k\")\n", "\n", - "def translate(inputs):\n", - " encoded_inputs = encode(inputs)\n", - " with tfe.restore_variables_on_create(ckpt_path):\n", - " model_output = translate_model.infer(encoded_inputs)\n", - " return decode(model_output)\n", + "# Copy the vocab file locally so we can encode inputs and decode model outputs\n", + "# All vocabs are stored on GCS\n", + "vocab_file = os.path.join(gs_data_dir, \"vocab.ende.32768\")\n", + "!gsutil cp {vocab_file} {data_dir}\n", "\n", - "inputs = \"This is a cat.\"\n", - "outputs = translate(inputs)\n", + "# Get the encoders from the problem\n", + "encoders = ende_problem.feature_encoders(data_dir)\n", "\n", - "print(\"Inputs: %s\" % inputs)\n", - "print(\"Outputs: %s\" % outputs)" + "# Setup helper functions for encoding and decoding\n", + "def encode(input_str):\n", + " \"\"\"Input str to features dict, ready for inference\"\"\"\n", + " inputs = encoders[\"inputs\"].encode(input_str) + [1] # add EOS id\n", + " batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D.\n", + " return {\"inputs\": batch_inputs}\n", + "\n", + "def decode(integers):\n", + " \"\"\"List of ints to str\"\"\"\n", + " integers = list(np.squeeze(integers))\n", + " if 1 in integers:\n", + " integers = integers[:integers.index(1)]\n", + " return encoders[\"inputs\"].decode(np.squeeze(integers))" ], "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "outputs": [ { "output_type": "stream", "text": [ - "INFO:tensorflow:Greedy Decoding\n", - "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensor2tensor/layers/common_layers.py:487: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n", - "Instructions for updating:\n", - "keep_dims is deprecated, use keepdims instead\n", - "Inputs: This is a cat.\n", - "Outputs: Das ist eine Katze.\n" + "Copying gs://tensor2tensor-data/vocab.ende.32768...\n", + "/ [1 files][316.4 KiB/316.4 KiB] \n", + "Operation completed over 1 objects/316.4 KiB. \n" ], "name": "stdout" } @@ -386,17 +453,46 @@ }, { "metadata": { - "id": "i7BZuO7T5BB4", - "colab_type": "text" + "id": "g2aQW7Z6TOEu", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } }, "source": [ - "# Train a custom model on MNIST" + "# # Generate and view the data\n", + "# # This cell is commented out because WMT data generation can take hours\n", + "\n", + "# ende_problem.generate_data(data_dir, tmp_dir)\n", + "# example = tfe.Iterator(ende_problem.dataset(Modes.TRAIN, data_dir)).next()\n", + "# inputs = [int(x) for x in example[\"inputs\"].numpy()] # Cast to ints.\n", + "# targets = [int(x) for x in example[\"targets\"].numpy()] # Cast to ints.\n", + "\n", + "\n", + "\n", + "# # Example inputs as int-tensor.\n", + "# print(\"Inputs, encoded:\")\n", + "# print(inputs)\n", + "# print(\"Inputs, decoded:\")\n", + "# # Example inputs as a sentence.\n", + "# print(decode(inputs))\n", + "# # Example targets as int-tensor.\n", + "# print(\"Targets, encoded:\")\n", + "# print(targets)\n", + "# # Example targets as a sentence.\n", + "# print(\"Targets, decoded:\")\n", + "# print(decode(targets))" ], - "cell_type": "markdown" + "cell_type": "code", + "execution_count": 0, + "outputs": [] }, { "metadata": { - "id": "RYDMO4zArgkz", + "id": "WkFUEs7ZOA79", "colab_type": "code", "colab": { "autoexec": { @@ -409,14 +505,14 @@ } ], "base_uri": "https://localhost:8080/", - "height": 1224 + "height": 408 }, - "outputId": "3b62dff4-7bfa-436e-a9f5-ecf66616e93a", + "outputId": "7283214e-af66-4f16-b203-3b209643484f", "executionInfo": { "status": "ok", - "timestamp": 1512165773875, + "timestamp": 1512174000121, "user_tz": 480, - "elapsed": 423, + "elapsed": 321, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -425,158 +521,79 @@ } }, "source": [ - "# Lots of problems available\n", - "problems.available()" + "# There are many models available in Tensor2Tensor\n", + "registry.list_models()" ], "cell_type": "code", - "execution_count": 10, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "['algorithmic_addition_binary40',\n", - " 'algorithmic_addition_decimal40',\n", - " 'algorithmic_cipher_shift200',\n", - " 'algorithmic_cipher_shift5',\n", - " 'algorithmic_cipher_vigenere200',\n", - " 'algorithmic_cipher_vigenere5',\n", - " 'algorithmic_identity_binary40',\n", - " 'algorithmic_identity_decimal40',\n", - " 'algorithmic_multiplication_binary40',\n", - " 'algorithmic_multiplication_decimal40',\n", - " 'algorithmic_reverse_binary40',\n", - " 'algorithmic_reverse_binary40_test',\n", - " 'algorithmic_reverse_decimal40',\n", - " 'algorithmic_reverse_nlplike32k',\n", - " 'algorithmic_reverse_nlplike8k',\n", - " 'algorithmic_shift_decimal40',\n", - " 'audio_timit_characters_tune',\n", - " 'audio_timit_tokens8k_test',\n", - " 'audio_timit_tokens8k_tune',\n", - " 'image_celeba_tune',\n", - " 'image_cifar10',\n", - " 'image_cifar10_plain',\n", - " 'image_cifar10_plain8',\n", - " 'image_cifar10_tune',\n", - " 'image_fsns',\n", - " 'image_imagenet',\n", - " 'image_imagenet224',\n", - " 'image_imagenet32',\n", - " 'image_imagenet64',\n", - " 'image_mnist',\n", - " 'image_mnist_tune',\n", - " 'image_ms_coco_characters',\n", - " 'image_ms_coco_tokens32k',\n", - " 'image_ms_coco_tokens8k',\n", - " 'img2img_cifar10',\n", - " 'img2img_imagenet',\n", - " 'languagemodel_lm1b32k',\n", - " 'languagemodel_lm1b8k_packed',\n", - " 'languagemodel_lm1b_characters',\n", - " 'languagemodel_ptb10k',\n", - " 'languagemodel_ptb_characters',\n", - " 'languagemodel_wiki_full32k',\n", - " 'languagemodel_wiki_scramble128',\n", - " 'languagemodel_wiki_scramble1k50',\n", - " 'languagemodel_wiki_scramble8k50',\n", - " 'librispeech',\n", - " 'multinli_matched',\n", - " 'multinli_mismatched',\n", - " 'ocr_test',\n", - " 'parsing_english_ptb16k',\n", - " 'parsing_english_ptb8k',\n", - " 'parsing_icelandic16k',\n", - " 'programming_desc2code_cpp',\n", - " 'programming_desc2code_py',\n", - " 'sentiment_imdb',\n", - " 'summarize_cnn_dailymail32k',\n", - " 'translate_encs_wmt32k',\n", - " 'translate_encs_wmt_characters',\n", - " 'translate_ende_wmt32k',\n", - " 'translate_ende_wmt32k_packed',\n", - " 'translate_ende_wmt8k',\n", - " 'translate_ende_wmt_bpe32k',\n", - " 'translate_ende_wmt_characters',\n", - " 'translate_enfr_wmt32k',\n", - " 'translate_enfr_wmt8k',\n", - " 'translate_enfr_wmt_characters',\n", - " 'translate_enfr_wmt_small32k',\n", - " 'translate_enfr_wmt_small8k',\n", - " 'translate_enfr_wmt_small_characters',\n", - " 'translate_enmk_setimes32k',\n", - " 'translate_enzh_wmt8k']" + "execution_count": 9, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "['resnet50',\n", + " 'lstm_seq2seq',\n", + " 'transformer_encoder',\n", + " 'attention_lm',\n", + " 'vanilla_gan',\n", + " 'transformer',\n", + " 'gene_expression_conv',\n", + " 'transformer_moe',\n", + " 'attention_lm_moe',\n", + " 'transformer_revnet',\n", + " 'lstm_seq2seq_attention',\n", + " 'shake_shake',\n", + " 'transformer_ae',\n", + " 'diagonal_neural_gpu',\n", + " 'xception',\n", + " 'aligned',\n", + " 'multi_model',\n", + " 'neural_gpu',\n", + " 'slice_net',\n", + " 'byte_net',\n", + " 'cycle_gan',\n", + " 'transformer_sketch',\n", + " 'blue_net']" ] }, "metadata": { "tags": [] }, - "execution_count": 10 + "execution_count": 9 } ] }, { "metadata": { - "id": "JKc2uSk6WX5e", + "id": "9l6hDQbrRUYV", "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - }, - "output_extras": [ - { - "item_id": 3 - } - ], - "base_uri": "https://localhost:8080/", - "height": 204 - }, - "outputId": "f9fa17c1-ed3f-474e-8bd8-f764c3b00000", - "executionInfo": { - "status": "ok", - "timestamp": 1512165774930, - "user_tz": 480, - "elapsed": 977, - "user": { - "displayName": "Ryan Sepassi", - "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", - "userId": "107877449274830904926" } } }, "source": [ - "# Create the MNIST problem and generate the data\n", + "# Create hparams and the model\n", + "model_name = \"transformer\"\n", + "hparams_set = \"transformer_base\"\n", "\n", - "mnist_problem = problems.problem(\"image_mnist\")\n", - "# Generate data\n", - "mnist_problem.generate_data(data_dir, tmp_dir)" + "hparams = trainer_utils.create_hparams(hparams_set, data_dir)\n", + "trainer_utils.add_problem_hparams(hparams, \"translate_ende_wmt32k\")\n", + "\n", + "# NOTE: Only create the model once when restoring from a checkpoint; it's a\n", + "# Layer and so subsequent instantiations will have different variable scopes\n", + "# that will not match the checkpoint.\n", + "translate_model = registry.model(model_name)(hparams, Modes.PREDICT)" ], "cell_type": "code", - "execution_count": 11, - "outputs": [ - { - "output_type": "stream", - "text": [ - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-images-idx3-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-labels-idx1-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-images-idx3-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-labels-idx1-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-images-idx3-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/train-labels-idx1-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-images-idx3-ubyte.gz\n", - "INFO:tensorflow:Not downloading, file already found: /content/t2t/tmp/t10k-labels-idx1-ubyte.gz\n", - "INFO:tensorflow:Skipping generator because outputs files exist\n", - "INFO:tensorflow:Skipping generator because outputs files exist\n", - "INFO:tensorflow:Skipping shuffle because output files exist\n" - ], - "name": "stdout" - } - ] + "execution_count": 0, + "outputs": [] }, { "metadata": { - "id": "VW6HCRANFPYV", + "id": "FEwNUVlMYOJi", "colab_type": "code", "colab": { "autoexec": { @@ -585,21 +602,18 @@ }, "output_extras": [ { - "item_id": 2 - }, - { - "item_id": 3 + "item_id": 1 } ], "base_uri": "https://localhost:8080/", - "height": 381 + "height": 34 }, - "outputId": "93dea49c-dbca-4856-998f-8bcbb621abea", + "outputId": "ec8569a0-ee0e-4520-c9c6-06f3c7582ecc", "executionInfo": { "status": "ok", - "timestamp": 1512165775597, + "timestamp": 1512174015202, "user_tz": 480, - "elapsed": 622, + "elapsed": 12781, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -608,42 +622,33 @@ } }, "source": [ - "# Get the tf.data.Dataset from Problem.dataset\n", - "mnist_example = tfe.Iterator(mnist_problem.dataset(Modes.TRAIN, data_dir)).next()\n", - "image = mnist_example[\"inputs\"]\n", - "label = mnist_example[\"targets\"]\n", - "\n", - "plt.imshow(image.numpy()[:, :, 0].astype(np.float32), cmap=plt.get_cmap('gray'))\n", - "print(\"Label: %d\" % label.numpy())" + "# Copy the pretrained checkpoint locally\n", + "ckpt_name = \"transformer_ende_test\"\n", + "gs_ckpt = os.path.join(gs_ckpt_dir, ckpt_name)\n", + "!gsutil -q cp -R {gs_ckpt} {checkpoint_dir}\n", + "ckpt_path = tf.train.latest_checkpoint(os.path.join(checkpoint_dir, ckpt_name))\n", + "ckpt_path" ], "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "outputs": [ { - "output_type": "stream", - "text": [ - "INFO:tensorflow:Reading data files from /content/t2t/data/image_mnist-train*\n", - "Label: 6\n" - ], - "name": "stdout" - }, - { - "output_type": "display_data", + "output_type": "execute_result", "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUsAAAFKCAYAAACU6307AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAFK1JREFUeJzt3X9MVfUfx/HXDSQgJJSEzS2rNS0m\nuFWzxB8Vymx8y1JrsxCdzT/shyaZK8ZEWzZ/oP2Qfomm/iG53cYfzj90MLNWKuBk1YR/0NqMWREY\nGSYU2P3+0WIhF3hzufeee67Px8Yf93M+nPN+fw+9vuee4znH4/P5fAIADOoGpwsAADcgLAHAgLAE\nAAPCEgAMCEsAMCAsAcDCFwaS/P6cOXNmwGVu/YnGnqK1L3pyz0+4+hqMJxz/ztLj8fgd9/l8Ay5z\nq2jsSYrOvujJPcLV12BxGBvoSjdt2qRvv/1WHo9HxcXFmjJlSqCrAoCIF1BYnjp1SufPn5fX69V3\n332n4uJieb3eYNcGABEjoAs8NTU1ys3NlSTdeeedunTpki5fvhzUwgAgkgR0ZNnW1qbJkyf3fh47\ndqxaW1uVlJTkd/6ZM2eUmZnpd1kYTpmGXTT2JEVnX/TkHk73FfA5y/8aqomsrKwBfy/aTkZHY09S\ndPZFT+4RCRd4AvoanpaWpra2tt7Pv/zyi8aNGxfIqgDAFQIKyxkzZqiqqkqS1NjYqLS0tAG/ggNA\nNAjoa/i9996ryZMn6+mnn5bH49GGDRuCXRcARBT+UXqQRWNPUnT2RU/u4dpzlgBwvSEsAcCAsAQA\nA8ISAAwISwAwICwBwICwBAADwhIADAhLADAgLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwI\nSwAwICwBwICwBAADwhIADAhLADAI6FW4QLSaNGmSad7JkyfN6/zss8/McxctWmSei/DiyBIADAhL\nADAgLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAw4HZHRL2EhATzspKSEtM6x4wZY97+\nN998Y56LyMWRJQAYEJYAYEBYAoABYQkABoQlABgQlgBgQFgCgAFhCQAGhCUAGHAHD6Jebm6ueVl+\nfn7Qt19RURH0dSL8OLIEAIOAjizr6uq0evVqTZw4UdI/rw+13lMLAG4U8Nfw+++/X2VlZcGsBQAi\nFl/DAcAg4LA8d+6cnnvuOT3zzDM6ceJEMGsCgIjj8fl8vuH+UktLi+rr65WXl6fm5mYtXbpU1dXV\niouL8zu/oaFBmZmZIy4WAJwSUFhe66mnntI777yjW2+91f9GPB6/4z6fb8BlbhWNPUnu7mvevHl+\nxw8dOqTHH3+8z9jBgweDvv3bb7/dPLe5uXlE23LzfhpMuPoaLA4D+hp+6NAh7dmzR5LU2tqqixcv\nKj09PbDqAMAFAroaPnv2bK1du1afffaZuru79frrrw/4FRwAokFAYZmUlKSdO3cGuxYAiFhBOWc5\n5EY4Z+l6kdbXYLcwXuvw4cN+x0eNGqXu7u4+YzExMaZ1fvnll+btz5071zz32nqGK9L2U7C49pwl\nAFxvCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADDg7Y6IKI899php3oEDB8zr\njI0d+M/82mWXLl0yrXP58uXm7Y/0FkZEBo4sAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhL\nADAgLAHAgDt4EJDB7oq51vPPP2+e+8Ybb5jm3XTTTeZ1/vHHH37Hk5KS+i3Lz883rfP77783bx/R\ngSNLADAgLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwIDbHRGQ2bNnm+e+++67\nQd9+T0+Pee7Bgwf9jhcUFPRbduTIkRHVhejFkSUAGBCWAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoAB\nYQkABoQlABgQlgBg4PH5fL6Qb8Tj8Tvu8/kGXOZWbu/pgQce8DteW1uradOm9X4+duyYeZ3x8fEj\nrutaGzZsMM998803/Y67fV/5E409SeHra7A4NB1ZNjU1KTc3VxUVFZKkn376SUuWLFF+fr5Wr16t\nv/76KziVAkCEGjIsr1y5oo0bNyo7O7t3rKysTPn5+Tpw4IBuu+02VVZWhrRIAHDakGEZFxen3bt3\nKy0trXesrq5Oc+bMkSTl5OSopqYmdBUCQAQY8hFtsbGxio3tO62zs1NxcXGSpNTUVLW2toamOgCI\nECN+nqXl+tCZM2eUmZkZ8O+7TTT2JP1zkSdSbNy4MShzo3FfRWNPkvN9BRSWiYmJ6urqUnx8vFpa\nWvp8RfcnKyvL73g0Xrlze09cDXe3aOxJctHV8GtNnz5dVVVVkqTq6mrNmjUrsMoAwCWGPLJsaGjQ\n1q1bdeHCBcXGxqqqqkrbt29XUVGRvF6vxo8fr/nz54ejVgBwzJBhmZmZqf379/cb37dvX0gKAoBI\nxAvL0MfatWtNy0JxHlKSPvnkE9O8t956KyTbD4VHH33UPDcmJsY898KFC37H77vvvj6f6+vrzevE\nwLg3HAAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADDghWVBFok9Pfvss+a5\nH374od/x+Ph4dXV19X7+9+HPFhcvXjTPnT59umneuXPnzOu8+eab/Y7/9ttvSklJ6TO2bt060zoX\nLlxo3v5tt91mnjucv50//vij39jo0aPV0dHRZ+yOO+4wr3M4+yqcXPuINgC43hCWAGBAWAKAAWEJ\nAAaEJQAYEJYAYEBYAoABYQkABoQlABgQlgBgwNsdXWo4bwFcsGCBee6NN95oWjacu2RXrVplnmu9\njXGgWxj9WbZsmXnZK6+8Yl6v1XBu0xvO/65JSUmm8cLCQvM6S0pKzHOvNxxZAoABYQkABoQlABgQ\nlgBgQFgCgAFhCQAGhCUAGBCWAGBAWAKAAS8sC7Jw9ZSenm6e++OPP454ezfccIP+/vvv3s/+XpY1\nkOTkZPPce++91zRvx44d5nXOmDHD77jH4+l3x0wo/nM4efKkea71hW0DuXY/DXeddXV1I9p+qPDC\nMgBwCcISAAwISwAwICwBwICwBAADwhIADAhLADAgLAHAgLAEAAPCEgAMeGEZAtLY2GieO3bsWPPc\niooK07y77rrLvM6Ojg6/48nJyf2Web1e0zo/+eQT8/ZHjRplnltVVWWeu23btn5jr732Wr/x+vp6\n8zoxMI4sAcDAFJZNTU3Kzc3t/X/9oqIizZs3T0uWLNGSJUv0xRdfhLJGAHDckF/Dr1y5oo0bNyo7\nO7vP+Jo1a5STkxOywgAgkgx5ZBkXF6fdu3crLS0tHPUAQEQyP8/yvffe05gxY1RQUKCioiK1traq\nu7tbqampKikpGfQkfkNDgzIzM4NWNACEW0BXw5944gmlpKQoIyNDu3bt0vvvv6/169cPOD8rK8vv\nOA//DZzTD/89deqU+XcfffRR89zjx4+b5g3navjly5f9jicnJ+v333/vMxYNV8O3bt3aZ2zdunXm\ndfb09JjnhpNrH/6bnZ2tjIwMSdLs2bPV1NQUWGUA4BIBheWqVavU3Nws6Z/H0E+cODGoRQFApBny\na3hDQ4O2bt2qCxcuKDY2VlVVVSooKFBhYaESEhKUmJiozZs3h6NWAHDMkGGZmZmp/fv39xt/5JFH\nQlIQAEQibnd0qWeffdbR7Q/nZPu/p2ws4uPjTfOuXLliXmdeXp7f8RMnTvRbZn0TY0JCgnn7R44c\nMc8dzgWWTz/9tN/Ya6+91m88Ui/auA23OwKAAWEJAAaEJQAYEJYAYEBYAoABYQkABoQlABgQlgBg\nQFgCgAFhCQAG5of/jmgjA9wax/MsA+f08yydVlxcbJ67Y8cOv+OdnZ39blucN2+eaZ2vvPKKefv3\n3HOPee6aNWvMcz/44IN+Y9H435Tk4udZAsD1hrAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAw\nICwBwIAXlrlUV1eXee4PP/xgnjthwoRAygmalStXmuadOHHCvM6PPvrIvGzp0qWmdX7//ffm7b/0\n0kvmueXl5ea5CC+OLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADXlgW\nZJHY04EDB8xzFy1a5Hc8XC8su3TpkmlebKz9Tt2kpCS/4x6Pp98Lqqy3kc6aNcu8/fr6evPckYrE\nv79g4IVlAOAShCUAGBCWAGBAWAKAAWEJAAaEJQAYEJYAYEBYAoABYQkABoQlABjwdsfrQFtbm9Ml\nmN18881BX+dAb4KcOXNmv2WLFy82rXM4b8xEdDCFZWlpqerr69XT06MVK1YoKytLr776qq5evapx\n48Zp27ZtiouLC3WtAOCYIcOytrZWZ8+eldfrVXt7uxYsWKDs7Gzl5+crLy9Pb7/9tiorK5Wfnx+O\negHAEUOes5w6dap27NghSUpOTlZnZ6fq6uo0Z84cSVJOTo5qampCWyUAOGzIsIyJiVFiYqIkqbKy\nUg8++KA6Ozt7v3anpqaqtbU1tFUCgMPMF3iOHj2qyspK7d27V3Pnzu0dtzwO88yZM8rMzPS7LAyP\n0wy7aOxJ+ueZlm40c+ZM87Lz58+HupyQi9a/P6f7MoXlV199pZ07d+rjjz/W6NGjlZiYqK6uLsXH\nx6ulpUVpaWmD/n5WVpbf8Wh8UGkk9lRWVmae++KLL/odD9fDf0Ph5MmTfsdnzpyp48eP9xlz+9Xw\nSPz7CwZXPPy3o6NDpaWlKi8vV0pKiiRp+vTpqqqqkiRVV1cP66nRAOBGQx5ZHj58WO3t7SosLOwd\n27Jli9atWyev16vx48dr/vz5IS0SAJw2ZFguWrTI73tZ9u3bF5KCACAS8cKyIIvEnqZNm2aeO9Dd\nLiM5Z/nWW2+Z5x45ciSgbQzm888/9zseiftqpKKxJ8kl5ywBAIQlAJgQlgBgQFgCgAFhCQAGhCUA\nGBCWAGBAWAKAAWEJAAaEJQAYcLtjkEViT/Hx8ea5Az3O7J577tHXX3/d+3ny5MnmdT700EPmubW1\ntea5IxWJ+2qkorEnidsdAcA1CEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADDg\ndscgi8aepOjsi57cg9sdAcAlCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADAg\nLAHAgLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADGItk0pLS1VfX6+e\nnh6tWLFCx44dU2Njo1JSUiRJy5cv18MPPxzKOgHAUUOGZW1trc6ePSuv16v29nYtWLBA06ZN05o1\na5STkxOOGgHAcUOG5dSpUzVlyhRJUnJysjo7O3X16tWQFwYAkcTjG+yt4tfwer06ffq0YmJi1Nra\nqu7ubqWmpqqkpERjx44deCMDvBw9Gl8IH409SdHZFz25R7j6GiwOzWF59OhRlZeXa+/evWpoaFBK\nSooyMjK0a9cu/fzzz1q/fv2Av9vQ0KDMzMzhVw4AkcJn8OWXX/qefPJJX3t7e79lZ8+e9S1evHjQ\n35fk92ewZW79icaeorUvenLPT7j6GsyQ/3Soo6NDpaWlKi8v7736vWrVKjU3N0uS6urqNHHixKFW\nAwCuNuQFnsOHD6u9vV2FhYW9YwsXLlRhYaESEhKUmJiozZs3h7RIAHDasC7wBLwRLvC4XjT2RU/u\nEa6+BotD7uABAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADAhLADAgLAHA\ngLAEAAPCEgAMCEsAMCAsAcCAsAQAA8ISAAwISwAwICwBwICwBAADwhIADMLyKlwAcDuOLAHAgLAE\nAAPCEgAMCEsAMCAsAcCAsAQAg1gnNrpp0yZ9++238ng8Ki4u1pQpU5woI6jq6uq0evVqTZw4UZI0\nadIklZSUOFxV4JqamvTCCy9o2bJlKigo0E8//aRXX31VV69e1bhx47Rt2zbFxcU5XeawXNtTUVGR\nGhsblZKSIklavny5Hn74YWeLHKbS0lLV19erp6dHK1asUFZWluv3k9S/r2PHjjm+r8IelqdOndL5\n8+fl9Xr13Xffqbi4WF6vN9xlhMT999+vsrIyp8sYsStXrmjjxo3Kzs7uHSsrK1N+fr7y8vL09ttv\nq7KyUvn5+Q5WOTz+epKkNWvWKCcnx6GqRqa2tlZnz56V1+tVe3u7FixYoOzsbFfvJ8l/X9OmTXN8\nX4X9a3hNTY1yc3MlSXfeeacuXbqky5cvh7sMDCIuLk67d+9WWlpa71hdXZ3mzJkjScrJyVFNTY1T\n5QXEX09uN3XqVO3YsUOSlJycrM7OTtfvJ8l/X1evXnW4KgfCsq2tTWPGjOn9PHbsWLW2toa7jJA4\nd+6cnnvuOT3zzDM6ceKE0+UELDY2VvHx8X3GOjs7e7/Opaamum6f+etJkioqKrR06VK9/PLL+vXX\nXx2oLHAxMTFKTEyUJFVWVurBBx90/X6S/PcVExPj+L5y5Jzlf0XL3Za33367Vq5cqby8PDU3N2vp\n0qWqrq525fmioUTLPnviiSeUkpKijIwM7dq1S++//77Wr1/vdFnDdvToUVVWVmrv3r2aO3du77jb\n99N/+2poaHB8X4X9yDItLU1tbW29n3/55ReNGzcu3GUEXXp6uv73v//J4/FowoQJuuWWW9TS0uJ0\nWUGTmJiorq4uSVJLS0tUfJ3Nzs5WRkaGJGn27NlqampyuKLh++qrr7Rz507t3r1bo0ePjpr9dG1f\nkbCvwh6WM2bMUFVVlSSpsbFRaWlpSkpKCncZQXfo0CHt2bNHktTa2qqLFy8qPT3d4aqCZ/r06b37\nrbq6WrNmzXK4opFbtWqVmpubJf1zTvbff8ngFh0dHSotLVV5eXnvVeJo2E/++oqEfeXIU4e2b9+u\n06dPy+PxaMOGDbr77rvDXULQXb58WWvXrtXvv/+u7u5urVy5Ug899JDTZQWkoaFBW7du1YULFxQb\nG6v09HRt375dRUVF+vPPPzV+/Hht3rxZo0aNcrpUM389FRQUaNeuXUpISFBiYqI2b96s1NRUp0s1\n83q9eu+993THHXf0jm3ZskXr1q1z7X6S/Pe1cOFCVVRUOLqveEQbABhwBw8AGBCWAGBAWAKAAWEJ\nAAaEJQAYEJYAYEBYAoABYQkABv8HkbgWVGnLsmMAAAAASUVORK5CYII=\n", "text/plain": [ - "" + "u'/content/t2t/checkpoints/transformer_ende_test/model.ckpt-350855'" ] }, "metadata": { "tags": [] - } + }, + "execution_count": 11 } ] }, { "metadata": { - "id": "WkFUEs7ZOA79", + "id": "3O-8E9d6TtuJ", "colab_type": "code", "colab": { "autoexec": { @@ -652,18 +657,18 @@ }, "output_extras": [ { - "item_id": 1 + "item_id": 3 } ], "base_uri": "https://localhost:8080/", - "height": 408 + "height": 119 }, - "outputId": "3d0c50f2-9c18-4d4b-8455-1aabe9e28190", + "outputId": "306d8df1-70c4-43f5-fc15-54ff66ec58ed", "executionInfo": { "status": "ok", - "timestamp": 1512165775887, + "timestamp": 1512174026517, "user_tz": 480, - "elapsed": 242, + "elapsed": 11277, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -672,48 +677,47 @@ } }, "source": [ - "# Lots of models available\n", - "registry.list_models()" + "# Restore and translate!\n", + "\n", + "def translate(inputs):\n", + " encoded_inputs = encode(inputs)\n", + " with tfe.restore_variables_on_create(ckpt_path):\n", + " model_output = translate_model.infer(encoded_inputs)\n", + " return decode(model_output)\n", + "\n", + "inputs = \"This is a cat.\"\n", + "outputs = translate(inputs)\n", + "\n", + "print(\"Inputs: %s\" % inputs)\n", + "print(\"Outputs: %s\" % outputs)" ], "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "outputs": [ { - "output_type": "execute_result", - "data": { - "text/plain": [ - "['resnet50',\n", - " 'lstm_seq2seq',\n", - " 'transformer_encoder',\n", - " 'attention_lm',\n", - " 'vanilla_gan',\n", - " 'transformer',\n", - " 'gene_expression_conv',\n", - " 'transformer_moe',\n", - " 'attention_lm_moe',\n", - " 'transformer_revnet',\n", - " 'lstm_seq2seq_attention',\n", - " 'shake_shake',\n", - " 'transformer_ae',\n", - " 'diagonal_neural_gpu',\n", - " 'xception',\n", - " 'aligned',\n", - " 'multi_model',\n", - " 'neural_gpu',\n", - " 'slice_net',\n", - " 'byte_net',\n", - " 'cycle_gan',\n", - " 'transformer_sketch',\n", - " 'blue_net']" - ] - }, - "metadata": { - "tags": [] - }, - "execution_count": 13 + "output_type": "stream", + "text": [ + "INFO:tensorflow:Greedy Decoding\n", + "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensor2tensor/layers/common_layers.py:487: calling reduce_mean (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n", + "Instructions for updating:\n", + "keep_dims is deprecated, use keepdims instead\n", + "Inputs: This is a cat.\n", + "Outputs: Das ist eine Katze.\n" + ], + "name": "stdout" } ] }, + { + "metadata": { + "id": "i7BZuO7T5BB4", + "colab_type": "text" + }, + "source": [ + "# Train a custom model on MNIST" + ], + "cell_type": "markdown" + }, { "metadata": { "id": "-H25oG91YQj3", @@ -751,7 +755,7 @@ }, { "metadata": { - "id": "AWVd2I7PYz6H", + "id": "7GEmpYQ2ZMnB", "colab_type": "code", "colab": { "autoexec": { @@ -760,18 +764,18 @@ }, "output_extras": [ { - "item_id": 12 + "item_id": 1 } ], "base_uri": "https://localhost:8080/", - "height": 357 + "height": 34 }, - "outputId": "19abcffa-6a56-4633-90c1-71a59a104ace", + "outputId": "9535b122-d663-470b-fb03-15541769a8d6", "executionInfo": { "status": "ok", - "timestamp": 1512165882231, + "timestamp": 1512174027233, "user_tz": 480, - "elapsed": 105926, + "elapsed": 372, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -780,27 +784,72 @@ } }, "source": [ - "# Train\n", + "# Prepare for the training loop\n", "\n", - "# In Eager mode, opt.minimize must be passed a function that produces the loss\n", - "def loss_function(features):\n", + "# In Eager mode, opt.minimize must be passed a loss function wrapped with\n", + "# implicit_value_and_gradients\n", + "@tfe.implicit_value_and_gradients\n", + "def loss_fn(features):\n", " _, losses = model(features)\n", " return losses[\"training\"]\n", "\n", - "tfe_loss_fn = tfe.implicit_value_and_gradients(loss_function)\n", - "optimizer = tf.train.AdamOptimizer()\n", - "\n", - "NUM_STEPS = 500\n", + "# Setup the training data\n", "BATCH_SIZE = 128\n", - "\n", - "# Repeat and batch the data\n", "mnist_train_dataset = mnist_problem.dataset(Modes.TRAIN, data_dir)\n", "mnist_train_dataset = mnist_train_dataset.repeat(None).batch(BATCH_SIZE)\n", "\n", - "# Training loop\n", + "optimizer = tf.train.AdamOptimizer()" + ], + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "text": [ + "INFO:tensorflow:Reading data files from /content/t2t/data/image_mnist-train*\n" + ], + "name": "stdout" + } + ] + }, + { + "metadata": { + "id": "AWVd2I7PYz6H", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "output_extras": [ + { + "item_id": 11 + } + ], + "base_uri": "https://localhost:8080/", + "height": 340 + }, + "outputId": "adfe2262-ca2a-4d74-ef6f-4caaf5531824", + "executionInfo": { + "status": "ok", + "timestamp": 1512174129153, + "user_tz": 480, + "elapsed": 101898, + "user": { + "displayName": "Ryan Sepassi", + "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", + "userId": "107877449274830904926" + } + } + }, + "source": [ + "# Train\n", + "\n", + "NUM_STEPS = 500\n", + "\n", "for count, example in enumerate(tfe.Iterator(mnist_train_dataset)):\n", " example[\"targets\"] = tf.reshape(example[\"targets\"], [BATCH_SIZE, 1, 1, 1]) # Make it 4D.\n", - " loss, gv = tfe_loss_fn(example)\n", + " loss, gv = loss_fn(example)\n", " optimizer.apply_gradients(gv)\n", "\n", " if count % 50 == 0:\n", @@ -814,7 +863,6 @@ { "output_type": "stream", "text": [ - "INFO:tensorflow:Reading data files from /content/t2t/data/image_mnist-train*\n", "WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensor2tensor/layers/common_layers.py:1671: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "\n", @@ -823,22 +871,71 @@ "\n", "See tf.nn.softmax_cross_entropy_with_logits_v2.\n", "\n", - "Step: 0, Loss: 5.430\n", - "Step: 50, Loss: 0.833\n", - "Step: 100, Loss: 0.722\n", - "Step: 150, Loss: 0.529\n", - "Step: 200, Loss: 0.349\n", - "Step: 250, Loss: 0.293\n", - "Step: 300, Loss: 0.303\n", - "Step: 350, Loss: 0.295\n", - "Step: 400, Loss: 0.275\n", - "Step: 450, Loss: 0.290\n", - "Step: 500, Loss: 0.334\n" + "Step: 0, Loss: 5.357\n", + "Step: 50, Loss: 0.746\n", + "Step: 100, Loss: 0.618\n", + "Step: 150, Loss: 0.502\n", + "Step: 200, Loss: 0.395\n", + "Step: 250, Loss: 0.345\n", + "Step: 300, Loss: 0.338\n", + "Step: 350, Loss: 0.175\n", + "Step: 400, Loss: 0.345\n", + "Step: 450, Loss: 0.373\n", + "Step: 500, Loss: 0.292\n" ], "name": "stdout" } ] }, + { + "metadata": { + "id": "a2cL8UwLaSYG", + "colab_type": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + } + }, + "source": [ + "# This will eventually be available at\n", + "# tensor2tensor.metrics.create_eager_metrics\n", + "def create_eager_metrics(metric_names):\n", + " \"\"\"Create metrics accumulators and averager for Eager mode.\n", + "\n", + " Args:\n", + " metric_names: list from tensor2tensor.metrics.Metrics\n", + "\n", + " Returns:\n", + " (accum_fn(predictions, targets) => None,\n", + " result_fn() => dict\n", + " \"\"\"\n", + " metric_fns = dict(\n", + " [(name, metrics.METRICS_FNS[name]) for name in metric_names])\n", + " tfe_metrics = dict()\n", + "\n", + " for name in metric_names:\n", + " tfe_metrics[name] = tfe.metrics.Mean(name=name)\n", + "\n", + " def metric_accum(predictions, targets):\n", + " for name, metric_fn in metric_fns.items():\n", + " val, weight = metric_fn(predictions, targets,\n", + " weights_fn=common_layers.weights_all)\n", + " tfe_metrics[name](np.squeeze(val), np.squeeze(weight))\n", + "\n", + " def metric_means():\n", + " avgs = {}\n", + " for name in metric_names:\n", + " avgs[name] = tfe_metrics[name].result().numpy()\n", + " return avgs\n", + "\n", + " return metric_accum, metric_means" + ], + "cell_type": "code", + "execution_count": 0, + "outputs": [] + }, { "metadata": { "id": "CIFlkiVOd8jO", @@ -854,14 +951,14 @@ } ], "base_uri": "https://localhost:8080/", - "height": 51 + "height": 68 }, - "outputId": "70b92db9-9ec0-466c-e5c2-c5a39f13447d", + "outputId": "95ec4064-d884-4ea8-acdf-ffe83dc0c230", "executionInfo": { "status": "ok", - "timestamp": 1512165950748, + "timestamp": 1512174132643, "user_tz": 480, - "elapsed": 2772, + "elapsed": 3097, "user": { "displayName": "Ryan Sepassi", "photoUrl": "//lh4.googleusercontent.com/-dcHmhQy1Y2A/AAAAAAAAAAI/AAAAAAAABEw/if_k14yF4KI/s50-c-k-no/photo.jpg", @@ -872,25 +969,29 @@ "source": [ "model.set_mode(Modes.EVAL)\n", "mnist_eval_dataset = mnist_problem.dataset(Modes.EVAL, data_dir)\n", - "all_perplexities = []\n", - "all_accuracies = []\n", + "\n", + "# Create eval metric accumulators for accuracy (ACC) and accuracy in\n", + "# top 5 (ACC_TOP5)\n", + "metrics_accum, metrics_result = create_eager_metrics(\n", + " [metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5])\n", + "\n", "for count, example in enumerate(tfe.Iterator(mnist_eval_dataset)):\n", - " if count >= 100:\n", + " if count >= 200:\n", " break\n", "\n", - " batch_inputs = tf.reshape(example[\"inputs\"], [1, 28, 28, 3]) # Make it 4D.\n", - " batch_targets = tf.reshape(example[\"targets\"], [1, 1, 1, 1]) # Make it 4D.\n", - " features = {\"inputs\": batch_inputs, \"targets\": batch_targets}\n", + " # Make the inputs and targets 4D\n", + " example[\"inputs\"] = tf.reshape(example[\"inputs\"], [1, 28, 28, 3])\n", + " example[\"targets\"] = tf.reshape(example[\"targets\"], [1, 1, 1, 1])\n", "\n", - " # Call the model.\n", - " predictions, _ = model(features)\n", + " # Call the model\n", + " predictions, _ = model(example)\n", "\n", - " # Calculate and append the metrics\n", - " all_perplexities.extend(metrics.padded_neg_log_perplexity(predictions, features[\"targets\"]))\n", - " all_accuracies.extend(metrics.padded_accuracy(predictions, features[\"targets\"]))\n", + " # Compute and accumulate metrics\n", + " metrics_accum(predictions, example[\"targets\"])\n", "\n", - "# Print out metrics on the dataset\n", - "print(\"Accuracy: %.2f\" % tf.reduce_mean(tf.concat(all_accuracies, axis=1)).numpy())" + "# Print out the averaged metric values on the eval data\n", + "for name, val in metrics_result().items():\n", + " print(\"%s: %.2f\" % (name, val))" ], "cell_type": "code", "execution_count": 17, @@ -899,7 +1000,8 @@ "output_type": "stream", "text": [ "INFO:tensorflow:Reading data files from /content/t2t/data/image_mnist-dev*\n", - "Accuracy: 0.98\n" + "accuracy_top5: 1.00\n", + "accuracy: 0.98\n" ], "name": "stdout" }