forked from Sreyan88/DiscLSTM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
discLSTM.py
71 lines (61 loc) · 2.67 KB
/
discLSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch.nn as nn
import math
import torch
class DiscLSTM(nn.Module):
def __init__(self, input_sz, hidden_sz, g_sz):
super(DiscLSTM, self).__init__()
self.input_sz = input_sz
self.hidden_sz = hidden_sz
self.g_sz = g_sz
self.all1 = nn.Linear((self.hidden_sz * 1 + self.input_sz * 1), self.hidden_sz)
self.all2 = nn.Linear((self.hidden_sz * 1 + self.input_sz +self.g_sz), self.hidden_sz)
self.all3 = nn.Linear((self.hidden_sz * 1 + self.input_sz +self.g_sz), self.hidden_sz)
self.all4 = nn.Linear((self.hidden_sz * 1 + self.input_sz * 1), self.hidden_sz)
self.all11 = nn.Linear((self.hidden_sz * 1 + self.g_sz), self.hidden_sz)
self.all44 = nn.Linear((self.hidden_sz * 1 + self.g_sz), self.hidden_sz)
self.init_weights()
self.drop = nn.Dropout(0.5)
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_sz)
for weight in self.parameters():
nn.init.uniform_(weight, -stdv, stdv)
def node_forward(self, xt, ht, Ct_x, mt, Ct_m):
# # # new standard lstm
hx_concat = torch.cat((ht, xt), dim=1)
hm_concat = torch.cat((ht, mt), dim=1)
hxm_concat = torch.cat((ht, xt, mt), dim=1)
i = self.all1(hx_concat)
o = self.all2(hxm_concat)
f = self.all3(hxm_concat)
u = self.all4(hx_concat)
ii = self.all11(hm_concat)
uu = self.all44(hm_concat)
i, f, o, u = torch.sigmoid(i), torch.sigmoid(f), torch.sigmoid(o), torch.tanh(u)
ii,uu = torch.sigmoid(ii), torch.tanh(uu)
Ct_x = i * u + ii * uu + f * Ct_x
ht = o * torch.tanh(Ct_x)
return ht, Ct_x, Ct_m
def forward(self, x, m, init_stat=None):
batch_sz, seq_sz, _ = x.size()
hidden_seq = []
cell_seq = []
if init_stat is None:
ht = torch.zeros((batch_sz, self.hidden_sz)).to(x.device)
Ct_x = torch.zeros((batch_sz, self.hidden_sz)).to(x.device)
Ct_m = torch.zeros((batch_sz, self.hidden_sz)).to(x.device)
else:
ht, Ct = init_stat
for t in range(seq_sz): # iterate over the time steps
xt = x[:, t, :]
mt = m[:, t, :]
ht, Ct_x, Ct_m= self.node_forward(xt, ht, Ct_x, mt, Ct_m)
hidden_seq.append(ht)
cell_seq.append(Ct_x)
if t == 0:
mht = ht
mct = Ct_x
else:
mht = torch.max(torch.stack(hidden_seq), dim=0)[0]
mct = torch.max(torch.stack(cell_seq), dim=0)[0]
hidden_seq = torch.stack(hidden_seq).permute(1, 0, 2) ##batch_size x max_len x hidden
return hidden_seq