Skip to content

Commit

Permalink
Decoupled batching from add_to_relation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Dec 9, 2024
1 parent 76af3c6 commit 7dfb944
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 157 deletions.
31 changes: 9 additions & 22 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -755,8 +755,7 @@ pub struct FormalLogupAtRow {
pub interaction: usize,
pub total_sum: ExtExpr,
pub claimed_sum: Option<(ExtExpr, usize)>,
pub prev_col_cumsum: ExtExpr,
pub cur_frac: Option<Fraction<ExtExpr, ExtExpr>>,
pub fracs: Vec<Fraction<ExtExpr, ExtExpr>>,
pub is_finalized: bool,
pub is_first: BaseExpr,
pub log_size: u32,
Expand All @@ -777,8 +776,7 @@ impl FormalLogupAtRow {
total_sum: ExtExpr::Param(total_sum_name),
claimed_sum: has_partial_sum
.then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)),
prev_col_cumsum: ExtExpr::zero(),
cur_frac: None,
fracs: vec![],
is_finalized: true,
is_first: BaseExpr::zero(),
log_size,
Expand Down Expand Up @@ -873,23 +871,12 @@ impl EvalAtRow for ExprEvaluator {

fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
entry: RelationEntry<'_, Self::F, Self::EF, R>,
) {
let fracs: Vec<Fraction<Self::EF, Self::EF>> = entries
.iter()
.map(
|RelationEntry {
relation,
multiplicity,
values,
}| {
let intermediate =
self.add_extension_intermediate(combine_formal(*relation, values));
Fraction::new(multiplicity.clone(), intermediate)
},
)
.collect();
self.write_logup_frac(fracs.into_iter().sum());
let intermediate =
self.add_extension_intermediate(combine_formal(entry.relation, entry.values));
let frac = Fraction::new(entry.multiplicity.clone(), intermediate);
self.write_logup_frac(frac);
}

fn add_intermediate(&mut self, expr: Self::F) -> Self::F {
Expand Down Expand Up @@ -1115,11 +1102,11 @@ mod tests {
let x2 = eval.next_trace_mask();
let intermediate = eval.add_intermediate(x1.clone() * x2.clone());
eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse());
eval.add_to_relation(&[RelationEntry::new(
eval.add_to_relation(RelationEntry::new(
&TestRelation::dummy(),
E::EF::one(),
&[x0, x1, x2],
)]);
));
eval.finalize_logup();
eval
}
Expand Down
9 changes: 3 additions & 6 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ pub struct LogupAtRow<E: EvalAtRow> {
/// None if the claimed_sum is the total_sum.
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
pub cur_frac: Option<Fraction<E::EF, E::EF>>,
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
Expand All @@ -74,8 +73,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction,
total_sum,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size,
Expand All @@ -88,8 +86,7 @@ impl<E: EvalAtRow> LogupAtRow<E> {
interaction: 100,
total_sum: SecureField::one(),
claimed_sum: None,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
fracs: vec![],
is_finalized: true,
is_first: E::F::zero(),
log_size: 10,
Expand Down
95 changes: 74 additions & 21 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ pub const PREPROCESSED_TRACE_IDX: usize = 0;
pub const ORIGINAL_TRACE_IDX: usize = 1;
pub const INTERACTION_TRACE_IDX: usize = 2;

type Batching = Vec<usize>;

/// A trait for evaluating expressions at some point or row.
pub trait EvalAtRow {
// TODO(Ohad): Use a better trait for these, like 'Algebra' or something.
Expand Down Expand Up @@ -132,25 +134,30 @@ pub trait EvalAtRow {
/// multiplied.
fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
entry: RelationEntry<'_, Self::F, Self::EF, R>,
) {
let fracs = entries.iter().map(
|RelationEntry {
relation,
multiplicity,
values,
}| { Fraction::new(multiplicity.clone(), relation.combine(values)) },
let frac = Fraction::new(
entry.multiplicity.clone(),
entry.relation.combine(entry.values),
);
self.write_logup_frac(fracs.sum());
self.write_logup_frac(frac);
}

// TODO(alont): Remove these once LogupAtRow is no longer used.
fn write_logup_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {
unimplemented!()
}
fn finalize_logup(&mut self) {
fn finalize_logup_batched(&mut self, _batching: &Batching) {
unimplemented!()
}

fn finalize_logup(&mut self) {
unimplemented!();
}

fn finalize_logup_in_pairs(&mut self) {
unimplemented!();
}
}

/// Default implementation for evaluators that have an element called "logup" that works like a
Expand All @@ -159,26 +166,59 @@ pub trait EvalAtRow {
macro_rules! logup_proxy {
() => {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
// Add a constraint that num / denom = diff.
if let Some(cur_frac) = self.logup.cur_frac.clone() {
let [cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [0]);
let diff = cur_cumsum.clone() - self.logup.prev_col_cumsum.clone();
self.logup.prev_col_cumsum = cur_cumsum;
self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
} else {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
super::preprocessed_columns::PreprocessedColumn::IsFirst(self.logup.log_size),
);
self.logup.is_finalized = false;
}
self.logup.cur_frac = Some(fraction);
self.logup.fracs.push(fraction.clone());
}

fn finalize_logup(&mut self) {
/// Finalize the logup by adding the constraints for the fractions, batched by
/// the given `batching`.
/// `batching` should contain the batch into which every logup entry should be inserted.
fn finalize_logup_batched(&mut self, batching: &super::Batching) {
assert!(!self.logup.is_finalized, "LogupAtRow was already finalized");
assert_eq!(
batching.len(),
self.logup.fracs.len(),
"Batching must be of the same length as the number of entries"
);

let last_batch = *batching.iter().max().unwrap();

let mut fracs_by_batch =
std::collections::HashMap::<usize, Vec<Fraction<Self::EF, Self::EF>>>::new();

for (batch, frac) in batching.iter().zip(self.logup.fracs.iter()) {
fracs_by_batch
.entry(*batch)
.or_insert_with(Vec::new)
.push(frac.clone());
}

let keys_set: std::collections::HashSet<_> = fracs_by_batch.keys().cloned().collect();
let all_batches_set: std::collections::HashSet<_> = (0..last_batch + 1).collect();

let frac = self.logup.cur_frac.clone().unwrap();
assert_eq!(
keys_set, all_batches_set,
"Batching must contain all consecutive batches"
);

let mut prev_col_cumsum = <Self::EF as num_traits::Zero>::zero();

// All batches except the last are cumulatively summed in new interaction columns.
for batch_id in (0..last_batch) {
let cur_frac: Fraction<_, _> = fracs_by_batch[&batch_id].iter().cloned().sum();
let [cur_cumsum] =
self.next_extension_interaction_mask(self.logup.interaction, [0]);
let diff = cur_cumsum.clone() - prev_col_cumsum.clone();
prev_col_cumsum = cur_cumsum;
self.add_constraint(diff * cur_frac.denominator - cur_frac.numerator);
}

let frac: Fraction<_, _> = fracs_by_batch[&last_batch].clone().into_iter().sum();

// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted
// offset from the is_first column when constant columns are supported.
Expand All @@ -205,12 +245,25 @@ macro_rules! logup_proxy {
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum =
prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone();
let diff = cur_cumsum - fixed_prev_row_cumsum - prev_col_cumsum.clone();

self.add_constraint(diff * frac.denominator - frac.numerator);

self.logup.is_finalized = true;
}

/// Finalizes the row's logup in the default way. Currently, this means no batching.
fn finalize_logup(&mut self) {
let batches = (0..self.logup.fracs.len()).collect();
self.finalize_logup_batched(&batches)
}

/// Finalizes the row's logup, batched in pairs.
/// TODO(alont) Remove this once a better batching mechanism is implemented.
fn finalize_logup_in_pairs(&mut self) {
let batches = (0..self.logup.fracs.len()).map(|n| n / 2).collect();
self.finalize_logup_batched(&batches)
}
};
}
pub(crate) use logup_proxy;
Expand Down
52 changes: 26 additions & 26 deletions crates/prover/src/constraint_framework/relation_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use num_traits::Zero;

use super::logup::LogupSums;
use super::{
EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator,
INTERACTION_TRACE_IDX,
Batching, EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry,
TraceLocationAllocator, INTERACTION_TRACE_IDX,
};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
Expand Down Expand Up @@ -152,38 +152,38 @@ impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> {

fn write_logup_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {}

fn finalize_logup_batched(&mut self, _batching: &Batching) {}
fn finalize_logup(&mut self) {}
fn finalize_logup_in_pairs(&mut self) {}

fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
entry: RelationEntry<'_, Self::F, Self::EF, R>,
) {
for entry in entries {
let relation = entry.relation.get_name().to_owned();
let values = entry.values.iter().map(|v| v.to_array()).collect_vec();
let mult = entry.multiplicity.to_array();
let relation = entry.relation.get_name().to_owned();
let values = entry.values.iter().map(|v| v.to_array()).collect_vec();
let mult = entry.multiplicity.to_array();

// Unpack SIMD.
for j in 0..N_LANES {
// Skip padded values.
let cannonical_index = bit_reverse_index(
coset_index_to_circle_domain_index(
(self.vec_row << LOG_N_LANES) + j,
self.domain_log_size,
),
// Unpack SIMD.
for j in 0..N_LANES {
// Skip padded values.
let cannonical_index = bit_reverse_index(
coset_index_to_circle_domain_index(
(self.vec_row << LOG_N_LANES) + j,
self.domain_log_size,
);
if cannonical_index >= self.n_rows {
continue;
}
let values = values.iter().map(|v| v[j]).collect_vec();
let mult = mult[j].to_m31_array()[0];
self.entries.push(RelationTrackerEntry {
relation: relation.clone(),
mult,
values,
});
),
self.domain_log_size,
);
if cannonical_index >= self.n_rows {
continue;
}
let values = values.iter().map(|v| v[j]).collect_vec();
let mult = mult[j].to_m31_array()[0];
self.entries.push(RelationTrackerEntry {
relation: relation.clone(),
mult,
values,
});
}
}
}
Expand Down
40 changes: 20 additions & 20 deletions crates/prover/src/examples/blake/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,26 +88,26 @@ impl BlakeXorElements {
// TODO(alont): Generalize this to variable sizes batches if ever used.
fn use_relation<E: EvalAtRow>(&self, eval: &mut E, w: u32, values: [&[E::F]; 2]) {
match w {
12 => eval.add_to_relation(&[
RelationEntry::new(&self.xor12, E::EF::one(), values[0]),
RelationEntry::new(&self.xor12, E::EF::one(), values[1]),
]),
9 => eval.add_to_relation(&[
RelationEntry::new(&self.xor9, E::EF::one(), values[0]),
RelationEntry::new(&self.xor9, E::EF::one(), values[1]),
]),
8 => eval.add_to_relation(&[
RelationEntry::new(&self.xor8, E::EF::one(), values[0]),
RelationEntry::new(&self.xor8, E::EF::one(), values[1]),
]),
7 => eval.add_to_relation(&[
RelationEntry::new(&self.xor7, E::EF::one(), values[0]),
RelationEntry::new(&self.xor7, E::EF::one(), values[1]),
]),
4 => eval.add_to_relation(&[
RelationEntry::new(&self.xor4, E::EF::one(), values[0]),
RelationEntry::new(&self.xor4, E::EF::one(), values[1]),
]),
12 => {
eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor12, E::EF::one(), values[1]));
}
9 => {
eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor9, E::EF::one(), values[1]));
}
8 => {
eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor8, E::EF::one(), values[1]));
}
7 => {
eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor7, E::EF::one(), values[1]));
}
4 => {
eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[0]));
eval.add_to_relation(RelationEntry::new(&self.xor4, E::EF::one(), values[1]));
}
_ => panic!("Invalid w"),
};
}
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/examples/blake/round/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
);

// Yield `Round(input_v, output_v, message)`.
self.eval.add_to_relation(&[RelationEntry::new(
self.eval.add_to_relation(RelationEntry::new(
self.round_lookup_elements,
-E::EF::one(),
&chain![
Expand All @@ -74,9 +74,9 @@ impl<'a, E: EvalAtRow> BlakeRoundEval<'a, E> {
m.iter().cloned().flat_map(Fu32::into_felts)
]
.collect_vec(),
)]);
));

self.eval.finalize_logup();
self.eval.finalize_logup_in_pairs();
self.eval
}
fn next_u32(&mut self) -> Fu32<E::F> {
Expand Down
Loading

0 comments on commit 7dfb944

Please sign in to comment.