diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 5d8013402..304c9f0e6 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -3,7 +3,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::m31::{self, BaseField}; +use crate::core::fields::m31::{self, BaseField, M31}; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; @@ -36,7 +36,6 @@ pub enum Expr { } impl Expr { - #[allow(dead_code)] pub fn format_expr(&self) -> String { match self { Expr::Col(ColumnExpr { @@ -67,6 +66,10 @@ impl Expr { Expr::Inv(a) => format!("1 / ({})", a.format_expr()), } } + + pub fn simplify_and_format(&self) -> String { + simplify(self.clone()).format_expr() + } } impl From for Expr { @@ -190,6 +193,88 @@ impl AddAssign for Expr { } } +const ZERO: M31 = M31(0); +const ONE: M31 = M31(1); +const MINUS_ONE: M31 = M31(m31::P - 1); + +// TODO(alont) Add random point assignment test. +pub fn simplify(expr: Expr) -> Expr { + match expr { + Expr::Add(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + match (a.clone(), b.clone()) { + (Expr::Const(a), Expr::Const(b)) => Expr::Const(a + b), + (Expr::Const(ZERO), _) => b, // 0 + b = b + (_, Expr::Const(ZERO)) => a, // a + 0 = a + // (-a + -b) = -(a + b) + (Expr::Neg(minus_a), Expr::Neg(minus_b)) => -(*minus_a + *minus_b), + (Expr::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a + (_, Expr::Neg(minus_b)) => a - *minus_b, // a + -b = a - b + _ => Expr::Add(Box::new(a), Box::new(b)), + } + } + Expr::Sub(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + match (a.clone(), b.clone()) { + (Expr::Const(a), Expr::Const(b)) => Expr::Const(a - b), + (Expr::Const(ZERO), _) => -b, // 0 - b = -b + (_, Expr::Const(ZERO)) => a, // a - 0 = a + // (-a - -b) = b - a + (Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_b - *minus_a, + (Expr::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b) + (_, Expr::Neg(minus_b)) => a + *minus_b, // a + -b = a - b + _ => Expr::Sub(Box::new(a), Box::new(b)), + } + } + Expr::Mul(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + match (a.clone(), b.clone()) { + (Expr::Const(a), Expr::Const(b)) => Expr::Const(a * b), + (Expr::Const(ZERO), _) => Expr::zero(), // 0 * b = 0 + (_, Expr::Const(ZERO)) => Expr::zero(), // a * 0 = 0 + (Expr::Const(ONE), _) => b, // 1 * b = b + (_, Expr::Const(ONE)) => a, // a * 1 = a + // (-a) * (-b) = a * b + (Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_a * *minus_b, + (Expr::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b) + (_, Expr::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b) + (Expr::Const(MINUS_ONE), _) => -b, // -1 * b = -b + (_, Expr::Const(MINUS_ONE)) => -a, // a * -1 = -a + _ => Expr::Mul(Box::new(a), Box::new(b)), + } + } + Expr::Col(colexpr) => Expr::Col(colexpr), + Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([ + Box::new(simplify(*a)), + Box::new(simplify(*b)), + Box::new(simplify(*c)), + Box::new(simplify(*d)), + ]), + Expr::Const(c) => Expr::Const(c), + Expr::Param(x) => Expr::Param(x), + Expr::Neg(a) => { + let a = simplify(*a); + match a { + Expr::Const(c) => Expr::Const(-c), + Expr::Neg(minus_a) => *minus_a, // -(-a) = a + Expr::Sub(a, b) => Expr::Sub(b, a), // -(a - b) = b - a + _ => Expr::Neg(Box::new(a)), + } + } + Expr::Inv(a) => { + let a = simplify(*a); + match a { + Expr::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a + Expr::Const(c) => Expr::Const(c.inverse()), + _ => Expr::Inv(Box::new(a)), + } + } + } +} + /// Returns the expression /// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { @@ -273,7 +358,7 @@ impl ExprEvaluator { let lets_string = self .intermediates .iter() - .map(|(name, expr)| format!("let {} = {};", name, expr.format_expr())) + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) .collect::>() .join("\n"); @@ -281,7 +366,7 @@ impl ExprEvaluator { .constraints .iter() .enumerate() - .map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";") + .map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";") .collect::>() .join("\n\n"); @@ -369,8 +454,7 @@ mod tests { fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = 0 \ - + (TestRelation_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ + (TestRelation_alpha2) * (col_1_2[0]) \ - (TestRelation_z); @@ -382,8 +466,8 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_3[0]) * (total_sum)))\ + ) \ * (intermediate0) \ - (1);" .to_string(); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 787394dd3..2cf8bc2f2 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -300,13 +300,11 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - let expected = "let intermediate0 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); \ - let intermediate1 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); @@ -322,10 +320,10 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_2[0]) * (total_sum)))\ + ) \ * ((intermediate0) * (intermediate1)) \ - - ((intermediate1) * (1) + (intermediate0) * (-(1)));" + - (intermediate1 - (intermediate0));" .to_string(); assert_eq!(eval.format_constraints(), expected);