Skip to content

Commit

Permalink
Support shared preproccessed columns. (#869)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware authored Nov 11, 2024
1 parent 8385e54 commit e0dd4fb
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 26 deletions.
103 changes: 97 additions & 6 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
Expand All @@ -9,7 +10,10 @@ use rayon::prelude::*;
use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator};
use super::preprocessed_columns::PreprocessedColumn;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
Expand All @@ -29,12 +33,23 @@ use crate::core::{utils, ColumnVec};

const CHUNK_SIZE: usize = 1;

#[derive(Debug, Default)]
enum PreprocessedColumnsAllocationMode {
#[default]
Dynamic,
Static,
}

// TODO(andrew): Docs.
// TODO(andrew): Consider better location for this.
#[derive(Debug, Default)]
pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
/// Mapping of preprocessed columns to their index.
preprocessed_columns: HashMap<PreprocessedColumn, usize>,
/// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
}

impl TraceLocationAllocator {
Expand Down Expand Up @@ -62,6 +77,23 @@ impl TraceLocationAllocator {
.collect(),
)
}

/// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
pub fn new_with_preproccessed_columnds(preprocessed_columns: &[PreprocessedColumn]) -> Self {
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.collect(),
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
}
}

pub fn n_preprocessed_columns(&self) -> usize {
self.preprocessed_columns.len()
}
}

/// A component defined solely in means of the constraints framework.
Expand All @@ -80,16 +112,42 @@ pub struct FrameworkComponent<C: FrameworkEval> {
eval: C,
trace_locations: TreeVec<TreeSubspan>,
info: InfoEvaluator,
preprocessed_column_indices: Vec<usize>,
}

impl<E: FrameworkEval> FrameworkComponent<E> {
pub fn new(location_allocator: &mut TraceLocationAllocator, eval: E) -> Self {
let info = eval.evaluate(InfoEvaluator::default());
let trace_locations = location_allocator.next_for_structure(&info.mask_offsets);

let preprocessed_column_indices = info
.preprocessed_columns
.iter()
.map(|col| {
let next_column = location_allocator.preprocessed_columns.len();
*location_allocator
.preprocessed_columns
.entry(*col)
.or_insert_with(|| {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
PreprocessedColumnsAllocationMode::Static
) {
panic!(
"Preprocessed column {:?} is missing from static alloction",
col
);
}

next_column
})
})
.collect();
Self {
eval,
trace_locations,
info,
preprocessed_column_indices,
}
}

Expand All @@ -108,10 +166,19 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
}

fn trace_log_degree_bounds(&self) -> TreeVec<ColumnVec<u32>> {
self.info
let mut log_degree_bounds = self
.info
.mask_offsets
.as_ref()
.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()])
.map(|tree_offsets| vec![self.eval.log_size(); tree_offsets.len()]);

log_degree_bounds[0] = self
.preprocessed_column_indices
.iter()
.map(|_| self.eval.log_size())
.collect();

log_degree_bounds
}

fn mask_points(
Expand All @@ -127,14 +194,27 @@ impl<E: FrameworkEval> Component for FrameworkComponent<E> {
})
}

fn preproccessed_column_indices(&self) -> ColumnVec<usize> {
self.preprocessed_column_indices.clone()
}

fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &TreeVec<ColumnVec<Vec<SecureField>>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
let preprocessed_mask = self
.preprocessed_column_indices
.iter()
.map(|idx| &mask[PREPROCESSED_TRACE_IDX][*idx])
.collect_vec();

let mut mask_points = mask.sub_tree(&self.trace_locations);
mask_points[PREPROCESSED_TRACE_IDX] = preprocessed_mask;

self.eval.evaluate(PointEvaluator::new(
mask.sub_tree(&self.trace_locations),
mask_points,
evaluation_accumulator,
coset_vanishing(CanonicCoset::new(self.eval.log_size()).coset, point).inverse(),
));
Expand All @@ -154,8 +234,19 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_domain = CanonicCoset::new(self.eval.log_size());

let component_polys = trace.polys.sub_tree(&self.trace_locations);
let component_evals = trace.evals.sub_tree(&self.trace_locations);
let mut component_polys = trace.polys.sub_tree(&self.trace_locations);
component_polys[PREPROCESSED_TRACE_IDX] = self
.preprocessed_column_indices
.iter()
.map(|idx| &trace.polys[PREPROCESSED_TRACE_IDX][*idx])
.collect();

let mut component_evals = trace.evals.sub_tree(&self.trace_locations);
component_evals[PREPROCESSED_TRACE_IDX] = self
.preprocessed_column_indices
.iter()
.map(|idx| &trace.evals[PREPROCESSED_TRACE_IDX][*idx])
.collect();

// Extend trace if necessary.
// TODO: Don't extend when eval_size < committed_size. Instead, pick a good
Expand Down
15 changes: 15 additions & 0 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::ops::Mul;

use num_traits::One;

use super::preprocessed_columns::PreprocessedColumn;
use super::EvalAtRow;
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::pcs::TreeVec;
Expand All @@ -13,6 +15,7 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
}
impl InfoEvaluator {
pub fn new() -> Self {
Expand All @@ -22,11 +25,17 @@ impl InfoEvaluator {
impl EvalAtRow for InfoEvaluator {
type F = BaseField;
type EF = SecureField;

fn next_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::F; N] {
assert!(
interaction != PREPROCESSED_TRACE_IDX,
"Preprocessed should be accesses with `get_preprocessed_column`",
);

// Check if requested a mask from a new interaction
if self.mask_offsets.len() <= interaction {
// Extend `mask_offsets` so that `interaction` is the last index.
Expand All @@ -35,6 +44,12 @@ impl EvalAtRow for InfoEvaluator {
self.mask_offsets[interaction].push(offsets.into_iter().collect());
[BaseField::one(); N]
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
self.preprocessed_columns.push(column);
BaseField::one()
}

fn add_constraint<G>(&mut self, _constraint: G)
where
Self::EF: Mul<G, Output = Self::EF>,
Expand Down
57 changes: 50 additions & 7 deletions crates/prover/src/core/air/components.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::iter::zip;

use itertools::Itertools;

use super::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::{Component, ComponentProver, Trace};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::backend::Backend;
use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
Expand All @@ -27,11 +30,22 @@ impl<'a> Components<'a> {
&self,
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>> {
TreeVec::concat_cols(
let mut mask_points = TreeVec::concat_cols(
self.components
.iter()
.map(|component| component.mask_points(point)),
)
);

let preprocessed_mask_points = &mut mask_points[PREPROCESSED_TRACE_IDX];
*preprocessed_mask_points = vec![vec![]; self.n_preprocessed_columns];

for component in &self.components {
for idx in component.preproccessed_column_indices() {
preprocessed_mask_points[idx] = vec![point];
}
}

mask_points
}

pub fn eval_composition_polynomial_at_point(
Expand All @@ -52,11 +66,40 @@ impl<'a> Components<'a> {
}

pub fn column_log_sizes(&self) -> TreeVec<ColumnVec<u32>> {
TreeVec::concat_cols(
self.components
.iter()
.map(|component| component.trace_log_degree_bounds()),
)
let mut preprocessed_columns_trace_log_sizes = vec![0; self.n_preprocessed_columns];
let mut visited_columns = vec![false; self.n_preprocessed_columns];

let mut column_log_sizes = TreeVec::concat_cols(self.components.iter().map(|component| {
let component_trace_log_sizes = component.trace_log_degree_bounds();

for (column_index, &log_size) in zip(
component.preproccessed_column_indices(),
&component_trace_log_sizes[PREPROCESSED_TRACE_IDX],
) {
let column_log_size = &mut preprocessed_columns_trace_log_sizes[column_index];
if visited_columns[column_index] {
assert!(
*column_log_size == log_size,
"Preprocessed column size mismatch for column {}",
column_index
);
} else {
*column_log_size = log_size;
visited_columns[column_index] = true;
}
}

component_trace_log_sizes
}));

assert!(
visited_columns.iter().all(|&updated| updated),
"Column size not set for all reprocessed columns"
);

column_log_sizes[PREPROCESSED_TRACE_IDX] = preprocessed_columns_trace_log_sizes;

column_log_sizes
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/prover/src/core/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ pub trait Component {
point: CirclePoint<SecureField>,
) -> TreeVec<ColumnVec<Vec<CirclePoint<SecureField>>>>;

fn preproccessed_column_indices(&self) -> ColumnVec<usize>;

/// Evaluates the constraint quotients combination of the component at a point.
fn evaluate_constraint_quotients_at_point(
&self,
Expand Down
Loading

0 comments on commit e0dd4fb

Please sign in to comment.