Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iquerejeta/pruning #17

Merged
merged 20 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub trait Chip<F: Field>: Sized {
}

/// Index of a region in a layouter
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct RegionIndex(usize);

impl From<usize> for RegionIndex {
Expand Down Expand Up @@ -86,7 +86,7 @@ impl std::ops::Deref for RegionStart {
}

/// A pointer to a cell within a circuit.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Cell {
/// Identifies the region in which this cell resides.
pub region_index: RegionIndex,
Expand All @@ -104,6 +104,21 @@ pub struct AssignedCell<V, F: Field> {
_marker: PhantomData<F>,
}

impl<V, F: Field> PartialEq for AssignedCell<V, F> {
fn eq(&self, other: &Self) -> bool {
self.cell == other.cell
}
}

impl<V, F: Field> Eq for AssignedCell<V, F> {}

use std::hash::{Hash, Hasher};
impl<V, F: Field> Hash for AssignedCell<V, F> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.cell.hash(state)
}
}

impl<V, F: Field> AssignedCell<V, F> {
/// Returns the value of the [`AssignedCell`].
pub fn value(&self) -> Value<&V> {
Expand Down
14 changes: 0 additions & 14 deletions src/circuit/floor_planner/v1/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,8 @@ pub fn slot_in_biggest_advice_first(
advice_cols * shape.row_count()
};

// This used to incorrectly use `sort_unstable_by_key` with non-unique keys, which gave
// output that differed between 32-bit and 64-bit platforms, and potentially between Rust
// versions.
// We now use `sort_by_cached_key` with non-unique keys, and rely on `region_shapes`
// being sorted by region index (which we also rely on below to return `RegionStart`s
// in the correct order).
#[cfg(not(feature = "floor-planner-v1-legacy-pdqsort"))]
sorted_regions.sort_by_cached_key(sort_key);

// To preserve compatibility, when the "floor-planner-v1-legacy-pdqsort" feature is enabled,
// we use a copy of the pdqsort implementation from the Rust 1.56.1 standard library, fixed
// to its behaviour on 64-bit platforms.
// https://github.com/rust-lang/rust/blob/1.56.1/library/core/src/slice/mod.rs#L2365-L2402
#[cfg(feature = "floor-planner-v1-legacy-pdqsort")]
halo2_legacy_pdqsort::sort::quicksort(&mut sorted_regions, |a, b| sort_key(a).lt(&sort_key(b)));

sorted_regions.reverse();

// Lay out the sorted regions.
Expand Down
76 changes: 75 additions & 1 deletion src/dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ pub use tfp::TracingFloorPlanner;
#[cfg(feature = "dev-graph")]
mod graph;

use crate::plonk::VirtualCell;
use crate::rational::Rational;
#[cfg(feature = "dev-graph")]
#[cfg_attr(docsrs, doc(cfg(feature = "dev-graph")))]
pub use graph::{circuit_dot_graph, layout::CircuitLayout};

use crate::poly::Rotation;

#[derive(Debug)]
struct Region {
/// The name of the region. Not required to be unique.
Expand Down Expand Up @@ -820,7 +823,15 @@ impl<F: FromUniformBytes<64> + Ord> MockProver<F> {
}
_ => {
// Check that it was assigned!
if r.cells.contains_key(&(cell.column, cell_row)) {
if r.cells.contains_key(&(cell.column, cell_row))
|| gate.polynomials().par_iter().all(|expr| {
self.cell_is_irrelevant(
iquerejeta marked this conversation as resolved.
Show resolved Hide resolved
cell,
expr,
gate_row as usize,
)
})
{
None
} else {
Some(VerifyFailure::CellNotAssigned {
Expand Down Expand Up @@ -1124,6 +1135,69 @@ impl<F: FromUniformBytes<64> + Ord> MockProver<F> {
}
}

// Checks if the given expression is guaranteed to be constantly zero at the given offset.
fn expr_is_constantly_zero(&self, expr: &Expression<F>, offset: usize) -> bool {
match expr {
Expression::Constant(constant) => constant.is_zero().into(),
Expression::Selector(selector) => !self.selectors[selector.0][offset],
Expression::Fixed(query) => match self.fixed[query.column_index][offset] {
CellValue::Assigned(value) => value.is_zero().into(),
_ => false,
},
Expression::Scaled(e, factor) => {
factor.is_zero().into() || self.expr_is_constantly_zero(e, offset)
}
Expression::Sum(e1, e2) => {
self.expr_is_constantly_zero(e1, offset) && self.expr_is_constantly_zero(e2, offset)
}
Expression::Product(e1, e2) => {
self.expr_is_constantly_zero(e1, offset) || self.expr_is_constantly_zero(e2, offset)
}
_ => false,
}
}

// Verify that the value of the given cell within the given expression is
// irrelevant to the evaluation of the expression. This may be because
// the cell is always multiplied by an expression that evaluates to 0, or
// because the cell is not being queried in the expression at all.
fn cell_is_irrelevant(&self, cell: &VirtualCell, expr: &Expression<F>, offset: usize) -> bool {
// Check if a given query (defined by its columnd and rotation, since we
// want this function to support different query types) is equal to `cell`.
let eq_query = |query_column: usize, query_rotation: Rotation, col_type: Any| {
cell.column.index() == query_column
&& cell.column.column_type() == &col_type
&& query_rotation == cell.rotation
};
match expr {
Expression::Constant(_) | Expression::Selector(_) => true,
Expression::Fixed(query) => !eq_query(query.column_index, query.rotation(), Any::Fixed),
Expression::Advice(query) => !eq_query(
query.column_index,
query.rotation(),
Any::Advice(Advice::new(query.phase)),
),
Expression::Instance(query) => {
!eq_query(query.column_index, query.rotation(), Any::Instance)
}
Expression::Challenge(_) => true,
Expression::Negated(e) => self.cell_is_irrelevant(cell, e, offset),
Expression::Sum(e1, e2) => {
self.cell_is_irrelevant(cell, e1, offset)
&& self.cell_is_irrelevant(cell, e2, offset)
}
Expression::Product(e1, e2) => {
(self.expr_is_constantly_zero(e1, offset)
|| self.expr_is_constantly_zero(e2, offset))
|| (self.cell_is_irrelevant(cell, e1, offset)
&& self.cell_is_irrelevant(cell, e2, offset))
}
Expression::Scaled(e, factor) => {
factor.is_zero().into() || self.cell_is_irrelevant(cell, e, offset)
}
}
}

/// Panics if the circuit being checked by this `MockProver` is not satisfied.
///
/// Any verification failures will be pretty-printed to stderr before the function
Expand Down
33 changes: 21 additions & 12 deletions src/dev/cost_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::collections::HashSet;
use std::panic::AssertUnwindSafe;
use std::{iter, num::ParseIntError, panic, str::FromStr};

use crate::plonk::Any::Fixed;
use crate::plonk::Circuit;
use ff::{Field, FromUniformBytes};
use serde::Deserialize;
use serde_derive::Serialize;
use crate::plonk::Any::Fixed;

use super::MockProver;

Expand Down Expand Up @@ -88,7 +88,8 @@ impl FromStr for Poly {
pub struct Lookup;

impl Lookup {
fn queries(&self) -> impl Iterator<Item = Poly> {
/// Returns the queries of the lookup argument
pub fn queries(&self) -> impl Iterator<Item = Poly> {
// - product commitments at x and \omega x
// - input commitments at x and x_inv
// - table commitments at x
Expand All @@ -110,7 +111,8 @@ pub struct Permutation {
}

impl Permutation {
fn queries(&self) -> impl Iterator<Item = Poly> {
/// Returns the queries of the Permutation argument
pub fn queries(&self) -> impl Iterator<Item = Poly> {
// - product commitments at x and x_inv
// - polynomial commitments at x
let product = "0,-1".parse().unwrap();
Expand All @@ -120,13 +122,22 @@ impl Permutation {
.chain(Some(product))
.chain(iter::repeat(poly).take(self.columns))
}

/// Returns the number of columns of the Permutation argument
pub fn nr_columns(&self) -> usize {
self.columns
}
}

/// High-level specifications of an abstract circuit.
#[derive(Debug, Deserialize, Serialize)]
pub struct ModelCircuit {
/// Power-of-2 bound on the number of rows in the circuit.
pub k: usize,
/// Number of rows in the circuit (not including table rows).
pub rows: usize,
/// Number of table rows in the circuit.
pub table_rows: usize,
/// Maximum degree of the circuit.
pub max_deg: usize,
/// Number of advice columns.
Expand Down Expand Up @@ -224,6 +235,8 @@ impl CostOptions {

ModelCircuit {
k: self.min_k,
rows: self.rows_count,
table_rows: self.table_rows_count,
max_deg: self.max_degree,
advice_columns: self.advice.len(),
lookups: self.lookup.len(),
Expand Down Expand Up @@ -260,7 +273,7 @@ fn run_mock_prover_with_fallback<F: Ord + Field + FromUniformBytes<64>, C: Circu
panic::catch_unwind(AssertUnwindSafe(|| {
MockProver::run(k, circuit, instances.clone()).unwrap()
}))
.ok()
.ok()
})
.expect("A circuit which can be implemented with at most 2^24 rows.")
}
Expand Down Expand Up @@ -338,11 +351,7 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
// columns (see that [`plonk::circuit::TableColumn` is a wrapper
// around `Column<Fixed>`]). All of a table region's rows are
// counted towards `table_rows_count.`
if region
.columns
.iter()
.all(|c| *c.column_type() == Fixed)
{
if region.columns.iter().all(|c| *c.column_type() == Fixed) {
table_rows_count += (end + 1) - start;
} else {
rows_count += (end + 1) - start;
Expand All @@ -358,9 +367,9 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
table_rows_count + cs.blinding_factors(),
instance_len,
]
.into_iter()
.max()
.unwrap();
.into_iter()
.max()
.unwrap();
if min_k == instance_len {
println!("WARNING: The dominant factor in your circuit's size is the number of public inputs, which causes the verifier to perform linear work.");
}
Expand Down
2 changes: 1 addition & 1 deletion src/dev/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub(super) fn format_value<F: Field>(v: F) -> String {
// Format value as hex.
let s = format!("{v:?}");
// Remove leading zeroes.
let s = s.strip_prefix("0x").unwrap();
let s = s.split_once("0x").unwrap().1.split(')').next().unwrap();
let s = s.trim_start_matches('0');
format!("0x{s}")
}
Expand Down
12 changes: 6 additions & 6 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ pub trait SerdeCurveAffine: PrimeCurveAffine + SerdeObject + Default {
/// Reads an element from the buffer and parses it according to the `format`:
/// - `Processed`: Reads a compressed curve element and decompress it
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
/// does not perform any checks
/// does not perform any checks
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
match format {
SerdeFormat::Processed => <Self as CurveRead>::read(reader),
Expand Down Expand Up @@ -83,9 +83,9 @@ impl<C: PrimeCurveAffine + SerdeObject + Default> SerdeCurveAffine for C {}
pub trait SerdePrimeField: PrimeField + SerdeObject {
/// Reads a field element as bytes from the buffer according to the `format`:
/// - `Processed`: Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// - `RawBytes`: Reads a field element from raw bytes in its internal Montgomery representations,
/// and checks that the element is less than the modulus.
/// and checks that the element is less than the modulus.
/// - `RawBytesUnchecked`: Reads a field element in Montgomery form and performs no checks.
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
match format {
Expand All @@ -103,9 +103,9 @@ pub trait SerdePrimeField: PrimeField + SerdeObject {

/// Writes a field element as bytes to the buffer according to the `format`:
/// - `Processed`: Writes a field element in standard form, with endianness specified by the
/// `PrimeField` implementation.
/// `PrimeField` implementation.
/// - Otherwise: Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
/// WITHOUT performing the expensive Montgomery reduction.
fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
match format {
SerdeFormat::Processed => writer.write_all(self.to_repr().as_ref()),
Expand Down
24 changes: 12 additions & 12 deletions src/plonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ where
///
/// Reads a curve element from the buffer and parses it according to the `format`:
/// - `Processed`: Reads a compressed curve element and decompresses it.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
/// does not perform any checks
/// does not perform any checks
pub fn read<R: io::Read, ConcreteCircuit: Circuit<F>>(
reader: &mut R,
format: SerdeFormat,
Expand Down Expand Up @@ -334,12 +334,12 @@ where
///
/// Writes a curve element according to `format`:
/// - `Processed`: Writes a compressed curve element with coordinates in standard form.
/// Writes a field element in standard form, with endianness specified by the
/// Writes a field element in standard form, with endianness specified by the
/// `PrimeField` implementation.
/// - Otherwise: Writes an uncompressed curve element with coordinates in Montgomery form
/// Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
/// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials)
/// Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
/// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials)
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
self.vk.write(writer, format)?;
self.l0.write(writer, format)?;
Expand All @@ -357,12 +357,12 @@ where
///
/// Reads a curve element from the buffer and parses it according to the `format`:
/// - `Processed`: Reads a compressed curve element and decompresses it.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
/// does not perform any checks
/// does not perform any checks
pub fn read<R: io::Read, ConcreteCircuit: Circuit<F>>(
reader: &mut R,
format: SerdeFormat,
Expand Down
5 changes: 3 additions & 2 deletions src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::ops::Range;

use ff::{Field, FromUniformBytes, WithSmallOrderMulGroup};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use super::{
circuit::{
Expand Down Expand Up @@ -312,12 +313,12 @@ where
);

let fixed_polys: Vec<_> = fixed
.iter()
.par_iter()
.map(|poly| vk.domain.lagrange_to_coeff(poly.clone()))
.collect();

let fixed_cosets = fixed_polys
.iter()
.par_iter()
.map(|poly| vk.domain.coeff_to_extended(poly.clone()))
.collect();

Expand Down
Loading
Loading