Skip to content

Commit

Permalink
update mask operation
Browse files Browse the repository at this point in the history
  • Loading branch information
yusugomori committed Mar 18, 2019
1 parent cc48692 commit 74fe26c
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 7 deletions.
4 changes: 2 additions & 2 deletions models/encoder_decoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def __init__(self,
input_dim,
hidden_dim,
output_dim,
device='cpu',
bos_value=1,
max_len=20):
max_len=20,
device='cpu'):
super().__init__()
self.device = device
self.encoder = Encoder(input_dim, hidden_dim, device=device)
Expand Down
1 change: 0 additions & 1 deletion models/encoder_decoder_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self,
input_dim,
hidden_dim,
output_dim,
device,
bos_value=1,
max_len=20,
device='cpu'):
Expand Down
6 changes: 4 additions & 2 deletions models/layers/Attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def forward(self, ht, hs, source=None, pad_value=0):

score = torch.exp(score)
if source is not None:
mask_source = (source.t() != pad_value).unsqueeze(0)
score = score * mask_source.float().to(self.device)
# mask_source = (source.t() != pad_value).unsqueeze(0)
# score = score * mask_source.float().to(self.device)
mask_source = source.t().eq(pad_value).unsqueeze(0)
score.data.masked_fill_(mask_source, 0)

a = score / torch.sum(score, dim=-1, keepdim=True)
c = torch.einsum('jik,kil->jil', (a, hs))
Expand Down
3 changes: 2 additions & 1 deletion models/layers/DotProductAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def forward(self, q, k, v, mask=None):
# in source-target-attention, source is `k` and `v`
if len(mask.size()) == 2:
mask = mask.unsqueeze(0)
score = score * mask.type(torch.Tensor).to(self.device)
# score = score * mask.float().to(self.device)
score.data.masked_fill_(mask, 0)

a = score / torch.sum(score, dim=-1, keepdim=True)
c = torch.einsum('jik,kil->jil', (a, v))
Expand Down
3 changes: 2 additions & 1 deletion models/layers/ScaledDotProductAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def forward(self, q, k, v, mask=None):
# in source-target-attention, source is `k` and `v`
if len(mask.size()) == 2:
mask = mask.unsqueeze(0)
score = score * mask.type(torch.Tensor).to(self.device)
# score = score * mask.float().to(self.device)
score.data.masked_fill_(mask, 0)

a = score / torch.sum(score, dim=-1, keepdim=True)
c = torch.einsum('jik,kil->jil', (a, v))
Expand Down

0 comments on commit 74fe26c

Please sign in to comment.