-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
80 lines (73 loc) · 3.78 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from scipy.stats import ortho_group
class LinearClassifier(nn.Module):
def __init__(self, hidden_dim, num_classes, bias=True):
super(LinearClassifier, self).__init__()
self.fc = nn.Linear(hidden_dim, num_classes, bias=bias)
def forward(self, x):
out = self.fc(x)
return out
# neural network
class TunnelNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, num_nonlinear_layers=0,
bias=False, init_method="default", eps=0.0):
super(TunnelNetwork, self).__init__()
# input layer of size data dim * hidden dim
self.num_layers = num_layers
num_linear_layers = num_layers - num_nonlinear_layers
self.num_nonlinear_layers = num_nonlinear_layers
self.eps = eps
# hidden layers
layers = [nn.Sequential(nn.Linear(input_dim, hidden_dim, bias=bias), nn.BatchNorm1d(hidden_dim), nn.ReLU())] #, nn.BatchNorm1d(hidden_dim)
# layers = [nn.Sequential(nn.Linear(input_dim, hidden_dim, bias=bias), nn.ReLU())]
for i in range(1, num_nonlinear_layers):
layers.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim, bias=bias), nn.BatchNorm1d(hidden_dim), nn.ReLU())) #, nn.BatchNorm1d(hidden_dim)
# layers.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim, bias=bias), nn.ReLU()))
for i in range(num_linear_layers):
layers.append(nn.Sequential(nn.Linear(hidden_dim, hidden_dim, bias=bias)))
self.layers = nn.ModuleList(layers)
# final classifier
self.fc = nn.Linear(hidden_dim, output_dim, bias=bias)
self.init_method = init_method
self.init_weight(input_dim, hidden_dim, output_dim)
def init_weight(self, input_dim, hidden_dim, output_dim):
print(f"initialize weights using {self.init_method}!")
if self.init_method == 'default':
H, W = self.fc.weight.data.shape
elif self.init_method == 'identity':
for i in range(self.num_layers):
self.layers[i][0].weight.data = torch.eye(hidden_dim) * self.eps
H, W = self.fc.weight.data.shape
self.fc.weight.data = torch.eye(hidden_dim)[:H,:W] * self.eps
elif self.init_method == "gaussian":
for i in range(self.num_layers):
nn.init.kaiming_normal_(self.layers[i][0].weight)
nn.init.kaiming_normal_(self.fc.weight)
elif self.init_method == "orthogonal":
for i in range(0, self.num_layers):
if i == 0:
weight = torch.randn(hidden_dim, input_dim)
weight = torch.linalg.svd(weight, full_matrices=False)[0]
weight = torch.cat([weight, torch.zeros(hidden_dim, input_dim-hidden_dim)],dim=1)
self.layers[i][0].weight.data = weight * self.eps # weight[:,:data_dim] * eps
else:
weight = torch.randn(hidden_dim, hidden_dim)
weight = torch.linalg.svd(weight)[0]
self.layers[i][0].weight.data = weight * self.eps
fc_weight = torch.from_numpy(ortho_group.rvs(output_dim)).float()
fc_weight = torch.cat([fc_weight, torch.zeros(output_dim, hidden_dim-output_dim)], 1)
# fc_weight = torch.randn(output_dim, hidden_dim)
self.fc.weight.data = fc_weight * self.eps
else:
raise ValueError("Init Method un-defined!")
def forward(self, x):
# store each layer's output
out_list = []
for layer in self.layers:
x = layer(x)
out_list.append(x.clone().detach())
out = self.fc(x)
return out, out_list