-
Notifications
You must be signed in to change notification settings - Fork 8
/
model_causal.py
executable file
·266 lines (204 loc) · 9.01 KB
/
model_causal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
from __future__ import division, print_function
__author__ = "Lauri Juvela, [email protected]"
import os
import sys
import math
import numpy as np
import tensorflow as tf
_FLOATX = tf.float32 # todo: move to lib/precision.py
def get_weight_variable(name, shape=None, initial_value=None):
if shape is None:
return tf.get_variable(name)
if initial_value is None:
initializer = tf.contrib.layers.xavier_initializer_conv2d()
W = tf.get_variable(name, shape=shape, dtype=_FLOATX, initializer=initializer)
else:
W = tf.Variable(initial_value)
return W
def get_bias_variable(name, shape=None, initializer=tf.constant_initializer(value=0.0, dtype=_FLOATX)):
return tf.get_variable(name, shape=shape, dtype=_FLOATX, initializer=initializer)
def convolution(X, W, dilation=1, causal=True):
"""
Applies 1D convolution
Args:
X: input tensor of shape (batch, timesteps, in_channels)
W: weight tensor of shape (filter_width, in_channels, out_channels)
dilation: int value for dilation
causal: bool flag for causal convolution
Returns:
Y: output tensor of shape (batch, timesteps, out_channels)
"""
if causal:
fw = tf.shape(W)[0]
pad = (fw - 1) * dilation
Y = tf.pad(X, paddings=[[0,0], [pad,0], [0,0]])
Y = tf.nn.convolution(Y, W, padding="VALID", dilation_rate=[dilation])
else:
Y = tf.nn.convolution(X, W, padding="SAME", dilation_rate=[dilation])
return Y
class WaveNet():
"""
TensorFlow WaveNet object
Initialization Args:
name: string used for variable namespacing
user is responsible for unique names if multiple models are used
residual_channels: number of channels used in the convolution layers
postnet_channels:
filter_width:
dilations: list of integers containing the dilation pattern
list length determines the number of dilated blocks used
input_channels:
causal: if True, use causal convolutions everywhere in the network
conv_block_gate: if True, use gated convolutions in the dilated blocks
conv_block_affine_out: if True, apply a 1x1 convolution in dilated blocks before the residual connection
Functions:
Members:
"""
def __init__(self,
name,
residual_channels=64,
postnet_channels=64,
filter_width=3,
dilations=[1, 2, 4, 8, 1, 2, 4, 8],
input_channels=1,
output_channels=1,
cond_channels=None,
cond_embed_dim = 64,
causal=True,
conv_block_gate=True,
conv_block_affine_out=True,
add_noise_at_each_layer=False
):
self.input_channels = input_channels
self.output_channels = output_channels
self.filter_width = filter_width
self.dilations = dilations
self.residual_channels = residual_channels
self.postnet_channels = postnet_channels
self.causal = causal
self.conv_block_gate = conv_block_gate
self.conv_block_affine_out = conv_block_affine_out
self.add_noise_at_each_layer = add_noise_at_each_layer
if cond_channels is not None:
self._use_cond = True
self.cond_embed_dim = cond_embed_dim
self.cond_channels = cond_channels
else:
self._use_cond = False
self._name = name
def get_receptive_field(self):
receptive_field = (self.filter_width - 1) * sum(self.dilations) + 1 # due to dilation layers
receptive_field += self.filter_width - 1 # due to input layer (if not 1x1)
if not self.causal:
receptive_field = 2 * receptive_field - 1
return receptive_field
def get_variable_list(self):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self._name)
def _input_layer(self, main_input):
with tf.variable_scope('input_layer'):
r = self.residual_channels
fw = self.filter_width
W = get_weight_variable('W', (fw, self.input_channels, r))
b = get_bias_variable('b', (r))
X = main_input
Y = convolution(X, W, causal=self.causal)
Y += b
Y = tf.tanh(Y)
return Y
def _embed_cond(self, cond_input):
with tf.variable_scope('embed_cond'):
W = get_weight_variable('W', (1, self.cond_channels, self.cond_embed_dim))
b = get_bias_variable('b', (self.cond_embed_dim))
Y = convolution(cond_input, W, causal=self.causal) # 1x1 convolution
Y += b
return tf.tanh(Y)
def _conv_module(self, main_input, module_idx, dilation, cond_input=None):
with tf.variable_scope('conv_modules'):
with tf.variable_scope('module{}'.format(module_idx)):
fw = self.filter_width
r = self.residual_channels
X = main_input
if self.conv_block_gate:
# convolution
W = get_weight_variable('filter_gate_W', (fw, r, 2*r))
b = get_bias_variable('filter_gate_b', (2*r))
Y = convolution(X, W,
dilation=dilation,
causal=self.causal)
Y += b
# conditioning
if self._use_cond:
V = get_weight_variable('cond_filter_gate_W',
(1, self.cond_embed_dim, 2*r))
b = get_bias_variable('cond_filter_gate_b', (2*r))
C = convolution(cond_input, V) # 1x1 convolution
Y += C + b
if self.add_noise_at_each_layer:
W = get_weight_variable('noise_scaling_W',
(1, 1, r))
Z = tf.random_normal(shape=tf.shape(Y[..., :r]))
Y += tf.concat([W * Z, tf.zeros_like(Y[..., r:])], axis=-1)
# filter and gate
Y = tf.tanh(Y[..., :r]) * tf.sigmoid(Y[..., r:])
else:
# convolution
W = get_weight_variable('filter_gate_W', (fw, r, r))
b = get_bias_variable('filter_gate_b', (r))
Y = convolution(X, W,
dilation=dilation,
causal=self.causal)
Y += b
# conditioning
if self._use_cond:
V = get_weight_variable('cond_filter_gate_W',
(1, self.cond_embed_dim, r))
b = get_bias_variable('cond_filter_gate_b', (r))
C = convolution(cond_input, V) # 1x1 convolution
Y += C + b
if self.add_noise_at_each_layer:
W = get_weight_variable('noise_scaling_W',
(1, 1, r))
Z = tf.random_normal(shape=tf.shape(Y))
Y += W * Z
# activation
Y = tf.tanh(Y)
skip_out = Y
if self.conv_block_affine_out:
W = get_weight_variable('output_W', (1, r, r))
b = get_bias_variable('output_b', (r))
Y = convolution(Y, W) + b
# residual connection
Y += X
return Y, skip_out
def _postproc_module(self, residual_module_outputs):
with tf.variable_scope('postproc_module'):
s = self.postnet_channels
r = self.residual_channels
d = len(self.dilations)
# concat and convolve
W1 = get_weight_variable('W1', (1, d*r, s))
b1 = get_bias_variable('b1', s)
X = tf.concat(residual_module_outputs, axis=-1) # concat along channel dim
Y = convolution(X, W1)
Y += b1
Y = tf.nn.tanh(Y)
# output layer
W2 = get_weight_variable('W2', (1, s, self.output_channels))
b2 = get_bias_variable('b2', self.output_channels)
Y = convolution(Y, W2)
Y += b2
return Y
def forward_pass(self, X_input, cond_input=None):
skip_outputs = []
with tf.variable_scope(self._name, reuse=tf.AUTO_REUSE):
if self._use_cond:
C = self._embed_cond(cond_input)
else:
C = None
R = self._input_layer(X_input)
X = R
for i, dilation in enumerate(self.dilations):
X, skip = self._conv_module(X, i, dilation, cond_input=C)
skip_outputs.append(skip)
Y = self._postproc_module(skip_outputs)
return Y