Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Self-Attention as a variant of tower+pool #30

Open
wants to merge 1 commit into
base: attention_encoder
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions gqn/gqn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,72 @@ def pool_encoder(frames: tf.Tensor, poses: tf.Tensor, scope="PoolEncoder"):
net = tf.reduce_mean(net, axis=[1, 2], keepdims=True)

return net, endpoints


def self_attention(layer: tf.Tensor, name: str):
"""
Self-Attention as described in Self-Attention Generative Adversarial Networks by Zhang, Goodfellow, Metaxas, and Odena
"""
m = tf.shape(layer)[0]
_, X, Y, C = layer.shape
W_f = tf.get_variable(f'W_{name}_f', shape=[C, C], initializer=tf.contrib.layers.xavier_initializer())
f = tf.tensordot(layer, W_f, [[-1], [0]])
W_g = tf.get_variable(f'W_{name}_g', shape=[C, C], initializer=tf.contrib.layers.xavier_initializer())
g = tf.tensordot(layer, W_g, [[-1], [0]])
W_h = tf.get_variable(f'W_{name}_h', shape=[C, C], initializer=tf.contrib.layers.xavier_initializer())
h = tf.tensordot(layer, W_h, [[-1], [0]])
f = tf.reshape(f, shape=[-1, X * Y * C])
g = tf.reshape(g, shape=[-1, X * Y * C])
s = tf.matmul(f, g, transpose_b=True)
s = tf.reshape(s, shape=[-1, X * Y * C, X * Y * C])
s = tf.nn.softmax(s, -1)
o = tf.tensordot(s, tf.reshape(h, [-1, X*Y*C]), [[2], [1]])
o = tf.reshape(h, [-1, X, Y, C])
gamma = tf.get_variable(f'gamma_{name}', shape=[1, X, Y, C], initializer=tf.zeros_initializer())
return gamma * o + layer


def sa_encoder(frames: tf.Tensor, poses: tf.Tensor, scope="SAEncoder"):
"""
Feed-forward convolutional architecture with self-attention (modified from tower+pool to add self-attention.)
"""
with tf.variable_scope(scope):
endpoints = {}
net = tf.layers.conv2d(frames, filters=256, kernel_size=2, strides=2,
padding="VALID", activation=tf.nn.relu)
net = self_attention(net, "l1")
skip1 = tf.layers.conv2d(net, filters=128, kernel_size=1, strides=1,
padding="SAME", activation=None)
net = tf.layers.conv2d(net, filters=128, kernel_size=3, strides=1,
padding="SAME", activation=tf.nn.relu)
net = net + skip1
net = self_attention(net, "l2")
net = tf.layers.conv2d(net, filters=256, kernel_size=2, strides=2,
padding="VALID", activation=tf.nn.relu)
net = self_attention(net, "l3")

# tile the poses to match the embedding shape
height, width = tf.shape(net)[1], tf.shape(net)[2]
poses = broadcast_pose(poses, height, width)

# concatenate the poses with the embedding
net = tf.concat([net, poses], axis=3)

skip2 = tf.layers.conv2d(net, filters=128, kernel_size=1, strides=1,
padding="SAME", activation=None)
net = tf.layers.conv2d(net, filters=128, kernel_size=3, strides=1,
padding="SAME", activation=tf.nn.relu)
net = net + skip2
net = self_attention(net, "l4")

net = tf.layers.conv2d(net, filters=256, kernel_size=3, strides=1,
padding="SAME", activation=tf.nn.relu)
net = self_attention(net, "l5")

net = tf.layers.conv2d(net, filters=256, kernel_size=1, strides=1,
padding="SAME", activation=tf.nn.relu)
net = self_attention(net, "l6")

net = tf.reduce_mean(net, axis=[1, 2], keepdims=True)

return net, endpoints
5 changes: 3 additions & 2 deletions gqn/gqn_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import tensorflow as tf

from .gqn_params import GQNConfig
from .gqn_encoder import tower_encoder, pool_encoder
from .gqn_encoder import tower_encoder, pool_encoder, sa_encoder
from .gqn_draw import inference_rnn, generator_rnn
from .gqn_utils import broadcast_encoding, compute_eta_and_sample_z
from .gqn_vae import vae_tower_decoder
Expand All @@ -20,6 +20,7 @@
_ENC_FUNCTIONS = { # switch for different encoding functions
'pool' : pool_encoder,
'tower' : tower_encoder,
'sa' : sa_encoder,
}


Expand Down Expand Up @@ -105,7 +106,7 @@ def gqn_draw(
endpoints.update(endpoints_enc)

# broadcast scene representation to 1/4 of targeted frame size
if enc_type == 'pool':
if enc_type == 'pool' or enc_type == 'sa':
enc_r_broadcast = broadcast_encoding(
vector=enc_r, height=dim_h_enc, width=dim_w_enc)
else:
Expand Down
4 changes: 4 additions & 0 deletions train_gqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
ARGPARSER.add_argument(
'--img_size', type=int, default=64,
help='Height and width of the squared input images.')
ARGPARSER.add_argument(
'--enc_type', type=str, default='pool',
help='The encoding architecture type.')
# solver parameters
ARGPARSER.add_argument(
'--adam_lr_alpha', type=float, default=5*10e-5,
Expand Down Expand Up @@ -152,6 +155,7 @@ def main(unparsed_argv):
'IMG_WIDTH' : ARGS.img_size,
'CONTEXT_SIZE' : ARGS.context_size,
'SEQ_LENGTH' : ARGS.seq_length,
'ENC_TYPE' : ARGS.enc_type,
'ENC_HEIGHT' : ARGS.img_size // 4, # must be 1/4 of target frame height
'ENC_WIDTH' : ARGS.img_size // 4, # must be 1/4 of target frame width
'ADAM_LR_ALPHA' : ARGS.adam_lr_alpha,
Expand Down