Skip to content

Commit

Permalink
Add constraint for Logup claimed cumsum
Browse files Browse the repository at this point in the history
  • Loading branch information
shaharsamocha7 committed Sep 24, 2024
1 parent 69bcf5e commit 8ba2299
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 14 deletions.
40 changes: 36 additions & 4 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,21 @@ use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::ColumnVec;

/// Represents the value of the prefix sum column at some index.
/// Should be used to eliminate padded rows for the logup sum.
pub type ClaimedPrefixSum = (SecureField, usize);

/// Evaluates constraints for batched logups.
/// These constraint enforce the sum of multiplicity_i / (z + sum_j alpha^j * x_j) = claimed_sum.
pub struct LogupAtRow<E: EvalAtRow> {
/// The index of the interaction used for the cumulative sum columns.
pub interaction: usize,
/// The total sum of all the fractions.
pub total_sum: SecureField,
/// The claimed sum of the relevant fractions.
/// This is used for padding the component with default rows. Padding should be in bit-reverse.
/// None if the claimed_sum is the total_sum.
pub claimed_sum: Option<ClaimedPrefixSum>,
/// The evaluation of the last cumulative sum column.
pub prev_col_cumsum: E::EF,
cur_frac: Option<Fraction<E::EF, E::EF>>,
Expand All @@ -37,10 +45,16 @@ pub struct LogupAtRow<E: EvalAtRow> {
pub is_first: E::F,
}
impl<E: EvalAtRow> LogupAtRow<E> {
pub fn new(interaction: usize, total_sum: SecureField, is_first: E::F) -> Self {
pub fn new(
interaction: usize,
total_sum: SecureField,
claimed_sum: Option<ClaimedPrefixSum>,
is_first: E::F,
) -> Self {
Self {
interaction,
total_sum,
claimed_sum,
prev_col_cumsum: E::EF::zero(),
cur_frac: None,
is_finalized: false,
Expand All @@ -64,9 +78,26 @@ impl<E: EvalAtRow> LogupAtRow<E> {

let frac = self.cur_frac.unwrap();

let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
// TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted offset
// from the is_first column when constant columns are supported.
let (cur_cumsum, prev_row_cumsum) = match self.claimed_sum {
Some((claimed_sum, claimed_row_index)) => {
let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = eval
.next_extension_interaction_mask(
self.interaction,
[0, -1, claimed_row_index as isize],
);

// Constrain that the claimed_sum in case that it is not equal to the total_sum.
eval.add_constraint((claimed_cumsum - claimed_sum) * self.is_first);
(cur_cumsum, prev_row_cumsum)
}
None => {
let [cur_cumsum, prev_row_cumsum] =
eval.next_extension_interaction_mask(self.interaction, [0, -1]);
(cur_cumsum, prev_row_cumsum)
}
};
// Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row.
let fixed_prev_row_cumsum = prev_row_cumsum - self.is_first * self.total_sum;
let diff = cur_cumsum - fixed_prev_row_cumsum - self.prev_col_cumsum;
Expand Down Expand Up @@ -277,7 +308,8 @@ mod tests {
#[test]
#[should_panic]
fn test_logup_not_finalized_panic() {
let mut logup = LogupAtRow::<InfoEvaluator>::new(1, SecureField::one(), BaseField::one());
let mut logup =
LogupAtRow::<InfoEvaluator>::new(1, SecureField::one(), None, BaseField::one());
logup.write_frac(
&mut InfoEvaluator::default(),
Fraction::new(SecureField::one(), SecureField::one()),
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/round/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl FrameworkEval for BlakeRoundEval {
eval,
xor_lookup_elements: &self.xor_lookup_elements,
round_lookup_elements: &self.round_lookup_elements,
logup: LogupAtRow::new(1, self.total_sum, is_first),
logup: LogupAtRow::new(1, self.total_sum, None, is_first),
};
blake_eval.eval()
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl FrameworkEval for BlakeSchedulerEval {
&mut eval,
&self.blake_lookup_elements,
&self.round_lookup_elements,
LogupAtRow::new(1, self.total_sum, is_first),
LogupAtRow::new(1, self.total_sum, None, is_first),
);
eval
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/examples/blake/xor_table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl<const ELEM_BITS: u32, const EXPAND_BITS: u32> FrameworkEval
let xor_eval = constraints::XorTableEval::<'_, _, ELEM_BITS, EXPAND_BITS> {
eval,
lookup_elements: &self.lookup_elements,
logup: LogupAtRow::new(1, self.claimed_sum, is_first),
logup: LogupAtRow::new(1, self.claimed_sum, None, is_first),
};
xor_eval.eval()
}
Expand Down
17 changes: 12 additions & 5 deletions crates/prover/src/examples/plonk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use num_traits::One;
use tracing::{span, Level};

use crate::constraint_framework::constant_columns::gen_is_first;
use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements};
use crate::constraint_framework::logup::{
ClaimedPrefixSum, LogupAtRow, LogupTraceGenerator, LookupElements,
};
use crate::constraint_framework::{
assert_constraints, EvalAtRow, FrameworkComponent, FrameworkEval, TraceLocationAllocator,
};
Expand All @@ -29,6 +31,7 @@ pub type PlonkComponent = FrameworkComponent<PlonkEval>;
pub struct PlonkEval {
pub log_n_rows: u32,
pub lookup_elements: LookupElements<2>,
pub claimed_sum: ClaimedPrefixSum,
pub total_sum: SecureField,
pub base_trace_location: TreeSubspan,
pub interaction_trace_location: TreeSubspan,
Expand All @@ -46,7 +49,7 @@ impl FrameworkEval for PlonkEval {

fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let mut logup = LogupAtRow::<_>::new(1, self.total_sum, is_first);
let mut logup = LogupAtRow::<_>::new(1, self.total_sum, Some(self.claimed_sum), is_first);

let [a_wire] = eval.next_interaction_mask(2, [0]);
let [b_wire] = eval.next_interaction_mask(2, [0]);
Expand Down Expand Up @@ -113,11 +116,12 @@ pub fn gen_trace(

pub fn gen_interaction_trace(
log_size: u32,
padding_offset: usize,
circuit: &PlonkCircuitTrace,
lookup_elements: &LookupElements<2>,
) -> (
ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>,
SecureField,
[SecureField; 2],
) {
let _span = span!(Level::INFO, "Generate interaction trace").entered();
let mut logup_gen = LogupTraceGenerator::new(log_size);
Expand All @@ -141,7 +145,7 @@ pub fn gen_interaction_trace(
}
col_gen.finalize_col();

logup_gen.finalize_last()
logup_gen.finalize_at([(1 << log_size) - 1, padding_offset])
}

#[allow(unused)]
Expand All @@ -156,6 +160,7 @@ pub fn prove_fibonacci_plonk(
for _ in 0..(1 << log_n_rows) {
fib_values.push(fib_values[fib_values.len() - 1] + fib_values[fib_values.len() - 2]);
}
let padding_offset = 17;
let range = 0..(1 << log_n_rows);
let mut circuit = PlonkCircuitTrace {
mult: range.clone().map(|_| 2.into()).collect(),
Expand Down Expand Up @@ -197,7 +202,8 @@ pub fn prove_fibonacci_plonk(

// Interaction trace.
let span = span!(Level::INFO, "Interaction").entered();
let (trace, total_sum) = gen_interaction_trace(log_n_rows, &circuit, &lookup_elements);
let (trace, [total_sum, claimed_sum]) =
gen_interaction_trace(log_n_rows, padding_offset, &circuit, &lookup_elements);
let mut tree_builder = commitment_scheme.tree_builder();
let interaction_trace_location = tree_builder.extend_evals(trace);
tree_builder.commit(channel);
Expand Down Expand Up @@ -227,6 +233,7 @@ pub fn prove_fibonacci_plonk(
PlonkEval {
log_n_rows,
lookup_elements,
claimed_sum: (claimed_sum, padding_offset),
total_sum,
base_trace_location,
interaction_trace_location,
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/examples/poseidon/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl FrameworkEval for PoseidonEval {
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let [is_first] = eval.next_interaction_mask(2, [0]);
let logup = LogupAtRow::new(1, self.total_sum, is_first);
let logup = LogupAtRow::new(1, self.total_sum, None, is_first);
eval_poseidon_constraints(&mut eval, logup, &self.lookup_elements);
eval
}
Expand Down Expand Up @@ -482,7 +482,7 @@ mod tests {
let [is_first] = eval.next_interaction_mask(2, [0]);
eval_poseidon_constraints(
&mut eval,
LogupAtRow::new(1, total_sum, is_first),
LogupAtRow::new(1, total_sum, None, is_first),
&lookup_elements,
);
});
Expand Down

0 comments on commit 8ba2299

Please sign in to comment.