diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 027449854..3f7a2e387 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -3608,8 +3608,7 @@ def local_attention_2d(q, Args: q: a Tensor with shape [batch, heads, h, w, depth_k] k: a Tensor with shape [batch, heads, h, w, depth_k] - v: a Tensor with shape [batch, heads, h, w, depth_v]. In the current - implementation, depth_v must be equal to depth_k. + v: a Tensor with shape [batch, heads, h, w, depth_v] query_shape: an tuple indicating the height and width of each query block. memory_flange: an integer indicating how much to look in height and width from each query block. @@ -3652,9 +3651,12 @@ def local_attention_2d(q, dropout_rate=0., name="local_2d", make_image_summary=False) - # Put representations back into original shapes. + + # Form padded output shape: [batch, heads, h_padded, w_padded, depth_v]. padded_q_shape = common_layers.shape_list(q) - output = scatter_blocks_2d(output, q_indices, padded_q_shape) + output_shape = padded_q_shape[:4] + v_shape[4:] + # Put representations back into original shapes. + output = scatter_blocks_2d(output, q_indices, output_shape) # Remove the padding if introduced. output = tf.slice(output, [0, 0, 0, 0, 0], @@ -3935,8 +3937,7 @@ def masked_local_attention_2d(q, Args: q: a Tensor with shape [batch, heads, h, w, depth_k] k: a Tensor with shape [batch, heads, h, w, depth_k] - v: a Tensor with shape [batch, heads, h, w, depth_v]. In the current - implementation, depth_v must be equal to depth_k. + v: a Tensor with shape [batch, heads, h, w, depth_v] query_shape: an tuple indicating the height and width of each query block. query_shape = block_shape memory_flange: an integer indicating how much to look in height and width @@ -4000,9 +4001,12 @@ def masked_local_attention_2d(q, dropout_rate=0., name="masked_local_2d", make_image_summary=False) - # Put representations back into original shapes. + + # Form padded output shape: [batch, heads, h_padded, w_padded, depth_v]. padded_q_shape = common_layers.shape_list(q) - output = scatter_blocks_2d(output, q_indices, padded_q_shape) + padded_output_shape = padded_q_shape[:4] + v_shape[4:] + # Put representations back into original shapes. + output = scatter_blocks_2d(output, q_indices, padded_output_shape) # Remove the padding if introduced. output = tf.slice(output, [0, 0, 0, 0, 0], diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index ebc286c6d..87a9398ab 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -517,8 +517,7 @@ def testMaskedLocalAttention1D(self, batch, heads, length, depth_k, depth_v, ("", 1, 1, 8, 4, 4, (2, 2)), ("dynamic_batch", None, 1, 8, 4, 4, (2, 2)), ("batches", 3, 2, 8, 4, 4, (2, 2)), - # TODO(trandustin): Extend function to enable depth_k != depth_v. - # ("depth_v", 1, 1, 8, 4, 1, (2, 2)), + ("depth_v", 1, 1, 8, 4, 1, (2, 2)), ("query_shape", 1, 1, 8, 4, 4, (4, 4)), ) def testMaskedLocalAttention2D(self, batch, heads, length, depth_k, depth_v, @@ -567,8 +566,7 @@ def testLocalUnmaskedAttention1D(self, batch, heads, length, ("matching_block_length", 3, 4, 25, 16, 16, (4, 4)), ("unmatching_block_length", 3, 4, 25, 16, 16, (5, 5)), ("dynamic_batch", None, 4, 25, 16, 16, (4, 4)), - # TODO(trandustin): Extend function to enable depth_k != depth_v. - # ("different_depth_v", 3, 4, 25, 16, 17, (4, 4)), + ("different_depth_v", 3, 4, 25, 16, 17, (4, 4)), ) def testLocalUnmaskedAttention2D(self, batch, heads, length, depth_k, depth_v, query_shape):