From 143c481c20abc3420e848eab075d1547a96cc447 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 18 Mar 2024 21:54:15 +0100 Subject: [PATCH] Expose candle gather op in pyo3. (#1870) --- candle-pyo3/py_src/candle/__init__.pyi | 6 ++++++ candle-pyo3/src/lib.rs | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/candle-pyo3/py_src/candle/__init__.pyi b/candle-pyo3/py_src/candle/__init__.pyi index aef0707d5..b0f05de59 100644 --- a/candle-pyo3/py_src/candle/__init__.pyi +++ b/candle-pyo3/py_src/candle/__init__.pyi @@ -324,6 +324,12 @@ class Tensor: """ pass + def gather(self, index, dim): + """ + Gathers values along an axis specified by dim. + """ + pass + def get(self, index: int) -> Tensor: """ Gets the value at the specified index. diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index 7b9a74134..e0d3bf300 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -448,6 +448,12 @@ impl PyTensor { Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?)) } + /// Gathers values along an axis specified by dim. + fn gather(&self, index: &Self, dim: i64) -> PyResult { + let dim = actual_dim(self, dim).map_err(wrap_err)?; + Ok(PyTensor(self.0.gather(index, dim).map_err(wrap_err)?)) + } + #[pyo3(text_signature = "(self, rhs:Tensor)")] /// Performs a matrix multiplication between the two tensors. /// &RETURNS&: Tensor