Skip to content

Commit

Permalink
feat: remove generic parameter from the BigNum trait (#44)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom French <[email protected]>
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
3 people authored Nov 8, 2024
1 parent 63e6c85 commit 53f652b
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 65 deletions.
76 changes: 40 additions & 36 deletions src/bignum.nr
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,23 @@ use crate::fns::{
pub struct BigNum<let N: u32, let MOD_BITS: u32, Params> {
pub limbs: [Field; N],
}

pub(crate) trait BigNumTrait<let N: u32> {
// We aim to avoid needing to add a generic parameter to this trait, for this reason we do not allow
// accessing the limbs of the bignum except through slices.
pub trait BigNumTrait {
// TODO: this crashes the compiler? v0.32
// fn default() -> Self { std::default::Default::default () }
pub fn new() -> Self;
pub fn one() -> Self;
pub fn derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub unconstrained fn __derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub fn from_slice(limbs: [Field]) -> Self;
pub fn from_array(limbs: [Field; N]) -> Self;
pub fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self;
pub fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

pub fn modulus() -> Self;
pub fn modulus_bits() -> u32;
pub fn num_limbs() -> u32;
// pub fn get(self) -> [Field];
pub fn get_limbs(self) -> [Field; N];
pub fn modulus_bits(self) -> u32;
pub fn num_limbs(self) -> u32;
pub fn get_limbs_slice(self) -> [Field];
pub fn get_limb(self, idx: u32) -> Field;
pub fn set_limb(&mut self, idx: u32, value: Field);

Expand Down Expand Up @@ -100,7 +99,7 @@ pub(crate) trait BigNumTrait<let N: u32> {
pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
}

impl<let N: u32, let MOD_BITS: u32, Params> BigNumTrait<N> for BigNum<N, MOD_BITS, Params>
impl<let N: u32, let MOD_BITS: u32, Params> BigNumTrait for BigNum<N, MOD_BITS, Params>
where
Params: BigNumParamsGetter<N, MOD_BITS>,
{
Expand Down Expand Up @@ -129,10 +128,6 @@ where
Self { limbs: limbs.as_array() }
}

fn from_array(limbs: [Field; N]) -> Self {
Self { limbs }
}

fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self {
Self { limbs: from_be_bytes::<_, MOD_BITS, _>(x) }
}
Expand All @@ -145,19 +140,15 @@ where
Self { limbs: Params::get_params().modulus }
}

fn modulus_bits() -> u32 {
fn modulus_bits(_: Self) -> u32 {
MOD_BITS
}

fn num_limbs() -> u32 {
fn num_limbs(_: Self) -> u32 {
N
}

// fn get(self) -> [Field] {
// self.get_limbs()
// }

fn get_limbs(self) -> [Field; N] {
fn get_limbs_slice(self) -> [Field] {
self.limbs
}

Expand All @@ -179,27 +170,27 @@ where

unconstrained fn __neg(self) -> Self {
let params = Params::get_params();
Self::from_array(__neg(params, self.limbs))
Self::from_slice(__neg(params, self.limbs))
}

unconstrained fn __add(self, other: Self) -> Self {
let params = Params::get_params();
Self::from_array(__add(params, self.limbs, other.limbs))
Self::from_slice(__add(params, self.limbs, other.limbs))
}

unconstrained fn __sub(self, other: Self) -> Self {
let params = Params::get_params();
Self::from_array(__sub(params, self.limbs, other.limbs))
Self::from_slice(__sub(params, self.limbs, other.limbs))
}

unconstrained fn __mul(self, other: Self) -> Self {
let params = Params::get_params();
Self::from_array(__mul::<_, MOD_BITS>(params, self.limbs, other.limbs))
Self::from_slice(__mul::<_, MOD_BITS>(params, self.limbs, other.limbs))
}

unconstrained fn __div(self, divisor: Self) -> Self {
let params = Params::get_params();
Self::from_array(__div::<_, MOD_BITS>(params, self.limbs, divisor.limbs))
Self::from_slice(__div::<_, MOD_BITS>(params, self.limbs, divisor.limbs))
}

unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self) {
Expand All @@ -221,17 +212,18 @@ where
unconstrained fn __batch_invert<let M: u32>(x: [Self; M]) -> [Self; M] {
let params = Params::get_params();
assert(params.has_multiplicative_inverse);
__batch_invert::<_, MOD_BITS, _>(params, x.map(|bn| Self::get_limbs(bn))).map(|limbs| {
Self { limbs }
})
__batch_invert::<_, MOD_BITS, _>(params, x.map(|bn| Self::get_limbs_slice(bn).as_array()))
.map(|limbs| Self { limbs })
}

unconstrained fn __batch_invert_slice<let M: u32>(x: [Self]) -> [Self] {
let params = Params::get_params();
assert(params.has_multiplicative_inverse);
__batch_invert_slice::<_, MOD_BITS>(params, x.map(|bn| Self::get_limbs(bn))).map(|limbs| {
Self { limbs }
})
__batch_invert_slice::<_, MOD_BITS>(
params,
x.map(|bn| Self::get_limbs_slice(bn).as_array()),
)
.map(|limbs| Self { limbs })
}

unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self> {
Expand All @@ -251,11 +243,17 @@ where
let params = Params::get_params();
let (q_limbs, r_limbs) = __compute_quadratic_expression::<_, MOD_BITS, _, _, _, _>(
params,
map(lhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
lhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
lhs_flags,
map(rhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
rhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
rhs_flags,
map(linear_terms, |bn| Self::get_limbs(bn)),
map(linear_terms, |bn| Self::get_limbs_slice(bn).as_array()),
linear_flags,
);
(Self { limbs: q_limbs }, Self { limbs: r_limbs })
Expand All @@ -272,11 +270,17 @@ where
let params = Params::get_params();
evaluate_quadratic_expression::<_, MOD_BITS, _, _, _, _>(
params,
map(lhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
lhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
lhs_flags,
map(rhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
rhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
rhs_flags,
map(linear_terms, |bn| Self::get_limbs(bn)),
map(linear_terms, |bn| Self::get_limbs_slice(bn).as_array()),
linear_flags,
)
}
Expand Down
6 changes: 3 additions & 3 deletions src/fields/secp256r1Fq.nr
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::utils::u60_representation::U60Repr;

pub struct Secp256r1_Fq_Params {}

impl BigNumParamsGetter<3, 265> for Secp256r1_Fq_Params {
fn get_params() -> BigNumParams<3, 265> {
impl BigNumParamsGetter<3, 256> for Secp256r1_Fq_Params {
fn get_params() -> BigNumParams<3, 256> {
Secp256r1_Fq_PARAMS
}
}

global Secp256r1_Fq_PARAMS: BigNumParams<3, 265> = BigNumParams {
global Secp256r1_Fq_PARAMS: BigNumParams<3, 256> = BigNumParams {
has_multiplicative_inverse: true,
modulus: [0xffffffffffffffffffffffff, 0xffff00000001000000000000000000, 0xffff],
double_modulus: [
Expand Down
1 change: 1 addition & 0 deletions src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) mod utils;

// Re-export the main structs so that users don't have to specify the paths
pub use bignum::BigNum;
pub use bignum::BigNumTrait; // So that external code can operate on a generic BigNum, `where BigNum: BigNumTrait`.
pub use runtime_bignum::RuntimeBigNum;

// Tests
Expand Down
53 changes: 27 additions & 26 deletions src/tests/bignum_test.nr
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,16 @@ fn test_bls_reduction() {

fn test_eq<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let c = unsafe { BN::__derive_from_seed([2, 2, 3, 4]) };

let modulus = BN::modulus();
let t0: U60Repr<N, 2> = (U60Repr::from(modulus.get_limbs()));
let t1: U60Repr<N, 2> = (U60Repr::from(b.get_limbs()));
let b_plus_modulus = BN::from_array(U60Repr::into(t0 + t1));
let t0: U60Repr<N, 2> = (U60Repr::from(modulus.get_limbs_slice().as_array()));
let t1: U60Repr<N, 2> = (U60Repr::from(b.get_limbs_slice().as_array()));
let b_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1));
assert(a.eq(b) == true);
assert(a.eq(b_plus_modulus) == true);
assert(c.eq(b) == false);
Expand Down Expand Up @@ -275,7 +275,7 @@ where
// // // 929 gates for a 2048 bit mul
fn test_mul<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::ops::Mul + std::ops::Add,
BN: BigNumTrait + std::ops::Mul + std::ops::Add,
{
let a: BN = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b: BN = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -287,7 +287,7 @@ where

fn test_add<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::ops::Add + std::ops::Mul + std::cmp::Eq,
BN: BigNumTrait + std::ops::Add + std::ops::Mul + std::cmp::Eq,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b: BN = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -301,7 +301,7 @@ where
let d = (a + b) * (one + one);
assert(c == (d));
let e = one + one;
let limbs = e.get_limbs();
let limbs: [Field; N] = e.get_limbs_slice().as_array();
let mut first: bool = true;
for limb in limbs {
if first {
Expand All @@ -315,7 +315,7 @@ where

fn test_div<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::ops::Div + std::ops::Mul + std::ops::Add + std::cmp::Eq,
BN: BigNumTrait + std::ops::Div + std::ops::Mul + std::ops::Add + std::cmp::Eq,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -326,7 +326,7 @@ where

fn test_invmod<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::cmp::Eq,
BN: BigNumTrait + std::cmp::Eq,
{
let u = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
for _ in 0..1 {
Expand All @@ -339,7 +339,7 @@ where

fn assert_is_not_equal<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -349,7 +349,7 @@ where

fn assert_is_not_equal_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
Expand All @@ -359,48 +359,48 @@ where

fn assert_is_not_equal_overloaded_lhs_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };

let modulus = BN::modulus();

let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs());
let a_plus_modulus = BN::from_array(U60Repr::into(t0 + t1));
let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs_slice().as_array());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs_slice().as_array());
let a_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1));
a_plus_modulus.assert_is_not_equal(b);
}

fn assert_is_not_equal_overloaded_rhs_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };

let modulus = BN::modulus();

let t0: U60Repr<N, 2> = U60Repr::from(b.get_limbs());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs());
let b_plus_modulus = BN::from_array(U60Repr::into(t0 + t1));
let t0: U60Repr<N, 2> = U60Repr::from(b.get_limbs_slice().as_array());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs_slice().as_array());
let b_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1));
a.assert_is_not_equal(b_plus_modulus);
}

fn assert_is_not_equal_overloaded_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };

let modulus = BN::modulus();

let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs());
let t1: U60Repr<N, 2> = U60Repr::from(b.get_limbs());
let t2: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs());
let a_plus_modulus: BN = BN::from_array(U60Repr::into(t0 + t2));
let b_plus_modulus: BN = BN::from_array(U60Repr::into(t1 + t2));
let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs_slice().as_array());
let t1: U60Repr<N, 2> = U60Repr::from(b.get_limbs_slice().as_array());
let t2: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs_slice().as_array());
let a_plus_modulus: BN = BN::from_slice(U60Repr::into(t0 + t2));
let b_plus_modulus: BN = BN::from_slice(U60Repr::into(t1 + t2));
a_plus_modulus.assert_is_not_equal(b_plus_modulus);
}

Expand Down Expand Up @@ -623,7 +623,8 @@ type U256 = BN256;
#[test]
fn test_udiv_mod_U256() {
let a: U256 = unsafe { BigNum::__derive_from_seed([1, 2, 3, 4]) };
let b: U256 = BigNum::from_array([12, 0, 0]);
let b: U256 = BigNum::from_slice([12, 0, 0]);

let (q, r) = a.udiv_mod(b);

// let qb = q.__mul(b);
Expand Down

0 comments on commit 53f652b

Please sign in to comment.