diff --git a/.travis.yml b/.travis.yml index 46373f829..370682401 100644 --- a/.travis.yml +++ b/.travis.yml @@ -24,6 +24,6 @@ script: - mkdir $T2T_TRAIN_DIR - t2t-datagen --problem=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR - t2t-trainer --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --train_steps=5 --eval_steps=5 --output_dir=$T2T_TRAIN_DIR - - t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR + - t2t-decoder --problems=$T2T_PROBLEM --data_dir=$T2T_DATA_DIR --model=transformer --hparams_set=transformer_tiny --output_dir=$T2T_TRAIN_DIR --decode_hparams='num_samples=10' git: depth: 3 diff --git a/setup.py b/setup.py index d097b91d6..5b6f4690e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.2.4', + version='1.2.5', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index 8fa1e52d0..09c1645a1 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -74,7 +74,7 @@ def story_generator(tmp_dir): for path in paths: for story_file in tf.gfile.Glob(path + "*"): story = u"" - for line in tf.gfile.Open(story_file, 'rb'): + for line in tf.gfile.Open(story_file, "rb"): line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8") story += line yield story diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index acd121868..c8fe03564 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -355,6 +355,8 @@ def generate(): for lang_file in source[1]: tf.logging.info("Reading file: %s" % lang_file) filepath = os.path.join(tmp_dir, lang_file) + + # Extract from tar if needed. if not tf.gfile.Exists(filepath): read_type = "r:gz" if filename.endswith("tgz") else "r" with tarfile.open(compressed_file, read_type) as corpus_tar: @@ -411,7 +413,7 @@ def generate(): for line in source_file: line = line.strip() if line and "\t" in line: - parts = line.split("\t", maxsplit=1) + parts = line.split("\t", 1) part = parts[index].strip() yield part diff --git a/tensor2tensor/data_generators/image.py b/tensor2tensor/data_generators/image.py index 5b41c4e19..df497019a 100644 --- a/tensor2tensor/data_generators/image.py +++ b/tensor2tensor/data_generators/image.py @@ -42,6 +42,12 @@ import tensorflow as tf +def resize_by_area(img, size): + """image resize function used by quite a few image problems.""" + return tf.to_int64( + tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) + + class ImageProblem(problem.Problem): def example_reading_spec(self, label_key=None): @@ -93,16 +99,12 @@ class ImageCeleba(ImageProblem): def preprocess_example(self, example, unused_mode, unused_hparams): - def resize(img, size): - return tf.to_int64( - tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) - inputs = example["inputs"] # Remove boundaries in CelebA images. Remove 40 pixels each side # vertically and 20 pixels each side horizontally. inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40) - example["inputs"] = resize(inputs, 8) - example["targets"] = resize(inputs, 32) + example["inputs"] = resize_by_area(inputs, 8) + example["targets"] = resize_by_area(inputs, 32) return example def hparams(self, defaults, unused_model_hparams): @@ -388,14 +390,10 @@ def dataset_filename(self): def preprocess_example(self, example, unused_mode, unused_hparams): - def resize(img, size): - return tf.to_int64( - tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA)) - inputs = example["inputs"] # For Img2Img resize input and output images as desired. - example["inputs"] = resize(inputs, 8) - example["targets"] = resize(inputs, 32) + example["inputs"] = resize_by_area(inputs, 8) + example["targets"] = resize_by_area(inputs, 32) return example def hparams(self, defaults, unused_model_hparams): @@ -654,6 +652,43 @@ def preprocess_example(self, example, mode, unused_hparams): return example +@registry.register_problem +class ImageCifar10Plain8(ImageCifar10): + """CIFAR-10 rescaled to 8x8 for output: Conditional image generation.""" + + def dataset_filename(self): + return "image_cifar10_plain" # Reuse CIFAR-10 plain data. + + def preprocess_example(self, example, mode, unused_hparams): + example["inputs"] = resize_by_area(example["inputs"], 8) + return example + + +@registry.register_problem +class Img2imgCifar10(ImageCifar10): + """CIFAR-10 rescaled to 8x8 for input and 32x32 for output.""" + + def dataset_filename(self): + return "image_cifar10_plain" # Reuse CIFAR-10 plain data. + + def preprocess_example(self, example, unused_mode, unused_hparams): + + inputs = example["inputs"] + # For Img2Img resize input and output images as desired. + example["inputs"] = resize_by_area(inputs, 8) + example["targets"] = resize_by_area(inputs, 32) + return example + + def hparams(self, defaults, unused_model_hparams): + p = defaults + p.input_modality = {"inputs": ("image:identity_no_pad", None)} + p.target_modality = ("image:identity_no_pad", None) + p.batch_size_multiplier = 256 + p.max_expected_batch_size_per_shard = 4 + p.input_space_id = 1 + p.target_space_id = 1 + + # URLs and filenames for MSCOCO data. _MSCOCO_ROOT_URL = "http://msvocds.blob.core.windows.net/" _MSCOCO_URLS = [ diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index f1b2b7dee..61716d012 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -19,9 +19,7 @@ from __future__ import division from __future__ import print_function -import glob import os -import stat import tarfile # Dependency imports @@ -115,7 +113,7 @@ def tabbed_generator(source_path, source_vocab, target_vocab, eos=None): with tf.gfile.GFile(source_path, mode="r") as source_file: for line in source_file: if line and "\t" in line: - parts = line.split("\t", maxsplit=1) + parts = line.split("\t", 1) source, target = parts[0].strip(), parts[1].strip() source_ints = source_vocab.encode(source) + eos_list target_ints = target_vocab.encode(target) + eos_list @@ -267,8 +265,9 @@ def bi_vocabs_token_generator(source_path, # English-Czech datasets _ENCS_TRAIN_DATASETS = [ [ - "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-1458/data-plaintext-format.tar", - ('tsv', 3, 2, 'data.plaintext-format/*train.gz') + ("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/" + "11234/1-1458/data-plaintext-format.tar"), + ("tsv", 3, 2, "data.plaintext-format/*train.gz") ], [ "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz", # pylint: disable=line-too-long @@ -375,25 +374,22 @@ def _compile_data(tmp_dir, datasets, filename): url = dataset[0] compressed_filename = os.path.basename(url) compressed_filepath = os.path.join(tmp_dir, compressed_filename) + generator_utils.maybe_download(tmp_dir, compressed_filename, url) - if dataset[1][0] == 'tsv': + if dataset[1][0] == "tsv": _, src_column, trg_column, glob_pattern = dataset[1] - filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) + filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern)) if not filenames: - mode = "r:gz" if compressed_filepath.endswith("gz") else "r" # *.tgz *.tar.gz + # Capture *.tgz and *.tar.gz too. + mode = "r:gz" if compressed_filepath.endswith("gz") else "r" with tarfile.open(compressed_filepath, mode) as corpus_tar: corpus_tar.extractall(tmp_dir) - filenames = glob.glob(os.path.join(tmp_dir, glob_pattern)) + filenames = tf.gfile.Glob(os.path.join(tmp_dir, glob_pattern)) for tsv_filename in filenames: if tsv_filename.endswith(".gz"): new_filename = tsv_filename.strip(".gz") - try: - generator_utils.gunzip_file(tsv_filename, new_filename) - except PermissionError: - tsvdir = os.path.dirname(tsv_filename) - os.chmod(tsvdir, os.stat(tsvdir).st_mode | stat.S_IWRITE) - generator_utils.gunzip_file(tsv_filename, new_filename) + generator_utils.gunzip_file(tsv_filename, new_filename) tsv_filename = new_filename with tf.gfile.GFile(tsv_filename, mode="r") as tsv_file: for line in tsv_file: @@ -663,17 +659,19 @@ def vocab_name(self): def generator(self, data_dir, tmp_dir, train): datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS tag = "train" if train else "dev" - data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag) vocab_datasets = [] + data_path = _compile_data(tmp_dir, datasets, "wmt_encs_tok_%s" % tag) # CzEng contains 100 gz files with tab-separated columns, so let's expect - # it is the first dataset in datasets and use the newly created *.lang{1,2} files instead. + # it is the first dataset in datasets and use the newly created *.lang{1,2} + # files for vocab construction. if datasets[0][0].endswith("data-plaintext-format.tar"): - vocab_datasets.append([datasets[0][0], - ["wmt_encs_tok_%s.lang1" % tag, "wmt_encs_tok_%s.lang2" % tag]]) + vocab_datasets.append([datasets[0][0], ["wmt_encs_tok_%s.lang1" % tag, + "wmt_encs_tok_%s.lang2" % tag]]) datasets = datasets[1:] vocab_datasets += [[item[0], [item[1][0], item[1][1]]] for item in datasets] symbolizer_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, vocab_datasets) + data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size, + vocab_datasets) return token_generator(data_path + ".lang1", data_path + ".lang2", symbolizer_vocab, EOS) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 33ce7d4a9..792241632 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import collections import functools import math @@ -36,6 +37,11 @@ from tensorflow.python.framework import function +# Struct conatining the sequences ids and order on a batch (are send to the +# expert to allow them to compute the bias mask) +BatchInfo = collections.namedtuple( + "BatchInfo", "coordinates, order") + _expert_count = 0 @@ -239,6 +245,86 @@ def add_positional_embedding_nd(x, max_length, name): return x +class LshGating(object): + """Class to split key/queries into separate buckets.""" + + def __init__(self, depth, nb_hyperplanes, nb_replicat=1, trainable=False): + """Construct the gating function parameters. + + Compute the gates for a single head. + + Args: + depth (int): Dimension of the key/queries to dispatch + nb_hyperplanes (int): Nb of vectors use to split the space. Will determine + the number of buckets (2^nb_hyperplanes - 1). + nb_replicat (int): Redundancy to avoid the edge cases (to be in one bucket + the input should be in a majority) + trainable (bool): If True, a balance loss is added to force the hyperplane + to divide the key/query space evenly + """ + self.depth = depth + self.nb_hyperplanes = nb_hyperplanes + self.nb_buckets = 2**nb_hyperplanes + self.nb_replicat = nb_replicat # Unused for now + self.trainable = trainable # Unused for now + + self.dispatchers = {} + + assert self.nb_replicat == 1 # For now + + with tf.variable_scope("lsh_gating"): + # Vectors defining the hyperplanes + self.t_vectors = tf.get_variable( + "vector", + shape=(self.depth, self.nb_hyperplanes * self.nb_replicat), + dtype=tf.float32, + trainable=self.trainable, + ) + # Projection vector from the bit space to similarity score space + self.t_group = tf.constant([ + self._idx_to_bits(i) + for i in range(self.nb_buckets) + ], dtype=tf.float32, name="group") + + def _idx_to_bits(self, i): + """Convert an group index to its bit representation.""" + bits = bin(i)[2:].zfill(self.nb_hyperplanes) # Pad the bits str with 0 + return [-1.0 if b == "0" else 1.0 for b in bits] + + @expert_utils.add_name_scope("lsh_gating") + def get_gates(self, x): + """Return the bucket id of the given tensor. + + Args: + x (tf.Tensor): float32 of shape [length, depth] + + Returns: + tf.Tensor: One-hot vector int64 of shape [heads, length, nb_buckets] + containing the id of the bucket + """ + + # The balance loss don't propagate to the rest of the network + x = tf.stop_gradient(x) + # [length, depth] * [depth, nb_vectors * replicat] + x = tf.matmul(x, self.t_vectors) + # [length, nb_vector * replicat] + x = tf.sign(x) # Get on which side of the hyperplane the keys are. + + # x = tf.reshape(x, [-1, nb_replicat, nb_vector]) + # [length, replicat, nb_vector] * [nb_vector, 2^nb_vector - 1] + + x = tf.matmul(x, self.t_group, transpose_b=True) / self.nb_hyperplanes + # We get a similarity score for each of the group between [-1, 1] + # [length, (replicat,) 2^nb_vector - 1] + # Do an argmax to get the most likely group for each replicat + x = tf.argmax(x, axis=-1) + # [length(, replicat)] + # One-hot for compatibility with the sparse dispatcher + x = tf.one_hot(x, self.nb_buckets) + # TODO(epot): Use a loss to force an even distribution + return x + + def embedding_to_padding(emb): """Calculates the padding mask based on which embeddings are all zero. @@ -368,29 +454,59 @@ def attention_bias_proximal(length): @expert_utils.add_name_scope() -def attention_bias_coordinates(batch_coordinate): +def attention_bias_batch( + batch_coordinates_q, + batch_coordinates_k=None, + condition_fn=None, +): """Generate a mask to prevent the batch to attend to each others. Args: - batch_coordinate (tf.Tensor): int32 of shape [length, 1] containing the + batch_coordinates_q (tf.Tensor): int32 of shape [length_q, 1] containing the coordinates of the batches + batch_coordinates_k (tf.Tensor): int32 of shape [length_k, 1] containing the + coordinates of the batches. If None, do self attention (q and k identical) + condition_fn (fct): A predicat function defining which type of mask build Returns: - tf.Tensor: float32 mask of shape [length, length] containing either 0 or + tf.Tensor: float32 mask of shape [length_q, length_k] containing either 0 or -infinity (-1e9) """ - batch_coord_float = tf.squeeze(batch_coordinate, 1) + if batch_coordinates_k is None: + batch_coordinates_k = batch_coordinates_q + # Convert to float first because of b/25387198 - batch_coord_float = tf.to_float(batch_coord_float) - bc_v = tf.expand_dims(batch_coord_float, 1) - bc_h = tf.expand_dims(batch_coord_float, 0) - bias_batch = bc_v - bc_h # Broadcast to create [length, length] mask + def to_float(bc): + bc = tf.squeeze(bc, 1) + bc = tf.to_float(bc) + return bc + + bc_v = tf.expand_dims(to_float(batch_coordinates_q), 1) + bc_h = tf.expand_dims(to_float(batch_coordinates_k), 0) + bias_batch = bc_h - bc_v # Broadcast to create [length_q, length_k] mask # Theshold non zeros to 1.0 - bias_batch = tf.minimum(1.0, tf.abs(bias_batch)) + bias_batch = condition_fn(bias_batch) bias_batch *= -1e9 # Set non zeros to -infinity return bias_batch +# Mask to prevent individual sequences of the same batch to attend to each other +attention_bias_coordinates = functools.partial( + attention_bias_batch, + condition_fn=lambda bias: tf.minimum(1.0, tf.abs(bias)), +) + + +# Mask similar to upper triangular mask, but allow dispatching +attention_bias_future = functools.partial( + attention_bias_batch, + # Elems can attend to themself (otherwise would use bias_batch + 1.0) + # No tf.abs to concider the order + # tf.maximum and tf.minimum to threshold the values + condition_fn=lambda bias: tf.maximum(0.0, tf.minimum(1.0, bias)), +) + + def split_last_dimension(x, n): """Reshape x so that the last dimension becomes two dimensions. @@ -539,76 +655,6 @@ def attention_image_summary(attn, image_shapes=None): tf.summary.image("attention", image, max_outputs=1) -def grouped_attention_single(num_groups, q, kv, q_gates, m_gates): - """Compute grouped attention for one batch and one head. - - q is a Tensor of queries, and kv is Tensor of keys and values - (concatenated in dimension 1). - - q_gates and m_gates are float32 Tensors containing zeros and ones. - The ones indicate which positions belong to which groups. A - key-value pair can be in zero or more groups. Each query is in one - group. A query can only pay attention to key-value pairs which are - in its group. - - In addition to the usual output, we return two additional Tensors: - q_total and m_total. - - For query position i belonging to group g, q_total[i, g] contains - log(sum(exp(q_i dot k_j))) for all keys k_j in group g. - - For memory position j belonging to group g, m_total[j, g] contains - the sum of the attention weights over all queries and that memory position. - - q_total and m_total contain zeros in positions where the - corresponding query/memory does not belong to the corresponding - group. - - Args: - num_groups: an integer - q: Tensor with shape [length_q, depth_qk] - kv: Tensor with shape [length_kv, depth_qk + depth_v] - q_gates: Tensor with shape [length_q, num_groups] - m_gates: Tensor with shape [length_kv, num_groups] - - Returns: - o: Tensor with shape [length_q, depth_v] - q_total: Tensor with shape [length_q, num_groups] - m_total: Tensor with shape [length_kv, num_groups] - """ - q_dispatcher = expert_utils.SparseDispatcher(num_groups, q_gates) - m_dispatcher = expert_utils.SparseDispatcher(num_groups, m_gates) - q_length_coordinate = q_dispatcher.expert_to_batch_indices() - m_length_coordinate = m_dispatcher.expert_to_batch_indices() - dispatched_q = q_dispatcher.dispatch(q) - dispatched_kv = m_dispatcher.dispatch(kv) - length_q = tf.shape(q)[0] - length_kv = tf.shape(kv)[0] - depth_qk = tf.shape(q)[1] - depth_v = tf.shape(kv)[1] - depth_qk - o = [] - q_totals = [] - m_totals = [] - for e in xrange(num_groups): - k, v = tf.split(dispatched_kv[e], [depth_qk, depth_v], axis=1) - logits = tf.matmul(dispatched_q[e], k, transpose_b=True) - log_weights = tf.nn.log_softmax(logits) - weights = tf.exp(log_weights) - o.append(tf.matmul(weights, v)) - # For each query, this is the log of the sum of the unnormalized weights. - q_total = tf.reshape(logits[:, :1] - log_weights[:, :1], [-1]) - q_totals.append(tf.unsorted_segment_sum( - q_total, q_length_coordinate[e], length_q)) - epsilon = 1e-3 - m_total = tf.log(tf.reduce_sum(tf.stop_gradient(weights), axis=0) + epsilon) - m_totals.append( - tf.unsorted_segment_sum(m_total, m_length_coordinate[e], length_kv)) - o = q_dispatcher.combine(o, multiply_by_gates=False) - q_total = tf.stack(q_totals, axis=1) - m_total = tf.stack(m_totals, axis=1) - return o, q_total, m_total - - def grouped_attention_multihead(query_antecedent, memory_antecedent, total_key_depth, @@ -616,10 +662,31 @@ def grouped_attention_multihead(query_antecedent, output_depth, num_heads, num_groups, - threshold=0.3, - name=None, - make_image_summary=True): - """Dot-product attention with sparsity. + memory_target_density=2.0, + multiplicative_overhead=1.25, + additive_overhead=8.0, + mask_right=False, + make_image_summary=True, + name=None): + """Multi-head dot-product attention with sparsity. + + For each attention head, the queries are partitioned into groups. + For each group, only a subset of the key-value pairs are considered. + + The choices of groups are selected based on trained predictors of + the total attention given the group inclusion. + + memory_target_density indicates the average how many groups in which + a key-value pair should participate. + + We use auxialiary losses to ensure that each group contains roughly + the same number of queries and the same number of key-value pairs. + If for a given sequence, the actual number of queries/pairs sent to + an expert exceeds this target by a factor of more than + multiplicative_overhead, then the last ones are dropped. We use + this drop-last policy to avoid bleeding information backwards, which + is necessary when using this function with autoregressive + prediction. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] @@ -629,9 +696,12 @@ def grouped_attention_multihead(query_antecedent, output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth num_groups: an integer - threshold: a floating point number - name: an optional string + memory_target_density: a floating point scalar + multiplicative_overhead: a floating point scalar + additive_overhead: a floating point scalar + mask_right: a boolean make_image_summary: a boolean + name: an optional string Returns: A Tensor with shape [batch, length_q, output_depth] @@ -667,13 +737,18 @@ def grouped_attention_multihead(query_antecedent, # These are used to determine group inclusion. # We will train these by auxiliary losses. We use stop_gradient here # to keep these losses from back-propagating to the rest of the model. + # We add biases that help balance the usage of the experts. q_pred = common_layers.conv1d( tf.stop_gradient(query_antecedent), num_heads * num_groups, 1, name="q_pred") q_pred = split_heads(q_pred, num_heads) + q_bias = tf.get_variable("q_bias", [1, num_heads, 1, num_groups]) + q_pred_biased = q_pred + q_bias m_pred = common_layers.conv1d(tf.stop_gradient( memory_antecedent), num_heads * num_groups, 1, name="m_pred") m_pred = split_heads(m_pred, num_heads) + m_bias = tf.get_variable("m_bias", [1, num_heads, 1, num_groups]) + m_pred_biased = m_pred + m_bias q *= depth_qk**-0.5 # q, kv, q_pred, m_pred are all [batch, heads, length_[q/m], ?] # now reshape them all to [batch * heads, length, ?] @@ -681,41 +756,98 @@ def grouped_attention_multihead(query_antecedent, kv = combine_first_two_dimensions(kv) q_pred = combine_first_two_dimensions(q_pred) m_pred = combine_first_two_dimensions(m_pred) - q_group = tf.argmax(q_pred, axis=2) - q_gates = tf.one_hot(q_group, num_groups, axis=-1) - m_gates = tf.to_float(tf.greater(m_pred, math.log(threshold))) - # include first memory position in all groups, to avoid zero-sized tensors. - # TODO(noam): do we need to do this for queries too? - m_gates = tf.maximum( - m_gates, tf.reshape(tf.one_hot([0], length_kv), [1, length_kv, 1])) - q_group_size = tf.reduce_sum(q_gates, 1) - m_group_size = tf.reduce_sum(m_gates, 1) - - # compute the output - o, q_total, m_total = tf.map_fn( - lambda args: grouped_attention_single(num_groups, *args), - (q, kv, q_gates, m_gates), - dtype=(tf.float32, tf.float32, tf.float32), - parallel_iterations=1) - - # compute auxiliary losses to train the predictions - q_loss = tf.nn.l2_loss((q_total - q_pred) * q_gates) + q_pred_biased = combine_first_two_dimensions(q_pred_biased) + m_pred_biased = combine_first_two_dimensions(m_pred_biased) + q_group = tf.argmax(q_pred_biased, axis=2) + q_requests = tf.one_hot(q_group, num_groups, axis=-1) + m_requests = tf.to_float(tf.greater(m_pred_biased, 0.0)) + # include first memory position in all groups, to avoid division by zero. + m_requests = tf.maximum( + m_requests, tf.reshape(tf.one_hot([0], length_kv), [1, length_kv, 1])) + q_group_size = tf.reduce_sum(q_requests, 1) + m_group_size = tf.reduce_sum(m_requests, 1) + q_group_target_size = tf.to_float(length_q) / tf.to_float(num_groups) + m_group_target_size = ( + tf.to_float(length_kv) * memory_target_density + / tf.to_float(num_groups)) + capacity_q = tf.minimum(length_q, tf.to_int32( + q_group_target_size * multiplicative_overhead + additive_overhead)) + capacity_m = tf.minimum(length_kv, tf.to_int32( + m_group_target_size * multiplicative_overhead + additive_overhead)) + q_dispatcher = expert_utils.TruncatingDispatcher(q_requests, capacity_q) + m_dispatcher = expert_utils.TruncatingDispatcher(m_requests, capacity_m) + q_gates = q_dispatcher.gates() + m_gates = m_dispatcher.gates() + dispatched_q = q_dispatcher.dispatch(q) + dispatched_kv = m_dispatcher.dispatch(kv) + # dispatched_q: [batch * num_heads, num_groups, capacity_q, depth_qk] + # dispatched_kv: + # [batch * num_heads, num_groups, capacity_m, depth_qk + depth_v] + k, v = tf.split(dispatched_kv, [depth_qk, depth_v], axis=3) + logits = tf.matmul(dispatched_q, k, transpose_b=True) + bias = tf.expand_dims((m_dispatcher.nonpadding() - 1.0) * 1e9, 2) + if mask_right: + q_coordinate = tf.to_float( + tf.expand_dims(q_dispatcher.length_coordinate(), 3)) + m_coordinate = tf.to_float( + tf.expand_dims(m_dispatcher.length_coordinate(), 2)) + bias += tf.to_float(tf.greater(m_coordinate, q_coordinate)) * -1e9 + logits += bias + log_weights = tf.nn.log_softmax(logits) + weights = tf.exp(log_weights) + # For each query, this is the log of the sum of the unnormalized weights. + q_total = tf.stop_gradient(logits[:, :, :, :1] - log_weights[:, :, :, :1]) + # For each key, this is the sum of the normalized weights. + m_total = tf.expand_dims( + tf.reduce_sum(tf.stop_gradient(weights), axis=2), -1) + o = tf.matmul(weights, v) + o = q_dispatcher.combine(o) + + o = tf.reshape(o, [batch, num_heads, length_q, depth_v]) + o = combine_heads(o) + o = common_layers.conv1d(o, output_depth, 1, name="output_transform") + + m_total = m_dispatcher.combine(m_total) + q_total = q_dispatcher.combine(q_total) + q_total = tf.squeeze(q_total, -1) + m_total = tf.squeeze(m_total, -1) + # Compute summed m predictions for all groups + m_pred_used = tf.reduce_sum(tf.exp(m_pred) * m_dispatcher.gates(), axis=2) + q_pred_used = tf.reduce_sum(q_pred * q_dispatcher.gates(), axis=2) + epsilon = 1e-3 + m_pred_used = tf.log(m_pred_used + epsilon) + m_total = tf.log(m_total + epsilon) + m_loss = tf.nn.l2_loss(m_total - m_pred_used) + q_loss = tf.nn.l2_loss( + (q_total - q_pred_used) * tf.reduce_sum(q_gates, axis=2)) + q_loss /= tf.to_float(batch * length_q) - m_loss = tf.nn.l2_loss((m_total - m_pred) * m_gates) m_loss /= tf.to_float(batch * length_kv) + # We would like the query groups to be equal sized. The group # size is discrete, so we need some trick here. We add a loss # proportional to the product of the group size and the # predictions for that group. This encourages the predictions to # decrease for groups that are too big. - q_group_deviation = (q_group_size - tf.reduce_mean( - q_group_size, axis=1, keep_dims=True)) / tf.to_float(length_kv) - q_pred_mean = tf.reduce_mean(q_pred, axis=1) - q_pred_mean -= tf.reduce_mean(q_pred_mean, axis=1, keep_dims=True) - q_balance_loss = ( - tf.reduce_sum(q_pred_mean * q_group_deviation) / tf.to_float(batch)) + q_group_deviation = (q_group_size / q_group_target_size) - 1.0 + q_balance_loss = tf.reduce_sum( + tf.reduce_mean(q_pred_biased, axis=1) * q_group_deviation + ) / tf.to_float(batch) + m_group_deviation = (m_group_size / m_group_target_size) - 1.0 + m_balance_loss = tf.reduce_sum( + tf.reduce_mean(m_pred_biased, axis=1) * m_group_deviation + ) / tf.to_float(batch) + + # The losses in this function only propagate back to variables + # defined in this function, and the losses outside of this + # function only propagate back to variables outside of this + # function. Assuming some kind of adaptive learning algorithm, + # it should not matter how much we scale the losses in this function. + # Still we scale them down a lot so that they should not show up + # much in the overall loss for the model. extra_loss_multiplier = 1e-3 - extra_loss = (q_loss + m_loss + q_balance_loss) * extra_loss_multiplier + extra_loss = q_loss + m_loss + q_balance_loss + m_balance_loss + extra_loss *= extra_loss_multiplier # Show a bunch of summaries. if (not tf.get_variable_scope().reuse and @@ -727,32 +859,45 @@ def grouped_attention_multihead(query_antecedent, tf.summary.scalar("q_loss", q_loss) tf.summary.scalar("m_loss", m_loss) tf.summary.scalar("q_balance_loss", q_balance_loss) - density = ( - tf.reduce_sum(tf.to_float(m_group_size) * tf.to_float(q_group_size)) / - tf.to_float(batch * num_heads * length_q * length_kv)) - tf.summary.scalar("density", density) + tf.summary.scalar("m_balance_loss", m_balance_loss) + tf.summary.histogram("m_pred_used", m_pred_used) + tf.summary.histogram("m_total", m_total) + tf.summary.histogram("q_pred_used", q_pred_used) + tf.summary.histogram("q_total", q_total) if make_image_summary: + # image summaries are expensive. + # So we restrict them to head_num<4, query_position<512, batch_index=0. + trunc_heads = min(4, num_heads) + trunc_length_q = tf.minimum(length_q, 512) # We recompute the attention for the first example, in an inefficient # way - masking. This lets us show pretty pictures. - # [num_heads, length_q, group] - q_gates_0 = q_gates[:num_heads, :, :] - # [num_heads, length_kv, group] - m_gates_0 = m_gates[:num_heads, :, :] - mask = tf.matmul(q_gates_0, m_gates_0, transpose_b=True) - q_0 = q[:num_heads, :, :] - k_0 = kv[:num_heads, :, :depth_qk] - att_0 = tf.nn.softmax(tf.matmul(q_0, k_0, transpose_b=True)) - hdr = tf.pow(att_0, 0.2) # for high-dynamic-range - mask_channel = mask * tf.maximum(hdr, 0.3) - image = tf.stack([hdr, mask_channel, mask_channel], axis=3) - tf.summary.image("att", image, max_outputs=num_heads) - mask_coverage = tf.reduce_sum(mask * att_0) / ( - tf.to_float(length_q) * num_heads) + # [trunc_heads, length_q, group] + q_gates_trunc = q_gates[:trunc_heads, :trunc_length_q, :] + # [trunc_heads, length_kv, group] + m_gates_trunc = m_gates[:trunc_heads, :, :] + grouping_mask = tf.matmul( + q_gates_trunc, m_gates_trunc, transpose_b=True) + q_trunc = q[:trunc_heads, :trunc_length_q, :] + k_trunc = kv[:trunc_heads, :, :depth_qk] + logits_trunc = tf.matmul(q_trunc, k_trunc, transpose_b=True) + if mask_right: + band = tf.matrix_band_part( + tf.ones([trunc_length_q, length_kv]), -1, 0) + trunc_bias = tf.expand_dims((1.0 - band) * -1e9, 0) + logits_trunc += trunc_bias + att_trunc = tf.nn.softmax(logits_trunc) + mask_coverage = tf.reduce_sum(grouping_mask * att_trunc) / ( + tf.to_float(trunc_length_q) * trunc_heads) tf.summary.scalar("coverage", mask_coverage) - - o = tf.reshape(o, [batch, num_heads, length_q, depth_v]) - o = combine_heads(o) - o = common_layers.conv1d(o, output_depth, 1, name="output_transform") + att_trunc_hdr = tf.pow(att_trunc, 0.2) # for high-dynamic-range + mask_channel = grouping_mask * tf.maximum(att_trunc_hdr, 0.3) + image = tf.stack([att_trunc_hdr, mask_channel, mask_channel], axis=3) + tf.summary.image("att", image, max_outputs=trunc_heads) + # show one group for each head. + att_per_group = tf.expand_dims(weights[:trunc_heads, 0, :, :], -1) + tf.summary.image( + "att_per_group_%d", tf.pow(att_per_group, 0.2), + max_outputs=trunc_heads) return o, extra_loss @@ -2039,6 +2184,7 @@ def parameter_attention(x, return y +@expert_utils.add_name_scope() def coordinate_tensor(shape, axis): """Return a tensor with given shape containing coordinte along given axis. @@ -2050,6 +2196,8 @@ def coordinate_tensor(shape, axis): A tensor with shape shape and type tf.int32, where each elements its coordinate along the given axis. """ + if axis < 0: + axis = tf.size(shape) + axis # Convert to positive for the one_hot indice r = tf.range(shape[axis]) r_shape = tf.one_hot( @@ -2223,7 +2371,163 @@ def local_expert_attention( @expert_utils.add_name_scope() -def sparse_dot_product_attention(q, k, v, bc, experts_params): +def expert_dot_product(q, k, v, info_q, info_k): + """Perform dot product on a subset of the sequence. + + Can add a mask to the attention to prevent sequences to attend to each other + and to prevent attention to the futur. + + Args: + q (tf.Tensor): Queries of shape [length_expert_q, depth_k] + k (tf.Tensor): Keys of shape [length_expert_k, depth_k] + v (tf.Tensor): Values of shape [length_expert_k, depth_v] + info_q (BatchInfo): Batch info for queries. If None, no mask is added + info_k (BatchInfo): Batch info for keys + + Returns: + tf.Tensor: dot product attention output ([length_expert_q, depth_v]) + """ + + length_q = tf.shape(q)[0] + length_k = tf.shape(k)[0] + depth_v = v.get_shape().as_list()[-1] + + # Create the mask + bias = attention_bias_coordinates(info_q.coordinates, info_k.coordinates) + if info_k.order is not None: + bias += attention_bias_future(info_q.order, info_k.order) + + # Restore batch and head dimension + q, k, v = [tf.expand_dims(tf.expand_dims(t, 0), 0) for t in (q, k, v)] + + def is_zero(): + zeros = tf.zeros(shape=[1, 1, length_q, depth_v], dtype=tf.float32) + zeros = tf.Print(zeros, [length_k, length_q], "length_k/length_q: ") + return zeros + + def is_not_zero(): + return dot_product_attention( + q, k, v, + bias=bias, + # No image summary to avoid "Retval[0] does not have value" (because + # inside a condition) + make_image_summary=False, + ) + + # TODO(epot): Should make sure a query gets at least one key. Because the + # different sequences of a batch are merged, it's possible that a + # query from a sequence only receive memory from another sequence, so + # with the mask, the query will perform a softmax on -infinity values. + # A hack could be to add at least one sequence of each batch on each group so + # the query can attend to at least one element. + # Softmax(Q.K)*V + v_out = tf.cond( + tf.logical_or(tf.equal(length_q, 0), tf.equal(length_k, 0)), + is_zero, + is_not_zero, + ) + + # Remove batch and head dimension + v_out = tf.squeeze(v_out, axis=0) + v_out = tf.squeeze(v_out, axis=0) + return v_out + + +@expert_utils.add_name_scope() +def dot_product_single_head(q, k, v, gates_q, gates_k, bi): + """Perform a dot product attention on a single sequence on a single head. + + This function dispatch the q, k, v and loop over the buckets to compute the + attention dot product on each subsequences. + + Args: + q (tf.Tensor): [length_q, depth_q] + k (tf.Tensor): [length_k, depth_q] + v (tf.Tensor): [length_k, depth_v] + gates_q (tf.Tensor): One-hot vector of shape [length_q, nb_buckets] + gates_k (tf.Tensor): One-hot vector of shape [length_k, nb_buckets] + bi (BatchInfo): Contains the batch coordinates and sequence order + + Returns: + tf.Tensor: [length_q, depth_v] + """ + + nb_buckets = gates_q.get_shape().as_list()[-1] + + q_dispatcher = expert_utils.SparseDispatcher(nb_buckets, gates_q) + k_dispatcher = expert_utils.SparseDispatcher(nb_buckets, gates_k) + + def eventually_dispatch(dispatcher, value): + if value is not None: + return dispatcher.dispatch(value) + return [None] * nb_buckets + + # Iterate over every dispatched group + list_v_out = [] + for ( + q, + k, + v, + qbc, + qbo, + kbc, + kbo, + ) in zip( + # Dispatch queries, keys and values + q_dispatcher.dispatch(q), + k_dispatcher.dispatch(k), + k_dispatcher.dispatch(v), + # Also dispatch the sequence positions and batch coordinates + eventually_dispatch(q_dispatcher, bi.coordinates), + eventually_dispatch(q_dispatcher, bi.order), + eventually_dispatch(k_dispatcher, bi.coordinates), + eventually_dispatch(k_dispatcher, bi.order), + ): + list_v_out.append(expert_dot_product( + q, k, v, + info_q=BatchInfo(coordinates=qbc, order=qbo), + info_k=BatchInfo(coordinates=kbc, order=kbo) + )) + + # Combine all buckets together to restore the original length + return q_dispatcher.combine(list_v_out) + + +def map_fn_switch(fn, elems, use_map_fn=True, **kwargs): + """Construct the graph with either tf.map_fn or a python for loop. + + This function is mainly for for benchmarking purpose. + + tf.map_fn is dynamic but is much slower than creating a static graph with + for loop. However, having a for loop make the graph much longer to build + and can consume too much RAM on distributed setting. + + Args: + fn (fct): same that tf.map_fn but for now can only return a single tensor + value (instead of a tuple of tensor for the general case) + elems (tuple): same that tf.map_fn + use_map_fn (bool): If True, tf.map_fn is used, if False, for _ in _: is used + instead + **kwargs: Additional tf.map_fn arguments (ignored if use_map_fn is False) + + Returns: + tf.Tensor: the output of tf.map_fn + """ + if use_map_fn: + return tf.map_fn(fn, elems, **kwargs) + else: + elems_unpacked = ( + tf.unstack(e) for e in elems + ) + out_unpacked = [ + fn(e) for e in zip(*elems_unpacked) + ] + out = tf.stack(out_unpacked) + return out + + +@expert_utils.add_name_scope() +def sparse_dot_product_attention(q, k, v, bi, use_map_fn, experts_params): """Sparse multihead self attention. Perform an approximation of the full multihead attention by dispatching @@ -2237,100 +2541,488 @@ def sparse_dot_product_attention(q, k, v, bc, experts_params): contains the elements from all different batches) * Right now, only self attention is supported so length_q and length_kv should be identical and the function will add triangular mask. - * The bias is added inside this function to prevent attention to the future. + * If bi.order is not None, The bias is added inside this function to + prevent attention to the future. Args: - q (tf.Tensor): Queries of shape [1, heads, length_q, depth_k] - k (tf.Tensor): Keys of shape [1, heads, length_q, depth_k] - v (tf.Tensor): Values of shape [1, heads, length_kv, depth_v] - bc (tf.Tensor): Batch coordinates of shape [1, length_q, 1] + q (tf.Tensor): Queries of shape [batch, heads, length_q, depth_k] + k (tf.Tensor): Keys of shape [batch, heads, length_q, depth_k] + v (tf.Tensor): Values of shape [batch, heads, length_kv, depth_v] + bi (BatchInfo): Contains the batch coordinates and sequence order + use_map_fn (bool): Use either tf.map_fn of python for loop to compute the + heads separately experts_params (dict): Additional params for the local expert Returns: tf.Tensor: Approximation of Softmax(Q.K) * V, of shape - [1, heads, length_q, depth_v] + [batch, heads, length_q, depth_v] + """ + batch_size, nb_heads, _, depth = q.get_shape().as_list() + batch_size = batch_size or tf.shape(q)[0] + + @expert_utils.add_name_scope() + def flatten_first_dims(x): + # Case 1: Either constant batch size of size 1 or batch already flattened + if x.get_shape().as_list()[0] == 1: + return tf.squeeze(x, axis=0) + # Case 2: Flatten batch dimension + else: + x = tf.transpose(x, perm=[1, 0, 2, 3]) + x = tf.reshape(x, [nb_heads, -1, depth]) + return x + + def flatten_batch(x): + if x is None: + return x + return expert_utils.flatten_all_but_last(x) + + q = flatten_first_dims(q) + k = flatten_first_dims(k) + v = flatten_first_dims(v) + bi = BatchInfo( + coordinates=flatten_batch(bi.coordinates), + order=flatten_batch(bi.order), + ) + + # Unstack heads + list_q = tf.unstack(q) # list[tf.Tensor(shape=[batch * length, depth])] + list_k = tf.unstack(k) + list_v = tf.unstack(v) + + list_gates_q = [] + list_gates_k = [] + + total_loss = 0.0 + # There might be a more optimized way to compute all heads at once + for single_q, single_k, _ in zip(list_q, list_k, list_v): + # Each head get its own dispatcher + lhs_gating = LshGating( + depth=single_q.get_shape().as_list()[-1], + **experts_params + ) + + list_gates_q.append(lhs_gating.get_gates(single_q)) + list_gates_k.append(lhs_gating.get_gates(single_k)) + + gates_q = tf.stack(list_gates_q) + gates_k = tf.stack(list_gates_k) + + # Process each head separatly + v_out = map_fn_switch( + lambda args: dot_product_single_head(bi=bi, *args), + elems=(q, k, v, gates_q, gates_k), + dtype=(tf.float32), + parallel_iterations=2, + # back_prop=True, + # swap_memory=False, + # infer_shape=True, + # name=None + use_map_fn=use_map_fn, + ) + + # Restore original shape as expected by multihead_attention + if isinstance(batch_size, int) and batch_size == 1: + v_out = tf.expand_dims(v_out, axis=0) # Restore batch_size = 1 + else: + v_out = tf.reshape(v_out, [nb_heads, batch_size, -1, depth]) + v_out = tf.transpose(v_out, [1, 0, 2, 3]) + return v_out, total_loss / nb_heads + + +@expert_utils.add_name_scope() +def dot_product_batched_head(q, k, v, gates_q, gates_k, mask_right=False): + """Perform a dot product attention on a single sequence on a single head. + + This function dispatch the q, k, v and loop over the buckets to compute the + attention dot product on each subsequences. + + Args: + q (tf.Tensor): [batch*heads, length_q, depth_q] + k (tf.Tensor): [batch*heads, length_k, depth_q] + v (tf.Tensor): [batch*heads, length_k, depth_v] + gates_q (tf.Tensor): One-hot of shape [batch*heads, length_q, nb_buckets] + gates_k (tf.Tensor): One-hot of shape [batch*heads, length_k, nb_buckets] + mask_right (bool): Add a bias to prevent attention to the future + + Returns: + tf.Tensor: [length_q, depth_v] + """ + nb_buckets = tf.shape(gates_q)[-1] + + @expert_utils.add_name_scope() + def get_dispatcher(gates): + length = tf.shape(gates)[1] + # Count the number of ones per batch (and keep the max value) + nb_elems_to_dispatch = tf.reduce_sum(gates, axis=[1, 2]) + nb_elems_to_dispatch = tf.reduce_max(nb_elems_to_dispatch) + nb_elems_to_dispatch = tf.to_int32(nb_elems_to_dispatch) + capacity = nb_elems_to_dispatch // nb_buckets * 2 # Capacity is hardcoded + capacity = tf.minimum(length, capacity) + tf.summary.scalar("dispatch_capacity", capacity, family="lsh") + return expert_utils.TruncatingDispatcher(gates, capacity) + + def add_summary_capacity(x, prefix): + # Monitor if capacity overflow + x = x[0, ...] # Take first batch/head + x = tf.reduce_sum(x, axis=0) + tf.summary.scalar(prefix + "_min", tf.reduce_min(x), family="lsh") + tf.summary.scalar(prefix + "_max", tf.reduce_max(x), family="lsh") + tf.summary.histogram(prefix + "capacity_distribution", x, family="lsh") + for i in range(3): # Show the first 3 buckets + tf.summary.scalar("{}_{}".format(prefix, i), x[i], family="lsh") + add_summary_capacity(gates_q, "q") + add_summary_capacity(gates_k, "k") + + q_dispatcher = get_dispatcher(gates_q) + k_dispatcher = get_dispatcher(gates_k) + + q = q_dispatcher.dispatch(q) + k = k_dispatcher.dispatch(k) + v = k_dispatcher.dispatch(v) + + # Bias of shape [batch*heads, nb_buckets, 1, capacity] broadcasted to every + # queries + bias = tf.expand_dims((k_dispatcher.nonpadding() - 1.0) * 1e9, 2) + if mask_right: + q_coordinate = tf.to_float( + tf.expand_dims(q_dispatcher.length_coordinate(), 3)) + k_coordinate = tf.to_float( + tf.expand_dims(k_dispatcher.length_coordinate(), 2)) + bias += tf.to_float(tf.greater(k_coordinate, q_coordinate)) * -1e9 + # The sequence padding is not masked but is ignored on the next layers + + # q, k, v now have shape [batch*heads, nb_bucket, capacity, depth] + # The buckets can be seen as different heads + v_out = dot_product_attention(q, k, v, bias=bias) + + # Combine all buckets together to restore the original length + return q_dispatcher.combine(v_out) + + +@expert_utils.add_name_scope() +def sparse_dot_product_attention_truncated( + q, k, v, + bi, # Unused + experts_params, + use_map_fn=False, # Unused + mask_right=False, +): # pylint: disable=unused-argument + """Sparse multihead self attention. + + Perform an approximation of the full multihead attention by dispatching + the tokens using their keys/values. Thus the attention matrix are only + computed each times on a subset of the tokens. + + Notes: + * The function don't perform scaling here (multihead_attention does + the /sqrt(depth)). + * The padding should have been removed (so batch size should be 1 but length + contains the elements from all different batches) + * Right now, only self attention is supported so length_q and length_kv + should be identical and the function will add triangular mask. + * If bi.order is not None, The bias is added inside this function to + prevent attention to the future. + + Args: + q (tf.Tensor): Queries of shape [batch, heads, length_q, depth_k] + k (tf.Tensor): Keys of shape [batch, heads, length_q, depth_k] + v (tf.Tensor): Values of shape [batch, heads, length_kv, depth_v] + bi (BatchInfo): Contains the batch coordinates and sequence order + experts_params (dict): Additional params for the local expert + use_map_fn (bool): Use either tf.map_fn of python for loop to compute the + heads separately + mask_right (bool): + Returns: + tf.Tensor: Approximation of Softmax(Q.K) * V, of shape + [batch, heads, length_q, depth_v] """ + # Currently depth is the same for for q and v + batch_size, nb_heads, _, depth = q.get_shape().as_list() + batch_size = batch_size or tf.shape(q)[0] + + total_loss = 0.0 - assert q.get_shape().as_list()[0] == 1 - assert k.get_shape().as_list()[0] == 1 - assert v.get_shape().as_list()[0] == 1 + # Each head get its own dispatcher + list_lsh = [ + LshGating( + depth=depth, + **experts_params + ) for _ in range(nb_heads) + ] @expert_utils.add_name_scope() - def unpack_heads(x): - # Flatten the batch. squeeze works because batch_size = 1 (otherwise could - # use tf.transpose and flatten after unpacking) - x = tf.squeeze(x, axis=0) - list_x = tf.unstack(x) - return list_x # list[tf.Tensor(shape=[batch * length, depth])] - - bc = tf.squeeze(bc, axis=0) - list_q = unpack_heads(q) - list_k = unpack_heads(k) - list_v = unpack_heads(v) + def get_gates_head(x, add_first=False): + """Return the gates for each heads of the current x. + + Args: + x (tf.Tensor): of shape [batch, heads, length, depth] + add_first (bool): if True, add the first element on each bucket + + Returns: + tf.Tensor: gates of shape [batch, heads, length, num_buckets] + """ + length = tf.shape(x)[2] + + # Invert heads/batch + x = tf.transpose(x, perm=[1, 0, 2, 3]) + x = tf.reshape(x, [nb_heads, batch_size*length, depth]) + + list_x = tf.unstack(x) # list[tf.Tensor(shape=[batch * length, depth])] + + # Unstack heads + list_gates = [] + # There might be a more optimized way to compute all heads at once + for lsh, single_x in zip(list_lsh, list_x): + # Each head get its own dispatcher + gates = lsh.get_gates(single_x) + nb_buckets = gates.get_shape().as_list()[-1] + # Reshape to [batch, length, depth] but should concider sequence + # padding in that case (also dispatch the padding) + gates = tf.reshape(gates, [batch_size, length, nb_buckets]) + list_gates.append(gates) + + gates = tf.stack(list_gates) + + # Restore original shape + gates = tf.reshape(gates, [nb_heads, batch_size, length, nb_buckets]) + gates = tf.transpose(gates, [1, 0, 2, 3]) + + # Dispatch the first element to every gates to avoid empty buckets + if add_first: + gates = tf.maximum( + gates, + tf.reshape(tf.one_hot([0], length), [1, 1, length, 1]) + ) + + return gates + + gates_q = get_gates_head(q) + gates_k = get_gates_head(k, add_first=True) + + # [batch, heads, length, depth] => [batch*heads, length, depth] + q, k, v, gates_q, gates_k = [ + combine_first_two_dimensions(t) for t in (q, k, v, gates_q, gates_k)] + + v_out = dot_product_batched_head(q, k, v, gates_q, gates_k, mask_right) + + # Restore original dimension + v_out = tf.reshape(v_out, [batch_size, nb_heads, -1, depth]) + + return v_out, total_loss / nb_heads + + +@expert_utils.add_var_scope() +def deconv_elems_1d(x, factor, out_depth=None): + """Increase the length and change the dimensionality. + Expand/project each positions of dim depth of the input into + factor*tokens of dim out_depth + + Args: + x (tf.Tensor): shape [batch_size, length, depth] + factor (int): Multiplicative factor of each tokens. + out_depth (int): Output depth (if None, keep depth constant) + + Returns: + tf.Tensor: shape [batch_size, length*factor, out_depth] + """ + out_depth = out_depth or x.get_shape().as_list()[-1] + x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth] + x = tf.layers.conv2d_transpose( + inputs=x, + filters=out_depth, + kernel_size=(1, factor), + strides=(1, factor), + padding="valid", + data_format="channels_last", + ) # [batch_size, 1, length*factor, out_depth] + x = tf.squeeze(x, 1) # [batch_size, length*factor, depth] + return x + + +@expert_utils.add_var_scope() +def conv_elems_1d(x, factor, out_depth=None): + """Decrease the length and change the dimensionality. + + Merge/restore/compress factors positions of dim depth of the input into + a single position of dim out_depth. + This is basically just a strided convolution without overlapp + between each strides. + The original length has to be divided by factor. + + Args: + x (tf.Tensor): shape [batch_size, length, depth] + factor (int): Length compression factor. + out_depth (int): Output depth + + Returns: + tf.Tensor: shape [batch_size, length//factor, out_depth] + """ + out_depth = out_depth or x.get_shape().as_list()[-1] + # with tf.control_dependencies( # Dynamic assertion + # [tf.assert_equal(tf.shape(x)[1] % factor, 0)]): + x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth] + x = tf.layers.conv2d( + inputs=x, + filters=out_depth, + kernel_size=(1, factor), + strides=(1, factor), + padding="valid", + data_format="channels_last", + ) # [batch_size, 1, length//factor, out_depth] + x = tf.squeeze(x, 1) # [batch_size, length//factor, depth] + return x + + +@expert_utils.add_var_scope() +def local_reduction_attention(x, block_length, multihead_params): + """Reduce the length dimension using self attention. + + Args: + x (tf.Tensor): float32 of shape [batch, length, depth] + block_length (int): Block length for local attention (Compression factor) + multihead_params (dict): parameters for multihead attention + + Returns: + tf.Tensor: Compressed tensor of shape [batch, length // factor, depth] + """ @expert_utils.add_name_scope() - def expert_dot_product(x, q, k, v, bc): - """Perform dot product on a subset of the sequence. + def dot_product_self_local_attention_flattened(q, k, v): + """Strided block local self-attention. + + No overlapp between the blocks. Args: - x (tf.Tensor): Unused but forwarded by local_moe - q (tf.Tensor): Queries of shape [length_expert, depth_k] - k (tf.Tensor): Queries of shape [length_expert, depth_k] - v (tf.Tensor): Queries of shape [length_expert, depth_v] - bc (tf.Tensor): Batch coordinates of shape [length_expert, 1] + q (tf.Tensor): shape [batch, heads, length, depth_k] + k (tf.Tensor): shape [batch, heads, length, depth_k] + v (tf.Tensor): shape [batch, heads, length, depth_v] Returns: - tf.Tensor: dot product attention output ([length_expert, depth_v]) + tf.Tensor: shape [batch, heads, length, depth_v] """ - length = tf.shape(x)[0] - - # Mask between the sequences - bias_batch = attention_bias_coordinates(bc) - # Mask to prevent sequences of attenting to the future - bias_past = tf.reshape( - attention_bias_lower_triangle(length), [length, length]) - bias = bias_batch + bias_past # bias has shape [length, length] - bias = tf.reshape(bias, [1, 1, length, length]) - - # Restore batch and head dimension - q, k, v = [tf.expand_dims(tf.expand_dims(t, 0), 0) for t in (q, k, v)] - # Softmax(Q.K)*V - v_out = dot_product_attention(q, k, v, bias=bias) - # Remove batch and head dimension - v_out = tf.squeeze(v_out, axis=0) - v_out = tf.squeeze(v_out, axis=0) + _, num_head, _, depth = q.get_shape().as_list() + + # Extract the blocks + def pad_and_reshape(x): + """Split the length dim into [num_block, block_length].""" + length_x = tf.shape(x)[2] + # Add some padding, but won't matter as the last block will never be + # attended by the query (after compression) + x = tf.pad(x, [ + [0, 0], + [0, 0], + [0, -length_x % block_length], + [0, 0] + ]) + x = tf.reshape(x, [ + tf.shape(x)[0], # Batch + num_head, # Head + tf.shape(x)[2] // block_length, # Num blocks + block_length, # Block length + depth, # Depth + ]) + return x + + q, k, v = [pad_and_reshape(t) for t in (q, k, v)] + + # Perform attention on the flattened dot product + logits = tf.matmul(q, k, transpose_b=True) + logits = tf.reshape(logits, [ + tf.shape(logits)[0], # Batch + num_head, # Head + tf.shape(logits)[2], # Num blocks + block_length**2, # Flatten last dimension + ]) + weights = tf.nn.softmax(logits) + weights = tf.reshape(weights, [ + tf.shape(weights)[0], # Batch + num_head, # Head + tf.shape(weights)[2], # Num blocks + block_length, + block_length, # Restore the block length dimension + ]) + weights = tf.reduce_sum(weights, axis=3, keep_dims=True) # Compress block + v_out = tf.matmul(weights, v) # [1, block_length] @ [block_length, depth] + v_out = tf.squeeze(v_out, axis=3) return v_out - list_v_out = [] - total_loss = 0.0 - for q, k, v in zip(list_q, list_k, list_v): - # Each head get its own dispatcher + return multihead_attention( + x, + None, + bias=None, + output_depth=x.get_shape().as_list()[-1], + attention_type=dot_product_self_local_attention_flattened, + **multihead_params + ) - # TODO(epot): Choose which dispatcher use here on the k/q pair (either - # noisy_top_k_gating or Locality-sensitive hashing) - - # Concatenate along the depth axis - x = tf.concat([q, k], axis=-1) # Works because q and k lengths are the same - - # Compute the attention on the sparse tokens - v_out, loss = expert_utils.local_moe( - x=x, - expert_fn=expert_dot_product, - additional_dispatch_params=dict( - q=q, - k=k, - v=v, - bc=bc - ), - **experts_params - ) - list_v_out.append(v_out) - total_loss += loss - # Restore original shape as expected by multihead_attention - v_out = tf.stack(list_v_out) # Merge heads - v_out = tf.expand_dims(v_out, axis=0) - return v_out, total_loss / len(list_v_out) +@expert_utils.add_var_scope() +def multihead_self_attention_reduced( + x, factor, reduction_type, multihead_params): + """Reduce the length dimension by compressing with conv. + + Args: + x (tf.Tensor): float32 of shape [batch, length, depth] + factor (int): compression factor for the memory sequence + reduction_type (str): type of compression + multihead_params (dict): parameters for multihead attention + + Returns: + (tf.Tensor): float32 of shape [batch, length, depth] + + Raises: + ValueError: If reduction_type invalid + """ + depth = x.get_shape().as_list()[-1] + + # Could try to have some overlapp between the blocks but that would + # create conv artifacts, would make it difficult to not attend to the future + # withing one group and the padding should be handled specially. + + # Reduce the memory dimension + if reduction_type == "attention": + memory_x = local_reduction_attention(x, factor, multihead_params) + elif reduction_type == "conv": + # With valid padding, the last block won't be computed (not attended anyway) + memory_x = conv_elems_1d(x, factor) + else: + raise ValueError("Unknown reduction type {}".format(reduction_type)) + + memory_x = tf.concat( + # Add the first elem to make it attendable by everyone (otherwise the + # first block cannot attend to anything) + [x[:, :1, :], memory_x], + axis=1, + ) + + # Construct the bias + @expert_utils.add_name_scope() + def construct_bias_vectors(t, axis): + length = tf.to_float(tf.shape(t)[1]) + length_coordinates = tf.range(length, dtype=tf.float32) + length_coordinates = tf.expand_dims(length_coordinates, axis=axis) + # [1, length_k] or [length_q, 1] + return length_coordinates + + bias = tf.to_float(tf.greater( + # Because we add the first elem to the memory block and it can be attended + # by anyone,we don't need to add +1 anymore to prevent self attention + # Use * factor to make sure the last tokens of a block cannot attend the + # block + construct_bias_vectors(memory_x, 0) * factor, + # +epsilon to avoid float equality + construct_bias_vectors(x, 1) + 1e-3, + )) * -1e9 + bias = tf.expand_dims(bias, axis=0) + bias = tf.expand_dims(bias, axis=0) # [1, 1, length_k, length_q] + + return multihead_attention( + query_antecedent=x, + memory_antecedent=memory_x, + bias=bias, + output_depth=depth, + **multihead_params + ) def scaled_dot_product_attention_simple(q, k, v, bias, name=None): @@ -2482,3 +3174,6 @@ def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias): multihead_attention_sparse_dot_prod = functools.partial( multihead_attention, attention_type=sparse_dot_product_attention) + +multihead_attention_sparse_truncated = functools.partial( + multihead_attention, attention_type=sparse_dot_product_attention_truncated) diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index ef67b0d8e..6f4a6a37c 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -258,6 +258,64 @@ def testDotProductAttentionRelative(self): res = session.run(a) self.assertEqual(res.shape, (5, 7, 12, 32)) + def testBiasBatchCoordinates(self): + """Testing the batch cooridnates mask.""" + q = tf.constant([0, 0, 1, 1, 1, 1, 2, 2, 2], dtype=tf.int32) + q = tf.expand_dims(q, axis=-1) + + k = tf.constant([0, 0, 0, 2, 2, 3, 3, 3], dtype=tf.int32) + k = tf.expand_dims(k, axis=-1) + + ground_truth = np.array([ + [0, 0, 0, 1, 1, 1, 1, 1], # 0 + [0, 0, 0, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], # 1 (just masked) + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 0, 0, 1, 1, 1], # 2 + [1, 1, 1, 0, 0, 1, 1, 1], + [1, 1, 1, 0, 0, 1, 1, 1], + ], np.float32) * -1e9 + + bias = common_attention.attention_bias_coordinates(q, k) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + self.assertAllClose( + bias.eval(), + ground_truth, + ) + + def testBiasFuture(self): + """Testing the sequence order mask.""" + q = tf.constant([0, 1, 2, 3, 0, 1, 2, 0, 1], dtype=tf.int32) + q = tf.expand_dims(q, axis=-1) + + k = tf.constant([0, 1, 2, 3, 4, 0, 1, 2], dtype=tf.int32) + k = tf.expand_dims(k, axis=-1) + + ground_truth = np.array([ + [0, 1, 1, 1, 1, 0, 1, 1], # 0 + [0, 0, 1, 1, 1, 0, 0, 1], # 1 + [0, 0, 0, 1, 1, 0, 0, 0], # 2 + [0, 0, 0, 0, 1, 0, 0, 0], # 3 + [0, 1, 1, 1, 1, 0, 1, 1], # 0 + [0, 0, 1, 1, 1, 0, 0, 1], # 1 + [0, 0, 0, 1, 1, 0, 0, 0], # 2 + [0, 1, 1, 1, 1, 0, 1, 1], # 0 + [0, 0, 1, 1, 1, 0, 0, 1], # 1 + ], np.float32) * -1e9 + + bias = common_attention.attention_bias_future(q, k) + + with self.test_session() as session: + session.run(tf.global_variables_initializer()) + self.assertAllClose( + bias.eval(), + ground_truth, + ) + if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index d3ebfdffe..d2d8bb2e5 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -62,6 +62,7 @@ def basic_params1(): learning_rate_cosine_cycle_steps=250000, learning_rate=0.1, sampling_method="argmax", # "argmax" or "random" + sampling_temp=1.0, # temperature for sampling problem_choice="adaptive", # "uniform", "adaptive", "distributed" # expand the logits a piece at a time - saves memory. factored_logits=int(False), @@ -93,6 +94,9 @@ def basic_params1(): # epsilon parameter to normalization function norm_epsilon=1e-6, symbol_modality_num_shards=16, + # During training, we drop sequences whose inputs and targets are shorter + # than min_length + min_length=0, # During training, we drop sequences whose inputs or targets are longer # than max_length. # If max_length==0, we use hparams.batch_size instead. @@ -155,7 +159,23 @@ def basic_params1(): # position in the inputs portion can see the # entire inputs portion. This removes the challenge of # autoregressively predicting the inputs portion. - prepend_mode="none",) + prepend_mode="none", + # Scheduled sampling is interesting for auto-regressive models. + # It runs an additional step using the generated output as autoregressive + # targets, which can improve the models inference results later. The + # parameter scheduled_sampling_prob determines with what probability + # will such additional step be run. It's turned off (0.0) by default. + # This probability will exponentially warm up for the number of + # steps determined by scheduled_sampling_warmup_steps. + # The tensor used for the second step will consist of outputs from + # the first step mixed with gold truth, with the proportion of gold + # determined by scheduled_sampling_gold_mixin_prob. + scheduled_sampling_prob=0.0, + scheduled_sampling_warmup_steps=50000, + scheduled_sampling_gold_mixin_prob=0.5, + # This is the actual batch size, *not* tokens per batch (i.e. for + # language models this is the number of sentences in the batch) + tpu_batch_size_per_shard=24,) class RangedHParams(object): diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index 1923a9e24..08fd2f56b 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -1697,7 +1697,10 @@ def body(): def underlying_variable_ref(t): - """Find the underlying variable ref, ignoring Identity ops. + """Find the underlying variable ref. + + Traverses through Identity, ReadVariableOp, and Enter ops. + Stops when op type has Variable or VarHandle in name. Args: t: a Tensor @@ -1705,9 +1708,11 @@ def underlying_variable_ref(t): Returns: a Tensor that is a variable ref, or None on error. """ - while t.op.type == "Identity": + while t.op.type in ["Identity", "ReadVariableOp", "Enter"]: t = t.op.inputs[0] - if "Variable" in t.op.type: + + op_type = t.op.type + if "Variable" in op_type or "VarHandle" in op_type: return t else: return None @@ -1938,13 +1943,13 @@ def _fn_with_custom_grad(fn, inputs, grad_fn, use_global_vars=False): Returns: fn(*inputs) """ - with tf.variable_scope(None, default_name="fn_with_custom_grad") as vs: - inputs = list(inputs) - outputs = fn(*inputs) - if use_global_vars: - train_vars = list(vs.global_variables()) - else: - train_vars = list(vs.trainable_variables()) + vs = tf.get_variable_scope() + get_vars_fn = (vs.global_variables if use_global_vars else + vs.trainable_variables) + len_before_vars = len(get_vars_fn()) + inputs = list(inputs) + outputs = fn(*inputs) + train_vars = get_vars_fn()[len_before_vars:] if grad_fn is None: return outputs diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index 5804e4d8f..1eb988c4c 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -365,8 +365,7 @@ def grad_fn(inputs, variables, outputs, output_grads): @common_layers.fn_with_custom_grad(grad_fn) def fn_with_recompute(*args): - with tf.variable_scope(None, default_name="recompute") as vs: - cached_vs.append(vs) - return fn(*args) + cached_vs.append(tf.get_variable_scope()) + return fn(*args) return fn_with_recompute(*args) diff --git a/tensor2tensor/models/aligned.py b/tensor2tensor/models/aligned.py index abfecbaed..a0e92da94 100644 --- a/tensor2tensor/models/aligned.py +++ b/tensor2tensor/models/aligned.py @@ -69,6 +69,12 @@ def postprocess(x, y): extra_loss = 0.0 ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")] moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + if hparams.mask_right: + def _bias(x): + return common_attention.attention_bias_lower_triangle(tf.shape(x)[1]) + bias = dp(_bias, x) + else: + bias = tf.zeros([1, 1, 1, 1]) if hparams.diet_experts: hsize, = moe_hidden_sizes @@ -97,13 +103,16 @@ def _diet_expert(x): common_attention.multihead_attention, x, None, - None, # bias + bias, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_grouped": + multiplicative_overhead = ( + hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN + else hparams.multiplicative_overhead_eval) y, loss = dp( common_attention.grouped_attention_multihead, x, @@ -113,24 +122,18 @@ def _diet_expert(x): hparams.hidden_size, hparams.num_heads, num_groups=hparams.attention_num_groups, + memory_target_density=hparams.memory_target_density, + multiplicative_overhead=multiplicative_overhead, make_image_summary=hparams.attention_image_summary, + mask_right=hparams.mask_right, ) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_memory_efficient": assert hparams.layer_preprocess_sequence == "n" - zero_bias = tf.zeros([1, 1, 1, 1]) - y = dp( - common_attention.multihead_self_attention_memory_efficient, - x, - zero_bias, - hparams.num_heads) - elif layer_type == "att_memory_efficient": - assert hparams.layer_preprocess_sequence == "n" - zero_bias = tf.zeros([1, 1, 1, 1]) y = dp( common_attention.multihead_self_attention_memory_efficient, x, - zero_bias, + bias, hparams.num_heads) elif layer_type == "att_local": y = dp( @@ -143,7 +146,9 @@ def _diet_expert(x): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - attention_type="local_unmasked", + attention_type=( + "local_mask_right" if hparams.mask_right + else "local_unmasked"), block_length=hparams.local_attention_window, block_width=hparams.local_attention_window) elif layer_type == "att_pseudolocal": @@ -153,7 +158,7 @@ def _pseudolocal_bias(x): return common_attention.attention_bias_local( tf.shape(x)[1], hparams.local_attention_window, - hparams.local_attention_window) + 0 if hparams.mask_right else hparams.local_attention_window) pseudolocal_bias = dp(_pseudolocal_bias, x) y = dp( common_attention.multihead_attention, @@ -174,12 +179,39 @@ def _pseudolocal_bias(x): attention_num_experts=hparams.attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, - mask_right=False, + mask_right=hparams.mask_right, split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n + elif layer_type == "att_lsh": + if hparams.lsh_truncated: + attention_fn = common_attention.multihead_attention_sparse_truncated + else: + attention_fn = common_attention.multihead_attention_sparse_dot_prod + y, loss = dp( + attention_fn, + x, + None, + None, # Bias is computed inside + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + + # Additional parameters + bi=[common_attention.BatchInfo( + coordinates=batch_coordinate[i], + order=None, # No future mask + ) for i in range(dp.n)], + use_map_fn=False, + experts_params=dict( + nb_hyperplanes=4, + ) + ) + extra_loss += tf.add_n(loss) / dp.n elif layer_type == "moe": y, loss = expert_utils.distributed_moe( dp, @@ -287,7 +319,15 @@ def aligned_base(): hparams.add_hparam("memory_efficient_ffn", int(False)) hparams.add_hparam("local_attention_window", 128) hparams.add_hparam("attention_num_groups", 8) + hparams.add_hparam("memory_target_density", 2.0) + hparams.add_hparam("multiplicative_overhead", 1.25) + hparams.add_hparam("multiplicative_overhead_eval", 2.0) hparams.add_hparam("attention_image_summary", int(True)) + # LSH params + hparams.add_hparam("lsh_truncated", int(True)) + # For testing right-masking. + # This is not implemented in all layers. + hparams.add_hparam("mask_right", int(False)) return hparams @@ -327,10 +367,9 @@ def aligned_local_expert(): def aligned_grouped(): """Use local_expert_attention. - languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.62 - 2.7 steps/sec on P100 - (some problem with map_fn - need to tune this) - 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.02 + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.63 + 10.2 steps/sec on P100 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.04 Returns: a hparams object @@ -468,6 +507,18 @@ def aligned_moe(): return hparams +@registry.register_hparams +def aligned_lsh(): + """Use multihead_attention_sparse_dot_prod. + + Returns: + a hparams object + """ + hparams = aligned_base() + hparams.layers = "timing," + "conv,att_lsh,ffn," * 2 + return hparams + + @registry.register_hparams def aligned_8k(): """version for languagemodel_wiki_scramble8k50. @@ -487,14 +538,16 @@ def aligned_8k(): def aligned_8k_grouped(): """version for languagemodel_wiki_scramble8k50. - languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.93 + languagemodel_wiki_scramble1k50, 1gpu, 7k steps: log(ppl)_eval = 2.92 3.3 steps/sec on P100 - 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.18 + 8gpu (8x batch), 7k steps: log(ppl)_eval = 2.15 Returns: a hparams object """ hparams = aligned_grouped() hparams.batch_size = 8192 - hparams.attention_image_summary = int(False) + # hparams.attention_image_summary = int(False) + hparams.num_groups = 16 + hparams.multiplicative_overhead = 1.1 return hparams diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 3a5b73a3e..85c7c9d49 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -46,11 +46,15 @@ class AttentionType(object): + """Enum of the attention layers types.""" MULTIHEAD = "multihead" LOCAL_EXPERTS = "local_experts" GLOBAL_MOE = "global_experts" MEMORY_EFFICIENT = "memory_efficient" SPARSE_MULTIHEAD = "sparse_multihead" + SPARSE_MULTIHEAD_TRUNCATED = "sparse_multihead_truncated" + MULTIHEAD_REDUCED = "multihead_reduced" + MULTIHEAD_FULL = "multihead_full" @staticmethod def get_choices(): @@ -59,6 +63,9 @@ def get_choices(): AttentionType.LOCAL_EXPERTS, AttentionType.MEMORY_EFFICIENT, AttentionType.SPARSE_MULTIHEAD, + AttentionType.SPARSE_MULTIHEAD_TRUNCATED, + AttentionType.MULTIHEAD_REDUCED, + AttentionType.MULTIHEAD_FULL, ] @@ -66,7 +73,10 @@ def get_choices(): "h": AttentionType.MULTIHEAD, # multi-Head "e": AttentionType.LOCAL_EXPERTS, # Experts "m": AttentionType.MEMORY_EFFICIENT, # Memory - "s": AttentionType.SPARSE_MULTIHEAD, # Sparse + "s": AttentionType.SPARSE_MULTIHEAD, # Sparse (Locality sensitive hashing) + "t": AttentionType.SPARSE_MULTIHEAD_TRUNCATED, # Using TruncatedDispatcher + "r": AttentionType.MULTIHEAD_REDUCED, # Reduced + "f": AttentionType.MULTIHEAD_FULL, # Force using full attention } @@ -132,12 +142,12 @@ def _diet_expert(x): x, hparams.attention_exp_factor) dp_expand_x = lambda x: dp( # pylint: disable=g-long-lambda - deconv_elems_1d, + common_attention.deconv_elems_1d, x, hparams.attention_exp_factor, hparams.attention_exp_inputdim) dp_compress_x = lambda x, l: dp( # pylint: disable=g-long-lambda - conv_elems_1d, + common_attention.conv_elems_1d, x, hparams.attention_exp_factor, l) @@ -158,6 +168,9 @@ def print_shape(x, suffix, debug=False): batch_coordinate = dp(get_batch_coordinate, x) batch_coordinate = dp_remove_pad(batch_coordinate) batch_coordinate = dp_expand_bc(batch_coordinate) + batch_order = dp(get_batch_coordinate, x, axis=-1) + batch_order = dp_remove_pad(batch_order) + batch_order = dp_expand_bc(batch_order) x = dp(print_shape, x, "in") @@ -176,7 +189,13 @@ def print_shape(x, suffix, debug=False): with tf.variable_scope( "attention_{}".format(attention_type)): - if attention_type == AttentionType.MULTIHEAD: + if attention_type in [ + AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL]: + attention_dot_type = ( + "local_mask_right" if hparams.attention_local else + "dot_product") + if attention_type == AttentionType.MULTIHEAD_FULL: + attention_dot_type = "dot_product" y = dp( common_attention.multihead_attention, preprocess(x), @@ -187,8 +206,8 @@ def print_shape(x, suffix, debug=False): hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, - attention_type=("local_mask_right" if hparams.attention_local - else "dot_product"), + attention_type=attention_dot_type, + block_length=hparams.attention_block_length, name="decoder_self_attention") elif attention_type == AttentionType.SPARSE_MULTIHEAD: x_in = preprocess(x) @@ -205,15 +224,43 @@ def print_shape(x, suffix, debug=False): hparams.attention_dropout, # Additional parameters - bc=batch_coordinate, + bi=[common_attention.BatchInfo( + coordinates=batch_coordinate[i], + order=batch_order[i], # No future mask + ) for i in range(dp.n)], + use_map_fn=hparams.lsh_use_map_fn, experts_params=dict( - train=hparams.mode == ModeKeys.TRAIN, - num_experts=hparams.attention_num_experts, - k=hparams.attention_moe_k, + nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) y = dp_restore_pad(y) + # TODO(avaswani, epot, noam): Do we need to divide by num shards ? + extra_loss += tf.add_n(loss_experts) / dp.n + elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED: + x_in = preprocess(x) + y, loss_experts = dp( + common_attention.multihead_attention_sparse_truncated, + x_in, + None, + None, # Bias is computed inside + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, + hparams.num_heads, + hparams.attention_dropout, + + # Additional parameters + bi=[common_attention.BatchInfo( + coordinates=batch_coordinate[i], + order=batch_order[i], # No future mask + ) for i in range(dp.n)], + mask_right=True, + experts_params=dict( + nb_hyperplanes=hparams.lsh_num_hyperplanes, + ), + ) + # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.MEMORY_EFFICIENT: @@ -224,6 +271,20 @@ def print_shape(x, suffix, debug=False): decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") + elif attention_type == AttentionType.MULTIHEAD_REDUCED: + y = dp( + common_attention.multihead_self_attention_reduced, + preprocess(x), + factor=hparams.attention_red_factor, + reduction_type=hparams.attention_reduction_type, + multihead_params=dict( + total_key_depth= + hparams.attention_key_channels or hparams.hidden_size, + total_value_depth= + hparams.attention_value_channels or hparams.hidden_size, + num_heads=hparams.num_heads, + dropout_rate=hparams.attention_dropout, + )) elif attention_type == AttentionType.LOCAL_EXPERTS: x_in = preprocess(x) x_in = dp_remove_pad(x_in) @@ -324,75 +385,16 @@ def attention_lm_moe_prepare_decoder(targets, hparams): return (decoder_input, decoder_self_attention_bias, pad_remover) -def get_batch_coordinate(x): +@expert_utils.add_name_scope() +def get_batch_coordinate(x, axis=0): """Return a flat int32 tensor of shape [1, batch_size*length, 1].""" # Compute the batch coordinate before flattening all batches batch_coordinate = tf.expand_dims( - common_attention.coordinate_tensor(tf.shape(x)[:-1], axis=0), axis=-1) + common_attention.coordinate_tensor(tf.shape(x)[:-1], axis=axis), axis=-1) return batch_coordinate -@expert_utils.add_var_scope() -def deconv_elems_1d(x, factor, out_depth): - """Increase the length and change the dimensionality. - - Expand/project each positions of dim depth of the input into - factor*tokens of dim out_depth - - Args: - x (tf.Tensor): shape [batch_size, length, depth] - factor (int): Multiplicative factor of each tokens. - out_depth (int): Output depth - - Returns: - tf.Tensor: shape [batch_size, length*factor, out_depth] - """ - x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth] - x = tf.layers.conv2d_transpose( - inputs=x, - filters=out_depth, - kernel_size=(1, factor), - strides=(1, factor), - padding="valid", - data_format="channels_last", - ) # [batch_size, 1, length*factor, out_depth] - x = tf.squeeze(x, 1) # [batch_size, 1, length, depth] - return x - - -@expert_utils.add_var_scope() -def conv_elems_1d(x, factor, out_depth): - """Decrease the length and change the dimensionality. - - Merge/restore/compress factors positions of dim depth of the input into - a single position of dim out_depth. - This is basically just a strided convolution without overlapp - between each strides. - The original length has to be divided by factor. - - Args: - x (tf.Tensor): shape [batch_size, length, depth] - factor (int): Length compression factor. - out_depth (int): Output depth - - Returns: - tf.Tensor: shape [batch_size, length//factor, out_depth] - """ - with tf.control_dependencies( # Dynamic assertion - [tf.assert_equal(tf.shape(x)[1] % factor, 0)]): - x = tf.expand_dims(x, 1) # [batch_size, 1, length, depth] - x = tf.layers.conv2d( - inputs=x, - filters=out_depth, - kernel_size=(1, factor), - strides=(1, factor), - padding="valid", - data_format="channels_last", - ) # [batch_size, 1, length//factor, out_depth] - x = tf.squeeze(x, 1) # [batch_size, 1, length, depth] - return x - - +@expert_utils.add_name_scope() def expand_batch_coordinates(bc, length_factor): """Duplicate elements of bc by length_factor. @@ -413,6 +415,7 @@ def expand_batch_coordinates(bc, length_factor): return bc +@expert_utils.add_name_scope() def remove_pad(x, pad_remover, mode): """Remove padding by concatenating all dimension into one. @@ -440,6 +443,7 @@ def remove_pad(x, pad_remover, mode): return x +@expert_utils.add_name_scope() def restore_pad(x, ref_x, pad_remover, mode): x = tf.squeeze(x, axis=0) if mode != ModeKeys.PREDICT: @@ -502,6 +506,9 @@ def attention_lm_moe_base(): hparams.add_hparam("attention_num_head", 1) hparams.add_hparam("attention_num_experts", 16) hparams.add_hparam("attention_split_batch", int(False)) + hparams.add_hparam("attention_red_factor", 3) + hparams.add_hparam("attention_block_length", 128) + hparams.add_hparam("attention_reduction_type", "conv") # If attention_exp_factor is set, each input to local_expert_attention (of # dimensionality hidden size) is projected into attention_exp_factor smaller # inputs, each of dimensionality attention_exp_inputdim. (otherwise @@ -513,6 +520,10 @@ def attention_lm_moe_base(): hparams.add_hparam("attention_v_size", 256) # Loss coef for load balancing hparams.add_hparam("attention_load_balance", 2e-2) + # Locality-sensitive hashing params + hparams.add_hparam("lsh_num_hyperplanes", 4) + hparams.add_hparam("lsh_use_map_fn", int(False)) + hparams.add_hparam("use_sepconv", int(False)) hparams.add_hparam("diet_experts", int(False)) hparams.add_hparam("memory_efficient_ffn", int(False)) @@ -581,6 +592,13 @@ def attention_lm_hybrid_v2(): return hparams +@registry.register_hparams +def attention_lm_16k(): + hparams = attention_lm_hybrid_v2() + hparams.batch_size = 16384 + return hparams + + @registry.register_hparams def attention_lm_ae_extended(): """Experiment with the exp_factor params.""" diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index e0f619805..baa85829c 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -95,7 +95,7 @@ def decode( attentions, used for fast decoding. Returns: - Final decoder representaiton. [batch_size, decoder_length, hidden_dim] + Final decoder representation. [batch_size, decoder_length, hidden_dim] """ decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) @@ -112,7 +112,7 @@ def decode( return tf.expand_dims(decoder_output, axis=2) def model_fn_body(self, features): - """Transformet main model_fn. + """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: @@ -122,7 +122,7 @@ def model_fn_body(self, features): "target_space_id" Returns: - Final decoder representaiton. [batch_size, decoder_length, hidden_dim] + Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams @@ -563,14 +563,13 @@ def transformer_ffn_layer(x, hparams, pad_remover=None): @registry.register_hparams -def transformer_base(): +def transformer_base_v1(): """Set of hyperparameters.""" hparams = common_hparams.basic_params1() hparams.norm_type = "layer" hparams.hidden_size = 512 hparams.batch_size = 4096 hparams.max_length = 256 - hparams.dropout = 0.0 hparams.clip_grad_norm = 0. # i.e. no gradient clipping hparams.optimizer_adam_epsilon = 1e-9 hparams.learning_rate_decay_scheme = "noam" @@ -611,6 +610,24 @@ def transformer_base(): return hparams +@registry.register_hparams +def transformer_base_v2(): + hparams = transformer_base_v1() + hparams.layer_preprocess_sequence = "n" + hparams.layer_postprocess_sequence = "da" + hparams.layer_prepostprocess_dropout = 0.1 + hparams.attention_dropout = 0.1 + hparams.relu_dropout = 0.1 + hparams.learning_rate_warmup_steps = 8000 + hparams.learning_rate = 0.2 + return hparams + + +@registry.register_hparams +def transformer_base(): + return transformer_base_v2() + + @registry.register_hparams def transformer_n_da(): """Normalize on layer input, instead of after residual connection. diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index d2b1bf631..67ec86ef5 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -129,7 +129,7 @@ def dae(x, hparams, name): gumbel_samples *= common_layers.inverse_exp_decay(steps // 5) * 0.5 temperature = 1.2 - common_layers.inverse_lin_decay(steps) # 30% of the time keep reasonably high temperature to keep learning. - temperature = tf.cond(tf.less(tf.random_uniform([]), 0.7), + temperature = tf.cond(tf.less(tf.random_uniform([]), 0.9), lambda: temperature, lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) s = tf.nn.softmax((logsm + gumbel_samples) / temperature) @@ -144,22 +144,56 @@ def dae(x, hparams, name): d_mean = tf.reduce_mean(distrib, axis=[0], keep_dims=True) d_variance = tf.reduce_mean(tf.square(distrib - d_mean), axis=[0]) d_dev = - tf.reduce_mean(d_variance) - ret = s # If we want just hot, do tf.reshape(maxvhot, tf.shape(s)) + ret = s + if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN: + ret = tf.reshape(maxvhot, tf.shape(s)) # Just hot on eval/infer. return m, ret, d_dev * 5.0 + tf.reduce_mean(kl) * 0.002 -def vae(x, hparams, name): +def vae(x, z_size, name): with tf.variable_scope(name): - mu = tf.layers.dense(x, hparams.z_size, name="mu") - log_sigma = tf.layers.dense(x, hparams.z_size, name="log_sigma") + mu = tf.layers.dense(x, z_size, name="mu") + log_sigma = tf.layers.dense(x, z_size, name="log_sigma") shape = tf.shape(x) - epsilon = tf.random_normal([shape[0], shape[1], 1, hparams.z_size]) + epsilon = tf.random_normal([shape[0], shape[1], 1, z_size]) z = mu + tf.exp(log_sigma / 2) * epsilon kl = 0.5 * tf.reduce_mean( tf.exp(log_sigma) + tf.square(mu) - 1. - log_sigma, axis=-1) return z, tf.reduce_mean(kl), mu, log_sigma +def bit_vae(x, hparams, name): + with tf.variable_scope(name): + bity = tf.layers.dense(x, hparams.z_size, name="bity") + dev = common_layers.inverse_lin_decay(hparams.startup_steps) * 1.5 + noise = tf.random_normal(tf.shape(bity), mean=0.0, stddev=dev) + y = common_layers.saturating_sigmoid(bity + noise) + tf.summary.histogram("bit", tf.reshape(y, [-1])) + def discrete_y(): + d = tf.to_float(tf.less(0.5, y)) + return tf.stop_gradient(d) + y - tf.stop_gradient(y) + y = tf.cond(tf.less(tf.train.get_global_step(), hparams.startup_steps), + lambda: y, discrete_y) + # Flatten and predict for loss. + y_flat = tf.reshape(y, [-1, hparams.z_size, 1, 1]) + hsize = hparams.hidden_size + hparams.hidden_size = hsize // 2 + emb0 = tf.get_variable("emb0", [hparams.hidden_size]) + emb1 = tf.get_variable("emb1", [hparams.hidden_size]) + emb0 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size]) + emb1 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size]) + y_emb = y_flat * emb1 + (1 - y_flat) * emb0 + y_logit = decode(None, None, y_emb, None, None, hparams, "dbit") + hparams.hidden_size = hsize + y_pred = tf.nn.log_softmax(tf.layers.dense(y_logit, 2, name="y_pred")) + y_flat = tf.reshape(y_flat, [-1]) + y_pred = tf.reshape(y_pred, [-1, 2]) + loss = - (y_flat * y_pred[:, 1] + (1 - y_flat) * y_pred[:, 0]) + # Get the final z and return. + z = tf.layers.dense(y, hparams.z_size, name="after_bit") + return z, tf.reduce_mean(loss) + + def nearest(x, means, hparams): """Find the nearest means to elements in x.""" x, means = tf.stop_gradient(x), tf.stop_gradient(means) @@ -223,18 +257,19 @@ def encode(x, x_space, hparams, name): encoder_input, encoder_self_attention_bias, hparams), ed -def decode(cond_vec, cond_add, gold, c, ed, hparams): +def decode(cond_vec, cond_add, gold, c, ed, hparams, name): """Transformer decoder.""" - drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout) - decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec) - if cond_add is not None: - decoder_input += cond_add - decoder_input = tf.squeeze(decoder_input, axis=2) - decoder_input = common_attention.add_timing_signal_1d(decoder_input) - bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1]) - if c is not None and len(c.get_shape()) > 3: - c = tf.squeeze(c, axis=2) - return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams) + with tf.variable_scope(name): + drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout) + decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec) + if cond_add is not None: + decoder_input += cond_add + decoder_input = tf.squeeze(decoder_input, axis=2) + decoder_input = common_attention.add_timing_signal_1d(decoder_input) + bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1]) + if c is not None and len(c.get_shape()) > 3: + c = tf.squeeze(c, axis=2) + return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams) def expand_batch(x, mul): @@ -256,9 +291,26 @@ def ae_compress(x, is_2d, hparams, name, reuse=None): # Convolve and ReLu to get state. cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv") - # To put a standard VAE use the line below. - # cur, vae_kl, _, _ = vae(cur, hparams, "kmeans_vae") - means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size]) + means_size = hparams.z_size if hparams.do_vae else hparams.v_size + means = tf.get_variable("z_to_dense", [means_size, hparams.hidden_size]) + if hparams.do_vae: + if hparams.bit_vae: + hot, loss = bit_vae(cur, hparams, "bvae") + else: + hot, loss, _, _ = vae(cur, hparams.z_size, "vae") + # Do a second level vae with some probability. + if hparams.z_size2 > 0: + prob_z2 = common_layers.inverse_exp_decay(hparams.startup_steps*2) * 0.8 + if hparams.mode != tf.contrib.learn.ModeKeys.TRAIN: + prob_z2 = 1.0 + def vae2(): + hot2, loss2, _, _ = vae(hot, hparams.z_size2, "vae2") + ret = tf.layers.dense(hot2, hparams.z_size) + return mix(ret, hot, hparams.startup_steps * 2), loss2 + hot, loss2 = tf.cond(tf.less(tf.random_uniform([]), prob_z2), + vae2, lambda: (hot, tf.constant(0.0))) + loss += loss2 * 0.1 + return cur, hot, loss if hparams.use_gumbel_softmax: _, hot, loss = dae(cur, hparams, "dae") return cur, hot, loss @@ -275,12 +327,13 @@ def ae_compress(x, is_2d, hparams, name, reuse=None): def ae_embed(hot, hparams, name, reuse=None): with tf.variable_scope(name, reuse=reuse): - means = tf.get_variable("z_to_dense", [hparams.v_size, hparams.hidden_size]) - hot_flat = tf.reshape(hot, [-1, hparams.v_size]) + means_size = hparams.z_size if hparams.do_vae else hparams.v_size + means = tf.get_variable("z_to_dense", [means_size, hparams.hidden_size]) + hot_flat = tf.reshape(hot, [-1, means_size]) emb = tf.matmul(hot_flat, means) emb = tf.reshape(emb, [tf.shape(hot)[0], tf.shape(hot)[1], tf.shape(hot)[2], hparams.hidden_size]) - if hparams.use_gumbel_softmax: + if hparams.use_gumbel_softmax or hparams.do_vae: return emb return tf.layers.dense(emb, hparams.hidden_size, name="unnormalize", reuse=reuse) @@ -289,14 +342,14 @@ def ae_embed(hot, hparams, name, reuse=None): def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None): """Decompress from z, leaking from ae.""" with tf.variable_scope(name + "_decompress", reuse=reuse): - if hparams.use_gumbel_softmax: + if hparams.use_gumbel_softmax or hparams.do_vae: # Leak at the beginning to help train. z = mix(z, ae, hparams.startup_steps) else: # Gradients flow to ae while the value is z. z = tf.stop_gradient(z) + ae - tf.stop_gradient(ae) # Leak during training to keep the full dense autoencoder. - prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.6 + prob_z = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.8 prob_z = prob_z if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN else 1.0 z = tf.cond(tf.less(tf.random_uniform([]), prob_z), lambda: z, lambda: ae) @@ -319,7 +372,7 @@ def ae_decompress(z, ae, x, is_2d, hparams, name, reuse=None): x_batch = tf.stop_gradient(x_batch) z_batch = tf.reshape(z, [-1, 1, 1, hparams.hidden_size]) d_batch = tf.reshape(d, [-1, k, 1, hparams.hidden_size]) - dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams) + dec_batch = decode(z_batch, d_batch, x_batch, None, None, hparams, "dar") else: # For non-autoregressive. dec_batch = d z = tf.reshape(dec_batch, [-1, tf.shape(x)[1], tf.shape(x)[2], @@ -352,21 +405,25 @@ def ae_transformer_internal(inputs, targets, target_space, hparams): emb = ae_embed(hot, hparams, "ae", reuse=True) # Compress context and run autoregressive decoder on emb-hot. - emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2) - emb_flat = tf.stop_gradient(emb_flat) - dec_c = decode(None, None, emb_flat, inputs, ed, hparams) - dec_c = tf.reshape(dec_c, tf.shape(emb)) - c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") - reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits( - labels=hot, logits=c_z) - # If not training, use the predicted z instead of the autoregressive one. - if hparams.mode == tf.estimator.ModeKeys.PREDICT: - hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) + if hparams.do_vae: + reconstruct_loss = 0.0 + else: + emb_flat = tf.expand_dims(common_layers.flatten4d3d(emb), axis=2) + emb_flat = tf.stop_gradient(emb_flat) + dec_c = decode(None, None, emb_flat, inputs, ed, hparams, "dgold") + dec_c = tf.reshape(dec_c, tf.shape(emb)) + c_z = tf.layers.dense(dec_c, hparams.v_size, name="mask_context") + reconstruct_loss = tf.nn.softmax_cross_entropy_with_logits( + labels=hot, logits=c_z) + # If not training, use the predicted z instead of the autoregressive one. + if hparams.mode == tf.estimator.ModeKeys.PREDICT: + hot = tf.one_hot(tf.argmax(c_z, axis=-1), hparams.v_size) # Decompress, pass for ae loss. z = ae_decompress(emb, ae, targets, hparams.is_2d, hparams, "ae") - kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8), - min_value=0.0001) + if not (hparams.use_gumbel_softmax and hparams.softmax_k > 0): + kl *= common_layers.inverse_exp_decay(int(hparams.startup_steps * 0.8), + min_value=0.0001) reconstruct_loss *= common_layers.inverse_exp_decay(hparams.startup_steps) losses = {"kl": kl, "reconstruction": reconstruct_loss * 0.1} return z, losses @@ -425,6 +482,7 @@ def transformer_ae_small(): hparams.batch_size = 2048 hparams.learning_rate_warmup_steps = 4000 hparams.add_hparam("z_size", 128) + hparams.add_hparam("z_size2", 0) hparams.add_hparam("v_size", 1024*32) hparams.add_hparam("num_compress_steps", 4) hparams.add_hparam("kl_warmup_steps", 60000) @@ -433,8 +491,10 @@ def transformer_ae_small(): hparams.add_hparam("z_dropout", 0.1) hparams.add_hparam("is_2d", 0) hparams.add_hparam("use_gumbel_softmax", int(True)) - hparams.add_hparam("softmax_k", 4) + hparams.add_hparam("softmax_k", 0) hparams.add_hparam("decode_autoregressive", int(True)) + hparams.add_hparam("do_vae", int(True)) + hparams.add_hparam("bit_vae", int(True)) return hparams @@ -442,15 +502,19 @@ def transformer_ae_small(): def transformer_ae_cifar(): """Hyperparameters for CIFAR-10 experiments.""" hparams = transformer_ae_small() - hparams.hidden_size = 384 - hparams.z_size = 256 - hparams.batch_size = 1024 * 16 + hparams.hidden_size = 256 + hparams.filter_size = 512 + hparams.z_size = 256 # 64 + hparams.z_size2 = 0 # 16 + hparams.batch_size = 1024 * 4 hparams.num_compress_steps = 2 hparams.v_size = 1024 * 16 hparams.kl_warmup_steps = 150000 - hparams.startup_steps = 30000 + hparams.startup_steps = 20000 hparams.kmeans_lr_factor = 0.0 hparams.is_2d = 1 + hparams.learning_rate_warmup_steps = 8000 + hparams.learning_rate = 0.2 return hparams diff --git a/tensor2tensor/tpu/tpu_trainer.py b/tensor2tensor/tpu/tpu_trainer.py index 2c6292405..8cda597d4 100644 --- a/tensor2tensor/tpu/tpu_trainer.py +++ b/tensor2tensor/tpu/tpu_trainer.py @@ -36,6 +36,8 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("master", "", "Address of TensorFlow master.") flags.DEFINE_integer("eval_steps", 10, "Number of steps in evaluation.") +flags.DEFINE_integer("iterations_per_loop", 1000, + "Number of iterations in a TPU training loop.") def main(unused_argv): @@ -58,14 +60,17 @@ def main(unused_argv): output_dir=FLAGS.output_dir, master=FLAGS.master, num_shards=FLAGS.tpu_num_shards, - batch_size=hparams.batch_size_per_shard * FLAGS.tpu_num_shards, - log_device_placement=FLAGS.log_device_placement) - estimator.train( - lambda params: input_fn(tf.estimator.ModeKeys.TRAIN, params), - steps=FLAGS.train_steps) - estimator.evaluate( - lambda params: input_fn(tf.estimator.ModeKeys.EVAL, params), - steps=FLAGS.eval_steps) + batch_size=hparams.tpu_batch_size_per_shard * FLAGS.tpu_num_shards, + log_device_placement=FLAGS.log_device_placement, + iterations_per_loop=FLAGS.iterations_per_loop) + if FLAGS.train_steps: + estimator.train( + lambda params: input_fn(tf.estimator.ModeKeys.TRAIN, params), + steps=FLAGS.train_steps) + if FLAGS.eval_steps: + estimator.evaluate( + lambda params: input_fn(tf.estimator.ModeKeys.EVAL, params), + steps=FLAGS.eval_steps) if __name__ == "__main__": diff --git a/tensor2tensor/tpu/tpu_trainer_lib.py b/tensor2tensor/tpu/tpu_trainer_lib.py index c6bba9d41..c514da2ad 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib.py +++ b/tensor2tensor/tpu/tpu_trainer_lib.py @@ -13,13 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Library for training on TPU. See tpu_trainer.py.""" +"""Library for training on TPU. See tpu_trainer.py. -# TODO(rsepassi): -# * Fix EVAL (breaks when loading from checkpoint) -# * Support all decoders -# * Share more code with Problem.dataset and input_pipeline -# * Support PREDICT +Currently only supports training and evaluation for text-to-text problems. +""" from __future__ import absolute_import from __future__ import division @@ -38,6 +35,7 @@ from tensor2tensor.utils import registry import tensorflow as tf +from tensorflow.python.util import nest def get_input_fn(data_dir, problem, hparams): @@ -49,11 +47,10 @@ def input_fn(mode, params): num_threads = 4 if is_training else 1 batch_size = params["batch_size"] - data_file_patterns = [problem.filepattern(data_dir, mode)] - batching_scheme = { "boundaries": [], "batch_sizes": [batch_size], + "min_length": hparams.min_length, "max_length": hparams.max_length, "window_size": batch_size, "padded_shapes": { @@ -71,9 +68,9 @@ def decode_record(record): return decoded data_files = tf.contrib.slim.parallel_reader.get_data_files( - data_file_patterns) - dataset = tf.contrib.data.TFRecordDataset(data_files) - dataset = dataset.map(decode_record, num_threads=num_threads) + problem.filepattern(data_dir, mode)) + dataset = tf.data.TFRecordDataset(data_files) + dataset = dataset.map(decode_record, num_parallel_calls=num_threads) def _preprocess(example, problem, hparams, mode): example = problem.preprocess_example(example, mode, hparams) @@ -83,19 +80,25 @@ def _preprocess(example, problem, hparams, mode): dataset = dataset.map( lambda ex: _preprocess(ex, problem, hparams, mode), - num_threads=num_threads) + num_parallel_calls=num_threads) def _valid_size(example): - return data_reader.example_valid_size(example, - batching_scheme["max_length"]) + return data_reader.example_valid_size( + example, batching_scheme["min_length"], batching_scheme["max_length"]) dataset = dataset.filter(_valid_size) if is_training: dataset = dataset.shuffle(100) - dataset = dataset.repeat(None) + # TODO(rsepassi): In eval mode, should not repeat + dataset = dataset.repeat(None) dataset = data_reader.padded_batch(dataset, batching_scheme["batch_sizes"][0], batching_scheme["padded_shapes"]) + + if not is_training: + dataset = dataset.map( + lambda f: pad_batch(f, batch_size), num_parallel_calls=num_threads) + dataset.prefetch(1) train_features = dataset.make_one_shot_iterator().get_next() @@ -109,13 +112,6 @@ def _valid_size(example): while len(targets.get_shape()) != 4: targets = tf.expand_dims(targets, axis=-1) - inputs_shape = inputs.get_shape().as_list() - inputs_shape[0] = batch_size - inputs.set_shape(inputs_shape) - targets_shape = targets.get_shape().as_list() - targets_shape[0] = batch_size - targets.set_shape(targets_shape) - train_features["inputs"] = inputs train_features["targets"] = targets @@ -124,6 +120,23 @@ def _valid_size(example): return input_fn +def pad_batch(features, batch_size): + """Pad each feature in features to batch_size on dim 0.""" + ts = [] + for t in nest.flatten(features): + before_pads = [0] * t.get_shape().ndims + after_pads = [0] * t.get_shape().ndims + batch_pad = tf.convert_to_tensor(batch_size) - tf.shape(t)[0] + after_pads[0] = batch_pad + pads = list(zip(before_pads, after_pads)) + old_shape = t.get_shape().as_list() + old_shape[0] = batch_size + t = tf.pad(t, pads) + t.set_shape(old_shape) + ts.append(t) + return nest.pack_sequence_as(features, ts) + + def get_model_fn(model, hp, use_tpu=True): """Get simple T2T model fn.""" @@ -150,6 +163,11 @@ def model_fn(features, labels, mode, params, config): outputs = model_class.model_fn_body(features) logits = target_modality.top(outputs, labels) + # Ensure the length is known statically + shape = [None] * logits.get_shape().ndims + shape[1] = hparams.max_length + logits.set_shape(logits.get_shape().merge_with(shape)) + # Loss loss_num, loss_den = target_modality.loss(logits, labels) loss = loss_num / tf.maximum(1.0, loss_den) @@ -157,6 +175,7 @@ def model_fn(features, labels, mode, params, config): if mode == tf.estimator.ModeKeys.EVAL: problem = hp.problem_instances[0] eval_metrics_fn = create_eval_metrics_fn(problem) + _remove_summaries() return tf.contrib.tpu.TPUEstimatorSpec( mode, eval_metrics=(eval_metrics_fn, [logits, orig_features["targets"]]), @@ -171,16 +190,7 @@ def model_fn(features, labels, mode, params, config): lr /= math.sqrt(float(num_shards)) # Optimizer - opt_name = hparams.optimizer - if opt_name == "Momentum": - opt = tf.train.MomentumOptimizer( - lr, momentum=hparams.optimizer_momentum_momentum) - else: - if hparams.optimizer not in ["RMSProp", "SGD"]: - tf.logging.warn( - "Only Momentum, RMSProp, and SGD are known to work on TPU.") - opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[opt_name](lr) - + opt = model_builder.ConditionalOptimizer(hparams.optimizer, lr, hparams) if use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) @@ -199,6 +209,13 @@ def model_fn(features, labels, mode, params, config): return model_fn +TPU_METRIC_BLACKLIST = set([ + metrics.Metrics.APPROX_BLEU, + metrics.Metrics.ROUGE_2_F, + metrics.Metrics.ROUGE_L_F, +]) + + def create_eval_metrics_fn(problem): """Create the metrics_fn that TPUEstimatorSpec expects.""" @@ -213,7 +230,11 @@ def wrapped_metric_fn(logits, labels): metric_fns = [] eval_metrics = problem.eval_metrics() + for metric in eval_metrics: + if metric in TPU_METRIC_BLACKLIST: + tf.logging.warn("Skipping eval metric %s in TPU_METRIC_BLACKLIST", metric) + continue name = "metrics-%s/%s" % (problem.name, metric) metric_fns.append((name, make_metric_fn(metrics.METRICS_FNS[metric]))) @@ -246,7 +267,7 @@ def make_estimator(model_fn, output_dir, master="", batch_size=16, - iterations_per_loop=100, + iterations_per_loop=1000, num_shards=8, per_host_input_for_training=True, use_tpu=True, @@ -264,7 +285,8 @@ def make_estimator(model_fn, save_summary_steps=0, save_checkpoints_steps=save_checkpoints_steps, tpu_config=tpu_config, - master=master) + master=master, + evaluation_master=master) return tf.contrib.tpu.TPUEstimator( model_fn=model_fn, @@ -280,16 +302,12 @@ def transformer_tpu(): """HParams for Transformer model on TPU.""" hp = transformer.transformer_base() hp.use_pad_remover = int(False) # where op not supported + hp.optimizer = "TrueAdam" + hp.learning_rate = 0.4 # Inputs - hp.add_hparam("batch_size_per_shard", 24) # Each example in the batch will be of (padded) length hp.max_length hp.max_length = 64 + hp.tpu_batch_size_per_shard = 20 - hp.optimizer = "Momentum" # can be SGD, Momentum, RMSProp - hp.norm_type = "none" # seem to get nans with layer norm - hp.clip_grad_norm = 2. - hp.norm_epsilon = 1e-3 - hp.layer_preprocess_sequence = "n" - hp.layer_postprocess_sequence = "da" return hp diff --git a/tensor2tensor/utils/beam_search.py b/tensor2tensor/utils/beam_search.py index 9c26579af..1dd2f87b1 100644 --- a/tensor2tensor/utils/beam_search.py +++ b/tensor2tensor/utils/beam_search.py @@ -22,12 +22,31 @@ # Dependency imports import tensorflow as tf +from tensorflow.python.util import nest + # Assuming EOS_ID is 1 EOS_ID = 1 # Default value for INF INF = 1. * 1e7 +def expand_to_beam_size(tensor, beam_size): + """Tiles a given tensor by beam_size. + + Args: + tensor: tensor to tile [batch_size, ...] + beam_size: How much to tile the tensor by. + + Returns: + Tiled tensor [batch_size, beam_size, ...] + """ + tensor = tf.expand_dims(tensor, axis=1) + tile_dims = [1] * tensor.shape.ndims + tile_dims[1] = beam_size + + return tf.tile(tensor, tile_dims) + + def log_prob_from_logits(logits): return logits - tf.reduce_logsumexp(logits, axis=2, keep_dims=True) @@ -51,7 +70,8 @@ def compute_batch_indices(batch_size, beam_size): def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags, - beam_size, batch_size, prefix="default"): + beam_size, batch_size, prefix="default", + states_to_gather=None): """Given sequences and scores, will gather the top k=beam size sequences. This function is used to grow alive, and finished. It takes sequences, @@ -79,6 +99,7 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags, beam_size: int batch_size: int prefix: string that will prefix unique names for the ops run. + states_to_gather: dict (possibly nested) of decoding states. Returns: Tuple of (topk_seq [batch_size, beam_size, decode_length], @@ -101,13 +122,17 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags, # Gather up the highest scoring sequences. For each operation added, give it # a concrete name to simplify observing these operations with tfdbg. Clients # can capture these tensors by watching these node names. - topk_seq = tf.gather_nd( - sequences, top_coordinates, name=(prefix + "_topk_seq")) - topk_flags = tf.gather_nd( - flags, top_coordinates, name=(prefix + "_topk_flags")) - topk_gathered_scores = tf.gather_nd( - scores_to_gather, top_coordinates, name=(prefix + "_topk_scores")) - return topk_seq, topk_gathered_scores, topk_flags + def gather(tensor, name): + return tf.gather_nd(tensor, top_coordinates, name=(prefix + name)) + topk_seq = gather(sequences, "_topk_seq") + topk_flags = gather(flags, "_topk_flags") + topk_gathered_scores = gather(scores_to_gather, "_topk_scores") + if states_to_gather: + topk_gathered_states = nest.map_structure( + lambda state: gather(state, "_topk_states"), states_to_gather) + else: + topk_gathered_states = states_to_gather + return topk_seq, topk_gathered_scores, topk_flags, topk_gathered_states def beam_search(symbols_to_logits_fn, @@ -116,6 +141,7 @@ def beam_search(symbols_to_logits_fn, decode_length, vocab_size, alpha, + states=None, eos_id=EOS_ID): """Beam search with length penalties. @@ -150,6 +176,7 @@ def beam_search(symbols_to_logits_fn, vocab_size: Size of the vocab, must equal the size of the logits returned by symbols_to_logits_fn alpha: alpha for length penalty. + states: dict (possibly nested) of decoding states. eos_id: ID for end of sentence. Returns: Tuple of @@ -163,9 +190,14 @@ def beam_search(symbols_to_logits_fn, # Expand to beam_size (batch_size, beam_size) alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1]) - # Expand each batch to beam_size - alive_seq = tf.tile(tf.expand_dims(initial_ids, 1), [1, beam_size]) - alive_seq = tf.expand_dims(alive_seq, 2) # (batch_size, beam_size, 1) + # Expand each batch and state to beam_size + alive_seq = expand_to_beam_size(initial_ids, beam_size) + alive_seq = tf.expand_dims(alive_seq, axis=2) # (batch_size, beam_size, 1) + if states: + states = nest.map_structure( + lambda state: expand_to_beam_size(state, beam_size), states) + else: + states = {} # Finished will keep track of all the sequences that have finished so far # Finished log probs will be negative infinity in the beginning @@ -214,7 +246,7 @@ def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq, curr_finished_seq, curr_finished_scores, curr_finished_scores, curr_finished_flags, beam_size, batch_size, "grow_finished") - def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished): + def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states): """Given sequences and scores, will gather the top k=beam size sequences. Args: @@ -225,6 +257,7 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished): [batch_size, beam_size] curr_finished: Finished flags for each of these sequences. [batch_size, beam_size] + states: dict (possibly nested) of decoding states. Returns: Tuple of (Topk sequences based on scores, @@ -236,9 +269,9 @@ def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished): curr_scores += tf.to_float(curr_finished) * -INF return compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs, curr_finished, beam_size, batch_size, - "grow_alive") + "grow_alive", states) - def grow_topk(i, alive_seq, alive_log_probs): + def grow_topk(i, alive_seq, alive_log_probs, states): r"""Inner beam seach loop. This function takes the current alive sequences, and grows them to topk @@ -255,19 +288,29 @@ def grow_topk(i, alive_seq, alive_log_probs): i: loop index alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1] alive_log_probs: probabilities of these sequences. [batch_size, beam_size] + states: dict (possibly nested) of decoding states. Returns: Tuple of (Topk sequences extended by the next word, The log probs of these sequences, The scores with length penalty of these sequences, - Flags indicating which of these sequences have finished decoding) + Flags indicating which of these sequences have finished decoding, + dict of transformed decoding states) """ # Get the logits for all the possible next symbols flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1]) # (batch_size * beam_size, decoded_length) - flat_logits = symbols_to_logits_fn(flat_ids) - logits = tf.reshape(flat_logits, (batch_size, beam_size, -1)) + if states: + flat_states = nest.map_structure( + lambda state: tf.reshape(state, [batch_size * beam_size, -1]), states) + flat_logits, flat_states = symbols_to_logits_fn(flat_ids, flat_states) + states = nest.map_structure( + lambda state: tf.reshape(state, [batch_size, beam_size, -1]), + flat_states) + else: + flat_logits = symbols_to_logits_fn(flat_ids) + logits = tf.reshape(flat_logits, [batch_size, beam_size, -1]) # Convert logits to normalized log probs candidate_log_probs = log_prob_from_logits(logits) @@ -305,16 +348,19 @@ def grow_topk(i, alive_seq, alive_log_probs): # Gather up the most probable 2*beams both for the ids and finished_in_alive # bools topk_seq = tf.gather_nd(alive_seq, topk_coordinates) + if states: + states = nest.map_structure( + lambda state: tf.gather_nd(state, topk_coordinates), states) # Append the most probable alive topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2) topk_finished = tf.equal(topk_ids, eos_id) - return topk_seq, topk_log_probs, topk_scores, topk_finished + return topk_seq, topk_log_probs, topk_scores, topk_finished, states def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, - finished_flags): + finished_flags, states): """Inner beam seach loop. There are three groups of tensors, alive, finished, and topk. @@ -346,6 +392,7 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, [batch_size, beam_size] finished_flags: finished bools for each of these sequences. [batch_size, beam_size] + states: dict (possibly nested) of decoding states. Returns: Tuple of @@ -354,26 +401,27 @@ def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores, Log probs of the alive sequences, New finished sequences, Scores of the new finished sequences, - Flags inidicating which sequence in finished as reached EOS) + Flags inidicating which sequence in finished as reached EOS, + dict of final decoding states) """ # Each inner loop, we carry out three steps: # 1. Get the current topk items. # 2. Extract the ones that have finished and haven't finished # 3. Recompute the contents of finished based on scores. - topk_seq, topk_log_probs, topk_scores, topk_finished = grow_topk( - i, alive_seq, alive_log_probs) - alive_seq, alive_log_probs, _ = grow_alive(topk_seq, topk_scores, - topk_log_probs, topk_finished) - finished_seq, finished_scores, finished_flags = grow_finished( + topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk( + i, alive_seq, alive_log_probs, states) + alive_seq, alive_log_probs, _, states = grow_alive( + topk_seq, topk_scores, topk_log_probs, topk_finished, states) + finished_seq, finished_scores, finished_flags, _ = grow_finished( finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished) return (i + 1, alive_seq, alive_log_probs, finished_seq, finished_scores, - finished_flags) + finished_flags, states) def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, - finished_scores, finished_in_finished): + finished_scores, finished_in_finished, unused_states): """Checking termination condition. We terminate when we decoded up to decode_length or the lowest scoring item @@ -416,11 +464,11 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, tf.less(i, decode_length), tf.logical_not(bound_is_met)) (_, alive_seq, alive_log_probs, finished_seq, finished_scores, - finished_flags) = tf.while_loop( + finished_flags, _) = tf.while_loop( _is_finished, inner_loop, [ tf.constant(0), alive_seq, alive_log_probs, finished_seq, - finished_scores, finished_flags + finished_scores, finished_flags, states ], shape_invariants=[ tf.TensorShape([]), @@ -428,7 +476,10 @@ def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, alive_log_probs.get_shape(), tf.TensorShape([None, None, None]), finished_scores.get_shape(), - finished_flags.get_shape() + finished_flags.get_shape(), + nest.map_structure( + lambda tensor: tf.TensorShape([None] * tensor.shape.ndims), + states), ], parallel_iterations=1, back_prop=False) diff --git a/tensor2tensor/utils/beam_search_test.py b/tensor2tensor/utils/beam_search_test.py index 5223989ea..fc15eb3bc 100644 --- a/tensor2tensor/utils/beam_search_test.py +++ b/tensor2tensor/utils/beam_search_test.py @@ -61,8 +61,9 @@ def testComputeTopkScoresAndSeq(self): flags = tf.constant([[True, False, False, True], [False, False, False, True]]) - topk_seq, topk_scores, topk_flags = beam_search.compute_topk_scores_and_seq( - sequences, scores, scores, flags, beam_size, batch_size) + topk_seq, topk_scores, topk_flags, _ = ( + beam_search.compute_topk_scores_and_seq( + sequences, scores, scores, flags, beam_size, batch_size)) with self.test_session(): topk_seq = topk_seq.eval() @@ -277,6 +278,96 @@ def symbols_to_logits(ids): ]], scores) self.assertAllEqual([[[0, 2, 0, 1], [0, 2, 1, 0]]], ids) + def testStates(self): + batch_size = 1 + beam_size = 1 + vocab_size = 2 + decode_length = 3 + + initial_ids = tf.constant([0] * batch_size) # GO + probabilities = tf.constant([[[0.7, 0.3]], [[0.4, 0.6]], [[0.5, 0.5]]]) + + expected_states = tf.constant([[[0.]], [[1.]]]) + + def symbols_to_logits(ids, states): + pos = tf.shape(ids)[1] - 1 + # We have to assert the values of state inline here since we can't fetch + # them out of the loop! + with tf.control_dependencies( + [tf.assert_equal(states["state"], expected_states[pos])]): + logits = tf.to_float(tf.log(probabilities[pos, :])) + + states["state"] += 1 + return logits, states + + states = { + "state": tf.zeros((batch_size, 1)), + } + + final_ids, _ = beam_search.beam_search( + symbols_to_logits, + initial_ids, + beam_size, + decode_length, + vocab_size, + 0.0, + eos_id=1, + states=states) + + with self.test_session() as sess: + # Catch and fail so that the testing framework doesn't think it's an error + try: + sess.run(final_ids) + except tf.errors.InvalidArgumentError as e: + raise AssertionError(e.message) + + def testStateBeamTwo(self): + batch_size = 1 + beam_size = 2 + vocab_size = 3 + decode_length = 3 + + initial_ids = tf.constant([0] * batch_size) # GO + probabilities = tf.constant([[[0.1, 0.1, 0.8], [0.1, 0.1, 0.8]], + [[0.4, 0.5, 0.1], [0.2, 0.4, 0.4]], + [[0.05, 0.9, 0.05], [0.4, 0.4, 0.2]]]) + + # The top beam is always selected so we should see the top beam's state + # at each position, which is the one thats getting 3 added to it each step. + expected_states = tf.constant([[[0.], [0.]], [[3.], [3.]], [[6.], [6.]]]) + + def symbols_to_logits(ids, states): + pos = tf.shape(ids)[1] - 1 + + # We have to assert the values of state inline here since we can't fetch + # them out of the loop! + with tf.control_dependencies( + [tf.assert_equal(states["state"], expected_states[pos])]): + logits = tf.to_float(tf.log(probabilities[pos, :])) + + states["state"] += tf.constant([[3.], [7.]]) + return logits, states + + states = { + "state": tf.zeros((batch_size, 1)), + } + + final_ids, _ = beam_search.beam_search( + symbols_to_logits, + initial_ids, + beam_size, + decode_length, + vocab_size, + 0.0, + eos_id=1, + states=states) + + with self.test_session() as sess: + # Catch and fail so that the testing framework doesn't think it's an error + try: + sess.run(final_ids) + except tf.errors.InvalidArgumentError as e: + raise AssertionError(e.message) if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/utils/data_reader.py b/tensor2tensor/utils/data_reader.py index cfe37c379..83f66b985 100644 --- a/tensor2tensor/utils/data_reader.py +++ b/tensor2tensor/utils/data_reader.py @@ -18,6 +18,8 @@ from __future__ import division from __future__ import print_function +import functools + # Dependency imports import numpy as np @@ -82,6 +84,7 @@ def input_pipeline(problem, "boundaries": a list of integers for the boundaries that will be used for bucketing; see bucket_by_sequence_length for more details. "batch_sizes": a list of batch sizes corresponding to the buckets + "min_length": an integer. We drop sequences which are shorter. "max_length": an integer. We drop sequences which are longer. dataset_split: tf.estimator.ModeKeys + ["test"], which split of the dataset to use. Defaults to mode. @@ -102,7 +105,11 @@ def input_pipeline(problem, dataset_split=dataset_split) dataset = dataset.map(cast_int64_to_int32, num_threads=num_threads) dataset = dataset.filter( - lambda ex: example_valid_size(ex, batching_scheme["max_length"])) + functools.partial( + example_valid_size, + min_length=batching_scheme["min_length"], + max_length=batching_scheme["max_length"], + )) if is_training: dataset = dataset.shuffle(capacity) dataset = dataset.repeat(None) @@ -143,8 +150,12 @@ def _example_length(example): return length -def example_valid_size(example, max_length): - return tf.less_equal(_example_length(example), max_length) +def example_valid_size(example, min_length, max_length): + length = _example_length(example) + return tf.logical_and( + length >= min_length, + length <= max_length, + ) def bucket_by_sequence_length(dataset, @@ -232,7 +243,8 @@ def _batching_scheme(batch_size, length_bucket_step, drop_long_sequences=False, shard_multiplier=1, - length_multiplier=1): + length_multiplier=1, + min_length=0): """A batching scheme based on model hyperparameters. Every batch containins a number of sequences divisible by `shard_multiplier`. @@ -251,18 +263,26 @@ def _batching_scheme(batch_size, across datashards. length_multiplier: an integer multiplier that is used to increase the batch sizes and sequence length tolerance. + min_length: int, sequences shorter than this will be skipped. Returns: A dictionary with parameters that can be passed to input_pipeline: * boundaries: list of bucket boundaries * batch_sizes: list of batch sizes for each length bucket * max_length: int, maximum length of an example + + Raises: + ValueError: If min_length > max_length """ max_length = max_length or batch_size + if max_length < min_length: + raise ValueError("max_length must be greater or equal to min_length") + boundaries = _bucket_boundaries(max_length, min_length_bucket, length_bucket_step) boundaries = [boundary * length_multiplier for boundary in boundaries] max_length *= length_multiplier + batch_sizes = [ max(1, batch_size // length) for length in boundaries + [max_length] ] @@ -293,9 +313,11 @@ def _batching_scheme(batch_size, # number of batches per window. max_batches_per_window = window_size // min(batch_sizes) shuffle_queue_size = max_batches_per_window * 3 + ret = { "boundaries": boundaries, "batch_sizes": batch_sizes, + "min_length": min_length, "max_length": (max_length if drop_long_sequences else 10**9), "shuffle_queue_size": shuffle_queue_size, "window_size": window_size, @@ -311,6 +333,7 @@ def hparams_to_batching_scheme(hparams, """Wrapper around _batching_scheme with hparams.""" return _batching_scheme( batch_size=hparams.batch_size, + min_length=hparams.min_length, max_length=hparams.max_length, min_length_bucket=hparams.min_length_bucket, length_bucket_step=hparams.length_bucket_step, @@ -333,6 +356,7 @@ def constant_batching_scheme(constant_batch_size_in_sequences): return { "boundaries": boundaries, "batch_sizes": batch_sizes, + "min_length": 0, "max_length": 10**9, "shuffle_queue_size": None, "window_size": constant_batch_size_in_sequences, diff --git a/tensor2tensor/utils/data_reader_test.py b/tensor2tensor/utils/data_reader_test.py index 0dccfaedf..bf2aa872e 100644 --- a/tensor2tensor/utils/data_reader_test.py +++ b/tensor2tensor/utils/data_reader_test.py @@ -120,7 +120,7 @@ def testLengthFilter(self): dataset = self.problem.dataset( tf.estimator.ModeKeys.TRAIN, data_dir=self.data_dir) dataset = dataset.filter( - lambda ex: data_reader.example_valid_size(ex, max_len)) + lambda ex: data_reader.example_valid_size(ex, 0, max_len)) examples = dataset.make_one_shot_iterator().get_next() with tf.train.MonitoredSession() as sess: ex_lens = [] diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index f1a3bf0bc..5dac0dd5f 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -52,7 +52,8 @@ def decode_hparams(overrides=""): return_beams=False, max_input_size=-1, identity_output=False, - num_samples=-1) + num_samples=-1, + delimiter="\n") hp = hp.parse(overrides) return hp @@ -86,10 +87,10 @@ def log_decode_results(inputs, if targets is not None: decoded_targets = " ".join(map(str, targets.flatten())) else: - decoded_outputs = " ".join( + decoded_outputs = "".join( map(str, targets_vocab.decode(_save_until_eos(outputs.flatten())))) if targets is not None: - decoded_targets = " ".join( + decoded_targets = "".join( map(str, targets_vocab.decode(_save_until_eos(targets.flatten())))) tf.logging.info("Inference results OUTPUT: %s" % decoded_outputs) @@ -176,8 +177,8 @@ def decode_from_dataset(estimator, # Write out predictions if decode_to_file passed if decode_to_file: for decoded_output, decoded_target in decoded_outputs: - output_file.write(str(decoded_output) + "\n") - target_file.write(str(decoded_target) + "\n") + output_file.write(str(decoded_output) + decode_hp.delimiter) + target_file.write(str(decoded_target) + decode_hp.delimiter) if (decode_hp.num_samples >= 0 and num_predictions >= decode_hp.num_samples): @@ -203,7 +204,8 @@ def decode_from_file(estimator, filename, decode_hp, decode_to_file=None): targets_vocab = hparams.problems[problem_id].vocabulary["targets"] problem_name = FLAGS.problems.split("-")[problem_id] tf.logging.info("Performing decoding from a file.") - sorted_inputs, sorted_keys = _get_sorted_inputs(filename, decode_hp.shards) + sorted_inputs, sorted_keys = _get_sorted_inputs(filename, decode_hp.shards, + decode_hp.delimiter) num_decode_batches = (len(sorted_inputs) - 1) // decode_hp.batch_size + 1 def input_fn(): @@ -251,7 +253,7 @@ def input_fn(): tf.logging.info("Writing decodes into %s" % decode_filename) outfile = tf.gfile.Open(decode_filename, "w") for index in range(len(sorted_inputs)): - outfile.write("%s\n" % (decodes[sorted_keys[index]])) + outfile.write("%s%s" % (decodes[sorted_keys[index]], decode_hp.delimiter)) def _decode_filename(base_filename, problem_name, decode_hp): @@ -472,13 +474,14 @@ def show_and_save_image(img, save_path): plt.savefig(save_path) -def _get_sorted_inputs(filename, num_shards=1): +def _get_sorted_inputs(filename, num_shards=1, delimiter="\n"): """Returning inputs sorted according to length. Args: filename: path to file with inputs, 1 per line. num_shards: number of input shards. If > 1, will read from file filename.XX, where XX is FLAGS.worker_id. + delimiter: str, delimits records in the file. Returns: a sorted list of inputs @@ -490,8 +493,12 @@ def _get_sorted_inputs(filename, num_shards=1): decode_filename = filename + ("%.2d" % FLAGS.worker_id) else: decode_filename = filename - inputs = [line.strip() for line in tf.gfile.Open(decode_filename)] - input_lens = [(i, len(line.strip().split())) for i, line in enumerate(inputs)] + + with tf.gfile.Open(decode_filename) as f: + text = f.read() + records = text.split(delimiter) + inputs = [record.strip() for record in records] + input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)] sorted_input_lens = sorted(input_lens, key=operator.itemgetter(1)) # We'll need the keys to rearrange the inputs back into their original order sorted_keys = {} @@ -553,8 +560,8 @@ def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring feature_map["problem_choice"]) features["input_space_id"] = input_space_id features["target_space_id"] = target_space_id - features["decode_length"] = (IMAGE_DECODE_LENGTH - if input_is_image else inputs[1]) + features["decode_length"] = ( + IMAGE_DECODE_LENGTH if input_is_image else inputs[1]) features["inputs"] = x return features @@ -588,7 +595,7 @@ def input_fn(problem_choice, x=inputs): # pylint: disable=missing-docstring features["problem_choice"] = feature_map["problem_choice"] features["input_space_id"] = input_space_id features["target_space_id"] = target_space_id - features["decode_length"] = (IMAGE_DECODE_LENGTH - if input_is_image else tf.shape(x)[1] + 50) + features["decode_length"] = ( + IMAGE_DECODE_LENGTH if input_is_image else tf.shape(x)[1] + 50) features["inputs"] = x return features diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index eb513d0e8..5005cdb50 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -677,6 +677,7 @@ def __init__(self, num_experts, gates): tf.reshape(self._gates, [-1]), self._batch_index * num_experts + self._expert_index) + @add_name_scope() def dispatch(self, inp): """Create one input Tensor for each expert. @@ -692,6 +693,7 @@ def dispatch(self, inp): inp = tf.gather(inp, self._batch_index) return tf.split(inp, self._part_sizes_tensor, 0, num=self._num_experts) + @add_name_scope() def combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, weighted by the gates. @@ -1019,3 +1021,143 @@ def local_moe(x, importance = tf.reduce_sum(gates, 0) loss = loss_coef * (cv_squared(importance) + cv_squared(load)) return y, loss + + +class TruncatingDispatcher(object): + """Helper for implementing a mixture of experts. + + A TruncatingDispatcher is useful when you need to deal with + fixed-sized Tensors. As opposed to a SparseDispatcher, which + produces batches of different sizes for the different experts, the + TruncatingDispatcher always produces batches of the same given size, + and the results are returned stacked in one big tensor. + + In the case where an expert is over-capacity, the last items that + should have gone to that expert are dropped. + + Confusingly, the inputs to a TruncatingDispatcher have both a + "batch" and a "length" dimension. Not only does each expert receive + the same total number of examples, it also receives the same number + of examples for each element of "batch". This behavior is necessary + for applications such as grouped attention, where we have a batch of + sequences, and we want each sequence to be divided evenly among + experts. For simpler applications like mixture-of-experts, you can + reshape the input so that the "batch" dimension is 1, and only the + "length" dimension is used. + """ + + @add_name_scope("truncating_dispatcher") + def __init__(self, requests, expert_capacity): + """Create a TruncatingDispatcher. + + Args: + requests: a boolean `Tensor` of shape `[batch, length, num_experts]`. + Alternatively, a float or int Tensor containing zeros and ones. + expert_capacity: a Scalar - maximum number of examples per expert per + batch element. + + Returns: + a TruncatingDispatcher + """ + self._requests = tf.to_float(requests) + self._expert_capacity = expert_capacity + expert_capacity_f = tf.to_float(expert_capacity) + self._batch, self._length, self._num_experts = tf.unstack( + tf.shape(self._requests), num=3) + + # [batch, length, num_experts] + position_in_expert = tf.cumsum(self._requests, axis=1, exclusive=True) + # [batch, length, num_experts] + self._gates = self._requests * tf.to_float( + tf.less(position_in_expert, expert_capacity_f)) + batch_index = tf.reshape( + tf.to_float(tf.range(self._batch)), [self._batch, 1, 1]) + length_index = tf.reshape( + tf.to_float(tf.range(self._length)), [1, self._length, 1]) + expert_index = tf.reshape( + tf.to_float(tf.range(self._num_experts)), [1, 1, self._num_experts]) + # position in a Tensor with shape [batch * num_experts * expert_capacity] + flat_position = ( + position_in_expert + + batch_index * (tf.to_float(self._num_experts) * expert_capacity_f) + + expert_index * expert_capacity_f) + # Tensor of shape [batch * num_experts * expert_capacity]. + # each element is an integer in [0, length) + self._indices = tf.unsorted_segment_sum( + data=tf.reshape((length_index + 1.0) * self._gates, [-1]), + segment_ids=tf.to_int32(tf.reshape(flat_position, [-1])), + num_segments=self._batch * self._num_experts * expert_capacity) + self._indices = tf.reshape( + self._indices, + [self._batch, self._num_experts, expert_capacity]) + # Tensors of shape [batch, num_experts, expert_capacity]. + # each element is 0.0 or 1.0 + self._nonpadding = tf.minimum(self._indices, 1.0) + # each element is an integer in [0, length) + self._indices = tf.nn.relu(self._indices - 1.0) + # self._flat_indices is [batch, num_experts, expert_capacity], with values + # in [0, batch * length) + self._flat_indices = tf.to_int32( + self._indices + + (tf.reshape(tf.to_float(tf.range(self._batch)), [-1, 1, 1]) + * tf.to_float(self._length))) + self._indices = tf.to_int32(self._indices) + + @add_name_scope("truncating_dispatcher_dispatch") + def dispatch(self, inp): + """Send the inputs to the experts. + + Args: + inp: a `Tensor` of shape "[batch, length, depth]` + Returns: + a tensor with shape [batch, num_experts, expert_capacity, depth] + """ + inp = tf.reshape(inp, [self._batch * self._length, -1]) + # [batch, num_experts, expert_capacity, depth] + ret = tf.gather(inp, self._flat_indices) + return ret + + @add_name_scope("truncating_dispatcher_combine") + def combine(self, x): + """Return the output from the experts. + + When one example goes to multiple experts, the outputs are summed. + + Args: + x: a Tensor with shape [batch, num_experts, expert_capacity, depth] + + Returns: + a `Tensor` with shape `[batch, length, depth] + """ + depth = tf.shape(x)[-1] + x *= tf.expand_dims(self._nonpadding, -1) + ret = tf.unsorted_segment_sum( + x, self._flat_indices, num_segments=self._batch * self._length) + ret = tf.reshape(ret, [self._batch, self._length, depth]) + return ret + + def nonpadding(self): + """Which elements of a dispatched Tensor are not padding. + + Returns: + a Zero/One float tensor with shape [batch, num_experts, expert_capacity]. + """ + return self._nonpadding + + def gates(self): + """A Tensor indicating which examples go to which experts. + + Returns: + A float32 Tensor with shape [batch, length, num_experts], where each value + is 0.0 or 1.0. + """ + return self._gates + + def length_coordinate(self): + """Length coordinate of dispatched tensor. + + Returns: + a tensor with shape [batch, num_experts, expert_capacity] containing + integers in the range [0, length) + """ + return self._indices diff --git a/tensor2tensor/utils/expert_utils_test.py b/tensor2tensor/utils/expert_utils_test.py index 93af9c78c..f9abc72c1 100644 --- a/tensor2tensor/utils/expert_utils_test.py +++ b/tensor2tensor/utils/expert_utils_test.py @@ -138,6 +138,74 @@ def testPadRemover(self): 0., # pad ]) + def testTruncatingDispatcher(self): + """Check that the TruncatingDispatcher is working correctly.""" + # batch = 1 + # length = 3 + # num_experts = 2 + expert_capacity = 2 + requests = tf.constant([ + [[True, False], + [True, True], + [True, False]], + [[False, False], + [False, True], + [True, False]] + ], dtype=tf.float32) + dispatcher = expert_utils.TruncatingDispatcher(requests, expert_capacity) + x = tf.constant([ + [[3, 4], + [5, 6], + [7, 8]], + [[2, 3], + [4, 5], + [6, 7]] + ], dtype=tf.float32) + dispatched = dispatcher.dispatch(x) + dispatched_expected = [ + [[[3, 4], [5, 6]], + [[5, 6], [3, 4]]], + [[[6, 7], [2, 3]], + [[4, 5], [2, 3]]] + ] + y = [ + [[[7, 12], [11, 30]], + [[-1, 30], [9, 9]]], + [[[13, 42], [9, 9]], + [[-1, 20], [9, 9]]] + ] + combined = dispatcher.combine(y) + combined_expected = [ + [[7, 12], + [10, 60], + [0, 0]], + [[0, 0], + [-1, 20], + [13, 42]] + ] + nonpadding = dispatcher.nonpadding() + nonpadding_expected = [ + [[1, 1], + [1, 0]], + [[1, 0], + [1, 0]] + ] + gates = dispatcher.gates() + gates_expected = [ + [[1, 0], + [1, 1], + [0, 0]], + [[0, 0], + [0, 1], + [1, 0]] + ] + + with self.test_session() as sess: + self._verify_value(sess, dispatched, dispatched_expected) + self._verify_value(sess, combined, combined_expected) + self._verify_value(sess, nonpadding, nonpadding_expected) + self._verify_value(sess, gates, gates_expected) + if __name__ == '__main__': tf.test.main() diff --git a/tensor2tensor/utils/input_fn_builder.py b/tensor2tensor/utils/input_fn_builder.py index c21dd973d..f4a3098ad 100644 --- a/tensor2tensor/utils/input_fn_builder.py +++ b/tensor2tensor/utils/input_fn_builder.py @@ -175,14 +175,13 @@ def _problem_choice(choice_mode, mode, problem_count, loss_moving_avgs, def cond_on_index(fn, index_tensor, max_idx, cur_idx=0): """Call fn(index_tensor) using tf.cond in [cur_id, max_idx].""" - if cur_idx == max_idx: return fn(cur_idx) return tf.cond( - tf.equal(index_tensor, cur_idx), - lambda: fn(cur_idx), - lambda: cond_on_index(fn, index_tensor, max_idx, cur_idx + 1) + tf.equal(index_tensor, cur_idx), + lambda: fn(cur_idx), + lambda: cond_on_index(fn, index_tensor, max_idx, cur_idx + 1) ) diff --git a/tensor2tensor/utils/input_fn_builder_test.py b/tensor2tensor/utils/input_fn_builder_test.py index 34b60c47a..ec2e6147e 100644 --- a/tensor2tensor/utils/input_fn_builder_test.py +++ b/tensor2tensor/utils/input_fn_builder_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function +# Dependency imports + from tensor2tensor.utils import input_fn_builder import tensorflow as tf @@ -26,13 +28,13 @@ class InputFnBuilderTest(tf.test.TestCase): def testCondOnIndex(self): - """Smoke tests of cond_on_index()""" + """Smoke tests of cond_on_index().""" z = tf.constant(1., dtype=tf.float32) def f(n): return { - "a": z * n, - "b": z * n * n + "a": z * n, + "b": z * n * n } index = tf.placeholder(shape=[], dtype=tf.int32) @@ -41,19 +43,19 @@ def f(n): with self.test_session() as sess: # Check dispatching to the correct branch result = sess.run(out, feed_dict={ - index: 2 + index: 2 }) self.assertAllClose(result["a"], 2.) self.assertAllClose(result["b"], 4.) result = sess.run(out, feed_dict={ - index: 3 + index: 3 }) self.assertAllClose(result["a"], 3.) self.assertAllClose(result["b"], 9.) -if __name__ == '__main__': +if __name__ == "__main__": tf.test.main() diff --git a/tensor2tensor/utils/metrics.py b/tensor2tensor/utils/metrics.py index 173ffb194..b4d82d97d 100644 --- a/tensor2tensor/utils/metrics.py +++ b/tensor2tensor/utils/metrics.py @@ -43,8 +43,9 @@ class Metrics(object): ROUGE_2_F = "rouge_2_fscore" ROUGE_L_F = "rouge_L_fscore" EDIT_DISTANCE = "edit_distance" - SET_PRECISION = 'set_precision' - SET_RECALL = 'set_recall' + SET_PRECISION = "set_precision" + SET_RECALL = "set_recall" + def padded_rmse(predictions, labels, weights_fn=common_layers.weights_all): predictions, labels = common_layers.pad_with_zeros(predictions, labels) @@ -189,47 +190,53 @@ def padded_accuracy(predictions, padded_labels = tf.to_int32(padded_labels) return tf.to_float(tf.equal(outputs, padded_labels)), weights + def set_precision(predictions, labels, weights_fn=common_layers.weights_nonzero): """Precision of set predictions. Args: - predictions : A Tensor of scores of shape (batch, nlabels) - labels: A Tensor of int32s giving true set elements of shape (batch, seq_length) + predictions : A Tensor of scores of shape [batch, nlabels]. + labels: A Tensor of int32s giving true set elements, + of shape [batch, seq_length]. + weights_fn: A function to weight the elements. Returns: - hits: A Tensor of shape (batch, nlabels) - weights: A Tensor of shape (batch, nlabels) + hits: A Tensor of shape [batch, nlabels]. + weights: A Tensor of shape [batch, nlabels]. """ with tf.variable_scope("set_precision", values=[predictions, labels]): labels = tf.squeeze(labels, [2, 3]) + weights = weights_fn(labels) labels = tf.one_hot(labels, predictions.shape[-1]) labels = tf.reduce_max(labels, axis=1) labels = tf.cast(labels, tf.bool) - predictions = predictions > 0 - return tf.to_float(tf.equal(labels, predictions)), tf.to_float(predictions) - + return tf.to_float(tf.equal(labels, predictions)), weights + + def set_recall(predictions, - labels, - weights_fn=common_layers.weights_nonzero): + labels, + weights_fn=common_layers.weights_nonzero): """Recall of set predictions. Args: - predictions : A Tensor of scores of shape (batch, nlabels) - labels: A Tensor of int32s giving true set elements of shape (batch, seq_length) + predictions : A Tensor of scores of shape [batch, nlabels]. + labels: A Tensor of int32s giving true set elements, + of shape [batch, seq_length]. + weights_fn: A function to weight the elements. Returns: - hits: A Tensor of shape (batch, nlabels) - weights: A Tensor of shape (batch, nlabels) + hits: A Tensor of shape [batch, nlabels]. + weights: A Tensor of shape [batch, nlabels]. """ with tf.variable_scope("set_recall", values=[predictions, labels]): labels = tf.squeeze(labels, [2, 3]) + weights = weights_fn(labels) labels = tf.one_hot(labels, predictions.shape[-1]) labels = tf.reduce_max(labels, axis=1) labels = tf.cast(labels, tf.bool) - predictions = predictions > 0 - return tf.to_float(tf.equal(labels, predictions)), tf.to_float(labels) + return tf.to_float(tf.equal(labels, predictions)), weights def create_evaluation_metrics(problems, model_hparams): @@ -299,7 +306,10 @@ def wrapped_metric_fn(): metric_fn = METRICS_FNS[metric] problem_metric_fn = make_problem_specific_metric_fn( metric_fn, problem_idx, weights_fn) - eval_metrics["metrics-%s/%s" % (problem_name, metric)] = problem_metric_fn + + metric_name = "metrics-%s/%s" % (problem_name, metric) + + eval_metrics[metric_name] = problem_metric_fn return eval_metrics diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index 370104907..44a6f5208 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -292,7 +292,7 @@ def nth_model(n): # Optimize total_loss = tf.identity(total_loss, name="total_loss") - opt = _ConditionalOptimizer(hparams.optimizer, learning_rate, hparams) + opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams) opt_summaries = ["learning_rate", "loss"] if hparams.summarize_grads: opt_summaries.extend(["gradients", "gradient_norm"]) @@ -350,7 +350,7 @@ def wrapping_model_fn(features, labels, mode, params): return wrapping_model_fn -class _ConditionalOptimizer(tf.train.Optimizer): +class ConditionalOptimizer(tf.train.Optimizer): """Conditional optimizer.""" def __init__(self, optimizer_name, lr, hparams): @@ -369,16 +369,21 @@ def __init__(self, optimizer_name, lr, hparams): tf.logging.info("Init YellowFin Optimizer.") self._opt = yellowfin.YellowFinOptimizer( learning_rate=lr, momentum=hparams.optimizer_momentum_momentum) + elif optimizer_name == "TrueAdam": + self._opt = tf.train.AdamOptimizer( + lr / 500.0, + beta1=hparams.optimizer_adam_beta1, + beta2=hparams.optimizer_adam_beta2, + epsilon=hparams.optimizer_adam_epsilon) else: self._opt = tf.contrib.layers.OPTIMIZER_CLS_NAMES[optimizer_name](lr) - def compute_gradients(self, loss, var_list, colocate_gradients_with_ops): - return self._opt.compute_gradients( - loss, var_list, colocate_gradients_with_ops=colocate_gradients_with_ops) + def compute_gradients(self, loss, var_list=None, **kwargs): + return self._opt.compute_gradients(loss, var_list, **kwargs) - def apply_gradients(self, gradients, global_step=None, name=None): + def apply_gradients(self, grads_and_vars, global_step=None, name=None): return self._opt.apply_gradients( - gradients, global_step=global_step, name=name) + grads_and_vars, global_step=global_step, name=name) def _sqrt_decay(step): diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 72e2ea602..c54b38f3f 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -26,6 +26,7 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensor2tensor.layers import common_layers from tensor2tensor.utils import beam_search from tensor2tensor.utils import expert_utils as eu from tensor2tensor.utils import registry @@ -332,6 +333,10 @@ def _slow_greedy_infer(self, features, decode_length, last_position_only): features["inputs"] = tf.expand_dims(features["inputs"], 2) if not self.has_input: features["partial_targets"] = tf.to_int64(features["inputs"]) + # Save the targets in a var and reassign it after the tf.while loop to avoid + # having targets being in a 'while' frame. This ensures targets when used + # in metric functions stays in the same frame as other vars. + targets_old = features.get("targets", None) def infer_step(recent_output, recent_logits, unused_loss): """Inference step.""" @@ -394,6 +399,9 @@ def infer_step(recent_output, recent_logits, unused_loss): parallel_iterations=1) if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old + # Reassign targets back to the previous value. + if targets_old is not None: + features["targets"] = targets_old losses = {"training": loss} if "partial_targets" in features: partial_target_length = tf.shape(features["partial_targets"])[1] @@ -420,15 +428,17 @@ def sample(self, features, last_position_only=False): else: assert self._hparams.sampling_method == "random" - def _multinomial_squeeze(logits): - reshaped_logits = tf.reshape(logits, [-1, tf.shape(logits)[-1]]) + def _multinomial_squeeze(logits, temperature=1.0): + reshaped_logits = ( + tf.reshape(logits, [-1, tf.shape(logits)[-1]])/temperature) choices = tf.multinomial(reshaped_logits, 1) choices = tf.reshape(choices, tf.shape(logits)[:logits.get_shape().ndims - 1]) return choices sharded_samples = self._data_parallelism(_multinomial_squeeze, - sharded_logits) + sharded_logits, + self._hparams.sampling_temp) return tf.concat(sharded_samples, 0), sharded_logits, losses def _shard_features(self, features): # pylint: disable=missing-docstring @@ -514,9 +524,9 @@ def model_fn(self, features, skip=False, last_position_only=False): with tf.variable_scope(target_modality.name, reuse=target_reuse): if not last_position_only: sharded_logits = target_modality.top_sharded( - body_outputs, sharded_features["targets"], self._data_parallelism) + body_outputs, sharded_features["targets"], dp) training_loss = target_modality.loss_sharded( - sharded_logits, sharded_features["targets"], self._data_parallelism) + sharded_logits, sharded_features["targets"], dp) training_loss *= self._problem_hparams.loss_multiplier else: @@ -534,9 +544,60 @@ def model_fn(self, features, skip=False, last_position_only=False): last_position_targets, self._data_parallelism) training_loss = None + losses["training"] = training_loss + + # Scheduled sampling. + do_scheduled_sampling = ( # Only do it if training and set for it. + self._hparams.scheduled_sampling_prob > 0.0 and + self._hparams.mode == tf.estimator.ModeKeys.TRAIN and + not skip) + if do_scheduled_sampling: + + def sample(x): + """Multinomial sampling from a n-dimensional tensor.""" + vocab_size = target_modality.top_dimensionality + samples = tf.multinomial(tf.reshape(x, [-1, vocab_size]), 1) + reshaped_samples = tf.reshape(samples, tf.shape(x)[:-1]) + return tf.to_int32(reshaped_samples) + + def mix_gold_sampled(gold_targets, sampled_targets): + return tf.where( + tf.less(tf.random_uniform(tf.shape(sampled_targets)), + self._hparams.scheduled_sampling_gold_mixin_prob), + gold_targets, sampled_targets) + + def sampled_results(): + """Generate scheduled sampling results.""" + sampled_targets = dp(sample, sharded_logits) + new_targets = dp(mix_gold_sampled, + sharded_features["targets"], sampled_targets) + new_features = transformed_features + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + with tf.variable_scope(target_modality.name): + new_features["targets"] = target_modality.targets_bottom_sharded( + new_targets, dp) + with tf.variable_scope("body"): + body_outputs, losses = self.model_fn_body_sharded(new_features) + if not isinstance(losses, dict): # If it's a single extra loss. + losses = {"extra": losses} + with tf.variable_scope(target_modality.name): + new_sharded_logits = target_modality.top_sharded( + body_outputs, sharded_features["targets"], dp) + training_loss = target_modality.loss_sharded( + sharded_logits, sharded_features["targets"], dp) + training_loss *= self._problem_hparams.loss_multiplier + losses["training"] = training_loss + return new_sharded_logits, losses + # Run the above conditionally. + prob = self._hparams.scheduled_sampling_prob + prob *= common_layers.inverse_exp_decay( + self._hparams.scheduled_sampling_warmup_steps, min_value=0.001) + sharded_logits, losses = tf.cond( + tf.less(tf.random_uniform([]), prob), + sampled_results, + lambda: (sharded_logits, losses)) tf.logging.info("This model_fn took %.3f sec." % (time.time() - start_time)) - losses["training"] = training_loss return sharded_logits, losses def model_fn_body_sharded(self, sharded_features): diff --git a/tensor2tensor/utils/trainer_utils.py b/tensor2tensor/utils/trainer_utils.py index fcdf5a463..e90e2dd10 100644 --- a/tensor2tensor/utils/trainer_utils.py +++ b/tensor2tensor/utils/trainer_utils.py @@ -167,6 +167,7 @@ def create_experiment(data_dir, model_name, train_steps, eval_steps, hparams, min_eval_frequency=FLAGS.local_eval_frequency, train_monitors=train_monitors, eval_hooks=eval_hooks, + eval_delay_secs=0, **optional_kwargs) @@ -353,6 +354,7 @@ def run(data_dir, model, output_dir, train_steps, eval_steps, schedule): def validate_flags(): + """Validate command line flags.""" if not FLAGS.model: raise ValueError("Must specify a model with --model.") if not FLAGS.problems: @@ -365,6 +367,8 @@ def validate_flags(): FLAGS.output_dir = "/tmp/tensor2tensor" tf.logging.warning("It is strongly recommended to specify --output_dir. " "Using default output_dir=%s.", FLAGS.output_dir) + if not FLAGS.data_dir: + raise ValueError("Must specify --data_dir.") def is_chief():