diff --git a/TCN/tcn.py b/TCN/tcn.py index 37dfd43..3a332c2 100644 --- a/TCN/tcn.py +++ b/TCN/tcn.py @@ -15,27 +15,25 @@ def forward(self, x): class TemporalBlock(nn.Module): def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): super(TemporalBlock, self).__init__() - self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, - stride=stride, padding=padding, dilation=dilation)) - self.chomp1 = Chomp1d(padding) - self.relu1 = nn.ReLU() - self.dropout1 = nn.Dropout(dropout) - - self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, - stride=stride, padding=padding, dilation=dilation)) - self.chomp2 = Chomp1d(padding) - self.relu2 = nn.ReLU() - self.dropout2 = nn.Dropout(dropout) - - self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, - self.conv2, self.chomp2, self.relu2, self.dropout2) + conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + chomp1 = Chomp1d(padding) + relu1 = nn.ReLU() + dropout1 = nn.Dropout(dropout) + + conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, + stride=stride, padding=padding, dilation=dilation)) + chomp2 = Chomp1d(padding) + relu2 = nn.ReLU() + dropout2 = nn.Dropout(dropout) + + self.net = nn.Sequential(conv1, chomp1, relu1, dropout1, + conv2, chomp2, relu2, dropout2) self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None self.relu = nn.ReLU() - self.init_weights() - def init_weights(self): - self.conv1.weight.data.normal_(0, 0.01) - self.conv2.weight.data.normal_(0, 0.01) + conv1.weight.data.normal_(0, 0.01) + conv2.weight.data.normal_(0, 0.01) if self.downsample is not None: self.downsample.weight.data.normal_(0, 0.01)