diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7961368e..76202a55 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,7 +41,7 @@ jobs: matrix: include: - os: macos-latest - - os: ubuntu-latest + # - os: ubuntu-latest steps: - name: Checkout code diff --git a/Cargo.lock b/Cargo.lock index 3150675f..a41aa1ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -565,6 +565,7 @@ version = "0.1.0" dependencies = [ "config", "field_hashers", + "gf2", "gkr_field_config", "mersenne31", "poly_commit", diff --git a/config/config_macros/Cargo.toml b/config/config_macros/Cargo.toml index dd96b48f..bb251f17 100644 --- a/config/config_macros/Cargo.toml +++ b/config/config_macros/Cargo.toml @@ -15,6 +15,7 @@ quote = "1.0" # For generating code proc-macro2 = "1.0" # For working with tokens [dev-dependencies] +gf2 = { path = "../../arith/gf2" } mersenne31 = { path = "../../arith/mersenne31/" } [lib] diff --git a/config/config_macros/src/lib.rs b/config/config_macros/src/lib.rs index 014183a1..9cc7430e 100644 --- a/config/config_macros/src/lib.rs +++ b/config/config_macros/src/lib.rs @@ -88,6 +88,7 @@ fn parse_fiat_shamir_hash_type( } fn parse_polynomial_commitment_type( + field_type: &str, field_config: &str, transcript_type: &str, polynomial_commitment_type: ExprPath, @@ -99,11 +100,19 @@ fn parse_polynomial_commitment_type( .expect("Empty path for polynomial commitment type"); let pcs_type_str = binding.ident.to_string(); - match pcs_type_str.as_str() { - "Raw" => ( + match (pcs_type_str.as_str(), field_type) { + ("Raw", _) => ( "Raw".to_owned(), format!("RawExpanderGKR::<{field_config}, {transcript_type}>").to_owned(), ), + ("Orion", "GF2") => ( + "Orion".to_owned(), + format!("OrionPCSForGKR::<{field_config}, GF2x128, {transcript_type}>").to_owned(), + ), + ("Orion", "M31") => ( + "Orion".to_owned(), + format!("OrionPCSForGKR::<{field_config}, M31x16, {transcript_type}>").to_owned(), + ), _ => panic!("Unknown polynomial commitment type in config macro expansion"), } } @@ -134,7 +143,8 @@ fn declare_gkr_config_impl(input: proc_macro::TokenStream) -> proc_macro::TokenS let (fiat_shamir_hash_type, transcript_type) = parse_fiat_shamir_hash_type(&field_type, &field_config, fiat_shamir_hash_type_expr); let (polynomial_commitment_enum, polynomial_commitment_type) = parse_polynomial_commitment_type( - field_config.as_str(), + &field_type, + &field_config, &transcript_type, polynomial_commitment_type, ); diff --git a/config/config_macros/tests/macro_expansion.rs b/config/config_macros/tests/macro_expansion.rs index 14e1ae6e..af8dba85 100644 --- a/config/config_macros/tests/macro_expansion.rs +++ b/config/config_macros/tests/macro_expansion.rs @@ -8,9 +8,10 @@ use gkr_field_config::FieldType; use config::GKRConfig; use config_macros::declare_gkr_config; use field_hashers::{MiMC5FiatShamirHasher, PoseidonFiatShamirHasher}; +use gf2::GF2x128; use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mersenne31::M31x16; -use poly_commit::raw::RawExpanderGKR; +use poly_commit::{OrionPCSForGKR, RawExpanderGKR}; use transcript::{BytesHashTranscript, FieldHashTranscript, Keccak256hasher, SHA256hasher}; fn print_type_name() { @@ -26,11 +27,17 @@ fn main() { PolynomialCommitmentType::Raw ); declare_gkr_config!( - M31PoseidonConfig, + M31PoseidonRawConfig, FieldType::M31, FiatShamirHashType::Poseidon, PolynomialCommitmentType::Raw ); + declare_gkr_config!( + M31PoseidonOrionConfig, + FieldType::M31, + FiatShamirHashType::Poseidon, + PolynomialCommitmentType::Orion + ); declare_gkr_config!( BN254MIMCConfig, FieldType::BN254, @@ -43,8 +50,17 @@ fn main() { FiatShamirHashType::Keccak256, PolynomialCommitmentType::Raw ); + declare_gkr_config!( + GF2Keccak256OrionConfig, + FieldType::GF2, + FiatShamirHashType::Keccak256, + PolynomialCommitmentType::Orion + ); print_type_name::(); + print_type_name::(); + print_type_name::(); print_type_name::(); print_type_name::(); + print_type_name::(); } diff --git a/config/mpi_config/src/lib.rs b/config/mpi_config/src/lib.rs index 2fae9994..38eadd07 100644 --- a/config/mpi_config/src/lib.rs +++ b/config/mpi_config/src/lib.rs @@ -134,16 +134,16 @@ impl MPIConfig { /// Return an u8 vector sharing THE SAME MEMORY SLOT with the input. #[inline] - unsafe fn vec_to_u8_bytes(vec: &Vec) -> Vec { + unsafe fn vec_to_u8_bytes(vec: &Vec) -> Vec { Vec::::from_raw_parts( vec.as_ptr() as *mut u8, - vec.len() * F::SIZE, - vec.capacity() * F::SIZE, + vec.len() * size_of::(), + vec.capacity() * size_of::(), ) } #[allow(clippy::collapsible_else_if)] - pub fn gather_vec(&self, local_vec: &Vec, global_vec: &mut Vec) { + pub fn gather_vec(&self, local_vec: &Vec, global_vec: &mut Vec) { unsafe { if self.world_size == 1 { *global_vec = local_vec.clone() diff --git a/gkr/src/executor.rs b/gkr/src/executor.rs index af6910c0..859afe9c 100644 --- a/gkr/src/executor.rs +++ b/gkr/src/executor.rs @@ -12,9 +12,10 @@ use config::{ }; use config_macros::declare_gkr_config; use field_hashers::{MiMC5FiatShamirHasher, PoseidonFiatShamirHasher}; +use gf2::GF2x128; use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mersenne31::M31x16; -use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR}; +use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR, OrionPCSForGKR}; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; use transcript::{BytesHashTranscript, FieldHashTranscript, SHA256hasher}; @@ -289,5 +290,5 @@ declare_gkr_config!( pub GF2ExtConfigSha2, FieldType::GF2, FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw + PolynomialCommitmentType::Orion ); diff --git a/gkr/src/main.rs b/gkr/src/main.rs index 2c11ef5b..48dc770e 100644 --- a/gkr/src/main.rs +++ b/gkr/src/main.rs @@ -7,10 +7,11 @@ use circuit::Circuit; use clap::Parser; use config::{Config, GKRConfig, GKRScheme}; use config_macros::declare_gkr_config; +use gf2::GF2x128; use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mpi_config::MPIConfig; -use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR}; +use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR, OrionPCSForGKR}; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; use transcript::{BytesHashTranscript, SHA256hasher}; @@ -72,7 +73,7 @@ fn main() { GF2ExtConfigSha2, FieldType::GF2, FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw + PolynomialCommitmentType::Orion ); match args.field.as_str() { diff --git a/gkr/src/main_mpi.rs b/gkr/src/main_mpi.rs index 2a03a453..afc198e5 100644 --- a/gkr/src/main_mpi.rs +++ b/gkr/src/main_mpi.rs @@ -4,8 +4,9 @@ use config::{Config, GKRConfig, GKRScheme}; use config_macros::declare_gkr_config; use mpi_config::MPIConfig; +use gf2::GF2x128; use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; -use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR}; +use poly_commit::{expander_pcs_init_testing_only, raw::RawExpanderGKR, OrionPCSForGKR}; use rand::SeedableRng; use rand_chacha::ChaCha12Rng; use transcript::{BytesHashTranscript, SHA256hasher}; @@ -63,7 +64,7 @@ fn main() { GF2ExtConfigSha2, FieldType::GF2, FiatShamirHashType::SHA256, - PolynomialCommitmentType::Raw + PolynomialCommitmentType::Orion ); match args.field.as_str() { diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index a6daed0d..5444b6d3 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -8,11 +8,11 @@ use circuit::Circuit; use config::{Config, FiatShamirHashType, GKRConfig, GKRScheme, PolynomialCommitmentType}; use config_macros::declare_gkr_config; use field_hashers::{MiMC5FiatShamirHasher, PoseidonFiatShamirHasher}; +use gf2::GF2x128; use gkr_field_config::{BN254Config, FieldType, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mersenne31::M31x16; use mpi_config::{root_println, MPIConfig}; -use poly_commit::expander_pcs_init_testing_only; -use poly_commit::raw::RawExpanderGKR; +use poly_commit::{expander_pcs_init_testing_only, OrionPCSForGKR, RawExpanderGKR}; use rand::{Rng, SeedableRng}; use rand_chacha::ChaCha12Rng; use sha2::Digest; @@ -67,12 +67,24 @@ fn test_gkr_correctness() { FiatShamirHashType::MIMC5, PolynomialCommitmentType::Raw ); + declare_gkr_config!( + C7, + FieldType::GF2, + FiatShamirHashType::Keccak256, + PolynomialCommitmentType::Orion, + ); declare_gkr_config!( C8, FieldType::M31, FiatShamirHashType::Poseidon, PolynomialCommitmentType::Raw, ); + declare_gkr_config!( + C9, + FieldType::M31, + FiatShamirHashType::Poseidon, + PolynomialCommitmentType::Orion, + ); test_gkr_correctness_helper( &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), @@ -102,10 +114,18 @@ fn test_gkr_correctness() { &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), Some("../data/gkr_proof.txt"), ); + test_gkr_correctness_helper( + &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + None, + ); test_gkr_correctness_helper( &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), None, ); + test_gkr_correctness_helper( + &Config::::new(GKRScheme::Vanilla, mpi_config.clone()), + None, + ); MPIConfig::finalize(); } diff --git a/poly_commit/src/orion.rs b/poly_commit/src/orion.rs index 48d144e9..375b0f30 100644 --- a/poly_commit/src/orion.rs +++ b/poly_commit/src/orion.rs @@ -26,6 +26,14 @@ pub use simd_field_impl::{ #[cfg(test)] mod simd_field_tests; +mod simd_field_agg_impl; + +#[cfg(test)] +mod simd_field_agg_tests; + +mod pcs_for_expander_gkr; +pub use pcs_for_expander_gkr::OrionPCSForGKR; + mod pcs_trait_impl; pub use pcs_trait_impl::{OrionBaseFieldPCS, OrionSIMDFieldPCS}; diff --git a/poly_commit/src/orion/pcs_for_expander_gkr.rs b/poly_commit/src/orion/pcs_for_expander_gkr.rs new file mode 100644 index 00000000..1157a6d5 --- /dev/null +++ b/poly_commit/src/orion/pcs_for_expander_gkr.rs @@ -0,0 +1,212 @@ +use std::io::Cursor; + +use arith::{FieldSerde, SimdField}; +use gkr_field_config::GKRFieldConfig; +use mpi_config::MPIConfig; +use polynomials::{EqPolynomial, MultilinearExtension}; +use transcript::Transcript; + +use crate::{ + orion::{simd_field_agg_impl::*, *}, + ExpanderGKRChallenge, PCSForExpanderGKR, StructuredReferenceString, +}; + +impl PCSForExpanderGKR + for OrionSIMDFieldPCS +where + C: GKRFieldConfig, + ComPackF: SimdField, + T: Transcript, +{ + const NAME: &'static str = "OrionPCSForExpanderGKR"; + + type Params = usize; + type ScratchPad = OrionScratchPad; + + type Commitment = OrionCommitment; + type Opening = OrionProof; + type SRS = OrionSRS; + + /// NOTE(HS): this is actually number of variables in polynomial, + /// ignoring the variables for MPI parties and SIMD field element + fn gen_params(n_input_vars: usize) -> Self::Params { + n_input_vars + } + + fn gen_srs_for_testing( + params: &Self::Params, + #[allow(unused)] mpi_config: &MPIConfig, + rng: impl rand::RngCore, + ) -> Self::SRS { + let num_vars_each_core = *params + C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + OrionSRS::from_random::( + num_vars_each_core, + ORION_CODE_PARAMETER_INSTANCE, + rng, + ) + } + + fn init_scratch_pad(_params: &Self::Params, _mpi_config: &MPIConfig) -> Self::ScratchPad { + Self::ScratchPad::default() + } + + fn commit( + params: &Self::Params, + mpi_config: &MPIConfig, + proving_key: &::PKey, + poly: &impl MultilinearExtension, + scratch_pad: &mut Self::ScratchPad, + ) -> Self::Commitment { + let num_vars_each_core = *params + C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + assert_eq!(num_vars_each_core, proving_key.num_vars); + + let commitment = orion_commit_simd_field(proving_key, poly, scratch_pad).unwrap(); + + // NOTE: Hang also assume that, linear GKR will take over the commitment + // and force sync transcript hash state of subordinate machines to be the same. + if mpi_config.world_size() == 1 { + return commitment; + } + + let local_buffer = vec![commitment]; + let mut buffer = vec![Self::Commitment::default(); mpi_config.world_size()]; + mpi_config.gather_vec(&local_buffer, &mut buffer); + + if !mpi_config.is_root() { + return commitment; + } + + let final_tree_height = 1 + buffer.len().ilog2(); + let (internals, _) = tree::Tree::new_with_leaf_nodes(buffer.clone(), final_tree_height); + internals[0] + } + + fn open( + params: &Self::Params, + mpi_config: &MPIConfig, + proving_key: &::PKey, + poly: &impl MultilinearExtension, + eval_point: &ExpanderGKRChallenge, + transcript: &mut T, // add transcript here to allow interactive arguments + scratch_pad: &mut Self::ScratchPad, + ) -> Self::Opening { + let num_vars_each_core = *params + C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + assert_eq!(num_vars_each_core, proving_key.num_vars); + + let local_xs = eval_point.local_xs(); + let local_opening = orion_open_simd_field::< + C::CircuitField, + C::SimdCircuitField, + C::ChallengeField, + ComPackF, + T, + >(proving_key, poly, &local_xs, transcript, scratch_pad); + if mpi_config.world_size() == 1 { + return local_opening; + } + + // NOTE: eval row combine from MPI + let mpi_eq_coeffs = EqPolynomial::build_eq_x_r(&eval_point.x_mpi); + let eval_row = mpi_config.coef_combine_vec(&local_opening.eval_row, &mpi_eq_coeffs); + + // NOTE: sample MPI linear combination coeffs for proximity rows, + // and proximity rows combine with MPI + let proximity_rows = local_opening + .proximity_rows + .iter() + .map(|row| { + let weights = transcript.generate_challenge_field_elements(mpi_config.world_size()); + mpi_config.coef_combine_vec(row, &weights) + }) + .collect(); + + // NOTE: local query openings serialized to bytes + let mut local_mt_paths_serialized = Vec::new(); + local_opening + .query_openings + .serialize_into(&mut local_mt_paths_serialized) + .unwrap(); + + // NOTE: Hang does not think this is a good move, but this is mostly + // working with MPI behavior, so we align local MT openings serialization + // against power-of-2 bytes length. + local_mt_paths_serialized.resize(local_mt_paths_serialized.len().next_power_of_two(), 0u8); + + // NOTE: gather all merkle paths + let mut mt_paths_serialized = + vec![0u8; mpi_config.world_size() * local_mt_paths_serialized.len()]; + mpi_config.gather_vec(&local_mt_paths_serialized, &mut mt_paths_serialized); + + let query_openings: Vec = mt_paths_serialized + .chunks(local_mt_paths_serialized.len()) + .flat_map(|bs| { + let mut read_cursor = Cursor::new(bs); + Vec::deserialize_from(&mut read_cursor).unwrap() + }) + .collect(); + + if !mpi_config.is_root() { + return local_opening; + } + + // NOTE: we only care about the root machine's opening as final proof, Hang assume. + OrionProof { + eval_row, + proximity_rows, + query_openings, + } + } + + fn verify( + params: &Self::Params, + mpi_config: &MPIConfig, + verifying_key: &::VKey, + commitment: &Self::Commitment, + eval_point: &ExpanderGKRChallenge, + eval: C::ChallengeField, + transcript: &mut T, // add transcript here to allow interactive arguments + opening: &Self::Opening, + ) -> bool { + let global_poly_num_vars = *params + + mpi_config.world_size().ilog2() as usize + + C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + assert_eq!(global_poly_num_vars, eval_point.num_vars()); + + if mpi_config.world_size() == 1 || !mpi_config.is_root() { + return orion_verify_simd_field::< + C::CircuitField, + C::SimdCircuitField, + C::ChallengeField, + ComPackF, + T, + >( + verifying_key, + commitment, + &eval_point.local_xs(), + eval, + transcript, + opening, + ); + } + + // NOTE: we now assume that the input opening is from the root machine, + // as proofs from other machines are typically undefined + orion_verify_simd_field_aggregated::( + mpi_config.world_size(), + verifying_key, + commitment, + eval_point, + eval, + transcript, + opening, + ) + } +} + +pub type OrionPCSForGKR = OrionSIMDFieldPCS< + ::CircuitField, + ::SimdCircuitField, + ::ChallengeField, + ComPack, + T, +>; diff --git a/poly_commit/src/orion/simd_field_agg_impl.rs b/poly_commit/src/orion/simd_field_agg_impl.rs new file mode 100644 index 00000000..d1935510 --- /dev/null +++ b/poly_commit/src/orion/simd_field_agg_impl.rs @@ -0,0 +1,160 @@ +use std::iter; + +use arith::{Field, SimdField}; +use gf2::GF2; +use gkr_field_config::GKRFieldConfig; +use itertools::izip; +use polynomials::{EqPolynomial, MultilinearExtension, RefMultiLinearPoly}; +use transcript::Transcript; + +use crate::{ + orion::utils::*, traits::TensorCodeIOPPCS, ExpanderGKRChallenge, OrionCommitment, OrionProof, + OrionSRS, PCS_SOUNDNESS_BITS, +}; + +pub(crate) fn orion_verify_simd_field_aggregated( + mpi_world_size: usize, + vk: &OrionSRS, + commitment: &OrionCommitment, + eval_point: &ExpanderGKRChallenge, + eval: C::ChallengeField, + transcript: &mut T, + proof: &OrionProof, +) -> bool +where + C: GKRFieldConfig, + ComPackF: SimdField, + T: Transcript, +{ + let local_num_vars = eval_point.num_vars() - mpi_world_size.ilog2() as usize; + assert_eq!(local_num_vars, vk.num_vars); + + let (row_num, msg_size) = { + let (row_field_elems, msg_size) = OrionSRS::evals_shape::(local_num_vars); + let row_num = row_field_elems / C::SimdCircuitField::PACK_SIZE; + (row_num, msg_size) + }; + + let num_vars_in_simd = C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + let num_vars_in_msg = msg_size.ilog2() as usize; + + let global_xs = eval_point.global_xs(); + + // NOTE: working on evaluation response + let mut scratch = vec![C::ChallengeField::ZERO; msg_size]; + let final_eval = RefMultiLinearPoly::from_ref(&proof.eval_row).evaluate_with_buffer( + &global_xs[num_vars_in_simd..num_vars_in_simd + num_vars_in_msg], + &mut scratch, + ); + if final_eval != eval { + return false; + } + + // NOTE: working on proximity responses, draw random linear combinations + // then draw query points from fiat shamir transcripts + let proximity_reps = vk.proximity_repetitions::(PCS_SOUNDNESS_BITS); + let proximity_local_coeffs: Vec> = (0..proximity_reps) + .map(|_| { + transcript.generate_challenge_field_elements(row_num * C::SimdCircuitField::PACK_SIZE) + }) + .collect(); + + let query_num = vk.query_complexity(PCS_SOUNDNESS_BITS); + let query_indices = transcript.generate_challenge_index_vector(query_num); + + let proximity_worlds_coeffs: Vec> = (0..proximity_reps) + .map(|_| transcript.generate_challenge_field_elements(mpi_world_size)) + .collect(); + + // NOTE: work on the Merkle tree path validity + let roots: Vec<_> = proof + .query_openings + .chunks(query_num) + .map(|qs| qs[0].root()) + .collect(); + + let final_root = { + // NOTE: check all merkle paths, and check merkle roots against commitment + let final_tree_height = 1 + roots.len().ilog2(); + let (internals, _) = tree::Tree::new_with_leaf_nodes(roots.clone(), final_tree_height); + internals[0] + }; + if final_root != *commitment { + return false; + } + + if !orion_mt_verify( + vk, + &query_indices, + &proof.query_openings[..query_num], + &roots[0], + ) { + return false; + } + + let mut packed_interleaved_alphabets: Vec> = + vec![Vec::new(); query_num]; + + let concatenated_packed_interleaved_alphabets: Vec<_> = proof + .query_openings + .iter() + .map(|c| -> Vec<_> { + let elts = c.unpack_field_elems::(); + elts.chunks(C::SimdCircuitField::PACK_SIZE) + .map(C::SimdCircuitField::pack) + .collect() + }) + .collect(); + + concatenated_packed_interleaved_alphabets + .chunks(query_num) + .for_each(|alphabets| { + izip!(&mut packed_interleaved_alphabets, alphabets) + .for_each(|(packed, alphabet)| packed.extend_from_slice(alphabet)) + }); + + let mut eq_vars = vec![C::ChallengeField::ZERO; eval_point.num_vars() - num_vars_in_msg]; + eq_vars[..num_vars_in_simd].copy_from_slice(&global_xs[..num_vars_in_simd]); + eq_vars[num_vars_in_simd..].copy_from_slice(&global_xs[num_vars_in_simd + num_vars_in_msg..]); + + let eq_col_coeffs = EqPolynomial::build_eq_x_r(&eq_vars); + + let proximity_coeffs: Vec> = (0..proximity_reps) + .map(|i| { + proximity_worlds_coeffs[i] + .iter() + .flat_map(|w| { + proximity_local_coeffs[i] + .iter() + .map(|l| *l * *w) + .collect::>() + }) + .collect() + }) + .collect(); + + // NOTE: decide if expected alphabet matches actual responses + izip!(&proximity_coeffs, &proof.proximity_rows) + .chain(iter::once((&eq_col_coeffs, &proof.eval_row))) + .all(|(rl, msg)| { + let codeword = match vk.code_instance.encode(msg) { + Ok(c) => c, + _ => return false, + }; + + match C::CircuitField::NAME { + GF2::NAME => lut_verify_alphabet_check( + &codeword, + rl, + &query_indices, + &packed_interleaved_alphabets, + ), + _ => simd_verify_alphabet_check( + &codeword, + rl, + &query_indices, + &packed_interleaved_alphabets, + ), + } + }) +} diff --git a/poly_commit/src/orion/simd_field_agg_tests.rs b/poly_commit/src/orion/simd_field_agg_tests.rs new file mode 100644 index 00000000..9cd1c2e4 --- /dev/null +++ b/poly_commit/src/orion/simd_field_agg_tests.rs @@ -0,0 +1,205 @@ +use std::marker::PhantomData; + +use arith::{ExtensionField, Field, SimdField}; +use ark_std::test_rng; +use gf2::GF2x128; +use gf2_128::GF2_128; +use gkr_field_config::{GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; +use itertools::izip; +use mersenne31::{M31Ext3, M31x16}; +use polynomials::{EqPolynomial, MultiLinearPoly}; +use transcript::{BytesHashTranscript, Keccak256hasher, Transcript}; + +use crate::{ + orion::{simd_field_agg_impl::*, utils::*, *}, + ExpanderGKRChallenge, RawExpanderGKR, +}; + +#[derive(Clone)] +struct DistributedCommitter +where + F: Field, + EvalF: ExtensionField, + ComPackF: SimdField, + T: Transcript, +{ + pub scratch_pad: OrionScratchPad, + pub transcript: T, + + _phantom: PhantomData, +} + +fn orion_proof_aggregate( + openings: &[OrionProof], + x_mpi: &[C::ChallengeField], + transcript: &mut T, +) -> OrionProof +where + C: GKRFieldConfig, + T: Transcript, +{ + let paths = openings + .iter() + .flat_map(|o| o.query_openings.clone()) + .collect(); + let num_parties = 1 << x_mpi.len(); + + let proximity_reps = openings[0].proximity_rows.len(); + let mut scratch = vec![C::ChallengeField::ZERO; num_parties * openings[0].eval_row.len()]; + + let aggregated_proximity_rows = (0..proximity_reps) + .map(|i| { + let weights = transcript.generate_challenge_field_elements(num_parties); + let mut rows: Vec<_> = openings + .iter() + .flat_map(|o| o.proximity_rows[i].clone()) + .collect(); + transpose_in_place(&mut rows, &mut scratch, num_parties); + rows.chunks(num_parties) + .map(|c| izip!(c, &weights).map(|(&l, &r)| l * r).sum()) + .collect() + }) + .collect(); + + let aggregated_eval_row: Vec<_> = { + let eq_worlds_coeffs = EqPolynomial::build_eq_x_r(x_mpi); + let mut rows: Vec<_> = openings.iter().flat_map(|o| o.eval_row.clone()).collect(); + transpose_in_place(&mut rows, &mut scratch, num_parties); + rows.chunks(num_parties) + .map(|c| izip!(c, &eq_worlds_coeffs).map(|(&l, &r)| l * r).sum()) + .collect() + }; + + OrionProof { + eval_row: aggregated_eval_row, + proximity_rows: aggregated_proximity_rows, + query_openings: paths, + } +} + +fn test_orion_simd_aggregate_verify_helper(num_parties: usize, num_vars: usize) +where + C: GKRFieldConfig, + ComPackF: SimdField, + T: Transcript, +{ + assert!(num_parties.is_power_of_two()); + + let mut rng = test_rng(); + + let simd_num_vars = C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + let world_num_vars = num_parties.ilog2() as usize; + + let global_poly = + MultiLinearPoly::::random(num_vars - simd_num_vars, &mut rng); + + let global_real_num_vars = global_poly.get_num_vars(); + let local_real_num_vars = global_real_num_vars - world_num_vars; + + let eval_point: Vec<_> = (0..num_vars) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect(); + + let gkr_challenge: ExpanderGKRChallenge = ExpanderGKRChallenge { + x_mpi: eval_point[num_vars - world_num_vars..].to_vec(), + x: eval_point[simd_num_vars..num_vars - world_num_vars].to_vec(), + x_simd: eval_point[..simd_num_vars].to_vec(), + }; + + let mut committee = vec![ + DistributedCommitter { + scratch_pad: OrionScratchPad::::default(), + transcript: T::new(), + _phantom: PhantomData, + }; + num_parties + ]; + let mut verifier_transcript = T::new(); + + let srs = OrionSRS::from_random::( + num_vars - world_num_vars, + ORION_CODE_PARAMETER_INSTANCE, + &mut rng, + ); + + let final_commitment = { + let roots: Vec<_> = izip!( + &mut committee, + global_poly.coeffs.chunks(1 << local_real_num_vars) + ) + .map(|(committer, eval_slice)| { + let cloned_poly = MultiLinearPoly::new(eval_slice.to_vec()); + orion_commit_simd_field(&srs, &cloned_poly, &mut committer.scratch_pad).unwrap() + }) + .collect(); + + let final_tree_height = 1 + roots.len().ilog2(); + let (internals, _) = tree::Tree::new_with_leaf_nodes(roots, final_tree_height); + internals[0] + }; + + let openings: Vec<_> = + izip!( + &mut committee, + global_poly.coeffs.chunks(1 << local_real_num_vars) + ) + .map(|(committer, eval_slice)| { + let cloned_poly = MultiLinearPoly::new(eval_slice.to_vec()); + orion_open_simd_field::< + C::CircuitField, + C::SimdCircuitField, + C::ChallengeField, + ComPackF, + T, + >( + &srs, + &cloned_poly, + &gkr_challenge.local_xs(), + &mut committer.transcript, + &committer.scratch_pad, + ) + }) + .collect(); + + let mut aggregator_transcript = committee[0].transcript.clone(); + let aggregated_proof = + orion_proof_aggregate::(&openings, &gkr_challenge.x_mpi, &mut aggregator_transcript); + + let final_expected_eval = RawExpanderGKR::::eval( + &global_poly.coeffs, + &gkr_challenge.x, + &gkr_challenge.x_simd, + &gkr_challenge.x_mpi, + ); + + assert!(orion_verify_simd_field_aggregated::( + num_parties, + &srs, + &final_commitment, + &gkr_challenge, + final_expected_eval, + &mut verifier_transcript, + &aggregated_proof, + )); +} + +#[test] +fn test_orion_simd_aggregate_verify() { + let parties = 16; + + (16..18).for_each(|num_var| { + test_orion_simd_aggregate_verify_helper::< + GF2ExtConfig, + GF2x128, + BytesHashTranscript, + >(parties, num_var) + }); + + (12..15).for_each(|num_var| { + test_orion_simd_aggregate_verify_helper::< + M31ExtConfig, + M31x16, + BytesHashTranscript, + >(parties, num_var) + }) +} diff --git a/poly_commit/src/orion/utils.rs b/poly_commit/src/orion/utils.rs index c715199b..025a4d88 100644 --- a/poly_commit/src/orion/utils.rs +++ b/poly_commit/src/orion/utils.rs @@ -472,3 +472,39 @@ where alphabet == codeword[index] }) } + +#[cfg(test)] +mod tests { + use arith::{Field, SimdField}; + use ark_std::test_rng; + use gf2::{GF2x8, GF2}; + use gf2_128::{GF2_128x8, GF2_128}; + use itertools::izip; + + use super::SubsetSumLUTs; + + #[test] + fn test_lut_simd_inner_prod_consistency() { + let mut rng = test_rng(); + + let weights: Vec<_> = (0..8).map(|_| GF2_128::random_unsafe(&mut rng)).collect(); + let bases: Vec<_> = (0..8).map(|_| GF2::random_unsafe(&mut rng)).collect(); + + let simd_weights = GF2_128x8::pack(&weights); + let simd_bases = GF2x8::pack(&bases); + + let expected_simd_inner_prod: GF2_128 = (simd_weights * simd_bases).unpack().iter().sum(); + + let expected_vanilla_inner_prod: GF2_128 = + izip!(&weights, &bases).map(|(w, b)| *w * *b).sum(); + + assert_eq!(expected_simd_inner_prod, expected_vanilla_inner_prod); + + let mut table = SubsetSumLUTs::new(8, 1); + table.build(&weights); + + let actual_lut_inner_prod = table.lookup_and_sum(&vec![simd_bases]); + + assert_eq!(expected_simd_inner_prod, actual_lut_inner_prod) + } +} diff --git a/poly_commit/tests/common.rs b/poly_commit/tests/common.rs index f7635412..0145657b 100644 --- a/poly_commit/tests/common.rs +++ b/poly_commit/tests/common.rs @@ -45,7 +45,6 @@ pub fn test_pcs, P: PolynomialCommitmentSche } } -#[allow(unused)] pub fn test_pcs_for_expander_gkr< C: GKRFieldConfig, T: Transcript, diff --git a/poly_commit/tests/test_orion.rs b/poly_commit/tests/test_orion.rs index 3c209afe..8e970c68 100644 --- a/poly_commit/tests/test_orion.rs +++ b/poly_commit/tests/test_orion.rs @@ -4,10 +4,12 @@ use arith::{ExtensionField, Field, SimdField}; use ark_std::test_rng; use gf2::{GF2x128, GF2x64, GF2x8, GF2}; use gf2_128::GF2_128; +use gkr_field_config::{GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mersenne31::{M31Ext3, M31x16, M31}; +use mpi_config::MPIConfig; use poly_commit::*; use polynomials::MultiLinearPoly; -use transcript::{BytesHashTranscript, Keccak256hasher}; +use transcript::{BytesHashTranscript, Keccak256hasher, Transcript}; const TEST_REPETITION: usize = 3; @@ -95,3 +97,82 @@ fn test_orion_simd_field_pcs_full_e2e() { test_orion_simd_field_pcs_generics::(19, 25); test_orion_simd_field_pcs_generics::(16, 22); } + +fn test_orion_for_expander_gkr_generics( + mpi_config_ref: &MPIConfig, + total_num_vars: usize, +) where + C: GKRFieldConfig, + ComPackF: SimdField, + T: Transcript, +{ + let mut rng = test_rng(); + + // NOTE: generate global random polynomial + let num_vars_in_simd = C::SimdCircuitField::PACK_SIZE.ilog2() as usize; + let num_vars_in_mpi = mpi_config_ref.world_size().ilog2() as usize; + let num_vars_in_each_poly = total_num_vars - num_vars_in_mpi - num_vars_in_simd; + let num_vars_in_global_poly = total_num_vars - num_vars_in_simd; + + let global_poly = + MultiLinearPoly::::random(num_vars_in_global_poly, &mut rng); + + // NOTE generate srs for each party, and shared challenge point in each party + let challenge_point = ExpanderGKRChallenge:: { + x_mpi: (0..num_vars_in_mpi) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect(), + x_simd: (0..num_vars_in_simd) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect(), + x: (0..num_vars_in_each_poly) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect(), + }; + + let mut transcript = T::new(); + + dbg!(global_poly.get_num_vars(), global_poly.coeffs[0]); + dbg!(&challenge_point.x_mpi); + dbg!(mpi_config_ref.world_size(), mpi_config_ref.world_rank()); + + // NOTE separate polynomial into different pieces by mpi rank + let poly_vars_stride = (1 << global_poly.get_num_vars()) / mpi_config_ref.world_size(); + let poly_coeff_starts = mpi_config_ref.world_rank() * poly_vars_stride; + let poly_coeff_ends = poly_coeff_starts + poly_vars_stride; + let local_poly = + MultiLinearPoly::new(global_poly.coeffs[poly_coeff_starts..poly_coeff_ends].to_vec()); + + dbg!(local_poly.get_num_vars(), local_poly.coeffs[0]); + + common::test_pcs_for_expander_gkr::< + C, + T, + OrionSIMDFieldPCS, + >( + &num_vars_in_each_poly, + &mpi_config_ref, + &mut transcript, + &local_poly, + &vec![challenge_point], + ); +} + +#[test] +fn test_orion_for_expander_gkr() { + let mpi_config = MPIConfig::new(); + + test_orion_for_expander_gkr_generics::< + GF2ExtConfig, + GF2x128, + BytesHashTranscript<_, Keccak256hasher>, + >(&mpi_config, 16); + + test_orion_for_expander_gkr_generics::< + M31ExtConfig, + M31x16, + BytesHashTranscript<_, Keccak256hasher>, + >(&mpi_config, 15); + + MPIConfig::finalize() +}