Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor::unfold, ops::topk_last_dim #2375

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,49 @@ impl Tensor {
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}

/// Returns a view of which contains all slices of size `size` from self tensor in the dimension
/// `dim` and stepped by `step`.
pub fn unfold<D: Dim>(&self, dim: D, size: usize, step: usize) -> Result<Self> {
// https://github.com/pytorch/pytorch/blob/75b0720a97ac5d82e8a7a1a6ae7c5f7a87d7183d/aten/src/ATen/native/TensorShape.cpp#L3785-L3804
let mut sizes = self.dims().to_vec();
let mut strides = self.stride().to_vec();

let dim = dim.to_index(self.shape(), "unfold")?;

let max_len = if self.dims().is_empty() {
1
} else {
sizes[dim]
};
if size > max_len {
bail!(
"unsqueeze: maximum size for tensor at dimension {dim} is {max_len} but size is {size}"
)
}
sizes.push(size);
strides.push(if self.dims().is_empty() {
1
} else {
strides[dim]
});

if !self.dims().is_empty() {
sizes[dim] = ((sizes[dim] as f32 - size as f32) / step as f32 + 1.) as usize;
strides[dim] *= step;
}

let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: Layout::new(sizes.into(), strides, self.layout.start_offset()),
op: BackpropOp::new1(self, Op::Reshape),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
}

macro_rules! bin_trait {
Expand Down
12 changes: 12 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1345,3 +1345,15 @@ fn pow() -> Result<()> {
);
Ok(())
}

#[test]
fn unfold() -> Result<()> {
let x = Tensor::arange(0i64, 3 * 2, &Device::Cpu)?.reshape((3, 2))?;
let unfolded = x.unfold(0, 2, 1)?;
dbg!(&unfolded);
assert_eq!(
unfolded.to_vec3::<i64>()?,
vec![[[0i64, 2], [1, 3]], [[2, 4], [3, 5]]]
);
Ok(())
}
16 changes: 16 additions & 0 deletions candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,19 @@ impl Module for Identity {
Ok(xs.clone())
}
}

pub struct TopKOutput {
pub values: Tensor,
pub indices: Tensor,
}

/// Top-K in the last dimension
pub fn topk_last_dim(xs: &Tensor, topk: usize) -> Result<TopKOutput> {
// Sorted descending
let sorted_indices = xs.arg_sort_last_dim(false)?;
let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
Ok(TopKOutput {
values: xs.gather(&topk_indices, D::Minus1)?,
indices: topk_indices,
})
}
25 changes: 25 additions & 0 deletions candle-nn/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;

use candle::{test_device, test_utils::to_vec3_round, Device, Result, Tensor};
use candle_nn::ops::TopKOutput;

fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
Expand Down Expand Up @@ -206,10 +207,34 @@ fn sigmoid(device: &Device) -> Result<()> {
Ok(())
}

fn topk(device: &Device) -> Result<()> {
// [[1, 3, 5],
// [2, 4, 6]]
let x = Tensor::arange(1f32, 7f32, device)?
.reshape((3, 2))?
.t()?
.contiguous()?;
let TopKOutput { values, indices } = candle_nn::ops::topk_last_dim(&x, 2)?;
assert_eq!(
x.to_vec2::<f32>()?,
vec![vec![1f32, 3f32, 5f32], vec![2f32, 4f32, 6f32]]
);
assert_eq!(
values.to_vec2::<f32>()?,
vec![vec![5f32, 3f32], vec![6f32, 4f32]]
);
assert_eq!(
indices.to_vec2::<u32>()?,
vec![vec![2u32, 1u32], vec![2u32, 1u32]]
);
Ok(())
}

test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
test_device!(topk, topk_cpu, topk_gpu, topk_metal);
Loading