From c156a9c0b1bcb55f1a8c263c76df91a911277a89 Mon Sep 17 00:00:00 2001 From: Waylon Jepsen Date: Sun, 30 Jun 2024 19:50:01 -0600 Subject: [PATCH] chore: upgrade generic polynomials --- src/compiler/errors.rs | 2 +- src/compiler/program.rs | 173 ++++++++++++++++++++++++++-------------- 2 files changed, 115 insertions(+), 60 deletions(-) diff --git a/src/compiler/errors.rs b/src/compiler/errors.rs index 5c5cdd0a..e2db422c 100644 --- a/src/compiler/errors.rs +++ b/src/compiler/errors.rs @@ -64,7 +64,7 @@ mod tests { fn program_error() { let constraints = &["a public", "d === 9", "b <== a * a + 5", "b public", "c <== -2 * b - a * b"]; - let program = Program::new(constraints, 5).unwrap(); + let program = Program::<5>::new(constraints).unwrap(); let public_vars = program.public_assignments(); diff --git a/src/compiler/program.rs b/src/compiler/program.rs index 1ad6f396..6a421eb3 100644 --- a/src/compiler/program.rs +++ b/src/compiler/program.rs @@ -19,7 +19,7 @@ use crate::{ polynomial::{Lagrange, Polynomial}, }; -type Poly = Polynomial, PlutoScalarField>; +// type Poly = /// Column represents all three columns in the execution trace which a variable /// can take. @@ -67,42 +67,42 @@ impl Cell { /// `Program` represents constraints used while defining the arithmetic on the inputs /// and group order of primitive roots of unity in the field. #[derive(Debug, PartialEq)] -pub struct Program<'a> { +pub struct Program<'a, const GROUP_ORDER: usize> { /// `constraints` defined during arithmetic evaluation on inputs in the circuit constraints: Vec>, - /// order of multiplicative group formed by primitive roots of unity in the scalar field - group_order: usize, + // order of multiplicative group formed by primitive roots of unity in the scalar field + // group_order: usize, } /// Represents circuit related input which is apriori known to `Prover` and `Verifier` involved in /// the process. -pub struct CommonPreprocessedInput { +pub struct CommonPreprocessedInput { /// multiplicative group order - pub group_order: usize, + // group_order: usize, /// Q_L(X): left wire selector polynomial - pub ql: Poly, + pub ql: Polynomial, PlutoScalarField, GROUP_ORDER>, /// Q_R(X): right wire selector polynomial - pub qr: Poly, + pub qr: Polynomial, PlutoScalarField, GROUP_ORDER>, /// Q_M(X): multiplication gate selector polynomial - pub qm: Poly, + pub qm: Polynomial, PlutoScalarField, GROUP_ORDER>, /// Q_O(X): output wire selector polynomial - pub qo: Poly, + pub qo: Polynomial, PlutoScalarField, GROUP_ORDER>, /// Q_C(X): constant selector polynomial - pub qc: Poly, + pub qc: Polynomial, PlutoScalarField, GROUP_ORDER>, /// S_σ1(X): first permutation polynomial - pub s1: Poly, + pub s1: Polynomial, PlutoScalarField, GROUP_ORDER>, /// S_σ2(X): second permutation polynomial - pub s2: Poly, + pub s2: Polynomial, PlutoScalarField, GROUP_ORDER>, /// S_σ3(X): third permutation polynomial - pub s3: Poly, + pub s3: Polynomial, PlutoScalarField, GROUP_ORDER>, } -impl<'a> Program<'a> { +impl<'a, const GROUP_ORDER: usize> Program<'a, GROUP_ORDER> { /// create a new [`Program`] from list of constraints and group order. Converts constraints into /// variables and their corresponding activation coefficients /// /// Assumes: group_order >= constraints.len() - pub fn new(constraints: &[&'a str], group_order: usize) -> Result> { + pub fn new(constraints: &[&'a str]) -> Result> { let assembly: Result, ParserError> = constraints.iter().map(|constraint| parse_constraints(constraint)).collect(); @@ -111,33 +111,48 @@ impl<'a> Program<'a> { Err(parser_error) => return Err(ProgramError::ParserError(parser_error)), }; - Ok(Self { constraints: assembly, group_order }) + Ok(Self { constraints: assembly }) } /// returns selector polynomial used in execution trace for a gate - fn selector_polynomials(&self) -> (Poly, Poly, Poly, Poly, Poly) { - let mut l = vec![PlutoScalarField::ZERO; self.group_order]; - let mut r = vec![PlutoScalarField::ZERO; self.group_order]; - let mut m = vec![PlutoScalarField::ZERO; self.group_order]; - let mut o = vec![PlutoScalarField::ZERO; self.group_order]; - let mut c = vec![PlutoScalarField::ZERO; self.group_order]; + #[allow(clippy::type_complexity)] + fn selector_polynomials( + &self, + ) -> ( + Polynomial, PlutoScalarField, GROUP_ORDER>, + Polynomial, PlutoScalarField, GROUP_ORDER>, + Polynomial, PlutoScalarField, GROUP_ORDER>, + Polynomial, PlutoScalarField, GROUP_ORDER>, + Polynomial, PlutoScalarField, GROUP_ORDER>, + ) { + let mut l = [PlutoScalarField::ZERO; GROUP_ORDER]; + let mut r = [PlutoScalarField::ZERO; GROUP_ORDER]; + let mut m = [PlutoScalarField::ZERO; GROUP_ORDER]; + let mut o = [PlutoScalarField::ZERO; GROUP_ORDER]; + let mut c = [PlutoScalarField::ZERO; GROUP_ORDER]; // iterate through the constraints and assign each selector value for (i, constraint) in self.constraints.iter().enumerate() { let gate = constraint.gate(); (l[i], r[i], m[i], o[i], c[i]) = (gate.l, gate.r, gate.m, gate.o, gate.c); } - - let poly_l = Polynomial::, PlutoScalarField>::new(l); - let poly_r = Polynomial::, PlutoScalarField>::new(r); - let poly_m = Polynomial::, PlutoScalarField>::new(m); - let poly_o = Polynomial::, PlutoScalarField>::new(o); - let poly_c = Polynomial::, PlutoScalarField>::new(c); + let poly_l = Polynomial::, PlutoScalarField, GROUP_ORDER>::new(l); + let poly_r = Polynomial::, PlutoScalarField, GROUP_ORDER>::new(r); + let poly_m = Polynomial::, PlutoScalarField, GROUP_ORDER>::new(m); + let poly_o = Polynomial::, PlutoScalarField, GROUP_ORDER>::new(o); + let poly_c = Polynomial::, PlutoScalarField, GROUP_ORDER>::new(c); (poly_l, poly_r, poly_m, poly_o, poly_c) } /// Returns `S1,S2,S3` polynomials used for creating permutation argument in PLONK - fn s_polynomials(&self) -> (Poly, Poly, Poly) { + #[allow(clippy::type_complexity)] + fn s_polynomials( + &self, + ) -> ( + Polynomial, PlutoScalarField, GROUP_ORDER>, + Polynomial, PlutoScalarField, GROUP_ORDER>, + Polynomial, PlutoScalarField, GROUP_ORDER>, + ) { // captures uses of a variable in constraints where each new constraint defines a new row in // execution trace and columns represent left, right and output wires in a gate. // Map of variable and (row, column) tuples @@ -165,7 +180,7 @@ impl<'a> Program<'a> { } // add zero values for unfilled values - for row in self.constraints.len()..self.group_order { + for row in self.constraints.len()..GROUP_ORDER { let val = variable_uses.get_mut(&None).unwrap(); val.insert(Cell { row: row as u32, column: Column::LEFT }); val.insert(Cell { row: row as u32, column: Column::RIGHT }); @@ -173,7 +188,7 @@ impl<'a> Program<'a> { } // $S_i$ polynomial in evaluation form - let mut s = vec![vec![PlutoScalarField::ZERO; self.group_order]; 3]; + let mut s: [[PlutoScalarField; GROUP_ORDER]; 3] = [[PlutoScalarField::ZERO; GROUP_ORDER]; 3]; // shift each polynomial value right by 1 and assign domain. for example: // let's say, usage of variable in execution trace looks like: @@ -189,23 +204,26 @@ impl<'a> Program<'a> { let next_i = (i + 1) % row_cols.len(); let next_column = row_cols[next_i].column as u32 - 1; let next_row = row_cols[next_i].row; - s[next_column as usize][next_row as usize] = cell.label(self.group_order); + s[next_column as usize][next_row as usize] = cell.label(GROUP_ORDER); } } // create polynomials in lagrange basis from variable values as evaluations - let poly_s1 = Polynomial::, PlutoScalarField>::new(s[0].clone()); - let poly_s2 = Polynomial::, PlutoScalarField>::new(s[1].clone()); - let poly_s3 = Polynomial::, PlutoScalarField>::new(s[2].clone()); + let poly_s1 = + Polynomial::, PlutoScalarField, GROUP_ORDER>::new(s[0]); + let poly_s2 = + Polynomial::, PlutoScalarField, GROUP_ORDER>::new(s[1]); + let poly_s3 = + Polynomial::, PlutoScalarField, GROUP_ORDER>::new(s[2]); (poly_s1, poly_s2, poly_s3) } /// creates selector and permutation helper polynomials from constraints as part of circuit /// preprocessing - pub fn common_preprocessed_input(&self) -> CommonPreprocessedInput { + pub fn common_preprocessed_input(&self) -> CommonPreprocessedInput { let (s1, s2, s3) = self.s_polynomials(); let (ql, qr, qm, qo, qc) = self.selector_polynomials(); - CommonPreprocessedInput { group_order: self.group_order, ql, qr, qm, qo, qc, s1, s2, s3 } + CommonPreprocessedInput { ql, qr, qm, qo, qc, s1, s2, s3 } } /// returns public variables assigned in the circuit @@ -308,7 +326,7 @@ mod tests { #[test] fn new_program() { let constraints = &["a public", "b <== a * a"]; - let program = Program::new(constraints, 5); + let program = Program::<5>::new(constraints); assert!(program.is_ok()); assert_eq!(program.unwrap(), Program { @@ -326,35 +344,34 @@ mod tests { coeffs: HashMap::from([(String::from("a*a"), 1)]), } ]), - group_order: 5, }) } #[rstest] fn s_polys(constraint1: &[&str]) { // TODO: make this more robust - let program = Program::new(constraint1, 4); + let program = Program::<4>::new(constraint1); assert!(program.is_ok()); let program = program.unwrap(); let (s1, s2, s3) = program.s_polynomials(); - assert_eq!(s1.coefficients, vec![ + assert_eq!(s1.coefficients.to_vec(), vec![ PlutoScalarField::from(4), PlutoScalarField::from(3), PlutoScalarField::from(1), PlutoScalarField::from(15), ]); - assert_eq!(s2.coefficients, vec![ + assert_eq!(s2.coefficients.to_vec(), vec![ PlutoScalarField::from(9), PlutoScalarField::from(13), PlutoScalarField::from(16), PlutoScalarField::from(14), ]); - assert_eq!(s3.coefficients, vec![ + assert_eq!(s3.coefficients.to_vec(), vec![ PlutoScalarField::from(2), PlutoScalarField::from(5), PlutoScalarField::from(8), @@ -365,37 +382,37 @@ mod tests { #[test] fn selector_polys() { let constraint = &["a public", "d === 9", "b <== a * a + 5", "c <== -2 * b - a * b"]; - let program = Program::new(constraint, 4); + let program = Program::<4>::new(constraint); assert!(program.is_ok()); let program = program.unwrap(); let (ql, qr, qm, qo, qc) = program.selector_polynomials(); - assert_eq!(ql.coefficients, vec![ + assert_eq!(ql.coefficients.to_vec(), vec![ PlutoScalarField::new(1), // first constraint, left variable is equal to 1 PlutoScalarField::new(0), // second constraint, no left variable PlutoScalarField::new(0), PlutoScalarField::new(0), ]); - assert_eq!(qr.coefficients, vec![ + assert_eq!(qr.coefficients.to_vec(), vec![ PlutoScalarField::new(0), PlutoScalarField::new(0), PlutoScalarField::new(0), PlutoScalarField::new(2), // 4th constraint, right variable `b`'s coeff = -(-2) = 2 ]); - assert_eq!(qm.coefficients, vec![ + assert_eq!(qm.coefficients.to_vec(), vec![ PlutoScalarField::new(0), PlutoScalarField::new(0), PlutoScalarField::from(-1), // 3rd constraint, `mul` variable = `a*a`, coeff = -(1) PlutoScalarField::new(1), // 4th, `a*b` = -(-1) ]); - assert_eq!(qo.coefficients, vec![ + assert_eq!(qo.coefficients.to_vec(), vec![ PlutoScalarField::new(0), PlutoScalarField::new(1), // `d`: 1 PlutoScalarField::new(1), // `b`: 1 PlutoScalarField::new(1), // `c`: 1 ]); - assert_eq!(qc.coefficients, vec![ + assert_eq!(qc.coefficients.to_vec(), vec![ PlutoScalarField::new(0), PlutoScalarField::new(17 - 9), PlutoScalarField::new(17 - 5), @@ -410,7 +427,7 @@ mod tests { #[should_panic] #[case(&["a public", "d === 9", "b <== a * a + 5", "b public", "c <== -2 * b - a * b"], vec![])] fn public_vars(#[case] constraint: &[&str], #[case] expected: Vec) { - let program = Program::new(constraint, 5); + let program = Program::<5>::new(constraint); assert!(program.is_ok()); let program = program.unwrap(); @@ -420,21 +437,59 @@ mod tests { assert_eq!(public_vars.unwrap(), expected); } + #[allow(unused_braces)] + #[fixture] + fn group_order_4() -> usize { 4 } + + #[allow(unused_braces)] + #[fixture] + fn group_order_8() -> usize { 8 } + #[rstest] - #[case(&["a public", "d === 9", "b <== a * a + 5", "c <== -2 * b - a * b"], 4, vec![PlutoScalarField::from(2)], - HashMap::from([(None, PlutoScalarField::from(0)), (Some("d"), PlutoScalarField::from(9)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9)), - (Some("c"), PlutoScalarField::from(-36))]))] - #[case(&["a public", "b public", "pq public", "b === pq", "c <== -a * b + 9", "e <== a + b * -3"], - 8, vec![PlutoScalarField::from(2), PlutoScalarField::from(1), PlutoScalarField::from(1)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(1)), (Some("c"), PlutoScalarField::from(7)), (Some("pq"), PlutoScalarField::from(1)), (Some("e"), PlutoScalarField::from(-1))]))] + #[case(&["a public", "d === 9", "b <== a * a + 5", "c <== -2 * b - a * b"], vec![PlutoScalarField::from(2)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("d"), PlutoScalarField::from(9)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9)), (Some("c"), PlutoScalarField::from(-36))]))] #[should_panic] - #[case(&["a public", "b === 9", "b <== a * a"], 4, vec![PlutoScalarField::from(2)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9))]))] + #[case(&["a public", "b === 9", "b <== a * a"], vec![PlutoScalarField::from(2)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9))]))] fn evaluate_circuit_constraints( #[case] constraint1: &[&str], - #[case] group_order: usize, #[case] public_var_values: Vec, #[case] expected: HashMap, PlutoScalarField>, ) { - let program = Program::new(constraint1, group_order); + let program = Program::<4>::new(constraint1); + assert!(program.is_ok()); + + let program = program.unwrap(); + let public_vars = program.public_assignments(); + assert!(public_vars.is_ok()); + + let public_vars = public_vars.unwrap(); + assert_eq!(public_vars.len(), public_var_values.len()); + + let starting_assignments: HashMap, PlutoScalarField> = + HashMap::from_iter(public_vars.iter().map(|var| Some(var.as_str())).zip(public_var_values)); + let evaluations = program.evaluate_circuit(starting_assignments); + + assert!(evaluations.is_ok()); + + assert_eq!(evaluations.unwrap(), expected); + } + + // #[case(, group_order_8(), , + + #[test] + fn evaluate_circuit_constraints_with_group_order_8() { + let public_var_values = + vec![PlutoScalarField::from(2), PlutoScalarField::from(1), PlutoScalarField::from(1)]; + let expected = HashMap::from([ + (None, PlutoScalarField::from(0)), + (Some("a"), PlutoScalarField::from(2)), + (Some("b"), PlutoScalarField::from(1)), + (Some("c"), PlutoScalarField::from(7)), + (Some("pq"), PlutoScalarField::from(1)), + (Some("e"), PlutoScalarField::from(-1)), + ]); + let constraints = + &["a public", "b public", "pq public", "b === pq", "c <== -a * b + 9", "e <== a + b * -3"]; + let program = Program::<8>::new(constraints); assert!(program.is_ok()); let program = program.unwrap();