Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TD-BERT #147

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ Xu, Hu, et al. "Bert post-training for review reading comprehension and aspect-b

Sun, Chi, Luyao Huang, and Xipeng Qiu. "Utilizing bert for aspect-based sentiment analysis via constructing auxiliary sentence." arXiv preprint arXiv:1903.09588 (2019). [[pdf](https://arxiv.org/pdf/1903.09588.pdf)]

### TD-BERT([td_bert.py](./models/td_bert.py))
Z. Gao, A. Feng, X. Song and X. Wu, "Target-Dependent Sentiment Classification With BERT," in IEEE Access, vol. 7, pp. 154290-154299, 2019, doi: 10.1109/ACCESS.2019.2946594.[[pdf]](https://ieeexplore.ieee.org/abstract/document/8864964)

### LCF-BERT ([lcf_bert.py](./models/lcf_bert.py)) ([official](https://github.com/yangheng95/LCF-ABSA))

Zeng Biqing, Yang Heng, et al. "LCF: A Local Context Focus Mechanism for Aspect-Based Sentiment Classification." Applied Sciences. 2019, 9, 3389. [[pdf]](https://www.mdpi.com/2076-3417/9/16/3389/pdf)
Expand Down
2 changes: 2 additions & 0 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def __init__(self, fname, tokenizer):
'aspect_indices': aspect_indices,
'aspect_in_text': aspect_in_text,
'polarity': polarity,
'left_context_len': left_context_len,
'aspect_len': aspect_len,
}

all_data.append(data)
Expand Down
50 changes: 50 additions & 0 deletions models/td_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# file: td_bert.py
# author: xiangpan <[email protected]>
# Copyright (C) 2020. All Rights Reserved.
import torch
import torch.nn as nn
from layers.attention import Attention


class TD_BERT(nn.Module):
def __init__(self, bert, opt):
super(TD_BERT, self).__init__()
self.bert = bert
self.dropout = nn.Dropout(opt.dropout)
self.opt = opt
self.dense = nn.Linear(opt.bert_dim, opt.polarities_dim)

def forward(self, inputs):
text_bert_indices, bert_segments_ids, left_context_len, aspect_len = (
inputs[0],
inputs[1],
inputs[2],
inputs[3],
)

encoded_layers, cls_output = self.bert(
text_bert_indices, bert_segments_ids
)


pooled_list = []
for i in range(0, encoded_layers.shape[0]): # batch_size i th batch
encoded_layers_i = encoded_layers[i]
left_context_len_i = left_context_len[i]
aspect_len_i = aspect_len[i]
e_list = []
if (left_context_len_i + 1) == (left_context_len_i + 1 + aspect_len_i):
e_list.append(encoded_layers_i[0])
for j in range(left_context_len_i + 1, left_context_len_i + 1 + aspect_len_i):
e_list.append(encoded_layers_i[j])
e = torch.stack(e_list, 0)
embed = torch.stack([e], 0)
pooled = nn.functional.max_pool2d(embed, (embed.size(1), 1)).squeeze(1)
pooled_list.append(pooled)
pooled_output = torch.cat(pooled_list)
pooled_output = self.dropout(pooled_output)

logits = self.dense(pooled_output)

return logits
3 changes: 3 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from models import LSTM, IAN, MemNet, RAM, TD_LSTM, TC_LSTM, Cabasc, ATAE_LSTM, TNet_LF, AOA, MGAN, LCF_BERT
from models.aen import CrossEntropyLoss_LSR, AEN_BERT
from models.bert_spc import BERT_SPC
from models.td_bert import TD_BERT

logger = logging.getLogger()
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -225,6 +226,7 @@ def main():
'bert_spc': BERT_SPC,
'aen_bert': AEN_BERT,
'lcf_bert': LCF_BERT,
'td_bert': TD_BERT,
# default hyper-parameters for LCF-BERT model is as follws:
# lr: 2e-5
# l2: 1e-5
Expand Down Expand Up @@ -260,6 +262,7 @@ def main():
'bert_spc': ['text_bert_indices', 'bert_segments_ids'],
'aen_bert': ['text_raw_bert_indices', 'aspect_bert_indices'],
'lcf_bert': ['text_bert_indices', 'bert_segments_ids', 'text_raw_bert_indices', 'aspect_bert_indices'],
'td_bert': ['text_bert_indices', 'bert_segments_ids','left_context_len','aspect_len'],
}
initializers = {
'xavier_uniform_': torch.nn.init.xavier_uniform_,
Expand Down