From bf21504fc0690ebd6911888230dfe70cfb64baad Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Sun, 1 Dec 2024 19:10:20 +0200 Subject: [PATCH] Add safe simplify for expressions that compares random assignments before and after. --- .../prover/src/constraint_framework/expr.rs | 161 +++++++++++++++++- 1 file changed, 152 insertions(+), 9 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 95f9f3342..e01a03818 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -1,8 +1,12 @@ +use std::collections::{HashMap, HashSet}; use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; +use itertools::sorted; use num_traits::{One, Zero}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; -use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; +use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{self, BaseField}; use crate::core::fields::qm31::{SecureField, QM31}; @@ -10,7 +14,7 @@ use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; /// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ColumnExpr { interaction: usize, idx: usize, @@ -174,11 +178,14 @@ impl BaseExpr { } } - pub fn simplify(&self) -> Self { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { let simple = simplify_arithmetic!(self); match simple { Self::Inv(a) => { - let a = a.simplify(); + let a = a.unchecked_simplify(); match a { Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a Self::Const(c) => Self::Const(c.inverse()), @@ -189,6 +196,14 @@ impl BaseExpr { } } + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } @@ -220,6 +235,25 @@ impl BaseExpr { Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), } } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + BaseExpr::Col(col) => ExprVariables::col(col.clone()), + BaseExpr::Const(_) => ExprVariables::default(), + BaseExpr::Param(param) => ExprVariables::param(param.to_string()), + BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Neg(a) => a.collect_variables(), + BaseExpr::Inv(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> BaseField { + let assignment = self.collect_variables().random_assignment(); + assert!(assignment.2.is_empty()); + self.eval_expr::, _, _>(&assignment.0, &assignment.1) + } } impl ExtExpr { @@ -256,14 +290,17 @@ impl ExtExpr { } } - pub fn simplify(&self) -> Self { + /// Helper function, use [`simplify`] instead. + /// + /// Simplifies an expression by applying basic arithmetic rules. + fn unchecked_simplify(&self) -> Self { let simple = simplify_arithmetic!(self); match simple { Self::SecureCol([a, b, c, d]) => { - let a = a.simplify(); - let b = b.simplify(); - let c = c.simplify(); - let d = d.simplify(); + let a = a.unchecked_simplify(); + let b = b.unchecked_simplify(); + let c = c.unchecked_simplify(); + let d = d.unchecked_simplify(); match (a.clone(), b.clone(), c.clone(), d.clone()) { ( BaseExpr::Const(a_val), @@ -278,6 +315,14 @@ impl ExtExpr { } } + /// Simplifies an expression by applying basic arithmetic rules and ensures that the result is + /// equivalent to the original expression by assigning random values. + pub fn simplify(&self) -> Self { + let simplified = self.unchecked_simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } @@ -319,6 +364,104 @@ impl ExtExpr { Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), } } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + a.collect_variables() + + b.collect_variables() + + c.collect_variables() + + d.collect_variables() + } + ExtExpr::Const(_) => ExprVariables::default(), + ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()), + ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Neg(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> SecureField { + let assignment = self.collect_variables().random_assignment(); + self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) + } +} + +/// An assignment to the variables that may appear in an expression. +pub type ExprVarAssignment = ( + HashMap<(usize, usize, isize), BaseField>, + HashMap, + HashMap, +); + +/// Three sets representing all the variables that can appear in an expression: +/// * `cols`: The columns of the AIR. +/// * `params`: The formal parameters to the AIR. +/// * `ext_params`: The extension field parameters to the AIR. +#[derive(Default)] +pub struct ExprVariables { + pub cols: HashSet, + pub params: HashSet, + pub ext_params: HashSet, +} + +impl ExprVariables { + pub fn col(col: ColumnExpr) -> Self { + Self { + cols: vec![col].into_iter().collect(), + params: HashSet::new(), + ext_params: HashSet::new(), + } + } + + pub fn param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: vec![param].into_iter().collect(), + ext_params: HashSet::new(), + } + } + + pub fn ext_param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: HashSet::new(), + ext_params: vec![param].into_iter().collect(), + } + } + + /// Generates a random assignment to the variables. + /// Note that the assignment is deterministic in the sets of variables (disregarding their + /// order), and this is required. + pub fn random_assignment(&self) -> ExprVarAssignment { + let mut rng = SmallRng::seed_from_u64(0); + + let cols = sorted(self.cols.iter()) + .map(|col| ((col.interaction, col.idx, col.offset), rng.gen())) + .collect(); + + let params = sorted(self.params.iter()) + .map(|param| (param.clone(), rng.gen())) + .collect(); + + let ext_params = sorted(self.ext_params.iter()) + .map(|param| (param.clone(), rng.gen())) + .collect(); + + (cols, params, ext_params) + } +} + +impl Add for ExprVariables { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self { + cols: self.cols.union(&rhs.cols).cloned().collect(), + params: self.params.union(&rhs.params).cloned().collect(), + ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(), + } + } } impl From for BaseExpr {