From 95a857cf57c56a34ecdaae5372f2a13ebd900001 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 17:51:19 +0200 Subject: [PATCH 1/9] Move the llama2-c model in transformers. (#1205) --- candle-examples/examples/llama2-c/main.rs | 6 +++--- candle-transformers/Cargo.toml | 1 + .../model.rs => candle-transformers/src/models/llama2_c.rs | 0 .../src/models/llama2_c_weights.rs | 5 ++--- candle-transformers/src/models/mod.rs | 3 +++ .../src/models/quantized_llama2_c.rs | 6 +++--- 6 files changed, 12 insertions(+), 9 deletions(-) rename candle-examples/examples/llama2-c/model.rs => candle-transformers/src/models/llama2_c.rs (100%) rename candle-examples/examples/llama2-c/weights.rs => candle-transformers/src/models/llama2_c_weights.rs (98%) rename candle-examples/examples/llama2-c/qmodel.rs => candle-transformers/src/models/quantized_llama2_c.rs (97%) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index 77dbc6778..a3f01ae2e 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -6,10 +6,10 @@ extern crate accelerate_src; #[cfg(feature = "mkl")] extern crate intel_mkl_src; -mod model; -mod qmodel; +use candle_transformers::models::llama2_c as model; +use candle_transformers::models::llama2_c_weights as weights; +use candle_transformers::models::quantized_llama2_c as qmodel; mod training; -mod weights; use clap::{Parser, Subcommand}; use anyhow::{Error as E, Result}; diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 5af7e55d7..e7290be63 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -11,6 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } +byteorder = { workspace = true } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true } candle-nn = { path = "../candle-nn", version = "0.3.0" } diff --git a/candle-examples/examples/llama2-c/model.rs b/candle-transformers/src/models/llama2_c.rs similarity index 100% rename from candle-examples/examples/llama2-c/model.rs rename to candle-transformers/src/models/llama2_c.rs diff --git a/candle-examples/examples/llama2-c/weights.rs b/candle-transformers/src/models/llama2_c_weights.rs similarity index 98% rename from candle-examples/examples/llama2-c/weights.rs rename to candle-transformers/src/models/llama2_c_weights.rs index b78418ce3..e5a8bb880 100644 --- a/candle-examples/examples/llama2-c/weights.rs +++ b/candle-transformers/src/models/llama2_c_weights.rs @@ -1,9 +1,8 @@ -use anyhow::Result; use byteorder::{LittleEndian, ReadBytesExt}; -use candle::{DType, Device, IndexOp, Shape, Tensor}; +use candle::{DType, Device, IndexOp, Result, Shape, Tensor}; use candle_nn::VarBuilder; -use crate::model::Config; +use super::llama2_c::Config; pub struct TransformerWeights { // token embedding table diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f722e93b0..c59bd880c 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,8 @@ pub mod efficientnet; pub mod falcon; pub mod jina_bert; pub mod llama; +pub mod llama2_c; +pub mod llama2_c_weights; pub mod mistral; pub mod mixformer; pub mod mpt; @@ -15,6 +17,7 @@ pub mod persimmon; pub mod quantized_blip; pub mod quantized_blip_text; pub mod quantized_llama; +pub mod quantized_llama2_c; pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_mpt; diff --git a/candle-examples/examples/llama2-c/qmodel.rs b/candle-transformers/src/models/quantized_llama2_c.rs similarity index 97% rename from candle-examples/examples/llama2-c/qmodel.rs rename to candle-transformers/src/models/quantized_llama2_c.rs index 07db146eb..68ebee0da 100644 --- a/candle-examples/examples/llama2-c/qmodel.rs +++ b/candle-transformers/src/models/quantized_llama2_c.rs @@ -1,7 +1,7 @@ -use super::model::{Cache, Config}; +use super::llama2_c::{Cache, Config}; +use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; +pub use crate::quantized_var_builder::VarBuilder; use candle::{DType, IndexOp, Module, Result, Tensor, D}; -use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm}; -pub use candle_transformers::quantized_var_builder::VarBuilder; fn silu(xs: &Tensor) -> Result { xs / (xs.neg()?.exp()? + 1.0)? From 012ae0090e70da67987a0308ef18587e9e8a8e44 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 28 Oct 2023 20:00:39 +0200 Subject: [PATCH 2/9] Infer the config for llama2-c. (#1208) --- candle-examples/examples/llama2-c/main.rs | 14 ++++++- candle-examples/examples/llama2-c/training.rs | 2 +- candle-transformers/src/models/llama2_c.rs | 41 ++++++++++++++++++- .../src/quantized_var_builder.rs | 10 +++++ 4 files changed, 63 insertions(+), 4 deletions(-) diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index a3f01ae2e..0ceb27af7 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { .extension() .map_or(false, |v| v == "safetensors"); let (model, config) = if is_gguf { - let config = Config::tiny(); let vb = qmodel::VarBuilder::from_gguf(config_path)?; + let (_vocab_size, dim) = vb + .get_no_shape("model.embed_tokens.weight")? + .shape() + .dims2()?; + let config = match dim { + 64 => Config::tiny_260k(), + 288 => Config::tiny_15m(), + 512 => Config::tiny_42m(), + 768 => Config::tiny_110m(), + _ => anyhow::bail!("no config for dim {dim}"), + }; let freq_cis_real = vb .get( (config.seq_len, config.head_size() / 2), @@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); (model, config) } else if is_safetensors { - let config = Config::tiny(); + let config = Config::tiny_15m(); let tensors = candle::safetensors::load(config_path, &device)?; let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device); let cache = model::Cache::new(true, &config, vb.pp("rot"))?; diff --git a/candle-examples/examples/llama2-c/training.rs b/candle-examples/examples/llama2-c/training.rs index 150a32723..b2aa0889f 100644 --- a/candle-examples/examples/llama2-c/training.rs +++ b/candle-examples/examples/llama2-c/training.rs @@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> { ); let varmap = candle_nn::VarMap::new(); let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device); - let config = Config::tiny(); + let config = Config::tiny_15m(); let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone()); let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size); diff --git a/candle-transformers/src/models/llama2_c.rs b/candle-transformers/src/models/llama2_c.rs index 07a6e2f21..753770fb7 100644 --- a/candle-transformers/src/models/llama2_c.rs +++ b/candle-transformers/src/models/llama2_c.rs @@ -17,7 +17,20 @@ pub struct Config { } impl Config { - pub fn tiny() -> Self { + pub fn tiny_260k() -> Self { + Self { + dim: 64, + hidden_dim: 768, + n_layers: 5, + n_heads: 8, + n_kv_heads: 4, + vocab_size: 32000, + seq_len: 512, + norm_eps: 1e-5, + } + } + + pub fn tiny_15m() -> Self { Self { dim: 288, hidden_dim: 768, @@ -29,6 +42,32 @@ impl Config { norm_eps: 1e-5, } } + + pub fn tiny_42m() -> Self { + Self { + dim: 512, + hidden_dim: 768, + n_layers: 8, + n_heads: 8, + n_kv_heads: 8, + vocab_size: 32000, + seq_len: 1024, + norm_eps: 1e-5, + } + } + + pub fn tiny_110m() -> Self { + Self { + dim: 768, + hidden_dim: 768, + n_layers: 12, + n_heads: 12, + n_kv_heads: 12, + vocab_size: 32000, + seq_len: 1024, + norm_eps: 1e-5, + } + } } #[derive(Clone)] diff --git a/candle-transformers/src/quantized_var_builder.rs b/candle-transformers/src/quantized_var_builder.rs index 259496d62..810802e8d 100644 --- a/candle-transformers/src/quantized_var_builder.rs +++ b/candle-transformers/src/quantized_var_builder.rs @@ -77,6 +77,16 @@ impl VarBuilder { } } + pub fn get_no_shape(&self, name: &str) -> Result> { + let path = self.path(name); + match self.data.get(&path) { + None => { + candle::bail!("cannot find tensor {name}") + } + Some(qtensor) => Ok(qtensor.clone()), + } + } + pub fn device(&self) -> &Device { &self.device } From 498c50348ce13456d683c987ad9aef319a45eb4a Mon Sep 17 00:00:00 2001 From: Travis Hammond Date: Sat, 28 Oct 2023 20:53:34 +0200 Subject: [PATCH 3/9] Add DDPG and fix Gym wrapper (#1207) * Fix Gym wrapper - It was returning things in the wrong order - Gym now differentiates between terminated and truncated * Add DDPG * Apply fixes * Remove Result annotations * Also remove Vec annotation * rustfmt * Various small improvements (avoid cloning, mutability, get clippy to pass, ...) --------- Co-authored-by: Travis Hammond Co-authored-by: Laurent --- .../examples/reinforcement-learning/ddpg.rs | 451 ++++++++++++++++++ .../reinforcement-learning/gym_env.rs | 38 +- .../examples/reinforcement-learning/main.rs | 85 +++- 3 files changed, 549 insertions(+), 25 deletions(-) create mode 100644 candle-examples/examples/reinforcement-learning/ddpg.rs diff --git a/candle-examples/examples/reinforcement-learning/ddpg.rs b/candle-examples/examples/reinforcement-learning/ddpg.rs new file mode 100644 index 000000000..c6d72fed4 --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/ddpg.rs @@ -0,0 +1,451 @@ +use std::collections::VecDeque; +use std::fmt::Display; + +use candle::{DType, Device, Error, Module, Result, Tensor, Var}; +use candle_nn::{ + func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential, + VarBuilder, VarMap, +}; +use rand::{distributions::Uniform, thread_rng, Rng}; + +pub struct OuNoise { + mu: f64, + theta: f64, + sigma: f64, + state: Tensor, +} +impl OuNoise { + pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result { + Ok(Self { + mu, + theta, + sigma, + state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?, + }) + } + + pub fn sample(&mut self) -> Result { + let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?; + let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?; + self.state = (&self.state + dx)?; + Ok(self.state.clone()) + } +} + +#[derive(Clone)] +struct Transition { + state: Tensor, + action: Tensor, + reward: Tensor, + next_state: Tensor, + terminated: bool, + truncated: bool, +} +impl Transition { + fn new( + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) -> Self { + Self { + state: state.clone(), + action: action.clone(), + reward: reward.clone(), + next_state: next_state.clone(), + terminated, + truncated, + } + } +} + +pub struct ReplayBuffer { + buffer: VecDeque, + capacity: usize, + size: usize, +} +impl ReplayBuffer { + pub fn new(capacity: usize) -> Self { + Self { + buffer: VecDeque::with_capacity(capacity), + capacity, + size: 0, + } + } + + pub fn push( + &mut self, + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) { + if self.size == self.capacity { + self.buffer.pop_front(); + } else { + self.size += 1; + } + self.buffer.push_back(Transition::new( + state, action, reward, next_state, terminated, truncated, + )); + } + + #[allow(clippy::type_complexity)] + pub fn random_batch( + &self, + batch_size: usize, + ) -> Result, Vec)>> { + if self.size < batch_size { + Ok(None) + } else { + let transitions: Vec<&Transition> = thread_rng() + .sample_iter(Uniform::from(0..self.size)) + .take(batch_size) + .map(|i| self.buffer.get(i).unwrap()) + .collect(); + + let states: Vec = transitions + .iter() + .map(|t| t.state.unsqueeze(0)) + .collect::>()?; + let actions: Vec = transitions + .iter() + .map(|t| t.action.unsqueeze(0)) + .collect::>()?; + let rewards: Vec = transitions + .iter() + .map(|t| t.reward.unsqueeze(0)) + .collect::>()?; + let next_states: Vec = transitions + .iter() + .map(|t| t.next_state.unsqueeze(0)) + .collect::>()?; + let terminateds: Vec = transitions.iter().map(|t| t.terminated).collect(); + let truncateds: Vec = transitions.iter().map(|t| t.truncated).collect(); + + Ok(Some(( + Tensor::cat(&states, 0)?, + Tensor::cat(&actions, 0)?, + Tensor::cat(&rewards, 0)?, + Tensor::cat(&next_states, 0)?, + terminateds, + truncateds, + ))) + } + } +} + +fn track( + varmap: &mut VarMap, + vb: &VarBuilder, + target_prefix: &str, + network_prefix: &str, + dims: &[(usize, usize)], + tau: f64, +) -> Result<()> { + for (i, &(in_dim, out_dim)) in dims.iter().enumerate() { + let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?; + let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?; + varmap.set_one( + format!("{target_prefix}-fc{i}.weight"), + ((tau * network_w)? + ((1.0 - tau) * target_w)?)?, + )?; + + let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?; + let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?; + varmap.set_one( + format!("{target_prefix}-fc{i}.bias"), + ((tau * network_b)? + ((1.0 - tau) * target_b)?)?, + )?; + } + Ok(()) +} + +struct Actor<'a> { + varmap: VarMap, + vb: VarBuilder<'a>, + network: Sequential, + target_network: Sequential, + size_state: usize, + size_action: usize, + dims: Vec<(usize, usize)>, +} + +impl Actor<'_> { + fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result { + let mut varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, dtype, device); + + let dims = vec![(size_state, 400), (400, 300), (300, size_action)]; + + let make_network = |prefix: &str| { + let seq = seq() + .add(linear( + dims[0].0, + dims[0].1, + vb.pp(format!("{prefix}-fc0")), + )?) + .add(Activation::Relu) + .add(linear( + dims[1].0, + dims[1].1, + vb.pp(format!("{prefix}-fc1")), + )?) + .add(Activation::Relu) + .add(linear( + dims[2].0, + dims[2].1, + vb.pp(format!("{prefix}-fc2")), + )?) + .add(func(|xs| xs.tanh())); + Ok::(seq) + }; + + let network = make_network("actor")?; + let target_network = make_network("target-actor")?; + + // this sets the two networks to be equal to each other using tau = 1.0 + track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0); + + Ok(Self { + varmap, + vb, + network, + target_network, + size_state, + size_action, + dims, + }) + } + + fn forward(&self, state: &Tensor) -> Result { + self.network.forward(state) + } + + fn target_forward(&self, state: &Tensor) -> Result { + self.target_network.forward(state) + } + + fn track(&mut self, tau: f64) -> Result<()> { + track( + &mut self.varmap, + &self.vb, + "target-actor", + "actor", + &self.dims, + tau, + ) + } +} + +struct Critic<'a> { + varmap: VarMap, + vb: VarBuilder<'a>, + network: Sequential, + target_network: Sequential, + size_state: usize, + size_action: usize, + dims: Vec<(usize, usize)>, +} + +impl Critic<'_> { + fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result { + let mut varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, dtype, device); + + let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)]; + + let make_network = |prefix: &str| { + let seq = seq() + .add(linear( + dims[0].0, + dims[0].1, + vb.pp(format!("{prefix}-fc0")), + )?) + .add(Activation::Relu) + .add(linear( + dims[1].0, + dims[1].1, + vb.pp(format!("{prefix}-fc1")), + )?) + .add(Activation::Relu) + .add(linear( + dims[2].0, + dims[2].1, + vb.pp(format!("{prefix}-fc2")), + )?); + Ok::(seq) + }; + + let network = make_network("critic")?; + let target_network = make_network("target-critic")?; + + // this sets the two networks to be equal to each other using tau = 1.0 + track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0); + + Ok(Self { + varmap, + vb, + network, + target_network, + size_state, + size_action, + dims, + }) + } + + fn forward(&self, state: &Tensor, action: &Tensor) -> Result { + let xs = Tensor::cat(&[action, state], 1)?; + self.network.forward(&xs) + } + + fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result { + let xs = Tensor::cat(&[action, state], 1)?; + self.target_network.forward(&xs) + } + + fn track(&mut self, tau: f64) -> Result<()> { + track( + &mut self.varmap, + &self.vb, + "target-critic", + "critic", + &self.dims, + tau, + ) + } +} + +#[allow(clippy::upper_case_acronyms)] +pub struct DDPG<'a> { + actor: Actor<'a>, + actor_optim: AdamW, + critic: Critic<'a>, + critic_optim: AdamW, + gamma: f64, + tau: f64, + replay_buffer: ReplayBuffer, + ou_noise: OuNoise, + + size_state: usize, + size_action: usize, + pub train: bool, +} + +impl DDPG<'_> { + #[allow(clippy::too_many_arguments)] + pub fn new( + device: &Device, + size_state: usize, + size_action: usize, + train: bool, + actor_lr: f64, + critic_lr: f64, + gamma: f64, + tau: f64, + buffer_capacity: usize, + ou_noise: OuNoise, + ) -> Result { + let filter_by_prefix = |varmap: &VarMap, prefix: &str| { + varmap + .data() + .lock() + .unwrap() + .iter() + .filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone())) + .collect::>() + }; + + let actor = Actor::new(device, DType::F32, size_state, size_action)?; + let actor_optim = AdamW::new( + filter_by_prefix(&actor.varmap, "actor"), + ParamsAdamW { + lr: actor_lr, + ..Default::default() + }, + )?; + + let critic = Critic::new(device, DType::F32, size_state, size_action)?; + let critic_optim = AdamW::new( + filter_by_prefix(&critic.varmap, "critic"), + ParamsAdamW { + lr: critic_lr, + ..Default::default() + }, + )?; + + Ok(Self { + actor, + actor_optim, + critic, + critic_optim, + gamma, + tau, + replay_buffer: ReplayBuffer::new(buffer_capacity), + ou_noise, + size_state, + size_action, + train, + }) + } + + pub fn remember( + &mut self, + state: &Tensor, + action: &Tensor, + reward: &Tensor, + next_state: &Tensor, + terminated: bool, + truncated: bool, + ) { + self.replay_buffer + .push(state, action, reward, next_state, terminated, truncated) + } + + pub fn actions(&mut self, state: &Tensor) -> Result { + let actions = self + .actor + .forward(&state.detach()?.unsqueeze(0)?)? + .squeeze(0)?; + let actions = if self.train { + (actions + self.ou_noise.sample()?)? + } else { + actions + }; + actions.squeeze(0)?.to_scalar::() + } + + pub fn train(&mut self, batch_size: usize) -> Result<()> { + let (states, actions, rewards, next_states, _, _) = + match self.replay_buffer.random_batch(batch_size)? { + Some(v) => v, + _ => return Ok(()), + }; + + let q_target = self + .critic + .target_forward(&next_states, &self.actor.target_forward(&next_states)?)?; + let q_target = (rewards + (self.gamma * q_target)?.detach())?; + let q = self.critic.forward(&states, &actions)?; + let diff = (q_target - q)?; + + let critic_loss = diff.sqr()?.mean_all()?; + self.critic_optim.backward_step(&critic_loss)?; + + let actor_loss = self + .critic + .forward(&states, &self.actor.forward(&states)?)? + .mean_all()? + .neg()?; + self.actor_optim.backward_step(&actor_loss)?; + + self.critic.track(self.tau)?; + self.actor.track(self.tau)?; + + Ok(()) + } +} diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs index b98be6bc8..8868c1884 100644 --- a/candle-examples/examples/reinforcement-learning/gym_env.rs +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -7,20 +7,22 @@ use pyo3::types::PyDict; /// The return value for a step. #[derive(Debug)] pub struct Step { - pub obs: Tensor, + pub state: Tensor, pub action: A, pub reward: f64, - pub is_done: bool, + pub terminated: bool, + pub truncated: bool, } impl Step { /// Returns a copy of this step changing the observation tensor. - pub fn copy_with_obs(&self, obs: &Tensor) -> Step { + pub fn copy_with_obs(&self, state: &Tensor) -> Step { Step { - obs: obs.clone(), + state: state.clone(), action: self.action, reward: self.reward, - is_done: self.is_done, + terminated: self.terminated, + truncated: self.truncated, } } } @@ -63,14 +65,14 @@ impl GymEnv { /// Resets the environment, returning the observation tensor. pub fn reset(&self, seed: u64) -> Result { - let obs: Vec = Python::with_gil(|py| { + let state: Vec = Python::with_gil(|py| { let kwargs = PyDict::new(py); kwargs.set_item("seed", seed)?; - let obs = self.env.call_method(py, "reset", (), Some(kwargs))?; - obs.as_ref(py).get_item(0)?.extract() + let state = self.env.call_method(py, "reset", (), Some(kwargs))?; + state.as_ref(py).get_item(0)?.extract() }) .map_err(w)?; - Tensor::new(obs, &Device::Cpu) + Tensor::new(state, &Device::Cpu) } /// Applies an environment step using the specified action. @@ -78,21 +80,23 @@ impl GymEnv { &self, action: A, ) -> Result> { - let (obs, reward, is_done) = Python::with_gil(|py| { + let (state, reward, terminated, truncated) = Python::with_gil(|py| { let step = self.env.call_method(py, "step", (action.clone(),), None)?; let step = step.as_ref(py); - let obs: Vec = step.get_item(0)?.extract()?; + let state: Vec = step.get_item(0)?.extract()?; let reward: f64 = step.get_item(1)?.extract()?; - let is_done: bool = step.get_item(2)?.extract()?; - Ok((obs, reward, is_done)) + let terminated: bool = step.get_item(2)?.extract()?; + let truncated: bool = step.get_item(3)?.extract()?; + Ok((state, reward, terminated, truncated)) }) .map_err(w)?; - let obs = Tensor::new(obs, &Device::Cpu)?; + let state = Tensor::new(state, &Device::Cpu)?; Ok(Step { - obs, - reward, - is_done, + state, action, + reward, + terminated, + truncated, }) } diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs index f16f042e9..96d7102d9 100644 --- a/candle-examples/examples/reinforcement-learning/main.rs +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -9,14 +9,34 @@ extern crate accelerate_src; mod gym_env; mod vec_gym_env; -use candle::Result; +mod ddpg; + +use candle::{Device, Result, Tensor}; use clap::Parser; use rand::Rng; +// The impact of the q value of the next state on the current state's q value. +const GAMMA: f64 = 0.99; +// The weight for updating the target networks. +const TAU: f64 = 0.005; +// The capacity of the replay buffer used for sampling training data. +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +// The training batch size for each training iteration. +const TRAINING_BATCH_SIZE: usize = 100; // The total number of episodes. const MAX_EPISODES: usize = 100; // The maximum length of an episode. const EPISODE_LENGTH: usize = 200; +// The number of training iterations after one episode finishes. +const TRAINING_ITERATIONS: usize = 200; + +// Ornstein-Uhlenbeck process parameters. +const MU: f64 = 0.0; +const THETA: f64 = 0.15; +const SIGMA: f64 = 0.1; + +const ACTOR_LEARNING_RATE: f64 = 1e-4; +const CRITIC_LEARNING_RATE: f64 = 1e-3; #[derive(Parser, Debug, Clone)] #[command(author, version, about, long_about = None)] @@ -48,28 +68,77 @@ fn main() -> Result<()> { println!("action space: {}", env.action_space()); println!("observation space: {:?}", env.observation_space()); - let _num_obs = env.observation_space().iter().product::(); - let _num_actions = env.action_space(); + let size_state = env.observation_space().iter().product::(); + let size_action = env.action_space(); + + let mut agent = ddpg::DDPG::new( + &Device::Cpu, + size_state, + size_action, + true, + ACTOR_LEARNING_RATE, + CRITIC_LEARNING_RATE, + GAMMA, + TAU, + REPLAY_BUFFER_CAPACITY, + ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?, + )?; let mut rng = rand::thread_rng(); for episode in 0..MAX_EPISODES { - let mut obs = env.reset(episode as u64)?; + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::())?; let mut total_reward = 0.0; for _ in 0..EPISODE_LENGTH { - let actions = rng.gen_range(-2.0..2.0); + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); - let step = env.step(vec![actions])?; + let step = env.step(vec![action])?; total_reward += step.reward; - if step.is_done { + agent.remember( + &state, + &Tensor::new(vec![action], &Device::Cpu)?, + &Tensor::new(vec![step.reward as f32], &Device::Cpu)?, + &step.state, + step.terminated, + step.truncated, + ); + + if step.terminated || step.truncated { break; } - obs = step.obs; + state = step.state; } println!("episode {episode} with total reward of {total_reward}"); + + for _ in 0..TRAINING_ITERATIONS { + agent.train(TRAINING_BATCH_SIZE)?; + } + } + + println!("Testing..."); + agent.train = false; + for episode in 0..10 { + // let mut state = env.reset(episode as u64)?; + let mut state = env.reset(rng.gen::())?; + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let mut action = 2.0 * agent.actions(&state)?; + action = action.clamp(-2.0, 2.0); + + let step = env.step(vec![action])?; + total_reward += step.reward; + + if step.terminated || step.truncated { + break; + } + state = step.state; + } + println!("episode {episode} with total reward of {total_reward}"); } Ok(()) } From dece37c6f4d9c5a52caf59a003afa6ba33034fe3 Mon Sep 17 00:00:00 2001 From: drbh Date: Sun, 29 Oct 2023 02:10:23 -0400 Subject: [PATCH 4/9] feat: implement VGG13, VGG16 and VGG19 (#1211) * feat: implement VGG13, VGG16 and VGG19 * Cosmetic fixes. * More cosmetic tweaks + avoid re-loading the weights on each final layer. --------- Co-authored-by: Laurent --- candle-examples/examples/vgg/README.md | 13 ++ candle-examples/examples/vgg/main.rs | 77 ++++++++ candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/vgg.rs | 254 +++++++++++++++++++++++++ 4 files changed, 345 insertions(+) create mode 100644 candle-examples/examples/vgg/README.md create mode 100644 candle-examples/examples/vgg/main.rs create mode 100644 candle-transformers/src/models/vgg.rs diff --git a/candle-examples/examples/vgg/README.md b/candle-examples/examples/vgg/README.md new file mode 100644 index 000000000..473038e80 --- /dev/null +++ b/candle-examples/examples/vgg/README.md @@ -0,0 +1,13 @@ +## VGG Model Implementation + +This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library. + +The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image. + +You can run the example with the following command: + +```bash +cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13 +``` + +In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19). diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs new file mode 100644 index 000000000..e01fa8e8b --- /dev/null +++ b/candle-examples/examples/vgg/main.rs @@ -0,0 +1,77 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::vgg::{Models, Vgg}; +use clap::{Parser, ValueEnum}; + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + Vgg13, + Vgg16, + Vgg19, +} + +#[derive(Parser)] +struct Args { + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Variant of the model to use. + #[arg(value_enum, long, default_value_t = Which::Vgg13)] + which: Which, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + let image = candle_examples::imagenet::load_image224(args.image)?; + + println!("loaded image {image:?}"); + + let api = hf_hub::api::sync::Api::new()?; + let repo = match args.which { + Which::Vgg13 => "timm/vgg13.tv_in1k", + Which::Vgg16 => "timm/vgg16.tv_in1k", + Which::Vgg19 => "timm/vgg19.tv_in1k", + }; + let api = api.model(repo.into()); + let filename = "model.safetensors"; + let model_file = api.get(filename)?; + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = match args.which { + Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?, + Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?, + Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?, + }; + let logits = model.forward(&image)?; + + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + + // Sort the predictions and take the top 5 + let mut top: Vec<_> = prs.iter().enumerate().collect(); + top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + let top = top.into_iter().take(5).collect::>(); + + // Print the top predictions + for &(i, p) in &top { + println!( + "{:50}: {:.2}%", + candle_examples::imagenet::CLASSES[i], + p * 100.0 + ); + } + + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index c59bd880c..aecfcd672 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -28,6 +28,7 @@ pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; pub mod t5; +pub mod vgg; pub mod vit; pub mod whisper; pub mod with_tracing; diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs new file mode 100644 index 000000000..7837dc3e6 --- /dev/null +++ b/candle-transformers/src/models/vgg.rs @@ -0,0 +1,254 @@ +//! VGG-16 model implementation. +//! +//! See Very Deep Convolutional Networks for Large-Scale Image Recognition +//! +use candle::{Module, Result, Tensor}; +use candle_nn::{Func, VarBuilder}; + +// Enum representing the different VGG models +pub enum Models { + Vgg13, + Vgg16, + Vgg19, +} + +// Struct representing a VGG model +#[derive(Debug)] +pub struct Vgg<'a> { + blocks: Vec>, +} + +// Struct representing the configuration for the pre-logit layer +struct PreLogitConfig { + in_dim: (usize, usize, usize, usize), + target_in: usize, + target_out: usize, +} + +// Implementation of the VGG model +impl<'a> Vgg<'a> { + // Function to create a new VGG model + pub fn new(vb: VarBuilder<'a>, model: Models) -> Result { + let blocks = match model { + Models::Vgg13 => vgg13_blocks(vb)?, + Models::Vgg16 => vgg16_blocks(vb)?, + Models::Vgg19 => vgg19_blocks(vb)?, + }; + Ok(Self { blocks }) + } +} + +// Implementation of the forward pass for the VGG model +impl Module for Vgg<'_> { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = xs.unsqueeze(0)?; + for block in self.blocks.iter() { + xs = xs.apply(block)?; + } + Ok(xs) + } +} + +// Function to create a conv2d block +// The block is composed of two conv2d layers followed by a max pool layer +fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { + let layers = convs + .iter() + .enumerate() + .map(|(_, &(in_c, out_c, name))| { + candle_nn::conv2d( + in_c, + out_c, + 3, + candle_nn::Conv2dConfig { + stride: 1, + padding: 1, + ..Default::default() + }, + vb.pp(name), + ) + }) + .collect::>>()?; + + Ok(Func::new(move |xs| { + let mut xs = xs.clone(); + for layer in layers.iter() { + xs = xs.apply(layer)?.relu()? + } + xs = xs.max_pool2d_with_stride(2, 2)?; + Ok(xs) + })) +} + +// Function to create a fully connected layer +// The layer is composed of two linear layers followed by a dropout layer +fn fully_connected( + num_classes: usize, + pre_logit_1: PreLogitConfig, + pre_logit_2: PreLogitConfig, + vb: VarBuilder, +) -> Result { + let lin = get_weights_and_biases( + &vb.pp("pre_logits.fc1"), + pre_logit_1.in_dim, + pre_logit_1.target_in, + pre_logit_1.target_out, + )?; + let lin2 = get_weights_and_biases( + &vb.pp("pre_logits.fc2"), + pre_logit_2.in_dim, + pre_logit_2.target_in, + pre_logit_2.target_out, + )?; + Ok(Func::new(move |xs| { + let xs = xs.reshape((1, pre_logit_1.target_out))?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?; + let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?; + let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?; + Ok(xs) + })) +} + +// Function to get the weights and biases for a layer +// This is required because the weights and biases are stored in different format than our linear layer expects +fn get_weights_and_biases( + vs: &VarBuilder, + in_dim: (usize, usize, usize, usize), + target_in: usize, + target_out: usize, +) -> Result { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let ws = vs.get_with_hints(in_dim, "weight", init_ws)?; + let ws = ws.reshape((target_in, target_out))?; + let bound = 1. / (target_out as f64).sqrt(); + let init_bs = candle_nn::Init::Uniform { + lo: -bound, + up: bound, + }; + let bs = vs.get_with_hints(target_in, "bias", init_bs)?; + Ok(candle_nn::Linear::new(ws, Some(bs))) +} + +fn vgg13_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?, + conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?, + conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} + +fn vgg16_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block( + &[ + (128, 256, "features.10"), + (256, 256, "features.12"), + (256, 256, "features.14"), + ], + &vb, + )?, + conv2d_block( + &[ + (256, 512, "features.17"), + (512, 512, "features.19"), + (512, 512, "features.21"), + ], + &vb, + )?, + conv2d_block( + &[ + (512, 512, "features.24"), + (512, 512, "features.26"), + (512, 512, "features.28"), + ], + &vb, + )?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} + +fn vgg19_blocks(vb: VarBuilder) -> Result> { + let num_classes = 1000; + let blocks = vec![ + conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, + conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?, + conv2d_block( + &[ + (128, 256, "features.10"), + (256, 256, "features.12"), + (256, 256, "features.14"), + (256, 256, "features.16"), + ], + &vb, + )?, + conv2d_block( + &[ + (256, 512, "features.19"), + (512, 512, "features.21"), + (512, 512, "features.23"), + (512, 512, "features.25"), + ], + &vb, + )?, + conv2d_block( + &[ + (512, 512, "features.28"), + (512, 512, "features.30"), + (512, 512, "features.32"), + (512, 512, "features.34"), + ], + &vb, + )?, + fully_connected( + num_classes, + PreLogitConfig { + in_dim: (4096, 512, 7, 7), + target_in: 4096, + target_out: 512 * 7 * 7, + }, + PreLogitConfig { + in_dim: (4096, 4096, 1, 1), + target_in: 4096, + target_out: 4096, + }, + vb.clone(), + )?, + ]; + Ok(blocks) +} From 55bc3382cfd3a86018c54f2343567f7c0c0b677c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 07:53:09 +0100 Subject: [PATCH 5/9] Allow for different behavior between training and eval (#1213) * Forward with training. * Do not use dropout on vgg evaluation. --- candle-core/src/lib.rs | 12 +++++++ candle-core/src/tensor.rs | 5 +++ .../examples/mnist-training/main.rs | 4 +-- candle-examples/examples/vgg/main.rs | 4 +-- candle-nn/src/func.rs | 35 +++++++++++++++++++ candle-nn/src/lib.rs | 4 +-- candle-nn/src/ops.rs | 6 ++++ candle-transformers/src/models/vgg.rs | 35 ++++++++++--------- 8 files changed, 83 insertions(+), 22 deletions(-) diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 52effdcf8..73830229c 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -125,3 +125,15 @@ impl Result> Module for T { self(xs) } } + +// A trait defining a module with forward method using a single tensor argument and a flag to +// separate the training and evaluation behaviors. +pub trait ModuleT { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result; +} + +impl ModuleT for M { + fn forward_t(&self, xs: &Tensor, _train: bool) -> Result { + self.forward(xs) + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index ce81d8aff..c6f2364d6 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2271,6 +2271,11 @@ impl Tensor { m.forward(self) } + /// Run the `forward` method of `m` on `self`. + pub fn apply_t(&self, m: &M, train: bool) -> Result { + m.forward_t(self, train) + } + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-examples/examples/mnist-training/main.rs b/candle-examples/examples/mnist-training/main.rs index a07505bf4..a41a6496b 100644 --- a/candle-examples/examples/mnist-training/main.rs +++ b/candle-examples/examples/mnist-training/main.rs @@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum}; use rand::prelude::*; use candle::{DType, Result, Tensor, D}; -use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap}; +use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; @@ -95,7 +95,7 @@ impl ConvNet { .flatten_from(1)? .apply(&self.fc1)? .relu()?; - self.dropout.forward(&xs, train)?.apply(&self.fc2) + self.dropout.forward_t(&xs, train)?.apply(&self.fc2) } } diff --git a/candle-examples/examples/vgg/main.rs b/candle-examples/examples/vgg/main.rs index e01fa8e8b..27e141cb9 100644 --- a/candle-examples/examples/vgg/main.rs +++ b/candle-examples/examples/vgg/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{DType, IndexOp, D}; -use candle_nn::{Module, VarBuilder}; +use candle_nn::{ModuleT, VarBuilder}; use candle_transformers::models::vgg::{Models, Vgg}; use clap::{Parser, ValueEnum}; @@ -53,7 +53,7 @@ pub fn main() -> anyhow::Result<()> { Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?, Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?, }; - let logits = model.forward(&image)?; + let logits = model.forward_t(&image, /*train=*/ false)?; let prs = candle_nn::ops::softmax(&logits, D::Minus1)? .i(0)? diff --git a/candle-nn/src/func.rs b/candle-nn/src/func.rs index 39311d458..3adfda860 100644 --- a/candle-nn/src/func.rs +++ b/candle-nn/src/func.rs @@ -36,3 +36,38 @@ impl<'a> Func<'a> { Self { f: Arc::new(f) } } } + +/// A layer defined by a simple closure. +#[derive(Clone)] +pub struct FuncT<'a> { + #[allow(clippy::type_complexity)] + f: Arc Result + Send + Sync>, +} + +impl<'a> std::fmt::Debug for FuncT<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "func") + } +} + +pub fn func_t<'a, F>(f: F) -> FuncT<'a> +where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, +{ + FuncT { f: Arc::new(f) } +} + +impl<'a> super::ModuleT for FuncT<'a> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + (*self.f)(xs, train) + } +} + +impl<'a> FuncT<'a> { + pub fn new(f: F) -> Self + where + F: 'a + Fn(&Tensor, bool) -> Result + Send + Sync, + { + Self { f: Arc::new(f) } + } +} diff --git a/candle-nn/src/lib.rs b/candle-nn/src/lib.rs index be95f5312..52d8f0c59 100644 --- a/candle-nn/src/lib.rs +++ b/candle-nn/src/lib.rs @@ -22,7 +22,7 @@ pub use conv::{ Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig, }; pub use embedding::{embedding, Embedding}; -pub use func::{func, Func}; +pub use func::{func, func_t, Func, FuncT}; pub use group_norm::{group_norm, GroupNorm}; pub use init::Init; pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm}; @@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential}; pub use var_builder::VarBuilder; pub use var_map::VarMap; -pub use candle::Module; +pub use candle::{Module, ModuleT}; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 32de1af9c..e98121083 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -84,6 +84,12 @@ impl Dropout { } } +impl candle::ModuleT for Dropout { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { + self.forward(xs, train) + } +} + struct SoftmaxLastDim; impl candle::CustomOp1 for SoftmaxLastDim { diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index 7837dc3e6..a20b5e372 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -2,8 +2,8 @@ //! //! See Very Deep Convolutional Networks for Large-Scale Image Recognition //! -use candle::{Module, Result, Tensor}; -use candle_nn::{Func, VarBuilder}; +use candle::{ModuleT, Result, Tensor}; +use candle_nn::{FuncT, VarBuilder}; // Enum representing the different VGG models pub enum Models { @@ -15,7 +15,7 @@ pub enum Models { // Struct representing a VGG model #[derive(Debug)] pub struct Vgg<'a> { - blocks: Vec>, + blocks: Vec>, } // Struct representing the configuration for the pre-logit layer @@ -39,11 +39,11 @@ impl<'a> Vgg<'a> { } // Implementation of the forward pass for the VGG model -impl Module for Vgg<'_> { - fn forward(&self, xs: &Tensor) -> Result { +impl ModuleT for Vgg<'_> { + fn forward_t(&self, xs: &Tensor, train: bool) -> Result { let mut xs = xs.unsqueeze(0)?; for block in self.blocks.iter() { - xs = xs.apply(block)?; + xs = xs.apply_t(block, train)?; } Ok(xs) } @@ -51,7 +51,7 @@ impl Module for Vgg<'_> { // Function to create a conv2d block // The block is composed of two conv2d layers followed by a max pool layer -fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { +fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { let layers = convs .iter() .enumerate() @@ -70,7 +70,7 @@ fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result>>()?; - Ok(Func::new(move |xs| { + Ok(FuncT::new(move |xs, _train| { let mut xs = xs.clone(); for layer in layers.iter() { xs = xs.apply(layer)?.relu()? @@ -87,7 +87,7 @@ fn fully_connected( pre_logit_1: PreLogitConfig, pre_logit_2: PreLogitConfig, vb: VarBuilder, -) -> Result { +) -> Result { let lin = get_weights_and_biases( &vb.pp("pre_logits.fc1"), pre_logit_1.in_dim, @@ -100,12 +100,15 @@ fn fully_connected( pre_logit_2.target_in, pre_logit_2.target_out, )?; - Ok(Func::new(move |xs| { + let dropout1 = candle_nn::Dropout::new(0.5); + let dropout2 = candle_nn::Dropout::new(0.5); + let dropout3 = candle_nn::Dropout::new(0.5); + Ok(FuncT::new(move |xs, train| { let xs = xs.reshape((1, pre_logit_1.target_out))?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin)?.relu()?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin2)?.relu()?; + let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?; + let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?; let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?; - let xs = candle_nn::ops::dropout(&xs, 0.5)?.apply(&lin3)?.relu()?; + let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?; Ok(xs) })) } @@ -130,7 +133,7 @@ fn get_weights_and_biases( Ok(candle_nn::Linear::new(ws, Some(bs))) } -fn vgg13_blocks(vb: VarBuilder) -> Result> { +fn vgg13_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, @@ -156,7 +159,7 @@ fn vgg13_blocks(vb: VarBuilder) -> Result> { Ok(blocks) } -fn vgg16_blocks(vb: VarBuilder) -> Result> { +fn vgg16_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, @@ -203,7 +206,7 @@ fn vgg16_blocks(vb: VarBuilder) -> Result> { Ok(blocks) } -fn vgg19_blocks(vb: VarBuilder) -> Result> { +fn vgg19_blocks(vb: VarBuilder) -> Result> { let num_classes = 1000; let blocks = vec![ conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?, From 46d6566c99f63fc74f3fbf5754183a49219224d5 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 10:50:04 +0100 Subject: [PATCH 6/9] Fix the conv2d gradient computation. (#1214) --- candle-core/src/backprop.rs | 7 ++++ candle-core/tests/conv_tests.rs | 65 +++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 7488d9397..155f49c5c 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -238,6 +238,13 @@ impl Tensor { .conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)? .transpose(0, 1)?; let sum_grad = grads.or_insert(kernel)?; + let (_, _, k0, k1) = kernel.dims4()?; + let (_, _, g_k0, g_k1) = grad_kernel.dims4()?; + let grad_kernel = if g_k0 != k0 || g_k1 != k1 { + grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)? + } else { + grad_kernel + }; *sum_grad = sum_grad.add(&grad_kernel)?; } Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported { diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index 937ddf676..e7fdf1381 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> { ] ] ); + + // Replicate the issue from https://github.com/huggingface/candle/issues/1212 + let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?; + let loss = res.sqr()?.sum_all()?; + assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32); + let grads = loss.backward()?; + let grad_t = grads.get(&t).unwrap(); + let grad_w = grads.get(&w).unwrap(); + assert_eq!(grad_t.dims(), [1, 4, 5, 5]); + assert_eq!(grad_w.dims(), [2, 4, 3, 3]); + assert_eq!( + test_utils::to_vec3_round(&grad_t.i(0)?, 2)?, + [ + [ + [9.29, -7.03, 7.87, 0.0, 0.0], + [-1.8, -7.82, 5.9, 0.0, 0.0], + [-3.12, 4.49, 5.52, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [21.73, 3.39, 4.77, 0.0, 0.0], + [8.25, 3.73, 27.61, 0.0, 0.0], + [-20.55, -5.61, -2.77, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [-8.98, 9.91, -7.15, 0.0, 0.0], + [4.93, -0.33, 4.56, 0.0, 0.0], + [-6.7, -5.76, -8.05, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ], + [ + [23.54, 6.98, -10.0, 0.0, 0.0], + [9.65, 6.18, 18.72, 0.0, 0.0], + [3.29, -5.27, 0.79, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0] + ] + ] + ); + assert_eq!( + test_utils::to_vec3_round(&grad_w.i(0)?, 2)?, + [ + [ + [-3.47, 7.44, 0.66], + [12.89, -3.4, -9.29], + [-14.16, -0.83, 7.14] + ], + [ + [-3.23, 5.37, -3.02], + [-2.12, -11.24, 1.94], + [6.97, 7.2, 2.99] + ], + [ + [-4.04, -3.31, 4.87], + [-6.68, -5.68, 1.73], + [-5.54, 4.32, 0.52] + ], + [[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]] + ] + ); + Ok(()) } From c3f2676d4932daaa5aa7e1bb7faad343ad54d36e Mon Sep 17 00:00:00 2001 From: Lukas Kreussel <65088241+LLukas22@users.noreply.github.com> Date: Sun, 29 Oct 2023 14:44:05 +0100 Subject: [PATCH 7/9] PyO3: Add CI to build & upload wheels as artifacts. (#1215) * Add maturin ci * fix paths * Change sdist path --- .github/workflows/maturin.yml | Bin 0 -> 5304 bytes candle-pyo3/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 .github/workflows/maturin.yml diff --git a/.github/workflows/maturin.yml b/.github/workflows/maturin.yml new file mode 100644 index 0000000000000000000000000000000000000000..1413f01475fb1a0baea85c10e3c05c82e19b415d GIT binary patch literal 5304 zcmeH~-EPxB5QXO&iFe2aQlyGBRIOUz3h@GPgSbElrHPvc@`vpZ5aQK=^X+z$wcRFF zMWYHuR$^y8J9FmmWcJVRsr77PnZ2}@y|o|q#*VFH@9k1+nT@Tm$Mz_EW;T@+zgoKH zw$QuFWQEmp%cXB>{jk5Ny+xv<&qOjKNx3f8ORWv1aczNB-_f=MYggpwk}qZDrXBr& zV;~PQ*__L>nLO)C&%sI$K8$sJ66(yp>Q^RxWevl>u(Xu*+`ia_tj%mGivvZV5H7qR zTG4MJG8c+mG2(rpZ{nVM*$*qFq^8=-n^wTzlTOdoXUUvbc8x>C7xum8T`sTD-w9gL zwa`x1N_}^P7lbh`X}*XPx#rza(QW39EF&{&*!0Yj^IW?#$zl0}B%j-mm&`)&dSCvT}lV!A>G_;{{DGW$8qwL73aM#n39%{&| z2T$6U7Z^{F&g>9tF4}#??6SzIUr+zwrbVG;ZA=#hG%QTL7d2~lF>lO3F ze(H*h;2&rlKh^}p^ zDyVDp*V(e}1Kp9<$vb%rpTRuT=dt%pMP=>^wBwd6uLQjDYA^<746w)aJfcN;=9 zI@@oKOm|fc#K_|W8lnip*V(1XvsLNcGp+Mz4~>8KOj>3B`({%9@0fG9v*=oK_Y7K{RjU)~o=-RP>7MhLFdp*Ex#zsya$d2T zDXXWv=Hz($6Z-w1bJO`com%{dJaF65V literal 0 HcmV?d00001 diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 0241d2b29..488404bf0 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -19,7 +19,7 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.3.0" } half = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.19.0", features = ["extension-module"] } +pyo3 = { version = "0.19.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] pyo3-build-config = "0.19" From 7bbde55c61d9bff90c9f7d0005ed17bbea4b4a8f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 16:12:22 +0100 Subject: [PATCH 8/9] Marian MT model (#1210) * Skeleton files for the marian MT model. * Marian initialization. * Implement the attention forward method. * Forward pass for the encoder side. * Expose the encoder and decoder. * Start plugging the decoder. * Forward pass for the decoder layer. * Set up the marian example. * Add some missing backtraces. * Bugfix. --- candle-core/src/cpu_backend.rs | 20 +- candle-examples/examples/marian-mt/main.rs | 90 ++++ candle-transformers/src/models/marian.rs | 413 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + .../src/models/with_tracing.rs | 7 + 5 files changed, 521 insertions(+), 10 deletions(-) create mode 100644 candle-examples/examples/marian-mt/main.rs create mode 100644 candle-transformers/src/models/marian.rs diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 86cbeb78a..e9ff46412 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> { fn f(&self, src: &[T], src_l: &Layout) -> Result> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], - None => Err(Error::RequiresContiguous { op: "gather" })?, + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, }; let src = match src_l.contiguous_offsets() { Some((a, b)) => &src[a..b], - None => Err(Error::RequiresContiguous { op: "gather" })?, + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, }; let dim = self.dim; let ids_dims = self.ids_l.dims(); @@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> { fn f(&self, src: &[T], layout: &Layout) -> Result> { let src = match layout.contiguous_offsets() { Some((a, b)) => &src[a..b], - None => Err(Error::RequiresContiguous { op: "index-select" })?, + None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?, }; let dim = self.dim; let n_ids = match self.ids_l.dims() { @@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { let mut dst = vec![T::zero(); dst_len]; copy_strided_src_(v1, &mut dst, 0, l1); let src = match src_l.contiguous_offsets() { - None => Err(Error::RequiresContiguous { op: "scatter-add" })?, + None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?, Some((o1, o2)) => &src[o1..o2], }; @@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> { let ids = match self.ids_l.contiguous_offsets() { Some((a, b)) => &self.ids[a..b], - None => Err(Error::RequiresContiguous { op: "gather" })?, + None => Err(Error::RequiresContiguous { op: "gather" }.bt())?, }; for left_i in 0..ids_left_len { let start_ids_idx = left_i * ids_right_len * ids_dim_len; @@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> { let mut dst = vec![T::zero(); dst_len]; copy_strided_src_(v1, &mut dst, 0, l1); let src = match src_l.contiguous_offsets() { - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, Some((o1, o2)) => &src[o1..o2], }; let dim = self.dim; @@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage { Self::U8(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, }; IndexAdd { ids, dim }.map(self, l, src, src_l) } Self::U32(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, }; IndexAdd { ids, dim }.map(self, l, src, src_l) } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], - None => Err(Error::RequiresContiguous { op: "index-add" })?, + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, }; IndexAdd { ids, dim }.map(self, l, src, src_l) } - _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")), + _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()), } } diff --git a/candle-examples/examples/marian-mt/main.rs b/candle-examples/examples/marian-mt/main.rs new file mode 100644 index 000000000..ed044627c --- /dev/null +++ b/candle-examples/examples/marian-mt/main.rs @@ -0,0 +1,90 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Error as E; +use clap::Parser; + +use candle::{DType, Tensor}; +use candle_examples::token_output_stream::TokenOutputStream; +use candle_nn::VarBuilder; +use candle_transformers::models::marian; + +use tokenizers::Tokenizer; + +// TODO: Maybe add support for the conditional prompt. +#[derive(Parser)] +struct Args { + #[arg(long)] + model: String, + + #[arg(long)] + tokenizer: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, + + /// Text to be translated + #[arg(long)] + text: String, +} + +const SEP_TOKEN_ID: u32 = 102; + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let config = marian::Config::opus_mt_tc_big_fr_en(); + + let device = candle_examples::device(args.cpu)?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? }; + let model = marian::MTModel::new(&config, vb)?; + + let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?; + let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone()); + let mut logits_processor = + candle_transformers::generation::LogitsProcessor::new(1337, None, None); + + let encoder_xs = { + let tokens = tokenizer + .encode(args.text, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?; + model.encoder().forward(&tokens, 0)? + }; + + let mut token_ids = vec![30522u32]; + for index in 0..1000 { + // TODO: Add a kv cache. + let context_size = if index >= 1000 { 1 } else { token_ids.len() }; + let start_pos = token_ids.len().saturating_sub(context_size); + let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; + let logits = model.decode(&input_ids, &encoder_xs)?; + let logits = logits.squeeze(0)?; + let logits = logits.get(logits.dim(0)? - 1)?; + let token = logits_processor.sample(&logits)?; + if token == SEP_TOKEN_ID { + break; + } + token_ids.push(token); + if let Some(t) = tokenizer_dec.next_token(token)? { + use std::io::Write; + print!("{t}"); + std::io::stdout().flush()?; + } + } + if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? { + print!("{rest}"); + } + + Ok(()) +} diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs new file mode 100644 index 000000000..d48ce38b1 --- /dev/null +++ b/candle-transformers/src/models/marian.rs @@ -0,0 +1,413 @@ +#![allow(unused)] +use super::with_tracing::{linear, linear_no_bias, Embedding, Linear}; +use candle::{Module, Result, Tensor}; +use candle_nn::{layer_norm, LayerNorm, VarBuilder}; + +#[derive(Debug, Clone)] +pub struct Config { + pub vocab_size: usize, + pub decoder_vocab_size: Option, + pub max_position_embeddings: usize, + pub encoder_layers: usize, + pub encoder_ffn_dim: usize, + pub encoder_attention_heads: usize, + pub decoder_layers: usize, + pub decoder_ffn_dim: usize, + pub decoder_attention_heads: usize, + pub use_cache: bool, + pub is_encoder_decoder: bool, + pub activation_function: candle_nn::Activation, + pub d_model: usize, + pub decoder_start_token_id: usize, + pub scale_embedding: bool, + pub pad_token_id: usize, + pub eos_token_id: usize, + pub forced_eos_token_id: usize, + pub share_encoder_decoder_embeddings: bool, +} + +impl Config { + // https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en/blob/main/config.json + pub fn opus_mt_tc_big_fr_en() -> Self { + Self { + activation_function: candle_nn::Activation::Relu, + d_model: 1024, + decoder_attention_heads: 16, + decoder_ffn_dim: 4096, + decoder_layers: 6, + decoder_start_token_id: 53016, + decoder_vocab_size: Some(53017), + encoder_attention_heads: 16, + encoder_ffn_dim: 4096, + encoder_layers: 6, + eos_token_id: 43311, + forced_eos_token_id: 43311, + is_encoder_decoder: true, + max_position_embeddings: 1024, + pad_token_id: 53016, + scale_embedding: true, + share_encoder_decoder_embeddings: true, + use_cache: true, + vocab_size: 53017, + } + } +} + +#[derive(Debug, Clone)] +struct SinusoidalPositionalEmbedding { + emb: Embedding, +} + +impl SinusoidalPositionalEmbedding { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dev = vb.device(); + let dtype = vb.dtype(); + let num_positions = cfg.max_position_embeddings; + let dim = cfg.d_model; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32)) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, num_positions as u32, dev)? + .to_dtype(dtype)? + .reshape((num_positions, 1))?; + let freqs = t.matmul(&inv_freq)?; + let sin = freqs.sin()?; + let cos = freqs.cos()?; + let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?; + let emb = Embedding::from_weights(weights)?; + Ok(Self { emb }) + } + + fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result { + let seq_len = input_ids.dim(1)?; + Tensor::arange( + past_kv_len as u32, + (past_kv_len + seq_len) as u32, + input_ids.device(), + )? + .apply(&self.emb) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + out_proj: Linear, + scaling: f64, + num_heads: usize, + head_dim: usize, +} + +impl Attention { + fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result { + let num_heads = if is_decoder { + cfg.decoder_attention_heads + } else { + cfg.encoder_attention_heads + }; + let embed_dim = cfg.d_model; + let head_dim = embed_dim / num_heads; + let scaling = (head_dim as f64).powf(-0.5); + let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?; + let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?; + let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?; + let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + out_proj, + scaling, + num_heads, + head_dim, + }) + } + + fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result { + tensor + .reshape((bsz, (), self.num_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, kv_states: Option<&Tensor>) -> Result { + let is_cross_attn = kv_states.is_some(); + let (b_sz, tgt_len, _) = xs.dims3()?; + let query_states = (xs.apply(&self.q_proj)? * self.scaling)?; + let (key_states, value_states) = match kv_states { + None => { + let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?; + (key_states, value_states) + } + Some(kv_states) => { + let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?; + let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?; + (key_states, value_states) + } + }; + let proj_shape = (b_sz * self.num_heads, (), self.head_dim); + let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?; + let key_states = key_states.reshape(proj_shape)?; + let value_states = value_states.reshape(proj_shape)?; + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + // todo: attn_mask + let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?; + let attn_output = attn_probs.matmul(&value_states)?; + attn_output + .reshape((b_sz, self.num_heads, tgt_len, self.head_dim))? + .transpose(1, 2)? + .reshape((b_sz, tgt_len, self.head_dim * self.num_heads))? + .apply(&self.out_proj) + } +} + +#[derive(Debug, Clone)] +struct EncoderLayer { + self_attn: Attention, + self_attn_layer_norm: LayerNorm, + activation_fn: candle_nn::Activation, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl EncoderLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?; + let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?; + let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; + Ok(Self { + self_attn, + self_attn_layer_norm, + activation_fn: cfg.activation_function, + fc1, + fc2, + final_layer_norm, + }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = + (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let residual = &xs; + let xs = xs + .apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2)?; + (xs + residual)?.apply(&self.final_layer_norm) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + self_attn_layer_norm: LayerNorm, + activation_fn: candle_nn::Activation, + encoder_attn: Attention, + encoder_attn_layer_norm: LayerNorm, + fc1: Linear, + fc2: Linear, + final_layer_norm: LayerNorm, +} + +impl DecoderLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?; + let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?; + let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?; + let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?; + let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?; + let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?; + Ok(Self { + self_attn, + self_attn_layer_norm, + activation_fn: cfg.activation_function, + encoder_attn, + encoder_attn_layer_norm, + fc1, + fc2, + final_layer_norm, + }) + } + + fn forward(&self, xs: &Tensor, encoder_xs: Option<&Tensor>) -> Result { + let residual = xs; + let xs = + (self.self_attn.forward(xs, None)? + residual)?.apply(&self.self_attn_layer_norm)?; + let xs = match encoder_xs { + None => xs, + Some(encoder_xs) => { + let residual = &xs; + let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?; + (residual + xs)?.apply(&self.self_attn_layer_norm)? + } + }; + let residual = &xs; + let xs = xs + .apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2)?; + (xs + residual)?.apply(&self.final_layer_norm) + } +} + +#[derive(Debug, Clone)] +pub struct Encoder { + embed_tokens: Embedding, + embed_positions: SinusoidalPositionalEmbedding, + layers: Vec, + embed_scale: Option, +} + +impl Encoder { + fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result { + let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?; + let mut layers = Vec::with_capacity(cfg.encoder_layers); + let vb_l = vb.pp("layers"); + for idx in 0..cfg.encoder_layers { + let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?; + layers.push(layer) + } + let embed_scale = if cfg.scale_embedding { + Some((cfg.d_model as f64).sqrt()) + } else { + None + }; + Ok(Self { + embed_tokens: embed_tokens.clone(), + embed_positions, + layers, + embed_scale, + }) + } + + pub fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result { + let xs = xs.apply(&self.embed_tokens)?; + let xs = match self.embed_scale { + None => xs, + Some(scale) => (xs * scale)?, + }; + let embed_pos = self + .embed_positions + .forward(&xs, past_kv_len)? + .unsqueeze(0)?; + let mut xs = xs.broadcast_add(&embed_pos)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct Decoder { + embed_tokens: Embedding, + embed_positions: SinusoidalPositionalEmbedding, + layers: Vec, + embed_scale: Option, +} + +impl Decoder { + fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result { + let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?; + let mut layers = Vec::with_capacity(cfg.decoder_layers); + let vb_l = vb.pp("layers"); + for idx in 0..cfg.decoder_layers { + let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?; + layers.push(layer) + } + let embed_scale = if cfg.scale_embedding { + Some((cfg.d_model as f64).sqrt()) + } else { + None + }; + Ok(Self { + embed_tokens: embed_tokens.clone(), + embed_positions, + layers, + embed_scale, + }) + } + + pub fn forward( + &self, + xs: &Tensor, + encoder_xs: Option<&Tensor>, + past_kv_len: usize, + ) -> Result { + let xs = xs.apply(&self.embed_tokens)?; + let xs = match self.embed_scale { + None => xs, + Some(scale) => (xs * scale)?, + }; + let embed_pos = self + .embed_positions + .forward(&xs, past_kv_len)? + .unsqueeze(0)?; + let mut xs = xs.broadcast_add(&embed_pos)?; + for layer in self.layers.iter() { + xs = layer.forward(&xs, encoder_xs)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +struct Model { + shared: Embedding, + encoder: Encoder, + decoder: Decoder, +} + +impl Model { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?; + let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?; + let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?; + Ok(Self { + shared, + encoder, + decoder, + }) + } +} + +#[derive(Debug, Clone)] +pub struct MTModel { + model: Model, + final_logits_bias: Tensor, +} + +impl MTModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size); + let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?; + let model = Model::new(cfg, vb.pp("model"))?; + Ok(Self { + model, + final_logits_bias, + }) + } + + pub fn encoder(&self) -> &Encoder { + &self.model.encoder + } + + pub fn decoder(&self) -> &Decoder { + &self.model.decoder + } + + pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result { + self.model.decoder.forward(xs, Some(encoder_xs), 0) + } +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index aecfcd672..370b9108d 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -10,6 +10,7 @@ pub mod jina_bert; pub mod llama; pub mod llama2_c; pub mod llama2_c_weights; +pub mod marian; pub mod mistral; pub mod mixformer; pub mod mpt; diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index 39258085d..a657011c3 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -14,6 +14,13 @@ impl Embedding { Ok(Self { inner, span }) } + pub fn from_weights(weights: Tensor) -> Result { + let (_in_size, out_size) = weights.dims2()?; + let inner = candle_nn::Embedding::new(weights, out_size); + let span = tracing::span!(tracing::Level::TRACE, "embedding"); + Ok(Self { inner, span }) + } + pub fn embeddings(&self) -> &Tensor { self.inner.embeddings() } From 154c674a798fd5a40d57ff9a8664856d9c41ca56 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 29 Oct 2023 16:28:53 +0100 Subject: [PATCH 9/9] Add i64-abs. (#1216) --- candle-core/src/op.rs | 35 ++++++++++++++++++++++++++++++- candle-core/tests/tensor_tests.rs | 8 +++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index b7f99f115..e1168c2e4 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln); unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin); unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos); unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh); -unary_op!(Abs, "abs", v, v.abs()); unary_op!(Neg, "neg", v, -v); unary_op!(Recip, "recip", v, v.recip()); unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr); @@ -666,6 +665,40 @@ impl UnaryOpT for Erf { } } +impl UnaryOpT for Abs { + const NAME: &'static str = "abs"; + const KERNEL: &'static str = "uabs"; + const V: Self = Abs; + #[inline(always)] + fn bf16(v: bf16) -> bf16 { + v.abs() + } + #[inline(always)] + fn f16(v: f16) -> f16 { + v.abs() + } + #[inline(always)] + fn f32(v: f32) -> f32 { + v.abs() + } + #[inline(always)] + fn f64(v: f64) -> f64 { + v.abs() + } + #[inline(always)] + fn u8(v: u8) -> u8 { + v + } + #[inline(always)] + fn u32(v: u32) -> u32 { + v + } + #[inline(always)] + fn i64(v: i64) -> i64 { + v.abs() + } +} + impl UnaryOpT for Ceil { const NAME: &'static str = "ceil"; const KERNEL: &'static str = "uceil"; diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index ae1bd0581..899efcf3a 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1089,3 +1089,11 @@ fn pad_with_same() -> Result<()> { ); Ok(()) } + +#[test] +fn i64_abs() -> Result<()> { + let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?; + let t = t.abs()?; + assert_eq!(t.to_vec1::()?, [42, 1337]); + Ok(()) +}