-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_model.py
47 lines (33 loc) · 1.49 KB
/
make_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import copy
import torch.nn as nn
from models.model.EncoderDecoder import EncoderDecoder
from models.model.encoder import Encoder
from models.model.decoder import Decoder
from models.layers.EncoderLayer import EncoderLayer
from models.layers.DecoderLayer import DecoderLayer
from models.layers.PositionalEncoding import PositionalEncoding
from models.layers.MultiHeadAttention import MultiHeadAttention
from models.layers.FeedForwardLayer import PositionwiseFeedForward
from models.func.embeddings import Embeddings
from models.func.generator import Generator
def make_model(
src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
"Helper: Construct a model from hyperparameters."
c = copy.deepcopy
attn = MultiHeadAttention(h, d_model)
ff = PositionwiseFeedForward(d_model, d_ff, dropout)
position = PositionalEncoding(d_model, dropout)
model = EncoderDecoder(
Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
Generator(d_model, tgt_vocab),
)
# This was important from their code.
# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model