Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatible shapes when #2501

Open
segeljakt opened this issue Sep 25, 2024 · 1 comment
Open

Incompatible shapes when #2501

segeljakt opened this issue Sep 25, 2024 · 1 comment

Comments

@segeljakt
Copy link

segeljakt commented Sep 25, 2024

I tried to modify the code in https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs to become a chatbot where each new prompt considers the history of all previous prompts. This is my code:

use anyhow::{Error, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use candle_transformers::models::llama::Llama;
use candle_transformers::models::llama::LlamaConfig;
use candle_transformers::models::llama::LlamaEosToks;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;

const EOS_TOKEN: &str = "</s>";
const REPEAT_PENALTY: f32 = 1.1;
const REPEAT_LAST_N: usize = 128;
const SEED: u64 = 299792458;
const SAMPLE_LEN: usize = 10000;
const ADD_SPECIAL_TOKENS: bool = true;
const SKIP_SPECIAL_TOKENS: bool = true;
const USE_KV_CACHE: bool = true;
const USE_FLASH_ATTENTION: bool = false;

pub struct Chat {
    model: Llama,
    logits_processor: LogitsProcessor,
    cache: Cache,
    tokenizer: Tokenizer,
    device: Device,
    eos_token_id: Option<LlamaEosToks>,
    tokens: Vec<u32>,
    index: usize,
}

impl Chat {
    pub fn new() -> Result<Self> {
        let device = Device::new_metal(0)?;
        let dtype = DType::F16;
        let api = Api::new()?;
        let api = api.repo(Repo::with_revision(
            "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
            RepoType::Model,
            "main".to_string(),
        ));

        let tokenizer_filename = api.get("tokenizer.json")?;
        let config_filename = api.get("config.json")?;
        let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
        let config = config.into_config(USE_FLASH_ATTENTION);
        let filenames = vec![api.get("model.safetensors")?];
        let cache = Cache::new(USE_KV_CACHE, dtype, &config, &device)?;

        let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
        let model = Llama::load(vb, &config)?;

        let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(Error::msg)?;
        let eos_token_id = config
            .eos_token_id
            .or_else(|| tokenizer.token_to_id(EOS_TOKEN).map(LlamaEosToks::Single));
        let logits_processor = LogitsProcessor::from_sampling(SEED, Sampling::ArgMax);

        Ok(Self {
            model,
            tokenizer,
            logits_processor,
            eos_token_id,
            cache,
            device,
            tokens: Vec::new(),
            index: 0,
        })
    }

    pub fn run(&mut self, prompt: &str) -> Result<String> {
        self.tokens.extend(
            self.tokenizer
                .encode(prompt, ADD_SPECIAL_TOKENS)
                .map_err(Error::msg)?
                .get_ids(),
        );

        for _ in 0..SAMPLE_LEN {
            let tokens_slice = &self.tokens[self.index..];
            let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
            let logits = self
                .model
                .forward(&input, self.index, &mut self.cache)?
                .squeeze(0)?;
            let logits = candle_transformers::utils::apply_repeat_penalty(
                &logits,
                REPEAT_PENALTY,
                &self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
            )?;
            self.index += tokens_slice.len();

            let next_token = self.logits_processor.sample(&logits)?;
            self.tokens.push(next_token);

            if self.is_eos_token(next_token) {
                break;
            }
        }
        let output = self
            .tokenizer
            .decode(&self.tokens, SKIP_SPECIAL_TOKENS)
            .map_err(Error::msg)?;
        Ok(output)
    }

    fn is_eos_token(&self, token: u32) -> bool {
        matches!(self.eos_token_id, Some(LlamaEosToks::Single(id)) if token == id)
            || matches!(self.eos_token_id, Some(LlamaEosToks::Multiple(ref ids)) if ids.contains(&token))
    }
}

fn main() {
    let mut chat = Chat::new().unwrap();
    println!("{}", chat.run("Hello my name is").unwrap());
    println!("{}", chat.run("Today").unwrap());
}

When I run, I get this error:

called `Result::unwrap()` on an `Err` value: BroadcastIncompatibleShapes { src_shape: [3, 3], dst_shape: [1, 32, 3, 349] }

The error happens the second time I call Chat::run in main and is thrown from this statement.

    // ...
            let logits = self
                .model
                .forward(&input, self.index, &mut self.cache)?
    // ...

The first time I run the chat in main, the shape of input is [1,5]. After producing an output token, the next shape of input is [1,1] since I use key-value caching.

When I later enter a new prompt and run the chat, the input shape is [1,3] (which includes the EOS token from the previous run). The error disappears if drop some tokens so the shape becomes [1,1]. Is there something that says the shape must be [1,1] when we use key-value caching?

@segeljakt
Copy link
Author

segeljakt commented Sep 25, 2024

Oh, I managed to get it working by turning off kv-caching and then turning it on again:

impl Chat {
    // ...
    pub fn run(&mut self, prompt: &str) -> Result<String> {
        self.tokens.extend(
            self.tokenizer
                .encode(prompt, ADD_SPECIAL_TOKENS)
                .map_err(Error::msg)?
                .get_ids(),
        );
        self.cache.use_kv_cache = false; // <---- Here

        for _ in 0..SAMPLE_LEN {
            let tokens_slice = &self.tokens[self.index..];
            let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
            let logits = self
                .model
                .forward(&input, self.index, &mut self.cache)?
                .squeeze(0)?;
            self.cache.use_kv_cache = true;  // <---- Here
            let logits = candle_transformers::utils::apply_repeat_penalty(
                &logits,
                REPEAT_PENALTY,
                &self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
            )?;
            // ...
        }
        // ...   
    }
    // ...
}

This means the first forward of every run is done without kv-caching. Is this the correct way to approach it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant