Skip to content

Commit

Permalink
add partial rotary positional embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 21, 2022
1 parent 1d77ba6 commit dfe22d0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
59 changes: 55 additions & 4 deletions perceiver_ar_pytorch/perceiver_ar_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

from einops import rearrange

# helper functions

def exists(val):
return val is not None

# feedforward

def FeedForward(dim, mult = 4, dropout = 0.):
hidden_dim = int(dim * mult)
return nn.Sequential(
Expand All @@ -17,6 +21,36 @@ def FeedForward(dim, mult = 4, dropout = 0.):
nn.Linear(hidden_dim, dim, bias = False)
)

# rotary positional embedding
# https://arxiv.org/abs/2104.09864

class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)

def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device = device, dtype = self.inv_freq.dtype)
freqs = einsum("i , j -> i j", seq, self.inv_freq)
return torch.cat((freqs, freqs), dim = -1)


def rotate_half(x):
x = rearrange(x, "... (j d) -> ... j d", j = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)


def apply_rotary_pos_emb(pos, t):
seq_len, rotate_dim = t.shape[-2], pos.shape[-1]
pos = pos[..., -seq_len:, :]
t, t_pass = t[..., :rotate_dim], t[..., rotate_dim:]
t = (t * pos.cos()) + (rotate_half(t) * pos.sin())
return torch.cat((t, t_pass), dim = -1)

# attention

class CausalAttention(nn.Module):
def __init__(
self,
Expand All @@ -35,13 +69,18 @@ def __init__(
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x):
def forward(self, x, rotary_pos_emb = None):
x = self.norm(x)

q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

q = q * self.scale

if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(rotary_pos_emb, q)
k = apply_rotary_pos_emb(rotary_pos_emb, k)

sim = einsum('b h i d, b h j d -> b h i j', q, k)

i, j = sim.shape[-2:]
Expand Down Expand Up @@ -73,7 +112,7 @@ def __init__(
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, context):
def forward(self, x, context, rotary_pos_emb = None):
x = self.norm(x)
q = self.to_q(x)

Expand All @@ -87,6 +126,10 @@ def forward(self, x, context):

q = q * self.scale

if exists(rotary_pos_emb):
q = apply_rotary_pos_emb(rotary_pos_emb, q)
k = apply_rotary_pos_emb(rotary_pos_emb, k)

sim = einsum('b h i d, b h j d -> b h i j', q, k)

i, j = sim.shape[-2:]
Expand Down Expand Up @@ -123,6 +166,8 @@ def __init__(
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)

self.rotary_pos_emb = RotaryEmbedding(dim = max(32, dim_head // 2))

self.perceive_layer = nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
Expand All @@ -148,19 +193,25 @@ def forward(
x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(seq_len, device = device))

# rotary positional embedding

rotary_pos_emb = self.rotary_pos_emb(seq_len, device = device)

# divide into prefix to cross attend to and sequence to self attend to

prefix, x = x[:, :self.cross_attn_seq_len], x[:, self.cross_attn_seq_len:]

# initial perceiver attention and feedforward (one cross attention)

cross_attn, ff = self.perceive_layer

x = cross_attn(x, prefix) + x
x = cross_attn(x, prefix, rotary_pos_emb = rotary_pos_emb) + x
x = ff(x) + x

# layers

for attn, ff in self.layers:
x = attn(x) + x
x = attn(x, rotary_pos_emb = rotary_pos_emb) + x
x = ff(x) + x

# to logits
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-ar-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'Perceiver AR',
author = 'Phil Wang',
Expand Down

0 comments on commit dfe22d0

Please sign in to comment.