Skip to content

Commit

Permalink
Fix for the pyo3-onnx bindings...
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 1, 2024
1 parent ad9b80e commit 0546127
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ fn candle_functional_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
}

#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn candle_onnx_m(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
Expand All @@ -1608,9 +1608,9 @@ fn candle(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_submodule(&nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
let onnx = PyModule::new_bound(py, "onnx")?;
candle_onnx_m(py, &onnx)?;
m.add_submodule(&onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
Expand Down
2 changes: 1 addition & 1 deletion candle-pyo3/src/onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl PyONNXTensorDescriptor {
/// The shape of the tensor.
/// &RETURNS&: Tuple[Union[int,str,Any]]
fn shape(&self, py: Python) -> PyResult<Py<PyTuple>> {
let shape = PyList::empty(py);
let shape = PyList::empty_bound(py);
if let Some(d) = &self.0.shape {
for dim in d.dim.iter() {
if let Some(value) = &dim.value {
Expand Down

0 comments on commit 0546127

Please sign in to comment.