From 3188ea74fe3b059219a2ea87899589c266256d74 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Wed, 23 Oct 2024 14:45:21 +0100 Subject: [PATCH] feat!: update to support Noir 0.36.0 (#7) * feat: update to support Noir 0.36.0 --- .github/workflows/test.yml | 4 +- Nargo.toml | 2 +- src/bjj.nr | 7 ++- src/lib.nr | 111 ++++++++++++++++++++++++++----------- src/scalar_field.nr | 49 ++++++++-------- src/test.nr | 10 ++-- 6 files changed, 118 insertions(+), 65 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d8ab990..971fc98 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - toolchain: [nightly, 0.34.0] + toolchain: [nightly, 0.36.0] steps: - name: Checkout sources uses: actions/checkout@v4 @@ -38,7 +38,7 @@ jobs: - name: Install Nargo uses: noir-lang/noirup@v0.1.3 with: - toolchain: 0.34.0 + toolchain: 0.36.0 - name: Run formatter run: nargo fmt --check diff --git a/Nargo.toml b/Nargo.toml index de6d9ac..b6a63f4 100644 --- a/Nargo.toml +++ b/Nargo.toml @@ -2,6 +2,6 @@ name = "edwards" type = "lib" authors = [""] -compiler_version = ">=0.34.0" +compiler_version = ">=0.36.0" [dependencies] diff --git a/src/bjj.nr b/src/bjj.nr index 78b238c..826f99a 100644 --- a/src/bjj.nr +++ b/src/bjj.nr @@ -1,7 +1,7 @@ use crate::TECurveParameterTrait; use crate::Curve; -struct BabyJubJubParams {} +pub struct BabyJubJubParams {} impl TECurveParameterTrait for BabyJubJubParams { fn a() -> Field { 168700 @@ -11,8 +11,9 @@ impl TECurveParameterTrait for BabyJubJubParams { } fn gen() -> (Field, Field) { ( - 0x0bb77a6ad63e739b4eacb2e09d6277c12ab8d8010534e0b62893f3f6bb957051, 0x25797203f7a0b24925572e1cd16bf9edfce0051fb9e133774b3c257a872d7d8b + 0x0bb77a6ad63e739b4eacb2e09d6277c12ab8d8010534e0b62893f3f6bb957051, + 0x25797203f7a0b24925572e1cd16bf9edfce0051fb9e133774b3c257a872d7d8b, ) } } -type BabyJubJub = Curve; +pub type BabyJubJub = Curve; diff --git a/src/lib.nr b/src/lib.nr index b583336..b04f938 100644 --- a/src/lib.nr +++ b/src/lib.nr @@ -4,7 +4,7 @@ mod bjj; use crate::scalar_field::ScalarField; -struct Curve { +pub struct Curve { x: Field, y: Field, } @@ -23,21 +23,39 @@ trait TECurveParameterTrait { } /// Defines methods that a valid Curve implementation must satisfy -trait CurveTrait where CurveTrait: std::ops::Add + std::ops::Sub + std::ops::Eq + std::ops::Neg { - fn default() -> Self { std::default::Default::default () } +trait CurveTrait +where + CurveTrait: std::ops::Add + std::ops::Sub + std::ops::Eq + std::ops::Neg, +{ + fn default() -> Self { + std::default::Default::default() + } fn new(x: Field, y: Field) -> Self; fn zero() -> Self; fn one() -> Self; - fn add(self, x: Self) -> Self { self + x } - fn sub(self, x: Self) -> Self { self - x } - fn neg(self) -> Self { std::ops::Neg::neg(self) } + fn add(self, x: Self) -> Self { + self + x + } + fn sub(self, x: Self) -> Self { + self - x + } + fn neg(self) -> Self { + std::ops::Neg::neg(self) + } fn dbl(self) -> Self; fn mul(self, x: ScalarField) -> Self; - fn msm (points: [Self; N], scalars: [ScalarField; N]) -> Self; + fn msm( + points: [Self; N], + scalars: [ScalarField; N], + ) -> Self; - fn eq(self, x: Self) -> bool { self == x } - fn is_zero(self) -> bool { self == Self::zero() } + fn eq(self, x: Self) -> bool { + self == x + } + fn is_zero(self) -> bool { + self == Self::zero() + } fn is_on_curve(self) -> bool; fn assert_is_on_curve(self); @@ -49,7 +67,10 @@ trait CurveTrait where CurveTrait: std::ops::Add + std::ops::Sub // ### C O N S T R A I N E D F U N C T I O N S // #################################################################################################################### // #################################################################################################################### -impl std::default::Default for Curve where Params: TECurveParameterTrait { +impl std::default::Default for Curve +where + Params: TECurveParameterTrait, +{ /// Returns point at infinity /// /// Cost: 0 gates @@ -58,7 +79,10 @@ impl std::default::Default for Curve where Params: TECurveParame } } -impl std::ops::Add for Curve where Params: TECurveParameterTrait { +impl std::ops::Add for Curve +where + Params: TECurveParameterTrait, +{ /// Compute `self + other` /// /// Cost: 7 gates @@ -67,7 +91,10 @@ impl std::ops::Add for Curve where Params: TECurveParameterTrait } } -impl std::ops::Neg for Curve where Params: TECurveParameterTrait { +impl std::ops::Neg for Curve +where + Params: TECurveParameterTrait, +{ /// Negate a point /// /// Cost: usually 0, will cost 1 gate if the `x` coordinate needs to be converted into a witness @@ -76,7 +103,10 @@ impl std::ops::Neg for Curve where Params: TECurveParameterTrait } } -impl std::ops::Sub for Curve where Params: TECurveParameterTrait { +impl std::ops::Sub for Curve +where + Params: TECurveParameterTrait, +{ /// Compute `self - other` /// /// Cost: 7 gates @@ -85,7 +115,10 @@ impl std::ops::Sub for Curve where Params: TECurveParameterTrait } } -impl std::cmp::Eq for Curve where Params: TECurveParameterTrait { +impl std::cmp::Eq for Curve +where + Params: TECurveParameterTrait, +{ /// Compute `self == other` /// /// Cost: 6 gates @@ -94,9 +127,12 @@ impl std::cmp::Eq for Curve where Params: TECurveParameterTrait } } -impl std::convert::From<(Field, Field)> for Curve where Params: TECurveParameterTrait { +impl std::convert::From<(Field, Field)> for Curve +where + Params: TECurveParameterTrait, +{ /// Construct from tuple of field elements - /// + /// /// Use this method instead of `new` if you know x/y is on the curve /// /// Cost: 0 gates @@ -105,10 +141,13 @@ impl std::convert::From<(Field, Field)> for Curve where Params: } } -impl CurveTrait for Curve where Params: TECurveParameterTrait { +impl CurveTrait for Curve +where + Params: TECurveParameterTrait, +{ /// Construct a new point - /// + /// /// If you know the x/y coords form a valid point DO NOT USE THIS METHOD /// This method calls `assert_is_on_curve` which costs 3 gates. /// Instead, directly construct via Curve{x, y} or use from((x, y)) @@ -136,7 +175,7 @@ impl CurveTrait for Curve where Params: TECurveParameter } /// Validate a point is on the curve - /// + /// /// cheaper than `is_on_curve` (assert is cheaper than returning a bool) /// /// Cost: 3 gates @@ -151,7 +190,7 @@ impl CurveTrait for Curve where Params: TECurveParameter } /// Constrain two points to equal each other - /// + /// /// Cheaper than `assert(self == other)` because no need to return a bool /// /// Cost: 0-2 gates (can do these asserts with just copy constraints) @@ -161,7 +200,7 @@ impl CurveTrait for Curve where Params: TECurveParameter } /// Return a bool that describes whether the point is on the curve - /// + /// /// If you don't need to handle the failure case, it is cheaper to call `assert_is_on_curve` /// /// Cost: 5 gates @@ -183,11 +222,11 @@ impl CurveTrait for Curve where Params: TECurveParameter } /// Compute `self * scalar` - /// + /// /// Uses the Straus method via lookup tables. /// Assumes backend has an efficient implementation of a memory table abstraction /// i.e. `let x = table[y]` is efficient even if `y` is not known at compile time - /// + /// /// Key cost components are as follows: /// 1: computing the Straus point lookup table (169 gates) /// 2: 252 point doublings (1260 gates) @@ -198,7 +237,7 @@ impl CurveTrait for Curve where Params: TECurveParameter /// /// TODO: use windowed non-adjacent form to remove 7 point additions when creating lookup table fn mul(self: Self, scalar: ScalarField) -> Self { - // define a, d params locally to make code more readable (shouldn't affect performance) + // define a, d params locally to make code more readable (shouldn't affect performance) let a = Params::a(); let d = Params::d(); @@ -222,7 +261,7 @@ impl CurveTrait for Curve where Params: TECurveParameter let idx: u8 = scalar.base4_slices[i]; let x = table_x[idx]; let y = table_y[idx]; - accumulator = accumulator.add_internal(Curve{ x, y }, a, d); + accumulator = accumulator.add_internal(Curve { x, y }, a, d); } // todo fix @@ -238,7 +277,7 @@ impl CurveTrait for Curve where Params: TECurveParameter /// uses the Straus MSM method via lookup tables. /// Assumes backend has an efficient implementation of a memory table abstraction /// i.e. `let x = table[y]` is efficient even if `y` is not known at compile time - /// + /// /// Key cost components are as follows /// PER POINT costs: /// 1: computing the Straus point lookup table (169N gates) @@ -251,7 +290,10 @@ impl CurveTrait for Curve where Params: TECurveParameter /// Cost: 1260 + 862N + cost of creating ScalarField (110N gates) /// /// TODO: use windowed non-adjacent form to remove 7 point additions per point when creating lookup table - fn msm(points: [Self; N], scalars: [ScalarField; N]) -> Self { + fn msm( + points: [Self; N], + scalars: [ScalarField; N], + ) -> Self { let a = Params::a(); let d = Params::d(); @@ -277,7 +319,7 @@ impl CurveTrait for Curve where Params: TECurveParameter let idx: u8 = scalars[j].base4_slices[i]; let x = point_tables[j].0[idx]; let y = point_tables[j].1[idx]; - accumulator = accumulator.add_internal(Curve{ x, y }, a, d); + accumulator = accumulator.add_internal(Curve { x, y }, a, d); } } @@ -298,7 +340,7 @@ impl CurveTrait for Curve where Params: TECurveParameter impl Curve { /// add two points together - /// + /// /// This method exists because of a Noir bug where `Params` cannot be accessed by an internal function /// called from internal function. e.g. compiler error if `mul` impl tries to call `add` :( fn add_internal(self, other: Self, a: Field, d: Field) -> Self { @@ -352,7 +394,7 @@ impl Curve { } /// Compute a 4-bit lookup table of point multiples for the Straus windowed scalar multiplication algorithm. - /// + /// /// Table contains [0, P, 2P, ..., 15P], which is used in the scalar mul algorithm to minimize the total number of required point additions /// /// It is cheaper to use ([Field; 16], [Field; 16]) than it is ([Curve; 16]). @@ -401,7 +443,14 @@ impl Curve { } /// add points together, return output + lambda ter -unconstrained fn __add_unconstrained(x1: Field, x2: Field, y1: Field, y2: Field, a: Field, d: Field) -> (Field, Field, Field) { +unconstrained fn __add_unconstrained( + x1: Field, + x2: Field, + y1: Field, + y2: Field, + a: Field, + d: Field, +) -> (Field, Field, Field) { let lambda = y1 * y2 * x1 * x2; let y = (x1 * x2 * a - y1 * y2) / (lambda * d - 1); let x = (x1 * y2 + y1 * x2) / (lambda * d + 1); diff --git a/src/scalar_field.nr b/src/scalar_field.nr index 40021fe..8c821dc 100644 --- a/src/scalar_field.nr +++ b/src/scalar_field.nr @@ -1,5 +1,5 @@ /// ScalarField represents a scalar multiplier as a sequence of 4-bit slices -/// +/// /// There is nuance to ScalarField, because twisted edwards curves generally have prime group orders that easily fit into a Field /// We can therefore obtain cheap conversions by simply summing up the bit slices and validate they equal the input scalar /// However...when converting arbitrary field elements (i.e. scalars that are multiples of a TE curve group order), @@ -12,7 +12,7 @@ /// ScalarField is used when performing scalar multiplications, where all operations wrap modulo the curve order struct ScalarField { base4_slices: [u8; N], - skew: bool + skew: bool, } unconstrained fn get_wnaf_slices(x: Field) -> ([u8; N], bool) { @@ -21,12 +21,12 @@ unconstrained fn get_wnaf_slices(x: Field) -> ([u8; N], bool) { let skew: bool = nibbles[0] & 1 == 0; nibbles[0] += skew as u8; - result[N-1] = (nibbles[0] + 15) / 2; + result[N - 1] = (nibbles[0] + 15) / 2; for i in 1..N { let mut nibble: u8 = nibbles[i]; - result[N-1 - i] = (nibble + 15) / 2; + result[N - 1 - i] = (nibble + 15) / 2; if (nibble & 1 == 0) { - result[N-1 - i] += 1; + result[N - 1 - i] += 1; result[N - i] -= 8; } } @@ -46,13 +46,15 @@ unconstrained fn from_wnaf_slices(x: [u8; 64], skew: bool) -> Field { #[test] fn test_wnaf() { - let result: Field = 0x123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0; - let (t0, t1) = get_wnaf_slices(result); - let expected = from_wnaf_slices(t0, t1); - assert(result == expected); + unsafe { + let result: Field = 0x123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0; + let (t0, t1) = get_wnaf_slices(result); + let expected = from_wnaf_slices(t0, t1); + assert(result == expected); + } } -unconstrained fn get_modulus_slices() -> (Field, Field) { +comptime fn get_modulus_slices() -> (Field, Field) { let bytes = std::field::modulus_be_bytes(); let num_bytes = std::field::modulus_num_bits() / 8; let mut lo: Field = 0; @@ -61,7 +63,7 @@ unconstrained fn get_modulus_slices() -> (Field, Field) { hi *= 256; hi += bytes[i] as Field; lo *= 256; - lo += bytes[i + (num_bytes/2)] as Field; + lo += bytes[i + (num_bytes / 2)] as Field; } if (num_bytes & 1 == 1) { lo *= 256; @@ -99,29 +101,30 @@ impl std::convert::From for ScalarField { for i in 1..(N / 2) { borrow_shift *= 16; lo *= 16; - lo += (slices[(N/2) + i] as Field) * 2 - 15; + lo += (slices[(N / 2) + i] as Field) * 2 - 15; hi *= 16; hi += (slices[i] as Field) * 2 - 15; } if ((N & 1) == 1) { borrow_shift *= 16; lo *= 16; - lo += (slices[N-1] as Field) * 2 - 15; + lo += (slices[N - 1] as Field) * 2 - 15; } lo -= skew as Field; // Validate that the integer represented by (lo, hi) is smaller than the integer represented by (plo, phi) - let (plo, phi) = get_modulus_slices(); - let borrow = get_borrow_flag(plo, lo) as Field; - let rlo = plo - lo + borrow * borrow_shift - 1; // -1 because we are checking a strict <, not <= - let rhi = phi - hi - borrow; - let offset = (N & 1 == 1) as u8; - let hibits = (N as u8 / 2) * 4; - let lobits = hibits + offset * 4; - rlo.assert_max_bit_size(lobits as u32); - rhi.assert_max_bit_size(hibits as u32); + let (plo, phi) = comptime { get_modulus_slices() }; + unsafe { + // Safety: `borrow`'s value is constrained to be correct by below range constraints. + let borrow = get_borrow_flag(plo, lo) as Field; + + let rlo = plo - lo + borrow * borrow_shift - 1; // -1 because we are checking a strict <, not <= + let rhi = phi - hi - borrow; + rlo.assert_max_bit_size::<(N / 2 + N % 2) * 4>(); + rhi.assert_max_bit_size::(); + } } for i in 0..N { - (result.base4_slices[i] as Field).assert_max_bit_size(4); + (result.base4_slices[i] as Field).assert_max_bit_size::<4>(); } result } diff --git a/src/test.nr b/src/test.nr index fcf2404..4b1abf5 100644 --- a/src/test.nr +++ b/src/test.nr @@ -10,7 +10,7 @@ type BabyJubJub = Curve; fn test_sub() { let bjj = baby_jubjub(); let bjj_point = bjj.base8; - let point: Curve = Curve { x: bjj_point.x, y: bjj_point.y }; + let point: Curve = Curve { x: bjj_point.x, y: bjj_point.y }; let expected = point + (point + (point)); let result = point.dbl().dbl().sub(point); @@ -26,7 +26,7 @@ fn test_mul() { let expected = bjj.curve.mul(scalar, bjj_point); let scalar_f: ScalarField<63> = ScalarField::from(scalar); - let point: BabyJubJub = Curve { x: bjj_point.x, y: bjj_point.y }; + let point: BabyJubJub = Curve { x: bjj_point.x, y: bjj_point.y }; let result = point.mul(scalar_f); let scalar_converted: Field = ScalarField::into(scalar_f); assert(scalar_converted == scalar); @@ -40,7 +40,7 @@ fn test_msm() { let bjj = baby_jubjub(); let bjj_point = bjj.base8; - let point: BabyJubJub = Curve { x: bjj_point.x, y: bjj_point.y }; + let point: BabyJubJub = Curve { x: bjj_point.x, y: bjj_point.y }; let mut scalar_values: [Field; 1] = [0; 1]; let mut points: [BabyJubJub; 1] = [Curve { x: 0, y: 0 }; 1]; let mut scalars: [ScalarField<63>; 1] = [ScalarField::new(); 1]; @@ -52,8 +52,8 @@ fn test_msm() { bjj_points[0] = TEPoint::new(points[0].x, points[0].y); for i in 1..1 { - points[i] = points[i-1].dbl(); - scalar_values[i] = scalar_values[i-1] + scalar; + points[i] = points[i - 1].dbl(); + scalar_values[i] = scalar_values[i - 1] + scalar; } for i in 0..1 {