-
Notifications
You must be signed in to change notification settings - Fork 0
/
aen.py
130 lines (103 loc) · 5.36 KB
/
aen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# -*- coding: utf-8 -*-
# file: aen.py
# author: songyouwei <[email protected]>
# Copyright (C) 2018. All Rights Reserved.
from layers.dynamic_rnn import DynamicLSTM
from layers.squeeze_embedding import SqueezeEmbedding
from layers.attention import Attention, NoQueryAttention
from layers.point_wise_feed_forward import PositionwiseFeedForward
import torch
import torch.nn as nn
import torch.nn.functional as F
# CrossEntropyLoss for Label Smoothing Regularization
class CrossEntropyLoss_LSR(nn.Module):
def __init__(self, device, para_LSR=0.2):
super(CrossEntropyLoss_LSR, self).__init__()
self.para_LSR = para_LSR
self.device = device
self.logSoftmax = nn.LogSoftmax(dim=-1)
def _toOneHot_smooth(self, label, batchsize, classes):
prob = self.para_LSR * 1.0 / classes
one_hot_label = torch.zeros(batchsize, classes) + prob
for i in range(batchsize):
index = label[i]
one_hot_label[i, index] += (1.0 - self.para_LSR)
return one_hot_label
def forward(self, pre, label, size_average=True):
b, c = pre.size()
one_hot_label = self._toOneHot_smooth(label, b, c).to(self.device)
loss = torch.sum(-one_hot_label * self.logSoftmax(pre), dim=1)
if size_average:
return torch.mean(loss)
else:
return torch.sum(loss)
class AEN_GloVe(nn.Module):
def __init__(self, embedding_matrix, opt):
super(AEN, self).__init__()
self.opt = opt
self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
self.squeeze_embedding = SqueezeEmbedding()
self.attn_k = Attention(opt.embed_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.attn_q = Attention(opt.embed_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.ffn_c = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.ffn_t = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.attn_s1 = Attention(opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.dense = nn.Linear(opt.hidden_dim*3, opt.polarities_dim)
def forward(self, inputs):
text_raw_indices, target_indices = inputs[0], inputs[1]
context_len = torch.sum(text_raw_indices != 0, dim=-1)
target_len = torch.sum(target_indices != 0, dim=-1)
context = self.embed(text_raw_indices)
context = self.squeeze_embedding(context, context_len)
target = self.embed(target_indices)
target = self.squeeze_embedding(target, target_len)
hc, _ = self.attn_k(context, context)
hc = self.ffn_c(hc)
ht, _ = self.attn_q(context, target)
ht = self.ffn_t(ht)
s1, _ = self.attn_s1(hc, ht)
context_len = torch.tensor(context_len, dtype=torch.float).to(self.opt.device)
target_len = torch.tensor(target_len, dtype=torch.float).to(self.opt.device)
hc_mean = torch.div(torch.sum(hc, dim=1), context_len.view(context_len.size(0), 1))
ht_mean = torch.div(torch.sum(ht, dim=1), target_len.view(target_len.size(0), 1))
s1_mean = torch.div(torch.sum(s1, dim=1), context_len.view(context_len.size(0), 1))
x = torch.cat((hc_mean, s1_mean, ht_mean), dim=-1)
out = self.dense(x)
return out
class AEN_BERT(nn.Module):
def __init__(self, bert, opt):
super(AEN_BERT, self).__init__()
self.opt = opt
self.bert = bert
self.squeeze_embedding = SqueezeEmbedding()
self.dropout = nn.Dropout(opt.dropout)
self.attn_k = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.attn_q = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.ffn_c = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.ffn_t = PositionwiseFeedForward(opt.hidden_dim, dropout=opt.dropout)
self.attn_s1 = Attention(opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
self.dense = nn.Linear(opt.hidden_dim*3, opt.polarities_dim)
def forward(self, inputs):
context, target = inputs[0], inputs[1]
context_len = torch.sum(context != 0, dim=-1)
target_len = torch.sum(target != 0, dim=-1)
context = self.squeeze_embedding(context, context_len)
context, _ = self.bert(context)
context = self.dropout(context)
target = self.squeeze_embedding(target, target_len)
target, _ = self.bert(target)
target = self.dropout(target)
hc, _ = self.attn_k(context, context)
hc = self.ffn_c(hc)
ht, _ = self.attn_q(context, target)
ht = self.ffn_t(ht)
s1, _ = self.attn_s1(hc, ht)
print(s1.shape)
context_len = torch.tensor(context_len, dtype=torch.float).to(self.opt.device)
target_len = torch.tensor(target_len, dtype=torch.float).to(self.opt.device)
hc_mean = torch.div(torch.sum(hc, dim=1), context_len.view(context_len.size(0), 1))
ht_mean = torch.div(torch.sum(ht, dim=1), target_len.view(target_len.size(0), 1))
s1_mean = torch.div(torch.sum(s1, dim=1), context_len.view(context_len.size(0), 1))
x = torch.cat((hc_mean, s1_mean, ht_mean), dim=-1)
out = self.dense(x)
return out