From 081aa5e0af2e8f1afb7a6ad61ceded1795439fb5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 21 Jun 2022 22:06:54 +0000 Subject: [PATCH] add dropout --- perceiver_ar_pytorch/perceiver_ar_pytorch.py | 6 ++++++ setup.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/perceiver_ar_pytorch/perceiver_ar_pytorch.py b/perceiver_ar_pytorch/perceiver_ar_pytorch.py index 045568c..5576aab 100644 --- a/perceiver_ar_pytorch/perceiver_ar_pytorch.py +++ b/perceiver_ar_pytorch/perceiver_ar_pytorch.py @@ -66,6 +66,7 @@ def __init__( inner_dim = heads * dim_head self.norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) @@ -88,6 +89,8 @@ def forward(self, x, rotary_pos_emb = None): sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) attn = sim.softmax(dim = -1) + attn = self.dropout(attn) + out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -109,6 +112,7 @@ def __init__( self.norm = nn.LayerNorm(dim) self.context_norm = nn.LayerNorm(dim) + self.dropout = nn.Dropout(dropout) self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) @@ -150,6 +154,8 @@ def forward(self, x, context, context_mask = None, rotary_pos_emb = None): sim = sim.masked_fill(causal_mask, mask_value) attn = sim.softmax(dim = -1) + attn = self.dropout(attn) + out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') diff --git a/setup.py b/setup.py index 36384dc..5803de6 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-ar-pytorch', packages = find_packages(exclude=[]), - version = '0.0.5', + version = '0.0.6', license='MIT', description = 'Perceiver AR', author = 'Phil Wang',