Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abstract Claim for multiple traces #93

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions crates/brainfuck_prover/src/brainfuck_air/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::components::memory::{self, table::MemoryTable};
use crate::components::{
memory::table::{MemoryColumn, MemoryTable},
Claim,
};
use brainfuck_vm::machine::Machine;
use stwo_prover::core::{
air::{Component, ComponentProver},
Expand Down Expand Up @@ -27,7 +30,7 @@ pub struct BrainfuckProof<H: MerkleHasher> {
/// It includes the common claim values such as the initial and final states
/// and the claim of each component.
pub struct BrainfuckClaim {
pub memory: memory::component::Claim,
pub memory: Claim<MemoryColumn>,
}

impl BrainfuckClaim {
Expand Down
14 changes: 8 additions & 6 deletions crates/brainfuck_prover/src/components/instruction/table.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::components::{memory::component::Claim, TraceError, TraceEval};
use crate::components::{Claim, TraceColumn, TraceError, TraceEval};
use brainfuck_vm::{
instruction::VALID_INSTRUCTIONS_BF, machine::ProgramMemory, registers::Registers,
};
Expand Down Expand Up @@ -128,7 +128,7 @@ impl InstructionTable {
///
/// # Errors
/// Returns [`TraceError::EmptyTrace`] if the table is empty.
pub fn trace_evaluation(&self) -> Result<(TraceEval, Claim), TraceError> {
pub fn trace_evaluation(&self) -> Result<(TraceEval, Claim<InstructionColumn>), TraceError> {
let n_rows = self.table.len() as u32;
// If the table is empty, there is no data to evaluate, so return an error.
if n_rows == 0 {
Expand Down Expand Up @@ -171,7 +171,7 @@ impl InstructionTable {
let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// Return the evaluated trace and a claim containing the log size of the domain.
Ok((trace, Claim { log_size }))
Ok((trace, Claim::<InstructionColumn>::new(log_size)))
}
}

Expand Down Expand Up @@ -233,16 +233,18 @@ impl InstructionColumn {
Self::Ni => 2,
}
}
}

/// Returns the total number of columns in the Instruction trace
pub const fn count() -> usize {
impl TraceColumn for InstructionColumn {
fn count() -> usize {
3
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::components::Claim;
use brainfuck_vm::{
compiler::Compiler, instruction::InstructionType, test_helper::create_test_machine,
};
Expand Down Expand Up @@ -583,7 +585,7 @@ mod tests {
.collect();

// Create the expected claim.
let expected_claim = Claim { log_size: expected_log_size };
let expected_claim = Claim::<InstructionColumn>::new(expected_log_size);

// Assert equality of the claim.
assert_eq!(claim, expected_claim);
Expand Down
34 changes: 28 additions & 6 deletions crates/brainfuck_prover/src/components/io/table.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::components::{memory::component::Claim, TraceEval};
use crate::components::{Claim, TraceColumn, TraceEval};
use brainfuck_vm::{instruction::InstructionType, registers::Registers};
use stwo_prover::core::{
backend::{
Expand Down Expand Up @@ -96,12 +96,12 @@ impl<const N: u32> IOTable<N> {
/// # Returns
/// A tuple containing the evaluated trace and claim for STARK proof.
/// If the table is empty, returns an empty trace and a claim with a log size of 0.
pub fn trace_evaluation(&self) -> (TraceEval, Claim) {
pub fn trace_evaluation(&self) -> (TraceEval, Claim<IoColumn>) {
let n_rows = self.table.len() as u32;

// It is possible that the table is empty because the program has no input or output.
if n_rows == 0 {
return (TraceEval::new(), Claim { log_size: 0 });
return (TraceEval::new(), Claim::<IoColumn>::new(0));
}

// Compute `log_n_rows`, the base-2 logarithm of the number of rows.
Expand All @@ -116,7 +116,7 @@ impl<const N: u32> IOTable<N> {

// Populate the column with data from the table rows.
for (index, row) in self.table.iter().enumerate().take(1 << log_n_rows) {
trace[0].data[index] = row.mv.into();
trace[IoColumn::Io.index()].data[index] = row.mv.into();
}

// Create a circle domain using a canonical coset.
Expand All @@ -126,7 +126,7 @@ impl<const N: u32> IOTable<N> {
let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// Return the evaluated trace and a claim containing the log size of the domain.
(trace, Claim { log_size })
(trace, Claim::<IoColumn>::new(log_size))
}
}

Expand Down Expand Up @@ -158,6 +158,28 @@ pub type InputTable = IOTable<{ InstructionType::ReadChar.to_u32() }>;
/// outputs (when the current instruction `ci` equals '.').
pub type OutputTable = IOTable<{ InstructionType::PutChar.to_u32() }>;

/// Enum representing the column indices in the IO trace.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IoColumn {
/// Column representing the input/output operations.
Io,
}

impl IoColumn {
/// Returns the index of the column in the IO table.
pub const fn index(self) -> usize {
match self {
Self::Io => 0,
}
}
}

impl TraceColumn for IoColumn {
fn count() -> usize {
1
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -329,7 +351,7 @@ mod tests {
expected_columns.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// Create the expected claim.
let expected_claim = Claim { log_size: expected_log_size };
let expected_claim = Claim::<IoColumn>::new(expected_log_size);

// Assert equality of the claim.
assert_eq!(claim, expected_claim, "The claim should match the expected claim.");
Expand Down
44 changes: 0 additions & 44 deletions crates/brainfuck_prover/src/components/memory/component.rs

This file was deleted.

1 change: 0 additions & 1 deletion crates/brainfuck_prover/src/components/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
pub mod component;
pub mod table;
14 changes: 7 additions & 7 deletions crates/brainfuck_prover/src/components/memory/table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::component::Claim;
use crate::components::{TraceError, TraceEval};
use crate::components::{Claim, TraceColumn, TraceError, TraceEval};
use brainfuck_vm::registers::Registers;
use num_traits::One;
use stwo_prover::core::{
Expand Down Expand Up @@ -196,7 +195,7 @@ impl MemoryTable {
///
/// # Errors
/// Returns [`TraceError::EmptyTrace`] if the table is empty.
pub fn trace_evaluation(&self) -> Result<(TraceEval, Claim), TraceError> {
pub fn trace_evaluation(&self) -> Result<(TraceEval, Claim<MemoryColumn>), TraceError> {
let n_rows = self.table.len() as u32;
if n_rows == 0 {
return Err(TraceError::EmptyTrace);
Expand All @@ -218,7 +217,7 @@ impl MemoryTable {
let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// TODO: Confirm that the log_size in `Claim` is `log_size`, including the SIMD lanes
Ok((trace, Claim { log_size }))
Ok((trace, Claim::<MemoryColumn>::new(log_size)))
}
}

Expand Down Expand Up @@ -260,9 +259,10 @@ impl MemoryColumn {
Self::D => 3,
}
}
}

/// Returns the total number of columns in the Memory table
pub const fn count() -> usize {
impl TraceColumn for MemoryColumn {
fn count() -> usize {
4
}
}
Expand Down Expand Up @@ -485,7 +485,7 @@ mod tests {
.into_iter()
.map(|col| CircleEvaluation::new(domain, col))
.collect();
let expected_claim = Claim { log_size: expected_log_size };
let expected_claim = Claim::<MemoryColumn>::new(expected_log_size);

assert_eq!(claim, expected_claim);
for col_index in 0..expected_trace.len() {
Expand Down
58 changes: 58 additions & 0 deletions crates/brainfuck_prover/src/components/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use stwo_prover::core::{
backend::simd::SimdBackend,
channel::Channel,
fields::m31::BaseField,
pcs::TreeVec,
poly::{circle::CircleEvaluation, BitReversedOrder},
ColumnVec,
};
Expand All @@ -21,3 +23,59 @@ pub enum TraceError {
#[error("The trace is empty.")]
EmptyTrace,
}

/// Represents a claim associated with a specific trace in the Brainfuck STARK proving system.
#[derive(Debug, Eq, PartialEq)]
pub struct Claim<T: TraceColumn> {
/// Logarithmic size (`log2`) of the evaluated trace.
pub log_size: u32,
/// Marker for the trace type.
pub _marker: std::marker::PhantomData<T>,
}

impl<T: TraceColumn> Claim<T> {
/// Creates a new claim for the given trace type.
pub const fn new(log_size: u32) -> Self {
Self { log_size, _marker: std::marker::PhantomData }
}

/// Returns the `log_size` for each type of trace committed for the given trace type:
/// - Preprocessed trace,
/// - Main trace,
/// - Interaction trace.
///
/// The number of columns of each trace is known before actually evaluating them.
/// The `log_size` is known once the main trace has been evaluated
/// (the log2 of the size of the [`super::table::MemoryTable`], to which we add
/// [`stwo_prover::core::backend::simd::m31::LOG_N_LANES`]
/// for the [`stwo_prover::core::backend::simd::SimdBackend`])
///
/// Each element of the [`TreeVec`] is dedicated to the commitment of one type of trace.
/// First element is for the preprocessed trace, second for the main trace and third for the
/// interaction one.
///
/// NOTE: Currently only the main trace is provided.
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
// TODO: Add the preprocessed and interaction trace correct sizes
let preprocessed_trace_log_sizes: Vec<u32> = vec![];
let trace_log_sizes = vec![self.log_size; T::count()];
let interaction_trace_log_sizes: Vec<u32> = vec![];
TreeVec::new(vec![
preprocessed_trace_log_sizes,
trace_log_sizes,
interaction_trace_log_sizes,
])
}

/// Mix the log size of the table to the Fiat-Shamir [`Channel`],
/// to bound the channel randomness and the trace.
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_u64(self.log_size.into());
}
}

/// Represents columns of a trace.
pub trait TraceColumn {
/// Returns the number of columns associated with the specific trace type.
fn count() -> usize;
}