diff --git a/.gitignore b/.gitignore index 83e47d2..2c7295b 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,9 @@ # Ignore Byebug command history file. .byebug_history +# Ignore the RustRover project file +.idea + ## Specific to RubyMotion: .dat* .repl_history @@ -64,4 +67,4 @@ target *.o *.lock -lib.py.rs \ No newline at end of file +lib.py.rs diff --git a/README.md b/README.md index 5b97a03..d530fa0 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,50 @@ x = x.reshape([3, 2]) # Tensor[[3, 2], f32] ``` +```ruby +require 'candle' +model = Candle::Model.new +embedding = model.embedding("Hi there!") +``` + +## A note on memory usage +The `Candle::Model` defaults to the `jinaai/jina-embeddings-v2-base-en` model with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer (both from [HuggingFace](https://huggingface.co)). With this configuration the model takes a little more than 3GB of memory running on my Mac. The memory stays with the instantiated `Candle::Model` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example: + +```ruby +> require 'candle' +# Ruby memory = 25.9 MB +> model = Candle::Model.new +# Ruby memory = 3.50 GB +> model2 = Candle::Model.new +# Ruby memory = 7.04 GB +> model2 = nil +> GC.start +# Ruby memory = 3.56 GB +> model = nil +> GC.start +# Ruby memory = 55.2 MB +``` + +## A note on returned embeddings + +The code should match the same embeddings when generated from the python `transformers` library. For instance, locally I was able to generate the same embedding for the text "Hi there!" using the python code: + +```python +from transformers import AutoModel +model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True) +sentence = ['Hi there!'] +embedding = model.encode(sentence) +print(embedding) +``` + +And the following ruby: + +```ruby +require 'candle' +model = Candle::Model.new +embedding = model.embedding("Hi there!") +``` + ## Development FORK IT! @@ -29,6 +73,7 @@ bundle bundle exec rake compile ``` + Implemented with [Magnus](https://github.com/matsadler/magnus), with reference to [Polars Ruby](https://github.com/ankane/polars-ruby) Policies diff --git a/ext/candle/src/lib.rs b/ext/candle/src/lib.rs index 33c5a5b..49e171b 100644 --- a/ext/candle/src/lib.rs +++ b/ext/candle/src/lib.rs @@ -1,6 +1,6 @@ use magnus::{function, method, prelude::*, Ruby}; -use crate::model::{candle_utils, ModelConfig, RbDType, RbDevice, RbQTensor, RbResult, RbTensor}; +use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor}; pub mod model; @@ -22,6 +22,7 @@ fn init(ruby: &Ruby) -> RbResult<()> { rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?; rb_tensor.define_method("device", method!(RbTensor::device, 0))?; rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?; + rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?; rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?; rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?; rb_tensor.define_method("log", method!(RbTensor::log, 0))?; @@ -93,10 +94,10 @@ fn init(ruby: &Ruby) -> RbResult<()> { rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?; let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?; - rb_model.define_singleton_method("new", function!(ModelConfig::new, 0))?; - rb_model.define_method("embedding", method!(ModelConfig::embedding, 1))?; - rb_model.define_method("to_s", method!(ModelConfig::__str__, 0))?; - rb_model.define_method("inspect", method!(ModelConfig::__repr__, 0))?; + rb_model.define_singleton_method("new", function!(RbModel::new, 0))?; + rb_model.define_method("embedding", method!(RbModel::embedding, 1))?; + rb_model.define_method("to_s", method!(RbModel::__str__, 0))?; + rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?; Ok(()) } diff --git a/ext/candle/src/model/mod.rs b/ext/candle/src/model/mod.rs index 40ef61d..f0ec326 100644 --- a/ext/candle/src/model/mod.rs +++ b/ext/candle/src/model/mod.rs @@ -1,5 +1,5 @@ -mod config; -pub use config::*; +mod rb_model; +pub use rb_model::*; mod errors; pub use errors::*; diff --git a/ext/candle/src/model/config.rs b/ext/candle/src/model/rb_model.rs similarity index 76% rename from ext/candle/src/model/config.rs rename to ext/candle/src/model/rb_model.rs index a1a3fe7..b6ff1c0 100644 --- a/ext/candle/src/model/config.rs +++ b/ext/candle/src/model/rb_model.rs @@ -16,9 +16,9 @@ use crate::model::RbResult; use tokenizers::Tokenizer; #[magnus::wrap(class = "Candle::Model", free_immediately, size)] -pub struct ModelConfig(pub ModelConfigInner); +pub struct RbModel(pub RbModelInner); -pub struct ModelConfigInner { +pub struct RbModelInner { device: Device, tokenizer_path: Option, model_path: Option, @@ -26,14 +26,14 @@ pub struct ModelConfigInner { tokenizer: Option, } -impl ModelConfig { +impl RbModel { pub fn new() -> RbResult { Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu)) } pub fn new2(model_path: Option, tokenizer_path: Option, device: Option) -> RbResult { let device = device.unwrap_or(Device::Cpu); - Ok(ModelConfig(ModelConfigInner { + Ok(RbModel(RbModelInner { device: device.clone(), model_path: model_path.clone(), tokenizer_path: tokenizer_path.clone(), @@ -92,10 +92,14 @@ impl ModelConfig { )) .get("tokenizer.json") .map_err(wrap_hf_err)?; - let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) - // .with_padding(None) - // .with_truncation(None) + let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path) .map_err(wrap_std_err)?; + let pp = tokenizers::PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + Ok(tokenizer) } @@ -105,9 +109,6 @@ impl ModelConfig { model: &BertModel, tokenizer: &Tokenizer, ) -> Result { - let start: std::time::Instant = std::time::Instant::now(); - // let tokenizer_impl = tokenizer - // .map_err(wrap_std_err)?; let tokens = tokenizer .encode(prompt, true) .map_err(wrap_std_err)? @@ -117,16 +118,29 @@ impl ModelConfig { .map_err(wrap_candle_err)? .unsqueeze(0) .map_err(wrap_candle_err)?; - println!("Loaded and encoded {:?}", start.elapsed()); - let start: std::time::Instant = std::time::Instant::now(); + + // let start: std::time::Instant = std::time::Instant::now(); let result = model.forward(&token_ids).map_err(wrap_candle_err)?; - // println!("{result}"); - println!("Took {:?}", start.elapsed()); - Ok(result) + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = result.dims3() + .map_err(wrap_candle_err)?; + let sum = result.sum(1) + .map_err(wrap_candle_err)?; + let embeddings = (sum / (n_tokens as f64)) + .map_err(wrap_candle_err)?; + // let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?; + + Ok(embeddings) + } + + #[allow(dead_code)] + fn normalize_l2(v: &Tensor) -> Result { + v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) } pub fn __repr__(&self) -> String { - format!("Candle::Model(path={})", self.0.model_path.as_deref().unwrap_or("None")) + format!("# String { @@ -143,14 +157,14 @@ impl ModelConfig { // #[test] // fn test_build_model_and_tokenizer() { -// let config = super::ModelConfig::build(); +// let config = super::RbModel::build(); // let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap(); // assert_eq!(tokenizer.get_vocab_size(true), 30522); // } // #[test] // fn test_embedding() { -// let config = super::ModelConfig::build(); +// let config = super::RbModel::build(); // // let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap(); // // assert_eq!(config.embedding("Scientist.com is a marketplace for pharmaceutical services.")?, None); // } diff --git a/ext/candle/src/model/rb_tensor.rs b/ext/candle/src/model/rb_tensor.rs index 8456cd2..b0c4370 100644 --- a/ext/candle/src/model/rb_tensor.rs +++ b/ext/candle/src/model/rb_tensor.rs @@ -40,7 +40,6 @@ impl RbTensor { )) } - // FXIME: Do not use `to_f64` here. pub fn values(&self) -> RbResult> { let values = self .0 @@ -83,6 +82,12 @@ impl RbTensor { self.0.rank() } + /// The number of elements stored in this tensor. + /// &RETURNS&: int + pub fn elem_count(&self) -> usize { + self.0.elem_count() + } + pub fn __repr__(&self) -> String { format!("{}", self.0) } diff --git a/ext/candle/src/model/utils.rs b/ext/candle/src/model/utils.rs index dc065c6..81ea0f8 100644 --- a/ext/candle/src/model/utils.rs +++ b/ext/candle/src/model/utils.rs @@ -71,6 +71,7 @@ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> { /// Applies the Softmax function to a given tensor.# /// &RETURNS&: Tensor +#[allow(dead_code)] fn softmax(tensor: RbTensor, dim: i64) -> RbResult { let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?; let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?; @@ -79,6 +80,7 @@ fn softmax(tensor: RbTensor, dim: i64) -> RbResult { /// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor. /// &RETURNS&: Tensor +#[allow(dead_code)] fn silu(tensor: RbTensor) -> RbResult { let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?; Ok(RbTensor(s)) diff --git a/lib/candle.rb b/lib/candle.rb index 886012a..0d82e32 100644 --- a/lib/candle.rb +++ b/lib/candle.rb @@ -1 +1,2 @@ require_relative 'candle/candle' +require_relative 'candle/tensor' diff --git a/lib/candle/tensor.rb b/lib/candle/tensor.rb new file mode 100644 index 0000000..4b412db --- /dev/null +++ b/lib/candle/tensor.rb @@ -0,0 +1,17 @@ +module Candle + class Tensor + include Enumerable + + def each + if self.rank == 1 + self.values.each do |value| + yield value + end + else + shape.first.times do |i| + yield self[i] + end + end + end + end +end