forked from mzc421/Pytorch-NLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
36 lines (26 loc) · 1.07 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding:utf-8 -*-
# @author: 木子川
# @Email: [email protected]
# @VX:fylaicai
import torch
import torch.nn as nn
class BiLSTMModel(nn.Module):
def __init__(self, corpus_num, class_num, embedding_num, hidden_num, bi=True):
super().__init__()
self.pred = None
self.embedding = nn.Embedding(corpus_num, embedding_num)
self.lstm = nn.LSTM(embedding_num, hidden_num, batch_first=True, bidirectional=bi)
if bi:
self.classifier = nn.Linear(hidden_num * 2, class_num)
else:
self.classifier = nn.Linear(hidden_num, class_num)
self.loss = nn.CrossEntropyLoss()
def forward(self, text, label=None):
embedding = self.embedding(text)
out, _ = self.lstm(embedding)
pred = self.classifier(out)
self.pred = torch.argmax(pred, dim=-1).reshape(-1)
if label is not None:
loss = self.loss(pred.reshape(-1, pred.shape[-1]), label.reshape(-1))
return loss
return torch.argmax(pred, dim=-1).reshape(-1)