Skip to content

Commit

Permalink
Feat/skip fixed commit of range table (#797)
Browse files Browse the repository at this point in the history
Work for #789

---------

Co-authored-by: sm.wu <[email protected]>
  • Loading branch information
10to4 and hero78119 authored Jan 13, 2025
1 parent 083b2d4 commit 7919da2
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 69 deletions.
4 changes: 2 additions & 2 deletions ceno_zkvm/src/chip_handler/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
pub fn lk_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rom_type: ROMType,
record: Vec<Expression<E>>,
multiplicity: Expression<E>,
Expand All @@ -105,7 +105,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> {
N: FnOnce() -> NR,
{
self.cs
.lk_table_record(name_fn, table_len, rom_type, record, multiplicity)
.lk_table_record(name_fn, table_spec, rom_type, record, multiplicity)
}

pub fn r_table_record<NR, N>(
Expand Down
13 changes: 3 additions & 10 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use ceno_emul::Addr;
use itertools::{Itertools, chain};
use std::{collections::HashMap, iter::once, marker::PhantomData};

Expand Down Expand Up @@ -56,13 +55,7 @@ impl NameSpace {
pub struct LogupTableExpression<E: ExtensionField> {
pub multiplicity: Expression<E>,
pub values: Expression<E>,
pub table_len: usize,
}

#[derive(Clone, Debug)]
pub struct DynamicAddr {
pub addr_witin_id: usize,
pub offset: Addr,
pub table_spec: SetTableSpec,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -297,7 +290,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
pub fn lk_table_record<NR, N>(
&mut self,
name_fn: N,
table_len: usize,
table_spec: SetTableSpec,
rom_type: ROMType,
record: Vec<Expression<E>>,
multiplicity: Expression<E>,
Expand All @@ -321,7 +314,7 @@ impl<E: ExtensionField> ConstraintSystem<E> {
self.lk_table_expressions.push(LogupTableExpression {
values: rlc_record,
multiplicity,
table_len,
table_spec,
});
let path = self.ns.compute_path(name_fn().into());
self.lk_expressions_namespace_map.push(path);
Expand Down
10 changes: 7 additions & 3 deletions ceno_zkvm/src/scheme/mock_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,7 @@ Hints:
);

let mut wit_mles = HashMap::new();
let mut structural_wit_mles = HashMap::new();
let mut fixed_mles = HashMap::new();
let mut num_instances = HashMap::new();

Expand All @@ -815,15 +816,17 @@ Hints:

if witness.num_instances() == 0 {
wit_mles.insert(circuit_name.clone(), vec![]);
structural_wit_mles.insert(circuit_name.clone(), vec![]);
fixed_mles.insert(circuit_name.clone(), vec![]);
num_instances.insert(circuit_name.clone(), num_rows);
continue;
}
let witness = witness
let mut witness = witness
.into_mles()
.into_iter()
.map(|w| w.into())
.collect_vec();
let structural_witness = witness.split_off(cs.num_witin as usize);
let fixed: Vec<_> = fixed_trace
.circuit_fixed_traces
.remove(circuit_name)
Expand Down Expand Up @@ -876,7 +879,7 @@ Hints:
let lk_table = wit_infer_by_expr(
&fixed,
&witness,
&[],
&structural_witness,
&pi_mles,
&challenges,
&expr.values,
Expand All @@ -887,7 +890,7 @@ Hints:
let multiplicity = wit_infer_by_expr(
&fixed,
&witness,
&[],
&structural_witness,
&pi_mles,
&challenges,
&expr.multiplicity,
Expand All @@ -905,6 +908,7 @@ Hints:
}
}
wit_mles.insert(circuit_name.clone(), witness);
structural_wit_mles.insert(circuit_name.clone(), structural_witness);
fixed_mles.insert(circuit_name.clone(), fixed);
num_instances.insert(circuit_name.clone(), num_rows);
}
Expand Down
9 changes: 3 additions & 6 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,12 +964,9 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMProver<E, PCS> {
);
exit_span!(tower_span);

// same point sumcheck is optional when all witin + fixed are in same num_vars
let is_skip_same_point_sumcheck = witnesses
.iter()
.chain(fixed.iter())
.map(|v| v.num_vars())
.all_equal();
// In table proof, we always skip same point sumcheck for now
// as tower sumcheck batch product argument/logup in same length
let is_skip_same_point_sumcheck = true;

let (input_open_point, same_r_sumcheck_proofs, rw_in_evals, lk_in_evals) =
if is_skip_same_point_sumcheck {
Expand Down
31 changes: 22 additions & 9 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
let tower_proofs = &proof.tower_proof;

let expected_rounds = cs
// only iterate r set, as read/write set round should match
.r_table_expressions
.iter()
.flat_map(|r| {
Expand All @@ -538,13 +539,24 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
.max()
.unwrap()
});
[num_vars, num_vars]
[num_vars, num_vars] // format: [read_round, write_round]
})
.chain(
cs.lk_table_expressions
.iter()
.map(|l| ceil_log2(l.table_len)),
)
.chain(cs.lk_table_expressions.iter().map(|l| {
// iterate through structural witins and collect max round.
let num_vars = l.table_spec.len.map(ceil_log2).unwrap_or_else(|| {
l.table_spec
.structural_witins
.iter()
.map(|StructuralWitIn { id, max_len, .. }| {
let hint_num_vars = proof.rw_hints_num_vars[*id as usize];
assert!((1 << hint_num_vars) <= *max_len);
hint_num_vars
})
.max()
.unwrap()
});
num_vars
}))
.collect_vec();

for var in proof.rw_hints_num_vars.iter() {
Expand Down Expand Up @@ -693,9 +705,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
let structural_witnesses = cs
.r_table_expressions
.iter()
.flat_map(|set_table_expression| {
set_table_expression
.table_spec
.map(|r| &r.table_spec)
.chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec))
.flat_map(|table_spec| {
table_spec
.structural_witins
.iter()
.map(
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl CircuitStats {
})
} else {
let table_len = if !system.lk_table_expressions.is_empty() {
system.lk_table_expressions[0].table_len
system.lk_table_expressions[0].table_spec.len.unwrap_or(0)
} else {
0
};
Expand Down
13 changes: 11 additions & 2 deletions ceno_zkvm/src/tables/ops/ops_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato
use std::collections::HashMap;

use crate::{
circuit_builder::CircuitBuilder,
circuit_builder::{CircuitBuilder, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::InstancePaddingStrategy,
Expand Down Expand Up @@ -38,7 +38,16 @@ impl OpTableConfig {

let record_exprs = abc.into_iter().map(|f| Expression::Fixed(f)).collect_vec();

cb.lk_table_record(|| "record", table_len, rom_type, record_exprs, mlt.expr())?;
cb.lk_table_record(
|| "record",
SetTableSpec {
len: Some(table_len),
structural_witins: vec![],
},
rom_type,
record_exprs,
mlt.expr(),
)?;

Ok(Self { abc, mlt })
}
Expand Down
7 changes: 5 additions & 2 deletions ceno_zkvm/src/tables/program.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, marker::PhantomData};

use crate::{
circuit_builder::CircuitBuilder,
circuit_builder::{CircuitBuilder, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
instructions::InstancePaddingStrategy,
Expand Down Expand Up @@ -115,7 +115,10 @@ impl<E: ExtensionField> TableCircuit<E> for ProgramTableCircuit<E> {

cb.lk_table_record(
|| "prog table",
cb.params.program_size.next_power_of_two(),
SetTableSpec {
len: Some(cb.params.program_size.next_power_of_two()),
structural_witins: vec![],
},
ROMType::Instruction,
record_exprs,
mlt.expr(),
Expand Down
18 changes: 12 additions & 6 deletions ceno_zkvm/src/tables/range/range_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use super::range_impl::RangeTableConfig;
use std::{collections::HashMap, marker::PhantomData};

use crate::{
circuit_builder::CircuitBuilder, error::ZKVMError, structs::ROMType, tables::TableCircuit,
witness::RowMajorMatrix,
circuit_builder::CircuitBuilder, error::ZKVMError, instructions::InstancePaddingStrategy,
structs::ROMType, tables::TableCircuit, witness::RowMajorMatrix,
};
use ff_ext::ExtensionField;

Expand Down Expand Up @@ -40,11 +40,11 @@ impl<E: ExtensionField, RANGE: RangeTable> TableCircuit<E> for RangeTableCircuit
}

fn generate_fixed_traces(
config: &RangeTableConfig,
num_fixed: usize,
_config: &RangeTableConfig,
_num_fixed: usize,
_input: &(),
) -> RowMajorMatrix<E::BaseField> {
config.generate_fixed_traces(num_fixed, RANGE::content())
RowMajorMatrix::<E::BaseField>::new(0, 0, InstancePaddingStrategy::Default)
}

fn assign_instances(
Expand All @@ -55,6 +55,12 @@ impl<E: ExtensionField, RANGE: RangeTable> TableCircuit<E> for RangeTableCircuit
_input: &(),
) -> Result<RowMajorMatrix<E::BaseField>, ZKVMError> {
let multiplicity = &multiplicity[RANGE::ROM_TYPE as usize];
config.assign_instances(num_witin, num_structural_witin, multiplicity, RANGE::len())
config.assign_instances(
num_witin,
num_structural_witin,
multiplicity,
RANGE::content(),
RANGE::len(),
)
}
}
54 changes: 26 additions & 28 deletions ceno_zkvm/src/tables/range/range_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterato
use std::collections::HashMap;

use crate::{
circuit_builder::CircuitBuilder,
circuit_builder::{CircuitBuilder, SetTableSpec},
error::ZKVMError,
expression::{Expression, Fixed, ToExpr, WitIn},
expression::{StructuralWitIn, ToExpr, WitIn},
instructions::InstancePaddingStrategy,
scheme::constants::MIN_PAR_SIZE,
set_fixed_val, set_val,
set_val,
structs::ROMType,
witness::RowMajorMatrix,
};

#[derive(Clone, Debug)]
pub struct RangeTableConfig {
fixed: Fixed,
range: StructuralWitIn,
mlt: WitIn,
}

Expand All @@ -28,40 +28,31 @@ impl RangeTableConfig {
rom_type: ROMType,
table_len: usize,
) -> Result<Self, ZKVMError> {
let fixed = cb.create_fixed(|| "fixed")?;
let range = cb.create_structural_witin(|| "structural range witin", table_len, 0, 1);
let mlt = cb.create_witin(|| "mlt");

let record_exprs = vec![Expression::Fixed(fixed)];
let record_exprs = vec![range.expr()];

cb.lk_table_record(|| "record", table_len, rom_type, record_exprs, mlt.expr())?;
cb.lk_table_record(
|| "record",
SetTableSpec {
len: Some(table_len),
structural_witins: vec![range],
},
rom_type,
record_exprs,
mlt.expr(),
)?;

Ok(Self { fixed, mlt })
}

pub fn generate_fixed_traces<F: SmallField>(
&self,
num_fixed: usize,
content: Vec<u64>,
) -> RowMajorMatrix<F> {
let mut fixed =
RowMajorMatrix::<F>::new(content.len(), num_fixed, InstancePaddingStrategy::Default);

fixed
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(content.into_par_iter())
.for_each(|(row, i)| {
set_fixed_val!(row, self.fixed, F::from(i));
});

fixed
Ok(Self { range, mlt })
}

pub fn assign_instances<F: SmallField>(
&self,
num_witin: usize,
num_structural_witin: usize,
multiplicity: &HashMap<u64, usize>,
content: Vec<u64>,
length: usize,
) -> Result<RowMajorMatrix<F>, ZKVMError> {
let mut witness = RowMajorMatrix::<F>::new(
Expand All @@ -75,12 +66,19 @@ impl RangeTableConfig {
mlts[*idx as usize] = *mlt;
}

let offset_range = StructuralWitIn {
id: self.range.id + (num_witin as u16),
..self.range
};

witness
.par_iter_mut()
.with_min_len(MIN_PAR_SIZE)
.zip(mlts.into_par_iter())
.for_each(|(row, mlt)| {
.zip(content.into_par_iter())
.for_each(|((row, mlt), i)| {
set_val!(row, self.mlt, F::from(mlt as u64));
set_val!(row, offset_range, F::from(i));
});

Ok(witness)
Expand Down

0 comments on commit 7919da2

Please sign in to comment.