Skip to content

Commit

Permalink
Take positions as input for .forward
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 17, 2023
1 parent a9fdd0e commit 9c08e04
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/openai/pipelines/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ impl<'s> ModulePipeline<'s> for LlamaPipeline {
fn forward(
&mut self,
_input_tokens: Tensor,
_input_positions: Tensor,
_kv_cache: Option<Arc<Vec<(Tensor, Tensor)>>>,
_input_metadata: InputMetadata,
) -> Result<Vec<TokenOrFinishReason>, APIError> {
Expand Down
11 changes: 7 additions & 4 deletions src/openai/pipelines/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<'a> LLMEngine<'a> {

let PreparedInputs {
tokens,
positions: _,
positions,
metadata,
} = if scheduled
.front()
Expand All @@ -147,9 +147,12 @@ impl<'a> LLMEngine<'a> {
self.prepare_decode(scheduled)
}?;

let result =
self.pipeline
.forward(tokens, Some(self.cache_engine.get_kv_cache()), metadata)?;
let result = self.pipeline.forward(
tokens,
positions,
Some(self.cache_engine.get_kv_cache()),
metadata,
)?;

for (result, (_, seq)) in zip(result, seqs) {
match result {
Expand Down
1 change: 1 addition & 0 deletions src/openai/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub trait ModulePipeline<'s>: Send + Sync {
fn forward(
&mut self,
input_tokens: Tensor,
input_positions: Tensor,
kv_cache: Option<Arc<Vec<(Tensor, Tensor)>>>,
input_metadata: InputMetadata,
) -> Result<Vec<TokenOrFinishReason>, APIError>;
Expand Down

0 comments on commit 9c08e04

Please sign in to comment.