Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

instruction table: implement trace_evaluation #81

Merged
merged 10 commits into from
Nov 21, 2024
241 changes: 240 additions & 1 deletion crates/brainfuck_prover/src/components/instruction/table.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
use crate::components::{memory::component::Claim, TraceError, TraceEval};
use brainfuck_vm::{
instruction::VALID_INSTRUCTIONS_BF, machine::ProgramMemory, registers::Registers,
};
use num_traits::Zero;
use stwo_prover::core::fields::m31::BaseField;
use stwo_prover::core::{
backend::{
simd::{column::BaseColumn, m31::LOG_N_LANES},
Column,
},
fields::m31::BaseField,
poly::circle::{CanonicCoset, CircleEvaluation},
};

/// Represents a single row in the Instruction Table.
///
Expand All @@ -22,6 +30,23 @@ pub struct InstructionTableRow {
ni: BaseField,
}

impl InstructionTableRow {
/// Get the instruction pointer.
pub const fn ip(&self) -> BaseField {
self.ip
}

/// Get the current instruction.
pub const fn ci(&self) -> BaseField {
self.ci
}

/// Get the next instruction.
pub const fn ni(&self) -> BaseField {
self.ni
}
}

impl From<&Registers> for InstructionTableRow {
fn from(registers: &Registers) -> Self {
Self { ip: registers.ip, ci: registers.ci, ni: registers.ni }
Expand Down Expand Up @@ -85,6 +110,67 @@ impl InstructionTable {
}
}
}

/// Get the instruction table.
pub const fn table(&self) -> &Vec<InstructionTableRow> {
&self.table
}

/// Transforms the [`InstructionTable`] into a [`TraceEval`], to be committed when
/// generating a STARK proof.
///
/// The [`InstructionTable`] is transformed from an array of rows (one element = one step
/// of all registers) to an array of columns (one element = all steps of one register).
/// It is then evaluated on the circle domain.
///
/// # Returns
/// A tuple containing the evaluated trace and claim for STARK proof.
///
/// # Errors
/// Returns [`TraceError::EmptyTrace`] if the table is empty.
pub fn trace_evaluation(&self) -> Result<(TraceEval, Claim), TraceError> {
let n_rows = self.table.len() as u32;
// If the table is empty, there is no data to evaluate, so return an error.
if n_rows == 0 {
return Err(TraceError::EmptyTrace);
}

// Compute `log_n_rows`, the base-2 logarithm of the number of rows.
// This determines the smallest power of two greater than or equal to `n_rows`.
tcoratger marked this conversation as resolved.
Show resolved Hide resolved
let log_n_rows = n_rows.ilog2();

// Add `LOG_N_LANES` to account for SIMD optimization. This ensures that
// the domain size is aligned for parallel processing.
let log_size = log_n_rows + LOG_N_LANES;

// Initialize a trace with 3 columns (for `ip`, `ci`, and `ni` registers),
// each column containing `2^log_size` entries initialized to zero.
let mut trace = vec![BaseColumn::zeros(1 << log_size); InstructionColumn::count()];

// Populate the columns with data from the table rows.
// We iterate over the table rows and, for each row:
// - Map the `ip` value to the first column.
// - Map the `ci` value to the second column.
// - Map the `ni` value to the third column.
for (index, row) in self.table.iter().enumerate().take(1 << log_n_rows) {
trace[InstructionColumn::Ip.index()].data[index] = row.ip().into();
trace[InstructionColumn::Ci.index()].data[index] = row.ci().into();
trace[InstructionColumn::Ni.index()].data[index] = row.ni().into();
}

// Create a circle domain using a canonical coset.
// This domain provides the mathematical structure required for FFT-based evaluation.
let domain = CanonicCoset::new(log_size).circle_domain();

// Map each column into the circle domain.
//
// This converts the columnar data into polynomial evaluations over the domain, enabling
// constraint verification in STARK proofs.
let trace = trace.into_iter().map(|col| CircleEvaluation::new(domain, col)).collect();

// Return the evaluated trace and a claim containing the log size of the domain.
Ok((trace, Claim { log_size }))
}
}

impl From<(Vec<Registers>, &ProgramMemory)> for InstructionTable {
Expand Down Expand Up @@ -125,6 +211,33 @@ impl From<(Vec<Registers>, &ProgramMemory)> for InstructionTable {
}
}

/// Enum representing the column indices in the Instruction trace
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InstructionColumn {
/// Index of the `ip` register column in the Instruction trace.
Ip,
/// Index of the `ci` register column in the Instruction trace.
Ci,
/// Index of the `ni` register column in the Instruction trace.
Ni,
}

impl InstructionColumn {
/// Returns the index of the column in the Instruction trace
pub const fn index(self) -> usize {
match self {
Self::Ip => 0,
Self::Ci => 1,
Self::Ni => 2,
}
}

/// Returns the total number of columns in the Instruction table
tcoratger marked this conversation as resolved.
Show resolved Hide resolved
tcoratger marked this conversation as resolved.
Show resolved Hide resolved
pub const fn count() -> usize {
3
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -365,4 +478,130 @@ mod tests {
// Verify that the instruction table is correct
assert_eq!(instruction_table, expected_instruction_table);
}

#[test]
fn test_trace_evaluation_empty_table() {
let instruction_table = InstructionTable::new();
let result = instruction_table.trace_evaluation();

assert!(matches!(result, Err(TraceError::EmptyTrace)));
}

#[test]
fn test_trace_evaluation_single_row() {
let mut instruction_table = InstructionTable::new();
instruction_table.add_row(InstructionTableRow {
ip: BaseField::from(1),
ci: BaseField::from(43),
ni: BaseField::from(91),
});

let (trace, claim) = instruction_table.trace_evaluation().unwrap();

assert_eq!(claim.log_size, LOG_N_LANES, "Log size should include SIMD lanes.");
assert_eq!(
trace.len(),
InstructionColumn::count(),
"Trace should contain one column per register."
);

// Check that each column contains the correct values
assert_eq!(trace[InstructionColumn::Ip.index()].to_cpu().values[0], BaseField::from(1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should compare values rather than values[0]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, as this test is simple, I've directly hardcoded 16, the expected number of elements, let me know if this is ok for you

assert_eq!(trace[InstructionColumn::Ci.index()].to_cpu().values[0], BaseField::from(43));
assert_eq!(trace[InstructionColumn::Ni.index()].to_cpu().values[0], BaseField::from(91));
}

#[test]
#[allow(clippy::similar_names)]
fn test_write_instruction_trace() {
tcoratger marked this conversation as resolved.
Show resolved Hide resolved
let mut instruction_table = InstructionTable::new();

// Add rows to the instruction table.
let rows = vec![
InstructionTableRow {
ip: BaseField::zero(),
ci: BaseField::from(43),
ni: BaseField::from(91),
},
InstructionTableRow {
ip: BaseField::one(),
ci: BaseField::from(91),
ni: BaseField::from(9),
},
];
instruction_table.add_rows(rows);

// Perform the trace evaluation.
let (trace, claim) = instruction_table.trace_evaluation().unwrap();

// Calculate the expected parameters.
let expected_log_n_rows: u32 = 1; // log2(2 rows)
let expected_log_size = expected_log_n_rows + LOG_N_LANES;
let expected_size = 1 << expected_log_size;

// Construct the expected trace columns.
let mut ip_column = BaseColumn::zeros(expected_size);
let mut ci_column = BaseColumn::zeros(expected_size);
let mut ni_column = BaseColumn::zeros(expected_size);

ip_column.data[0] = BaseField::zero().into();
ip_column.data[1] = BaseField::one().into();

ci_column.data[0] = BaseField::from(43).into();
ci_column.data[1] = BaseField::from(91).into();

ni_column.data[0] = BaseField::from(91).into();
ni_column.data[1] = BaseField::from(9).into();

// Create the expected domain for evaluation.
let domain = CanonicCoset::new(expected_log_size).circle_domain();

// Transform expected columns into CircleEvaluation.
let expected_trace: TraceEval = vec![ip_column, ci_column, ni_column]
.into_iter()
.map(|col| CircleEvaluation::new(domain, col))
.collect();

// Create the expected claim.
let expected_claim = Claim { log_size: expected_log_size };

// Assert equality of the claim.
assert_eq!(claim, expected_claim);

// Assert equality of the trace.
for col_index in 0..expected_trace.len() {
assert_eq!(trace[col_index].domain, expected_trace[col_index].domain);
assert_eq!(trace[col_index].to_cpu().values, expected_trace[col_index].to_cpu().values);
}
}

#[test]
fn test_trace_evaluation_circle_domain() {
let mut instruction_table = InstructionTable::new();
instruction_table.add_rows(vec![
InstructionTableRow {
ip: BaseField::from(0),
ci: BaseField::from(43),
ni: BaseField::from(91),
},
InstructionTableRow {
ip: BaseField::from(1),
ci: BaseField::from(91),
ni: BaseField::from(9),
},
]);

let (trace, claim) = instruction_table.trace_evaluation().unwrap();

let log_size = claim.log_size;
let domain = CanonicCoset::new(log_size).circle_domain();

// Check that each column is evaluated over the correct domain
for column in trace {
assert_eq!(
column.domain, domain,
"Trace column domain should match expected circle domain."
);
}
}
}
Loading