Skip to content

Commit

Permalink
chore: upgrade generic polynomials
Browse files Browse the repository at this point in the history
  • Loading branch information
0xJepsen committed Jul 1, 2024
1 parent 075de89 commit c156a9c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 60 deletions.
2 changes: 1 addition & 1 deletion src/compiler/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ mod tests {
fn program_error() {
let constraints =
&["a public", "d === 9", "b <== a * a + 5", "b public", "c <== -2 * b - a * b"];
let program = Program::new(constraints, 5).unwrap();
let program = Program::<5>::new(constraints).unwrap();

let public_vars = program.public_assignments();

Expand Down
173 changes: 114 additions & 59 deletions src/compiler/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
polynomial::{Lagrange, Polynomial},
};

type Poly = Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField>;
// type Poly =

/// Column represents all three columns in the execution trace which a variable
/// can take.
Expand Down Expand Up @@ -67,42 +67,42 @@ impl Cell {
/// `Program` represents constraints used while defining the arithmetic on the inputs
/// and group order of primitive roots of unity in the field.
#[derive(Debug, PartialEq)]
pub struct Program<'a> {
pub struct Program<'a, const GROUP_ORDER: usize> {
/// `constraints` defined during arithmetic evaluation on inputs in the circuit
constraints: Vec<WireCoeffs<'a>>,
/// order of multiplicative group formed by primitive roots of unity in the scalar field
group_order: usize,
// order of multiplicative group formed by primitive roots of unity in the scalar field
// group_order: usize,
}

/// Represents circuit related input which is apriori known to `Prover` and `Verifier` involved in
/// the process.
pub struct CommonPreprocessedInput {
pub struct CommonPreprocessedInput<const GROUP_ORDER: usize> {
/// multiplicative group order
pub group_order: usize,
// group_order: usize,
/// Q_L(X): left wire selector polynomial
pub ql: Poly,
pub ql: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// Q_R(X): right wire selector polynomial
pub qr: Poly,
pub qr: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// Q_M(X): multiplication gate selector polynomial
pub qm: Poly,
pub qm: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// Q_O(X): output wire selector polynomial
pub qo: Poly,
pub qo: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// Q_C(X): constant selector polynomial
pub qc: Poly,
pub qc: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// S_σ1(X): first permutation polynomial
pub s1: Poly,
pub s1: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// S_σ2(X): second permutation polynomial
pub s2: Poly,
pub s2: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
/// S_σ3(X): third permutation polynomial
pub s3: Poly,
pub s3: Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
}

impl<'a> Program<'a> {
impl<'a, const GROUP_ORDER: usize> Program<'a, GROUP_ORDER> {
/// create a new [`Program`] from list of constraints and group order. Converts constraints into
/// variables and their corresponding activation coefficients
///
/// Assumes: group_order >= constraints.len()
pub fn new(constraints: &[&'a str], group_order: usize) -> Result<Self, ProgramError<'a>> {
pub fn new(constraints: &[&'a str]) -> Result<Self, ProgramError<'a>> {
let assembly: Result<Vec<WireCoeffs>, ParserError> =
constraints.iter().map(|constraint| parse_constraints(constraint)).collect();

Expand All @@ -111,33 +111,48 @@ impl<'a> Program<'a> {
Err(parser_error) => return Err(ProgramError::ParserError(parser_error)),
};

Ok(Self { constraints: assembly, group_order })
Ok(Self { constraints: assembly })
}

/// returns selector polynomial used in execution trace for a gate
fn selector_polynomials(&self) -> (Poly, Poly, Poly, Poly, Poly) {
let mut l = vec![PlutoScalarField::ZERO; self.group_order];
let mut r = vec![PlutoScalarField::ZERO; self.group_order];
let mut m = vec![PlutoScalarField::ZERO; self.group_order];
let mut o = vec![PlutoScalarField::ZERO; self.group_order];
let mut c = vec![PlutoScalarField::ZERO; self.group_order];
#[allow(clippy::type_complexity)]
fn selector_polynomials(
&self,
) -> (
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
) {
let mut l = [PlutoScalarField::ZERO; GROUP_ORDER];
let mut r = [PlutoScalarField::ZERO; GROUP_ORDER];
let mut m = [PlutoScalarField::ZERO; GROUP_ORDER];
let mut o = [PlutoScalarField::ZERO; GROUP_ORDER];
let mut c = [PlutoScalarField::ZERO; GROUP_ORDER];

// iterate through the constraints and assign each selector value
for (i, constraint) in self.constraints.iter().enumerate() {
let gate = constraint.gate();
(l[i], r[i], m[i], o[i], c[i]) = (gate.l, gate.r, gate.m, gate.o, gate.c);
}

let poly_l = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(l);
let poly_r = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(r);
let poly_m = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(m);
let poly_o = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(o);
let poly_c = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(c);
let poly_l = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(l);
let poly_r = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(r);
let poly_m = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(m);
let poly_o = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(o);
let poly_c = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(c);
(poly_l, poly_r, poly_m, poly_o, poly_c)
}

/// Returns `S1,S2,S3` polynomials used for creating permutation argument in PLONK
fn s_polynomials(&self) -> (Poly, Poly, Poly) {
#[allow(clippy::type_complexity)]
fn s_polynomials(
&self,
) -> (
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
Polynomial<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>,
) {
// captures uses of a variable in constraints where each new constraint defines a new row in
// execution trace and columns represent left, right and output wires in a gate.
// Map of variable and (row, column) tuples
Expand Down Expand Up @@ -165,15 +180,15 @@ impl<'a> Program<'a> {
}

// add zero values for unfilled values
for row in self.constraints.len()..self.group_order {
for row in self.constraints.len()..GROUP_ORDER {
let val = variable_uses.get_mut(&None).unwrap();
val.insert(Cell { row: row as u32, column: Column::LEFT });
val.insert(Cell { row: row as u32, column: Column::RIGHT });
val.insert(Cell { row: row as u32, column: Column::OUTPUT });
}

// $S_i$ polynomial in evaluation form
let mut s = vec![vec![PlutoScalarField::ZERO; self.group_order]; 3];
let mut s: [[PlutoScalarField; GROUP_ORDER]; 3] = [[PlutoScalarField::ZERO; GROUP_ORDER]; 3];

// shift each polynomial value right by 1 and assign domain. for example:
// let's say, usage of variable in execution trace looks like:
Expand All @@ -189,23 +204,26 @@ impl<'a> Program<'a> {
let next_i = (i + 1) % row_cols.len();
let next_column = row_cols[next_i].column as u32 - 1;
let next_row = row_cols[next_i].row;
s[next_column as usize][next_row as usize] = cell.label(self.group_order);
s[next_column as usize][next_row as usize] = cell.label(GROUP_ORDER);
}
}

// create polynomials in lagrange basis from variable values as evaluations
let poly_s1 = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(s[0].clone());
let poly_s2 = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(s[1].clone());
let poly_s3 = Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField>::new(s[2].clone());
let poly_s1 =
Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(s[0]);
let poly_s2 =
Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(s[1]);
let poly_s3 =
Polynomial::<Lagrange<PlutoScalarField>, PlutoScalarField, GROUP_ORDER>::new(s[2]);
(poly_s1, poly_s2, poly_s3)
}

/// creates selector and permutation helper polynomials from constraints as part of circuit
/// preprocessing
pub fn common_preprocessed_input(&self) -> CommonPreprocessedInput {
pub fn common_preprocessed_input(&self) -> CommonPreprocessedInput<GROUP_ORDER> {
let (s1, s2, s3) = self.s_polynomials();
let (ql, qr, qm, qo, qc) = self.selector_polynomials();
CommonPreprocessedInput { group_order: self.group_order, ql, qr, qm, qo, qc, s1, s2, s3 }
CommonPreprocessedInput { ql, qr, qm, qo, qc, s1, s2, s3 }
}

/// returns public variables assigned in the circuit
Expand Down Expand Up @@ -308,7 +326,7 @@ mod tests {
#[test]
fn new_program() {
let constraints = &["a public", "b <== a * a"];
let program = Program::new(constraints, 5);
let program = Program::<5>::new(constraints);
assert!(program.is_ok());

assert_eq!(program.unwrap(), Program {
Expand All @@ -326,35 +344,34 @@ mod tests {
coeffs: HashMap::from([(String::from("a*a"), 1)]),
}
]),
group_order: 5,
})
}

#[rstest]
fn s_polys(constraint1: &[&str]) {
// TODO: make this more robust
let program = Program::new(constraint1, 4);
let program = Program::<4>::new(constraint1);

assert!(program.is_ok());

let program = program.unwrap();
let (s1, s2, s3) = program.s_polynomials();

assert_eq!(s1.coefficients, vec![
assert_eq!(s1.coefficients.to_vec(), vec![
PlutoScalarField::from(4),
PlutoScalarField::from(3),
PlutoScalarField::from(1),
PlutoScalarField::from(15),
]);

assert_eq!(s2.coefficients, vec![
assert_eq!(s2.coefficients.to_vec(), vec![
PlutoScalarField::from(9),
PlutoScalarField::from(13),
PlutoScalarField::from(16),
PlutoScalarField::from(14),
]);

assert_eq!(s3.coefficients, vec![
assert_eq!(s3.coefficients.to_vec(), vec![
PlutoScalarField::from(2),
PlutoScalarField::from(5),
PlutoScalarField::from(8),
Expand All @@ -365,37 +382,37 @@ mod tests {
#[test]
fn selector_polys() {
let constraint = &["a public", "d === 9", "b <== a * a + 5", "c <== -2 * b - a * b"];
let program = Program::new(constraint, 4);
let program = Program::<4>::new(constraint);

assert!(program.is_ok());
let program = program.unwrap();

let (ql, qr, qm, qo, qc) = program.selector_polynomials();
assert_eq!(ql.coefficients, vec![
assert_eq!(ql.coefficients.to_vec(), vec![
PlutoScalarField::new(1), // first constraint, left variable is equal to 1
PlutoScalarField::new(0), // second constraint, no left variable
PlutoScalarField::new(0),
PlutoScalarField::new(0),
]);
assert_eq!(qr.coefficients, vec![
assert_eq!(qr.coefficients.to_vec(), vec![
PlutoScalarField::new(0),
PlutoScalarField::new(0),
PlutoScalarField::new(0),
PlutoScalarField::new(2), // 4th constraint, right variable `b`'s coeff = -(-2) = 2
]);
assert_eq!(qm.coefficients, vec![
assert_eq!(qm.coefficients.to_vec(), vec![
PlutoScalarField::new(0),
PlutoScalarField::new(0),
PlutoScalarField::from(-1), // 3rd constraint, `mul` variable = `a*a`, coeff = -(1)
PlutoScalarField::new(1), // 4th, `a*b` = -(-1)
]);
assert_eq!(qo.coefficients, vec![
assert_eq!(qo.coefficients.to_vec(), vec![
PlutoScalarField::new(0),
PlutoScalarField::new(1), // `d`: 1
PlutoScalarField::new(1), // `b`: 1
PlutoScalarField::new(1), // `c`: 1
]);
assert_eq!(qc.coefficients, vec![
assert_eq!(qc.coefficients.to_vec(), vec![
PlutoScalarField::new(0),
PlutoScalarField::new(17 - 9),
PlutoScalarField::new(17 - 5),
Expand All @@ -410,7 +427,7 @@ mod tests {
#[should_panic]
#[case(&["a public", "d === 9", "b <== a * a + 5", "b public", "c <== -2 * b - a * b"], vec![])]
fn public_vars(#[case] constraint: &[&str], #[case] expected: Vec<String>) {
let program = Program::new(constraint, 5);
let program = Program::<5>::new(constraint);
assert!(program.is_ok());

let program = program.unwrap();
Expand All @@ -420,21 +437,59 @@ mod tests {
assert_eq!(public_vars.unwrap(), expected);
}

#[allow(unused_braces)]
#[fixture]
fn group_order_4() -> usize { 4 }

#[allow(unused_braces)]
#[fixture]
fn group_order_8() -> usize { 8 }

#[rstest]
#[case(&["a public", "d === 9", "b <== a * a + 5", "c <== -2 * b - a * b"], 4, vec![PlutoScalarField::from(2)],
HashMap::from([(None, PlutoScalarField::from(0)), (Some("d"), PlutoScalarField::from(9)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9)),
(Some("c"), PlutoScalarField::from(-36))]))]
#[case(&["a public", "b public", "pq public", "b === pq", "c <== -a * b + 9", "e <== a + b * -3"],
8, vec![PlutoScalarField::from(2), PlutoScalarField::from(1), PlutoScalarField::from(1)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(1)), (Some("c"), PlutoScalarField::from(7)), (Some("pq"), PlutoScalarField::from(1)), (Some("e"), PlutoScalarField::from(-1))]))]
#[case(&["a public", "d === 9", "b <== a * a + 5", "c <== -2 * b - a * b"], vec![PlutoScalarField::from(2)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("d"), PlutoScalarField::from(9)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9)), (Some("c"), PlutoScalarField::from(-36))]))]
#[should_panic]
#[case(&["a public", "b === 9", "b <== a * a"], 4, vec![PlutoScalarField::from(2)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9))]))]
#[case(&["a public", "b === 9", "b <== a * a"], vec![PlutoScalarField::from(2)], HashMap::from([(None, PlutoScalarField::from(0)), (Some("a"), PlutoScalarField::from(2)), (Some("b"), PlutoScalarField::from(9))]))]
fn evaluate_circuit_constraints(
#[case] constraint1: &[&str],
#[case] group_order: usize,
#[case] public_var_values: Vec<PlutoScalarField>,
#[case] expected: HashMap<Option<&str>, PlutoScalarField>,
) {
let program = Program::new(constraint1, group_order);
let program = Program::<4>::new(constraint1);
assert!(program.is_ok());

let program = program.unwrap();
let public_vars = program.public_assignments();
assert!(public_vars.is_ok());

let public_vars = public_vars.unwrap();
assert_eq!(public_vars.len(), public_var_values.len());

let starting_assignments: HashMap<Option<&str>, PlutoScalarField> =
HashMap::from_iter(public_vars.iter().map(|var| Some(var.as_str())).zip(public_var_values));
let evaluations = program.evaluate_circuit(starting_assignments);

assert!(evaluations.is_ok());

assert_eq!(evaluations.unwrap(), expected);
}

// #[case(, group_order_8(), ,

#[test]
fn evaluate_circuit_constraints_with_group_order_8() {
let public_var_values =
vec![PlutoScalarField::from(2), PlutoScalarField::from(1), PlutoScalarField::from(1)];
let expected = HashMap::from([
(None, PlutoScalarField::from(0)),
(Some("a"), PlutoScalarField::from(2)),
(Some("b"), PlutoScalarField::from(1)),
(Some("c"), PlutoScalarField::from(7)),
(Some("pq"), PlutoScalarField::from(1)),
(Some("e"), PlutoScalarField::from(-1)),
]);
let constraints =
&["a public", "b public", "pq public", "b === pq", "c <== -a * b + 9", "e <== a + b * -3"];
let program = Program::<8>::new(constraints);
assert!(program.is_ok());

let program = program.unwrap();
Expand Down

0 comments on commit c156a9c

Please sign in to comment.