diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index ec34f4c11..3a1869b37 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,6 +1,6 @@ use num_traits::{One, Zero}; -use crate::constraint_framework::logup::{LogupAtRow, LookupElements}; +use crate::constraint_framework::logup::{ClaimedPrefixSum, LogupAtRow, LookupElements}; use crate::constraint_framework::{EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator}; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::SimdBackend; @@ -28,6 +28,7 @@ pub struct StateTransitionEval { pub log_n_rows: u32, pub lookup_elements: StateMachineElements, pub total_sum: QM31, + pub claimed_sum: ClaimedPrefixSum, } impl FrameworkEval for StateTransitionEval { @@ -39,7 +40,8 @@ impl FrameworkEval for StateTransitionEval } fn evaluate(&self, mut eval: E) -> E { let [is_first] = eval.next_interaction_mask(2, [0]); - let mut logup: LogupAtRow = LogupAtRow::new(1, self.total_sum, None, is_first); + let mut logup: LogupAtRow = + LogupAtRow::new(1, self.total_sum, Some(self.claimed_sum), is_first); let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask()); let input_denom: E::EF = self.lookup_elements.combine(&input_state); @@ -98,6 +100,7 @@ fn state_transition_info() -> InfoEvaluator { log_n_rows: 1, lookup_elements: StateMachineElements::dummy(), total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), }; component.evaluate(InfoEvaluator::default()) } diff --git a/crates/prover/src/examples/state_machine/gen.rs b/crates/prover/src/examples/state_machine/gen.rs index 7c868aadc..4ec6fdbb2 100644 --- a/crates/prover/src/examples/state_machine/gen.rs +++ b/crates/prover/src/examples/state_machine/gen.rs @@ -11,6 +11,7 @@ use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; +use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use crate::core::ColumnVec; // Given `initial state`, generate a trace that row `i` is the initial state plus `i` in the @@ -30,7 +31,9 @@ pub fn gen_trace( // Add the states in bit reversed circle domain order. for i in 0..1 << log_size { for j in 0..STATE_SIZE { - trace[j][i] = curr_state[j]; + let bit_rev_index = + bit_reverse_index(coset_index_to_circle_domain_index(i, log_size), log_size); + trace[j][bit_rev_index] = curr_state[j]; } // Increment the state to the next state row. curr_state[inc_index] += M31::one(); @@ -48,14 +51,17 @@ pub fn gen_trace( } pub fn gen_interaction_trace( - log_size: u32, + n_rows: usize, trace: &ColumnVec>, inc_index: usize, lookup_elements: &LookupElements, ) -> ( ColumnVec>, - QM31, + [QM31; 2], ) { + let log_size = trace[0].domain.log_size(); + assert!(n_rows <= 1 << log_size, "n_rows exceeds the trace size"); + let ones = PackedM31::broadcast(M31::one()); let mut logup_gen = LogupTraceGenerator::new(log_size); let mut col_gen = logup_gen.new_col(); @@ -78,7 +84,7 @@ pub fn gen_interaction_trace( } col_gen.finalize_col(); - logup_gen.finalize_last() + logup_gen.finalize_at([(1 << log_size) - 1, n_rows]) } #[cfg(test)] @@ -88,6 +94,7 @@ mod tests { use crate::core::fields::qm31::QM31; use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use crate::core::fields::FieldExpOps; + use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index}; use crate::examples::state_machine::components::StateMachineElements; use crate::examples::state_machine::gen::{gen_interaction_trace, gen_trace}; @@ -97,13 +104,15 @@ mod tests { let initial_state = [M31::from_u32_unchecked(17), M31::from_u32_unchecked(16)]; let inc_index = 1; let row = 123; + let bit_rev_row = + bit_reverse_index(coset_index_to_circle_domain_index(row, log_size), log_size); let trace = gen_trace(log_size, initial_state, inc_index); assert_eq!(trace.len(), 2); assert_eq!(trace[0].at(row), initial_state[0]); assert_eq!( - trace[1].at(row), + trace[1].at(bit_rev_row), initial_state[1] + M31::from_u32_unchecked(row as u32) ); } @@ -122,10 +131,11 @@ mod tests { let first_state_comb: QM31 = lookup_elements.combine(&first_state); let last_state_comb: QM31 = lookup_elements.combine(&last_state); - let (interaction_trace, total_sum) = - gen_interaction_trace(log_size, &trace, inc_index, &lookup_elements); + let (interaction_trace, [total_sum, claimed_sum]) = + gen_interaction_trace((1 << log_size) - 1, &trace, inc_index, &lookup_elements); assert_eq!(interaction_trace.len(), SECURE_EXTENSION_DEGREE); // One extension column. + assert_eq!(claimed_sum, total_sum); assert_eq!( total_sum, first_state_comb.inverse() - last_state_comb.inverse() diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 268d37c30..1855a1980 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -31,14 +31,16 @@ pub fn prove_state_machine( StateMachineComponents, StateMachineProof, ) { - assert!(log_n_rows >= LOG_N_LANES); - let x_axis_log_rows = log_n_rows; - let y_axis_log_rows = log_n_rows - 1; + let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1); + let (x_row, y_row) = (34, 56); + assert!(y_axis_log_rows >= LOG_N_LANES && x_axis_log_rows >= LOG_N_LANES); + assert!(x_row < 1 << x_axis_log_rows); + assert!(y_row < 1 << y_axis_log_rows); let mut intermediate_state = initial_state; - intermediate_state[0] += M31::from_u32_unchecked(1 << x_axis_log_rows); + intermediate_state[0] += M31::from_u32_unchecked(x_row); let mut final_state = intermediate_state; - final_state[1] += M31::from_u32_unchecked(1 << y_axis_log_rows); + final_state[1] += M31::from_u32_unchecked(y_row); // Precompute twiddles. let twiddles = SimdBackend::precompute_twiddles( @@ -69,14 +71,14 @@ pub fn prove_state_machine( let lookup_elements = StateMachineElements::draw(channel); // Interaction trace. - let (interaction_trace_op0, total_sum_op0) = - gen_interaction_trace(x_axis_log_rows, &trace_op0, 0, &lookup_elements); - let (interaction_trace_op1, total_sum_op1) = - gen_interaction_trace(y_axis_log_rows, &trace_op1, 1, &lookup_elements); + let (interaction_trace_op0, [total_sum_op0, claimed_sum_op0]) = + gen_interaction_trace(x_row as usize - 1, &trace_op0, 0, &lookup_elements); + let (interaction_trace_op1, [total_sum_op1, claimed_sum_op1]) = + gen_interaction_trace(y_row as usize - 1, &trace_op1, 1, &lookup_elements); let stmt1 = StateMachineStatement1 { - x_axis_claimed_sum: total_sum_op0, - y_axis_claimed_sum: total_sum_op1, + x_axis_claimed_sum: claimed_sum_op0, + y_axis_claimed_sum: claimed_sum_op1, }; stmt1.mix_into(channel); @@ -100,6 +102,7 @@ pub fn prove_state_machine( log_n_rows: x_axis_log_rows, lookup_elements: lookup_elements.clone(), total_sum: total_sum_op0, + claimed_sum: (claimed_sum_op0, x_row as usize - 1), }, ); let component1 = StateMachineOp1Component::new( @@ -108,6 +111,7 @@ pub fn prove_state_machine( log_n_rows: y_axis_log_rows, lookup_elements, total_sum: total_sum_op1, + claimed_sum: (claimed_sum_op1, y_row as usize - 1), }, ); let components = StateMachineComponents { @@ -190,15 +194,17 @@ mod tests { let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); // Interaction trace. - let (interaction_trace, total_sum) = - gen_interaction_trace(log_n_rows, &trace, 0, &lookup_elements); + let (interaction_trace, [total_sum, claimed_sum]) = + gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements); + assert_eq!(total_sum, claimed_sum); let component = StateMachineOp0Component::new( &mut TraceLocationAllocator::default(), StateTransitionEval { log_n_rows, lookup_elements, total_sum, + claimed_sum: (total_sum, (1 << log_n_rows) - 1), }, ); @@ -214,16 +220,13 @@ mod tests { } #[test] - fn test_state_machine_total_sum() { + fn test_state_machine_claimed_sum() { let log_n_rows = 8; let config = PcsConfig::default(); // Initial and last state. let initial_state = [M31::zero(); STATE_SIZE]; - let last_state = [ - M31::from_u32_unchecked(1 << log_n_rows), - M31::from_u32_unchecked(1 << (log_n_rows - 1)), - ]; + let last_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)]; // Setup protocol. let channel = &mut Blake2sChannel::default(); @@ -234,7 +237,7 @@ mod tests { let last_state_comb: QM31 = interaction_elements.combine(&last_state); assert_eq!( - component.component0.total_sum + component.component1.total_sum, + component.component0.claimed_sum.0 + component.component1.claimed_sum.0, initial_state_comb.inverse() - last_state_comb.inverse() ); }