diff --git a/crates/utils/src/math.cairo b/crates/utils/src/math.cairo index 6e77fdbc9..cf13026f3 100644 --- a/crates/utils/src/math.cairo +++ b/crates/utils/src/math.cairo @@ -2,6 +2,7 @@ use core::integer::{u512}; use core::num::traits::{Zero, One, BitSize, OverflowingAdd, OverflowingMul}; use core::panic_with_felt252; use core::traits::{BitAnd}; +use utils::constants::POW_2; // === Exponentiation === @@ -236,6 +237,8 @@ impl BitshiftImpl< +PartialOrd, +BitSize, +TryInto, + +TryInto, + +TryInto, > of Bitshift { fn shl(self: T, shift: T) -> T { // if we shift by more than nb_bits of T, the result is 0 @@ -243,6 +246,12 @@ impl BitshiftImpl< if shift > BitSize::::bits().try_into().unwrap() - One::one() { panic_with_felt252('mul Overflow'); } + // if the shift is within the bit size of u256 (<= 255 bits), + // use the POW_2 lookup table to get 2^shift for efficient multiplication + if shift <= BitSize::::bits().try_into().unwrap() - One::::one() { + return self * (*POW_2.span().at(shift.try_into().unwrap())).try_into().unwrap(); + } + // for shifts greater than 255 bits, perform the shift manually let two = One::one() + One::one(); self * two.pow(shift) } @@ -252,6 +261,10 @@ impl BitshiftImpl< if shift > BitSize::::bits().try_into().unwrap() - One::one() { panic_with_felt252('mul Overflow'); } + // use the POW_2 lookup table when the bit size + if shift <= BitSize::::bits().try_into().unwrap() - One::::one() { + return self / (*POW_2.span().at(shift.try_into().unwrap())).try_into().unwrap(); + } let two = One::one() + One::one(); self / two.pow(shift) } @@ -301,16 +314,24 @@ pub impl WrappingBitshiftImpl< +WrappingExponentiation, +BitSize, +TryInto, + +TryInto, + +TryInto > of WrappingBitshift { fn wrapping_shl(self: T, shift: T) -> T { + if shift <= BitSize::::bits().try_into().unwrap() - One::::one() { + let (result, _) = self.overflowing_mul((*POW_2.span().at(shift.try_into().unwrap())).try_into().unwrap()); + return result; + } let two = One::::one() + One::::one(); let (result, _) = self.overflowing_mul(two.wrapping_pow(shift)); result } fn wrapping_shr(self: T, shift: T) -> T { + if shift <= BitSize::::bits().try_into().unwrap() - One::::one() { + return self / (*POW_2.span().at(shift.try_into().unwrap())).try_into().unwrap(); + } let two = One::::one() + One::::one(); - if shift > BitSize::::bits().try_into().unwrap() - One::one() { return Zero::zero(); }