Skip to content

Commit

Permalink
relation tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Dec 2, 2024
1 parent f7de614 commit 5f5a02f
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 14 deletions.
1 change: 0 additions & 1 deletion crates/prover/src/constraint_framework/relation_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ impl<E: FrameworkEval> RelationTrackerComponent<E> {
}

/// Aggregates relation entries.
// TODO(Ohad): test.
pub struct RelationTrackerEvaluator<'a> {
entries: Vec<RelationTrackerEntry>,
pub trace_eval:
Expand Down
48 changes: 46 additions & 2 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use num_traits::{One, Zero};

use crate::constraint_framework::logup::ClaimedPrefixSum;
use crate::constraint_framework::relation_tracker::{
RelationTrackerComponent, RelationTrackerEntry,
};
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry,
PREPROCESSED_TRACE_IDX,
TraceLocationAllocator, PREPROCESSED_TRACE_IDX,
};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::M31;
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::prover::StarkProof;
use crate::core::vcs::ops::MerkleHasher;

Expand Down Expand Up @@ -124,6 +129,45 @@ impl StateMachineComponents {
}
}

pub fn track_state_machine_relations(
trace: &TreeVec<&Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
x_axis_log_n_rows: u32,
y_axis_log_n_rows: u32,
n_rows_x: u32,
n_rows_y: u32,
) -> Vec<RelationTrackerEntry> {
let tree_span_provider = &mut TraceLocationAllocator::default();
let mut entries = vec![];
entries.extend(
RelationTrackerComponent::new(
tree_span_provider,
StateTransitionEval::<0> {
log_n_rows: x_axis_log_n_rows,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
},
n_rows_x as usize,
)
.entries(&trace.into()),
);
entries.extend(
RelationTrackerComponent::new(
tree_span_provider,
StateTransitionEval::<1> {
log_n_rows: y_axis_log_n_rows,
lookup_elements: StateMachineElements::dummy(),
total_sum: QM31::zero(),
claimed_sum: (QM31::zero(), 0),
},
n_rows_y as usize,
)
.entries(&trace.into()),
);

entries
}

pub struct StateMachineProof<H: MerkleHasher> {
pub public_input: [State; 2], // Initial and final state.
pub stmt0: StateMachineStatement0,
Expand Down
76 changes: 65 additions & 11 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::constraint_framework::relation_tracker::RelationSummary;
use crate::constraint_framework::Relation;
pub mod components;
pub mod gen;

use components::{
State, StateMachineComponents, StateMachineElements, StateMachineOp0Component,
StateMachineOp1Component, StateMachineProof, StateMachineStatement0, StateMachineStatement1,
StateTransitionEval,
track_state_machine_relations, State, StateMachineComponents, StateMachineElements,
StateMachineOp0Component, StateMachineOp1Component, StateMachineProof, StateMachineStatement0,
StateMachineStatement1, StateTransitionEval,
};
use gen::{gen_interaction_trace, gen_trace};
use itertools::{chain, Itertools};
Expand All @@ -19,7 +20,7 @@ use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Blake2sChannel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig};
use crate::core::pcs::{CommitmentSchemeProver, CommitmentSchemeVerifier, PcsConfig, TreeVec};
use crate::core::poly::circle::{CanonicCoset, PolyOps};
use crate::core::prover::{prove, verify, VerificationError};
use crate::core::vcs::blake2_merkle::{Blake2sMerkleChannel, Blake2sMerkleHasher};
Expand All @@ -30,9 +31,11 @@ pub fn prove_state_machine(
initial_state: State,
config: PcsConfig,
channel: &mut Blake2sChannel,
track_relations: bool,
) -> (
StateMachineComponents,
StateMachineProof<Blake2sMerkleHasher>,
Option<RelationSummary>,
) {
let (x_axis_log_rows, y_axis_log_rows) = (log_n_rows, log_n_rows - 1);
let (x_row, y_row) = (34, 56);
Expand Down Expand Up @@ -62,14 +65,32 @@ pub fn prove_state_machine(
];

// Preprocessed trace.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(gen_preprocessed_columns(preprocessed_columns.iter()));
tree_builder.commit(channel);
let preprocessed_trace = gen_preprocessed_columns(preprocessed_columns.iter());

// Trace.
let trace_op0 = gen_trace(x_axis_log_rows, initial_state, 0);
let trace_op1 = gen_trace(y_axis_log_rows, intermediate_state, 1);

let trace = chain![trace_op0.clone(), trace_op1.clone()].collect_vec();

let relation_summary = match track_relations {
false => None,
true => Some(RelationSummary::summarize_relations(
&track_state_machine_relations(
&TreeVec(vec![&preprocessed_trace, &trace]),
x_axis_log_rows,
y_axis_log_rows,
x_row,
y_row,
),
)),
};

// Commitments.
let mut tree_builder = commitment_scheme.tree_builder();
tree_builder.extend_evals(preprocessed_trace);
tree_builder.commit(channel);

let stmt0 = StateMachineStatement0 {
n: x_axis_log_rows,
m: y_axis_log_rows,
Expand Down Expand Up @@ -135,7 +156,7 @@ pub fn prove_state_machine(
stmt1,
stark_proof,
};
(components, proof)
(components, proof, relation_summary)
}

pub fn verify_state_machine(
Expand Down Expand Up @@ -250,7 +271,8 @@ mod tests {

// Setup protocol.
let channel = &mut Blake2sChannel::default();
let (component, _) = prove_state_machine(log_n_rows, initial_state, config, channel);
let (component, ..) =
prove_state_machine(log_n_rows, initial_state, config, channel, false);

let interaction_elements = component.component0.lookup_elements.clone();
let initial_state_comb: QM31 = interaction_elements.combine(&initial_state);
Expand All @@ -262,6 +284,38 @@ mod tests {
);
}

#[test]
fn test_relation_tracker() {
let log_n_rows = 8;
let config = PcsConfig::default();
let initial_state = [M31::zero(); STATE_SIZE];
let final_state = [M31::from_u32_unchecked(34), M31::from_u32_unchecked(56)];

// Summarize `StateMachineElements`.
let (_, _, summary) = prove_state_machine(
log_n_rows,
initial_state,
config,
&mut Blake2sChannel::default(),
true,
);
let summary = summary.unwrap();
let relation_info = summary.get_relation_info("StateMachineElements").unwrap();

// Check the final state inferred from the summary.
let mut curr_state = initial_state;
for entry in relation_info {
let x_step = entry.0[0];
let y_step = entry.0[1];
let mult = entry.1;
let next_state = [curr_state[0] - x_step * mult, curr_state[1] - y_step * mult];

curr_state = next_state;
}

assert_eq!(curr_state, final_state);
}

#[test]
fn test_state_machine_prove() {
let log_n_rows = 8;
Expand All @@ -270,8 +324,8 @@ mod tests {
let prover_channel = &mut Blake2sChannel::default();
let verifier_channel = &mut Blake2sChannel::default();

let (components, proof) =
prove_state_machine(log_n_rows, initial_state, config, prover_channel);
let (components, proof, _) =
prove_state_machine(log_n_rows, initial_state, config, prover_channel, false);

verify_state_machine(config, verifier_channel, components, proof).unwrap();
}
Expand Down

1 comment on commit 5f5a02f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 5f5a02f Previous: cd8b37b Ratio
iffts/simd ifft/21 6145444 ns/iter (± 249518) 3044114 ns/iter (± 156865) 2.02
iffts/simd ifft/22 12838541 ns/iter (± 305315) 6306399 ns/iter (± 210024) 2.04
iffts/simd ifft/28 1293461261 ns/iter (± 32972126) 643771030 ns/iter (± 19147376) 2.01
merkle throughput/simd merkle 32341215 ns/iter (± 618500) 13712527 ns/iter (± 579195) 2.36

This comment was automatically generated by workflow using github-action-benchmark.

CC: @shaharsamocha7

Please sign in to comment.