diff --git a/perceiver_ar_pytorch/perceiver_ar_pytorch.py b/perceiver_ar_pytorch/perceiver_ar_pytorch.py index ac1d4a1..045568c 100644 --- a/perceiver_ar_pytorch/perceiver_ar_pytorch.py +++ b/perceiver_ar_pytorch/perceiver_ar_pytorch.py @@ -108,12 +108,16 @@ def __init__( inner_dim = heads * dim_head self.norm = nn.LayerNorm(dim) + self.context_norm = nn.LayerNorm(dim) + self.to_q = nn.Linear(dim, inner_dim, bias = False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) def forward(self, x, context, context_mask = None, rotary_pos_emb = None): x = self.norm(x) + context = self.context_norm(context) + q = self.to_q(x) k_input, v_input = self.to_kv(x).chunk(2, dim = -1) diff --git a/setup.py b/setup.py index e0a1f12..36384dc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-ar-pytorch', packages = find_packages(exclude=[]), - version = '0.0.4', + version = '0.0.5', license='MIT', description = 'Perceiver AR', author = 'Phil Wang',