Skip to content

Commit

Permalink
allow for variable number of perceiver layers
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 21, 2022
1 parent 256d405 commit 3c50fa0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
21 changes: 12 additions & 9 deletions perceiver_ar_pytorch/perceiver_ar_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def __init__(
dim_head = 64,
heads = 8,
dropout = 0.,
ff_mult = 4
ff_mult = 4,
perceive_depth = 1
):
super().__init__()
assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style'
Expand All @@ -168,10 +169,13 @@ def __init__(

self.rotary_pos_emb = RotaryEmbedding(dim = max(32, dim_head // 2))

self.perceive_layer = nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
])
self.perceive_layers = nn.ModuleList([])

for _ in range(perceive_depth):
self.perceive_layers.append(nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
]))

self.layers = nn.ModuleList([])
for _ in range(depth):
Expand Down Expand Up @@ -203,10 +207,9 @@ def forward(

# initial perceiver attention and feedforward (one cross attention)

cross_attn, ff = self.perceive_layer

x = cross_attn(x, prefix, rotary_pos_emb = rotary_pos_emb) + x
x = ff(x) + x
for cross_attn, ff in self.perceive_layers:
x = cross_attn(x, prefix, rotary_pos_emb = rotary_pos_emb) + x
x = ff(x) + x

# layers

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-ar-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.2',
version = '0.0.3',
license='MIT',
description = 'Perceiver AR',
author = 'Phil Wang',
Expand Down

0 comments on commit 3c50fa0

Please sign in to comment.