diff --git a/2-1.TextCNN/TextCNN.py b/2-1.TextCNN/TextCNN.py index f19c139..b7fb373 100644 --- a/2-1.TextCNN/TextCNN.py +++ b/2-1.TextCNN/TextCNN.py @@ -16,7 +16,7 @@ def __init__(self): self.filter_list = nn.ModuleList([nn.Conv2d(1, num_filters, (size, embedding_size)) for size in filter_sizes]) def forward(self, X): - embedded_chars = self.W(X) # [batch_size, sequence_length, sequence_length] + embedded_chars = self.W(X) # [batch_size, sequence_length, embedding_size] embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = [] @@ -81,4 +81,4 @@ def forward(self, X): if predict[0][0] == 0: print(test_text,"is Bad Mean...") else: - print(test_text,"is Good Mean!!") \ No newline at end of file + print(test_text,"is Good Mean!!")