Skip to content

Commit

Permalink
Switch to fixed length token tensor (slice_assign) + concurrent stop …
Browse files Browse the repository at this point in the history
…state
  • Loading branch information
laggui committed Jan 23, 2025
1 parent 0e544a0 commit 3cd82f8
Showing 1 changed file with 76 additions and 20 deletions.
96 changes: 76 additions & 20 deletions llama-burn/src/llama.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
use std::time::Instant;
use std::{
sync::{
atomic::{AtomicBool, Ordering},
mpsc::Sender,
Arc,
},
time::Instant,
};

use burn::{
config::Config,
Expand Down Expand Up @@ -597,6 +604,48 @@ pub struct GenerationOutput {
pub time: f64,
}

#[derive(Clone)]
/// The text generation state, used to check when a stop token has been reached.
pub struct GenerateState<B: Backend> {
sender: Sender<Tensor<B, 1, Int>>,
state: Arc<AtomicBool>,
num_tokens: usize,
}

impl<B: Backend> GenerateState<B> {
/// Create a new instance.
pub fn new(stop_tokens: Tensor<B, 1, Int>) -> Self {
let (sender, receiver) = std::sync::mpsc::channel();
let state = Arc::new(AtomicBool::new(false));
let state_clone = Arc::clone(&state);

std::thread::spawn(move || {
for tokens in receiver {
if stop_tokens.clone().equal(tokens).any().into_scalar() {
state_clone.store(true, Ordering::Relaxed);
}
}
});

Self {
sender,
state,
num_tokens: 0,
}
}

/// Check if a stop token has been encountered during generation.
pub fn check_stop_token(&mut self, tokens: Tensor<B, 1, Int>) {
self.num_tokens += tokens.shape().num_elements(); // should always be +1
self.sender.send(tokens).unwrap();
}

/// True if `.check_stop_token` previously detected a stop token.
pub fn should_stop(&self) -> bool {
self.state.load(Ordering::Relaxed)
}
}

/// Meta Llama large language model and tokenizer.
pub struct Llama<B: Backend, T: Tokenizer> {
/// The tokenizer.
Expand Down Expand Up @@ -629,14 +678,23 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {
temperature: f64,
sampler: &mut Sampler,
) -> GenerationOutput {
let mut tokens = self.tokenize(prompt);
let prompt_len = tokens.dims()[0];
let stop_tokens = Tensor::from_ints(self.tokenizer.stop_ids().as_slice(), &self.device);
let input_tokens = self.tokenize(prompt);
let prompt_len = input_tokens.dims()[0];
let mut tokens = Tensor::<B, 1, Int>::empty([prompt_len + sample_len], &self.device);
tokens = tokens.slice_assign([0..prompt_len], input_tokens);

let mut num_tokens: usize = 0;
let mut input_pos = Tensor::<B, 1, Int>::arange(0..tokens.dims()[0] as i64, &self.device);
let mut state = GenerateState::new(Tensor::from_ints(
self.tokenizer.stop_ids().as_slice(),
&self.device,
));

let mut input_pos = Tensor::<B, 1, Int>::arange(0..prompt_len as i64, &self.device);
let now = Instant::now();
for _ in 0..sample_len {
for i in 0..sample_len {
if state.should_stop() {
break;
}

let x = tokens.clone().select(0, input_pos.clone()).reshape([1, -1]);
let logits = self.model.forward(x, &mut self.cache, &self.rope);

Expand All @@ -652,25 +710,18 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {
let next_token = sampler.sample(next_token_logits).squeeze(0);

// Stop when any of the valid stop tokens is encountered
if stop_tokens
.clone()
.equal(next_token.clone())
.any()
.into_scalar()
{
break;
}
state.check_stop_token(next_token.clone());

// Concatenate the new generated token
tokens = Tensor::cat(vec![tokens, next_token], 0);
num_tokens += 1;
// Update with the new generated token
tokens = tokens.slice_assign([prompt_len + i..prompt_len + i + 1], next_token);

// Advance
let t = input_pos.dims()[0];
input_pos = input_pos.slice([t - 1..t]) + 1;
}

let tokens = tokens.into_data().as_slice::<B::IntElem>().unwrap()[prompt_len..]
let tokens = tokens.into_data().as_slice::<B::IntElem>().unwrap()
[prompt_len..prompt_len + state.num_tokens]
.iter()
.map(|t| t.elem::<u32>())
.collect::<Vec<_>>();
Expand All @@ -680,7 +731,7 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {

GenerationOutput {
text: generated,
tokens: num_tokens,
tokens: state.num_tokens,
time: elapsed,
}
}
Expand Down Expand Up @@ -723,6 +774,11 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {

Ok(self)
}

/// Reset the model state (used between generations)
pub fn reset(&mut self) {
self.cache.iter_mut().for_each(|cache| cache.reset());
}
}

impl RopeFrequencyScaling {
Expand Down

0 comments on commit 3cd82f8

Please sign in to comment.