diff --git a/.vscode/settings.json b/.vscode/settings.json index b2dbd68012..6989e19810 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,11 @@ "candle-pyo3" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "rust-analyzer.cargo.features": [ + "cuda" + ], + "files.associations": { + "cstdint": "cpp" + } } \ No newline at end of file diff --git a/candle-core/src/convert.rs b/candle-core/src/convert.rs index 5ea5612a7c..b29ff346f6 100644 --- a/candle-core/src/convert.rs +++ b/candle-core/src/convert.rs @@ -130,6 +130,11 @@ impl Tensor { f.write_u32::(v)? } } + DType::I32 => { + for v in vs.to_vec1::()? { + f.write_i32::(v)? + } + } DType::I64 => { for v in vs.to_vec1::()? { f.write_i64::(v)? diff --git a/candle-core/src/cpu/kernels.rs b/candle-core/src/cpu/kernels.rs index 527646d62b..fe0e241622 100644 --- a/candle-core/src/cpu/kernels.rs +++ b/candle-core/src/cpu/kernels.rs @@ -144,6 +144,17 @@ impl VecOps for u32 { ::max(self, other) } } +impl VecOps for i32 { + #[inline(always)] + fn min(self, other: Self) -> Self { + ::min(self, other) + } + + #[inline(always)] + fn max(self, other: Self) -> Self { + ::max(self, other) + } +} impl VecOps for i64 { #[inline(always)] fn min(self, other: Self) -> Self { diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 58773c8020..28c567523e 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -19,6 +19,7 @@ const USE_IM2COL_CONV2D: bool = true; pub enum CpuStorage { U8(Vec), U32(Vec), + I32(Vec), I64(Vec), BF16(Vec), F16(Vec), @@ -30,6 +31,7 @@ pub enum CpuStorage { pub enum CpuStorageRef<'a> { U8(&'a [u8]), U32(&'a [u32]), + I32(&'a [i32]), I64(&'a [i64]), BF16(&'a [bf16]), F16(&'a [f16]), @@ -1567,6 +1569,17 @@ impl CpuStorage { .concat(); Self::U32(storages) } + Self::I32(_) => { + let storages = storages + .iter() + .map(|s| match s { + Self::I32(s) => Ok(s.as_slice()), + _ => crate::bail!("dtype mismatch"), + }) + .collect::>>()? + .concat(); + Self::I32(storages) + } Self::I64(_) => { let storages = storages .iter() @@ -1634,6 +1647,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(_) => DType::U8, Self::U32(_) => DType::U32, + Self::I32(_) => DType::I32, Self::I64(_) => DType::I64, Self::BF16(_) => DType::BF16, Self::F16(_) => DType::F16, @@ -1653,6 +1667,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } + (Self::I32(storage), DType::BF16) => { + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); + Ok(Self::BF16(data)) + } (Self::I64(storage), DType::BF16) => { let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) @@ -1681,6 +1699,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } + (Self::I32(storage), DType::F16) => { + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); + Ok(Self::F16(data)) + } (Self::I64(storage), DType::F16) => { let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) @@ -1709,6 +1731,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } + (Self::I32(storage), DType::F32) => { + let data = unary_map(storage, layout, |v| v as f32); + Ok(Self::F32(data)) + } (Self::I64(storage), DType::F32) => { let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) @@ -1753,6 +1779,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) } + (Self::I32(storage), DType::U8) => { + let data = unary_map(storage, layout, |v| v as u8); + Ok(Self::U8(data)) + } (Self::I64(storage), DType::U8) => { let data = unary_map(storage, layout, |v| v as u8); Ok(Self::U8(data)) @@ -1765,6 +1795,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } + (Self::I32(storage), DType::U32) => { + let data = unary_map(storage, layout, |v| v as u32); + Ok(Self::U32(data)) + } (Self::I64(storage), DType::U32) => { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) @@ -1785,6 +1819,38 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } + (Self::U8(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::U32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } + (Self::I32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v); + Ok(Self::I32(data)) + } + (Self::I64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::BF16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F16(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v.to_f32() as i32); + Ok(Self::I32(data)) + } + (Self::F32(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } + (Self::F64(storage), DType::I32) => { + let data = unary_map(storage, layout, |v| v as i32); + Ok(Self::I32(data)) + } (Self::U8(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) @@ -1793,6 +1859,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as i64); Ok(Self::I64(data)) } + (Self::I32(storage), DType::I64) => { + let data = unary_map(storage, layout, |v| v as i64); + Ok(Self::I64(data)) + } (Self::I64(storage), DType::I64) => { let data = unary_map(storage, layout, |v| v); Ok(Self::I64(data)) @@ -1821,6 +1891,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } + (Self::I32(storage), DType::F64) => { + let data = unary_map(storage, layout, |v| v as f64); + Ok(Self::F64(data)) + } (Self::I64(storage), DType::F64) => { let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) @@ -1956,6 +2030,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I32, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -1981,6 +2056,7 @@ impl BackendStorage for CpuStorage { } Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()), Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()), + Self::I32(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()), } } @@ -2031,6 +2107,10 @@ impl BackendStorage for CpuStorage { let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } + Self::I32(storage) => { + let data = unary_map(storage, layout, B::i32); + Ok(Self::I32(data)) + } Self::I64(storage) => { let data = unary_map(storage, layout, B::i64); Ok(Self::I64(data)) @@ -2085,6 +2165,14 @@ impl BackendStorage for CpuStorage { }; Ok(Self::U32(data)) } + (Self::I32(lhs), Self::I32(rhs)) => { + let data = if B::I32_VEC { + binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i32, B::i32_vec) + } else { + binary_map(lhs_l, rhs_l, lhs, rhs, B::i32) + }; + Ok(Self::I32(data)) + } (Self::I64(lhs), Self::I64(rhs)) => { let data = if B::I64_VEC { binary_map_vec(lhs_l, rhs_l, lhs, rhs, B::i64, B::i64_vec) @@ -2128,6 +2216,9 @@ impl BackendStorage for CpuStorage { (Self::U32(src), Self::U32(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } + (Self::I32(src), Self::I32(dst)) => { + copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) + } (Self::I64(src), Self::I64(dst)) => { copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o) } @@ -2159,6 +2250,7 @@ impl BackendStorage for CpuStorage { match (self, dst) { (Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::I32(src), Self::I32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::I64(src), Self::I64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), @@ -2188,6 +2280,7 @@ impl BackendStorage for CpuStorage { match self { Self::U8(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::U32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), + Self::I32(pred) => WCond(pred, layout).map(t, t_l, f, f_l), Self::I64(pred) => WCond(pred, layout).map(t, t_l, f, f_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "where-cond")), } @@ -2357,6 +2450,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()), } @@ -2366,6 +2460,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l), + Self::I32(ids) => Gather { ids, ids_l, dim }.map(self, l), Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()), } @@ -2383,6 +2478,7 @@ impl BackendStorage for CpuStorage { match ids { Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), + Self::I32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l), _ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()), } @@ -2412,6 +2508,13 @@ impl BackendStorage for CpuStorage { }; IndexAdd { ids, dim }.map(self, l, src, src_l) } + Self::I32(ids) => { + let ids = match ids_l.contiguous_offsets() { + Some((a, b)) => &ids[a..b], + None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?, + }; + IndexAdd { ids, dim }.map(self, l, src, src_l) + } Self::I64(ids) => { let ids = match ids_l.contiguous_offsets() { Some((a, b)) => &ids[a..b], @@ -2483,7 +2586,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_uniform").bt()) } DType::BF16 => { @@ -2529,7 +2632,7 @@ impl BackendDevice for CpuDevice { let elem_count = shape.elem_count(); let mut rng = rand::thread_rng(); match dtype { - DType::U8 | DType::U32 | DType::I64 => { + DType::U8 | DType::U32 | DType::I32 | DType::I64 => { Err(Error::UnsupportedDTypeForOp(dtype, "rand_normal").bt()) } DType::BF16 => { @@ -2588,6 +2691,11 @@ impl BackendDevice for CpuDevice { v.set_len(elem_count); CpuStorage::U32(v) } + DType::I32 => { + let mut v = Vec::with_capacity(elem_count); + v.set_len(elem_count); + CpuStorage::I32(v) + } DType::I64 => { let mut v = Vec::with_capacity(elem_count); v.set_len(elem_count); @@ -2622,6 +2730,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![1u8; elem_count]), DType::U32 => CpuStorage::U32(vec![1u32; elem_count]), + DType::I32 => CpuStorage::I32(vec![1i32; elem_count]), DType::I64 => CpuStorage::I64(vec![1i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]), @@ -2636,6 +2745,7 @@ impl BackendDevice for CpuDevice { let storage = match dtype { DType::U8 => CpuStorage::U8(vec![0u8; elem_count]), DType::U32 => CpuStorage::U32(vec![0u32; elem_count]), + DType::I32 => CpuStorage::I32(vec![0i32; elem_count]), DType::I64 => CpuStorage::I64(vec![0i64; elem_count]), DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]), DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]), diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index 3e0c69b4f7..53944aee41 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -10,6 +10,7 @@ pub trait Map1 { match vs { C::U8(vs) => Ok(C::U8(self.f(vs, layout)?)), C::U32(vs) => Ok(C::U32(self.f(vs, layout)?)), + C::I32(vs) => Ok(C::I32(self.f(vs, layout)?)), C::I64(vs) => Ok(C::I64(self.f(vs, layout)?)), C::BF16(vs) => Ok(C::BF16(self.f(vs, layout)?)), C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)), @@ -26,6 +27,7 @@ pub trait Map1Any { match vs { C::U8(vs) => Ok(self.f(vs, layout, C::U8)?), C::U32(vs) => Ok(self.f(vs, layout, C::U32)?), + C::I32(vs) => Ok(self.f(vs, layout, C::I32)?), C::I64(vs) => Ok(self.f(vs, layout, C::I64)?), C::BF16(vs) => Ok(self.f(vs, layout, C::BF16)?), C::F16(vs) => Ok(self.f(vs, layout, C::F16)?), diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 0aa58cacde..f282356b9b 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -75,6 +75,14 @@ impl CudaDevice { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(data) } + DType::I32 => { + // SAFETY: Set later by running the fill kernel. + let data = unsafe { self.alloc::(elem_count) }.w()?; + let func = self.get_or_load_func("fill_i32", kernels::FILL)?; + let params = (&data, v as i32, elem_count); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { // SAFETY: Set later by running the fill kernel. let data = unsafe { self.alloc::(elem_count) }.w()?; @@ -188,6 +196,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I32 => { + let data = self.alloc_zeros::(elem_count).w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc_zeros::(elem_count).w()?; CudaStorageSlice::I64(data) @@ -221,7 +233,7 @@ impl BackendDevice for CudaDevice { let slice = match dtype { // TODO: Add support for F16 and BF16 though this is likely to require some upstream // cudarc changes. - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + DType::U8 | DType::U32 | DType::I64 | DType::I32 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_uniform", @@ -265,7 +277,7 @@ impl BackendDevice for CudaDevice { elem_count }; let slice = match dtype { - DType::U8 | DType::U32 | DType::I64 | DType::F16 | DType::BF16 => { + DType::U8 | DType::U32 | DType::I32 | DType::I64 | DType::F16 | DType::BF16 => { Err(CudaError::UnsupportedDtype { dtype, op: "rand_normal", @@ -307,6 +319,10 @@ impl BackendDevice for CudaDevice { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::U32(data) } + DType::I32 => { + let data = self.alloc::(elem_count).w()?; + CudaStorageSlice::I32(data) + } DType::I64 => { let data = self.alloc::(elem_count).w()?; CudaStorageSlice::I64(data) @@ -344,6 +360,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorageRef::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorageRef::I64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -381,6 +401,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I32(storage) => { + let data = self.htod_sync_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.htod_sync_copy(storage).w()?; CudaStorageSlice::I64(data) @@ -418,6 +442,10 @@ impl BackendDevice for CudaDevice { let data = self.htod_copy(storage).w()?; CudaStorageSlice::U32(data) } + CpuStorage::I32(storage) => { + let data = self.htod_copy(storage).w()?; + CudaStorageSlice::I32(data) + } CpuStorage::I64(storage) => { let data = self.htod_copy(storage).w()?; CudaStorageSlice::I64(data) diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 7edad3d409..e406004a77 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -47,6 +47,7 @@ impl SlicePtrOrNull { pub enum CudaStorageSlice { U8(CudaSlice), U32(CudaSlice), + I32(CudaSlice), I64(CudaSlice), BF16(CudaSlice), F16(CudaSlice), @@ -361,11 +362,14 @@ impl<'a> Map1 for IndexSelect<'a> { CudaStorageSlice::U8(slice) => { ("is_u8", *slice.slice(ids_l.start_offset()..).device_ptr()) } + CudaStorageSlice::I32(slice) => { + ("is_i32", *slice.slice(ids_l.start_offset()..).device_ptr()) + } CudaStorageSlice::I64(slice) => { ("is_i64", *slice.slice(ids_l.start_offset()..).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "index_select ids should be u8 or u32", + msg: "index_select ids should be u8/u32/i32/i64", expected: DType::U32, got: self.0.dtype(), }) @@ -425,11 +429,14 @@ impl<'a> Map1 for Gather<'a> { ("gather_u32", *slice.slice(ids_o1..ids_o2).device_ptr()) } CudaStorageSlice::U8(slice) => ("gather_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => { + ("gather_i32", *slice.slice(ids_o1..ids_o2).device_ptr()) + } CudaStorageSlice::I64(slice) => { ("gather_i64", *slice.slice(ids_o1..ids_o2).device_ptr()) } _ => Err(CudaError::UnexpectedDType { - msg: "gather ids should be u8/u32/i64", + msg: "gather ids should be u8/u32/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -475,10 +482,11 @@ impl<'a> Map2InPlace for IndexAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("ia_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("ia_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("ia_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("ia_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "index-add ids should be u8/u32/i64", + msg: "index-add ids should be u8/u32/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -523,10 +531,11 @@ impl<'a> Map2InPlace for ScatterAdd<'a> { }; let (name, ids) = match &ids.slice { CudaStorageSlice::U32(slice) => ("sa_u32", *slice.slice(ids_o1..ids_o2).device_ptr()), + CudaStorageSlice::I32(slice) => ("sa_i32", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::I64(slice) => ("sa_i64", *slice.slice(ids_o1..ids_o2).device_ptr()), CudaStorageSlice::U8(slice) => ("sa_u8", *slice.slice(ids_o1..ids_o2).device_ptr()), _ => Err(CudaError::UnexpectedDType { - msg: "scatter-add ids should be u8/u32/i64", + msg: "scatter-add ids should be u8/u32/i32/i64", expected: DType::U32, got: ids.dtype(), })?, @@ -865,6 +874,10 @@ impl<'a> Map2 for WhereCond<'a> { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_u32") } + CudaStorageSlice::I32(slice) => { + let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); + (ptr, "where_i32") + } CudaStorageSlice::I64(slice) => { let ptr = *slice.slice(ids_l.start_offset()..).device_ptr(); (ptr, "where_i64") @@ -1024,6 +1037,7 @@ macro_rules! cuda_dtype { } cuda_dtype!(u8, U8); cuda_dtype!(u32, U32); +cuda_dtype!(i32, I32); cuda_dtype!(i64, I64); cuda_dtype!(f16, F16); cuda_dtype!(bf16, BF16); @@ -1146,6 +1160,7 @@ impl BackendStorage for CudaStorage { match self.slice { CudaStorageSlice::U8(_) => DType::U8, CudaStorageSlice::U32(_) => DType::U32, + CudaStorageSlice::I32(_) => DType::I32, CudaStorageSlice::I64(_) => DType::I64, CudaStorageSlice::BF16(_) => DType::BF16, CudaStorageSlice::F16(_) => DType::F16, @@ -1172,6 +1187,7 @@ impl BackendStorage for CudaStorage { let inp = match &self.slice { CudaStorageSlice::U8(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::I32(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::I64(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), @@ -1195,6 +1211,12 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()?; CudaStorageSlice::U32(out) } + DType::I32 => { + let out = unsafe { dev.alloc::(el) }.w()?; + let params = (el, dims.len(), &ds, *inp, &out); + unsafe { func.launch(cfg, params) }.w()?; + CudaStorageSlice::I32(out) + } DType::I64 => { let out = unsafe { dev.alloc::(el) }.w()?; let params = (el, dims.len(), &ds, *inp, &out); @@ -1291,6 +1313,11 @@ impl BackendStorage for CudaStorage { let cpu_storage = dev.dtoh_sync_copy(slice).w()?; Ok(CpuStorage::U32(cpu_storage)) } + CudaStorageSlice::I32(slice) => { + let dev = slice.device(); + let cpu_storage = dev.dtoh_sync_copy(slice).w()?; + Ok(CpuStorage::I32(cpu_storage)) + } CudaStorageSlice::I64(slice) => { let dev = slice.device(); let cpu_storage = dev.dtoh_sync_copy(slice).w()?; @@ -1557,6 +1584,7 @@ impl BackendStorage for CudaStorage { S::F64(out) } (S::U32(_), S::U32(_)) => Err(CudaError::InternalError("conv2d does not support u32"))?, + (S::I32(_), S::I32(_)) => Err(CudaError::InternalError("conv2d does not support i32"))?, (S::I64(_), S::I64(_)) => Err(CudaError::InternalError("conv2d does not support i64"))?, _ => Err(CudaError::InternalError("dtype mismatch in conv2d"))?, }; @@ -1740,6 +1768,11 @@ impl BackendStorage for CudaStorage { *d.slice(dst_o..).device_ptr(), "copy2d_u32", ), + (S::I32(s), S::I32(d)) => ( + *s.slice(src_o..).device_ptr(), + *d.slice(dst_o..).device_ptr(), + "copy2d_i32", + ), (S::I64(s), S::I64(d)) => ( *s.slice(src_o..).device_ptr(), *d.slice(dst_o..).device_ptr(), @@ -1846,6 +1879,18 @@ impl BackendStorage for CudaStorage { unsafe { func.launch(cfg, params) }.w()? } } + (CudaStorageSlice::I32(src), CudaStorageSlice::I32(dst)) => { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { + dev.dtod_copy(&src, &mut dst).w()? + } else { + let func = dev.get_or_load_func("ucopy_i32", kernels::UNARY)?; + // SAFETY: Set later by running the kernel. + let params = (el_count, dims.len(), &ds, &src, &mut dst); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()? + } + } (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { diff --git a/candle-core/src/cuda_backend/utils.rs b/candle-core/src/cuda_backend/utils.rs index c1210727ad..ae009b26ab 100644 --- a/candle-core/src/cuda_backend/utils.rs +++ b/candle-core/src/cuda_backend/utils.rs @@ -19,6 +19,7 @@ pub trait Map1 { let out = match s { S::U8(s) => S::U8(self.f(s, d, l)?), S::U32(s) => S::U32(self.f(s, d, l)?), + S::I32(s) => S::I32(self.f(s, d, l)?), S::I64(s) => S::I64(self.f(s, d, l)?), S::BF16(s) => S::BF16(self.f(s, d, l)?), S::F16(s) => S::F16(self.f(s, d, l)?), @@ -136,6 +137,7 @@ pub trait Map1Any { let out = match s { S::U8(s) => self.f(s, d, l, S::U8)?, S::U32(s) => self.f(s, d, l, S::U32)?, + S::I32(s) => self.f(s, d, l, S::I32)?, S::I64(s) => self.f(s, d, l, S::I64)?, S::BF16(s) => self.f(s, d, l, S::BF16)?, S::F16(s) => self.f(s, d, l, S::F16)?, diff --git a/candle-core/src/display.rs b/candle-core/src/display.rs index 7e6e3cf8f1..5fb370b696 100644 --- a/candle-core/src/display.rs +++ b/candle-core/src/display.rs @@ -55,6 +55,7 @@ impl std::fmt::Debug for Tensor { match self.dtype() { DType::U8 => self.fmt_dt::(f), DType::U32 => self.fmt_dt::(f), + DType::I32 => self.fmt_dt::(f), DType::I64 => self.fmt_dt::(f), DType::BF16 => self.fmt_dt::(f), DType::F16 => self.fmt_dt::(f), @@ -463,6 +464,12 @@ impl std::fmt::Display for Tensor { tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; writeln!(f)?; } + DType::I32 => { + let tf: IntFormatter = IntFormatter::new(); + let max_w = tf.max_width(&to_display); + tf.fmt_tensor(self, 1, max_w, summarize, &po, f)?; + writeln!(f)?; + } DType::I64 => { let tf: IntFormatter = IntFormatter::new(); let max_w = tf.max_width(&to_display); diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index de6cddc3a3..c6a0800b24 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -10,6 +10,8 @@ pub enum DType { U8, // Unsigned 32 bits integer. U32, + // Signed 32 bits integer. + I32, // Signed 64 bits integer. I64, // Brain floating-point using half precision (16 bits). @@ -39,6 +41,7 @@ impl std::str::FromStr for DType { match s { "u8" => Ok(Self::U8), "u32" => Ok(Self::U32), + "i32" => Ok(Self::I32), "i64" => Ok(Self::I64), "bf16" => Ok(Self::BF16), "f16" => Ok(Self::F16), @@ -55,6 +58,7 @@ impl DType { match self { Self::U8 => "u8", Self::U32 => "u32", + Self::I32 => "i32", Self::I64 => "i64", Self::BF16 => "bf16", Self::F16 => "f16", @@ -68,6 +72,7 @@ impl DType { match self { Self::U8 => 1, Self::U32 => 4, + Self::I32 => 4, Self::I64 => 8, Self::BF16 => 2, Self::F16 => 2, @@ -78,14 +83,14 @@ impl DType { pub fn is_int(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => true, + Self::U8 | Self::U32 | Self::I32 | Self::I64 => true, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => false, } } pub fn is_float(&self) -> bool { match self { - Self::U8 | Self::U32 | Self::I64 => false, + Self::U8 | Self::U32 | Self::I32 | Self::I64 => false, Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true, } } @@ -169,6 +174,7 @@ use half::{bf16, f16}; with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64); with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64); +with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64); with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64); with_dtype!(f16, F16, f16::from_f64, f16::to_f64); with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64); @@ -180,6 +186,15 @@ pub trait IntDType: WithDType { fn as_usize(&self) -> usize; } +impl IntDType for i32 { + fn is_true(&self) -> bool { + *self != 0 + } + fn as_usize(&self) -> usize { + *self as usize + } +} + impl IntDType for i64 { fn is_true(&self) -> bool { *self != 0 diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 09d5fd49cd..b0129ab08c 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -96,6 +96,7 @@ impl BackendStorage for MetalStorage { match self.dtype { DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), + DType::I32 => Ok(CpuStorage::I32(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), @@ -304,6 +305,11 @@ impl BackendStorage for MetalStorage { (ReduceOp::Max, DType::BF16) => ("fast_max_bf16_strided", true, false), (ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16_strided", true, true), (ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16_strided", true, true), + (ReduceOp::Sum, DType::I32) => ("fast_sum_i32_strided", false, false), + (ReduceOp::Min, DType::I32) => ("fast_min_i32_strided", true, false), + (ReduceOp::Max, DType::I32) => ("fast_max_i32_strided", true, false), + (ReduceOp::ArgMin, DType::I32) => ("fast_argmin_i32_strided", true, true), + (ReduceOp::ArgMax, DType::I32) => ("fast_argmax_i32_strided", true, true), (ReduceOp::Sum, DType::I64) => ("fast_sum_i64_strided", false, false), (ReduceOp::Min, DType::I64) => ("fast_min_i64_strided", true, false), (ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false), @@ -363,21 +369,30 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "cast_u32_bf16", (DType::U32, DType::F16) => "cast_u32_f16", (DType::U32, DType::F32) => "cast_u32_f32", + (DType::U32, DType::I32) => "cast_u32_i32", (DType::U32, DType::I64) => "cast_u32_i64", (DType::U32, DType::U8) => "cast_u32_u8", (DType::U8, DType::BF16) => "cast_u8_bf16", (DType::U8, DType::F16) => "cast_u8_f16", (DType::U8, DType::F32) => "cast_u8_f32", + (DType::U8, DType::I32) => "cast_u8_i32", (DType::U8, DType::I64) => "cast_u8_i64", (DType::U8, DType::U32) => "cast_u8_u32", (DType::F32, DType::BF16) => "cast_f32_bf16", (DType::F32, DType::F16) => "cast_f32_f16", + (DType::F32, DType::I32) => "cast_f32_i32", (DType::F32, DType::I64) => "cast_f32_i64", (DType::F32, DType::U32) => "cast_f32_u32", (DType::F32, DType::U8) => "cast_f32_u8", + (DType::I32, DType::BF16) => "cast_i32_bf16", + (DType::I32, DType::F16) => "cast_i32_f16", + (DType::I32, DType::F32) => "cast_i32_f32", + (DType::I32, DType::U32) => "cast_i32_u32", + (DType::I32, DType::U8) => "cast_i32_u8", + (DType::I64, DType::BF16) => "cast_i64_bf16", (DType::I64, DType::F16) => "cast_i64_f16", (DType::I64, DType::F32) => "cast_i64_f32", @@ -386,12 +401,14 @@ impl BackendStorage for MetalStorage { (DType::F16, DType::BF16) => "cast_f16_bf16", (DType::F16, DType::F32) => "cast_f16_f32", + (DType::F16, DType::I32) => "cast_f16_i32", (DType::F16, DType::I64) => "cast_f16_i64", (DType::F16, DType::U32) => "cast_f16_u32", (DType::F16, DType::U8) => "cast_f16_u8", (DType::BF16, DType::F16) => "cast_bf16_f16", (DType::BF16, DType::F32) => "cast_bf16_f32", + (DType::BF16, DType::I32) => "cast_bf16_i32", (DType::BF16, DType::I64) => "cast_bf16_i64", (DType::BF16, DType::U32) => "cast_bf16_u32", (DType::BF16, DType::U8) => "cast_bf16_u8", @@ -414,12 +431,15 @@ impl BackendStorage for MetalStorage { let kernel_name = match (self.dtype, dtype) { (DType::U32, DType::F32) => "cast_u32_f32_strided", (DType::U32, DType::U8) => "cast_u32_u8_strided", + (DType::U32, DType::I32) => "cast_u32_i32_strided", (DType::U32, DType::I64) => "cast_u32_i64_strided", (DType::U8, DType::U32) => "cast_u8_u32_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", + (DType::U8, DType::I32) => "cast_u8_i32_strided", (DType::U8, DType::I64) => "cast_u8_i64_strided", (DType::F32, DType::F16) => "cast_f32_f16_strided", (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::I32, DType::F32) => "cast_i32_f32_strided", (DType::I64, DType::F32) => "cast_i64_f32_strided", (DType::F32, DType::BF16) => "cast_f32_bf16_strided", (DType::BF16, DType::F32) => "cast_bf16_f32_strided", @@ -514,6 +534,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous_tiled::sign::HALF, ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, + ("usign", DType::I32) => contiguous_tiled::sign::I32, ("usign", DType::I64) => contiguous_tiled::sign::I64, (name, dtype) => { crate::bail!( @@ -592,6 +613,7 @@ impl BackendStorage for MetalStorage { ("usign", DType::F16) => contiguous::sign::HALF, ("usign", DType::F32) => contiguous::sign::FLOAT, ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I32) => contiguous::sign::I32, ("usign", DType::I64) => contiguous::sign::I64, (name, dtype) => { crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") @@ -723,6 +745,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "where_u32_f32", (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", + (DType::U8, DType::I32) => "where_u8_i32", (DType::U8, DType::I64) => "where_u8_i64", (DType::U8, DType::U32) => "where_u8_u32", (DType::U8, DType::U8) => "where_u8_u8", @@ -1259,6 +1282,9 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "sa_u32_f32", (DType::U32, DType::F16) => "sa_u32_f16", (DType::U32, DType::BF16) => "sa_u32_bf16", + (DType::I32, DType::F32) => "sa_i32_f32", + (DType::I32, DType::F16) => "sa_i32_f16", + (DType::I32, DType::BF16) => "sa_i32_bf16", (DType::I64, DType::F32) => "sa_i64_f32", (DType::I64, DType::F16) => "sa_i64_f16", (DType::I64, DType::BF16) => "sa_i64_bf16", @@ -1307,6 +1333,10 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I32, DType::F32) => "is_i32_f32", + (DType::I32, DType::F16) => "is_i32_f16", + (DType::I32, DType::BF16) => "is_i32_bf16", + (DType::I64, DType::F32) => "is_i64_f32", (DType::I64, DType::F16) => "is_i64_f16", (DType::I64, DType::BF16) => "is_i64_bf16", @@ -1352,9 +1382,18 @@ impl BackendStorage for MetalStorage { return Err(crate::Error::RequiresContiguous { op: "index-add" }.bt()); }; let name = match (ids.dtype, self.dtype) { + (DType::I32, DType::BF16) => "ia_i32_bf16", + (DType::I32, DType::F16) => "ia_i32_f16", + (DType::I32, DType::F32) => "ia_i32_f32", + (DType::I32, DType::I32) => "ia_i32_i32", + (DType::I32, DType::I64) => "ia_i32_i64", + (DType::I32, DType::U32) => "ia_i32_u32", + (DType::I32, DType::U8) => "ia_i32_u8", + (DType::I64, DType::BF16) => "ia_i64_bf16", (DType::I64, DType::F16) => "ia_i64_f16", (DType::I64, DType::F32) => "ia_i64_f32", + (DType::I64, DType::I32) => "ia_i64_i32", (DType::I64, DType::I64) => "ia_i64_i64", (DType::I64, DType::U32) => "ia_i64_u32", (DType::I64, DType::U8) => "ia_i64_u8", @@ -1362,6 +1401,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::BF16) => "ia_u32_bf16", (DType::U32, DType::F16) => "ia_u32_f16", (DType::U32, DType::F32) => "ia_u32_f32", + (DType::U32, DType::I32) => "ia_u32_i32", (DType::U32, DType::I64) => "ia_u32_i64", (DType::U32, DType::U32) => "ia_u32_u32", (DType::U32, DType::U8) => "ia_u32_u8", @@ -1369,6 +1409,7 @@ impl BackendStorage for MetalStorage { (DType::U8, DType::BF16) => "ia_u8_bf16", (DType::U8, DType::F16) => "ia_u8_f16", (DType::U8, DType::F32) => "ia_u8_f32", + (DType::U8, DType::I32) => "ia_u8_i32", (DType::U8, DType::I64) => "ia_u8_i64", (DType::U8, DType::U32) => "ia_u8_u32", (DType::U8, DType::U8) => "ia_u8_u8", @@ -1476,6 +1517,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::copy2d::FLOAT, DType::F16 => candle_metal_kernels::copy2d::HALF, DType::BF16 => candle_metal_kernels::copy2d::BFLOAT, + DType::I32 => candle_metal_kernels::copy2d::I32, DType::I64 => candle_metal_kernels::copy2d::I64, DType::U32 => candle_metal_kernels::copy2d::U32, DType::U8 => candle_metal_kernels::copy2d::U8, @@ -1522,6 +1564,7 @@ impl BackendStorage for MetalStorage { DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT, DType::F16 => candle_metal_kernels::unary::strided::copy::HALF, DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT, + DType::I32 => candle_metal_kernels::unary::strided::copy::I32, DType::I64 => candle_metal_kernels::unary::strided::copy::I64, DType::U32 => candle_metal_kernels::unary::strided::copy::U32, DType::U8 => candle_metal_kernels::unary::strided::copy::U8, @@ -1613,6 +1656,17 @@ impl MetalStorage { ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), + ("add", DType::I32) => (contiguous::add::I32, self.dtype), + ("sub", DType::I32) => (contiguous::sub::I32, self.dtype), + ("mul", DType::I32) => (contiguous::mul::I32, self.dtype), + ("div", DType::I32) => (contiguous::div::I32, self.dtype), + ("eq", DType::I32) => (contiguous::eq::I32, DType::U8), + ("ne", DType::I32) => (contiguous::ne::I32, DType::U8), + ("le", DType::I32) => (contiguous::le::I32, DType::U8), + ("lt", DType::I32) => (contiguous::lt::I32, DType::U8), + ("ge", DType::I32) => (contiguous::ge::I32, DType::U8), + ("gt", DType::I32) => (contiguous::gt::I32, DType::U8), + ("add", DType::I64) => (contiguous::add::I64, self.dtype), ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), @@ -1706,6 +1760,19 @@ impl MetalStorage { ("ge", DType::BF16) => (strided::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (strided::gt::BFLOAT, DType::U8), + ("badd", DType::I32) => (strided::add::I32, self.dtype), + ("bsub", DType::I32) => (strided::sub::I32, self.dtype), + ("bmul", DType::I32) => (strided::mul::I32, self.dtype), + ("bdiv", DType::I32) => (strided::div::I32, self.dtype), + ("bminimum", DType::I32) => (strided::min::I32, self.dtype), + ("bmaximum", DType::I32) => (strided::max::I32, self.dtype), + ("eq", DType::I32) => (strided::eq::I32, DType::U8), + ("ne", DType::I32) => (strided::ne::I32, DType::U8), + ("le", DType::I32) => (strided::le::I32, DType::U8), + ("lt", DType::I32) => (strided::lt::I32, DType::U8), + ("ge", DType::I32) => (strided::ge::I32, DType::U8), + ("gt", DType::I32) => (strided::gt::I32, DType::U8), + ("badd", DType::I64) => (strided::add::I64, self.dtype), ("bsub", DType::I64) => (strided::sub::I64, self.dtype), ("bmul", DType::I64) => (strided::mul::I64, self.dtype), @@ -1861,6 +1928,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorageRef::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorageRef::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), @@ -1874,6 +1942,7 @@ impl BackendDevice for MetalDevice { let (count, buffer) = match storage { CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)), + CpuStorage::I32(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)), CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)), diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 83e4f6527f..b321a619f8 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -85,6 +85,7 @@ impl Header { DType::F16 => "f2", DType::F32 => "f4", DType::F64 => "f8", + DType::I32 => "i4", DType::I64 => "i8", DType::U32 => "u4", DType::U8 => "u1", @@ -234,6 +235,11 @@ impl Tensor { reader.read_u32_into::(&mut data_t)?; Tensor::from_vec(data_t, shape, &Device::Cpu) } + DType::I32 => { + let mut data_t = vec![0i32; elem_count]; + reader.read_i32_into::(&mut data_t)?; + Tensor::from_vec(data_t, shape, &Device::Cpu) + } DType::I64 => { let mut data_t = vec![0i64; elem_count]; reader.read_i64_into::(&mut data_t)?; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 49ba44be89..75931ee2fe 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -189,6 +189,7 @@ pub trait UnaryOpT { fn f64(v1: f64) -> f64; fn u8(v1: u8) -> u8; fn u32(v1: u32) -> u32; + fn i32(v1: i32) -> i32; fn i64(v1: i64) -> i64; // There is no very good way to represent optional function in traits so we go for an explicit @@ -213,6 +214,7 @@ pub trait BinaryOpT { fn f64(v1: f64, v2: f64) -> f64; fn u8(v1: u8, v2: u8) -> u8; fn u32(v1: u32, v2: u32) -> u32; + fn i32(v1: i32, v2: i32) -> i32; fn i64(v1: i64, v2: i64) -> i64; const BF16_VEC: bool = false; @@ -229,6 +231,8 @@ pub trait BinaryOpT { fn u32_vec(_xs1: &[u32], _xs2: &[u32], _ys: &mut [u32]) {} const I64_VEC: bool = false; fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {} + const I32_VEC: bool = false; + fn i32_vec(_xs1: &[i32], _xs2: &[i32], _ys: &mut [i32]) {} } pub(crate) struct Add; @@ -288,6 +292,10 @@ macro_rules! bin_op { $e(v1, v2) } #[inline(always)] + fn i32(v1: i32, v2: i32) -> i32 { + $e(v1, v2) + } + #[inline(always)] fn i64(v1: i64, v2: i64) -> i64 { $e(v1, v2) } @@ -379,6 +387,10 @@ macro_rules! unary_op { fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } } }; @@ -415,6 +427,10 @@ macro_rules! unary_op { fn i64(_: i64) -> i64 { todo!("no unary function for i64") } + #[inline(always)] + fn i32(_: i32) -> i32 { + todo!("no unary function for i32") + } #[cfg(feature = "mkl")] const F32_VEC: bool = true; @@ -514,6 +530,10 @@ impl UnaryOpT for Gelu { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } const KERNEL: &'static str = "ugelu"; #[cfg(feature = "mkl")] @@ -587,6 +607,10 @@ impl UnaryOpT for Erf { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } } /// Silu operation @@ -621,6 +645,10 @@ impl UnaryOpT for Silu { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } const KERNEL: &'static str = "usilu"; #[cfg(feature = "mkl")] @@ -692,6 +720,10 @@ impl UnaryOpT for Abs { fn i64(v: i64) -> i64 { v.abs() } + #[inline(always)] + fn i32(v: i32) -> i32 { + v.abs() + } } impl UnaryOpT for Ceil { @@ -726,6 +758,10 @@ impl UnaryOpT for Ceil { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } impl UnaryOpT for Floor { @@ -760,6 +796,10 @@ impl UnaryOpT for Floor { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } impl UnaryOpT for Round { @@ -794,6 +834,10 @@ impl UnaryOpT for Round { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } impl UnaryOpT for GeluErf { @@ -828,6 +872,10 @@ impl UnaryOpT for GeluErf { fn i64(_: i64) -> i64 { 0 } + #[inline(always)] + fn i32(_: i32) -> i32 { + 0 + } } impl UnaryOpT for Relu { @@ -862,6 +910,10 @@ impl UnaryOpT for Relu { fn i64(v: i64) -> i64 { v } + #[inline(always)] + fn i32(v: i32) -> i32 { + v + } } /// `BackpropOp` is a wrapper around `Option`. The main goal is to ensure that dependencies are @@ -960,4 +1012,8 @@ impl UnaryOpT for Sign { fn i64(v: i64) -> i64 { (v > 0) as i64 - (v < 0) as i64 } + #[inline(always)] + fn i32(v: i32) -> i32 { + (v > 0) as i32 - (v < 0) as i32 + } } diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 5ea1f192b3..162928ec7d 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -11,6 +11,7 @@ impl From for st::Dtype { DType::U8 => st::Dtype::U8, DType::U32 => st::Dtype::U32, DType::I64 => st::Dtype::I64, + DType::I32 => st::Dtype::I32, DType::BF16 => st::Dtype::BF16, DType::F16 => st::Dtype::F16, DType::F32 => st::Dtype::F32, @@ -187,6 +188,7 @@ impl Tensor { match dtype { DType::U8 => convert_slice::(data, shape, device), DType::U32 => convert_slice::(data, shape, device), + DType::I32 => convert_slice::(data, shape, device), DType::I64 => convert_slice::(data, shape, device), DType::BF16 => convert_slice::(data, shape, device), DType::F16 => convert_slice::(data, shape, device), @@ -204,10 +206,7 @@ fn convert(view: &st::TensorView<'_>, device: &Device) -> Result { convert_with_cast_::(view, device, conv) } st::Dtype::U32 => convert_::(view, device), - st::Dtype::I32 => { - let conv = |x| Ok(i64::from(x)); - convert_with_cast_::(view, device, conv) - } + st::Dtype::I32 => convert_::(view, device), st::Dtype::I64 => convert_::(view, device), st::Dtype::BF16 => convert_::(view, device), st::Dtype::F16 => convert_::(view, device), @@ -223,6 +222,7 @@ fn convert_back(tensor: &Tensor) -> Result> { match tensor.dtype() { DType::U8 => Ok(convert_back_::(tensor.to_vec1()?)), DType::U32 => Ok(convert_back_::(tensor.to_vec1()?)), + DType::I32 => Ok(convert_back_::(tensor.to_vec1()?)), DType::I64 => Ok(convert_back_::(tensor.to_vec1()?)), DType::F16 => Ok(convert_back_::(tensor.to_vec1()?)), DType::BF16 => Ok(convert_back_::(tensor.to_vec1()?)), diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 614a37fe65..92ad1d5adc 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -65,6 +65,7 @@ impl crate::CustomOp1 for ArgSort { let sort_indexes = match storage { crate::CpuStorage::U8(vs) => self.asort(vs, layout), crate::CpuStorage::U32(vs) => self.asort(vs, layout), + crate::CpuStorage::I32(vs) => self.asort(vs, layout), crate::CpuStorage::I64(vs) => self.asort(vs, layout), crate::CpuStorage::BF16(vs) => self.asort(vs, layout), crate::CpuStorage::F16(vs) => self.asort(vs, layout), diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 567b49f1db..7bba53af31 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -17,6 +17,10 @@ fn ones(device: &Device) -> Result<()> { Tensor::ones((2, 3), DType::U32, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], ); + assert_eq!( + Tensor::ones((2, 3), DType::I32, device)?.to_vec2::()?, + [[1, 1, 1], [1, 1, 1]], + ); assert_eq!( Tensor::ones((2, 3), DType::I64, device)?.to_vec2::()?, [[1, 1, 1], [1, 1, 1]], @@ -805,7 +809,7 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - for dtype in [DType::U8, DType::U32, DType::I64] { + for dtype in [DType::U8, DType::U32, DType::I32, DType::I64] { let ids = ids.to_dtype(dtype)?; let hs = t.index_select(&ids, 1)?; assert_eq!( diff --git a/candle-kernels/src/affine.cu b/candle-kernels/src/affine.cu index 540d0819f5..c3ff5b8753 100644 --- a/candle-kernels/src/affine.cu +++ b/candle-kernels/src/affine.cu @@ -40,4 +40,5 @@ AFFINE_OP(float, affine_f32) AFFINE_OP(double, affine_f64) AFFINE_OP(uint8_t, affine_u8) AFFINE_OP(uint32_t, affine_u32) +AFFINE_OP(int32_t, affine_i32) AFFINE_OP(int64_t, affine_i64) diff --git a/candle-kernels/src/binary.cu b/candle-kernels/src/binary.cu index d44e3b20ee..f534fc76ad 100644 --- a/candle-kernels/src/binary.cu +++ b/candle-kernels/src/binary.cu @@ -35,65 +35,77 @@ BINARY_OP(float, badd_f32, x + y) BINARY_OP(double, badd_f64, x + y); BINARY_OP(uint8_t, badd_u8, x + y); BINARY_OP(uint32_t, badd_u32, x + y); +BINARY_OP(int32_t, badd_i32, x + y); BINARY_OP(int64_t, badd_i64, x + y); BINARY_OP(float, bdiv_f32, x / y) BINARY_OP(double, bdiv_f64, x / y); BINARY_OP(uint8_t, bdiv_u8, x / y); BINARY_OP(uint32_t, bdiv_u32, x / y); +BINARY_OP(int32_t, bdiv_i32, x / y); BINARY_OP(int64_t, bdiv_i64, x / y); BINARY_OP(float, bmul_f32, x * y) BINARY_OP(double, bmul_f64, x * y); BINARY_OP(uint8_t, bmul_u8, x * y); BINARY_OP(uint32_t, bmul_u32, x * y); +BINARY_OP(int32_t, bmul_i32, x * y); BINARY_OP(int64_t, bmul_i64, x * y); BINARY_OP(float, bsub_f32, x - y) BINARY_OP(double, bsub_f64, x - y); BINARY_OP(uint8_t, bsub_u8, x - y); BINARY_OP(uint32_t, bsub_u32, x - y); +BINARY_OP(int32_t, bsub_i32, x - y); BINARY_OP(int64_t, bsub_i64, x - y); BINARY_OP(float, bminimum_f32, ming(x, y)); BINARY_OP(double, bminimum_f64, ming(x, y)); BINARY_OP(uint8_t, bminimum_u8, ming(x, y)); BINARY_OP(uint32_t, bminimum_u32, ming(x, y)); +BINARY_OP(int32_t, bminimum_i32, ming(x, y)); BINARY_OP(int64_t, bminimum_i64, ming(x, y)); BINARY_OP(float, bmaximum_f32, maxg(x, y)); BINARY_OP(double, bmaximum_f64, maxg(x, y)); BINARY_OP(uint8_t, bmaximum_u8, maxg(x, y)); BINARY_OP(uint32_t, bmaximum_u32, maxg(x, y)); +BINARY_OP(int32_t, bmaximum_i32, maxg(x, y)); BINARY_OP(int64_t, bmaximum_i64, maxg(x, y)); BINARY_OP_OUT(float, uint8_t, eq_f32, x == y) BINARY_OP_OUT(double, uint8_t, eq_f64, x == y) BINARY_OP_OUT(uint8_t, uint8_t, eq_u8, x == y) BINARY_OP_OUT(uint32_t, uint8_t, eq_u32, x == y) +BINARY_OP_OUT(int32_t, uint8_t, eq_i32, x == y) BINARY_OP_OUT(int64_t, uint8_t, eq_i64, x == y) BINARY_OP_OUT(float, uint8_t, ne_f32, x != y) BINARY_OP_OUT(double, uint8_t, ne_f64, x != y) BINARY_OP_OUT(uint8_t, uint8_t, ne_u8, x != y) BINARY_OP_OUT(uint32_t, uint8_t, ne_u32, x != y) +BINARY_OP_OUT(int32_t, uint8_t, ne_i32, x != y) BINARY_OP_OUT(int64_t, uint8_t, ne_i64, x != y) BINARY_OP_OUT(float, uint8_t, lt_f32, x < y) BINARY_OP_OUT(double, uint8_t, lt_f64, x < y) BINARY_OP_OUT(uint8_t, uint8_t, lt_u8, x < y) BINARY_OP_OUT(uint32_t, uint8_t, lt_u32, x < y) +BINARY_OP_OUT(int32_t, uint8_t, lt_i32, x < y) BINARY_OP_OUT(int64_t, uint8_t, lt_i64, x < y) BINARY_OP_OUT(float, uint8_t, le_f32, x <= y) BINARY_OP_OUT(double, uint8_t, le_f64, x <= y) BINARY_OP_OUT(uint8_t, uint8_t, le_u8, x <= y) BINARY_OP_OUT(uint32_t, uint8_t, le_u32, x <= y) +BINARY_OP_OUT(int32_t, uint8_t, le_i32, x <= y) BINARY_OP_OUT(int64_t, uint8_t, le_i64, x <= y) BINARY_OP_OUT(float, uint8_t, gt_f32, x > y) BINARY_OP_OUT(double, uint8_t, gt_f64, x > y) BINARY_OP_OUT(uint8_t, uint8_t, gt_u8, x > y) BINARY_OP_OUT(uint32_t, uint8_t, gt_u32, x > y) +BINARY_OP_OUT(int32_t, uint8_t, gt_i32, x > y) BINARY_OP_OUT(int64_t, uint8_t, gt_i64, x > y) BINARY_OP_OUT(float, uint8_t, ge_f32, x >= y) BINARY_OP_OUT(double, uint8_t, ge_f64, x >= y) BINARY_OP_OUT(uint8_t, uint8_t, ge_u8, x >= y) BINARY_OP_OUT(uint32_t, uint8_t, ge_u32, x >= y) +BINARY_OP_OUT(int32_t, uint8_t, ge_i32, x >= y) BINARY_OP_OUT(int64_t, uint8_t, ge_i64, x >= y) diff --git a/candle-kernels/src/cast.cu b/candle-kernels/src/cast.cu index 90f5e7ba48..f92ac0cbf9 100644 --- a/candle-kernels/src/cast.cu +++ b/candle-kernels/src/cast.cu @@ -83,6 +83,8 @@ CAST_OP(double, __nv_bfloat16, cast_f64_bf16) CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8) CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) #else #include #if CUDA_VERSION >= 11000 @@ -94,6 +96,8 @@ CAST_THROUGH_OP(__nv_bfloat16, double, float, cast_bf16_f64) CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16) CAST_THROUGH_OP(double, __nv_bfloat16, float, cast_f64_bf16) CAST_THROUGH_OP(uint8_t, __nv_bfloat16, float, cast_u8_bf16) +CAST_THROUGH_OP(int32_t, __nv_bfloat16, float, cast_i32_bf16) +CAST_THROUGH_OP(__nv_bfloat16, int32_t, float, cast_bf16_i32) #endif #endif @@ -108,34 +112,48 @@ CAST_OP(uint8_t, __half, cast_u8_f16 ) CAST_OP(uint32_t, __half, cast_u32_f16) CAST_OP(float, __half, cast_f32_f16) CAST_OP(double, __half, cast_f64_f16) +CAST_OP(int32_t, __half, cast_i32_f16 ) +CAST_THROUGH_OP(__half, int32_t, float, cast_f16_i32) #endif CAST_OP(uint32_t, uint32_t, cast_u32_u32) CAST_OP(uint32_t, uint8_t, cast_u32_u8 ) CAST_OP(uint32_t, int64_t, cast_u32_i64 ) +CAST_OP(uint32_t, int32_t, cast_u32_i32 ) CAST_OP(uint32_t, float, cast_u32_f32) CAST_OP(uint32_t, double, cast_u32_f64) CAST_OP(uint8_t, uint32_t, cast_u8_u32) CAST_OP(uint8_t, uint8_t, cast_u8_u8 ) +CAST_OP(uint8_t, int32_t, cast_u8_i32 ) CAST_OP(uint8_t, int64_t, cast_u8_i64 ) CAST_OP(uint8_t, float, cast_u8_f32) CAST_OP(uint8_t, double, cast_u8_f64) CAST_OP(int64_t, uint32_t, cast_i64_u32) CAST_OP(int64_t, uint8_t, cast_i64_u8 ) +CAST_OP(int64_t, int32_t, cast_i64_i32 ) CAST_OP(int64_t, int64_t, cast_i64_i64 ) CAST_OP(int64_t, float, cast_i64_f32) CAST_OP(int64_t, double, cast_i64_f64) +CAST_OP(int32_t, uint32_t, cast_i32_u32) +CAST_OP(int32_t, uint8_t, cast_i32_u8 ) +CAST_OP(int32_t, int64_t, cast_i32_i64 ) +CAST_OP(int32_t, int32_t, cast_i32_i32 ) +CAST_OP(int32_t, float, cast_i32_f32) +CAST_OP(int32_t, double, cast_i32_f64) + CAST_OP(float, uint8_t, cast_f32_u8 ) CAST_OP(float, uint32_t, cast_f32_u32) +CAST_OP(float, int32_t, cast_f32_i32 ) CAST_OP(float, int64_t, cast_f32_i64 ) CAST_OP(float, float, cast_f32_f32) CAST_OP(float, double, cast_f32_f64) CAST_OP(double, uint8_t, cast_f64_u8 ) CAST_OP(double, uint32_t, cast_f64_u32) +CAST_OP(double, int32_t, cast_f64_i32 ) CAST_OP(double, int64_t, cast_f64_i64 ) CAST_OP(double, float, cast_f64_f32) CAST_OP(double, double, cast_f64_f64) diff --git a/candle-kernels/src/cuda_utils.cuh b/candle-kernels/src/cuda_utils.cuh index 2673b8aaf1..3d581fbfdb 100644 --- a/candle-kernels/src/cuda_utils.cuh +++ b/candle-kernels/src/cuda_utils.cuh @@ -152,6 +152,8 @@ __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ float copysigng(float a, float b) { return copysignf(a, b); } __device__ __forceinline__ double copysigng(double a, double b) { return copysign(a, b); } +__device__ __forceinline__ int32_t ming(int32_t a, int32_t b) { return min(a, b); } +__device__ __forceinline__ int32_t maxg(int32_t a, int32_t b) { return max(a, b); } __device__ __forceinline__ int64_t ming(int64_t a, int64_t b) { return min(a, b); } __device__ __forceinline__ int64_t maxg(int64_t a, int64_t b) { return max(a, b); } __device__ __forceinline__ uint32_t ming(uint32_t a, uint32_t b) { return min(a, b); } diff --git a/candle-kernels/src/fill.cu b/candle-kernels/src/fill.cu index ca448d989f..42bfddfd9f 100644 --- a/candle-kernels/src/fill.cu +++ b/candle-kernels/src/fill.cu @@ -9,6 +9,7 @@ __device__ void fill_with(T *buf, T value, const size_t numel) { } extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); } +extern "C" __global__ void fill_i32(int32_t *buf, int32_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); } extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); } @@ -34,6 +35,7 @@ COPY2D_OP(float, copy2d_f32) COPY2D_OP(double, copy2d_f64) COPY2D_OP(uint8_t, copy2d_u8) COPY2D_OP(uint32_t, copy2d_u32) +COPY2D_OP(int32_t, copy2d_i32) COPY2D_OP(int64_t, copy2d_i64) #if __CUDA_ARCH__ >= 530 diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 8af2954d13..2f3df4de1b 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -147,44 +147,61 @@ extern "C" __global__ void FN_NAME( \ #if __CUDA_ARCH__ >= 800 +IS_OP(__nv_bfloat16, int32_t, is_i32_bf16) IS_OP(__nv_bfloat16, int64_t, is_i64_bf16) IS_OP(__nv_bfloat16, uint32_t, is_u32_bf16) IS_OP(__nv_bfloat16, uint8_t, is_u8_bf16) +GATHER_OP(__nv_bfloat16, int32_t, gather_i32_bf16) GATHER_OP(__nv_bfloat16, int64_t, gather_i64_bf16) GATHER_OP(__nv_bfloat16, uint32_t, gather_u32_bf16) GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) +IA_OP(__nv_bfloat16, int32_t, ia_i32_bf16) IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) +SA_OP(__nv_bfloat16, int32_t, sa_i32_bf16) SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +IS_OP(__half, int32_t, is_i32_f16) IS_OP(__half, int64_t, is_i64_f16) IS_OP(__half, uint32_t, is_u32_f16) IS_OP(__half, uint8_t, is_u8_f16) +GATHER_OP(__half, int32_t, gather_i32_f16) GATHER_OP(__half, int64_t, gather_i64_f16) GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) +IA_OP(__half, int32_t, ia_i32_f16) IA_OP(__half, int64_t, ia_i64_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) +SA_OP(__half, int32_t, sa_i32_f16) SA_OP(__half, int64_t, sa_i64_f16) SA_OP(__half, uint32_t, sa_u32_f16) SA_OP(__half, uint8_t, sa_u8_f16) #endif +IS_OP(float, int32_t, is_i32_f32) +IS_OP(double, int32_t, is_i32_f64) +IS_OP(uint8_t, int32_t, is_i32_u8) +IS_OP(uint32_t, int32_t, is_i32_u32) +IS_OP(int32_t, int32_t, is_i32_i32) +IS_OP(int64_t, int32_t, is_i32_i64) + IS_OP(float, int64_t, is_i64_f32) IS_OP(double, int64_t, is_i64_f64) IS_OP(uint8_t, int64_t, is_i64_u8) IS_OP(uint32_t, int64_t, is_i64_u32) IS_OP(int64_t, int64_t, is_i64_i64) +IS_OP(int32_t, int64_t, is_i64_i32) IS_OP(float, uint32_t, is_u32_f32) IS_OP(double, uint32_t, is_u32_f64) IS_OP(uint8_t, uint32_t, is_u32_u8) +IS_OP(int32_t, uint32_t, is_u32_i32) IS_OP(int64_t, uint32_t, is_u32_i64) IS_OP(uint32_t, uint32_t, is_u32_u32) @@ -192,17 +209,27 @@ IS_OP(float, uint8_t, is_u8_f32) IS_OP(double, uint8_t, is_u8_f64) IS_OP(uint8_t, uint8_t, is_u8_u8) IS_OP(uint32_t, uint8_t, is_u8_u32) +IS_OP(int32_t, uint8_t, is_u8_i32) IS_OP(int64_t, uint8_t, is_u8_i64) +GATHER_OP(float, int32_t, gather_i32_f32) +GATHER_OP(double, int32_t, gather_i32_f64) +GATHER_OP(uint8_t, int32_t, gather_i32_u8) +GATHER_OP(uint32_t, int32_t, gather_i32_u32) +GATHER_OP(int32_t, int32_t, gather_i32_i32) +GATHER_OP(int64_t, int32_t, gather_i32_i64) + GATHER_OP(float, int64_t, gather_i64_f32) GATHER_OP(double, int64_t, gather_i64_f64) GATHER_OP(uint8_t, int64_t, gather_i64_u8) GATHER_OP(uint32_t, int64_t, gather_i64_u32) GATHER_OP(int64_t, int64_t, gather_i64_i64) +GATHER_OP(int32_t, int64_t, gather_i64_i32) GATHER_OP(float, uint32_t, gather_u32_f32) GATHER_OP(double, uint32_t, gather_u32_f64) GATHER_OP(uint8_t, uint32_t, gather_u32_u8) +GATHER_OP(int32_t, uint32_t, gather_u32_i32) GATHER_OP(int64_t, uint32_t, gather_u32_i64) GATHER_OP(uint32_t, uint32_t, gather_u32_u32) @@ -210,17 +237,26 @@ GATHER_OP(float, uint8_t, gather_u8_f32) GATHER_OP(double, uint8_t, gather_u8_f64) GATHER_OP(uint8_t, uint8_t, gather_u8_u8) GATHER_OP(uint32_t, uint8_t, gather_u8_u32) +GATHER_OP(int32_t, uint8_t, gather_u8_i32) GATHER_OP(int64_t, uint8_t, gather_u8_i64) +IA_OP(float, int32_t, ia_i32_f32) +IA_OP(double, int32_t, ia_i32_f64) +IA_OP(uint8_t, int32_t, ia_i32_u8) +IA_OP(int32_t, int32_t, ia_i32_i32) +IA_OP(uint32_t, int32_t, ia_i32_u32) + IA_OP(float, int64_t, ia_i64_f32) IA_OP(double, int64_t, ia_i64_f64) IA_OP(uint8_t, int64_t, ia_i64_u8) IA_OP(int64_t, int64_t, ia_i64_i64) IA_OP(uint32_t, int64_t, ia_i64_u32) +IA_OP(int32_t, int64_t, ia_i64_i32) IA_OP(float, uint32_t, ia_u32_f32) IA_OP(double, uint32_t, ia_u32_f64) IA_OP(uint8_t, uint32_t, ia_u32_u8) +IA_OP(int32_t, uint32_t, ia_u32_i32) IA_OP(int64_t, uint32_t, ia_u32_i64) IA_OP(uint32_t, uint32_t, ia_u32_u32) @@ -228,17 +264,26 @@ IA_OP(float, uint8_t, ia_u8_f32) IA_OP(double, uint8_t, ia_u8_f64) IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) +IA_OP(int32_t, uint8_t, ia_u8_i32) IA_OP(int64_t, uint8_t, ia_u8_i64) +SA_OP(float, int32_t, sa_i32_f32) +SA_OP(double, int32_t, sa_i32_f64) +SA_OP(uint8_t, int32_t, sa_i32_u8) +SA_OP(int32_t, int32_t, sa_i32_i32) +SA_OP(uint32_t, int32_t, sa_i32_u32) + SA_OP(float, int64_t, sa_i64_f32) SA_OP(double, int64_t, sa_i64_f64) SA_OP(uint8_t, int64_t, sa_i64_u8) +SA_OP(int32_t, int64_t, sa_i64_i32) SA_OP(int64_t, int64_t, sa_i64_i64) SA_OP(uint32_t, int64_t, sa_i64_u32) SA_OP(float, uint32_t, sa_u32_f32) SA_OP(double, uint32_t, sa_u32_f64) SA_OP(uint8_t, uint32_t, sa_u32_u8) +SA_OP(int32_t, uint32_t, sa_u32_i32) SA_OP(int64_t, uint32_t, sa_u32_i64) SA_OP(uint32_t, uint32_t, sa_u32_u32) @@ -246,4 +291,5 @@ SA_OP(float, uint8_t, sa_u8_f32) SA_OP(double, uint8_t, sa_u8_f64) SA_OP(uint8_t, uint8_t, sa_u8_u8) SA_OP(uint32_t, uint8_t, sa_u8_u32) +SA_OP(int32_t, uint8_t, sa_u8_i32) SA_OP(int64_t, uint8_t, sa_u8_i64) diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index aaac24a146..9a1354a8dc 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -606,5 +606,6 @@ ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64) FAST_OP(float, fast_min_f32, fast_max_f32, fast_argmin_f32, fast_argmax_f32, fast_sum_f32) FAST_OP(double, fast_min_f64, fast_max_f64, fast_argmin_f64, fast_argmax_f64, fast_sum_f64) FAST_OP(uint32_t, fast_min_u32, fast_max_u32, fast_argmin_u32, fast_argmax_u32, fast_sum_u32) +FAST_OP(int32_t, fast_min_i32, fast_max_i32, fast_argmin_i32, fast_argmax_i32, fast_sum_i32) FAST_OP(int64_t, fast_min_i64, fast_max_i64, fast_argmin_i64, fast_argmax_i64, fast_sum_i64) FAST_OP(uint8_t, fast_min_u8, fast_max_u8, fast_argmin_u8, fast_argmax_u8, fast_sum_u8) diff --git a/candle-kernels/src/sort.cu b/candle-kernels/src/sort.cu index 08f1f9fc29..7fecf8413e 100644 --- a/candle-kernels/src/sort.cu +++ b/candle-kernels/src/sort.cu @@ -85,4 +85,5 @@ ASORT_OP(float, f32) ASORT_OP(double, f64) ASORT_OP(uint8_t, u8) ASORT_OP(uint32_t, u32) +ASORT_OP(int32_t, i32) ASORT_OP(int64_t, i64) diff --git a/candle-kernels/src/ternary.cu b/candle-kernels/src/ternary.cu index aaa8a881fb..4617c08fbe 100644 --- a/candle-kernels/src/ternary.cu +++ b/candle-kernels/src/ternary.cu @@ -33,17 +33,25 @@ extern "C" __global__ void FN_NAME( \ } \ #if __CUDA_ARCH__ >= 800 +WHERE_OP(__nv_bfloat16, int32_t, where_i32_bf16) WHERE_OP(__nv_bfloat16, int64_t, where_i64_bf16) WHERE_OP(__nv_bfloat16, uint32_t, where_u32_bf16) WHERE_OP(__nv_bfloat16, uint8_t, where_u8_bf16) #endif #if __CUDA_ARCH__ >= 530 +WHERE_OP(__half, int32_t, where_i32_f16) WHERE_OP(__half, int64_t, where_i64_f16) WHERE_OP(__half, uint32_t, where_u32_f16) WHERE_OP(__half, uint8_t, where_u8_f16) #endif +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(double, int32_t, where_i32_f64) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int32_t, int32_t, where_i32_i64) + WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(double, int64_t, where_i64_f64) WHERE_OP(uint8_t, int64_t, where_i64_u8) @@ -54,10 +62,12 @@ WHERE_OP(float, uint32_t, where_u32_f32) WHERE_OP(double, uint32_t, where_u32_f64) WHERE_OP(uint8_t, uint32_t, where_u32_u8) WHERE_OP(uint32_t, uint32_t, where_u32_u32) +WHERE_OP(int32_t, uint32_t, where_u32_i32) WHERE_OP(int64_t, uint32_t, where_u32_i64) WHERE_OP(float, uint8_t, where_u8_f32) WHERE_OP(double, uint8_t, where_u8_f64) WHERE_OP(uint8_t, uint8_t, where_u8_u8) WHERE_OP(uint32_t, uint8_t, where_u8_u32) +WHERE_OP(int32_t, uint8_t, where_u8_i32) WHERE_OP(int64_t, uint8_t, where_u8_i64) diff --git a/candle-kernels/src/unary.cu b/candle-kernels/src/unary.cu index c82a88375d..21d3d995c0 100644 --- a/candle-kernels/src/unary.cu +++ b/candle-kernels/src/unary.cu @@ -153,6 +153,7 @@ UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x)) UNARY_OP(uint8_t, ucopy_u8, x) UNARY_OP(uint32_t, ucopy_u32, x) +UNARY_OP(int32_t, ucopy_i32, x) UNARY_OP(int64_t, ucopy_i64, x) UNARY_OP(float, ucopy_f32, x) UNARY_OP(double, ucopy_f64, x) diff --git a/candle-metal-kernels/src/binary.metal b/candle-metal-kernels/src/binary.metal index e83498e40d..a9b8129c3a 100644 --- a/candle-metal-kernels/src/binary.metal +++ b/candle-metal-kernels/src/binary.metal @@ -58,13 +58,15 @@ kernel void FN_NAME_STRIDED( \ BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int32_t, int32_t, NAME##_i32, NAME##_i32_strided); #define BINARY_OP_OUT(NAME, FN) \ BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); +BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); \ +BINARY(FN, int32_t, uint8_t, NAME##_i32, NAME##_i32_strided); #define INT64_BINARY_OP(NAME, FN) \ BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); diff --git a/candle-metal-kernels/src/cast.metal b/candle-metal-kernels/src/cast.metal index 2af3fdceb0..c8122ccf0a 100644 --- a/candle-metal-kernels/src/cast.metal +++ b/candle-metal-kernels/src/cast.metal @@ -76,6 +76,7 @@ kernel void FN_NAME_STRIDED( \ CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float) CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t) CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half) +CAST(cast_u32_i32, cast_u32_i32_strided, uint32_t, int32_t) #if __METAL_VERSION__ >= 220 CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t) #endif @@ -87,6 +88,7 @@ CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat) CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t) CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float) CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half) +CAST(cast_u8_i32, cast_u8_i32_strided, uint8_t, int64_t) #if __METAL_VERSION__ >= 220 CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t) #endif @@ -98,6 +100,7 @@ CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat) CAST(cast_f16_f32, cast_f16_f32_strided, half, float) CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t) CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t) +CAST(cast_f16_i32, cast_f16_i32_strided, half, int64_t) CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) @@ -107,15 +110,27 @@ CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float) CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float) CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t) CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t) +CAST(cast_i64_i32, cast_i64_i32_strided, int64_t, int32_t) CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half) #if defined(__HAVE_BFLOAT__) CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float) #endif +// i32 +CAST(cast_i32_f32, cast_i32_f32_strided, int32_t, float) +CAST(cast_i32_u8, cast_i32_u8_strided, int32_t, uint8_t) +CAST(cast_i32_u32, cast_i32_u32_strided, int32_t, uint32_t) +CAST(cast_i32_i64, cast_i32_i64_strided, int32_t, int64_t) +CAST(cast_i32_f16, cast_i32_f16_strided, int32_t, half) +#if defined(__HAVE_BFLOAT__) +CAST_THROUGH(cast_i32_bf16, cast_i32_bf16_strided, int64_t, bfloat, float) +#endif + // f32 CAST(cast_f32_f16, cast_f32_f16_strided, float, half) CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t) CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t) +CAST(cast_f32_i32, cast_f32_i32_strided, float, int32_t) CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t) #if defined(__HAVE_BFLOAT__) CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) @@ -124,6 +139,7 @@ CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat) // bf16 #if defined(__HAVE_BFLOAT__) CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t) +CAST(cast_bf16_i32, cast_bf16_i32_strided, bfloat, int32_t) CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index 9eee97ca0a..eaa78d7b73 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -193,6 +193,12 @@ INDEX_OP(is_i64_f16, int64_t, half) INDEX_OP(is_i64_bf16, int64_t, bfloat) #endif +INDEX_OP(is_i32_f32, int32_t, float) +INDEX_OP(is_i32_f16, int32_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i32_bf16, int32_t, bfloat) +#endif + INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) @@ -213,9 +219,11 @@ GATHER_OP(gather_u32_bf16, uint, bfloat) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float) +SCATTER_ADD_OP(sa_i32_f32, int32_t, float) SCATTER_ADD_OP(sa_i64_f32, int64_t, float) SCATTER_ADD_OP(sa_u32_f16, uint32_t, half) SCATTER_ADD_OP(sa_u8_f16, uint8_t, half) +SCATTER_ADD_OP(sa_i32_f16, int32_t, half) SCATTER_ADD_OP(sa_i64_f16, int64_t, half) #if defined(__HAVE_BFLOAT__) SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat) @@ -226,6 +234,7 @@ SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat) // i64 INDEX_ADD_OP(ia_i64_f16, int64_t, half) INDEX_ADD_OP(ia_i64_f32, int64_t, float) +INDEX_ADD_OP(ia_i64_i32, int64_t, int32_t) INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t) INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t) INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) @@ -233,9 +242,21 @@ INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t) INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat) #endif +// i64 +INDEX_ADD_OP(ia_i32_f16, int32_t, half) +INDEX_ADD_OP(ia_i32_f32, int32_t, float) +INDEX_ADD_OP(ia_i32_i64, int32_t, int64_t) +INDEX_ADD_OP(ia_i32_i32, int32_t, int32_t) +INDEX_ADD_OP(ia_i32_u32, int32_t, uint32_t) +INDEX_ADD_OP(ia_i32_u8, int32_t, uint8_t) +#if defined(__HAVE_BFLOAT__) +INDEX_ADD_OP(ia_i32_bf16, int32_t, bfloat) +#endif + // u32 INDEX_ADD_OP(ia_u32_f16, uint32_t, half) INDEX_ADD_OP(ia_u32_f32, uint32_t, float) +INDEX_ADD_OP(ia_u32_i32, uint32_t, int32_t) INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t) INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t) INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t) @@ -246,6 +267,7 @@ INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat) // u8 INDEX_ADD_OP(ia_u8_f16, uint8_t, half) INDEX_ADD_OP(ia_u8_f32, uint8_t, float) +INDEX_ADD_OP(ia_u8_i32, uint8_t, int32_t) INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t) INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 743b9fe2b3..7a2a54b608 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -46,6 +46,7 @@ pub mod copy2d { pub const HALF: Kernel = Kernel("copy2d_f16"); pub const BFLOAT: Kernel = Kernel("copy2d_bf16"); pub const I64: Kernel = Kernel("copy2d_i64"); + pub const I32: Kernel = Kernel("copy2d_i32"); pub const U32: Kernel = Kernel("copy2d_u32"); pub const U8: Kernel = Kernel("copy2d_u8"); } @@ -62,6 +63,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8")); } @@ -72,6 +74,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16"); pub const BFLOAT: Kernel = Kernel("copy_bf16"); pub const I64: Kernel = Kernel("copy_i64"); + pub const I32: Kernel = Kernel("copy_i32"); pub const U32: Kernel = Kernel("copy_u32"); pub const U8: Kernel = Kernel("copy_u8"); } @@ -86,6 +89,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_tiled")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); } @@ -96,6 +100,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16_tiled"); pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); pub const I64: Kernel = Kernel("copy_i64_tiled"); + pub const I32: Kernel = Kernel("copy_i32_tiled"); pub const U32: Kernel = Kernel("copy_u32_tiled"); pub const U8: Kernel = Kernel("copy_u8_tiled"); } @@ -110,6 +115,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_strided")); pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_strided")); pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_strided")); + pub const I32: Kernel = Kernel(concat!(stringify!($name), "_i32_strided")); pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_strided")); pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_strided")); } @@ -120,6 +126,7 @@ macro_rules! ops{ pub const HALF: Kernel = Kernel("copy_f16_strided"); pub const BFLOAT: Kernel = Kernel("copy_bf16_strided"); pub const I64: Kernel = Kernel("copy_i64_strided"); + pub const I32: Kernel = Kernel("copy_i32_strided"); pub const U32: Kernel = Kernel("copy_u32_strided"); pub const U8: Kernel = Kernel("copy_u8_strided"); } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index e009ca1d6a..484fa0a1b1 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -602,6 +602,12 @@ ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif +REDUCE(x + y, fast_sum_i32_strided, int32_t, 0) +REDUCE(MIN(x, y), fast_min_i32_strided, int32_t, INT_MAX) +REDUCE(MAX(x, y), fast_max_i32_strided, int32_t, INT_MIN) +ARGMIN(fast_argmin_i32_strided, int32_t, INT_MAX) +ARGMAX(fast_argmax_i32_strided, int32_t, INT_MIN) + #if defined(__HAVE_BFLOAT__) REDUCE(x + y, fast_sum_bf16, bfloat, 0) REDUCE(x + y, fast_sum_bf16_strided, half, 0) diff --git a/candle-metal-kernels/src/sort.metal b/candle-metal-kernels/src/sort.metal index d71ab82234..b7cf71bb58 100644 --- a/candle-metal-kernels/src/sort.metal +++ b/candle-metal-kernels/src/sort.metal @@ -88,6 +88,7 @@ ARGSORT(float, f32) ARGSORT(half, f16) ARGSORT(uint8_t, u8) ARGSORT(uint32_t, u32) +ARGSORT(int32_t, i32) #if __METAL_VERSION__ >= 220 ARGSORT(int64_t, i64) diff --git a/candle-metal-kernels/src/ternary.metal b/candle-metal-kernels/src/ternary.metal index fe04f2378f..0e043332fe 100644 --- a/candle-metal-kernels/src/ternary.metal +++ b/candle-metal-kernels/src/ternary.metal @@ -75,11 +75,25 @@ WHERE_OP(float, int64_t, where_i64_f32) WHERE_OP(uint8_t, int64_t, where_i64_u8) WHERE_OP(uint32_t, int64_t, where_i64_u32) WHERE_OP(int64_t, int64_t, where_i64_i64) +WHERE_OP(int64_t, int32_t, where_i64_i32) #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, int64_t, where_i64_bf16) #endif #endif +WHERE_OP(int64_t, uint8_t, where_u8_i32) +WHERE_OP(int64_t, uint32_t, where_u32_i32) + +WHERE_OP(half, int32_t, where_i32_f16) +WHERE_OP(float, int32_t, where_i32_f32) +WHERE_OP(uint8_t, int32_t, where_i32_u8) +WHERE_OP(uint32_t, int32_t, where_i32_u32) +WHERE_OP(int64_t, int32_t, where_i32_i64) +WHERE_OP(int32_t, int32_t, where_i32_i32) +#if defined(__HAVE_BFLOAT__) +WHERE_OP(bfloat, int32_t, where_i32_bf16) +#endif + #if defined(__HAVE_BFLOAT__) WHERE_OP(bfloat, uint8_t, where_u8_bf16) WHERE_OP(bfloat, uint32_t, where_u32_bf16) diff --git a/candle-metal-kernels/src/unary.metal b/candle-metal-kernels/src/unary.metal index a82bfdbdd6..0c5a2736ee 100644 --- a/candle-metal-kernels/src/unary.metal +++ b/candle-metal-kernels/src/unary.metal @@ -169,6 +169,9 @@ UNARY(id, int64_t, copy_i64, copy_i64_strided) COPY2D(copy2d_i64, int64_t) #endif +UNARY(id, int32_t, copy_i32, copy_i32_strided) +COPY2D(copy2d_i32, int32_t) + #if defined(__HAVE_BFLOAT__) BFLOAT_UNARY_OP(cos) BFLOAT_UNARY_OP(sin) diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 0da2c70028..55b5542ed8 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -151,6 +151,7 @@ macro_rules! pydtype { }; } +pydtype!(i32, |v| v); pydtype!(i64, |v| v); pydtype!(u8, |v| v); pydtype!(u32, |v| v); @@ -200,6 +201,7 @@ trait MapDType { match t.dtype() { DType::U8 => self.f::(t), DType::U32 => self.f::(t), + DType::I32 => self.f::(t), DType::I64 => self.f::(t), DType::BF16 => self.f::(t), DType::F16 => self.f::(t),