Skip to content

Commit

Permalink
Remove claimed_sum and LogupSums (#979)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 authored Jan 15, 2025
2 parents 9316f98 + 021ccf8 commit c2a584e
Show file tree
Hide file tree
Showing 21 changed files with 73 additions and 157 deletions.
10 changes: 5 additions & 5 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use num_traits::Zero;

use super::logup::{LogupAtRow, LogupSums};
use super::logup::LogupAtRow;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::{Backend, Column};
use crate::core::fields::m31::BaseField;
Expand All @@ -23,13 +23,13 @@ impl<'a> AssertEvaluator<'a> {
trace: &'a TreeVec<Vec<Vec<BaseField>>>,
row: usize,
log_size: u32,
logup_sums: LogupSums,
total_sum: SecureField,
) -> Self {
Self {
trace,
col_index: TreeVec::new(vec![0; trace.len()]),
row,
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, total_sum, log_size),
}
}
}
Expand Down Expand Up @@ -78,7 +78,7 @@ pub fn assert_constraints<B: Backend>(
trace_polys: &TreeVec<Vec<CirclePoly<B>>>,
trace_domain: CanonicCoset,
assert_func: impl Fn(AssertEvaluator<'_>),
logup_sums: LogupSums,
total_sum: SecureField,
) {
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
Expand All @@ -94,7 +94,7 @@ pub fn assert_constraints<B: Backend>(
.collect()
});
for row in 0..trace_domain.size() {
let eval = AssertEvaluator::new(&traces, row, trace_domain.log_size(), logup_sums);
let eval = AssertEvaluator::new(&traces, row, trace_domain.log_size(), total_sum);

assert_func(eval);
}
Expand Down
17 changes: 8 additions & 9 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use rayon::prelude::*;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreProcessedColumnId;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
Expand Down Expand Up @@ -124,16 +123,16 @@ pub struct FrameworkComponent<C: FrameworkEval> {
trace_locations: TreeVec<TreeSubspan>,
info: InfoEvaluator,
preprocessed_column_indices: Vec<usize>,
logup_sums: LogupSums,
total_sum: SecureField,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(
location_allocator: &mut TraceLocationAllocator,
eval: E,
logup_sums: LogupSums,
total_sum: SecureField,
) -> Self {
let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], logup_sums));
let info = eval.evaluate(InfoEvaluator::new(eval.log_size(), vec![], total_sum));
let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);

let preprocessed_column_indices = info
Expand Down Expand Up @@ -167,7 +166,7 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
trace_locations,
info,
preprocessed_column_indices,
logup_sums,
total_sum,
}
}

Expand Down Expand Up @@ -238,7 +237,7 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
self.eval.log_size(),
self.logup_sums,
self.total_sum,
));
}
}
Expand Down Expand Up @@ -319,7 +318,7 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
trace_domain.log_size(),
eval_domain.log_size(),
self.eval.log_size(),
self.logup_sums,
self.total_sum,
);
let row_res = self.eval.evaluate(eval).row_res;

Expand Down Expand Up @@ -348,7 +347,7 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
// Define any `self` values outside the loop to prevent the compiler thinking there is a
// `Sync` requirement on `Self`.
let self_eval = &self.eval;
let self_logup_sums = self.logup_sums;
let self_total_sum = self.total_sum;

iter.for_each(|(chunk_idx, mut chunk)| {
let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref());
Expand All @@ -363,7 +362,7 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
trace_domain.log_size(),
eval_domain.log_size(),
self_eval.log_size(),
self_logup_sums,
self_total_sum,
);
let row_res = self_eval.evaluate(eval).row_res;

Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/constraint_framework/cpu_domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ops::Mul;

use num_traits::Zero;

use super::logup::{LogupAtRow, LogupSums};
use super::logup::LogupAtRow;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -36,7 +36,7 @@ impl<'a> CpuDomainEvaluator<'a> {
domain_log_size: u32,
eval_log_size: u32,
log_size: u32,
logup_sums: LogupSums,
total_sum: SecureField,
) -> Self {
Self {
trace_eval,
Expand All @@ -47,7 +47,7 @@ impl<'a> CpuDomainEvaluator<'a> {
constraint_index: 0,
domain_log_size,
eval_domain_log_size: eval_log_size,
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, total_sum, log_size),
}
}
}
Expand Down
18 changes: 5 additions & 13 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,28 @@ use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreProcessedColumnId;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::{self, M31};
use crate::core::fields::m31::M31;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

pub struct FormalLogupAtRow {
pub interaction: usize,
pub total_sum: ExtExpr,
pub claimed_sum: Option<(ExtExpr, usize)>,
pub fracs: Vec<Fraction<ExtExpr, ExtExpr>>,
pub is_finalized: bool,
pub is_first: BaseExpr,
pub cumsum_shift: ExtExpr,
pub log_size: u32,
}

// P is an offset no column can reach, it signifies the variable
// offset, which is an input to the verifier.
pub const CLAIMED_SUM_DUMMY_OFFSET: usize = m31::P as usize;

impl FormalLogupAtRow {
pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self {
pub fn new(interaction: usize, log_size: u32) -> Self {
let total_sum_name = "total_sum".to_string();
let claimed_sum_name = "claimed_sum".to_string();

Self {
interaction,
// TODO(alont): Should these be Expr::SecureField?
total_sum: ExtExpr::Param(total_sum_name.clone()),
claimed_sum: has_partial_sum
.then_some((ExtExpr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)),
fracs: vec![],
is_finalized: true,
is_first: BaseExpr::zero(),
Expand Down Expand Up @@ -75,11 +67,11 @@ pub struct ExprEvaluator {
}

impl ExprEvaluator {
pub fn new(log_size: u32, has_partial_sum: bool) -> Self {
pub fn new(log_size: u32) -> Self {
Self {
cur_var_index: Default::default(),
constraints: Default::default(),
logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size),
logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, log_size),
intermediates: vec![],
ext_intermediates: vec![],
}
Expand Down Expand Up @@ -200,7 +192,7 @@ mod tests {
#[test]
fn test_expr_evaluator() {
let test_struct = TestStruct {};
let eval = test_struct.evaluate(ExprEvaluator::new(16, false));
let eval = test_struct.evaluate(ExprEvaluator::new(16));
let expected = "let intermediate0 = (trace_1_column_1_offset_0) * (trace_1_column_2_offset_0);
\
Expand Down
6 changes: 2 additions & 4 deletions crates/prover/src/constraint_framework/expr/format.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use num_traits::Zero;

use super::{BaseExpr, ColumnExpr, ExtExpr, CLAIMED_SUM_DUMMY_OFFSET};
use super::{BaseExpr, ColumnExpr, ExtExpr};

impl BaseExpr {
pub fn format_expr(&self) -> String {
Expand All @@ -10,9 +10,7 @@ impl BaseExpr {
idx,
offset,
}) => {
let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET as isize {
"claimed_sum".to_string()
} else {
let offset_str = {
let offset_abs = offset.abs();
if *offset >= 0 {
offset.to_string()
Expand Down
1 change: 0 additions & 1 deletion crates/prover/src/constraint_framework/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
pub use evaluator::ExprEvaluator;
use num_traits::{One, Zero};

use crate::constraint_framework::expr::evaluator::CLAIMED_SUM_DUMMY_OFFSET;
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::{SecureField, QM31};
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::rc::Rc;

use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::logup::LogupAtRow;
use super::preprocessed_columns::PreProcessedColumnId;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
Expand All @@ -30,21 +30,21 @@ impl InfoEvaluator {
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreProcessedColumnId>,
logup_sums: LogupSums,
total_sum: SecureField,
) -> Self {
Self {
mask_offsets: Default::default(),
n_constraints: Default::default(),
preprocessed_columns,
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size),
logup: LogupAtRow::new(INTERACTION_TRACE_IDX, total_sum, log_size),
arithmetic_counts: Default::default(),
}
}

/// Create an empty `InfoEvaluator`, to measure components before their size and logup sums are
/// available.
pub fn empty() -> Self {
Self::new(16, vec![], (SecureField::default(), None))
Self::new(16, vec![], SecureField::default())
}
}
impl EvalAtRow for InfoEvaluator {
Expand Down
71 changes: 2 additions & 69 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,15 @@ use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::lookups::utils::Fraction;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::ColumnVec;

/// Represents the value of the prefix sum column at some index.
/// Should be used to eliminate padded rows for the logup sum.
pub type ClaimedPrefixSum = (SecureField, usize);
// (total_sum, claimed_sum)
pub type LogupSums = (SecureField, Option<ClaimedPrefixSum>);

pub trait LogupSumsExt {
fn value(&self) -> SecureField;
}

impl LogupSumsExt for LogupSums {
fn value(&self) -> SecureField {
self.1.map(|(claimed_sum, _)| claimed_sum).unwrap_or(self.0)
}
}

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = total_sum.
pub struct LogupAtRow<E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// The total sum of all the fractions divided by n_rows.
pub cumsum_shift: SecureField,
/// The claimed sum of the relevant fractions.
/// This is used for padding the component with default rows. Padding should be in bit-reverse.
/// None if the claimed_sum is the total_sum.
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
Expand All @@ -60,18 +39,10 @@ impl<E: EvalAtRow> Default for LogupAtRow<E> {
}
}
impl<E: EvalAtRow> LogupAtRow<E> {
pub fn new(
interaction: usize,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
log_size: u32,
) -> Self {
// TODO(ShaharS): remove once claimed sum at internal index is supported.
assert!(claimed_sum.is_none(), "Partial prefix-sum is not supported");
pub fn new(interaction: usize, total_sum: SecureField, log_size: u32) -> Self {
Self {
interaction,
cumsum_shift: total_sum / BaseField::from_u32_unchecked(1 << log_size),
claimed_sum,
fracs: vec![],
is_finalized: true,
log_size,
Expand All @@ -83,7 +54,6 @@ impl<E: EvalAtRow> LogupAtRow<E> {
Self {
interaction: 100,
cumsum_shift: SecureField::one(),
claimed_sum: None,
fracs: vec![],
is_finalized: true,
log_size: 10,
Expand Down Expand Up @@ -222,43 +192,6 @@ impl LogupTraceGenerator {
.collect_vec();
(trace, total_sum)
}

/// Finalize the trace. Returns the trace and the prefix sum of the last column at
/// the corresponding `indices`.
pub fn finalize_at<const N: usize>(
mut self,
indices: [usize; N],
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
[SecureField; N],
) {
// Prefix sum the last column.
let last_col_coords = self.trace.pop().unwrap().columns;
let coord_prefix_sum = last_col_coords.map(inclusive_prefix_sum);
let secure_prefix_sum = SecureColumnByCoords {
columns: coord_prefix_sum,
};
let returned_prefix_sums = indices.map(|idx| {
// Prefix sum column is in bit-reversed circle domain order.
let fixed_index = bit_reverse_index(
coset_index_to_circle_domain_index(idx, self.log_size),
self.log_size,
);
secure_prefix_sum.at(fixed_index)
});
self.trace.push(secure_prefix_sum);

let trace = self
.trace
.into_iter()
.flat_map(|eval| {
eval.columns.map(|col| {
CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col)
})
})
.collect_vec();
(trace, returned_prefix_sums)
}
}

/// Trace generator for a single lookup column.
Expand Down
5 changes: 0 additions & 5 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,6 @@ macro_rules! logup_proxy {
/// `batching` should contain the batch into which every logup entry should be inserted.
fn finalize_logup_batched(&mut self, batching: &crate::constraint_framework::Batching) {
assert!(!self.logup.is_finalized, "LogupAtRow was already finalized");

assert!(
self.logup.claimed_sum.is_none(),
"Partial prefix-sum is not supported"
);
assert_eq!(
batching.len(),
self.logup.fracs.len(),
Expand Down
Loading

0 comments on commit c2a584e

Please sign in to comment.