Skip to content

Commit

Permalink
tch 0.14.0 update (#435)
Browse files Browse the repository at this point in the history
* updated tch version

* Addition of casting operation for cpu compat

* Fix ONNX resource path

* Fix GPT-J bias bool loading

* Updated changelog

* Fix Clippy warnings

* Updated readme
  • Loading branch information
guillaume-be authored Nov 26, 2023
1 parent dc99a30 commit 9f2cd17
Show file tree
Hide file tree
Showing 46 changed files with 231 additions and 109 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ All notable changes to this project will be documented in this file. The format
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.

## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
- Updated ONNX runtime backend version to 1.15.x
- Issue with incorrect results for QA models with a tokenizer not using segment ids
- Issue with GPT-J that was incorrectly tracking the gradients for the attention bias

## Changed
- (BREAKING) Upgraded to `torch` 2.1 (via `tch` 0.14.0).

## [0.21.0] - 2023-06-03
## Added
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ features = ["doc-only"]

[dependencies]
rust_tokenizers = "8.1.1"
tch = "0.13.0"
tch = "0.14.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
Expand All @@ -97,7 +97,7 @@ anyhow = "1"
csv = "1"
criterion = "0.4"
tokio = { version = "1.24", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.13.0"
torch-sys = "0.14.0"
tempfile = "3"
itertools = "0.10"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett

### Manual installation (recommended)

1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.0.0`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
Expand Down
1 change: 1 addition & 0 deletions benches/generation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fn create_text_generation_model() -> TextGenerationModel {
diversity_penalty: None,
num_return_sequences: 5,
device: Device::cuda_if_available(),
kind: None,
};
TextGenerationModel::new(config).unwrap()
}
Expand Down
2 changes: 1 addition & 1 deletion examples/natural_language_inference_deberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
load_weights(&model_resource, &mut vs)?;
load_weights(&model_resource, &mut vs, None, device)?;

// Define input
let input = [("I love you.", "I like you.")];
Expand Down
15 changes: 9 additions & 6 deletions src/common/resources/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::ops::DerefMut;
use std::path::PathBuf;
use std::sync::RwLockWriteGuard;
use tch::nn::VarStore;
use tch::{Device, Kind};

pub enum Resource<'a> {
PathBuf(PathBuf),
Expand Down Expand Up @@ -84,17 +85,19 @@ impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T> {
pub fn load_weights(
rp: &(impl ResourceProvider + ?Sized),
vs: &mut VarStore,
kind: Option<Kind>,
device: Device,
) -> Result<(), RustBertError> {
match rp.get_resource()? {
Resource::Buffer(mut data) => {
vs.load_from_stream(std::io::Cursor::new(data.deref_mut()))?;
Ok(())
}
Resource::PathBuf(path) => Ok(vs.load(path)?),
}
Resource::Buffer(mut data) => vs.load_from_stream(std::io::Cursor::new(data.deref_mut())),
Resource::PathBuf(path) => vs.load(path),
}?;
cast_var_store(vs, kind, device);
Ok(())
}

#[cfg(feature = "remote")]
mod remote;
use crate::pipelines::common::cast_var_store;
#[cfg(feature = "remote")]
pub use remote::RemoteResource;
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.0`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.1`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
Expand Down
7 changes: 6 additions & 1 deletion src/models/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,12 @@ impl BartGenerator {
let mut var_store = nn::VarStore::new(device);
let config = BartConfig::from_file(config_path);
let model = BartForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {
Expand Down
7 changes: 6 additions & 1 deletion src/models/gpt2/gpt2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,12 @@ impl GPT2Generator {

let config = Gpt2Config::from_file(config_path);
let model = GPT2LMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
Expand Down
21 changes: 7 additions & 14 deletions src/models/gpt_j/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ impl GptJAttention {
let p = p.borrow();

let max_positions = config.n_positions;
let bias = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
let bias_value = Tensor::ones([max_positions, max_positions], (Kind::Uint8, p.device()))
.tril(0)
.view([1, 1, max_positions, max_positions])
.requires_grad_(false);
let bias = p.var_copy("bias", &bias);
let mut bias = p
.f_ones_no_train("bias", &[1, 1, max_positions, max_positions])
.unwrap()
.to_kind(Kind::Uint8)
.to_device(p.device());
bias.copy_(&bias_value);

let attn_pdrop = config.attn_pdrop.unwrap_or(0.1);
let resid_pdrop = config.resid_pdrop.unwrap_or(0.1);
Expand All @@ -95,21 +100,9 @@ impl GptJAttention {
..Default::default()
};
let k_proj = nn::linear(p / "k_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "k_proj").half();
}
let v_proj = nn::linear(p / "v_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "v_proj").half();
}
let q_proj = nn::linear(p / "q_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "q_proj").half();
}
let out_proj = nn::linear(p / "out_proj", config.n_embd, config.n_embd, linear_config);
if config.use_float16 {
(p / "out_proj").half();
}

GptJAttention {
bias,
Expand Down
23 changes: 6 additions & 17 deletions src/models/gpt_j/gpt_j_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ pub struct GptJConfig {
pub rotary_dim: Option<i64>,
pub vocab_size: i64,
pub scale_attn_weights: Option<bool>,
#[serde(default = "default_use_float16")]
pub use_float16: bool,
#[serde(default = "default_preload_on_cpu")]
pub preload_on_cpu: bool,
pub decoder_start_token_id: Option<i64>,
Expand Down Expand Up @@ -164,7 +162,6 @@ impl Default for GptJConfig {
rotary_dim: Some(64),
vocab_size: 50400,
scale_attn_weights: Some(true),
use_float16: default_use_float16(),
preload_on_cpu: default_preload_on_cpu(),
decoder_start_token_id: None,
forced_bos_token_id: None,
Expand All @@ -173,10 +170,6 @@ impl Default for GptJConfig {
}
}

fn default_use_float16() -> bool {
true
}

fn default_preload_on_cpu() -> bool {
true
}
Expand Down Expand Up @@ -233,9 +226,6 @@ impl GptJModel {
config.n_embd,
Default::default(),
);
if config.use_float16 {
(&(&p / "wte") / "weight").half()
};

let embd_pdrop = config.embd_pdrop.unwrap_or(0.1);
let drop = Dropout::new(embd_pdrop);
Expand All @@ -245,9 +235,6 @@ impl GptJModel {
..Default::default()
};
let ln_f = nn::layer_norm(&p / "ln_f", vec![config.n_embd], layer_norm_config);
if config.use_float16 {
(&p / "ln_f").half()
};

let mut h: Vec<GptJBlock> = vec![];
let h_path = &p / "h";
Expand Down Expand Up @@ -475,9 +462,6 @@ impl GptJLMHeadModel {
config.vocab_size,
Default::default(),
);
if config.use_float16 {
(p / "lm_head").half();
}

GptJLMHeadModel {
transformer,
Expand Down Expand Up @@ -625,7 +609,12 @@ impl GptJGenerator {
if config.preload_on_cpu && device != Device::Cpu {
var_store.set_device(Device::Cpu);
}
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;
if device != Device::Cpu {
var_store.set_device(device);
}
Expand Down
9 changes: 0 additions & 9 deletions src/models/gpt_j/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,12 @@ impl GptJMLP {
intermediate_size,
Default::default(),
);
if config.use_float16 {
(p / "fc_in").half()
};
let fc_out = nn::linear(
p / "fc_out",
intermediate_size,
config.n_embd,
Default::default(),
);
if config.use_float16 {
(p / "fc_out").half()
};

let activation = match &config.afn {
Some(activation_enum) => match activation_enum {
Expand Down Expand Up @@ -100,9 +94,6 @@ impl GptJBlock {
..Default::default()
};
let ln_1 = nn::layer_norm(p / "ln_1", vec![config.n_embd], layer_norm_config);
if config.use_float16 {
(p / "ln_1").half()
};
let attn = GptJAttention::new(p / "attn", config);
let mlp = GptJMLP::new(p / "mlp", config);

Expand Down
7 changes: 6 additions & 1 deletion src/models/gpt_neo/gpt_neo_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,12 @@ impl GptNeoGenerator {
let mut var_store = nn::VarStore::new(device);
let config = GptNeoConfig::from_file(config_path);
let model = GptNeoForCausalLM::new(var_store.root(), &config)?;
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
Expand Down
4 changes: 2 additions & 2 deletions src/models/longt5/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ impl LongT5Stack {

let (batch_size, sequence_length) = (input_shape[0], input_shape[1]);

let mask_seq_length = if old_layer_states.is_some() {
if old_layer_states.as_ref().unwrap()[0].0.is_some() {
let mask_seq_length = if let Some(old_layer_states_value) = &old_layer_states {
if old_layer_states_value[0].0.is_some() {
old_layer_states.as_ref().unwrap()[0]
.0
.as_ref()
Expand Down
7 changes: 6 additions & 1 deletion src/models/longt5/longt5_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,12 @@ impl LongT5Generator {

let config = LongT5Config::from_file(config_path);
let model = LongT5ForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = config.bos_token_id;
let eos_token_ids = Some(match config.eos_token_id {
Expand Down
7 changes: 6 additions & 1 deletion src/models/m2m_100/m2m_100_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,12 @@ impl M2M100Generator {

let config = M2M100Config::from_file(config_path);
let model = M2M100ForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {
Expand Down
7 changes: 6 additions & 1 deletion src/models/marian/marian_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,12 @@ impl MarianGenerator {

let config = BartConfig::from_file(config_path);
let model = MarianForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {
Expand Down
9 changes: 7 additions & 2 deletions src/models/mbart/mbart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ impl MBartForSequenceClassification {
/// # let device = Device::Cpu;
/// # let vs = nn::VarStore::new(device);
/// # let config = MBartConfig::from_file(config_path);
/// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();;
/// # let mbart_model: MBartForSequenceClassification = MBartForSequenceClassification::new(&vs.root(), &config).unwrap();
/// let (batch_size, source_sequence_length, target_sequence_length) = (64, 128, 56);
/// let input_tensor = Tensor::rand(&[batch_size, source_sequence_length], (Int64, device));
/// let target_tensor = Tensor::rand(&[batch_size, target_sequence_length], (Int64, device));
Expand Down Expand Up @@ -800,7 +800,12 @@ impl MBartGenerator {

let config = MBartConfig::from_file(config_path);
let model = MBartForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = Some(match config.eos_token_id {
Expand Down
7 changes: 6 additions & 1 deletion src/models/openai_gpt/openai_gpt_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,12 @@ impl OpenAIGenerator {
let mut var_store = nn::VarStore::new(device);
let config = Gpt2Config::from_file(config_path);
let model = OpenAIGPTLMHeadModel::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = tokenizer.get_bos_id();
let eos_token_ids = tokenizer.get_eos_id().map(|id| vec![id]);
Expand Down
7 changes: 6 additions & 1 deletion src/models/pegasus/pegasus_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,12 @@ impl PegasusConditionalGenerator {
let mut var_store = nn::VarStore::new(device);
let config = PegasusConfig::from_file(config_path);
let model = PegasusForConditionalGeneration::new(var_store.root(), &config);
crate::resources::load_weights(&generate_config.model_resource, &mut var_store)?;
crate::resources::load_weights(
&generate_config.model_resource,
&mut var_store,
generate_config.kind,
device,
)?;

let bos_token_id = Some(config.bos_token_id.unwrap_or(0));
let eos_token_ids = config
Expand Down
Loading

0 comments on commit 9f2cd17

Please sign in to comment.