Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the i32 dtype #2432

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,11 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": [
"cuda"
],
"files.associations": {
"cstdint": "cpp"
}
}
5 changes: 5 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ impl Tensor {
f.write_u32::<LittleEndian>(v)?
}
}
DType::I32 => {
for v in vs.to_vec1::<i32>()? {
f.write_i32::<LittleEndian>(v)?
}
}
DType::I64 => {
for v in vs.to_vec1::<i64>()? {
f.write_i64::<LittleEndian>(v)?
Expand Down
11 changes: 11 additions & 0 deletions candle-core/src/cpu/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,17 @@ impl VecOps for u32 {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
<Self as Ord>::min(self, other)
}

#[inline(always)]
fn max(self, other: Self) -> Self {
<Self as Ord>::max(self, other)
}
}
impl VecOps for i64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
Expand Down
114 changes: 112 additions & 2 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ const USE_IM2COL_CONV2D: bool = true;
pub enum CpuStorage {
U8(Vec<u8>),
U32(Vec<u32>),
I32(Vec<i32>),
I64(Vec<i64>),
BF16(Vec<bf16>),
F16(Vec<f16>),
Expand All @@ -30,6 +31,7 @@ pub enum CpuStorage {
pub enum CpuStorageRef<'a> {
U8(&'a [u8]),
U32(&'a [u32]),
I32(&'a [i32]),
I64(&'a [i64]),
BF16(&'a [bf16]),
F16(&'a [f16]),
Expand Down Expand Up @@ -1567,6 +1569,17 @@ impl CpuStorage {
.concat();
Self::U32(storages)
}
Self::I32(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::I32(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::I32(storages)
}
Self::I64(_) => {
let storages = storages
.iter()
Expand Down Expand Up @@ -1634,6 +1647,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(_) => DType::U8,
Self::U32(_) => DType::U32,
Self::I32(_) => DType::I32,
Self::I64(_) => DType::I64,
Self::BF16(_) => DType::BF16,
Self::F16(_) => DType::F16,
Expand All @@ -1653,6 +1667,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::I32(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
}
(Self::I64(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32));
Ok(Self::BF16(data))
Expand Down Expand Up @@ -1681,6 +1699,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::I32(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
}
(Self::I64(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
Expand Down Expand Up @@ -1709,6 +1731,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::I32(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::I64(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
Expand Down Expand Up @@ -1753,6 +1779,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::I32(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::I64(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
Expand All @@ -1765,6 +1795,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U32(data))
}
(Self::I32(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::I64(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
Expand All @@ -1785,6 +1819,38 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::U32(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I32(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::I32(data))
}
(Self::I64(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::BF16(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i32);
Ok(Self::I32(data))
}
(Self::F16(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i32);
Ok(Self::I32(data))
}
(Self::F32(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::F64(storage), DType::I32) => {
let data = unary_map(storage, layout, |v| v as i32);
Ok(Self::I32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand All @@ -1793,6 +1859,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I32(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::I64(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::I64(data))
Expand Down Expand Up @@ -1821,6 +1891,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I32(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
}
(Self::I64(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
Expand Down Expand Up @@ -1956,6 +2030,7 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
Expand All @@ -1981,6 +2056,7 @@ impl BackendStorage for CpuStorage {
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
}
}
Expand Down Expand Up @@ -2031,6 +2107,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, B::u32);
Ok(Self::U32(data))
}
Self::I32(storage) => {
let data = unary_map(storage, layout, B::i32);
Ok(Self::I32(data))
}
Self::I64(storage) => {
let data = unary_map(storage, layout, B::i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -2085,6 +2165,14 @@ impl BackendStorage for CpuStorage {
};
Ok(Self::U32(data))
}
(Self::I32(lhs), Self::I32(rhs)) => {
let data = if B::I32_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec)
} else {
binary_map(lhs_l, rhs_l, lhs, rhs, B::i32)
};
Ok(Self::I32(data))
}
(Self::I64(lhs), Self::I64(rhs)) => {
let data = if B::I64_VEC {
binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec)
Expand Down Expand Up @@ -2128,6 +2216,9 @@ impl BackendStorage for CpuStorage {
(Self::U32(src), Self::U32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I32(src), Self::I32(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
(Self::I64(src), Self::I64(dst)) => {
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
}
Expand Down Expand Up @@ -2159,6 +2250,7 @@ impl BackendStorage for CpuStorage {
match (self, dst) {
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
(Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
Expand Down Expand Up @@ -2188,6 +2280,7 @@ impl BackendStorage for CpuStorage {
match self {
Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")),
}
Expand Down Expand Up @@ -2357,6 +2450,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
}
Expand All @@ -2366,6 +2460,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l),
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
}
Expand All @@ -2383,6 +2478,7 @@ impl BackendStorage for CpuStorage {
match ids {
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
}
Expand Down Expand Up @@ -2412,6 +2508,13 @@ impl BackendStorage for CpuStorage {
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
Expand Down Expand Up @@ -2483,7 +2586,7 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
DType::U8 | DType::U32 | DType::I32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt())
}
DType::BF16 => {
Expand Down Expand Up @@ -2529,7 +2632,7 @@ impl BackendDevice for CpuDevice {
let elem_count = shape.elem_count();
let mut rng = rand::thread_rng();
match dtype {
DType::U8 | DType::U32 | DType::I64 => {
DType::U8 | DType::U32 | DType::I32 | DType::I64 => {
Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt())
}
DType::BF16 => {
Expand Down Expand Up @@ -2588,6 +2691,11 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::U32(v)
}
DType::I32 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::I32(v)
}
DType::I64 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
Expand Down Expand Up @@ -2622,6 +2730,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
DType::I32 => CpuStorage::I32(vec![1i32; elem_count]),
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
Expand All @@ -2636,6 +2745,7 @@ impl BackendDevice for CpuDevice {
let storage = match dtype {
DType::U8 => CpuStorage::U8(vec![0u8; elem_count]),
DType::U32 => CpuStorage::U32(vec![0u32; elem_count]),
DType::I32 => CpuStorage::I32(vec![0i32; elem_count]),
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
Expand Down
2 changes: 2 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub trait Map1 {
match vs {
C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)),
C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)),
C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)),
C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)),
C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)),
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
Expand All @@ -26,6 +27,7 @@ pub trait Map1Any {
match vs {
C::U8(vs) => Ok(self.f(vs, layout, C::U8)?),
C::U32(vs) => Ok(self.f(vs, layout, C::U32)?),
C::I32(vs) => Ok(self.f(vs, layout, C::I32)?),
C::I64(vs) => Ok(self.f(vs, layout, C::I64)?),
C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?),
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
Expand Down
Loading
Loading