diff --git a/perceiver_ar_pytorch/perceiver_ar_pytorch.py b/perceiver_ar_pytorch/perceiver_ar_pytorch.py index 5576aab..c82aede 100644 --- a/perceiver_ar_pytorch/perceiver_ar_pytorch.py +++ b/perceiver_ar_pytorch/perceiver_ar_pytorch.py @@ -103,11 +103,14 @@ def __init__( dim, dim_head = 64, heads = 8, + max_heads_process = 2, dropout = 0. ): super().__init__() self.scale = dim_head ** -0.5 self.heads = heads + self.max_heads_process = max_heads_process + inner_dim = heads * dim_head self.norm = nn.LayerNorm(dim) @@ -138,25 +141,43 @@ def forward(self, x, context, context_mask = None, rotary_pos_emb = None): q = apply_rotary_pos_emb(rotary_pos_emb, q) k = apply_rotary_pos_emb(rotary_pos_emb, k) - sim = einsum('b h i d, b h j d -> b h i j', q, k) - - i, j = sim.shape[-2:] + # take care of masking - mask_value = -torch.finfo(sim.dtype).max + i, j = q.shape[-2], k.shape[-2] + mask_value = -torch.finfo(q.dtype).max if exists(context_mask): mask_len = context_mask.shape[-1] context_mask = F.pad(context_mask, (0, max(j - mask_len, 0)), value = True) context_mask = rearrange(context_mask, 'b j -> b 1 1 j') - sim = sim.masked_fill(~context_mask, mask_value) causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1) - sim = sim.masked_fill(causal_mask, mask_value) - attn = sim.softmax(dim = -1) - attn = self.dropout(attn) + # process in chunks of heads - out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = [] + + max_heads = self.max_heads_process + + for q_chunk, k_chunk, v_chunk in zip(q.split(max_heads, dim = 1), k.split(max_heads, dim = 1), v.split(max_heads, dim = 1)): + sim = einsum('b h i d, b h j d -> b h i j', q_chunk, k_chunk) + + if exists(context_mask): + sim = sim.masked_fill(~context_mask, mask_value) + + sim = sim.masked_fill(causal_mask, mask_value) + + attn = sim.softmax(dim = -1) + attn = self.dropout(attn) + + out_chunk = einsum('b h i j, b h j d -> b h i d', attn, v_chunk) + out.append(out_chunk) + + # concat all the heads together + + out = torch.cat(out, dim = 1) + + # merge heads and then combine with linear out = rearrange(out, 'b h n d -> b n (h d)') @@ -175,7 +196,8 @@ def __init__( heads = 8, dropout = 0., ff_mult = 4, - perceive_depth = 1 + perceive_depth = 1, + perceive_max_heads_process = 2 # processes the heads in the perceiver layer in chunks to lower peak memory, in the case the prefix is really long ): super().__init__() assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style' @@ -191,7 +213,7 @@ def __init__( for _ in range(perceive_depth): self.perceive_layers.append(nn.ModuleList([ - CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), + CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout), FeedForward(dim, mult = ff_mult, dropout = dropout) ]))