Skip to content

Commit

Permalink
Merge pull request #1 from assaydepot/main
Browse files Browse the repository at this point in the history
Made Embeddings work
  • Loading branch information
kojix2 committed Mar 24, 2024
2 parents c9ad1f3 + a65963e commit cca1da9
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 27 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
# Ignore Byebug command history file.
.byebug_history

# Ignore the RustRover project file
.idea

## Specific to RubyMotion:
.dat*
.repl_history
Expand Down Expand Up @@ -64,4 +67,4 @@ target
*.o
*.lock

lib.py.rs
lib.py.rs
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,50 @@ x = x.reshape([3, 2])
# Tensor[[3, 2], f32]
```

```ruby
require 'candle'
model = Candle::Model.new
embedding = model.embedding("Hi there!")
```

## A note on memory usage
The `Candle::Model` defaults to the `jinaai/jina-embeddings-v2-base-en` model with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer (both from [HuggingFace](https://huggingface.co)). With this configuration the model takes a little more than 3GB of memory running on my Mac. The memory stays with the instantiated `Candle::Model` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example:

```ruby
> require 'candle'
# Ruby memory = 25.9 MB
> model = Candle::Model.new
# Ruby memory = 3.50 GB
> model2 = Candle::Model.new
# Ruby memory = 7.04 GB
> model2 = nil
> GC.start
# Ruby memory = 3.56 GB
> model = nil
> GC.start
# Ruby memory = 55.2 MB
```

## A note on returned embeddings

The code should match the same embeddings when generated from the python `transformers` library. For instance, locally I was able to generate the same embedding for the text "Hi there!" using the python code:

```python
from transformers import AutoModel
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
sentence = ['Hi there!']
embedding = model.encode(sentence)
print(embedding)
```

And the following ruby:

```ruby
require 'candle'
model = Candle::Model.new
embedding = model.embedding("Hi there!")
```

## Development

FORK IT!
Expand All @@ -29,6 +73,7 @@ bundle
bundle exec rake compile
```


Implemented with [Magnus](https://github.com/matsadler/magnus), with reference to [Polars Ruby](https://github.com/ankane/polars-ruby)

Policies
Expand Down
11 changes: 6 additions & 5 deletions ext/candle/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use magnus::{function, method, prelude::*, Ruby};

use crate::model::{candle_utils, ModelConfig, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};
use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};

pub mod model;

Expand All @@ -22,6 +22,7 @@ fn init(ruby: &Ruby) -> RbResult<()> {
rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?;
rb_tensor.define_method("device", method!(RbTensor::device, 0))?;
rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?;
rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?;
rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?;
rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?;
rb_tensor.define_method("log", method!(RbTensor::log, 0))?;
Expand Down Expand Up @@ -93,10 +94,10 @@ fn init(ruby: &Ruby) -> RbResult<()> {
rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?;

let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
rb_model.define_singleton_method("new", function!(ModelConfig::new, 0))?;
rb_model.define_method("embedding", method!(ModelConfig::embedding, 1))?;
rb_model.define_method("to_s", method!(ModelConfig::__str__, 0))?;
rb_model.define_method("inspect", method!(ModelConfig::__repr__, 0))?;
rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;

Ok(())
}
4 changes: 2 additions & 2 deletions ext/candle/src/model/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod config;
pub use config::*;
mod rb_model;
pub use rb_model::*;

mod errors;
pub use errors::*;
Expand Down
50 changes: 32 additions & 18 deletions ext/candle/src/model/config.rs → ext/candle/src/model/rb_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,24 @@ use crate::model::RbResult;
use tokenizers::Tokenizer;

#[magnus::wrap(class = "Candle::Model", free_immediately, size)]
pub struct ModelConfig(pub ModelConfigInner);
pub struct RbModel(pub RbModelInner);

pub struct ModelConfigInner {
pub struct RbModelInner {
device: Device,
tokenizer_path: Option<String>,
model_path: Option<String>,
model: Option<BertModel>,
tokenizer: Option<Tokenizer>,
}

impl ModelConfig {
impl RbModel {
pub fn new() -> RbResult<Self> {
Self::new2(Some("jinaai/jina-embeddings-v2-base-en".to_string()), Some("sentence-transformers/all-MiniLM-L6-v2".to_string()), Some(Device::Cpu))
}

pub fn new2(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>) -> RbResult<Self> {
let device = device.unwrap_or(Device::Cpu);
Ok(ModelConfig(ModelConfigInner {
Ok(RbModel(RbModelInner {
device: device.clone(),
model_path: model_path.clone(),
tokenizer_path: tokenizer_path.clone(),
Expand Down Expand Up @@ -92,10 +92,14 @@ impl ModelConfig {
))
.get("tokenizer.json")
.map_err(wrap_hf_err)?;
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
// .with_padding(None)
// .with_truncation(None)
let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
.map_err(wrap_std_err)?;
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));

Ok(tokenizer)
}

Expand All @@ -105,9 +109,6 @@ impl ModelConfig {
model: &BertModel,
tokenizer: &Tokenizer,
) -> Result<Tensor, Error> {
let start: std::time::Instant = std::time::Instant::now();
// let tokenizer_impl = tokenizer
// .map_err(wrap_std_err)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(wrap_std_err)?
Expand All @@ -117,16 +118,29 @@ impl ModelConfig {
.map_err(wrap_candle_err)?
.unsqueeze(0)
.map_err(wrap_candle_err)?;
println!("Loaded and encoded {:?}", start.elapsed());
let start: std::time::Instant = std::time::Instant::now();

// let start: std::time::Instant = std::time::Instant::now();
let result = model.forward(&token_ids).map_err(wrap_candle_err)?;
// println!("{result}");
println!("Took {:?}", start.elapsed());
Ok(result)

// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = result.dims3()
.map_err(wrap_candle_err)?;
let sum = result.sum(1)
.map_err(wrap_candle_err)?;
let embeddings = (sum / (n_tokens as f64))
.map_err(wrap_candle_err)?;
// let embeddings = Self::normalize_l2(&embeddings).map_err(wrap_candle_err)?;

Ok(embeddings)
}

#[allow(dead_code)]
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

pub fn __repr__(&self) -> String {
format!("Candle::Model(path={})", self.0.model_path.as_deref().unwrap_or("None"))
format!("#<Candle::Model model_path: {} tokenizer_path: {})", self.0.model_path.as_deref().unwrap_or("nil"), self.0.tokenizer_path.as_deref().unwrap_or("nil"))
}

pub fn __str__(&self) -> String {
Expand All @@ -143,14 +157,14 @@ impl ModelConfig {

// #[test]
// fn test_build_model_and_tokenizer() {
// let config = super::ModelConfig::build();
// let config = super::RbModel::build();
// let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap();
// assert_eq!(tokenizer.get_vocab_size(true), 30522);
// }

// #[test]
// fn test_embedding() {
// let config = super::ModelConfig::build();
// let config = super::RbModel::build();
// // let (_model, tokenizer) = config.build_model_and_tokenizer().unwrap();
// // assert_eq!(config.embedding("Scientist.com is a marketplace for pharmaceutical services.")?, None);
// }
Expand Down
7 changes: 6 additions & 1 deletion ext/candle/src/model/rb_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ impl RbTensor {
))
}

// FXIME: Do not use `to_f64` here.
pub fn values(&self) -> RbResult<Vec<f64>> {
let values = self
.0
Expand Down Expand Up @@ -83,6 +82,12 @@ impl RbTensor {
self.0.rank()
}

/// The number of elements stored in this tensor.
/// &RETURNS&: int
pub fn elem_count(&self) -> usize {
self.0.elem_count()
}

pub fn __repr__(&self) -> String {
format!("{}", self.0)
}
Expand Down
2 changes: 2 additions & 0 deletions ext/candle/src/model/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub fn candle_utils(rb_candle: magnus::RModule) -> Result<(), Error> {

/// Applies the Softmax function to a given tensor.#
/// &RETURNS&: Tensor
#[allow(dead_code)]
fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {
let dim = actual_dim(&tensor, dim).map_err(wrap_candle_err)?;
let sm = candle_nn::ops::softmax(&tensor.0, dim).map_err(wrap_candle_err)?;
Expand All @@ -79,6 +80,7 @@ fn softmax(tensor: RbTensor, dim: i64) -> RbResult<RbTensor> {

/// Applies the Sigmoid Linear Unit (SiLU) function to a given tensor.
/// &RETURNS&: Tensor
#[allow(dead_code)]
fn silu(tensor: RbTensor) -> RbResult<RbTensor> {
let s = candle_nn::ops::silu(&tensor.0).map_err(wrap_candle_err)?;
Ok(RbTensor(s))
Expand Down
1 change: 1 addition & 0 deletions lib/candle.rb
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
require_relative 'candle/candle'
require_relative 'candle/tensor'
17 changes: 17 additions & 0 deletions lib/candle/tensor.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module Candle
class Tensor
include Enumerable

def each
if self.rank == 1
self.values.each do |value|
yield value
end
else
shape.first.times do |i|
yield self[i]
end
end
end
end
end

0 comments on commit cca1da9

Please sign in to comment.