Skip to content

Commit

Permalink
allow for processing heads in chunks in the initial cross attention l…
Browse files Browse the repository at this point in the history
…ayer, to save on peak memory
  • Loading branch information
lucidrains committed Jun 22, 2022
1 parent 081aa5e commit be37653
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions perceiver_ar_pytorch/perceiver_ar_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ def __init__(
dim,
dim_head = 64,
heads = 8,
max_heads_process = 2,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.max_heads_process = max_heads_process

inner_dim = heads * dim_head

self.norm = nn.LayerNorm(dim)
Expand Down Expand Up @@ -138,25 +141,43 @@ def forward(self, x, context, context_mask = None, rotary_pos_emb = None):
q = apply_rotary_pos_emb(rotary_pos_emb, q)
k = apply_rotary_pos_emb(rotary_pos_emb, k)

sim = einsum('b h i d, b h j d -> b h i j', q, k)

i, j = sim.shape[-2:]
# take care of masking

mask_value = -torch.finfo(sim.dtype).max
i, j = q.shape[-2], k.shape[-2]
mask_value = -torch.finfo(q.dtype).max

if exists(context_mask):
mask_len = context_mask.shape[-1]
context_mask = F.pad(context_mask, (0, max(j - mask_len, 0)), value = True)
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~context_mask, mask_value)

causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, mask_value)

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# process in chunks of heads

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = []

max_heads = self.max_heads_process

for q_chunk, k_chunk, v_chunk in zip(q.split(max_heads, dim = 1), k.split(max_heads, dim = 1), v.split(max_heads, dim = 1)):
sim = einsum('b h i d, b h j d -> b h i j', q_chunk, k_chunk)

if exists(context_mask):
sim = sim.masked_fill(~context_mask, mask_value)

sim = sim.masked_fill(causal_mask, mask_value)

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

out_chunk = einsum('b h i j, b h j d -> b h i d', attn, v_chunk)
out.append(out_chunk)

# concat all the heads together

out = torch.cat(out, dim = 1)

# merge heads and then combine with linear

out = rearrange(out, 'b h n d -> b n (h d)')

Expand All @@ -175,7 +196,8 @@ def __init__(
heads = 8,
dropout = 0.,
ff_mult = 4,
perceive_depth = 1
perceive_depth = 1,
perceive_max_heads_process = 2 # processes the heads in the perceiver layer in chunks to lower peak memory, in the case the prefix is really long
):
super().__init__()
assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style'
Expand All @@ -191,7 +213,7 @@ def __init__(

for _ in range(perceive_depth):
self.perceive_layers.append(nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, max_heads_process = perceive_max_heads_process, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
]))

Expand Down

0 comments on commit be37653

Please sign in to comment.