diff --git a/src/params.nr b/src/params.nr index 01854ab..f5beb50 100644 --- a/src/params.nr +++ b/src/params.nr @@ -39,6 +39,17 @@ impl BigNumParams { } } +impl std::cmp::Eq for BigNumParams { + fn eq(self, other: Self) -> bool { + (self.has_multiplicative_inverse == other.has_multiplicative_inverse) + & (self.modulus == other.modulus) + & (self.modulus_u60 == other.modulus_u60) + & (self.modulus_u60_x4 == other.modulus_u60_x4) + & (self.double_modulus == other.double_modulus) + & (self.redc_param == other.redc_param) + } +} + fn get_double_modulus(modulus: [Field; N]) -> [Field; N] { let TWO_POW_120: Field = 0x1000000000000000000000000000000; let m: U60Repr = U60Repr::from(modulus); diff --git a/src/runtime_bignum.nr b/src/runtime_bignum.nr index 3e64566..0325ca5 100644 --- a/src/runtime_bignum.nr +++ b/src/runtime_bignum.nr @@ -201,6 +201,7 @@ impl RuntimeBigNumTrait for RuntimeB } unconstrained fn __eq(self, other: Self) -> bool { + assert(self.params == other.params); __eq(self.limbs, other.limbs) } @@ -218,6 +219,7 @@ impl RuntimeBigNumTrait for RuntimeB // UNCONSTRAINED! (Hence `__` prefix). fn __add(self, other: Self) -> Self { let params = self.params; + assert(params == other.params); let limbs = unsafe { __add(params, self.limbs, other.limbs) }; Self { params, limbs } } @@ -225,6 +227,7 @@ impl RuntimeBigNumTrait for RuntimeB // UNCONSTRAINED! (Hence `__` prefix). fn __sub(self, other: Self) -> Self { let params = self.params; + assert(params == other.params); let limbs = unsafe { __sub(params, self.limbs, other.limbs) }; Self { params, limbs } } @@ -232,6 +235,7 @@ impl RuntimeBigNumTrait for RuntimeB // UNCONSTRAINED! (Hence `__` prefix). fn __mul(self, other: Self) -> Self { let params = self.params; + assert(params == other.params); let limbs = unsafe { __mul::<_, MOD_BITS>(params, self.limbs, other.limbs) }; Self { params, limbs } } @@ -239,6 +243,7 @@ impl RuntimeBigNumTrait for RuntimeB // UNCONSTRAINED! (Hence `__` prefix). fn __div(self, divisor: Self) -> Self { let params = self.params; + assert(params == divisor.params); let limbs = unsafe { __div::<_, MOD_BITS>(params, self.limbs, divisor.limbs) }; Self { params, limbs } } @@ -246,6 +251,7 @@ impl RuntimeBigNumTrait for RuntimeB // UNCONSTRAINED! (Hence `__` prefix). fn __udiv_mod(self, divisor: Self) -> (Self, Self) { let params = self.params; + assert(params == divisor.params); let (q, r) = unsafe { __udiv_mod(self.limbs, divisor.limbs) }; (Self { limbs: q, params }, Self { limbs: r, params }) } @@ -261,6 +267,7 @@ impl RuntimeBigNumTrait for RuntimeB // UNCONSTRAINED! (Hence `__` prefix). fn __pow(self, exponent: Self) -> Self { let params = self.params; + assert(params == exponent.params); let limbs = unsafe { __pow::<_, MOD_BITS>(params, self.limbs, exponent.limbs) }; Self { limbs, params } } @@ -348,6 +355,7 @@ impl RuntimeBigNumTrait for RuntimeB fn assert_is_not_equal(self, other: Self) { let params = self.params; + assert(params == other.params); assert_is_not_equal(params, self.limbs, other.limbs); } @@ -358,22 +366,26 @@ impl RuntimeBigNumTrait for RuntimeB fn udiv_mod(self, divisor: Self) -> (Self, Self) { let params = self.params; + assert(params == divisor.params); let (q, r) = udiv_mod::<_, MOD_BITS>(params, self.limbs, divisor.limbs); (Self { limbs: q, params }, Self { limbs: r, params }) } fn udiv(self, divisor: Self) -> Self { let params = self.params; + assert(params == divisor.params); Self { limbs: udiv::<_, MOD_BITS>(params, self.limbs, divisor.limbs), params } } fn umod(self, divisor: Self) -> Self { let params = self.params; + assert(params == divisor.params); Self { limbs: umod::<_, MOD_BITS>(params, self.limbs, divisor.limbs), params } } fn conditional_select(lhs: Self, rhs: Self, predicate: bool) -> Self { let params = lhs.params; + assert(params == rhs.params); Self { limbs: conditional_select(lhs.limbs, rhs.limbs, predicate), params } } } @@ -383,6 +395,7 @@ impl std::ops::Add for RuntimeBigNum // via evaluate_quadratic_expression fn add(self, other: Self) -> Self { let params = self.params; + assert(params == other.params); Self { limbs: add::<_, MOD_BITS>(params, self.limbs, other.limbs), params } } } @@ -392,6 +405,7 @@ impl std::ops::Sub for RuntimeBigNum // via evaluate_quadratic_expression fn sub(self, other: Self) -> Self { let params = self.params; + assert(params == other.params); Self { limbs: sub::<_, MOD_BITS>(params, self.limbs, other.limbs), params } } } @@ -403,6 +417,7 @@ impl std::ops::Mul for RuntimeBigNum // will create much fewer constraints than calling `mul` and `add` directly fn mul(self, other: Self) -> Self { let params = self.params; + assert(params == other.params); Self { limbs: mul::<_, MOD_BITS>(params, self.limbs, other.limbs), params } } } @@ -411,6 +426,7 @@ impl std::ops::Div for RuntimeBigNum // Note: this method is expensive! Witness computation is extremely expensive as it requires modular exponentiation fn div(self, divisor: Self) -> Self { let params = self.params; + assert(params == divisor.params); Self { limbs: div::<_, MOD_BITS>(params, self.limbs, divisor.limbs), params } } } @@ -418,6 +434,7 @@ impl std::ops::Div for RuntimeBigNum impl std::cmp::Eq for RuntimeBigNum { fn eq(self, other: Self) -> bool { let params = self.params; + assert(params == other.params); eq::<_, MOD_BITS>(params, self.limbs, other.limbs) } }