diff --git a/gqn/gqn_encoder.py b/gqn/gqn_encoder.py index ab58168..9158fe8 100644 --- a/gqn/gqn_encoder.py +++ b/gqn/gqn_encoder.py @@ -60,3 +60,72 @@ def pool_encoder(frames: tf.Tensor, poses: tf.Tensor, scope="PoolEncoder"): net = tf.reduce_mean(net, axis=[1, 2], keepdims=True) return net, endpoints + + +def self_attention(layer: tf.Tensor, name: str): + """ + Self-Attention as described in Self-Attention Generative Adversarial Networks by Zhang, Goodfellow, Metaxas, and Odena + """ + m = tf.shape(layer)[0] + _, X, Y, C = layer.shape + W_f = tf.get_variable(f'W_{name}_f', shape=[C, C], initializer=tf.contrib.layers.xavier_initializer()) + f = tf.tensordot(layer, W_f, [[-1], [0]]) + W_g = tf.get_variable(f'W_{name}_g', shape=[C, C], initializer=tf.contrib.layers.xavier_initializer()) + g = tf.tensordot(layer, W_g, [[-1], [0]]) + W_h = tf.get_variable(f'W_{name}_h', shape=[C, C], initializer=tf.contrib.layers.xavier_initializer()) + h = tf.tensordot(layer, W_h, [[-1], [0]]) + f = tf.reshape(f, shape=[-1, X * Y * C]) + g = tf.reshape(g, shape=[-1, X * Y * C]) + s = tf.matmul(f, g, transpose_b=True) + s = tf.reshape(s, shape=[-1, X * Y * C, X * Y * C]) + s = tf.nn.softmax(s, -1) + o = tf.tensordot(s, tf.reshape(h, [-1, X*Y*C]), [[2], [1]]) + o = tf.reshape(h, [-1, X, Y, C]) + gamma = tf.get_variable(f'gamma_{name}', shape=[1, X, Y, C], initializer=tf.zeros_initializer()) + return gamma * o + layer + + +def sa_encoder(frames: tf.Tensor, poses: tf.Tensor, scope="SAEncoder"): + """ + Feed-forward convolutional architecture with self-attention (modified from tower+pool to add self-attention.) + """ + with tf.variable_scope(scope): + endpoints = {} + net = tf.layers.conv2d(frames, filters=256, kernel_size=2, strides=2, + padding="VALID", activation=tf.nn.relu) + net = self_attention(net, "l1") + skip1 = tf.layers.conv2d(net, filters=128, kernel_size=1, strides=1, + padding="SAME", activation=None) + net = tf.layers.conv2d(net, filters=128, kernel_size=3, strides=1, + padding="SAME", activation=tf.nn.relu) + net = net + skip1 + net = self_attention(net, "l2") + net = tf.layers.conv2d(net, filters=256, kernel_size=2, strides=2, + padding="VALID", activation=tf.nn.relu) + net = self_attention(net, "l3") + + # tile the poses to match the embedding shape + height, width = tf.shape(net)[1], tf.shape(net)[2] + poses = broadcast_pose(poses, height, width) + + # concatenate the poses with the embedding + net = tf.concat([net, poses], axis=3) + + skip2 = tf.layers.conv2d(net, filters=128, kernel_size=1, strides=1, + padding="SAME", activation=None) + net = tf.layers.conv2d(net, filters=128, kernel_size=3, strides=1, + padding="SAME", activation=tf.nn.relu) + net = net + skip2 + net = self_attention(net, "l4") + + net = tf.layers.conv2d(net, filters=256, kernel_size=3, strides=1, + padding="SAME", activation=tf.nn.relu) + net = self_attention(net, "l5") + + net = tf.layers.conv2d(net, filters=256, kernel_size=1, strides=1, + padding="SAME", activation=tf.nn.relu) + net = self_attention(net, "l6") + + net = tf.reduce_mean(net, axis=[1, 2], keepdims=True) + + return net, endpoints diff --git a/gqn/gqn_graph.py b/gqn/gqn_graph.py index 8cdcd31..3865753 100644 --- a/gqn/gqn_graph.py +++ b/gqn/gqn_graph.py @@ -9,7 +9,7 @@ import tensorflow as tf from .gqn_params import GQNConfig -from .gqn_encoder import tower_encoder, pool_encoder +from .gqn_encoder import tower_encoder, pool_encoder, sa_encoder from .gqn_draw import inference_rnn, generator_rnn from .gqn_utils import broadcast_encoding, compute_eta_and_sample_z from .gqn_vae import vae_tower_decoder @@ -20,6 +20,7 @@ _ENC_FUNCTIONS = { # switch for different encoding functions 'pool' : pool_encoder, 'tower' : tower_encoder, + 'sa' : sa_encoder, } @@ -105,7 +106,7 @@ def gqn_draw( endpoints.update(endpoints_enc) # broadcast scene representation to 1/4 of targeted frame size - if enc_type == 'pool': + if enc_type == 'pool' or enc_type == 'sa': enc_r_broadcast = broadcast_encoding( vector=enc_r, height=dim_h_enc, width=dim_w_enc) else: diff --git a/train_gqn.py b/train_gqn.py index 8022961..f9dead8 100644 --- a/train_gqn.py +++ b/train_gqn.py @@ -46,6 +46,9 @@ ARGPARSER.add_argument( '--img_size', type=int, default=64, help='Height and width of the squared input images.') +ARGPARSER.add_argument( + '--enc_type', type=str, default='pool', + help='The encoding architecture type.') # solver parameters ARGPARSER.add_argument( '--adam_lr_alpha', type=float, default=5*10e-5, @@ -152,6 +155,7 @@ def main(unparsed_argv): 'IMG_WIDTH' : ARGS.img_size, 'CONTEXT_SIZE' : ARGS.context_size, 'SEQ_LENGTH' : ARGS.seq_length, + 'ENC_TYPE' : ARGS.enc_type, 'ENC_HEIGHT' : ARGS.img_size // 4, # must be 1/4 of target frame height 'ENC_WIDTH' : ARGS.img_size // 4, # must be 1/4 of target frame width 'ADAM_LR_ALPHA' : ARGS.adam_lr_alpha,