Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: remove generic parameter from the BigNum trait #44

Merged
merged 8 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 40 additions & 36 deletions src/bignum.nr
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,23 @@ use crate::fns::{
pub struct BigNum<let N: u32, let MOD_BITS: u32, Params> {
pub limbs: [Field; N],
}

pub(crate) trait BigNumTrait<let N: u32> {
// We aim to avoid needing to add a generic parameter to this trait, for this reason we do not allow
// accessing the limbs of the bignum except through slices.
pub trait BigNumTrait {
// TODO: this crashes the compiler? v0.32
// fn default() -> Self { std::default::Default::default () }
pub fn new() -> Self;
pub fn one() -> Self;
pub fn derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub unconstrained fn __derive_from_seed<let SeedBytes: u32>(seed: [u8; SeedBytes]) -> Self;
pub fn from_slice(limbs: [Field]) -> Self;
pub fn from_array(limbs: [Field; N]) -> Self;
pub fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self;
pub fn to_le_bytes<let NBytes: u32>(self) -> [u8; NBytes];

pub fn modulus() -> Self;
pub fn modulus_bits() -> u32;
pub fn num_limbs() -> u32;
// pub fn get(self) -> [Field];
pub fn get_limbs(self) -> [Field; N];
pub fn modulus_bits(self) -> u32;
pub fn num_limbs(self) -> u32;
pub fn get_limbs_slice(self) -> [Field];
pub fn get_limb(self, idx: u32) -> Field;
pub fn set_limb(&mut self, idx: u32, value: Field);

Expand Down Expand Up @@ -100,7 +99,7 @@ pub(crate) trait BigNumTrait<let N: u32> {
pub fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self;
}

impl<let N: u32, let MOD_BITS: u32, Params> BigNumTrait<N> for BigNum<N, MOD_BITS, Params>
impl<let N: u32, let MOD_BITS: u32, Params> BigNumTrait for BigNum<N, MOD_BITS, Params>
where
Params: BigNumParamsGetter<N, MOD_BITS>,
{
Expand Down Expand Up @@ -129,10 +128,6 @@ where
Self { limbs: limbs.as_array() }
}

fn from_array(limbs: [Field; N]) -> Self {
Self { limbs }
}

fn from_be_bytes<let NBytes: u32>(x: [u8; NBytes]) -> Self {
Self { limbs: from_be_bytes::<_, MOD_BITS, _>(x) }
}
Expand All @@ -145,19 +140,15 @@ where
Self { limbs: Params::get_params().modulus }
}

fn modulus_bits() -> u32 {
fn modulus_bits(_: Self) -> u32 {
MOD_BITS
}

fn num_limbs() -> u32 {
fn num_limbs(_: Self) -> u32 {
N
}

// fn get(self) -> [Field] {
// self.get_limbs()
// }

fn get_limbs(self) -> [Field; N] {
fn get_limbs_slice(self) -> [Field] {
self.limbs
}

Expand All @@ -179,27 +170,27 @@ where

unconstrained fn __neg(self) -> Self {
let params = Params::get_params();
Self::from_array(__neg(params, self.limbs))
Self::from_slice(__neg(params, self.limbs))
}

unconstrained fn __add(self, other: Self) -> Self {
let params = Params::get_params();
Self::from_array(__add(params, self.limbs, other.limbs))
Self::from_slice(__add(params, self.limbs, other.limbs))
}

unconstrained fn __sub(self, other: Self) -> Self {
let params = Params::get_params();
Self::from_array(__sub(params, self.limbs, other.limbs))
Self::from_slice(__sub(params, self.limbs, other.limbs))
}

unconstrained fn __mul(self, other: Self) -> Self {
let params = Params::get_params();
Self::from_array(__mul::<_, MOD_BITS>(params, self.limbs, other.limbs))
Self::from_slice(__mul::<_, MOD_BITS>(params, self.limbs, other.limbs))
}

unconstrained fn __div(self, divisor: Self) -> Self {
let params = Params::get_params();
Self::from_array(__div::<_, MOD_BITS>(params, self.limbs, divisor.limbs))
Self::from_slice(__div::<_, MOD_BITS>(params, self.limbs, divisor.limbs))
}

unconstrained fn __udiv_mod(self, divisor: Self) -> (Self, Self) {
Expand All @@ -221,17 +212,18 @@ where
unconstrained fn __batch_invert<let M: u32>(x: [Self; M]) -> [Self; M] {
let params = Params::get_params();
assert(params.has_multiplicative_inverse);
__batch_invert::<_, MOD_BITS, _>(params, x.map(|bn| Self::get_limbs(bn))).map(|limbs| {
Self { limbs }
})
__batch_invert::<_, MOD_BITS, _>(params, x.map(|bn| Self::get_limbs_slice(bn).as_array()))
.map(|limbs| Self { limbs })
}

unconstrained fn __batch_invert_slice<let M: u32>(x: [Self]) -> [Self] {
let params = Params::get_params();
assert(params.has_multiplicative_inverse);
__batch_invert_slice::<_, MOD_BITS>(params, x.map(|bn| Self::get_limbs(bn))).map(|limbs| {
Self { limbs }
})
__batch_invert_slice::<_, MOD_BITS>(
params,
x.map(|bn| Self::get_limbs_slice(bn).as_array()),
)
.map(|limbs| Self { limbs })
}

unconstrained fn __tonelli_shanks_sqrt(self) -> std::option::Option<Self> {
Expand All @@ -251,11 +243,17 @@ where
let params = Params::get_params();
let (q_limbs, r_limbs) = __compute_quadratic_expression::<_, MOD_BITS, _, _, _, _>(
params,
map(lhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
lhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
lhs_flags,
map(rhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
rhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
rhs_flags,
map(linear_terms, |bn| Self::get_limbs(bn)),
map(linear_terms, |bn| Self::get_limbs_slice(bn).as_array()),
linear_flags,
);
(Self { limbs: q_limbs }, Self { limbs: r_limbs })
Expand All @@ -272,11 +270,17 @@ where
let params = Params::get_params();
evaluate_quadratic_expression::<_, MOD_BITS, _, _, _, _>(
params,
map(lhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
Copy link
Collaborator Author

@iAmMichaelConnor iAmMichaelConnor Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomAFrench I've honed in on this PR being the one that caused a regression in the constraint counts of BigNum. In particular, I was measuring the constraints of calls to this evaluate_quadratic_expression function.
Any thoughts on why the changes in this PR are increasing constraint counts? Is it simply that converting between slices and arrays is not free?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra ACIR opcodes (for a main function which calls this evaluate_quadratic_expression) function appear to be:

INIT (id: 0, len: 4) 
EXPR [ (-1, _5) 0 ]
MEM (id: 0, read at: x5, value: x6) 
EXPR [ (-1, _7) 1 ]
MEM (id: 0, read at: x7, value: x8) 
EXPR [ (-1, _9) 2 ]
MEM (id: 0, read at: x9, value: x10) 
EXPR [ (-1, _11) 3 ]
MEM (id: 0, read at: x11, value: x12) 
INIT (id: 1, len: 4) 
MEM (id: 1, read at: x5, value: x13) 
MEM (id: 1, read at: x7, value: x14) 
MEM (id: 1, read at: x9, value: x15) 
MEM (id: 1, read at: x11, value: x16) 

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is bad actor codegen as we're just reading all the values out of the array to write into another one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neater code here, to fix it: #53

lhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
lhs_flags,
map(rhs_terms, |bns| map(bns, |bn| Self::get_limbs(bn))),
map(
rhs_terms,
|bns| map(bns, |bn| Self::get_limbs_slice(bn).as_array()),
),
rhs_flags,
map(linear_terms, |bn| Self::get_limbs(bn)),
map(linear_terms, |bn| Self::get_limbs_slice(bn).as_array()),
linear_flags,
)
}
Expand Down
6 changes: 3 additions & 3 deletions src/fields/secp256r1Fq.nr
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use crate::utils::u60_representation::U60Repr;

pub struct Secp256r1_Fq_Params {}

impl BigNumParamsGetter<3, 265> for Secp256r1_Fq_Params {
fn get_params() -> BigNumParams<3, 265> {
impl BigNumParamsGetter<3, 256> for Secp256r1_Fq_Params {
fn get_params() -> BigNumParams<3, 256> {
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
Secp256r1_Fq_PARAMS
}
}

global Secp256r1_Fq_PARAMS: BigNumParams<3, 265> = BigNumParams {
global Secp256r1_Fq_PARAMS: BigNumParams<3, 256> = BigNumParams {
has_multiplicative_inverse: true,
modulus: [0xffffffffffffffffffffffff, 0xffff00000001000000000000000000, 0xffff],
double_modulus: [
Expand Down
1 change: 1 addition & 0 deletions src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub(crate) mod utils;

// Re-export the main structs so that users don't have to specify the paths
pub use bignum::BigNum;
pub use bignum::BigNumTrait; // So that external code can operate on a generic BigNum, `where BigNum: BigNumTrait`.
pub use runtime_bignum::RuntimeBigNum;

// Tests
Expand Down
53 changes: 27 additions & 26 deletions src/tests/bignum_test.nr
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,16 @@ fn test_bls_reduction() {

fn test_eq<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let c = unsafe { BN::__derive_from_seed([2, 2, 3, 4]) };

let modulus = BN::modulus();
let t0: U60Repr<N, 2> = (U60Repr::from(modulus.get_limbs()));
let t1: U60Repr<N, 2> = (U60Repr::from(b.get_limbs()));
let b_plus_modulus = BN::from_array(U60Repr::into(t0 + t1));
let t0: U60Repr<N, 2> = (U60Repr::from(modulus.get_limbs_slice().as_array()));
let t1: U60Repr<N, 2> = (U60Repr::from(b.get_limbs_slice().as_array()));
let b_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1));
assert(a.eq(b) == true);
assert(a.eq(b_plus_modulus) == true);
assert(c.eq(b) == false);
Expand Down Expand Up @@ -275,7 +275,7 @@ where
// // // 929 gates for a 2048 bit mul
fn test_mul<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::ops::Mul + std::ops::Add,
BN: BigNumTrait + std::ops::Mul + std::ops::Add,
{
let a: BN = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b: BN = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -287,7 +287,7 @@ where

fn test_add<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::ops::Add + std::ops::Mul + std::cmp::Eq,
BN: BigNumTrait + std::ops::Add + std::ops::Mul + std::cmp::Eq,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b: BN = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -301,7 +301,7 @@ where
let d = (a + b) * (one + one);
assert(c == (d));
let e = one + one;
let limbs = e.get_limbs();
let limbs: [Field; N] = e.get_limbs_slice().as_array();
let mut first: bool = true;
for limb in limbs {
if first {
Expand All @@ -315,7 +315,7 @@ where

fn test_div<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::ops::Div + std::ops::Mul + std::ops::Add + std::cmp::Eq,
BN: BigNumTrait + std::ops::Div + std::ops::Mul + std::ops::Add + std::cmp::Eq,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -326,7 +326,7 @@ where

fn test_invmod<let N: u32, BN>()
where
BN: BigNumTrait<N> + std::cmp::Eq,
BN: BigNumTrait + std::cmp::Eq,
{
let u = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
for _ in 0..1 {
Expand All @@ -339,7 +339,7 @@ where

fn assert_is_not_equal<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([4, 5, 6, 7]) };
Expand All @@ -349,7 +349,7 @@ where

fn assert_is_not_equal_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
Expand All @@ -359,48 +359,48 @@ where

fn assert_is_not_equal_overloaded_lhs_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };

let modulus = BN::modulus();

let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs());
let a_plus_modulus = BN::from_array(U60Repr::into(t0 + t1));
let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs_slice().as_array());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs_slice().as_array());
let a_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1));
a_plus_modulus.assert_is_not_equal(b);
}

fn assert_is_not_equal_overloaded_rhs_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };

let modulus = BN::modulus();

let t0: U60Repr<N, 2> = U60Repr::from(b.get_limbs());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs());
let b_plus_modulus = BN::from_array(U60Repr::into(t0 + t1));
let t0: U60Repr<N, 2> = U60Repr::from(b.get_limbs_slice().as_array());
let t1: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs_slice().as_array());
let b_plus_modulus = BN::from_slice(U60Repr::into(t0 + t1));
a.assert_is_not_equal(b_plus_modulus);
}

fn assert_is_not_equal_overloaded_fail<let N: u32, BN>()
where
BN: BigNumTrait<N>,
BN: BigNumTrait,
{
let a = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };
let b = unsafe { BN::__derive_from_seed([1, 2, 3, 4]) };

let modulus = BN::modulus();

let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs());
let t1: U60Repr<N, 2> = U60Repr::from(b.get_limbs());
let t2: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs());
let a_plus_modulus: BN = BN::from_array(U60Repr::into(t0 + t2));
let b_plus_modulus: BN = BN::from_array(U60Repr::into(t1 + t2));
let t0: U60Repr<N, 2> = U60Repr::from(a.get_limbs_slice().as_array());
let t1: U60Repr<N, 2> = U60Repr::from(b.get_limbs_slice().as_array());
let t2: U60Repr<N, 2> = U60Repr::from(modulus.get_limbs_slice().as_array());
let a_plus_modulus: BN = BN::from_slice(U60Repr::into(t0 + t2));
let b_plus_modulus: BN = BN::from_slice(U60Repr::into(t1 + t2));
a_plus_modulus.assert_is_not_equal(b_plus_modulus);
}

Expand Down Expand Up @@ -623,7 +623,8 @@ type U256 = BN256;
#[test]
fn test_udiv_mod_U256() {
let a: U256 = unsafe { BigNum::__derive_from_seed([1, 2, 3, 4]) };
let b: U256 = BigNum::from_array([12, 0, 0]);
let b: U256 = BigNum::from_slice([12, 0, 0]);

let (q, r) = a.udiv_mod(b);

// let qb = q.__mul(b);
Expand Down
Loading