Skip to content

Commit

Permalink
relation tracker eval
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Nov 28, 2024
1 parent 4d64300 commit b4fefbf
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 0 deletions.
1 change: 1 addition & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod info;
pub mod logup;
mod point;
pub mod preprocessed_columns;
pub mod relation_tracker;
mod simd_domain;

use std::array;
Expand Down
189 changes: 189 additions & 0 deletions crates/prover/src/constraint_framework/relation_tracker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use std::fmt::Debug;

use itertools::Itertools;
use num_traits::Zero;

use super::logup::LogupSums;
use super::{
EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry, TraceLocationAllocator,
INTERACTION_TRACE_IDX,
};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::very_packed_m31::LOG_N_VERY_PACKED_ELEMS;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::{TreeSubspan, TreeVec};
use crate::core::poly::circle::CircleEvaluation;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{
bit_reverse_index, coset_index_to_circle_domain_index, offset_bit_reversed_circle_domain_index,
};

#[derive(Debug)]
pub struct RelationTrackerEntry {
pub relation: String,
pub mult: M31,
pub values: Vec<M31>,
}

pub struct RelationTrackerComponent<E: FrameworkEval> {
eval: E,
trace_locations: TreeVec<TreeSubspan>,
n_rows: usize,
}
impl<E: FrameworkEval> RelationTrackerComponent<E> {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E, n_rows: usize) -> Self {
let info = eval.evaluate(InfoEvaluator::new(
eval.log_size(),
vec![],
LogupSums::default(),
));
let mut mask_offsets = info.mask_offsets;
mask_offsets.drain(INTERACTION_TRACE_IDX..);
let trace_locations = location_allocator.next_for_structure(&mask_offsets);
Self {
eval,
trace_locations,
n_rows,
}
}

pub fn entries(
self,
trace: &TreeVec<Vec<&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
) -> Vec<RelationTrackerEntry> {
let log_size = self.eval.log_size();

// Deref the sub-tree. Only copies the references.
let sub_tree = trace
.sub_tree(&self.trace_locations)
.map(|vec| vec.into_iter().copied().collect_vec());
let mut entries = vec![];

for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let evaluator =
RelationTrackerEvaluator::new(&sub_tree, vec_row, log_size, self.n_rows);
entries.extend(self.eval.evaluate(evaluator).entries());
}
entries
}
}

/// Aggregates relation entries.
// TODO(Ohad): write a summarize function, test.
pub struct RelationTrackerEvaluator<'a> {
entries: Vec<RelationTrackerEntry>,
pub trace_eval:
&'a TreeVec<Vec<&'a CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
pub column_index_per_interaction: Vec<usize>,
pub vec_row: usize,
pub domain_log_size: u32,
pub n_rows: usize,
}
impl<'a> RelationTrackerEvaluator<'a> {
pub fn new(
trace_eval: &'a TreeVec<Vec<&CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>>>,
vec_row: usize,
domain_log_size: u32,
n_rows: usize,
) -> Self {
Self {
entries: vec![],
trace_eval,
column_index_per_interaction: vec![0; trace_eval.len()],
vec_row,
domain_log_size,
n_rows,
}
}

pub fn entries(self) -> Vec<RelationTrackerEntry> {
self.entries
}
}
impl<'a> EvalAtRow for RelationTrackerEvaluator<'a> {
type F = PackedBaseField;
type EF = PackedSecureField;

// TODO(Ohad): Add debug boundary checks.
fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
assert_ne!(interaction, INTERACTION_TRACE_IDX);
let col_index = self.column_index_per_interaction[interaction];
self.column_index_per_interaction[interaction] += 1;
offsets.map(|off| {
// If the offset is 0, we can just return the value directly from this row.
if off == 0 {
unsafe {
let col = &self
.trace_eval
.get_unchecked(interaction)
.get_unchecked(col_index)
.values;
return *col.data.get_unchecked(self.vec_row);
};
}
// Otherwise, we need to look up the value at the offset.
// Since the domain is bit-reversed circle domain ordered, we need to look up the value
// at the bit-reversed natural order index at an offset.
PackedBaseField::from_array(std::array::from_fn(|i| {
let row_index = offset_bit_reversed_circle_domain_index(
(self.vec_row << (LOG_N_LANES + LOG_N_VERY_PACKED_ELEMS)) + i,
self.domain_log_size,
self.domain_log_size,
off,
);
self.trace_eval[interaction][col_index].at(row_index)
}))
})
}
fn add_constraint<G>(&mut self, _constraint: G) {}

fn combine_ef(_values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
PackedSecureField::zero()
}

fn write_logup_frac(&mut self, _fraction: Fraction<Self::EF, Self::EF>) {}

fn finalize_logup(&mut self) {}

fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
) {
for entry in entries {
let relation = entry.relation.get_name().to_owned();
let values = entry.values.iter().map(|v| v.to_array()).collect_vec();
let mult = entry.multiplicity.to_array();

// Unpack SIMD.
for j in 0..N_LANES {
// Skip padded values.
let cannonical_index = bit_reverse_index(
coset_index_to_circle_domain_index(
(self.vec_row << LOG_N_LANES) + j,
self.domain_log_size,
),
self.domain_log_size,
);
if cannonical_index >= self.n_rows {
continue;
}
let values = values.iter().map(|v| v[j]).collect_vec();
let mult = mult[j].to_m31_array()[0];
self.entries.push(RelationTrackerEntry {
relation: relation.clone(),
mult,
values,
});
}
}
}
}
7 changes: 7 additions & 0 deletions crates/prover/src/core/pcs/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ impl<'a, T> From<&'a TreeVec<T>> for TreeVec<&'a T> {
}
}

/// Converts `&TreeVec<&Vec<T>>` to `TreeVec<Vec<&T>>`.
impl<'a, T> From<&'a TreeVec<&'a Vec<T>>> for TreeVec<Vec<&'a T>> {
fn from(val: &'a TreeVec<&'a Vec<T>>) -> Self {
TreeVec(val.iter().map(|vec| vec.iter().collect()).collect())
}
}

impl<T> Deref for TreeVec<T> {
type Target = Vec<T>;
fn deref(&self) -> &Self::Target {
Expand Down

0 comments on commit b4fefbf

Please sign in to comment.