diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 82532f204..03e61bde4 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2454,6 +2454,49 @@ impl Tensor { pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { rhs.broadcast_mul(&self.log()?)?.exp() } + + /// Returns a view of which contains all slices of size `size` from self tensor in the dimension + /// `dim` and stepped by `step`. + pub fn unfold(&self, dim: D, size: usize, step: usize) -> Result { + // https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804 + let mut sizes = self.dims().to_vec(); + let mut strides = self.stride().to_vec(); + + let dim = dim.to_index(self.shape(), "unfold")?; + + let max_len = if self.dims().is_empty() { + 1 + } else { + sizes[dim] + }; + if size > max_len { + bail!( + "unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}" + ) + } + sizes.push(size); + strides.push(if self.dims().is_empty() { + 1 + } else { + strides[dim] + }); + + if !self.dims().is_empty() { + sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize; + strides[dim] *= step; + } + + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: Layout::new(sizes.into(), strides, self.layout.start_offset()), + op: BackpropOp::new1(self, Op::Reshape), + is_variable: false, + dtype: self.dtype, + device: self.device.clone(), + }; + Ok(Tensor(Arc::new(tensor_))) + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index cd5f4ca14..975f40ac9 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1345,3 +1345,15 @@ fn pow() -> Result<()> { ); Ok(()) } + +#[test] +fn unfold() -> Result<()> { + let x = Tensor::arange(0i64, 3 * 2, &Device::Cpu)?.reshape((3, 2))?; + let unfolded = x.unfold(0, 2, 1)?; + dbg!(&unfolded); + assert_eq!( + unfolded.to_vec3::()?, + vec![[[0i64, 2], [1, 3]], [[2, 4], [3, 5]]] + ); + Ok(()) +} diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 9a360c472..a6ba04d03 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -947,3 +947,19 @@ impl Module for Identity { Ok(xs.clone()) } } + +pub struct TopKOutput { + pub values: Tensor, + pub indices: Tensor, +} + +/// Top-K in the last dimension +pub fn topk_last_dim(xs: &Tensor, topk: usize) -> Result { + // Sorted descending + let sorted_indices = xs.arg_sort_last_dim(false)?; + let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?; + Ok(TopKOutput { + values: xs.gather(&topk_indices, D::Minus1)?, + indices: topk_indices, + }) +} diff --git a/candle-nn/tests/ops.rs b/candle-nn/tests/ops.rs index 65a8fbf28..a77794ba2 100644 --- a/candle-nn/tests/ops.rs +++ b/candle-nn/tests/ops.rs @@ -5,6 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor}; +use candle_nn::ops::TopKOutput; fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; @@ -206,6 +207,29 @@ fn sigmoid(device: &Device) -> Result<()> { Ok(()) } +fn topk(device: &Device) -> Result<()> { + // [[1, 3, 5], + // [2, 4, 6]] + let x = Tensor::arange(1f32, 7f32, device)? + .reshape((3, 2))? + .t()? + .contiguous()?; + let TopKOutput { values, indices } = candle_nn::ops::topk_last_dim(&x, 2)?; + assert_eq!( + x.to_vec2::()?, + vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]] + ); + assert_eq!( + values.to_vec2::()?, + vec![vec![5f32, 3f32], vec![6f32, 4f32]] + ); + assert_eq!( + indices.to_vec2::()?, + vec![vec![2u32, 1u32], vec![2u32, 1u32]] + ); + Ok(()) +} + test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal); test_device!(rope, rope_cpu, rope_gpu, rope_metal); test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal); @@ -213,3 +237,4 @@ test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal); test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal); test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal); test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal); +test_device!(topk, topk_cpu, topk_gpu, topk_metal);