Skip to content

Commit

Permalink
Introduce ProofError/ProofResult for better error handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmaker committed Jan 18, 2024
1 parent 87a0129 commit 80c8ef7
Show file tree
Hide file tree
Showing 13 changed files with 140 additions and 105 deletions.
4 changes: 2 additions & 2 deletions examples/bulletproof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use ark_ec::{AffineRepr, CurveGroup, VariableBaseMSM};
use ark_ff::Field;
use ark_std::log2;
use nimue::plugins::arkworks::prelude::*;
use nimue::{DuplexHash, InvalidTag};
use nimue::{DuplexHash, IOPatternError};
use rand::rngs::OsRng;

fn fold_generators<A: AffineRepr>(
Expand Down Expand Up @@ -126,7 +126,7 @@ fn verify<G, H>(
generators: (&[G::Affine], &[G::Affine], &G::Affine),
mut n: usize,
statement: &G,
) -> Result<(), InvalidTag>
) -> Result<(), IOPatternError>
where
H: DuplexHash<u8>,
G: CurveGroup,
Expand Down
4 changes: 2 additions & 2 deletions examples/schnorr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use ark_ec::{CurveGroup, PrimeGroup};
use ark_std::UniformRand;
use nimue::{DuplexHash, InvalidTag};
use nimue::{DuplexHash, IOPatternError};

use nimue::plugins::arkworks::prelude::*;
use rand::rngs::OsRng;
Expand All @@ -14,7 +14,7 @@ fn keygen<G: CurveGroup>() -> (G::ScalarField, G) {
fn prove<H: DuplexHash<u8>, G: CurveGroup>(
arthur: &mut ArkGroupArthur<G, H>,
witness: G::ScalarField,
) -> Result<&[u8], InvalidTag> {
) -> Result<&[u8], IOPatternError> {
let k = G::ScalarField::rand(&mut arthur.rng());
let commitment = G::generator() * k;
arthur.add_points(&[commitment])?;
Expand Down
14 changes: 7 additions & 7 deletions src/arthur.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::hash::Unit;
use crate::{IOPattern, Safe};

use super::hash::{DuplexHash, Keccak};
use super::{DefaultHash, DefaultRng, InvalidTag};
use super::{DefaultHash, DefaultRng, IOPatternError};

/// A cryptographically-secure random number generator that is bound to the protocol transcript.
///
Expand Down Expand Up @@ -100,7 +100,7 @@ where

impl<R: RngCore + CryptoRng, U: Unit, H: DuplexHash<U>> Arthur<H, R, U> {
#[inline(always)]
pub fn add(&mut self, input: &[U]) -> Result<(), InvalidTag> {
pub fn add(&mut self, input: &[U]) -> Result<(), IOPatternError> {
// let serialized = bincode::serialize(input).unwrap();
// self.arthur.sponge.absorb_unchecked(&serialized);
let old_len = self.transcript.len();
Expand All @@ -114,19 +114,19 @@ impl<R: RngCore + CryptoRng, U: Unit, H: DuplexHash<U>> Arthur<H, R, U> {
Ok(())
}

pub fn public(&mut self, input: &[U]) -> Result<(), InvalidTag> {
pub fn public(&mut self, input: &[U]) -> Result<(), IOPatternError> {
let len = self.transcript.len();
self.add(input)?;
self.transcript.truncate(len);
Ok(())
}

pub fn challenge(&mut self, output: &mut [U]) -> Result<(), InvalidTag> {
pub fn challenge(&mut self, output: &mut [U]) -> Result<(), IOPatternError> {
self.safe.squeeze(output)
}

#[inline(always)]
pub fn ratchet(&mut self) -> Result<(), InvalidTag> {
pub fn ratchet(&mut self) -> Result<(), IOPatternError> {
self.safe.ratchet()
}

Expand All @@ -150,12 +150,12 @@ impl<R: RngCore + CryptoRng, U: Unit, H: DuplexHash<U>> core::fmt::Debug for Art

impl<H: DuplexHash<u8>, R: RngCore + CryptoRng> Arthur<H, R, u8> {
#[inline(always)]
pub fn add_bytes(&mut self, input: &[u8]) -> Result<(), InvalidTag> {
pub fn add_bytes(&mut self, input: &[u8]) -> Result<(), IOPatternError> {
self.add(input)
}

#[inline(always)]
pub fn challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), InvalidTag> {
pub fn challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), IOPatternError> {
self.safe.squeeze(output)
}
}
40 changes: 34 additions & 6 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,56 @@
use std::{error::Error, fmt::Display};
use std::{error::Error, fmt::Display, borrow::Borrow};

/// Signals an invalid IO pattern.
///
/// This error indicates a wrong IO Pattern declared
/// upon instantiation of the SAFE sponge.
#[derive(Debug, Clone)]
pub struct InvalidTag(String);
pub struct IOPatternError(String);

impl Display for InvalidTag {
#[derive(Debug, Clone)]
pub enum ProofError {
InvalidProof,
InvalidIO(IOPatternError),
SerializationError,
}


pub type ProofResult<T> = Result<T, ProofError>;

impl Display for IOPatternError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.0)
}
}

impl Error for InvalidTag {}
impl Display for ProofError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SerializationError =>
write!(f, "Serialization Error"),
Self::InvalidIO(e) => e.fmt(f),
Self::InvalidProof => write!(f, "Invalid proof")
}
}
}

impl Error for IOPatternError {}
impl Error for ProofError {}

impl From<&str> for InvalidTag {
impl From<&str> for IOPatternError {
fn from(s: &str) -> Self {
s.to_string().into()
}
}

impl From<String> for InvalidTag {
impl From<String> for IOPatternError {
fn from(s: String) -> Self {
Self(s)
}
}

impl<B: Borrow<IOPatternError>> From<B> for ProofError {
fn from(value: B) -> Self {
ProofError::InvalidIO(value.borrow().clone())
}
}
29 changes: 16 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,31 +118,34 @@ This crate doesn't support big-endian targets.
"#
);

/// Built-in proof results
mod errors;
/// Support for hash functions.
pub mod hash;
/// APIs for common zkp libraries.
pub mod plugins;
/// Prover's internal state and transcript generation.
mod arthur;
/// Verifier state and transcript deserialization.
mod merlin;
/// SAFE API
mod safe;
/// Unit-tests.
#[cfg(test)]
mod tests;


pub use arthur::Arthur;
pub use errors::InvalidTag;
pub use hash::DuplexHash;
pub use errors::IOPatternError;
pub use hash::{DuplexHash, Unit};
pub use merlin::Merlin;
pub use safe::{IOPattern, Safe};
pub use errors::{ProofResult, ProofError};

// Default random number generator used ([`rand::rngs::OsRng`])
pub type DefaultRng = rand::rngs::OsRng;

/// Default hash function used ([`hash::Keccak`])
pub type DefaultHash = hash::Keccak;

/// Prover's internal state and transcript generation.
mod arthur;
/// Error types.
mod errors;
/// Verifier state and transcript deserialization.
mod merlin;
/// SAFE API
mod safe;
/// Unit-tests.
#[cfg(test)]
mod tests;

20 changes: 10 additions & 10 deletions src/merlin.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::hash::DuplexHash;
use crate::DefaultHash;

use crate::errors::InvalidTag;
use crate::errors::IOPatternError;
use crate::hash::Unit;
use crate::safe::{IOPattern, Safe};

Expand All @@ -28,31 +28,31 @@ impl<'a, U: Unit, H: DuplexHash<U>> Merlin<'a, H, U> {

/// Read `input.len()` elements from the transcript.
#[inline(always)]
pub fn fill_next(&mut self, input: &mut [U]) -> Result<(), InvalidTag> {
pub fn fill_next(&mut self, input: &mut [U]) -> Result<(), IOPatternError> {
U::read(&mut self.transcript, input).unwrap();
self.safe.absorb(input)
}

#[inline(always)]
pub fn public_input(&mut self, input: &[U]) -> Result<(), InvalidTag> {
pub fn public_input(&mut self, input: &[U]) -> Result<(), IOPatternError> {
self.safe.absorb(input)
}

/// Get a challenge of `count` elements.
#[inline(always)]
pub fn fill_challenges(&mut self, input: &mut [U]) -> Result<(), InvalidTag> {
pub fn fill_challenges(&mut self, input: &mut [U]) -> Result<(), IOPatternError> {
self.safe.squeeze(input)
}

/// Signals the end of the statement.
#[inline(always)]
pub fn ratchet(&mut self) -> Result<(), InvalidTag> {
pub fn ratchet(&mut self) -> Result<(), IOPatternError> {
self.safe.ratchet()
}

/// Signals the end of the statement and returns the (compressed) sponge state.
#[inline(always)]
pub fn preprocess(self) -> Result<&'static [U], InvalidTag> {
pub fn preprocess(self) -> Result<&'static [U], IOPatternError> {
self.safe.preprocess()
}
}
Expand All @@ -65,21 +65,21 @@ impl<'a, H: DuplexHash<U>, U: Unit> core::fmt::Debug for Merlin<'a, H, U> {

impl<'a, H: DuplexHash<u8>> Merlin<'a, H, u8> {
#[inline(always)]
pub fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), InvalidTag> {
pub fn fill_next_bytes(&mut self, input: &mut [u8]) -> Result<(), IOPatternError> {
self.fill_next(input)
}

#[inline(always)]
pub fn fill_challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), InvalidTag> {
pub fn fill_challenge_bytes(&mut self, output: &mut [u8]) -> Result<(), IOPatternError> {
self.fill_challenges(output)
}

pub fn next_bytes<const N: usize>(&mut self) -> Result<[u8; N], InvalidTag> {
pub fn next_bytes<const N: usize>(&mut self) -> Result<[u8; N], IOPatternError> {
let mut input = [0u8; N];
self.fill_next_bytes(&mut input).map(|()| input)
}

pub fn challenge_bytes<const N: usize>(&mut self) -> Result<[u8; N], InvalidTag> {
pub fn challenge_bytes<const N: usize>(&mut self) -> Result<[u8; N], IOPatternError> {
let mut output = [0u8; N];
self.fill_challenge_bytes(&mut output).map(|()| output)
}
Expand Down
32 changes: 15 additions & 17 deletions src/plugins/arkworks/arthur.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ark_ec::CurveGroup;
use ark_ff::{Field, PrimeField};
use rand::CryptoRng;

use super::prelude::*;
use crate::{ProofResult, Arthur, DuplexHash, IOPattern, Unit};

pub struct ArkFieldArthur<F: Field, H = crate::DefaultHash, U = u8, R = crate::DefaultRng>
where
Expand Down Expand Up @@ -73,24 +73,23 @@ where
H: DuplexHash<u8>,
R: rand::RngCore + CryptoRng,
{
pub fn public_scalars(&mut self, input: &[F]) -> Result<Vec<u8>, InvalidTag> {
pub fn public_scalars(&mut self, input: &[F]) -> ProofResult<Vec<u8>> {
let mut buf = Vec::<u8>::new();

for scalar in input {
scalar
.serialize_compressed(&mut buf)
.expect("serialization failed");
.serialize_compressed(&mut buf)?
}
self.public(&buf).map(|()| buf)
self.public(&buf).map(|()| buf).map_err(|x| x.into())
}

fn add_scalars(&mut self, input: &[F]) -> Result<(), InvalidTag> {
fn add_scalars(&mut self, input: &[F]) -> ProofResult<()> {
let serialized = self.public_scalars(input);
self.arthur.transcript.extend(serialized?);
Ok(())
}

fn fill_challenge_scalars(&mut self, output: &mut [F]) -> Result<(), InvalidTag> {
fn fill_challenge_scalars(&mut self, output: &mut [F]) -> ProofResult<()> {
let mut buf = vec![0u8; super::f_bytes::<F>()];
for o in output.iter_mut() {
self.arthur.challenge_bytes(&mut buf)?;
Expand All @@ -99,7 +98,7 @@ where
Ok(())
}

pub fn challenge_scalars<const N: usize>(&mut self) -> Result<[F; N], InvalidTag> {
pub fn challenge_scalars<const N: usize>(&mut self) -> ProofResult<[F; N]> {
let mut output = [F::default(); N];
self.fill_challenge_scalars(&mut output)?;
Ok(output)
Expand Down Expand Up @@ -168,36 +167,35 @@ where
Arthur::new(io, csrng).into()
}

pub fn public_scalars(&mut self, input: &[G::ScalarField]) -> Result<Vec<u8>, InvalidTag> {
pub fn public_scalars(&mut self, input: &[G::ScalarField]) -> ProofResult<Vec<u8>> {
self.arthur.public_scalars(input)
}

pub fn add_scalars(&mut self, input: &[G::ScalarField]) -> Result<(), InvalidTag> {
pub fn add_scalars(&mut self, input: &[G::ScalarField]) -> ProofResult<()> {
self.arthur.add_scalars(input)
}

pub fn fill_challenge_scalars(
&mut self,
output: &mut [G::ScalarField],
) -> Result<(), InvalidTag> {
) -> ProofResult<()> {
self.arthur.fill_challenge_scalars(output)
}

pub fn challenge_scalars<const N: usize>(&mut self) -> Result<[G::ScalarField; N], InvalidTag> {
pub fn challenge_scalars<const N: usize>(&mut self) -> ProofResult<[G::ScalarField; N]> {
self.arthur.challenge_scalars()
}

pub fn public_points(&mut self, input: &[G]) -> Result<Vec<u8>, InvalidTag> {
pub fn public_points(&mut self, input: &[G]) -> ProofResult<Vec<u8>> {
let mut buf = Vec::new();
for point in input {
point
.serialize_compressed(&mut buf)
.expect("serialization failed");
.serialize_compressed(&mut buf)?
}
self.arthur.public(&buf).map(|()| buf)
self.arthur.public(&buf).map(|()| buf).map_err(|x| x.into())
}

pub fn add_points(&mut self, input: &[G]) -> Result<(), InvalidTag> {
pub fn add_points(&mut self, input: &[G]) -> ProofResult<()> {
let serialized = self.public_points(input);
self.arthur.transcript.extend(serialized?);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/arkworks/iopattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ark_ec::CurveGroup;
use ark_ff::{Field, PrimeField};
use core::ops::Deref;

use super::prelude::*;
use super::*;

pub struct ArkFieldIOPattern<F: Field, H = crate::DefaultHash, U = u8>
where
Expand Down
Loading

0 comments on commit 80c8ef7

Please sign in to comment.