-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
32 lines (27 loc) · 1015 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.nn as nn
class LSTMClassifier(nn.Module):
"""
This is the simple RNN model we will be using to perform Sentiment Analysis.
"""
def __init__(self, embedding_dim, hidden_dim, vocab_size):
"""
Initialize the model by settingg up the various layers.
"""
super(LSTMClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
self.lstm = nn.LSTM(embedding_dim, hidden_dim)
self.dense = nn.Linear(in_features=hidden_dim, out_features=1)
self.sig = nn.Sigmoid()
self.word_dict = None
def forward(self, x):
"""
Perform a forward pass of our model on some input.
"""
x = x.t()
lengths = x[0,:]
reviews = x[1:,:]
embeds = self.embedding(reviews)
lstm_out, _ = self.lstm(embeds)
out = self.dense(lstm_out)
out = out[lengths - 1, range(len(lengths))]
return self.sig(out.squeeze())