Skip to content

Commit

Permalink
Added test for expression simplifier. (#908)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti authored Dec 2, 2024
1 parent 4af2d44 commit a459c48
Showing 1 changed file with 111 additions and 8 deletions.
119 changes: 111 additions & 8 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -750,16 +750,64 @@ mod tests {
use std::collections::HashMap;

use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ExprEvaluator;
use crate::constraint_framework::{
relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry,
};
use crate::core::fields::m31::BaseField;
use crate::core::fields::m31::{self, BaseField};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;

macro_rules! secure_col {
($a:expr, $b:expr, $c:expr, $d:expr) => {
ExtExpr::SecureCol([
Box::new($a.into()),
Box::new($b.into()),
Box::new($c.into()),
Box::new($d.into()),
])
};
}

macro_rules! col {
($interaction:expr, $idx:expr, $offset:expr) => {
BaseExpr::Col(($interaction, $idx, $offset).into())
};
}

macro_rules! var {
($var:expr) => {
BaseExpr::Param($var.to_string())
};
}

macro_rules! qvar {
($var:expr) => {
ExtExpr::Param($var.to_string())
};
}

macro_rules! felt {
($val:expr) => {
BaseExpr::Const($val.into())
};
}

macro_rules! qfelt {
($a:expr, $b:expr, $c:expr, $d:expr) => {
ExtExpr::Const(SecureField::from_m31_array([
$a.into(),
$b.into(),
$c.into(),
$d.into(),
]))
};
}

#[test]
fn test_eval_expr() {
let col_1_0_0 = BaseField::from(12);
Expand All @@ -778,13 +826,13 @@ mod tests {
let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]);
let ext_vars = HashMap::from([("c".to_string(), var_c)]);

let expr = ExtExpr::SecureCol([
Box::new(BaseExpr::Col((1, 0, 0).into()) - BaseExpr::Col((1, 1, 0).into())),
Box::new(BaseExpr::Col((1, 1, 0).into()) * (-BaseExpr::Param("a".to_string()))),
Box::new(BaseExpr::Param("a".to_string()) + BaseExpr::Param("a".to_string()).inverse()),
Box::new(BaseExpr::Param("b".to_string()) * BaseExpr::Const(BaseField::from(7))),
]) + ExtExpr::Param("c".to_string()) * ExtExpr::Param("c".to_string())
- ExtExpr::Const(SecureField::one());
let expr = secure_col!(
col!(1, 0, 0) - col!(1, 1, 0),
col!(1, 1, 0) * (-var!("a")),
var!("a") + var!("a").inverse(),
var!("b") * felt!(7)
) + qvar!("c") * qvar!("c")
- qfelt!(1, 0, 0, 0);

let expected = SecureField::from_m31_array([
col_1_0_0 - col_1_1_0,
Expand All @@ -800,6 +848,61 @@ mod tests {
);
}

#[test]
fn test_simplify_expr() {
let c0 = col!(1, 0, 0);
let c1 = col!(1, 1, 0);
let a = var!("a");
let b = qvar!("b");
let zero = felt!(0);
let qzero = qfelt!(0, 0, 0, 0);
let one = felt!(1);
let qone = qfelt!(1, 0, 0, 0);
let minus_one = felt!(m31::P - 1);
let qminus_one = qfelt!(m31::P - 1, 0, 0, 0);

let mut rng = SmallRng::seed_from_u64(0);
let columns: HashMap<(usize, usize, isize), BaseField> =
HashMap::from([((1, 0, 0), rng.gen()), ((1, 1, 0), rng.gen())]);
let vars: HashMap<String, BaseField> = HashMap::from([("a".to_string(), rng.gen())]);
let ext_vars: HashMap<String, SecureField> = HashMap::from([("b".to_string(), rng.gen())]);

let base_expr = (((zero.clone() + c0.clone()) + (a.clone() + zero.clone()))
* ((-c1.clone()) + (-c0.clone()))
+ (-(-(a.clone() + a.clone() + c0.clone())))
- zero.clone())
+ (a.clone() - zero.clone())
+ (-c1.clone() - (a.clone() * a.clone()))
+ (a.clone() * zero.clone())
- (zero.clone() * c1.clone())
+ one.clone()
* a.clone()
* one.clone()
* c1.clone()
* (-a.clone())
* c1.clone()
* (minus_one.clone() * c0.clone());

let expr = (qzero.clone()
+ secure_col!(
base_expr.clone(),
base_expr.clone(),
zero.clone(),
one.clone()
)
- qzero.clone())
* qone.clone()
* b.clone()
* qminus_one.clone();

let full_eval = expr.eval_expr::<AssertEvaluator<'_>, _, _, _>(&columns, &vars, &ext_vars);
let simplified_eval = expr
.simplify()
.eval_expr::<AssertEvaluator<'_>, _, _, _>(&columns, &vars, &ext_vars);

assert_eq!(full_eval, simplified_eval);
}

#[test]
fn test_format_expr() {
let test_struct = TestStruct {};
Expand Down

1 comment on commit a459c48

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: a459c48 Previous: cd8b37b Ratio
merkle throughput/simd merkle 29743866 ns/iter (± 363538) 13712527 ns/iter (± 579195) 2.17

This comment was automatically generated by workflow using github-action-benchmark.

CC: @shaharsamocha7

Please sign in to comment.