Skip to content

Commit

Permalink
Clippy fixes for onnx + fix a broken test.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 26, 2024
1 parent ed48f54 commit 6bab639
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 281 deletions.
27 changes: 13 additions & 14 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
use std::collections::HashMap;

pub type Value = Tensor;

Expand Down Expand Up @@ -321,7 +321,7 @@ fn simple_eval_(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
None => bail!("cannot find {input_name} for op '{}'", node.name),
};
let get_opt = |i: usize| {
node.input
Expand Down Expand Up @@ -362,7 +362,7 @@ fn simple_eval_(
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
let output = input0.powf(exp)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
Expand Down Expand Up @@ -643,7 +643,7 @@ fn simple_eval_(
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
.add(indices)?
};

// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
Expand Down Expand Up @@ -767,7 +767,7 @@ fn simple_eval_(

// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
Expand Down Expand Up @@ -1283,8 +1283,7 @@ fn simple_eval_(
.map(|x| x as usize)
.collect::<Vec<_>>();

let target_shape =
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;

let expanded_tensor = input_tensor.broadcast_as(target_shape)?;

Expand All @@ -1301,12 +1300,12 @@ fn simple_eval_(
.unwrap_or(0);

let axes = match axes {
Some(axes) => axes?
Some(Ok(axes)) => axes
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
Some(Err(_)) | None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
Expand Down Expand Up @@ -1640,7 +1639,7 @@ fn simple_eval_(
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
Expand All @@ -1649,8 +1648,8 @@ fn simple_eval_(

// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
Expand All @@ -1674,7 +1673,7 @@ fn simple_eval_(
)?;

let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
Expand All @@ -1688,7 +1687,7 @@ fn simple_eval_(
}

assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
if let Some(name) = node.output.first() {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((
Expand Down
Loading

0 comments on commit 6bab639

Please sign in to comment.