Skip to content

Commit

Permalink
TD-BERT Implementation
Browse files Browse the repository at this point in the history
TD-BERT implementation from songyouwei#147
  • Loading branch information
prasys committed Jul 22, 2021
1 parent 1584032 commit 8be7f0e
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
2 changes: 2 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def __init__(self, fname, tokenizer):
'aspect_boundary': aspect_boundary,
'dependency_graph': dependency_graph,
'polarity': polarity,
'left_context_len': left_context_len,
'aspect_len': aspect_len,
}

all_data.append(data)
Expand Down
50 changes: 50 additions & 0 deletions models/td_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# file: td_bert.py
# author: xiangpan <[email protected]>
# Copyright (C) 2020. All Rights Reserved.
import torch
import torch.nn as nn
from layers.attention import Attention


class TD_BERT(nn.Module):
def __init__(self, bert, opt):
super(TD_BERT, self).__init__()
self.bert = bert
self.dropout = nn.Dropout(opt.dropout)
self.opt = opt
self.dense = nn.Linear(opt.bert_dim, opt.polarities_dim)

def forward(self, inputs):
text_bert_indices, bert_segments_ids, left_context_len, aspect_len = (
inputs[0],
inputs[1],
inputs[2],
inputs[3],
)

encoded_layers, cls_output = self.bert(
text_bert_indices, bert_segments_ids
)


pooled_list = []
for i in range(0, encoded_layers.shape[0]): # batch_size i th batch
encoded_layers_i = encoded_layers[i]
left_context_len_i = left_context_len[i]
aspect_len_i = aspect_len[i]
e_list = []
if (left_context_len_i + 1) == (left_context_len_i + 1 + aspect_len_i):
e_list.append(encoded_layers_i[0])
for j in range(left_context_len_i + 1, left_context_len_i + 1 + aspect_len_i):
e_list.append(encoded_layers_i[j])
e = torch.stack(e_list, 0)
embed = torch.stack([e], 0)
pooled = nn.functional.max_pool2d(embed, (embed.size(1), 1)).squeeze(1)
pooled_list.append(pooled)
pooled_output = torch.cat(pooled_list)
pooled_output = self.dropout(pooled_output)

logits = self.dense(pooled_output)

return logits
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from models.bert_spc import BERT_SPC
from models.albert_spc import ALBERT_SPC
from models.roberta_spc import ROBERTA_SPC
from models.td_bert import TD_BERT

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -249,6 +250,8 @@ def main():
'lcf_bert': LCF_BERT,
'albert_spc': ALBERT_SPC,
'roberta_spc': ROBERTA_SPC,
'td_bert': TD_BERT,

# default hyper-parameters for LCF-BERT model is as follws:
# lr: 2e-5
# l2: 1e-5
Expand Down Expand Up @@ -298,6 +301,7 @@ def main():
'aen_bert': ['text_bert_indices', 'aspect_bert_indices'],
'aen_bert': ['text_bert_indices', 'aspect_bert_indices'],
'lcf_bert': ['concat_bert_indices', 'concat_segments_indices', 'text_bert_indices', 'aspect_bert_indices'],
'td_bert': ['text_bert_indices', 'bert_segments_ids','left_context_len','aspect_len'],
}
initializers = {
'xavier_uniform_': torch.nn.init.xavier_uniform_,
Expand Down

0 comments on commit 8be7f0e

Please sign in to comment.