From dfe22d0a732e42684a5804f76765b719370acbb7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 21 Jun 2022 14:32:26 -0700 Subject: [PATCH] add partial rotary positional embedding --- perceiver_ar_pytorch/perceiver_ar_pytorch.py | 59 ++++++++++++++++++-- setup.py | 2 +- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/perceiver_ar_pytorch/perceiver_ar_pytorch.py b/perceiver_ar_pytorch/perceiver_ar_pytorch.py index 4967fe5..6c56a86 100644 --- a/perceiver_ar_pytorch/perceiver_ar_pytorch.py +++ b/perceiver_ar_pytorch/perceiver_ar_pytorch.py @@ -4,9 +4,13 @@ from einops import rearrange +# helper functions + def exists(val): return val is not None +# feedforward + def FeedForward(dim, mult = 4, dropout = 0.): hidden_dim = int(dim * mult) return nn.Sequential( @@ -17,6 +21,36 @@ def FeedForward(dim, mult = 4, dropout = 0.): nn.Linear(hidden_dim, dim, bias = False) ) +# rotary positional embedding +# https://arxiv.org/abs/2104.09864 + +class RotaryEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + + def forward(self, max_seq_len, *, device): + seq = torch.arange(max_seq_len, device = device, dtype = self.inv_freq.dtype) + freqs = einsum("i , j -> i j", seq, self.inv_freq) + return torch.cat((freqs, freqs), dim = -1) + + +def rotate_half(x): + x = rearrange(x, "... (j d) -> ... j d", j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + + +def apply_rotary_pos_emb(pos, t): + seq_len, rotate_dim = t.shape[-2], pos.shape[-1] + pos = pos[..., -seq_len:, :] + t, t_pass = t[..., :rotate_dim], t[..., rotate_dim:] + t = (t * pos.cos()) + (rotate_half(t) * pos.sin()) + return torch.cat((t, t_pass), dim = -1) + +# attention + class CausalAttention(nn.Module): def __init__( self, @@ -35,13 +69,18 @@ def __init__( self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) - def forward(self, x): + def forward(self, x, rotary_pos_emb = None): x = self.norm(x) q, k, v = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) q = q * self.scale + + if exists(rotary_pos_emb): + q = apply_rotary_pos_emb(rotary_pos_emb, q) + k = apply_rotary_pos_emb(rotary_pos_emb, k) + sim = einsum('b h i d, b h j d -> b h i j', q, k) i, j = sim.shape[-2:] @@ -73,7 +112,7 @@ def __init__( self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) - def forward(self, x, context): + def forward(self, x, context, rotary_pos_emb = None): x = self.norm(x) q = self.to_q(x) @@ -87,6 +126,10 @@ def forward(self, x, context): q = q * self.scale + if exists(rotary_pos_emb): + q = apply_rotary_pos_emb(rotary_pos_emb, q) + k = apply_rotary_pos_emb(rotary_pos_emb, k) + sim = einsum('b h i d, b h j d -> b h i j', q, k) i, j = sim.shape[-2:] @@ -123,6 +166,8 @@ def __init__( self.token_emb = nn.Embedding(num_tokens, dim) self.pos_emb = nn.Embedding(max_seq_len, dim) + self.rotary_pos_emb = RotaryEmbedding(dim = max(32, dim_head // 2)) + self.perceive_layer = nn.ModuleList([ CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), FeedForward(dim, mult = ff_mult, dropout = dropout) @@ -148,19 +193,25 @@ def forward( x = self.token_emb(x) x = x + self.pos_emb(torch.arange(seq_len, device = device)) + # rotary positional embedding + + rotary_pos_emb = self.rotary_pos_emb(seq_len, device = device) + + # divide into prefix to cross attend to and sequence to self attend to + prefix, x = x[:, :self.cross_attn_seq_len], x[:, self.cross_attn_seq_len:] # initial perceiver attention and feedforward (one cross attention) cross_attn, ff = self.perceive_layer - x = cross_attn(x, prefix) + x + x = cross_attn(x, prefix, rotary_pos_emb = rotary_pos_emb) + x x = ff(x) + x # layers for attn, ff in self.layers: - x = attn(x) + x + x = attn(x, rotary_pos_emb = rotary_pos_emb) + x x = ff(x) + x # to logits diff --git a/setup.py b/setup.py index 06e9ae0..47813c4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-ar-pytorch', packages = find_packages(exclude=[]), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'Perceiver AR', author = 'Phil Wang',