Skip to content

Commit

Permalink
Simplify expressions. (#894)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti authored Nov 28, 2024
1 parent 9a9ae36 commit b6ea512
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 15 deletions.
100 changes: 92 additions & 8 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use num_traits::{One, Zero};

use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::{self, BaseField};
use crate::core::fields::m31::{self, BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;
Expand Down Expand Up @@ -36,7 +36,6 @@ pub enum Expr {
}

impl Expr {
#[allow(dead_code)]
pub fn format_expr(&self) -> String {
match self {
Expr::Col(ColumnExpr {
Expand Down Expand Up @@ -67,6 +66,10 @@ impl Expr {
Expr::Inv(a) => format!("1 / ({})", a.format_expr()),
}
}

pub fn simplify_and_format(&self) -> String {
simplify(self.clone()).format_expr()
}
}

impl From<BaseField> for Expr {
Expand Down Expand Up @@ -190,6 +193,88 @@ impl AddAssign<BaseField> for Expr {
}
}

const ZERO: M31 = M31(0);
const ONE: M31 = M31(1);
const MINUS_ONE: M31 = M31(m31::P - 1);

// TODO(alont) Add random point assignment test.
pub fn simplify(expr: Expr) -> Expr {
match expr {
Expr::Add(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (a.clone(), b.clone()) {
(Expr::Const(a), Expr::Const(b)) => Expr::Const(a + b),
(Expr::Const(ZERO), _) => b, // 0 + b = b
(_, Expr::Const(ZERO)) => a, // a + 0 = a
// (-a + -b) = -(a + b)
(Expr::Neg(minus_a), Expr::Neg(minus_b)) => -(*minus_a + *minus_b),
(Expr::Neg(minus_a), _) => b - *minus_a, // -a + b = b - a
(_, Expr::Neg(minus_b)) => a - *minus_b, // a + -b = a - b
_ => Expr::Add(Box::new(a), Box::new(b)),
}
}
Expr::Sub(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (a.clone(), b.clone()) {
(Expr::Const(a), Expr::Const(b)) => Expr::Const(a - b),
(Expr::Const(ZERO), _) => -b, // 0 - b = -b
(_, Expr::Const(ZERO)) => a, // a - 0 = a
// (-a - -b) = b - a
(Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_b - *minus_a,
(Expr::Neg(minus_a), _) => -(*minus_a + b), // -a - b = -(a + b)
(_, Expr::Neg(minus_b)) => a + *minus_b, // a + -b = a - b
_ => Expr::Sub(Box::new(a), Box::new(b)),
}
}
Expr::Mul(a, b) => {
let a = simplify(*a);
let b = simplify(*b);
match (a.clone(), b.clone()) {
(Expr::Const(a), Expr::Const(b)) => Expr::Const(a * b),
(Expr::Const(ZERO), _) => Expr::zero(), // 0 * b = 0
(_, Expr::Const(ZERO)) => Expr::zero(), // a * 0 = 0
(Expr::Const(ONE), _) => b, // 1 * b = b
(_, Expr::Const(ONE)) => a, // a * 1 = a
// (-a) * (-b) = a * b
(Expr::Neg(minus_a), Expr::Neg(minus_b)) => *minus_a * *minus_b,
(Expr::Neg(minus_a), _) => -(*minus_a * b), // (-a) * b = -(a * b)
(_, Expr::Neg(minus_b)) => -(a * *minus_b), // a * (-b) = -(a * b)
(Expr::Const(MINUS_ONE), _) => -b, // -1 * b = -b
(_, Expr::Const(MINUS_ONE)) => -a, // a * -1 = -a
_ => Expr::Mul(Box::new(a), Box::new(b)),
}
}
Expr::Col(colexpr) => Expr::Col(colexpr),
Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([
Box::new(simplify(*a)),
Box::new(simplify(*b)),
Box::new(simplify(*c)),
Box::new(simplify(*d)),
]),
Expr::Const(c) => Expr::Const(c),
Expr::Param(x) => Expr::Param(x),
Expr::Neg(a) => {
let a = simplify(*a);
match a {
Expr::Const(c) => Expr::Const(-c),
Expr::Neg(minus_a) => *minus_a, // -(-a) = a
Expr::Sub(a, b) => Expr::Sub(b, a), // -(a - b) = b - a
_ => Expr::Neg(Box::new(a)),
}
}
Expr::Inv(a) => {
let a = simplify(*a);
match a {
Expr::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a
Expr::Const(c) => Expr::Const(c.inverse()),
_ => Expr::Inv(Box::new(a)),
}
}
}
}

/// Returns the expression
/// `value[0] * <relation>_alpha0 + value[1] * <relation>_alpha1 + ... - <relation>_z.`
fn combine_formal<R: Relation<Expr, Expr>>(relation: &R, values: &[Expr]) -> Expr {
Expand Down Expand Up @@ -273,15 +358,15 @@ impl ExprEvaluator {
let lets_string = self
.intermediates
.iter()
.map(|(name, expr)| format!("let {} = {};", name, expr.format_expr()))
.map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format()))
.collect::<Vec<String>>()
.join("\n");

let constraints_str = self
.constraints
.iter()
.enumerate()
.map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";")
.map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";")
.collect::<Vec<String>>()
.join("\n\n");

Expand Down Expand Up @@ -369,8 +454,7 @@ mod tests {
fn test_format_expr() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let expected = "let intermediate0 = 0 \
+ (TestRelation_alpha0) * (col_1_0[0]) \
let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
+ (TestRelation_alpha2) * (col_1_2[0]) \
- (TestRelation_z);
Expand All @@ -382,8 +466,8 @@ mod tests {
\
let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \
- (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \
- ((col_0_3[0]) * (total_sum))) \
- (0)) \
- ((col_0_3[0]) * (total_sum)))\
) \
* (intermediate0) \
- (1);"
.to_string();
Expand Down
12 changes: 5 additions & 7 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,11 @@ mod tests {
);

let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true));
let expected = "let intermediate0 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0]) \
let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
\
let intermediate1 = 0 \
+ (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
Expand All @@ -322,10 +320,10 @@ mod tests {
\
let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \
- (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \
- ((col_0_2[0]) * (total_sum))) \
- (0)) \
- ((col_0_2[0]) * (total_sum)))\
) \
* ((intermediate0) * (intermediate1)) \
- ((intermediate1) * (1) + (intermediate0) * (-(1)));"
- (intermediate1 - (intermediate0));"
.to_string();

assert_eq!(eval.format_constraints(), expected);
Expand Down

1 comment on commit b6ea512

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: b6ea512 Previous: cd8b37b Ratio
iffts/simd ifft/22 12762048 ns/iter (± 325632) 6306399 ns/iter (± 210024) 2.02
merkle throughput/simd merkle 29947591 ns/iter (± 517401) 13712527 ns/iter (± 579195) 2.18

This comment was automatically generated by workflow using github-action-benchmark.

CC: @shaharsamocha7

Please sign in to comment.