diff --git a/crates/prover/src/constraint_framework/relation_tracker.rs b/crates/prover/src/constraint_framework/relation_tracker.rs index b5220130a..3866df39a 100644 --- a/crates/prover/src/constraint_framework/relation_tracker.rs +++ b/crates/prover/src/constraint_framework/relation_tracker.rs @@ -75,7 +75,6 @@ impl RelationTrackerComponent { } /// Aggregates relation entries. -// TODO(Ohad): test. pub struct RelationTrackerEvaluator<'a> { entries: Vec, pub trace_eval: diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index 2451eef23..4600a3cf0 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -1,16 +1,21 @@ use num_traits::{One, Zero}; use crate::constraint_framework::logup::ClaimedPrefixSum; +use crate::constraint_framework::relation_tracker::{ + RelationTrackerComponent, RelationTrackerEntry, +}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry, - PREPROCESSED_TRACE_IDX, + TraceLocationAllocator, PREPROCESSED_TRACE_IDX, }; use crate::core::air::{Component, ComponentProver}; use crate::core::backend::simd::SimdBackend; use crate::core::channel::Channel; -use crate::core::fields::m31::M31; +use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; use crate::core::prover::StarkProof; use crate::core::vcs::ops::MerkleHasher; @@ -124,6 +129,45 @@ impl StateMachineComponents { } } +pub fn track_state_machine_relations( + trace: &TreeVec<&Vec>>, + x_axis_log_n_rows: u32, + y_axis_log_n_rows: u32, + n_rows_x: u32, + n_rows_y: u32, +) -> Vec { + let tree_span_provider = &mut TraceLocationAllocator::default(); + let mut entries = vec![]; + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<0> { + log_n_rows: x_axis_log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + n_rows_x as usize, + ) + .entries(&trace.into()), + ); + entries.extend( + RelationTrackerComponent::new( + tree_span_provider, + StateTransitionEval::<1> { + log_n_rows: y_axis_log_n_rows, + lookup_elements: StateMachineElements::dummy(), + total_sum: QM31::zero(), + claimed_sum: (QM31::zero(), 0), + }, + n_rows_y as usize, + ) + .entries(&trace.into()), + ); + + entries +} + pub struct StateMachineProof { pub public_input: [State; 2], // Initial and final state. pub stmt0: StateMachineStatement0, diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 8dbe3a068..23973a960 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -1,11 +1,12 @@ +use crate::constraint_framework::relation_tracker::RelationSummary; use crate::constraint_framework::Relation; pub mod components; pub mod gen; use components::{ - State, StateMachineComponents, StateMachineElements, StateMachineOp0Component, - StateMachineOp1Component, StateMachineProof, StateMachineStatement0, StateMachineStatement1, - StateTransitionEval, + track_state_machine_relations, State, StateMachineComponents, StateMachineElements, + StateMachineOp0Component, StateMachineOp1Component, StateMachineProof, StateMachineStatement0, + StateMachineStatement1, StateTransitionEval, }; use gen::{gen_interaction_trace, gen_trace}; use itertools::{chain, Itertools}; @@ -19,7 +20,7 @@ use crate::core::backend::simd::SimdBackend; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; -use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig}; +use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec}; use crate::core::poly::circle::{CanonicCoset, PolyOps}; use crate::core::prover::{prove, verify, VerificationError}; use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher}; @@ -30,9 +31,11 @@ pub fn prove_state_machine( initial_state: State, config: PcsConfig, channel: &mut Blake2sChannel, + track_relations: bool, ) -> ( StateMachineComponents, StateMachineProof, + Option, ) { let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1); let (x_row, y_row) = (34, 56); @@ -62,14 +65,32 @@ pub fn prove_state_machine( ]; // Preprocessed trace. - let mut tree_builder = commitment_scheme.tree_builder(); - tree_builder.extend_evals(gen_preprocessed_columns(preprocessed_columns.iter())); - tree_builder.commit(channel); + let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter()); // Trace. let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0); let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1); + let trace = chain![trace_op0.clone(), trace_op1.clone()].collect_vec(); + + let relation_summary = match track_relations { + false => None, + true => Some(RelationSummary::summarize_relations( + &track_state_machine_relations( + &TreeVec(vec![&preprocessed_trace, &trace]), + x_axis_log_rows, + y_axis_log_rows, + x_row, + y_row, + ), + )), + }; + + // Commitments. + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(preprocessed_trace); + tree_builder.commit(channel); + let stmt0 = StateMachineStatement0 { n: x_axis_log_rows, m: y_axis_log_rows, @@ -135,7 +156,7 @@ pub fn prove_state_machine( stmt1, stark_proof, }; - (components, proof) + (components, proof, relation_summary) } pub fn verify_state_machine( @@ -250,7 +271,8 @@ mod tests { // Setup protocol. let channel = &mut Blake2sChannel::default(); - let (component, _) = prove_state_machine(log_n_rows, initial_state, config, channel); + let (component, ..) = + prove_state_machine(log_n_rows, initial_state, config, channel, false); let interaction_elements = component.component0.lookup_elements.clone(); let initial_state_comb: QM31 = interaction_elements.combine(&initial_state); @@ -262,6 +284,38 @@ mod tests { ); } + #[test] + fn test_relation_tracker() { + let log_n_rows = 8; + let config = PcsConfig::default(); + let initial_state = [M31::zero(); STATE_SIZE]; + let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)]; + + // Summarize `StateMachineElements`. + let (_, _, summary) = prove_state_machine( + log_n_rows, + initial_state, + config, + &mut Blake2sChannel::default(), + true, + ); + let summary = summary.unwrap(); + let relation_info = summary.get_relation_info("StateMachineElements").unwrap(); + + // Check the final state inferred from the summary. + let mut curr_state = initial_state; + for entry in relation_info { + let x_step = entry.0[0]; + let y_step = entry.0[1]; + let mult = entry.1; + let next_state = [curr_state[0] - x_step * mult, curr_state[1] - y_step * mult]; + + curr_state = next_state; + } + + assert_eq!(curr_state, final_state); + } + #[test] fn test_state_machine_prove() { let log_n_rows = 8; @@ -270,8 +324,8 @@ mod tests { let prover_channel = &mut Blake2sChannel::default(); let verifier_channel = &mut Blake2sChannel::default(); - let (components, proof) = - prove_state_machine(log_n_rows, initial_state, config, prover_channel); + let (components, proof, _) = + prove_state_machine(log_n_rows, initial_state, config, prover_channel, false); verify_state_machine(config, verifier_channel, components, proof).unwrap(); }