From 77b04f5bd453e05250e2d37a30e56d44d93183eb Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Wed, 4 Dec 2024 16:21:53 +0200 Subject: [PATCH] Decoupled batching from `add_to_relation`. --- .../prover/src/constraint_framework/expr.rs | 31 ++---- .../prover/src/constraint_framework/logup.rs | 9 +- crates/prover/src/constraint_framework/mod.rs | 100 ++++++++++++++---- .../constraint_framework/relation_tracker.rs | 52 ++++----- crates/prover/src/examples/blake/mod.rs | 40 +++---- .../src/examples/blake/round/constraints.rs | 6 +- .../examples/blake/scheduler/constraints.rs | 20 ++-- .../examples/blake/xor_table/constraints.rs | 56 ++++------ crates/prover/src/examples/plonk/mod.rs | 20 ++-- crates/prover/src/examples/poseidon/mod.rs | 12 ++- .../src/examples/state_machine/components.rs | 18 ++-- 11 files changed, 207 insertions(+), 157 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 3098dcc56..31c509f3f 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -755,8 +755,7 @@ pub struct FormalLogupAtRow { pub interaction: usize, pub total_sum: ExtExpr, pub claimed_sum: Option<(ExtExpr, usize)>, - pub prev_col_cumsum: ExtExpr, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, pub is_first: BaseExpr, pub log_size: u32, @@ -777,8 +776,7 @@ impl FormalLogupAtRow { total_sum: ExtExpr::Param(total_sum_name), claimed_sum: has_partial_sum .then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), - prev_col_cumsum: ExtExpr::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: BaseExpr::zero(), log_size, @@ -873,23 +871,12 @@ impl EvalAtRow for ExprEvaluator { fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs: Vec> = entries - .iter() - .map( - |RelationEntry { - relation, - multiplicity, - values, - }| { - let intermediate = - self.add_extension_intermediate(combine_formal(*relation, values)); - Fraction::new(multiplicity.clone(), intermediate) - }, - ) - .collect(); - self.write_logup_frac(fracs.into_iter().sum()); + let intermediate = + self.add_extension_intermediate(combine_formal(entry.relation, entry.values)); + let frac = Fraction::new(entry.multiplicity.clone(), intermediate); + self.write_logup_frac(frac); } fn add_intermediate(&mut self, expr: Self::F) -> Self::F { @@ -1115,11 +1102,11 @@ mod tests { let x2 = eval.next_trace_mask(); let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &TestRelation::dummy(), E::EF::one(), &[x0, x1, x2], - )]); + )); eval.finalize_logup(); eval } diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index d6af96e46..bb05c6b5c 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -49,8 +49,7 @@ pub struct LogupAtRow { /// None if the claimed_sum is the total_sum. pub claimed_sum: Option, /// The evaluation of the last cumulative sum column. - pub prev_col_cumsum: E::EF, - pub cur_frac: Option>, + pub fracs: Vec>, pub is_finalized: bool, /// The value of the `is_first` constant column at current row. /// See [`super::preprocessed_columns::gen_is_first()`]. @@ -74,8 +73,7 @@ impl LogupAtRow { interaction, total_sum, claimed_sum, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size, @@ -88,8 +86,7 @@ impl LogupAtRow { interaction: 100, total_sum: SecureField::one(), claimed_sum: None, - prev_col_cumsum: E::EF::zero(), - cur_frac: None, + fracs: vec![], is_finalized: true, is_first: E::F::zero(), log_size: 10, diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index b03c08ce3..bc188eb57 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -32,6 +32,13 @@ pub const PREPROCESSED_TRACE_IDX: usize = 0; pub const ORIGINAL_TRACE_IDX: usize = 1; pub const INTERACTION_TRACE_IDX: usize = 2; +/// A vector that describes the batching of logup entries. +/// Each vector member corresponds to a logup entry, and contains the batch number to which the +/// entry should be added. +/// Note that the batch numbers should be consecutive and start from 0, and that the vector's +/// length should be equal to the number of logup entries. +type Batching = Vec; + /// A trait for evaluating expressions at some point or row. pub trait EvalAtRow { // TODO(Ohad): Use a better trait for these, like 'Algebra' or something. @@ -132,25 +139,30 @@ pub trait EvalAtRow { /// multiplied. fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - let fracs = entries.iter().map( - |RelationEntry { - relation, - multiplicity, - values, - }| { Fraction::new(multiplicity.clone(), relation.combine(values)) }, + let frac = Fraction::new( + entry.multiplicity.clone(), + entry.relation.combine(entry.values), ); - self.write_logup_frac(fracs.sum()); + self.write_logup_frac(frac); } // TODO(alont): Remove these once LogupAtRow is no longer used. fn write_logup_frac(&mut self, _fraction: Fraction) { unimplemented!() } - fn finalize_logup(&mut self) { + fn finalize_logup_batched(&mut self, _batching: &Batching) { unimplemented!() } + + fn finalize_logup(&mut self) { + unimplemented!(); + } + + fn finalize_logup_in_pairs(&mut self) { + unimplemented!(); + } } /// Default implementation for evaluators that have an element called "logup" that works like a @@ -159,26 +171,59 @@ pub trait EvalAtRow { macro_rules! logup_proxy { () => { fn write_logup_frac(&mut self, fraction: Fraction) { - // Add a constraint that num / denom = diff. - if let Some(cur_frac) = self.logup.cur_frac.clone() { - let [cur_cumsum] = - self.next_extension_interaction_mask(self.logup.interaction, [0]); - let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone(); - self.logup.prev_col_cumsum = cur_cumsum; - self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); - } else { + if self.logup.fracs.is_empty() { self.logup.is_first = self.get_preprocessed_column( super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size), ); self.logup.is_finalized = false; } - self.logup.cur_frac = Some(fraction); + self.logup.fracs.push(fraction.clone()); } - fn finalize_logup(&mut self) { + /// Finalize the logup by adding the constraints for the fractions, batched by + /// the given `batching`. + /// `batching` should contain the batch into which every logup entry should be inserted. + fn finalize_logup_batched(&mut self, batching: &super::Batching) { assert!(!self.logup.is_finalized, "LogupAtRow was already finalized"); + assert_eq!( + batching.len(), + self.logup.fracs.len(), + "Batching must be of the same length as the number of entries" + ); + + let last_batch = *batching.iter().max().unwrap(); + + let mut fracs_by_batch = + std::collections::HashMap::>>::new(); + + for (batch, frac) in batching.iter().zip(self.logup.fracs.iter()) { + fracs_by_batch + .entry(*batch) + .or_insert_with(Vec::new) + .push(frac.clone()); + } + + let keys_set: std::collections::HashSet<_> = fracs_by_batch.keys().cloned().collect(); + let all_batches_set: std::collections::HashSet<_> = (0..last_batch + 1).collect(); - let frac = self.logup.cur_frac.clone().unwrap(); + assert_eq!( + keys_set, all_batches_set, + "Batching must contain all consecutive batches" + ); + + let mut prev_col_cumsum = ::zero(); + + // All batches except the last are cumulatively summed in new interaction columns. + for batch_id in (0..last_batch) { + let cur_frac: Fraction<_, _> = fracs_by_batch[&batch_id].iter().cloned().sum(); + let [cur_cumsum] = + self.next_extension_interaction_mask(self.logup.interaction, [0]); + let diff = cur_cumsum.clone() - prev_col_cumsum.clone(); + prev_col_cumsum = cur_cumsum; + self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator); + } + + let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum(); // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted // offset from the is_first column when constant columns are supported. @@ -205,12 +250,25 @@ macro_rules! logup_proxy { // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. let fixed_prev_row_cumsum = prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone(); - let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone(); + let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone(); self.add_constraint(diff * frac.denominator - frac.numerator); self.logup.is_finalized = true; } + + /// Finalizes the row's logup in the default way. Currently, this means no batching. + fn finalize_logup(&mut self) { + let batches = (0..self.logup.fracs.len()).collect(); + self.finalize_logup_batched(&batches) + } + + /// Finalizes the row's logup, batched in pairs. + /// TODO(alont) Remove this once a better batching mechanism is implemented. + fn finalize_logup_in_pairs(&mut self) { + let batches = (0..self.logup.fracs.len()).map(|n| n / 2).collect(); + self.finalize_logup_batched(&batches) + } }; } pub(crate) use logup_proxy; diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index b804d488e..8311209d1 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -6,8 +6,8 @@ use num_traits::Zero; use super::logup::LogupSums; use super::{ - EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator, - INTERACTION_TRACE_IDX, + Batching, EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, + TraceLocationAllocator, INTERACTION_TRACE_IDX, }; use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; @@ -152,38 +152,38 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> { fn write_logup_frac(&mut self, _fraction: Fraction) {} + fn finalize_logup_batched(&mut self, _batching: &Batching) {} fn finalize_logup(&mut self) {} + fn finalize_logup_in_pairs(&mut self) {} fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, R>], + entry: RelationEntry<'_, Self::F, Self::EF, R>, ) { - for entry in entries { - let relation = entry.relation.get_name().to_owned(); - let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); - let mult = entry.multiplicity.to_array(); + let relation = entry.relation.get_name().to_owned(); + let values = entry.values.iter().map(|v| v.to_array()).collect_vec(); + let mult = entry.multiplicity.to_array(); - // Unpack SIMD. - for j in 0..N_LANES { - // Skip padded values. - let cannonical_index = bit_reverse_index( - coset_index_to_circle_domain_index( - (self.vec_row << LOG_N_LANES) + j, - self.domain_log_size, - ), + // Unpack SIMD. + for j in 0..N_LANES { + // Skip padded values. + let cannonical_index = bit_reverse_index( + coset_index_to_circle_domain_index( + (self.vec_row << LOG_N_LANES) + j, self.domain_log_size, - ); - if cannonical_index >= self.n_rows { - continue; - } - let values = values.iter().map(|v| v[j]).collect_vec(); - let mult = mult[j].to_m31_array()[0]; - self.entries.push(RelationTrackerEntry { - relation: relation.clone(), - mult, - values, - }); + ), + self.domain_log_size, + ); + if cannonical_index >= self.n_rows { + continue; } + let values = values.iter().map(|v| v[j]).collect_vec(); + let mult = mult[j].to_m31_array()[0]; + self.entries.push(RelationTrackerEntry { + relation: relation.clone(), + mult, + values, + }); } } } diff --git a/crates/prover/src/examples/blake/mod.rs b/crates/prover/src/examples/blake/mod.rs index ff62f9f7d..76feb7f8b 100644 --- a/crates/prover/src/examples/blake/mod.rs +++ b/crates/prover/src/examples/blake/mod.rs @@ -88,26 +88,26 @@ impl BlakeXorElements { // TODO(alont): Generalize this to variable sizes batches if ever used. fn use_relation(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) { match w { - 12 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor12, E::EF::one(), values[0]), - RelationEntry::new(&self.xor12, E::EF::one(), values[1]), - ]), - 9 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor9, E::EF::one(), values[0]), - RelationEntry::new(&self.xor9, E::EF::one(), values[1]), - ]), - 8 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor8, E::EF::one(), values[0]), - RelationEntry::new(&self.xor8, E::EF::one(), values[1]), - ]), - 7 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor7, E::EF::one(), values[0]), - RelationEntry::new(&self.xor7, E::EF::one(), values[1]), - ]), - 4 => eval.add_to_relation(&[ - RelationEntry::new(&self.xor4, E::EF::one(), values[0]), - RelationEntry::new(&self.xor4, E::EF::one(), values[1]), - ]), + 12 => { + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[1])); + } + 9 => { + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[1])); + } + 8 => { + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[1])); + } + 7 => { + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[1])); + } + 4 => { + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[0])); + eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[1])); + } _ => panic!("Invalid w"), }; } diff --git a/crates/prover/src/examples/blake/round/constraints.rs b/crates/prover/src/examples/blake/round/constraints.rs index ada5fb287..e15a225df 100644 --- a/crates/prover/src/examples/blake/round/constraints.rs +++ b/crates/prover/src/examples/blake/round/constraints.rs @@ -65,7 +65,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { ); // Yield `Round(input_v, output_v, message)`. - self.eval.add_to_relation(&[RelationEntry::new( + self.eval.add_to_relation(RelationEntry::new( self.round_lookup_elements, -E::EF::one(), &chain![ @@ -74,9 +74,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> { m.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - self.eval.finalize_logup(); + self.eval.finalize_logup_in_pairs(); self.eval } fn next_u32(&mut self) -> Fu32 { diff --git a/crates/prover/src/examples/blake/scheduler/constraints.rs b/crates/prover/src/examples/blake/scheduler/constraints.rs index 1bf93d1aa..aceece2e8 100644 --- a/crates/prover/src/examples/blake/scheduler/constraints.rs +++ b/crates/prover/src/examples/blake/scheduler/constraints.rs @@ -30,17 +30,23 @@ pub fn eval_blake_scheduler_constraints( ] .collect_vec() }); - eval.add_to_relation(&[ - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_i), - RelationEntry::new(round_lookup_elements, E::EF::one(), &elems_j), - ]); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_i, + )); + eval.add_to_relation(RelationEntry::new( + round_lookup_elements, + E::EF::one(), + &elems_j, + )); } let input_state = &states[0]; let output_state = &states[N_ROUNDS]; // TODO(alont): Remove blake interaction. - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( blake_lookup_elements, E::EF::zero(), &chain![ @@ -49,9 +55,9 @@ pub fn eval_blake_scheduler_constraints( messages.iter().cloned().flat_map(Fu32::into_felts) ] .collect_vec(), - )]); + )); - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); } fn eval_next_u32(eval: &mut E) -> Fu32 { diff --git a/crates/prover/src/examples/blake/xor_table/constraints.rs b/crates/prover/src/examples/blake/xor_table/constraints.rs index 4df0a6c63..60fef8bfe 100644 --- a/crates/prover/src/examples/blake/xor_table/constraints.rs +++ b/crates/prover/src/examples/blake/xor_table/constraints.rs @@ -40,43 +40,31 @@ macro_rules! xor_table_eval { 2, )); - let entry_chunks = (0..(1 << (2 * EXPAND_BITS))) - .map(|i| { - let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); - let multiplicity = self.eval.next_trace_mask(); + for i in (0..(1 << (2 * EXPAND_BITS))) { + let (i, j) = ((i >> EXPAND_BITS) as u32, (i % (1 << EXPAND_BITS)) as u32); + let multiplicity = self.eval.next_trace_mask(); - let a = al.clone() - + E::F::from(BaseField::from_u32_unchecked( - i << limb_bits::(), - )); - let b = bl.clone() - + E::F::from(BaseField::from_u32_unchecked( - j << limb_bits::(), - )); - let c = cl.clone() - + E::F::from(BaseField::from_u32_unchecked( - (i ^ j) << limb_bits::(), - )); + let a = al.clone() + + E::F::from(BaseField::from_u32_unchecked( + i << limb_bits::(), + )); + let b = bl.clone() + + E::F::from(BaseField::from_u32_unchecked( + j << limb_bits::(), + )); + let c = cl.clone() + + E::F::from(BaseField::from_u32_unchecked( + (i ^ j) << limb_bits::(), + )); - (self.lookup_elements, -multiplicity, [a, b, c]) - }) - .collect_vec(); - - for entry_chunk in entry_chunks.chunks(2) { - self.eval.add_to_relation( - &entry_chunk - .iter() - .map(|(lookup, multiplicity, values)| { - RelationEntry::new( - *lookup, - E::EF::from(multiplicity.clone()), - values, - ) - }) - .collect_vec(), - ); + self.eval.add_to_relation(RelationEntry::new( + self.lookup_elements, + -E::EF::from(multiplicity), + &[a, b, c], + )); } - self.eval.finalize_logup(); + + self.eval.finalize_logup_in_pairs(); self.eval } } diff --git a/crates/prover/src/examples/plonk/mod.rs b/crates/prover/src/examples/plonk/mod.rs index 49da86f8a..a1e0362c9 100644 --- a/crates/prover/src/examples/plonk/mod.rs +++ b/crates/prover/src/examples/plonk/mod.rs @@ -66,18 +66,24 @@ impl FrameworkEval for PlonkEval { + (E::F::one() - op) * a_val.clone() * b_val.clone(), ); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[a_wire, a_val]), - RelationEntry::new(&self.lookup_elements, E::EF::one(), &[b_wire, b_val]), - ]); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[a_wire, a_val], + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &[b_wire, b_val], + )); - eval.add_to_relation(&[RelationEntry::new( + eval.add_to_relation(RelationEntry::new( &self.lookup_elements, (-mult).into(), &[c_wire, c_val], - )]); + )); - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); eval } } diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 51b671580..808dcc74d 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -186,13 +186,15 @@ pub fn eval_poseidon_constraints(eval: &mut E, lookup_elements: &P }); // Provide state lookups. - eval.add_to_relation(&[ - RelationEntry::new(lookup_elements, E::EF::one(), &initial_state), - RelationEntry::new(lookup_elements, -E::EF::one(), &state), - ]) + eval.add_to_relation(RelationEntry::new( + lookup_elements, + E::EF::one(), + &initial_state, + )); + eval.add_to_relation(RelationEntry::new(lookup_elements, -E::EF::one(), &state)); } - eval.finalize_logup(); + eval.finalize_logup_in_pairs(); } pub struct LookupData { diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 4600a3cf0..23bcf2977 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -52,12 +52,18 @@ impl FrameworkEval for StateTransitionEval let mut output_state = input_state.clone(); output_state[COORDINATE] += E::F::one(); - eval.add_to_relation(&[ - RelationEntry::new(&self.lookup_elements, E::EF::one(), &input_state), - RelationEntry::new(&self.lookup_elements, -E::EF::one(), &output_state), - ]); - - eval.finalize_logup(); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + E::EF::one(), + &input_state, + )); + eval.add_to_relation(RelationEntry::new( + &self.lookup_elements, + -E::EF::one(), + &output_state, + )); + + eval.finalize_logup_in_pairs(); eval } }