From 528830182f96982320d85ead67e52e8828522848 Mon Sep 17 00:00:00 2001 From: Chris Petersen Date: Wed, 13 Mar 2024 11:43:51 -0700 Subject: [PATCH] Model is now loaded in the constructor and retained for multiple embedding runs --- ext/candle/src/model/config.rs | 100 ++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 45 deletions(-) diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/config.rs index f74a91f..62988e2 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/config.rs @@ -20,84 +20,94 @@ pub struct ModelConfig(pub ModelConfigInner); pub struct ModelConfigInner { device: Device, - tokenizer_path: Option, - model_path: Option, + model: Option, + tokenizer: Option, } impl ModelConfig { pub fn new() -> RbResult { - Ok(ModelConfig(ModelConfigInner { - device: Device::Cpu, - model_path: None, - tokenizer_path: None, - })) + Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu)) } - pub fn build() -> ModelConfig { - ModelConfig(ModelConfigInner { - device: Device::Cpu, - model_path: None, - tokenizer_path: None - }) + pub fn new2(model_path: Option, tokenizer_path: Option, device: Option) -> RbResult { + let device = device.unwrap_or(Device::Cpu); + Ok(ModelConfig(ModelConfigInner { + device: device.clone(), + model_path: model_path.clone(), + tokenizer_path: tokenizer_path.clone(), + model: match model_path { + Some(mp) => Some(Self::build_model(mp, device)?), + None => None + }, + tokenizer: match tokenizer_path { + Some(tp) => Some(Self::build_tokenizer(tp)?), + None => None + } + })) } /// Performs the `sin` operation on the tensor. /// &RETURNS&: Tensor pub fn embedding(&self, input: String) -> RbResult { - let config = ModelConfig::build(); - let (model, tokenizer) = config.build_model_and_tokenizer()?; - Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?)) + match &self.0.model { + Some(model) => { + match &self.0.tokenizer { + Some(tokenizer) => Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?)), + None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found")) + } + } + None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer or Model not found")) + } + } - fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> { + fn build_model(model_path: String, device: Device) -> RbResult { use hf_hub::{api::sync::Api, Repo, RepoType}; - let model_path = match &self.0.model_path { - Some(model_file) => std::path::PathBuf::from(model_file), - None => Api::new() + let model_path = Api::new() .map_err(wrap_hf_err)? .repo(Repo::new( - "jinaai/jina-embeddings-v2-base-en".to_string(), + model_path, RepoType::Model, )) .get("model.safetensors") - .map_err(wrap_hf_err)?, + .map_err(wrap_hf_err)?; + let config = Config::v2_base(); + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device) + .map_err(wrap_candle_err)? }; - let tokenizer_path = match &self.0.tokenizer_path { - Some(file) => std::path::PathBuf::from(file), - None => Api::new() + let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?; + Ok(model) + } + + fn build_tokenizer(tokenizer_path: String) -> RbResult { + use hf_hub::{api::sync::Api, Repo, RepoType}; + let tokenizer_path = Api::new() .map_err(wrap_hf_err)? .repo(Repo::new( - "sentence-transformers/all-MiniLM-L6-v2".to_string(), + tokenizer_path, RepoType::Model, )) .get("tokenizer.json") - .map_err(wrap_hf_err)?, - }; - // let device = candle_examples::device(self.cpu)?; - let config = Config::v2_base(); - let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path).map_err(wrap_std_err)?; - let vb = unsafe { - VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.0.device) - .map_err(wrap_candle_err)? - }; - let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?; - Ok((model, tokenizer)) + .map_err(wrap_hf_err)?; + let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) + // .with_padding(None) + // .with_truncation(None) + .map_err(wrap_std_err)?; + Ok(tokenizer) } fn compute_embedding( &self, prompt: String, - model: BertModel, - mut tokenizer: Tokenizer, + model: &BertModel, + tokenizer: &Tokenizer, ) -> Result { let start: std::time::Instant = std::time::Instant::now(); - // let prompt = args.prompt.as_deref().unwrap_or("Hello, world!"); - let tokenizer = tokenizer - .with_padding(None) - .with_truncation(None) - .map_err(wrap_std_err)?; + // let tokenizer_impl = tokenizer + // .map_err(wrap_std_err)?; let tokens = tokenizer .encode(prompt, true) .map_err(wrap_std_err)? @@ -110,7 +120,7 @@ impl ModelConfig { println!("Loaded and encoded {:?}", start.elapsed()); let start: std::time::Instant = std::time::Instant::now(); let result = model.forward(&token_ids).map_err(wrap_candle_err)?; - println!("{result}"); + // println!("{result}"); println!("Took {:?}", start.elapsed()); Ok(result) }