From 3c50fa0e78b24645560bf21f68c02549dcf3ad5d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 21 Jun 2022 14:36:22 -0700 Subject: [PATCH] allow for variable number of perceiver layers --- perceiver_ar_pytorch/perceiver_ar_pytorch.py | 21 +++++++++++--------- setup.py | 2 +- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/perceiver_ar_pytorch/perceiver_ar_pytorch.py b/perceiver_ar_pytorch/perceiver_ar_pytorch.py index 6c56a86..cca933f 100644 --- a/perceiver_ar_pytorch/perceiver_ar_pytorch.py +++ b/perceiver_ar_pytorch/perceiver_ar_pytorch.py @@ -156,7 +156,8 @@ def __init__( dim_head = 64, heads = 8, dropout = 0., - ff_mult = 4 + ff_mult = 4, + perceive_depth = 1 ): super().__init__() assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style' @@ -168,10 +169,13 @@ def __init__( self.rotary_pos_emb = RotaryEmbedding(dim = max(32, dim_head // 2)) - self.perceive_layer = nn.ModuleList([ - CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), - FeedForward(dim, mult = ff_mult, dropout = dropout) - ]) + self.perceive_layers = nn.ModuleList([]) + + for _ in range(perceive_depth): + self.perceive_layers.append(nn.ModuleList([ + CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), + FeedForward(dim, mult = ff_mult, dropout = dropout) + ])) self.layers = nn.ModuleList([]) for _ in range(depth): @@ -203,10 +207,9 @@ def forward( # initial perceiver attention and feedforward (one cross attention) - cross_attn, ff = self.perceive_layer - - x = cross_attn(x, prefix, rotary_pos_emb = rotary_pos_emb) + x - x = ff(x) + x + for cross_attn, ff in self.perceive_layers: + x = cross_attn(x, prefix, rotary_pos_emb = rotary_pos_emb) + x + x = ff(x) + x # layers diff --git a/setup.py b/setup.py index 47813c4..caee633 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-ar-pytorch', packages = find_packages(exclude=[]), - version = '0.0.2', + version = '0.0.3', license='MIT', description = 'Perceiver AR', author = 'Phil Wang',