diff --git a/src/bignum.nr b/src/bignum.nr index 787d0ed..b50a02b 100644 --- a/src/bignum.nr +++ b/src/bignum.nr @@ -17,8 +17,9 @@ use crate::fns::{ pub struct BigNum { pub limbs: [Field; N], } - -pub(crate) trait BigNumTrait { +// We aim to avoid needing to add a generic parameter to this trait, for this reason we do not allow +// accessing the limbs of the bignum except through slices. +pub trait BigNumTrait { // TODO: this crashes the compiler? v0.32 // fn default() -> Self { std::default::Default::default () } pub fn new() -> Self; @@ -26,15 +27,13 @@ pub(crate) trait BigNumTrait { pub fn derive_from_seed(seed: [u8; SeedBytes]) -> Self; pub unconstrained fn __derive_from_seed(seed: [u8; SeedBytes]) -> Self; pub fn from_slice(limbs: [Field]) -> Self; - pub fn from_array(limbs: [Field; N]) -> Self; pub fn from_be_bytes(x: [u8; NBytes]) -> Self; pub fn to_le_bytes(self) -> [u8; NBytes]; pub fn modulus() -> Self; - pub fn modulus_bits() -> u32; - pub fn num_limbs() -> u32; - // pub fn get(self) -> [Field]; - pub fn get_limbs(self) -> [Field; N]; + pub fn modulus_bits(self) -> u32; + pub fn num_limbs(self) -> u32; + pub fn get_limbs_slice(self) -> [Field]; pub fn get_limb(self, idx: u32) -> Field; pub fn set_limb(&mut self, idx: u32, value: Field); @@ -100,7 +99,7 @@ pub(crate) trait BigNumTrait { pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self; } -impl BigNumTrait for BigNum +impl BigNumTrait for BigNum where Params: BigNumParamsGetter, { @@ -129,10 +128,6 @@ where Self { limbs: limbs.as_array() } } - fn from_array(limbs: [Field; N]) -> Self { - Self { limbs } - } - fn from_be_bytes(x: [u8; NBytes]) -> Self { Self { limbs: from_be_bytes::<_, MOD_BITS, _>(x) } } @@ -145,19 +140,15 @@ where Self { limbs: Params::get_params().modulus } } - fn modulus_bits() -> u32 { + fn modulus_bits(_: Self) -> u32 { MOD_BITS } - fn num_limbs() -> u32 { + fn num_limbs(_: Self) -> u32 { N } - // fn get(self) -> [Field] { - // self.get_limbs() - // } - - fn get_limbs(self) -> [Field; N] { + fn get_limbs_slice(self) -> [Field] { self.limbs } @@ -179,27 +170,27 @@ where unconstrained fn __neg(self) -> Self { let params = Params::get_params(); - Self::from_array(__neg(params, self.limbs)) + Self::from_slice(__neg(params, self.limbs)) } unconstrained fn __add(self, other: Self) -> Self { let params = Params::get_params(); - Self::from_array(__add(params, self.limbs, other.limbs)) + Self::from_slice(__add(params, self.limbs, other.limbs)) } unconstrained fn __sub(self, other: Self) -> Self { let params = Params::get_params(); - Self::from_array(__sub(params, self.limbs, other.limbs)) + Self::from_slice(__sub(params, self.limbs, other.limbs)) } unconstrained fn __mul(self, other: Self) -> Self { let params = Params::get_params(); - Self::from_array(__mul::<_, MOD_BITS>(params, self.limbs, other.limbs)) + Self::from_slice(__mul::<_, MOD_BITS>(params, self.limbs, other.limbs)) } unconstrained fn __div(self, divisor: Self) -> Self { let params = Params::get_params(); - Self::from_array(__div::<_, MOD_BITS>(params, self.limbs, divisor.limbs)) + Self::from_slice(__div::<_, MOD_BITS>(params, self.limbs, divisor.limbs)) } unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self) { @@ -221,17 +212,18 @@ where unconstrained fn __batch_invert(x: [Self; M]) -> [Self; M] { let params = Params::get_params(); assert(params.has_multiplicative_inverse); - __batch_invert::<_, MOD_BITS, _>(params, x.map(|bn| Self::get_limbs(bn))).map(|limbs| { - Self { limbs } - }) + __batch_invert::<_, MOD_BITS, _>(params, x.map(|bn| Self::get_limbs_slice(bn).as_array())) + .map(|limbs| Self { limbs }) } unconstrained fn __batch_invert_slice(x: [Self]) -> [Self] { let params = Params::get_params(); assert(params.has_multiplicative_inverse); - __batch_invert_slice::<_, MOD_BITS>(params, x.map(|bn| Self::get_limbs(bn))).map(|limbs| { - Self { limbs } - }) + __batch_invert_slice::<_, MOD_BITS>( + params, + x.map(|bn| Self::get_limbs_slice(bn).as_array()), + ) + .map(|limbs| Self { limbs }) } unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option { @@ -251,11 +243,17 @@ where let params = Params::get_params(); let (q_limbs, r_limbs) = __compute_quadratic_expression::<_, MOD_BITS, _, _, _, _>( params, - map(lhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))), + map( + lhs_terms, + |bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()), + ), lhs_flags, - map(rhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))), + map( + rhs_terms, + |bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()), + ), rhs_flags, - map(linear_terms, |bn| Self::get_limbs(bn)), + map(linear_terms, |bn| Self::get_limbs_slice(bn).as_array()), linear_flags, ); (Self { limbs: q_limbs }, Self { limbs: r_limbs }) @@ -272,11 +270,17 @@ where let params = Params::get_params(); evaluate_quadratic_expression::<_, MOD_BITS, _, _, _, _>( params, - map(lhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))), + map( + lhs_terms, + |bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()), + ), lhs_flags, - map(rhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))), + map( + rhs_terms, + |bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()), + ), rhs_flags, - map(linear_terms, |bn| Self::get_limbs(bn)), + map(linear_terms, |bn| Self::get_limbs_slice(bn).as_array()), linear_flags, ) } diff --git a/src/fields/secp256r1Fq.nr b/src/fields/secp256r1Fq.nr index 6f661aa..6736a00 100644 --- a/src/fields/secp256r1Fq.nr +++ b/src/fields/secp256r1Fq.nr @@ -4,13 +4,13 @@ use crate::utils::u60_representation::U60Repr; pub struct Secp256r1_Fq_Params {} -impl BigNumParamsGetter<3, 265> for Secp256r1_Fq_Params { - fn get_params() -> BigNumParams<3, 265> { +impl BigNumParamsGetter<3, 256> for Secp256r1_Fq_Params { + fn get_params() -> BigNumParams<3, 256> { Secp256r1_Fq_PARAMS } } -global Secp256r1_Fq_PARAMS: BigNumParams<3, 265> = BigNumParams { +global Secp256r1_Fq_PARAMS: BigNumParams<3, 256> = BigNumParams { has_multiplicative_inverse: true, modulus: [0xffffffffffffffffffffffff, 0xffff00000001000000000000000000, 0xffff], double_modulus: [ diff --git a/src/lib.nr b/src/lib.nr index 60b25d8..fa8840f 100644 --- a/src/lib.nr +++ b/src/lib.nr @@ -16,6 +16,7 @@ pub(crate) mod utils; // Re-export the main structs so that users don't have to specify the paths pub use bignum::BigNum; +pub use bignum::BigNumTrait; // So that external code can operate on a generic BigNum, `where BigNum: BigNumTrait`. pub use runtime_bignum::RuntimeBigNum; // Tests diff --git a/src/tests/bignum_test.nr b/src/tests/bignum_test.nr index 46b936c..ffc8f52 100644 --- a/src/tests/bignum_test.nr +++ b/src/tests/bignum_test.nr @@ -238,16 +238,16 @@ fn test_bls_reduction() { fn test_eq() where - BN: BigNumTrait, + BN: BigNumTrait, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let c = unsafe { BN::__derive_from_seed([2, 2, 3, 4]) }; let modulus = BN::modulus(); - let t0: U60Repr = (U60Repr::from(modulus.get_limbs())); - let t1: U60Repr = (U60Repr::from(b.get_limbs())); - let b_plus_modulus = BN::from_array(U60Repr::into(t0 + t1)); + let t0: U60Repr = (U60Repr::from(modulus.get_limbs_slice().as_array())); + let t1: U60Repr = (U60Repr::from(b.get_limbs_slice().as_array())); + let b_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1)); assert(a.eq(b) == true); assert(a.eq(b_plus_modulus) == true); assert(c.eq(b) == false); @@ -275,7 +275,7 @@ where // // // 929 gates for a 2048 bit mul fn test_mul() where - BN: BigNumTrait + std::ops::Mul + std::ops::Add, + BN: BigNumTrait + std::ops::Mul + std::ops::Add, { let a: BN = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b: BN = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) }; @@ -287,7 +287,7 @@ where fn test_add() where - BN: BigNumTrait + std::ops::Add + std::ops::Mul + std::cmp::Eq, + BN: BigNumTrait + std::ops::Add + std::ops::Mul + std::cmp::Eq, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b: BN = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) }; @@ -301,7 +301,7 @@ where let d = (a + b) * (one + one); assert(c == (d)); let e = one + one; - let limbs = e.get_limbs(); + let limbs: [Field; N] = e.get_limbs_slice().as_array(); let mut first: bool = true; for limb in limbs { if first { @@ -315,7 +315,7 @@ where fn test_div() where - BN: BigNumTrait + std::ops::Div + std::ops::Mul + std::ops::Add + std::cmp::Eq, + BN: BigNumTrait + std::ops::Div + std::ops::Mul + std::ops::Add + std::cmp::Eq, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) }; @@ -326,7 +326,7 @@ where fn test_invmod() where - BN: BigNumTrait + std::cmp::Eq, + BN: BigNumTrait + std::cmp::Eq, { let u = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; for _ in 0..1 { @@ -339,7 +339,7 @@ where fn assert_is_not_equal() where - BN: BigNumTrait, + BN: BigNumTrait, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) }; @@ -349,7 +349,7 @@ where fn assert_is_not_equal_fail() where - BN: BigNumTrait, + BN: BigNumTrait, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; @@ -359,48 +359,48 @@ where fn assert_is_not_equal_overloaded_lhs_fail() where - BN: BigNumTrait, + BN: BigNumTrait, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let modulus = BN::modulus(); - let t0: U60Repr = U60Repr::from(a.get_limbs()); - let t1: U60Repr = U60Repr::from(modulus.get_limbs()); - let a_plus_modulus = BN::from_array(U60Repr::into(t0 + t1)); + let t0: U60Repr = U60Repr::from(a.get_limbs_slice().as_array()); + let t1: U60Repr = U60Repr::from(modulus.get_limbs_slice().as_array()); + let a_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1)); a_plus_modulus.assert_is_not_equal(b); } fn assert_is_not_equal_overloaded_rhs_fail() where - BN: BigNumTrait, + BN: BigNumTrait, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let modulus = BN::modulus(); - let t0: U60Repr = U60Repr::from(b.get_limbs()); - let t1: U60Repr = U60Repr::from(modulus.get_limbs()); - let b_plus_modulus = BN::from_array(U60Repr::into(t0 + t1)); + let t0: U60Repr = U60Repr::from(b.get_limbs_slice().as_array()); + let t1: U60Repr = U60Repr::from(modulus.get_limbs_slice().as_array()); + let b_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1)); a.assert_is_not_equal(b_plus_modulus); } fn assert_is_not_equal_overloaded_fail() where - BN: BigNumTrait, + BN: BigNumTrait, { let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) }; let modulus = BN::modulus(); - let t0: U60Repr = U60Repr::from(a.get_limbs()); - let t1: U60Repr = U60Repr::from(b.get_limbs()); - let t2: U60Repr = U60Repr::from(modulus.get_limbs()); - let a_plus_modulus: BN = BN::from_array(U60Repr::into(t0 + t2)); - let b_plus_modulus: BN = BN::from_array(U60Repr::into(t1 + t2)); + let t0: U60Repr = U60Repr::from(a.get_limbs_slice().as_array()); + let t1: U60Repr = U60Repr::from(b.get_limbs_slice().as_array()); + let t2: U60Repr = U60Repr::from(modulus.get_limbs_slice().as_array()); + let a_plus_modulus: BN = BN::from_slice(U60Repr::into(t0 + t2)); + let b_plus_modulus: BN = BN::from_slice(U60Repr::into(t1 + t2)); a_plus_modulus.assert_is_not_equal(b_plus_modulus); } @@ -623,7 +623,8 @@ type U256 = BN256; #[test] fn test_udiv_mod_U256() { let a: U256 = unsafe { BigNum::__derive_from_seed([1, 2, 3, 4]) }; - let b: U256 = BigNum::from_array([12, 0, 0]); + let b: U256 = BigNum::from_slice([12, 0, 0]); + let (q, r) = a.udiv_mod(b); // let qb = q.__mul(b);