Skip to content

Commit

Permalink
Added add_intermediate and add_secure_intermediate to eval API.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Dec 2, 2024
1 parent a198090 commit 348fc4e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
69 changes: 48 additions & 21 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,34 +772,35 @@ pub struct ExprEvaluator {
pub cur_var_index: usize,
pub constraints: Vec<ExtExpr>,
pub logup: FormalLogupAtRow,
pub intermediates: Vec<(String, ExtExpr)>,
pub intermediates: Vec<(String, BaseExpr)>,
pub secure_intermediates: Vec<(String, ExtExpr)>,
}

impl ExprEvaluator {
#[allow(dead_code)]
pub fn new(log_size: u32, has_partial_sum: bool) -> Self {
Self {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size),
intermediates: vec![],
secure_intermediates: vec![],
}
}

pub fn add_intermediate(&mut self, expr: ExtExpr) -> ExtExpr {
let name = format!("intermediate{}", self.intermediates.len());
let intermediate = ExtExpr::Param(name.clone());
self.intermediates.push((name, expr));
intermediate
}

pub fn format_constraints(&self) -> String {
let lets_string = self
.intermediates
.iter()
.map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format()))
.collect::<Vec<String>>()
.join("\n");
.join("\n\n");

let secure_lets_string = self
.secure_intermediates
.iter()
.map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format()))
.collect::<Vec<String>>()
.join("\n\n");

let constraints_str = self
.constraints
Expand All @@ -809,7 +810,12 @@ impl ExprEvaluator {
.collect::<Vec<String>>()
.join("\n\n");

lets_string + "\n\n" + &constraints_str
[lets_string, secure_lets_string, constraints_str]
.iter()
.filter(|x| !x.is_empty())
.cloned()
.collect::<Vec<_>>()
.join("\n\n")
}
}

Expand Down Expand Up @@ -858,14 +864,35 @@ impl EvalAtRow for ExprEvaluator {
multiplicity,
values,
}| {
let intermediate = self.add_intermediate(combine_formal(*relation, values));
let intermediate =
self.add_secure_intermediate(combine_formal(*relation, values));
Fraction::new(multiplicity.clone(), intermediate)
},
)
.collect();
self.write_logup_frac(fracs.into_iter().sum());
}

fn add_intermediate(&mut self, expr: Self::F) -> Self::F {
let name = format!(
"intermediate{}",
self.intermediates.len() + self.secure_intermediates.len()
);
let intermediate = BaseExpr::Param(name.clone());
self.intermediates.push((name, expr));
intermediate
}

fn add_secure_intermediate(&mut self, expr: Self::EF) -> Self::EF {
let name = format!(
"intermediate{}",
self.intermediates.len() + self.secure_intermediates.len()
);
let intermediate = ExtExpr::Param(name.clone());
self.secure_intermediates.push((name, expr));
intermediate
}

super::logup_proxy!();
}

Expand Down Expand Up @@ -1031,21 +1058,22 @@ mod tests {
fn test_format_expr() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \
let expected = "let intermediate0 = (col_1_1[0]) * (col_1_2[0]);
\
let intermediate1 = (TestRelation_alpha0) * (col_1_0[0]) \
+ (TestRelation_alpha1) * (col_1_1[0]) \
+ (TestRelation_alpha2) * (col_1_2[0]) \
- (TestRelation_z);
\
let constraint_0 = \
(((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1 / (col_1_0[0] + col_1_1[0]));
let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0]));
\
let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \
- (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \
- ((total_sum) * (col_0_3[0])))\
) \
* (intermediate0) \
- ((total_sum) * (col_0_3[0])))) \
* (intermediate1) \
- (1);"
.to_string();

Expand All @@ -1066,9 +1094,8 @@ mod tests {
let x0 = eval.next_trace_mask();
let x1 = eval.next_trace_mask();
let x2 = eval.next_trace_mask();
eval.add_constraint(
x0.clone() * x1.clone() * x2.clone() * (x0.clone() + x1.clone()).inverse(),
);
let intermediate = eval.add_intermediate(x1.clone() * x2.clone());
eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse());
eval.add_to_relation(&[RelationEntry::new(
&TestRelation::dummy(),
E::EF::one(),
Expand Down
12 changes: 12 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,18 @@ pub trait EvalAtRow {
where
Self::EF: Mul<G, Output = Self::EF> + From<G>;

/// Adds an intermediate value to the component and returns its value.
/// Does nothing by default.
fn add_intermediate(&mut self, val: Self::F) -> Self::F {
val
}

/// Adds a secure intermediate value to the component and returns its value.
/// Does nothing by default.
fn add_secure_intermediate(&mut self, val: Self::EF) -> Self::EF {
val
}

/// Combines 4 base field values into a single extension field value.
fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF;

Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ mod tests {
let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
- (StateMachineElements_z);
\
let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \
+ (StateMachineElements_alpha1) * (col_1_1[0]) \
Expand Down

0 comments on commit 348fc4e

Please sign in to comment.