diff --git a/examples/bulletproof.rs b/examples/bulletproof.rs index edbf920..d5d9311 100644 --- a/examples/bulletproof.rs +++ b/examples/bulletproof.rs @@ -47,9 +47,6 @@ where { assert_eq!(witness.0.len(), witness.1.len()); - arthur.public_points(&[*statement]).unwrap(); - arthur.ratchet().unwrap(); - if witness.0.len() == 1 { assert_eq!(generators.0.len(), 1); @@ -99,9 +96,6 @@ fn verify( where for<'a> Merlin<'a>: GroupReader + FieldChallenges, { - merlin.public_points(&[*statement]).unwrap(); - merlin.ratchet().unwrap(); - let mut g = generators.0.to_vec(); let mut h = generators.1.to_vec(); let u = generators.2.clone(); @@ -196,6 +190,8 @@ fn main() { let witness = (&a[..], &b[..]); let mut arthur = iopattern.to_arthur(); + arthur.public_points(&[statement]).unwrap(); + arthur.ratchet().unwrap(); let proof = prove(&mut arthur, generators, &statement, witness).expect("Error proving"); println!( "Here's a bulletproof for {} elements:\n{}", @@ -203,6 +199,8 @@ fn main() { hex::encode(proof) ); - let mut verifier_transcript = iopattern.to_merlin(proof); - verify(&mut verifier_transcript, generators, size, &statement).expect("Invalid proof"); + let mut merlin = iopattern.to_merlin(proof); + merlin.public_points(&[statement]).unwrap(); + merlin.ratchet().unwrap(); + verify(&mut merlin, generators, size, &statement).expect("Invalid proof"); } diff --git a/examples/schnorr.rs b/examples/schnorr.rs index 42264b0..204f9e0 100644 --- a/examples/schnorr.rs +++ b/examples/schnorr.rs @@ -74,9 +74,6 @@ where G: CurveGroup, Arthur: GroupWriter + FieldChallenges, { - arthur.public_points(&[P, P * x]).unwrap(); - arthur.ratchet().unwrap(); - // `Arthur` types implement a cryptographically-secure random number generator that is tied to the protocol transcript // and that can be accessed via the `rng()` funciton. let k = G::ScalarField::rand(arthur.rng()); @@ -114,12 +111,12 @@ fn verify( where G: CurveGroup, H: DuplexHash, - for<'a> Merlin<'a, H>: GroupReader + FieldChallenges, + for<'a> Merlin<'a, H>: + GroupReader + FieldReader + FieldChallenges, { - merlin.public_points(&[P, X]).unwrap(); - merlin.ratchet().unwrap(); - // Read the protocol from the transcript: + // XXX. possible inconsistent implementations: + // if the point is not validated here (but the public key is) then the proof may fail with InvalidProof, instead of SerializationError let [K] = merlin.next_points().unwrap(); let [c] = merlin.challenge_scalars().unwrap(); let [r] = merlin.next_scalars().unwrap(); @@ -146,11 +143,11 @@ fn main() { type G = ark_curve25519::EdwardsProjective; // Set the hash function (commented out other valid choices): // type H = nimue::hash::Keccak; - // type H = nimue::hash::legacy::DigestBridge; + type H = nimue::hash::legacy::DigestBridge; // type H = nimue::hash::legacy::DigestBridge; // Set up the IO for the protocol transcript with domain separator "nimue::examples::schnorr" - let io = IOPattern::new("nimue::examples::schnorr"); + let io = IOPattern::::new("nimue::examples::schnorr"); let io = SchnorrIOPattern::::add_schnorr_io(io); // Set up the elements to prove @@ -158,7 +155,9 @@ fn main() { let (x, X) = keygen(); // Create the prover transcript, add the statement to it, and then invoke the prover. - let mut arthur: Arthur = io.to_arthur(); + let mut arthur = io.to_arthur(); + arthur.public_points(&[P, P * x]).unwrap(); + arthur.ratchet().unwrap(); let proof = prove(&mut arthur, P, x).expect("Invalid proof"); // Print out the hex-encoded schnorr proof. diff --git a/examples/schnorr_algebraic_hash.rs b/examples/schnorr_algebraic_hash.rs index 9381450..1c45660 100644 --- a/examples/schnorr_algebraic_hash.rs +++ b/examples/schnorr_algebraic_hash.rs @@ -45,21 +45,22 @@ fn keygen() -> (G::ScalarField, G) { /// - the secret key $x \in \mathbb{Z}_p$ /// It returns a zero-knowledge proof of knowledge of `x` as a sequence of bytes. #[allow(non_snake_case)] -fn aprove( +fn prove( // the hash function `H` works over bytes. // Algebraic hashes over a particular domain can be denoted with an additional type argument implementing `nimue::Unit`. - arthur: &mut Arthur, + arthur: &mut Arthur, // the generator P: G, // the secret key x: G::ScalarField, ) -> ProofResult<&[u8]> where - G::BaseField: Unit, + U: Unit, + G::BaseField: PrimeField, R: CryptoRng + RngCore, - H: DuplexHash, + H: DuplexHash, G: CurveGroup, - Arthur: GroupWriter + ByteChallenges, + Arthur: GroupWriter + FieldWriter + ByteChallenges, { // `Arthur` types implement a cryptographically-secure random number generator that is tied to the protocol transcript // and that can be accessed via the `rng()` funciton. @@ -74,9 +75,10 @@ where let c_bytes = arthur.challenge_bytes::<16>()?; let c = G::ScalarField::from_le_bytes_mod_order(&c_bytes); - let _r = k + c * x; + let r = k + c * x; + let r_q = swap_field::(r)?; // Add a sequence of scalar elements to the protocol transcript. - // arthur.add_scalars(&[r])?; + arthur.add_scalars(&[r_q])?; // Output the current protocol transcript as a sequence of bytes. Ok(arthur.transcript()) @@ -87,27 +89,28 @@ where /// - the secret key `witness` /// It returns a zero-knowledge proof of knowledge of `witness` as a sequence of bytes. #[allow(non_snake_case)] -fn averify<'a, G, H>( +fn verify<'a, G, H, U>( // `ArkGroupMelin` contains the veirifier state, including the messages currently read. In addition, it is aware of the group `G` // from which it can serialize/deserialize elements. - merlin: &mut Merlin<'a, H, G::BaseField>, + merlin: &mut Merlin<'a, H, U>, // The group generator `P`` P: G, // The public key `X` X: G, ) -> ProofResult<()> where - G::BaseField: Unit, + U: Unit, + G::BaseField: PrimeField, G: CurveGroup, - H: DuplexHash, - Merlin<'a, H, G::BaseField>: GroupReader + ByteChallenges, + H: DuplexHash, + Merlin<'a, H, U>: GroupReader + FieldReader + ByteChallenges, { // Read the protocol from the transcript: let [K] = merlin.next_points().unwrap(); let c_bytes = merlin.challenge_bytes::<16>().unwrap(); let c = G::ScalarField::from_le_bytes_mod_order(&c_bytes); - let r = G::ScalarField::from(0); - // let [r] = merlin.next_scalars().unwrap(); + let [r_q] = merlin.next_scalars().unwrap(); + let r = swap_field::(r_q)?; // Check the verification equation, otherwise return a verification error. // The type ProofError is an enum that can report: @@ -136,8 +139,12 @@ fn main() { // type H = nimue::hash::legacy::DigestBridge; type H = nimue::plugins::ark::poseidon::PoseidonHash; + // + // type U = u8; + type U = Fq; + // Set up the IO for the protocol transcript with domain separator "nimue::examples::schnorr" - let io = IOPattern::::new("nimue::examples::schnorr"); + let io = IOPattern::::new("nimue::examples::schnorr"); let io = SchnorrIOPattern::::add_schnorr_io(io); // Set up the elements to prove @@ -145,17 +152,17 @@ fn main() { let (x, X) = keygen(); // Create the prover transcript, add the statement to it, and then invoke the prover. - let mut arthur: Arthur = Arthur::::new(&io, OsRng); + let mut arthur = Arthur::::new(&io, OsRng); arthur.public_points(&[P, X]).unwrap(); arthur.ratchet().unwrap(); - let proof = aprove(&mut arthur, P, x).expect("Invalid proof"); + let proof = prove(&mut arthur, P, x).expect("Invalid proof"); // Print out the hex-encoded schnorr proof. println!("Here's a Schnorr signature:\n{}", hex::encode(proof)); // Verify the proof: create the verifier transcript, add the statement to it, and invoke the verifier. - let mut merlin = Merlin::::new(&io, &proof); + let mut merlin = Merlin::::new(&io, &proof); merlin.public_points(&[P, X]).unwrap(); merlin.ratchet().unwrap(); - averify(&mut merlin, P, X).expect("Invalid proof"); + verify(&mut merlin, P, X).expect("Invalid proof"); } diff --git a/src/errors.rs b/src/errors.rs index 9aebc2f..ccb1bc5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -76,3 +76,9 @@ impl> From for ProofError { ProofError::InvalidIO(value.borrow().clone()) } } + +impl From for IOPatternError { + fn from(value: std::io::Error) -> Self { + IOPatternError(value.to_string()) + } +} diff --git a/src/merlin.rs b/src/merlin.rs index 68b9ec4..5e30cc7 100644 --- a/src/merlin.rs +++ b/src/merlin.rs @@ -29,8 +29,9 @@ impl<'a, U: Unit, H: DuplexHash> Merlin<'a, H, U> { /// Read `input.len()` elements from the transcript. #[inline(always)] pub fn fill_next(&mut self, input: &mut [U]) -> Result<(), IOPatternError> { - U::read(&mut self.transcript, input).unwrap(); - self.safe.absorb(input) + U::read(&mut self.transcript, input)?; + self.safe.absorb(input)?; + Ok(()) } /// Signals the end of the statement. diff --git a/src/plugins/ark/common.rs b/src/plugins/ark/common.rs index 06f7349..5c53778 100644 --- a/src/plugins/ark/common.rs +++ b/src/plugins/ark/common.rs @@ -1,7 +1,7 @@ use std::io; use ark_ec::{AffineRepr, CurveGroup}; -use ark_ff::{Fp, FpConfig, PrimeField}; +use ark_ff::{Field, Fp, FpConfig, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError}; use rand::{CryptoRng, RngCore}; @@ -25,8 +25,10 @@ impl, const N: usize> Unit for Fp { fn read(mut r: &mut impl io::Read, bunch: &mut [Self]) -> Result<(), io::Error> { for b in bunch.iter_mut() { - *b = ark_ff::Fp::::deserialize_compressed(&mut r) - .map_err(|_| io::Error::new(io::ErrorKind::Other, "oh no!"))? + let b_result = Fp::deserialize_compressed(&mut r); + *b = b_result.map_err(|_| { + io::Error::new(io::ErrorKind::Other, "Unable to deserialize into Field.") + })? } Ok(()) } @@ -58,7 +60,7 @@ where impl FieldPublic for T where - F: PrimeField, + F: Field, T: UnitTranscript, { type Repr = Vec; @@ -83,7 +85,7 @@ where for o in output.iter_mut() { self.fill_challenge_bytes(&mut buf)?; - *o = F::from_be_bytes_mod_order(&buf); + *o = F::from_be_bytes_mod_order(&buf).into(); } Ok(()) } @@ -172,7 +174,6 @@ where } } - // Field <-> Bytes interactions: impl<'a, H, C, const N: usize> BytePublic for Merlin<'a, H, Fp> diff --git a/src/plugins/ark/iopattern.rs b/src/plugins/ark/iopattern.rs index 09cddcb..4995454 100644 --- a/src/plugins/ark/iopattern.rs +++ b/src/plugins/ark/iopattern.rs @@ -43,7 +43,7 @@ where fn challenge_bytes(self, count: usize, label: &str) -> Self { let n = bytes_uniform_modp(Fp::::MODULUS_BIT_SIZE); - self.absorb((count + n - 1) / n, label) + self.squeeze((count + n - 1) / n, label) } } @@ -62,7 +62,7 @@ where G: CurveGroup>, H: DuplexHash>, C: FpConfig, - IOPattern>: FieldIOPattern, + IOPattern>: FieldIOPattern>, { fn add_points(self, count: usize, label: &str) -> Self { self.absorb(count * 2, label) diff --git a/src/plugins/ark/mod.rs b/src/plugins/ark/mod.rs index 557abe4..770ecb1 100644 --- a/src/plugins/ark/mod.rs +++ b/src/plugins/ark/mod.rs @@ -8,8 +8,20 @@ pub mod poseidon; pub use crate::traits::*; pub use crate::{hash::Unit, Arthur, DuplexHash, IOPattern, Merlin, ProofError, ProofResult, Safe}; -super::traits::field_traits!(ark_ff::PrimeField); -super::traits::group_traits!(ark_ec::CurveGroup, G::Scalar : ark_ff::Field); +super::traits::field_traits!(ark_ff::Field); +super::traits::group_traits!(ark_ec::CurveGroup, G::BaseField : ark_ff::PrimeField); + +/// Move a value from one field to another. +/// +/// Return an error if the value is larger than the destination field. +pub fn swap_field(a_f1: F1) -> ProofResult { + use ark_ff::BigInteger; + let a_f2 = F2::from_le_bytes_mod_order(&a_f1.into_bigint().to_bytes_le()); + let a_f1_control = F1::from_le_bytes_mod_order(&a_f2.into_bigint().to_bytes_le()); + (a_f1 == a_f1_control) + .then(|| a_f2) + .ok_or(ProofError::SerializationError) +} // pub trait PairingReader: GroupReader + GroupReader { // fn fill_next_g1_points(&mut self, input: &mut [P::G1]) -> crate::ProofResult<()> { diff --git a/src/plugins/ark/poseidon.rs b/src/plugins/ark/poseidon.rs index bd0ce04..dce3c01 100644 --- a/src/plugins/ark/poseidon.rs +++ b/src/plugins/ark/poseidon.rs @@ -1,12 +1,13 @@ -use std::ops::RangeTo; - +use crate::{hash::sponge::DuplexSponge, Unit}; use ark_bls12_381::Fq; use ark_ff::{PrimeField, Zero}; +use std::ops::RangeTo; -type F = Fq; +type FF = Fq; +pub type PoseidonHash = DuplexSponge>; #[derive(Clone, Debug)] -pub struct PoseidonConfig { +pub struct PoseidonConfig { /// Number of rounds in a full-round operation. pub full_rounds: usize, /// Number of rounds in a partial-round operation. @@ -18,15 +19,188 @@ pub struct PoseidonConfig { pub ark: &'static [[F; 3]], /// Maximally Distance Separating (MDS) Matrix. pub mds: &'static [[F; 3]], - /// The rate (in terms of number of field elements). - /// See [On the Indifferentiability of the Sponge Construction](https://iacr.org/archive/eurocrypt2008/49650180/49650180.pdf) - /// for more details on the rate and capacity of a sponge. - pub rate: usize, - /// The capacity (in terms of number of field elements). - pub capacity: usize, } -const MDS: &'static [[F; 3]] = &[ +/// Generate default parameters (bls381-fr-only) for alpha = 17, state-size = 8 +const BLS12381POSEIDON_CONF: PoseidonConfig = { + let alpha = 17; + let full_rounds = 8; + let total_rounds = 37; + let partial_rounds = total_rounds - full_rounds; + PoseidonConfig { + full_rounds, + partial_rounds, + alpha, + ark: ARK, + mds: MDS, + } +}; + +#[derive(Clone)] +pub struct PoseidonSponge { + /// Sponge Config + pub parameters: PoseidonConfig, + + // Sponge State + /// Current sponge's state (current elements in the permutation block) + pub state: Vec, +} + +impl PoseidonSponge { + fn apply_s_box(&self, state: &mut [F], is_full_round: bool) { + // Full rounds apply the S Box (x^alpha) to every element of state + if is_full_round { + for elem in state { + *elem = elem.pow(&[self.parameters.alpha]); + } + } + // Partial rounds apply the S Box (x^alpha) to just the first element of state + else { + state[0] = state[0].pow(&[self.parameters.alpha]); + } + } + + fn apply_ark(&self, state: &mut [F], round_number: usize) { + for (i, state_elem) in state.iter_mut().enumerate() { + state_elem.add_assign(&self.parameters.ark[round_number][i]); + } + } + + fn apply_mds(&self, state: &mut [F]) { + let mut new_state = Vec::new(); + for i in 0..state.len() { + let mut cur = F::zero(); + for (j, state_elem) in state.iter().enumerate() { + let term = state_elem.mul(&self.parameters.mds[i][j]); + cur.add_assign(&term); + } + new_state.push(cur); + } + state.clone_from_slice(&new_state[..state.len()]) + } +} + +impl Default for PoseidonSponge { + fn default() -> Self { + PoseidonSponge { + parameters: BLS12381POSEIDON_CONF.clone(), + state: vec![FF::zero(); 2 + 1], + } + } +} + +macro_rules! impl_index { + ($trait: ty, $struct: ident, Output = $output: ident, Params = [$($type:ident : $trait:ident),*], Constants = $($constgen:ident),*) => { + impl<$($type: $trait,)* $(const $constgen: usize,)*> $trait for $struct<$($type,)* $($constgen,)*> { + type Output = $output; + + fn index(&self, index: usize) -> &Self::Output { + &self.state[index] + } + } + }; +} + +impl_index!(std::ops::Index, PoseidonSponge, Output = F, Params = [F: PrimeField], Constants = RATE, CAPACITY); + +impl std::ops::Index for PoseidonSponge { + type Output = F; + + fn index(&self, index: usize) -> &Self::Output { + &self.state[index] + } +} + + + +impl std::ops::Index> for PoseidonSponge { + type Output = [F]; + + fn index(&self, index: RangeTo) -> &Self::Output { + &self.state[index] + } +} + +impl std::ops::IndexMut> for PoseidonSponge { + fn index_mut(&mut self, index: RangeTo) -> &mut Self::Output { + &mut self.state[index] + } +} + +impl std::ops::Index> for PoseidonSponge { + type Output = [F]; + + fn index(&self, index: std::ops::Range) -> &Self::Output { + &self.state[index] + } +} + +impl std::ops::IndexMut> for PoseidonSponge { + fn index_mut(&mut self, index: std::ops::Range) -> &mut Self::Output { + &mut self.state[index] + } +} + +impl std::ops::Index> for PoseidonSponge { + type Output = [F]; + + fn index(&self, index: std::ops::RangeFrom) -> &Self::Output { + &self.state[index] + } +} + +impl std::ops::IndexMut> for PoseidonSponge { + fn index_mut(&mut self, index: std::ops::RangeFrom) -> &mut Self::Output { + &mut self.state[index] + } +} + +impl zeroize::Zeroize for PoseidonSponge { + fn zeroize(&mut self) { + self.state.zeroize(); + } +} + +impl crate::hash::sponge::Sponge for PoseidonSponge + where PoseidonSponge: Default, F: Unit { + type U = F; + const CAPACITY: usize = 1; + const RATE: usize = 2; + + fn new(iv: [u8; 32]) -> Self { + assert!(Self::CAPACITY >= 1); + let mut ark_sponge = Self::default(); + ark_sponge.state[Self::RATE] = F::from_be_bytes_mod_order(&iv); + ark_sponge + } + + fn permute(&mut self) { + let full_rounds_over_2 = self.parameters.full_rounds / 2; + let mut state = self.state.clone(); + for i in 0..full_rounds_over_2 { + self.apply_ark(&mut state, i); + self.apply_s_box(&mut state, true); + self.apply_mds(&mut state); + } + + for i in full_rounds_over_2..(full_rounds_over_2 + self.parameters.partial_rounds) { + self.apply_ark(&mut state, i); + self.apply_s_box(&mut state, false); + self.apply_mds(&mut state); + } + + for i in (full_rounds_over_2 + self.parameters.partial_rounds) + ..(self.parameters.partial_rounds + self.parameters.full_rounds) + { + self.apply_ark(&mut state, i); + self.apply_s_box(&mut state, true); + self.apply_mds(&mut state); + } + self.state = state; + } +} + +const MDS: &'static [[FF; 3]] = &[ [ ark_ff::MontFp!( "43228725308391137369947362226390319299014033584574058394339561338097152657858" @@ -61,7 +235,7 @@ const MDS: &'static [[F; 3]] = &[ ), ], ]; -const ARK: &'static [[F; 3]] = &[ +const ARK: &'static [[FF; 3]] = &[ [ ark_ff::MontFp!( "44595993092652566245296379427906271087754779418564084732265552598173323099784" @@ -470,188 +644,3 @@ const ARK: &'static [[F; 3]] = &[ ), ], ]; - -/// Generate default parameters (bls381-fr-only) for alpha = 17, state-size = 8 -pub(crate) const fn poseidon_parameters_for_test() -> PoseidonConfig { - let alpha = 17; - let full_rounds = 8; - let total_rounds = 37; - let partial_rounds = total_rounds - full_rounds; - let capacity = 1; - let rate = 2; - PoseidonConfig { - full_rounds, - partial_rounds, - alpha, - ark: ARK, - mds: MDS, - rate, - capacity, - } -} - -#[derive(Clone)] -pub struct PoseidonSponge { - /// Sponge Config - pub parameters: PoseidonConfig, - - // Sponge State - /// Current sponge's state (current elements in the permutation block) - pub state: Vec, -} - -impl PoseidonSponge { - fn apply_s_box(&self, state: &mut [F], is_full_round: bool) { - // Full rounds apply the S Box (x^alpha) to every element of state - if is_full_round { - for elem in state { - *elem = elem.pow(&[self.parameters.alpha]); - } - } - // Partial rounds apply the S Box (x^alpha) to just the first element of state - else { - state[0] = state[0].pow(&[self.parameters.alpha]); - } - } - - fn apply_ark(&self, state: &mut [F], round_number: usize) { - for (i, state_elem) in state.iter_mut().enumerate() { - state_elem.add_assign(&self.parameters.ark[round_number][i]); - } - } - - fn apply_mds(&self, state: &mut [F]) { - let mut new_state = Vec::new(); - for i in 0..state.len() { - let mut cur = F::zero(); - for (j, state_elem) in state.iter().enumerate() { - let term = state_elem.mul(&self.parameters.mds[i][j]); - cur.add_assign(&term); - } - new_state.push(cur); - } - state.clone_from_slice(&new_state[..state.len()]) - } - - fn permute(&mut self) { - let full_rounds_over_2 = self.parameters.full_rounds / 2; - let mut state = self.state.clone(); - for i in 0..full_rounds_over_2 { - self.apply_ark(&mut state, i); - self.apply_s_box(&mut state, true); - self.apply_mds(&mut state); - } - - for i in full_rounds_over_2..(full_rounds_over_2 + self.parameters.partial_rounds) { - self.apply_ark(&mut state, i); - self.apply_s_box(&mut state, false); - self.apply_mds(&mut state); - } - - for i in (full_rounds_over_2 + self.parameters.partial_rounds) - ..(self.parameters.partial_rounds + self.parameters.full_rounds) - { - self.apply_ark(&mut state, i); - self.apply_s_box(&mut state, true); - self.apply_mds(&mut state); - } - self.state = state; - } -} - -const BLS12381POSEIDON_CONF: PoseidonConfig = poseidon_parameters_for_test(); - -#[derive(Clone)] -pub struct Bls12381Poseidon(PoseidonSponge); - -impl Default for Bls12381Poseidon { - fn default() -> Self { - Self(PoseidonSponge { - parameters: BLS12381POSEIDON_CONF.clone(), - state: vec![F::zero(); BLS12381POSEIDON_CONF.rate + BLS12381POSEIDON_CONF.capacity], - }) - } -} - -impl std::ops::Index for Bls12381Poseidon { - type Output = F; - - fn index(&self, index: usize) -> &Self::Output { - &self.0.state[index] - } -} - -impl std::ops::IndexMut for Bls12381Poseidon { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.0.state[index] - } -} - -impl std::ops::Index> for Bls12381Poseidon { - type Output = [F]; - - fn index(&self, index: RangeTo) -> &Self::Output { - &self.0.state[index] - } -} - -impl std::ops::IndexMut> for Bls12381Poseidon { - fn index_mut(&mut self, index: RangeTo) -> &mut Self::Output { - &mut self.0.state[index] - } -} - -impl std::ops::Index> for Bls12381Poseidon { - type Output = [F]; - - fn index(&self, index: std::ops::Range) -> &Self::Output { - &self.0.state[index] - } -} - -impl std::ops::IndexMut> for Bls12381Poseidon { - fn index_mut(&mut self, index: std::ops::Range) -> &mut Self::Output { - &mut self.0.state[index] - } -} - -impl std::ops::Index> for Bls12381Poseidon { - type Output = [F]; - - fn index(&self, index: std::ops::RangeFrom) -> &Self::Output { - &self.0.state[index] - } -} - -impl std::ops::IndexMut> for Bls12381Poseidon { - fn index_mut(&mut self, index: std::ops::RangeFrom) -> &mut Self::Output { - &mut self.0.state[index] - } -} - -impl zeroize::Zeroize for Bls12381Poseidon { - fn zeroize(&mut self) { - self.0.state.zeroize(); - } -} - -impl crate::hash::sponge::Sponge for Bls12381Poseidon { - type U = F; - - const CAPACITY: usize = 1; - - const RATE: usize = 2; - - fn new(iv: [u8; 32]) -> Self { - assert!(Self::CAPACITY >= 1); - let mut ark_sponge = Self::default(); - ark_sponge.0.state[Self::RATE] = F::from_be_bytes_mod_order(&iv); - ark_sponge - } - - fn permute(&mut self) { - self.0.permute(); - } -} - -pub type PoseidonHash = crate::hash::sponge::DuplexSponge; diff --git a/src/plugins/ark/reader.rs b/src/plugins/ark/reader.rs index 03f7542..236dae1 100644 --- a/src/plugins/ark/reader.rs +++ b/src/plugins/ark/reader.rs @@ -1,5 +1,9 @@ +use ark_ec::short_weierstrass::{Affine as SWAffine, Projective as SWCurve, SWCurveConfig}; +use ark_ec::twisted_edwards::{Affine as EdwardsAffine, Projective as EdwardsCurve, TECurveConfig}; use ark_ec::CurveGroup; -use ark_ff::{Fp, FpConfig, PrimeField}; +use ark_ff::Field; +use ark_ff::{Fp, FpConfig}; +use ark_serialize::CanonicalDeserialize; use super::{FieldReader, GroupReader}; use crate::traits::*; @@ -7,7 +11,7 @@ use crate::{DuplexHash, Merlin, ProofResult}; impl<'a, F, H> FieldReader for Merlin<'a, H> where - F: PrimeField, + F: Field, H: DuplexHash, { fn fill_next_scalars(&mut self, output: &mut [F]) -> ProofResult<()> { @@ -38,13 +42,45 @@ where } } -impl<'a, G, H, C, const N: usize> GroupReader for Merlin<'a, H, Fp> +impl<'a, H, C, const N: usize> FieldReader> for Merlin<'a, H, Fp> where C: FpConfig, - G: CurveGroup>, - H: DuplexHash, + H: DuplexHash>, { - fn fill_next_points(&mut self, _output: &mut [G]) -> ProofResult<()> { - todo!() + fn fill_next_scalars(&mut self, output: &mut [Fp]) -> crate::ProofResult<()> { + self.fill_next(output)?; + Ok(()) + } +} + +impl<'a, P, H, C, const N: usize> GroupReader> for Merlin<'a, H, Fp> +where + C: FpConfig, + H: DuplexHash>, + P: TECurveConfig>, +{ + fn fill_next_points(&mut self, output: &mut [EdwardsCurve

]) -> ProofResult<()> { + for o in output.iter_mut() { + let o_affine = EdwardsAffine::deserialize_compressed(&mut self.transcript)?; + *o = o_affine.into(); + self.public_units(&[o.x, o.y])?; + } + Ok(()) + } +} + +impl<'a, P, H, C, const N: usize> GroupReader> for Merlin<'a, H, Fp> +where + C: FpConfig, + H: DuplexHash>, + P: SWCurveConfig>, +{ + fn fill_next_points(&mut self, output: &mut [SWCurve

]) -> ProofResult<()> { + for o in output.iter_mut() { + let o_affine = SWAffine::deserialize_compressed(&mut self.transcript)?; + *o = o_affine.into(); + self.public_units(&[o.x, o.y])?; + } + Ok(()) } } diff --git a/src/plugins/ark/writer.rs b/src/plugins/ark/writer.rs index 987c149..ba51eb6 100644 --- a/src/plugins/ark/writer.rs +++ b/src/plugins/ark/writer.rs @@ -1,9 +1,10 @@ use ark_ec::CurveGroup; use ark_ff::{Fp, FpConfig, PrimeField}; +use ark_serialize::CanonicalSerialize; use rand::{CryptoRng, RngCore}; use super::{FieldPublic, FieldWriter, GroupPublic, GroupWriter}; -use crate::{Arthur, DuplexHash, ProofResult}; +use crate::{Arthur, DuplexHash, ProofResult, UnitTranscript}; impl FieldWriter for Arthur { fn add_scalars(&mut self, input: &[F]) -> ProofResult<()> { @@ -13,10 +14,23 @@ impl FieldWriter for Ar } } +impl, H: DuplexHash>, R: RngCore + CryptoRng, const N: usize> + FieldWriter> for Arthur> +{ + fn add_scalars(&mut self, input: &[Fp]) -> ProofResult<()> { + self.public_units(input)?; + for i in input { + i.serialize_compressed(&mut self.transcript)?; + } + Ok(()) + } +} + impl GroupWriter for Arthur where G: CurveGroup, H: DuplexHash, + G::BaseField: PrimeField, R: RngCore + CryptoRng, Arthur: GroupPublic>, { @@ -28,12 +42,13 @@ where } } -impl, const N: usize> GroupWriter for Arthur> +impl, C2: FpConfig, const N: usize> GroupWriter + for Arthur> where - G: CurveGroup, + G: CurveGroup>, H: DuplexHash>, R: RngCore + CryptoRng, - Arthur>: GroupPublic, + Arthur>: GroupPublic + FieldWriter, { #[inline(always)] fn add_points(&mut self, input: &[G]) -> ProofResult<()> { diff --git a/src/plugins/traits.rs b/src/plugins/traits.rs index c37c83d..a97e5ad 100644 --- a/src/plugins/traits.rs +++ b/src/plugins/traits.rs @@ -21,11 +21,11 @@ macro_rules! field_traits { fn public_scalars(&mut self, input: &[F]) -> crate::ProofResult; } - pub trait FieldWriter: FieldChallenges + FieldPublic { + pub trait FieldWriter: FieldPublic { fn add_scalars(&mut self, input: &[F]) -> crate::ProofResult<()>; } - pub trait FieldReader: FieldChallenges + FieldPublic { + pub trait FieldReader: FieldPublic { fn fill_next_scalars(&mut self, output: &mut [F]) -> crate::ProofResult<()>; fn next_scalars(&mut self) -> crate::ProofResult<[F; N]> { @@ -38,7 +38,7 @@ macro_rules! field_traits { #[macro_export] macro_rules! group_traits { - ($Group:path, $ScalarField:path : $Field:path) => { + ($Group:path, $BaseField:path : $Field:path) => { pub trait GroupIOPattern { fn add_points(self, count: usize, label: &str) -> Self; }