Skip to content

Commit

Permalink
first pass at perceiver AR, without rotary embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 21, 2022
1 parent bfa7526 commit 1d77ba6
Show file tree
Hide file tree
Showing 9 changed files with 484 additions and 0 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@


# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [published]

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,41 @@ However, the experimental results they provided are still worthwhile and I'll bu

<a href="https://github.com/google-research/perceiver-ar">Official Jax repository</a>

## Install

```bash
$ pip install perceiver-ar-pytorch
```

## Usage

```python
import torch
from perceiver_ar_pytorch import PerceiverAR

model = PerceiverAR(
num_tokens = 20000, # number of tokens
dim = 512, # model dimensions
depth = 8, # model depth
dim_head = 64, # attention head dimension
heads = 8, # attention heads
max_seq_len = 4096, # total max sequence length
cross_attn_seq_len = 3072, # the sequence length in which to attend to, but does not undergo self attention (must be less than max_seq_len)
)

x = torch.randint(0, 20000, (1, 4096))

logits = model(x) # (1, 1024, 20000) - (4096 [seq len] - 3072 [perceived prefix] == 1024)
```

## Test

Enwik8 at 4096

```bash
$ python train.py
```

## Citations

```bibtex
Expand Down
3 changes: 3 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Data source

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file added data/enwik8.gz
Binary file not shown.
1 change: 1 addition & 0 deletions perceiver_ar_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from perceiver_ar_pytorch.perceiver_ar_pytorch import PerceiverAR
85 changes: 85 additions & 0 deletions perceiver_ar_pytorch/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

# helper function


def exists(val):
return val is not None


def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out

return inner


# top k filtering


def top_k(logits, thres=0.9):
k = int((1 - thres) * logits.shape[-1])
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float("-inf"))
probs.scatter_(1, ind, val)
return probs


class AutoregressiveWrapper(nn.Module):
def __init__(self, net, pad_value=0):
super().__init__()
self.max_seq_len = net.max_seq_len
self.pad_value = pad_value
self.net = net

@torch.no_grad()
@eval_decorator
def generate(
self,
start_tokens,
seq_len,
eos_token=None,
temperature=1.0,
filter_thres=0.9,
**kwargs
):
b, seq_len, device = *start_tokens.shape, start_tokens.device

offset = seq_len
out = start_tokens

for _ in range(seq_len):
out = out[:, -self.max_seq_len:]
logits = self.net(out, **kwargs)[:, -1, :]

filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)
offset = max(0, offset - 1)

if exists(eos_token):
is_eos_token = out == eos_token

if is_eos_token.any(dim=-1).all():
# mask out everything after the eos tokens
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1
out = out.masked_fill(mask, self.pad_value)
break

out = out[:, offset:]
return out

def forward(self, x, **kwargs):
x_inp, x_labels = x[:, :-1], x[:, 1:]
return self.net(x_inp, labels = x_labels, **kwargs)
176 changes: 176 additions & 0 deletions perceiver_ar_pytorch/perceiver_ar_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange

def exists(val):
return val is not None

def FeedForward(dim, mult = 4, dropout = 0.):
hidden_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim, bias = False),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim, bias = False)
)

class CausalAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head

self.norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x):
x = self.norm(x)

q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)

i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class CausalPrefixAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = heads * dim_head

self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, context):
x = self.norm(x)
q = self.to_q(x)

k_input, v_input = self.to_kv(x).chunk(2, dim = -1)
k_context, v_context = self.to_kv(context).chunk(2, dim = -1)

k = torch.cat((k_context, k_input), dim = 1)
v = torch.cat((v_context, v_input), dim = 1)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

q = q * self.scale

sim = einsum('b h i d, b h j d -> b h i j', q, k)

i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1)

sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)

out = rearrange(out, 'b h n d -> b n (h d)')

return self.to_out(out)

class PerceiverAR(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
depth,
max_seq_len,
cross_attn_seq_len,
dim_head = 64,
heads = 8,
dropout = 0.,
ff_mult = 4
):
super().__init__()
assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style'
self.max_seq_len = max_seq_len
self.cross_attn_seq_len = cross_attn_seq_len

self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)

self.perceive_layer = nn.ModuleList([
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
FeedForward(dim, mult = ff_mult, dropout = dropout)
])

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
CausalAttention(dim = dim, dim_head = dim_head, heads = heads),
FeedForward(dim, mult = ff_mult, dropout = dropout),
]))

self.to_logits = nn.Linear(dim, num_tokens, bias = False)

def forward(
self,
x,
labels = None
):
seq_len, device = x.shape[1], x.device
assert self.cross_attn_seq_len < seq_len <= self.max_seq_len

x = self.token_emb(x)
x = x + self.pos_emb(torch.arange(seq_len, device = device))

prefix, x = x[:, :self.cross_attn_seq_len], x[:, self.cross_attn_seq_len:]

# initial perceiver attention and feedforward (one cross attention)

cross_attn, ff = self.perceive_layer

x = cross_attn(x, prefix) + x
x = ff(x) + x

# layers

for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x

# to logits

logits = self.to_logits(x)

# take care of cross entropy loss if labels are provided

if not exists(labels):
return logits

labels = labels[:, self.cross_attn_seq_len:]
return F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = 0)
31 changes: 31 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from setuptools import setup, find_packages

setup(
name = 'perceiver-ar-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
license='MIT',
description = 'Perceiver AR',
author = 'Phil Wang',
author_email = '[email protected]',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/perceiver-ar-pytorch',
keywords = [
'artificial intelligence',
'deep learning',
'transformers',
'long context',
'attention'
],
install_requires=[
'einops>=0.4',
'torch>=1.6',
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
Loading

0 comments on commit 1d77ba6

Please sign in to comment.