diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 567a711b3c..90a37be663 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -304,6 +304,7 @@ impl Dim for usize { pub enum D { Minus1, Minus2, + Minus(usize), } impl D { @@ -311,6 +312,7 @@ impl D { let dim = match self { Self::Minus1 => -1, Self::Minus2 => -2, + Self::Minus(u) => -(*u as i32), }; Error::DimOutOfRange { shape: shape.clone(), @@ -327,6 +329,7 @@ impl Dim for D { match self { Self::Minus1 if rank >= 1 => Ok(rank - 1), Self::Minus2 if rank >= 2 => Ok(rank - 2), + Self::Minus(u) if *u > 0 && rank >= *u => Ok(rank - *u), _ => Err(self.out_of_range(shape, op)), } } @@ -336,6 +339,7 @@ impl Dim for D { match self { Self::Minus1 => Ok(rank), Self::Minus2 if rank >= 1 => Ok(rank - 1), + Self::Minus(u) if *u > 0 && rank + 1 >= *u => Ok(rank + 1 - *u), _ => Err(self.out_of_range(shape, op)), } }