From 8be7f0eba7b4043ab8cc3e7f6eeba488f1783f56 Mon Sep 17 00:00:00 2001 From: prasys Date: Thu, 22 Jul 2021 12:11:48 +1200 Subject: [PATCH] TD-BERT Implementation TD-BERT implementation from https://github.com/songyouwei/ABSA-PyTorch/pull/147 --- data_utils.py | 2 ++ models/td_bert.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++ train.py | 4 ++++ 3 files changed, 56 insertions(+) create mode 100644 models/td_bert.py diff --git a/data_utils.py b/data_utils.py index 4ce1ad2..8c49c8a 100644 --- a/data_utils.py +++ b/data_utils.py @@ -205,6 +205,8 @@ def __init__(self, fname, tokenizer): 'aspect_boundary': aspect_boundary, 'dependency_graph': dependency_graph, 'polarity': polarity, + 'left_context_len': left_context_len, + 'aspect_len': aspect_len, } all_data.append(data) diff --git a/models/td_bert.py b/models/td_bert.py new file mode 100644 index 0000000..b605827 --- /dev/null +++ b/models/td_bert.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# file: td_bert.py +# author: xiangpan +# Copyright (C) 2020. All Rights Reserved. +import torch +import torch.nn as nn +from layers.attention import Attention + + +class TD_BERT(nn.Module): + def __init__(self, bert, opt): + super(TD_BERT, self).__init__() + self.bert = bert + self.dropout = nn.Dropout(opt.dropout) + self.opt = opt + self.dense = nn.Linear(opt.bert_dim, opt.polarities_dim) + + def forward(self, inputs): + text_bert_indices, bert_segments_ids, left_context_len, aspect_len = ( + inputs[0], + inputs[1], + inputs[2], + inputs[3], + ) + + encoded_layers, cls_output = self.bert( + text_bert_indices, bert_segments_ids + ) + + + pooled_list = [] + for i in range(0, encoded_layers.shape[0]): # batch_size i th batch + encoded_layers_i = encoded_layers[i] + left_context_len_i = left_context_len[i] + aspect_len_i = aspect_len[i] + e_list = [] + if (left_context_len_i + 1) == (left_context_len_i + 1 + aspect_len_i): + e_list.append(encoded_layers_i[0]) + for j in range(left_context_len_i + 1, left_context_len_i + 1 + aspect_len_i): + e_list.append(encoded_layers_i[j]) + e = torch.stack(e_list, 0) + embed = torch.stack([e], 0) + pooled = nn.functional.max_pool2d(embed, (embed.size(1), 1)).squeeze(1) + pooled_list.append(pooled) + pooled_output = torch.cat(pooled_list) + pooled_output = self.dropout(pooled_output) + + logits = self.dense(pooled_output) + + return logits \ No newline at end of file diff --git a/train.py b/train.py index bf32348..a737786 100644 --- a/train.py +++ b/train.py @@ -26,6 +26,7 @@ from models.bert_spc import BERT_SPC from models.albert_spc import ALBERT_SPC from models.roberta_spc import ROBERTA_SPC +from models.td_bert import TD_BERT logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -249,6 +250,8 @@ def main(): 'lcf_bert': LCF_BERT, 'albert_spc': ALBERT_SPC, 'roberta_spc': ROBERTA_SPC, + 'td_bert': TD_BERT, + # default hyper-parameters for LCF-BERT model is as follws: # lr: 2e-5 # l2: 1e-5 @@ -298,6 +301,7 @@ def main(): 'aen_bert': ['text_bert_indices', 'aspect_bert_indices'], 'aen_bert': ['text_bert_indices', 'aspect_bert_indices'], 'lcf_bert': ['concat_bert_indices', 'concat_segments_indices', 'text_bert_indices', 'aspect_bert_indices'], + 'td_bert': ['text_bert_indices', 'bert_segments_ids','left_context_len','aspect_len'], } initializers = { 'xavier_uniform_': torch.nn.init.xavier_uniform_,