Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #377 from kolloldas/master
Browse files Browse the repository at this point in the history
Update LSTM Attention Model to use tf.contrib.seq2seq.AttentionWrapper
  • Loading branch information
lukaszkaiser authored Nov 2, 2017
2 parents 172a1b1 + f67483e commit 9e7d03f
Showing 1 changed file with 63 additions and 148 deletions.
211 changes: 63 additions & 148 deletions tensor2tensor/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,144 +31,6 @@
import tensorflow as tf
from tensorflow.python.util import nest

# Track Tuple of state and attention values
AttentionTuple = collections.namedtuple("AttentionTuple", ("state",
"attention"))


class ExternalAttentionCellWrapper(tf.contrib.rnn.RNNCell):
"""Wrapper for external attention states for an encoder-decoder setup."""

def __init__(self,
cell,
attn_states,
attn_vec_size=None,
input_size=None,
state_is_tuple=True,
reuse=None):
"""Create a cell with attention.
Args:
cell: an RNNCell, an attention is added to it.
attn_states: External attention states typically the encoder output in the
form [batch_size, time steps, hidden size]
attn_vec_size: integer, the number of convolutional features calculated
on attention state and a size of the hidden layer built from
base cell state. Equal attn_size to by default.
input_size: integer, the size of a hidden linear layer,
built from inputs and attention. Derived from the input tensor
by default.
state_is_tuple: If True, accepted and returned states are n-tuples, where
`n = len(cells)`. Must be set to True else will raise an exception
concatenated along the column axis.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if the flag `state_is_tuple` is `False` or if shape of
`attn_states` is not 3 or if innermost dimension (hidden size) is None.
"""
super(ExternalAttentionCellWrapper, self).__init__(_reuse=reuse)
if not state_is_tuple:
raise ValueError("Only tuple state is supported")

self._cell = cell
self._input_size = input_size

# Validate attn_states shape.
attn_shape = attn_states.get_shape()
if not attn_shape or len(attn_shape) != 3:
raise ValueError("attn_shape must be rank 3")

self._attn_states = attn_states
self._attn_size = attn_shape[2].value
if self._attn_size is None:
raise ValueError("Hidden size of attn_states cannot be None")

self._attn_vec_size = attn_vec_size
if self._attn_vec_size is None:
self._attn_vec_size = self._attn_size

self._reuse = reuse

@property
def state_size(self):
return AttentionTuple(self._cell.state_size, self._attn_size)

@property
def output_size(self):
return self._attn_size

def combine_state(self, previous_state):
"""Combines previous state (from encoder) with internal attention values.
You must use this function to derive the initial state passed into
this cell as it expects a named tuple (AttentionTuple).
Args:
previous_state: State from another block that will be fed into this cell;
Must have same structure as the state of the cell wrapped by this.
Returns:
Combined state (AttentionTuple).
"""
batch_size = self._attn_states.get_shape()[0].value
if batch_size is None:
batch_size = tf.shape(self._attn_states)[0]
zeroed_state = self.zero_state(batch_size, self._attn_states.dtype)
return AttentionTuple(previous_state, zeroed_state.attention)

def call(self, inputs, state):
"""Long short-term memory cell with attention (LSTMA)."""

if not isinstance(state, AttentionTuple):
raise TypeError("State must be of type AttentionTuple")

state, attns = state
attn_states = self._attn_states
attn_length = attn_states.get_shape()[1].value
if attn_length is None:
attn_length = tf.shape(attn_states)[1]

input_size = self._input_size
if input_size is None:
input_size = inputs.get_shape().as_list()[1]
if attns is not None:
inputs = tf.layers.dense(tf.concat([inputs, attns], axis=1), input_size)
lstm_output, new_state = self._cell(inputs, state)

new_state_cat = tf.concat(nest.flatten(new_state), 1)
new_attns = self._attention(new_state_cat, attn_states, attn_length)

with tf.variable_scope("attn_output_projection"):
output = tf.layers.dense(
tf.concat([lstm_output, new_attns], axis=1), self._attn_size)

new_state = AttentionTuple(new_state, new_attns)

return output, new_state

def _attention(self, query, attn_states, attn_length):
conv2d = tf.nn.conv2d
reduce_sum = tf.reduce_sum
softmax = tf.nn.softmax
tanh = tf.tanh

with tf.variable_scope("attention"):
k = tf.get_variable("attn_w",
[1, 1, self._attn_size, self._attn_vec_size])
v = tf.get_variable("attn_v", [self._attn_vec_size, 1])
hidden = tf.reshape(attn_states, [-1, attn_length, 1, self._attn_size])
hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
y = tf.layers.dense(query, self._attn_vec_size)
y = tf.reshape(y, [-1, 1, 1, self._attn_vec_size])
s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
a = softmax(s)
d = reduce_sum(tf.reshape(a, [-1, attn_length, 1, 1]) * hidden, [1, 2])
new_attns = tf.reshape(d, [-1, self._attn_size])

return new_attns


def lstm(inputs, hparams, train, name, initial_state=None):
"""Run LSTM cell on inputs, assuming they are [batch x time x size]."""
Expand All @@ -189,7 +51,7 @@ def dropout_lstm_cell():


def lstm_attention_decoder(inputs, hparams, train, name, initial_state,
attn_states):
encoder_outputs):
"""Run LSTM cell with attention on inputs of shape [batch x time x size]."""

def dropout_lstm_cell():
Expand All @@ -198,18 +60,36 @@ def dropout_lstm_cell():
input_keep_prob=1.0 - hparams.dropout * tf.to_float(train))

layers = [dropout_lstm_cell() for _ in range(hparams.num_hidden_layers)]
cell = ExternalAttentionCellWrapper(
AttentionMechanism = (tf.contrib.seq2seq.LuongAttention if hparams.attention_mechanism == "luong"
else tf.contrib.seq2seq.BahdanauAttention)
attention_mechanism = AttentionMechanism(hparams.hidden_size, encoder_outputs)

cell = tf.contrib.seq2seq.AttentionWrapper(
tf.nn.rnn_cell.MultiRNNCell(layers),
attn_states,
attn_vec_size=hparams.attn_vec_size)
initial_state = cell.combine_state(initial_state)
[attention_mechanism]*hparams.num_heads,
attention_layer_size=[hparams.attention_layer_size]*hparams.num_heads,
output_attention=(hparams.output_attention==1))


batch_size = inputs.get_shape()[0].value
if batch_size is None:
batch_size = tf.shape(inputs)[0]

initial_state = cell.zero_state(batch_size, tf.float32).clone(cell_state=initial_state)

with tf.variable_scope(name):
return tf.nn.dynamic_rnn(
output, state = tf.nn.dynamic_rnn(
cell,
inputs,
initial_state=initial_state,
dtype=tf.float32,
time_major=False)

# For multi-head attention project output back to hidden size
if hparams.output_attention == 1 and hparams.num_heads > 1:
output = tf.layers.dense(output, hparams.hidden_size)

return output, state


def lstm_seq2seq_internal(inputs, targets, hparams, train):
Expand Down Expand Up @@ -273,14 +153,49 @@ def lstm_seq2seq():
hparams.hidden_size = 128
hparams.num_hidden_layers = 2
hparams.initializer = "uniform_unit_scaling"
hparams.initializer_gain = 1.0
hparams.weight_decay = 0.0

return hparams

def lstm_attention_base():
""" Base attention params. """
hparams = lstm_seq2seq()
hparams.add_hparam("attention_layer_size", hparams.hidden_size)
hparams.add_hparam("output_attention", int(True))
hparams.add_hparam("num_heads", 1)
return hparams


@registry.register_hparams
def lstm_bahdanau_attention():
"""hparams for LSTM with bahdanau attention."""
hparams = lstm_attention_base()
hparams.add_hparam("attention_mechanism", "bahdanau")
return hparams

@registry.register_hparams
def lstm_luong_attention():
"""hparams for LSTM with luong attention."""
hparams = lstm_attention_base()
hparams.add_hparam("attention_mechanism", "luong")
return hparams

@registry.register_hparams
def lstm_attention():
"""hparams for LSTM with attention."""
hparams = lstm_seq2seq()
""" For backwards compatibility, Defaults to bahdanau """
return lstm_bahdanau_attention()

# Attention
hparams.add_hparam("attn_vec_size", hparams.hidden_size)
@registry.register_hparams
def lstm_bahdanau_attention_multi():
""" Multi-head Luong attention """
hparams = lstm_bahdanau_attention()
hparams.num_heads = 4
return hparams

@registry.register_hparams
def lstm_luong_attention_multi():
""" Multi-head Luong attention """
hparams = lstm_luong_attention()
hparams.num_heads = 4
return hparams

0 comments on commit 9e7d03f

Please sign in to comment.