diff --git a/examples/sumcheck.rs b/examples/sumcheck.rs index 9917cfc..095dbff 100644 --- a/examples/sumcheck.rs +++ b/examples/sumcheck.rs @@ -23,7 +23,7 @@ where fn add_sumcheck(mut self, num_var: usize) -> Self { for _ in 0..num_var { self = self - .add_scalars(2, "partial evaluation") + .add_scalars(1, "partial evaluation, constant term") .challenge_scalars(1, "sumcheck challenge"); } self = self.add_scalars(1, "folded polynomial"); @@ -43,9 +43,9 @@ where let mut partial_poly = polynomial.clone(); for _ in 0..num_var { let eval = partial_poly.to_evaluations(); + // The partial polynomial of each round is of the form b * x + a. let a = eval.iter().step_by(2).sum(); - let b = eval.iter().skip(1).step_by(2).sum(); - merlin.add_scalars(&[a, b])?; + merlin.add_scalars(&[a])?; let [r] = merlin.challenge_scalars()?; partial_poly = partial_poly.fix_variables(&[r]); } @@ -67,12 +67,10 @@ where let mut value = value.clone(); let num_vars = polynomial.num_vars(); for _ in 0..num_vars { - let [a, b] = arthur.next_scalars()?; - if a + b != value { - return Err(ProofError::InvalidProof); - } + let [a] = arthur.next_scalars()?; + let b = value - a - a; let [r] = arthur.challenge_scalars()?; - value = (b - a) * r + a; + value = b * r + a; } let [folded] = arthur.next_scalars()?; if folded != value { @@ -84,11 +82,10 @@ where fn main() { use ark_curve25519::Fq; - use rand::rngs::OsRng; use ark_poly::DenseMultilinearExtension; + use rand::rngs::OsRng; - let num_vars= 4; - + let num_vars = 4; // initialize the IO Pattern putting the domain separator ("example.com") let iopattern = IOPattern::new("example.com");