diff --git a/perceiver_ar_pytorch/autoregressive_wrapper.py b/perceiver_ar_pytorch/autoregressive_wrapper.py index 107d9c0..f3163ed 100644 --- a/perceiver_ar_pytorch/autoregressive_wrapper.py +++ b/perceiver_ar_pytorch/autoregressive_wrapper.py @@ -50,22 +50,21 @@ def generate( filter_thres=0.9, **kwargs ): - b, seq_len, device = *start_tokens.shape, start_tokens.device + b, n, device = *start_tokens.shape, start_tokens.device - offset = seq_len out = start_tokens for _ in range(seq_len): - out = out[:, -self.max_seq_len:] - logits = self.net(out, **kwargs)[:, -1, :] + logits = self.net( + out[:, -self.max_seq_len:], + **kwargs + )[:, -1] filtered_logits = top_k(logits, thres = filter_thres) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) - out = torch.cat((out, sample), dim=-1) - offset = max(0, offset - 1) if exists(eos_token): is_eos_token = out == eos_token @@ -77,7 +76,7 @@ def generate( out = out.masked_fill(mask, self.pad_value) break - out = out[:, offset:] + out = out[:, n:] return out def forward(self, x, **kwargs): diff --git a/setup.py b/setup.py index 1433117..215f2df 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-ar-pytorch', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.0.9', license='MIT', description = 'Perceiver AR', author = 'Phil Wang',