Skip to content

Commit

Permalink
Model is now loaded in the constructor and retained for multiple embe…
Browse files Browse the repository at this point in the history
…dding runs
  • Loading branch information
cpetersen committed Mar 13, 2024
1 parent b989c2a commit 5288301
Showing 1 changed file with 55 additions and 45 deletions.
100 changes: 55 additions & 45 deletions ext/candle/src/model/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,84 +20,94 @@ pub struct ModelConfig(pub ModelConfigInner);

pub struct ModelConfigInner {
device: Device,

tokenizer_path: Option<String>,

model_path: Option<String>,
model: Option<BertModel>,
tokenizer: Option<Tokenizer>,
}

impl ModelConfig {
pub fn new() -> RbResult<Self> {
Ok(ModelConfig(ModelConfigInner {
device: Device::Cpu,
model_path: None,
tokenizer_path: None,
}))
Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu))
}

pub fn build() -> ModelConfig {
ModelConfig(ModelConfigInner {
device: Device::Cpu,
model_path: None,
tokenizer_path: None
})
pub fn new2(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>) -> RbResult<Self> {
let device = device.unwrap_or(Device::Cpu);
Ok(ModelConfig(ModelConfigInner {
device: device.clone(),
model_path: model_path.clone(),
tokenizer_path: tokenizer_path.clone(),
model: match model_path {
Some(mp) => Some(Self::build_model(mp, device)?),
None => None
},
tokenizer: match tokenizer_path {
Some(tp) => Some(Self::build_tokenizer(tp)?),
None => None
}
}))
}

/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
pub fn embedding(&self, input: String) -> RbResult<RbTensor> {
let config = ModelConfig::build();
let (model, tokenizer) = config.build_model_and_tokenizer()?;
Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?))
match &self.0.model {
Some(model) => {
match &self.0.tokenizer {
Some(tokenizer) => Ok(RbTensor(self.compute_embedding(input, model, tokenizer)?)),
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
}
}
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer or Model not found"))
}

}

fn build_model_and_tokenizer(&self) -> Result<(BertModel, tokenizers::Tokenizer), Error> {
fn build_model(model_path: String, device: Device) -> RbResult<BertModel> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model_path = match &self.0.model_path {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()
let model_path = Api::new()
.map_err(wrap_hf_err)?
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
model_path,
RepoType::Model,
))
.get("model.safetensors")
.map_err(wrap_hf_err)?,
.map_err(wrap_hf_err)?;
let config = Config::v2_base();
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &device)
.map_err(wrap_candle_err)?
};
let tokenizer_path = match &self.0.tokenizer_path {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()
let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?;
Ok(model)
}

fn build_tokenizer(tokenizer_path: String) -> RbResult<Tokenizer> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let tokenizer_path = Api::new()
.map_err(wrap_hf_err)?
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
tokenizer_path,
RepoType::Model,
))
.get("tokenizer.json")
.map_err(wrap_hf_err)?,
};
// let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path).map_err(wrap_std_err)?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[model_path], DType::F32, &self.0.device)
.map_err(wrap_candle_err)?
};
let model = BertModel::new(vb, &config).map_err(wrap_candle_err)?;
Ok((model, tokenizer))
.map_err(wrap_hf_err)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
// .with_padding(None)
// .with_truncation(None)
.map_err(wrap_std_err)?;
Ok(tokenizer)
}

fn compute_embedding(
&self,
prompt: String,
model: BertModel,
mut tokenizer: Tokenizer,
model: &BertModel,
tokenizer: &Tokenizer,
) -> Result<Tensor, Error> {
let start: std::time::Instant = std::time::Instant::now();
// let prompt = args.prompt.as_deref().unwrap_or("Hello, world!");
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(wrap_std_err)?;
// let tokenizer_impl = tokenizer
// .map_err(wrap_std_err)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(wrap_std_err)?
Expand All @@ -110,7 +120,7 @@ impl ModelConfig {
println!("Loaded and encoded {:?}", start.elapsed());
let start: std::time::Instant = std::time::Instant::now();
let result = model.forward(&token_ids).map_err(wrap_candle_err)?;
println!("{result}");
// println!("{result}");
println!("Took {:?}", start.elapsed());
Ok(result)
}
Expand Down

0 comments on commit 5288301

Please sign in to comment.