Skip to content

Commit

Permalink
Merge branch 'main' into feat/better-shape-handling
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 29, 2023
2 parents 2fa0c72 + 154c674 commit 5d0e5ab
Show file tree
Hide file tree
Showing 30 changed files with 1,670 additions and 54 deletions.
Binary file added .github/workflows/maturin.yml
Binary file not shown.
7 changes: 7 additions & 0 deletions candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
Expand Down
20 changes: 10 additions & 10 deletions candle-core/src/cpu_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
Expand Down Expand Up @@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" })?,
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
Expand Down Expand Up @@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};

Expand All @@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {

let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
Expand Down Expand Up @@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
Expand Down Expand Up @@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::U32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
}
}

Expand Down
12 changes: 12 additions & 0 deletions candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
self(xs)
}
}

// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}

impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}
35 changes: 34 additions & 1 deletion candle-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
Expand Down Expand Up @@ -666,6 +665,40 @@ impl UnaryOpT for Erf {
}
}

impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}

impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
Expand Down
5 changes: 5 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,11 @@ impl Tensor {
m.forward(self)
}

/// Run the `forward` method of `m` on `self`.
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
m.forward_t(self, train)
}

pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}
Expand Down
65 changes: 65 additions & 0 deletions candle-core/tests/conv_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
]
);

// Replicate the issue from https://github.com/huggingface/candle/issues/1212
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 7.87, 0.0, 0.0],
[-1.8, -7.82, 5.9, 0.0, 0.0],
[-3.12, 4.49, 5.52, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[21.73, 3.39, 4.77, 0.0, 0.0],
[8.25, 3.73, 27.61, 0.0, 0.0],
[-20.55, -5.61, -2.77, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-8.98, 9.91, -7.15, 0.0, 0.0],
[4.93, -0.33, 4.56, 0.0, 0.0],
[-6.7, -5.76, -8.05, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[23.54, 6.98, -10.0, 0.0, 0.0],
[9.65, 6.18, 18.72, 0.0, 0.0],
[3.29, -5.27, 0.79, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[-3.47, 7.44, 0.66],
[12.89, -3.4, -9.29],
[-14.16, -0.83, 7.14]
],
[
[-3.23, 5.37, -3.02],
[-2.12, -11.24, 1.94],
[6.97, 7.2, 2.99]
],
[
[-4.04, -3.31, 4.87],
[-6.68, -5.68, 1.73],
[-5.54, 4.32, 0.52]
],
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
]
);

Ok(())
}

Expand Down
8 changes: 8 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,3 +1089,11 @@ fn pad_with_same() -> Result<()> {
);
Ok(())
}

#[test]
fn i64_abs() -> Result<()> {
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
let t = t.abs()?;
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}
20 changes: 15 additions & 5 deletions candle-examples/examples/llama2-c/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

mod model;
mod qmodel;
use candle_transformers::models::llama2_c as model;
use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod training;
mod weights;
use clap::{Parser, Subcommand};

use anyhow::{Error as E, Result};
Expand Down Expand Up @@ -262,8 +262,18 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let config = Config::tiny();
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
Expand Down Expand Up @@ -291,7 +301,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny();
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/llama2-c/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let config = Config::tiny_15m();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);

Expand Down
90 changes: 90 additions & 0 deletions candle-examples/examples/marian-mt/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use anyhow::Error as E;
use clap::Parser;

use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;

use tokenizers::Tokenizer;

// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: String,

#[arg(long)]
tokenizer: String,

/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,

/// Text to be translated
#[arg(long)]
text: String,
}

const SEP_TOKEN_ID: u32 = 102;

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();

let config = marian::Config::opus_mt_tc_big_fr_en();

let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[&args.model], DType::F32, &device)? };
let model = marian::MTModel::new(&config, vb)?;

let tokenizer = Tokenizer::from_file(&args.tokenizer).map_err(E::msg)?;
let mut tokenizer_dec = TokenOutputStream::new(tokenizer.clone());
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);

let encoder_xs = {
let tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};

let mut token_ids = vec![30522u32];
for index in 0..1000 {
// TODO: Add a kv cache.
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}

Ok(())
}
Loading

0 comments on commit 5d0e5ab

Please sign in to comment.