Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add memory trace prove entrypoint #71

Merged
merged 18 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/brainfuck_prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ workspace = true
brainfuck_vm.workspace = true
stwo-prover.workspace = true
num-traits.workspace = true
thiserror.workspace = true
52 changes: 34 additions & 18 deletions crates/brainfuck_prover/src/brainfuck_air/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::components::memory::{self, table::MemoryTable};
use brainfuck_vm::machine::Machine;
use stwo_prover::core::{
air::{Component, ComponentProver},
backend::simd::SimdBackend,
Expand All @@ -24,15 +26,17 @@ pub struct BrainfuckProof<H: MerkleHasher> {
///
/// It includes the common claim values such as the initial and final states
/// and the claim of each component.
pub struct BrainfuckClaim;
pub struct BrainfuckClaim {
pub memory: memory::component::Claim,
}

impl BrainfuckClaim {
pub fn mix_into(&self, _channel: &mut impl Channel) {
todo!();
pub fn mix_into(&self, channel: &mut impl Channel) {
self.memory.mix_into(channel);
}

pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
todo!();
self.memory.log_sizes()
}
}

Expand Down Expand Up @@ -93,12 +97,19 @@ impl BrainfuckComponents {
}
}

/// `LOG_MAX_ROWS = log2(MAX_ROWS)` ?
/// `LOG_MAX_ROWS = ilog2(MAX_ROWS)`
///
/// Means that the ZK-VM does not accept programs with more than 2^20 steps (1M steps).
const LOG_MAX_ROWS: u32 = 20;

pub fn prove_brainfuck() -> Result<BrainfuckProof<Blake2sMerkleHasher>, ProvingError> {
/// Generate a STARK proof of the given Brainfuck program execution.
///
/// # Arguments
/// * `inputs` - The [`Machine`] struct after the program execution
/// The inputs contains the program, the memory, the I/O and the trace.
pub fn prove_brainfuck(
inputs: &Machine,
) -> Result<BrainfuckProof<Blake2sMerkleHasher>, ProvingError> {
// ┌──────────────────────────┐
// │ Protocol Setup │
// └──────────────────────────┘
Expand All @@ -112,22 +123,26 @@ pub fn prove_brainfuck() -> Result<BrainfuckProof<Blake2sMerkleHasher>, ProvingE
let channel = &mut Blake2sChannel::default();
let commitment_scheme =
&mut CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles);
let tree_builder = commitment_scheme.tree_builder();
let mut tree_builder = commitment_scheme.tree_builder();

// ┌──────────────────────────┐
// │ Interaction Phase 0 │
// └──────────────────────────┘
// ┌───────────────────────────────────────┐
// │ Interaction Phase 0 - Main Trace │
// └───────────────────────────────────────┘
let vm_trace = inputs.get_trace();
tcoratger marked this conversation as resolved.
Show resolved Hide resolved

// Generate BrainfuckClaim (from the execution trace provided by brainfuck_vm)
let (memory_trace, memory_claim) = MemoryTable::from(vm_trace).trace_evaluation().unwrap();

tree_builder.extend_evals(memory_trace);

let claim = BrainfuckClaim { memory: memory_claim };

// Commit to the claim and the trace.
let claim = BrainfuckClaim {};
claim.mix_into(channel);
tree_builder.commit(channel);

// ┌──────────────────────────┐
// │ Interaction Phase 1 │
// └──────────────────────────┘
// ┌───────────────────────────────────────────────
// │ Interaction Phase 1 - Interaction Trace
// └───────────────────────────────────────────────

// Draw interaction elements
let interaction_elements = BrainfuckInteractionElements::draw(channel);
Expand All @@ -139,9 +154,10 @@ pub fn prove_brainfuck() -> Result<BrainfuckProof<Blake2sMerkleHasher>, ProvingE
interaction_claim.mix_into(channel);
tree_builder.commit(channel);

// ┌──────────────────────────┐
// │ Interaction Phase 2 │
// └──────────────────────────┘
// TODO: move the preprocessed trace to Phase 0
// ┌───────────────────────────────────────────────┐
// │ Interaction Phase 2 - Preprocessed Trace │
// └───────────────────────────────────────────────┘

// Generate constant columns (e.g. is_first)
let tree_builder = commitment_scheme.tree_builder();
Expand Down
38 changes: 38 additions & 0 deletions crates/brainfuck_prover/src/components/memory/component.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use super::table::N_COLS_MEMORY_TABLE;
use stwo_prover::core::{channel::Channel, pcs::TreeVec};

/// The claim for the Memory component
#[derive(Debug, Eq, PartialEq)]
pub struct Claim {
pub log_size: u32,
}

impl Claim {
/// Returns the `log_size` of the each type of trace commited for the Memory component:
/// - Preprocessed trace,
/// - Main trace,
/// - Interaction trace.
///
/// The number of columns of each trace is known before actually evaluating them.
/// The `log_size` is known once the main trace has been evaluated
/// (the log2 of the size of the [`super::table::MemoryTable`], to which we add
/// [`stwo_prover::core::backend::simd::m31::LOG_N_LANES`]
/// for the [`stwo_prover::core::backend::simd::SimdBackend`])
///
/// Each element of the [`TreeVec`] is dedicated to the commitment of one type of trace.
/// First element is for the preprocessed trace, second for the main trace and third for the
/// interaction one.
///
/// NOTE: Currently only the main trace is provided.
pub fn log_sizes(&self) -> TreeVec<Vec<u32>> {
// TODO: Add the preprocessed and interaction trace sizes
let trace_log_sizes = vec![self.log_size; N_COLS_MEMORY_TABLE];
TreeVec::new(vec![trace_log_sizes])
tcoratger marked this conversation as resolved.
Show resolved Hide resolved
}

/// Mix the log size of the Memory table to the Fiat-Shamir [`Channel`],
/// to bound the channel randomness and the trace.
pub fn mix_into(&self, channel: &mut impl Channel) {
channel.mix_u64(self.log_size.into());
}
}
1 change: 1 addition & 0 deletions crates/brainfuck_prover/src/components/memory/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod component;
pub mod table;
144 changes: 143 additions & 1 deletion crates/brainfuck_prover/src/components/memory/table.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
use super::component::Claim;
use crate::components::{TraceError, TraceEval};
use brainfuck_vm::registers::Registers;
use num_traits::One;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::{
backend::{
simd::{column::BaseColumn, m31::LOG_N_LANES},
Column,
},
fields::m31::BaseField,
poly::circle::{CanonicCoset, CircleEvaluation},
};

/// Represents a single row in the Memory Table.
///
Expand All @@ -22,13 +31,43 @@ pub struct MemoryTableRow {
}

impl MemoryTableRow {
/// Creates a row for the [`MemoryTable`] which is considered 'real'.
///
/// A 'real' row, is a row that is part of the execution trace from the Brainfuck program
/// execution.
pub fn new(clk: BaseField, mp: BaseField, mv: BaseField) -> Self {
Self { clk, mp, mv, ..Default::default() }
}

/// Creates a row for the [`MemoryTable`] which is considered 'dummy'.
///
/// A 'dummy' row, is a row that is not part of the execution trace from the Brainfuck program
/// execution.
/// They are used for padding and filling the `clk` gaps after sorting by `mp`, to enforce the
/// correct sorting.
pub fn new_dummy(clk: BaseField, mp: BaseField, mv: BaseField) -> Self {
Self { clk, mp, mv, d: BaseField::one() }
}

/// Getter for the `clk` field.
pub const fn clk(&self) -> BaseField {
self.clk
}

/// Getter for the `mp` field.
pub const fn mp(&self) -> BaseField {
self.mp
}

/// Getter for the `mv` field.
pub const fn mv(&self) -> BaseField {
self.mv
}

/// Getter for the `d` field.
pub const fn d(&self) -> BaseField {
self.d
}
}

impl From<(&Registers, bool)> for MemoryTableRow {
Expand Down Expand Up @@ -65,6 +104,11 @@ impl MemoryTable {
Self::default()
}

/// Getter for the `table` field.
pub const fn table(&self) -> &Vec<MemoryTableRow> {
&self.table
}

/// Adds a new row to the Memory Table.
///
/// # Arguments
Expand Down Expand Up @@ -136,6 +180,41 @@ impl MemoryTable {
}
}
}

/// Transforms the [`MemoryTable`] into [`super::super::TraceEval`], to be commited when
/// generating a STARK proof.
///
/// The [`MemoryTable`] is transformed from an array of rows (one element = one step of all
/// registers) to an array of columns (one element = all steps of one register).
/// It is then evaluated on the circle domain.
///
/// # Arguments
/// * memory - The [`MemoryTable`] containing the sorted and padded trace as an array of rows.
pub fn trace_evaluation(&self) -> Result<(TraceEval, Claim), TraceError> {
let n_rows = self.table.len() as u32;
if n_rows == 0 {
return Err(TraceError::EmptyTrace);
}
let log_n_rows = n_rows.ilog2();
// TODO: Confirm that the log_size used for evaluation on Circle domain is the log_size of
// the table plus the SIMD lanes
let log_size = log_n_rows + LOG_N_LANES;
let mut trace: Vec<BaseColumn> =
(0..N_COLS_MEMORY_TABLE).map(|_| BaseColumn::zeros(1 << log_size)).collect();

for (vec_row, row) in self.table.iter().enumerate().take(1 << log_n_rows) {
trace[CLK_COL_INDEX].data[vec_row] = row.clk().into();
trace[MP_COL_INDEX].data[vec_row] = row.mp().into();
trace[MV_COL_INDEX].data[vec_row] = row.mv().into();
trace[D_COL_INDEX].data[vec_row] = row.d().into();
}

let domain = CanonicCoset::new(log_size).circle_domain();
let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// TODO: Confirm that the log_size in `Claim` is `log_size`, including the SIMD lanes
Ok((trace, Claim { log_size }))
}
}

impl From<Vec<Registers>> for MemoryTable {
Expand All @@ -153,6 +232,17 @@ impl From<Vec<Registers>> for MemoryTable {
}
}

/// Number of columns in the memory table
pub const N_COLS_MEMORY_TABLE: usize = 4;
/// Index of the `clk` register column in the Memory trace.
const CLK_COL_INDEX: usize = 0;
/// Index of the `mp` register column in the Memory trace.
const MP_COL_INDEX: usize = 1;
/// Index of the `mv` register column in the Memory trace.
const MV_COL_INDEX: usize = 2;
/// Index of the `d` register column in the Memory trace.
const D_COL_INDEX: usize = 3;
Comment on lines +235 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe (just a suggestion, totally fine to do in a follow up or not if you don't like it) to avoid the number of constants into the code, can't we transform that into an enum?

/// Enum representing the column indices in the Memory trace
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryColumn {
    Clk,
    Mp,
    Mv,
    D,
}

impl MemoryColumn {
    /// Returns the index of the column in the Memory table
    pub fn index(self) -> usize {
        match self {
            MemoryColumn::Clk => 0,
            MemoryColumn::Mp => 1,
            MemoryColumn::Mv => 2,
            MemoryColumn::D => 3,
        }
    }

    /// Returns the total number of columns in the Memory table
    pub fn count() -> usize {
        std::mem::variant_count::<Self>()
    }
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not opinionated on const/enum, the goal is to reduce the number of public variables: having a single enum rather than 5 constants, + the number of columns in the table is not a magic number but calculated from the number of columns defined in the enum

From this angle, having an enum seems better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let us do in a follow up if you want so that we don't block things here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, i'll open an issue then


#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -334,4 +424,56 @@ mod tests {

assert_eq!(MemoryTable::from(registers), expected_memory_table);
}

#[test]
fn test_write_trace() {
let mut memory_table = MemoryTable::new();
let rows = vec![
MemoryTableRow::new(BaseField::zero(), BaseField::from(43), BaseField::from(91)),
MemoryTableRow::new(BaseField::one(), BaseField::from(91), BaseField::from(9)),
];
memory_table.add_rows(rows);

let (trace, claim) = memory_table.trace_evaluation().unwrap();

let expected_log_n_rows: u32 = 1;
let expected_log_size = expected_log_n_rows + LOG_N_LANES;
let expected_size = 1 << expected_log_size;
let mut clk_column = BaseColumn::zeros(expected_size);
let mut mp_column = BaseColumn::zeros(expected_size);
let mut mv_col = BaseColumn::zeros(expected_size);
let mut d_column = BaseColumn::zeros(expected_size);

clk_column.data[0] = BaseField::zero().into();
clk_column.data[1] = BaseField::one().into();

mp_column.data[0] = BaseField::from(43).into();
mp_column.data[1] = BaseField::from(91).into();

mv_col.data[0] = BaseField::from(91).into();
mv_col.data[1] = BaseField::from(9).into();

d_column.data[0] = BaseField::zero().into();
d_column.data[1] = BaseField::zero().into();

let domain = CanonicCoset::new(expected_log_size).circle_domain();
let expected_trace: TraceEval = vec![clk_column, mp_column, mv_col, d_column]
.into_iter()
.map(|col| CircleEvaluation::new(domain, col))
.collect();
let expected_claim = Claim { log_size: expected_log_size };

assert_eq!(claim, expected_claim);
for col_index in 0..expected_trace.len() {
assert_eq!(trace[col_index].domain, expected_trace[col_index].domain);
assert_eq!(trace[col_index].to_cpu().values, expected_trace[col_index].to_cpu().values);
}
}
#[test]
fn test_write_empty_trace() {
let memory_table = MemoryTable::new();
let run = memory_table.trace_evaluation();

assert!(matches!(run, Err(TraceError::EmptyTrace)));
}
}
19 changes: 19 additions & 0 deletions crates/brainfuck_prover/src/components/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
use stwo_prover::core::{
backend::simd::SimdBackend,
fields::m31::BaseField,
poly::{circle::CircleEvaluation, BitReversedOrder},
ColumnVec,
};
use thiserror::Error;

pub mod instruction;
pub mod io;
pub mod memory;
pub mod processor;

/// Type for trace evaluation to be used in Stwo.
pub type TraceEval = ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>;

/// Custom error type for the Trace.
#[derive(Debug, Error, Eq, PartialEq)]
pub enum TraceError {
/// The component trace is empty.
#[error("The trace is empty.")]
EmptyTrace,
}