Skip to content

Commit

Permalink
Put prompt_token_ids, attentions_mask and weights on the same d…
Browse files Browse the repository at this point in the history
…evice
  • Loading branch information
rlouf committed Mar 1, 2024
1 parent c4de2e0 commit 42f465c
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,9 @@ def stream(
stop_sequences = stop_at
num_samples = self.num_samples

if rng is None:
rng = torch.Generator(device=self.device)
rng.seed()

prompt_token_ids, attention_masks = self.tokenizer.encode(prompts)
prompt_token_ids = prompt_token_ids.to(self.device)
attention_masks = attention_masks.to(self.device)
attention_masks = attention_masks.to(prompt_token_ids.device)

# To draw multiple samples we repeat the prompt as many times
# as there are samples. We copy the FSMs and initialize the
Expand All @@ -298,9 +294,15 @@ def stream(
fsm_states = [FSMState(0) for _ in range(batch_size * num_samples)]
fsms = [self.fsm.copy() for _ in range(batch_size * num_samples)]
weights = torch.zeros(
(batch_size * num_samples), dtype=torch.float, device=self.device
(batch_size * num_samples),
dtype=torch.float,
device=prompt_token_ids.device,
)

if rng is None:
rng = torch.Generator(device=prompt_token_ids.device)
rng.seed()

states = sequence_generator(
self.model,
self.sampler,
Expand Down

0 comments on commit 42f465c

Please sign in to comment.