-
Notifications
You must be signed in to change notification settings - Fork 0
/
bert_spc.py
62 lines (51 loc) · 2.58 KB
/
bert_spc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
# file: BERT_SPC.py
# author: songyouwei <[email protected]>
# Copyright (C) 2019. All Rights Reserved.
import torch
import torch.nn as nn
from layers.attention import Attention
# from torch_multi_head_attention import MultiHeadAttention
# import numpy as np
# from text_models.leam import LEAM
# import torchsparseattn
class BERT_SPC(nn.Module):
def __init__(self, bert, opt):
super(BERT_SPC, self).__init__()
self.bert = bert
self.dropout = nn.Dropout(opt.dropout)
self.opt=opt
self.dense = nn.Linear(opt.bert_dim, opt.polarities_dim)
# self.dense = nn.Linear(opt.bert_dim*opt.max_seq_len, opt.polarities_dim)
# self.attn=nn.MultiheadAttention(embed_dim=opt.bert_dim, num_heads=3, dropout=0.1)
# self.attn=MultiHeadAttention(in_features=768, head_num=3)
self.attn=nn.TransformerEncoderLayer(d_model=opt.bert_dim,
nhead=3,
# dim_feedforward=4*opt.bert_dim,
dropout=0.1)
self.attn_k = Attention(opt.bert_dim, out_dim=opt.bert_dim, n_head=3, score_function='dot_product', dropout=opt.dropout)
# self.attn_q = Attention(opt.bert_dim, out_dim=opt.hidden_dim, n_head=8, score_function='mlp', dropout=opt.dropout)
def forward(self, inputs):
text_spc_bert_indices, bert_segments_ids = inputs[0], inputs[1]
# print(self.opt.bert_dim)
encoded_layers, pooled_output ,attention, = self.bert(text_spc_bert_indices, bert_segments_ids)
# word_output,pooled_output,bert_word_eb,attention
# encoded_layers
# new_encoded_layers=self.attn(encoded_layers)
# hc, scores = self.attn_k(encoded_layers, encoded_layers)
# for i in range(len(scores)):
# M_i=scores[i].index_select(1,poss[i])
# reg_i=torch.norm(torch.matmul(M_i.t(),M_i)-torch.eye(M_i.shape[1]).to('cuda:1'))
# reg=reg_i+reg
# # print('hc.shape',hc.shape)
# print('scores.shape',scores.shape)
# print(att_out.shape)
# att_out=torch.reshape(att_out, (-1, self.opt.bert_dim*self.opt.max_seq_len))
pooled_output = self.dropout(pooled_output)
# pooled_output = self.dropout(new_encoded_layers[:,0,:])
# pooled_output=new_encoded_layers[:,0,:]
# pooled_output=hc[:,0,:]
# pooled_output=att_out
logits = self.dense(pooled_output)
# print(logits.shape)
return logits