Skip to content

Commit

Permalink
refactor: memory table cleanup (#64)
Browse files Browse the repository at this point in the history
* refactor: remove get_row method

* refactor: remove add_row_from_registers

* refactor: change Memory table and row visibility

* refactor: fix comment
  • Loading branch information
zmalatrax authored Nov 14, 2024
1 parent e1fc830 commit c87b629
Showing 1 changed file with 19 additions and 109 deletions.
128 changes: 19 additions & 109 deletions crates/brainfuck_prover/src/components/memory/table.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use brainfuck_vm::registers::Registers;
use num_traits::{One, Zero};
use num_traits::One;
use stwo_prover::core::fields::m31::BaseField;

/// Represents a single row in the Memory Table.
Expand All @@ -12,11 +12,11 @@ use stwo_prover::core::fields::m31::BaseField;
#[derive(Debug, Default, PartialEq, Eq, Clone)]
pub struct MemoryTableRow {
/// Clock cycle counter: the current step.
pub clk: BaseField,
clk: BaseField,
/// Memory pointer: points to a memory cell.
pub mp: BaseField,
mp: BaseField,
/// Memory value: value of the cell pointer by `mp` - values in [0..2^31 - 1)
pub mv: BaseField,
mv: BaseField,
/// Dummy: Flag whether the row is a dummy one or not
d: BaseField,
}
Expand All @@ -31,6 +31,16 @@ impl MemoryTableRow {
}
}

impl From<(&Registers, bool)> for MemoryTableRow {
fn from((registers, is_dummy): (&Registers, bool)) -> Self {
if is_dummy {
Self::new_dummy(registers.clk, registers.mp, registers.mv)
} else {
Self::new(registers.clk, registers.mp, registers.mv)
}
}
}

/// Represents the Memory Table, which holds the required registers
/// for the Memory component.
///
Expand All @@ -43,7 +53,7 @@ impl MemoryTableRow {
#[derive(Debug, Default, PartialEq, Eq, Clone)]
pub struct MemoryTable {
/// A vector of [`MemoryTableRow`] representing the table rows.
pub table: Vec<MemoryTableRow>,
table: Vec<MemoryTableRow>,
}

impl MemoryTable {
Expand All @@ -55,33 +65,13 @@ impl MemoryTable {
Self::default()
}

/// Adds a new row to the Memory Table from the provided registers.
///
/// # Arguments
/// * `clk` - The clock cycle counter for the new row.
/// * `mp` - The memory pointer for the new row.
/// * `mv` - The memory value for the new row.
/// * `is_dummy` - Flag whether the new row is dummy.
///
/// This method pushes a new [`MemoryTableRow`] onto the `table` vector.
pub fn add_row_from_registers(
&mut self,
clk: BaseField,
mp: BaseField,
mv: BaseField,
is_dummy: bool,
) {
let d = if is_dummy { BaseField::one() } else { BaseField::zero() };
self.table.push(MemoryTableRow { clk, mp, mv, d });
}

/// Adds a new row to the Memory Table.
///
/// # Arguments
/// * `row` - The [`MemoryTableRow`] to add to the table.
///
/// This method pushes a new [`MemoryTableRow`] onto the `table` vector.
pub fn add_row(&mut self, row: MemoryTableRow) {
fn add_row(&mut self, row: MemoryTableRow) {
self.table.push(row);
}

Expand All @@ -91,22 +81,10 @@ impl MemoryTable {
/// * `rows` - A vector of [`MemoryTableRow`] to add to the table.
///
/// This method extends the `table` vector with the provided rows.
pub fn add_rows(&mut self, rows: Vec<MemoryTableRow>) {
fn add_rows(&mut self, rows: Vec<MemoryTableRow>) {
self.table.extend(rows);
}

/// Retrieves a reference to a specific row in the Memory Table.
///
/// # Arguments
/// * `row` - The [`MemoryTableRow`] to search for in the table.
///
/// # Returns
/// An `Option` containing a reference to the matching row if found,
/// or `None` if the row does not exist in the table.
pub fn get_row(&self, row: &MemoryTableRow) -> Option<&MemoryTableRow> {
self.table.iter().find(|r| *r == row)
}

/// Sorts in-place the existing [`MemoryTableRow`] rows in the Memory Table by `mp`, then `clk`.
///
/// Having the rows sorted is required to ensure a correct proof generation (such that the
Expand Down Expand Up @@ -164,9 +142,8 @@ impl From<Vec<Registers>> for MemoryTable {
fn from(registers: Vec<Registers>) -> Self {
let mut memory_table = Self::new();

for register in registers {
memory_table.add_row_from_registers(register.clk, register.mp, register.mv, false);
}
let memory_rows = registers.iter().map(|reg| (reg, false).into()).collect();
memory_table.add_rows(memory_rows);

memory_table.sort();
memory_table.complete_with_dummy_rows();
Expand Down Expand Up @@ -212,39 +189,6 @@ mod tests {
assert!(memory_table.table.is_empty(), "Memory table should be empty upon initialization.");
}

#[test]
fn test_add_row_from_registers() {
let mut memory_table = MemoryTable::new();
// Create a row to add to the table
let row = MemoryTableRow::new(BaseField::zero(), BaseField::from(43), BaseField::from(91));
// Add the row to the table
memory_table.add_row_from_registers(
BaseField::zero(),
BaseField::from(43),
BaseField::from(91),
false,
);
// Check that the table contains the added row
assert_eq!(memory_table.table, vec![row], "Added row should match the expected row.");
}

#[test]
fn test_add_dummy_row_from_registers() {
let mut memory_table = MemoryTable::new();
// Create a row to add to the table
let row =
MemoryTableRow::new_dummy(BaseField::zero(), BaseField::from(43), BaseField::from(91));
// Add the row to the table
memory_table.add_row_from_registers(
BaseField::zero(),
BaseField::from(43),
BaseField::from(91),
true,
);
// Check that the table contains the added row
assert_eq!(memory_table.table, vec![row], "Added row should match the expected row.");
}

#[test]
fn test_add_row() {
let mut memory_table = MemoryTable::new();
Expand Down Expand Up @@ -283,40 +227,6 @@ mod tests {
assert_eq!(memory_table, MemoryTable { table: rows });
}

#[test]
fn test_get_existing_row() {
let mut memory_table = MemoryTable::new();
// Create a row to add to the table
let row = MemoryTableRow {
clk: BaseField::zero(),
mp: BaseField::from(43),
mv: BaseField::from(91),
d: BaseField::zero(),
};
// Add the row to the table
memory_table.add_row(row.clone());
// Retrieve the row from the table
let retrieved = memory_table.get_row(&row);
// Check that the retrieved row matches the added row
assert_eq!(retrieved.unwrap(), &row, "Retrieved row should match the added row.");
}

#[test]
fn test_get_non_existing_row() {
let memory_table = MemoryTable::new();
// Create a row to search for in the table
let row = MemoryTableRow {
clk: BaseField::zero(),
mp: BaseField::from(43),
mv: BaseField::from(91),
d: BaseField::zero(),
};
// Try to retrieve the non-existing row from the table
let retrieved = memory_table.get_row(&row);
// Check that the retrieved row is None
assert!(retrieved.is_none(), "Should return None for a non-existing row.");
}

#[test]
fn test_sort() {
let mut memory_table = MemoryTable::new();
Expand Down

0 comments on commit c87b629

Please sign in to comment.