Skip to content

Commit

Permalink
Add documentation examples for Tensor::i and Tensor::narrow metho…
Browse files Browse the repository at this point in the history
…ds (#2308)

* Add documentation examples for `Tensor` methods

* Apply fmt.

* Cosmetic tweaks.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
csicar and LaurentMazare authored Aug 10, 2024
1 parent 14db029 commit d3fe989
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 8 deletions.
102 changes: 94 additions & 8 deletions candle-core/src/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,114 @@ impl<T> IndexOp<T> for Tensor
where
T: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0., 1.],
/// [2., 3.],
/// [4., 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i(0)?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f64>()?, &[0., 1.]);
///
/// let c = a.i(..2)?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f64>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i(1..)?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f64>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
fn i(&self, index: T) -> Result<Tensor, Error> {
self.index(&[index.into()])
}
}

impl<A> IndexOp<(A,)> for Tensor
where
A: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[
/// [0f32, 1.],
/// [2. , 3.],
/// [4. , 5.]
/// ], &Device::Cpu)?;
///
/// let b = a.i((0,))?;
/// assert_eq!(b.shape().dims(), &[2]);
/// assert_eq!(b.to_vec1::<f32>()?, &[0., 1.]);
///
/// let c = a.i((..2,))?;
/// assert_eq!(c.shape().dims(), &[2, 2]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [0., 1.],
/// [2., 3.]
/// ]);
///
/// let d = a.i((1..,))?;
/// assert_eq!(d.shape().dims(), &[2, 2]);
/// assert_eq!(d.to_vec2::<f32>()?, &[
/// [2., 3.],
/// [4., 5.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
fn i(&self, (a,): (A,)) -> Result<Tensor, Error> {
self.index(&[a.into()])
}
}
#[allow(non_snake_case)]
impl<A, B> IndexOp<(A, B)> for Tensor
where
A: Into<TensorIndexer>,
B: Into<TensorIndexer>,
{
///```rust
/// use candle_core::{Tensor, DType, Device, IndexOp};
/// let a = Tensor::new(&[[0f32, 1., 2.], [3., 4., 5.], [6., 7., 8.]], &Device::Cpu)?;
///
/// let b = a.i((1, 0))?;
/// assert_eq!(b.to_vec0::<f32>()?, 3.);
///
/// let c = a.i((..2, 1))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
///
/// let d = a.i((2.., ..))?;
/// assert_eq!(c.shape().dims(), &[2]);
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
/// # Ok::<(), candle_core::Error>(())
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
self.index(&[a.into(), b.into()])
}
}

macro_rules! index_op_tuple {
($($t:ident),+) => {
($doc:tt, $($t:ident),+) => {
#[allow(non_snake_case)]
impl<$($t),*> IndexOp<($($t,)*)> for Tensor
where
$($t: Into<TensorIndexer>,)*
{
#[doc=$doc]
fn i(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor, Error> {
self.index(&[$($t.into(),)*])
}
}
};
}
index_op_tuple!(A);
index_op_tuple!(A, B);
index_op_tuple!(A, B, C);
index_op_tuple!(A, B, C, D);
index_op_tuple!(A, B, C, D, E);
index_op_tuple!(A, B, C, D, E, F);
index_op_tuple!(A, B, C, D, E, F, G);

index_op_tuple!("see [TensorIndex#method.i]", A, B, C);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F);
index_op_tuple!("see [TensorIndex#method.i]", A, B, C, D, E, F, G);
75 changes: 75 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ impl Tensor {

/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::full(3.5, (2, 4), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [3.5, 3.5, 3.5, 3.5],
/// [3.5, 3.5, 3.5, 3.5],
/// ]);
/// # Ok::<(), candle_core::Error>(())
pub fn full<D: crate::WithDType, S: Into<Shape>>(
value: D,
shape: S,
Expand All @@ -379,6 +388,13 @@ impl Tensor {
}

/// Creates a new 1D tensor from an iterator.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_iter( [1.0, 2.0, 3.0, 4.0].into_iter(), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[1.0, 2.0, 3.0, 4.0]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
device: &Device,
Expand All @@ -390,12 +406,26 @@ impl Tensor {

/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `1` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange(2., 5., &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2., 3., 4.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange<D: crate::WithDType>(start: D, end: D, device: &Device) -> Result<Self> {
Self::arange_step(start, end, D::one(), device)
}

/// Creates a new 1D tensor with values from the interval `[start, end)` taken with a common
/// difference `step` from `start`.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::arange_step(2.0, 4.0, 0.5, &Device::Cpu)?;
///
/// assert_eq!(a.to_vec1::<f64>()?, &[2.0, 2.5, 3.0, 3.5]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn arange_step<D: crate::WithDType>(
start: D,
end: D,
Expand Down Expand Up @@ -441,6 +471,16 @@ impl Tensor {
/// Creates a new tensor initialized with values from the input vector. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
/// If the device is cpu, no data copy is made.
///```rust
/// use candle_core::{Tensor, Device};
/// let a = Tensor::from_vec(vec!{1., 2., 3., 4., 5., 6.}, (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [1., 2., 3.],
/// [4., 5., 6.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
data: Vec<D>,
shape: S,
Expand All @@ -451,6 +491,17 @@ impl Tensor {

/// Creates a new tensor initialized with values from the input slice. The number of elements
/// in this vector must be the same as the number of elements defined by the shape.
///```rust
/// use candle_core::{Tensor, Device};
/// let values = vec![1., 2., 3., 4., 5., 6., 7., 8.];
/// let a = Tensor::from_slice(&values[1..7], (2, 3), &Device::Cpu)?;
///
/// assert_eq!(a.to_vec2::<f64>()?, &[
/// [2., 3., 4.],
/// [5., 6., 7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
array: &[D],
shape: S,
Expand Down Expand Up @@ -732,6 +783,30 @@ impl Tensor {

/// Returns a new tensor that is a narrowed version of the input, the dimension `dim`
/// ranges from `start` to `start + len`.
/// ```
/// use candle_core::{Tensor, Device};
/// let a = Tensor::new(&[
/// [0f32, 1., 2.],
/// [3. , 4., 5.],
/// [6. , 7., 8.]
/// ], &Device::Cpu)?;
///
/// let b = a.narrow(0, 1, 2)?;
/// assert_eq!(b.shape().dims(), &[2, 3]);
/// assert_eq!(b.to_vec2::<f32>()?, &[
/// [3., 4., 5.],
/// [6., 7., 8.]
/// ]);
///
/// let c = a.narrow(1, 1, 1)?;
/// assert_eq!(c.shape().dims(), &[3, 1]);
/// assert_eq!(c.to_vec2::<f32>()?, &[
/// [1.],
/// [4.],
/// [7.]
/// ]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
Expand Down

0 comments on commit d3fe989

Please sign in to comment.