-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first pass at perceiver AR, without rotary embeddings
- Loading branch information
1 parent
bfa7526
commit 1d77ba6
Showing
9 changed files
with
484 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
|
||
|
||
# This workflow will upload a Python Package using Twine when a release is created | ||
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries | ||
|
||
# This workflow uses actions that are not certified by GitHub. | ||
# They are provided by a third-party and are governed by | ||
# separate terms of service, privacy policy, and support | ||
# documentation. | ||
|
||
name: Upload Python Package | ||
|
||
on: | ||
release: | ||
types: [published] | ||
|
||
jobs: | ||
deploy: | ||
|
||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
- name: Set up Python | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: '3.x' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install build | ||
- name: Build package | ||
run: python -m build | ||
- name: Publish package | ||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 | ||
with: | ||
user: __token__ | ||
password: ${{ secrets.PYPI_API_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Data source | ||
|
||
The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from perceiver_ar_pytorch.perceiver_ar_pytorch import PerceiverAR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from einops import rearrange | ||
from torch import nn | ||
|
||
# helper function | ||
|
||
|
||
def exists(val): | ||
return val is not None | ||
|
||
|
||
def eval_decorator(fn): | ||
def inner(model, *args, **kwargs): | ||
was_training = model.training | ||
model.eval() | ||
out = fn(model, *args, **kwargs) | ||
model.train(was_training) | ||
return out | ||
|
||
return inner | ||
|
||
|
||
# top k filtering | ||
|
||
|
||
def top_k(logits, thres=0.9): | ||
k = int((1 - thres) * logits.shape[-1]) | ||
val, ind = torch.topk(logits, k) | ||
probs = torch.full_like(logits, float("-inf")) | ||
probs.scatter_(1, ind, val) | ||
return probs | ||
|
||
|
||
class AutoregressiveWrapper(nn.Module): | ||
def __init__(self, net, pad_value=0): | ||
super().__init__() | ||
self.max_seq_len = net.max_seq_len | ||
self.pad_value = pad_value | ||
self.net = net | ||
|
||
@torch.no_grad() | ||
@eval_decorator | ||
def generate( | ||
self, | ||
start_tokens, | ||
seq_len, | ||
eos_token=None, | ||
temperature=1.0, | ||
filter_thres=0.9, | ||
**kwargs | ||
): | ||
b, seq_len, device = *start_tokens.shape, start_tokens.device | ||
|
||
offset = seq_len | ||
out = start_tokens | ||
|
||
for _ in range(seq_len): | ||
out = out[:, -self.max_seq_len:] | ||
logits = self.net(out, **kwargs)[:, -1, :] | ||
|
||
filtered_logits = top_k(logits, thres = filter_thres) | ||
probs = F.softmax(filtered_logits / temperature, dim=-1) | ||
|
||
sample = torch.multinomial(probs, 1) | ||
|
||
out = torch.cat((out, sample), dim=-1) | ||
offset = max(0, offset - 1) | ||
|
||
if exists(eos_token): | ||
is_eos_token = out == eos_token | ||
|
||
if is_eos_token.any(dim=-1).all(): | ||
# mask out everything after the eos tokens | ||
shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) | ||
mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 | ||
out = out.masked_fill(mask, self.pad_value) | ||
break | ||
|
||
out = out[:, offset:] | ||
return out | ||
|
||
def forward(self, x, **kwargs): | ||
x_inp, x_labels = x[:, :-1], x[:, 1:] | ||
return self.net(x_inp, labels = x_labels, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn, einsum | ||
|
||
from einops import rearrange | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def FeedForward(dim, mult = 4, dropout = 0.): | ||
hidden_dim = int(dim * mult) | ||
return nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.Linear(dim, hidden_dim, bias = False), | ||
nn.GELU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(hidden_dim, dim, bias = False) | ||
) | ||
|
||
class CausalAttention(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
dim, | ||
dim_head = 64, | ||
heads = 8, | ||
dropout = 0. | ||
): | ||
super().__init__() | ||
self.scale = dim_head ** -0.5 | ||
self.heads = heads | ||
inner_dim = heads * dim_head | ||
|
||
self.norm = nn.LayerNorm(dim) | ||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | ||
self.to_out = nn.Linear(inner_dim, dim, bias = False) | ||
|
||
def forward(self, x): | ||
x = self.norm(x) | ||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim = -1) | ||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) | ||
|
||
q = q * self.scale | ||
sim = einsum('b h i d, b h j d -> b h i j', q, k) | ||
|
||
i, j = sim.shape[-2:] | ||
causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1) | ||
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | ||
|
||
attn = sim.softmax(dim = -1) | ||
out = einsum('b h i j, b h j d -> b h i d', attn, v) | ||
|
||
out = rearrange(out, 'b h n d -> b n (h d)') | ||
return self.to_out(out) | ||
|
||
class CausalPrefixAttention(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
dim, | ||
dim_head = 64, | ||
heads = 8, | ||
dropout = 0. | ||
): | ||
super().__init__() | ||
self.scale = dim_head ** -0.5 | ||
self.heads = heads | ||
inner_dim = heads * dim_head | ||
|
||
self.norm = nn.LayerNorm(dim) | ||
self.to_q = nn.Linear(dim, inner_dim, bias = False) | ||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | ||
self.to_out = nn.Linear(inner_dim, dim) | ||
|
||
def forward(self, x, context): | ||
x = self.norm(x) | ||
q = self.to_q(x) | ||
|
||
k_input, v_input = self.to_kv(x).chunk(2, dim = -1) | ||
k_context, v_context = self.to_kv(context).chunk(2, dim = -1) | ||
|
||
k = torch.cat((k_context, k_input), dim = 1) | ||
v = torch.cat((v_context, v_input), dim = 1) | ||
|
||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) | ||
|
||
q = q * self.scale | ||
|
||
sim = einsum('b h i d, b h j d -> b h i j', q, k) | ||
|
||
i, j = sim.shape[-2:] | ||
causal_mask = torch.ones((i, j), device = x.device, dtype = torch.bool).triu(j - i + 1) | ||
|
||
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | ||
|
||
attn = sim.softmax(dim = -1) | ||
out = einsum('b h i j, b h j d -> b h i d', attn, v) | ||
|
||
out = rearrange(out, 'b h n d -> b n (h d)') | ||
|
||
return self.to_out(out) | ||
|
||
class PerceiverAR(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
num_tokens, | ||
dim, | ||
depth, | ||
max_seq_len, | ||
cross_attn_seq_len, | ||
dim_head = 64, | ||
heads = 8, | ||
dropout = 0., | ||
ff_mult = 4 | ||
): | ||
super().__init__() | ||
assert max_seq_len > cross_attn_seq_len, 'max_seq_len must be greater than cross_attn_seq_len, the length of the sequence for which to cross attend to "perceiver" style' | ||
self.max_seq_len = max_seq_len | ||
self.cross_attn_seq_len = cross_attn_seq_len | ||
|
||
self.token_emb = nn.Embedding(num_tokens, dim) | ||
self.pos_emb = nn.Embedding(max_seq_len, dim) | ||
|
||
self.perceive_layer = nn.ModuleList([ | ||
CausalPrefixAttention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), | ||
FeedForward(dim, mult = ff_mult, dropout = dropout) | ||
]) | ||
|
||
self.layers = nn.ModuleList([]) | ||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
CausalAttention(dim = dim, dim_head = dim_head, heads = heads), | ||
FeedForward(dim, mult = ff_mult, dropout = dropout), | ||
])) | ||
|
||
self.to_logits = nn.Linear(dim, num_tokens, bias = False) | ||
|
||
def forward( | ||
self, | ||
x, | ||
labels = None | ||
): | ||
seq_len, device = x.shape[1], x.device | ||
assert self.cross_attn_seq_len < seq_len <= self.max_seq_len | ||
|
||
x = self.token_emb(x) | ||
x = x + self.pos_emb(torch.arange(seq_len, device = device)) | ||
|
||
prefix, x = x[:, :self.cross_attn_seq_len], x[:, self.cross_attn_seq_len:] | ||
|
||
# initial perceiver attention and feedforward (one cross attention) | ||
|
||
cross_attn, ff = self.perceive_layer | ||
|
||
x = cross_attn(x, prefix) + x | ||
x = ff(x) + x | ||
|
||
# layers | ||
|
||
for attn, ff in self.layers: | ||
x = attn(x) + x | ||
x = ff(x) + x | ||
|
||
# to logits | ||
|
||
logits = self.to_logits(x) | ||
|
||
# take care of cross entropy loss if labels are provided | ||
|
||
if not exists(labels): | ||
return logits | ||
|
||
labels = labels[:, self.cross_attn_seq_len:] | ||
return F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels, ignore_index = 0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name = 'perceiver-ar-pytorch', | ||
packages = find_packages(exclude=[]), | ||
version = '0.0.1', | ||
license='MIT', | ||
description = 'Perceiver AR', | ||
author = 'Phil Wang', | ||
author_email = '[email protected]', | ||
long_description_content_type = 'text/markdown', | ||
url = 'https://github.com/lucidrains/perceiver-ar-pytorch', | ||
keywords = [ | ||
'artificial intelligence', | ||
'deep learning', | ||
'transformers', | ||
'long context', | ||
'attention' | ||
], | ||
install_requires=[ | ||
'einops>=0.4', | ||
'torch>=1.6', | ||
], | ||
classifiers=[ | ||
'Development Status :: 4 - Beta', | ||
'Intended Audience :: Developers', | ||
'Topic :: Scientific/Engineering :: Artificial Intelligence', | ||
'License :: OSI Approved :: MIT License', | ||
'Programming Language :: Python :: 3.6', | ||
], | ||
) |
Oops, something went wrong.