-
Notifications
You must be signed in to change notification settings - Fork 931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incompatible shapes when #2501
Comments
Oh, I managed to get it working by turning off kv-caching and then turning it on again: impl Chat {
// ...
pub fn run(&mut self, prompt: &str) -> Result<String> {
self.tokens.extend(
self.tokenizer
.encode(prompt, ADD_SPECIAL_TOKENS)
.map_err(Error::msg)?
.get_ids(),
);
self.cache.use_kv_cache = false; // <---- Here
for _ in 0..SAMPLE_LEN {
let tokens_slice = &self.tokens[self.index..];
let input = Tensor::new(tokens_slice, &self.device)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input, self.index, &mut self.cache)?
.squeeze(0)?;
self.cache.use_kv_cache = true; // <---- Here
let logits = candle_transformers::utils::apply_repeat_penalty(
&logits,
REPEAT_PENALTY,
&self.tokens[self.tokens.len().saturating_sub(REPEAT_LAST_N)..],
)?;
// ...
}
// ...
}
// ...
} This means the first forward of every run is done without kv-caching. Is this the correct way to approach it? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I tried to modify the code in https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs to become a chatbot where each new prompt considers the history of all previous prompts. This is my code:
When I run, I get this error:
The error happens the second time I call
Chat::run
inmain
and is thrown from this statement.The first time I run the chat in main, the shape of
input
is[1,5]
. After producing an output token, the next shape ofinput
is[1,1]
since I use key-value caching.When I later enter a new prompt and run the chat, the
input
shape is[1,3]
(which includes the EOS token from the previous run). The error disappears if drop some tokens so the shape becomes[1,1]
. Is there something that says the shape must be[1,1]
when we use key-value caching?The text was updated successfully, but these errors were encountered: