Skip to content

Commit

Permalink
Expose candle gather op in pyo3.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Mar 18, 2024
1 parent 5860525 commit e368a86
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions candle-pyo3/py_src/candle/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,12 @@ class Tensor:
"""
pass

def gather(self, index, dim):
"""
Gathers values along an axis specified by dim.
"""
pass

def get(self, index: int) -> Tensor:
"""
Gets the value at the specified index.
Expand Down
6 changes: 6 additions & 0 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ impl PyTensor {
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}

/// Gathers values along an axis specified by dim.
fn gather(&self, index: &Self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?))
}

#[pyo3(text_signature = "(self, rhs:Tensor)")]
/// Performs a matrix multiplication between the two tensors.
/// &RETURNS&: Tensor
Expand Down

0 comments on commit e368a86

Please sign in to comment.