Skip to content

Commit

Permalink
some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 2, 2022
1 parent 4864ba1 commit 685d77d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
13 changes: 6 additions & 7 deletions perceiver_ar_pytorch/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,21 @@ def generate(
filter_thres=0.9,
**kwargs
):
b, seq_len, device = *start_tokens.shape, start_tokens.device
b, n, device = *start_tokens.shape, start_tokens.device

offset = seq_len
out = start_tokens

for _ in range(seq_len):
out = out[:, -self.max_seq_len:]
logits = self.net(out, **kwargs)[:, -1, :]
logits = self.net(
out[:, -self.max_seq_len:],
**kwargs
)[:, -1]

filtered_logits = top_k(logits, thres = filter_thres)
probs = F.softmax(filtered_logits / temperature, dim=-1)

sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)
offset = max(0, offset - 1)

if exists(eos_token):
is_eos_token = out == eos_token
Expand All @@ -77,7 +76,7 @@ def generate(
out = out.masked_fill(mask, self.pad_value)
break

out = out[:, offset:]
out = out[:, n:]
return out

def forward(self, x, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-ar-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.9',
license='MIT',
description = 'Perceiver AR',
author = 'Phil Wang',
Expand Down

0 comments on commit 685d77d

Please sign in to comment.