Skip to content

Commit

Permalink
minor fix ScaledDotProductAttention
Browse files Browse the repository at this point in the history
yusugomori committed Mar 21, 2019
1 parent f0c18d4 commit 8db9197
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion models/layers/MultiHeadAttention.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def __init__(self,
nn.init.xavier_normal_(self.W_k)
nn.init.xavier_normal_(self.W_v)

self.attn = ScaledDotProductAttention(d_model)
self.attn = ScaledDotProductAttention(d_k)
self.linear = nn.Linear((h * d_v), d_model)
nn.init.xavier_normal_(self.linear.weight)

4 changes: 2 additions & 2 deletions models/layers/ScaledDotProductAttention.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,11 @@

class ScaledDotProductAttention(nn.Module):
def __init__(self,
d_model,
d_k,
device='cpu'):
super().__init__()
self.device = device
self.scaler = np.sqrt(d_model)
self.scaler = np.sqrt(d_k)

def forward(self, q, k, v, mask=None):
'''

0 comments on commit 8db9197

Please sign in to comment.