diff --git a/mistralrs-core/src/pipeline/gguf.rs b/mistralrs-core/src/pipeline/gguf.rs index 8ffc8350b..c2957a626 100644 --- a/mistralrs-core/src/pipeline/gguf.rs +++ b/mistralrs-core/src/pipeline/gguf.rs @@ -40,7 +40,7 @@ use crate::{ utils::tokens::get_token, xlora_models::{XLoraQLlama, XLoraQPhi3}, }; -use anyhow::{bail, Context, Result}; +use anyhow::{bail, Result}; use candle_core::{DType, Device, Tensor}; use either::Either; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; @@ -645,30 +645,32 @@ impl Pipeline for GGUFPipeline { flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed."); - let paged_attn_meta = paged_attn_meta - .as_mut() - .with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") - .map_err(|e| candle_core::Error::Msg(e.to_string()))?; + let paged_attn_meta = match ( + self.get_metadata().cache_engine.as_ref(), + &mut paged_attn_meta, + ) { + (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)), + (Some(_), None) => { + // This can happen if Rust-side user code is wrong + candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + } + (None, Some(_)) => { + // This should never happen but we handle it anyway + candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") + } + (None, None) => None, + }; let logits = match self.model { Model::Llama(ref model) => model.forward( &input_ids, &seqlen_offsets, seqlen_offsets_kernel, context_lens, - self.get_metadata() - .cache_engine - .as_ref() - .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), - )?, - Model::Phi2(ref model) => model.forward( - &input_ids, - &seqlen_offsets, - context_lens, - self.get_metadata() - .cache_engine - .as_ref() - .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), + paged_attn_meta, )?, + Model::Phi2(ref model) => { + model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)? + } Model::XLoraLlama(ref model) => model.forward( &input_ids, input_ids_full.as_ref().unwrap_or(&input_ids), @@ -682,14 +684,9 @@ impl Pipeline for GGUFPipeline { &flash_meta, flash_meta_full.as_ref().unwrap_or(&flash_meta), )?, - Model::Phi3(ref model) => model.forward( - &input_ids, - &seqlen_offsets, - self.get_metadata() - .cache_engine - .as_ref() - .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), - )?, + Model::Phi3(ref model) => { + model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)? + } Model::XLoraPhi3(ref model) => model.forward( &input_ids, input_ids_full.as_ref().unwrap_or(&input_ids), @@ -707,10 +704,7 @@ impl Pipeline for GGUFPipeline { &input_ids, &seqlen_offsets, seqlen_offsets_kernel, - self.get_metadata() - .cache_engine - .as_ref() - .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), + paged_attn_meta, )?, }; Ok(ForwardInputsResult::CausalGeneration { logits }) diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 5ea931468..867493507 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -32,7 +32,7 @@ use crate::{ normal_model_loader, xlora_model_loader, DeviceMapMetadata, PagedAttentionConfig, Pipeline, Topology, TryIntoDType, }; -use anyhow::{Context, Result}; +use anyhow::Result; use candle_core::{Device, Tensor, Var}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use mistralrs_quant::IsqType; @@ -515,10 +515,21 @@ impl Pipeline for NormalPipeline { flash_meta, flash_meta_full, } = *inputs.downcast().expect("Downcast failed."); - let paged_attn_meta = paged_attn_meta - .as_mut() - .with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") - .map_err(|e| candle_core::Error::Msg(e.to_string()))?; + let paged_attn_meta = match ( + self.get_metadata().cache_engine.as_ref(), + &mut paged_attn_meta, + ) { + (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)), + (Some(_), None) => { + // This can happen if Rust-side user code is wrong + candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + } + (None, Some(_)) => { + // This should never happen but we handle it anyway + candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") + } + (None, None) => None, + }; let logits = match self.model.is_xlora() { false => self.model.forward( &input_ids, @@ -526,10 +537,7 @@ impl Pipeline for NormalPipeline { seqlen_offsets_kernel, context_lens, position_ids, - self.get_metadata() - .cache_engine - .as_ref() - .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), + paged_attn_meta, &flash_meta, )?, true => self.model.xlora_forward( diff --git a/mistralrs-core/src/pipeline/vision.rs b/mistralrs-core/src/pipeline/vision.rs index 7d883afa3..01ce33ccc 100644 --- a/mistralrs-core/src/pipeline/vision.rs +++ b/mistralrs-core/src/pipeline/vision.rs @@ -24,7 +24,7 @@ use crate::{ api_dir_list, api_get_file, get_paths, vision_normal_model_loader, AnyMoeExpertType, DeviceMapMetadata, Ordering, PagedAttentionConfig, Pipeline, Topology, TryIntoDType, }; -use anyhow::{Context, Result}; +use anyhow::Result; use candle_core::{Device, Tensor, Var}; use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; use mistralrs_quant::IsqType; @@ -411,10 +411,21 @@ impl Pipeline for VisionPipeline { mut paged_attn_meta, flash_meta, } = *inputs.downcast::().expect("Downcast failed."); - let paged_attn_meta = paged_attn_meta - .as_mut() - .with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") - .map_err(|e| candle_core::Error::Msg(e.to_string()))?; + let paged_attn_meta = match ( + self.get_metadata().cache_engine.as_ref(), + &mut paged_attn_meta, + ) { + (Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)), + (Some(_), None) => { + // This can happen if Rust-side user code is wrong + candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.") + } + (None, Some(_)) => { + // This should never happen but we handle it anyway + candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.") + } + (None, None) => None, + }; let logits = self.model.forward( &input_ids, pixel_values, @@ -423,10 +434,7 @@ impl Pipeline for VisionPipeline { context_lens, position_ids, model_specific_args, - self.get_metadata() - .cache_engine - .as_ref() - .map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)), + paged_attn_meta, &flash_meta, )?; Ok(ForwardInputsResult::CausalGeneration { logits })