diff --git a/src/ml_sumcheck/mod.rs b/src/ml_sumcheck/mod.rs index 4fd2f9f..735f0ab 100644 --- a/src/ml_sumcheck/mod.rs +++ b/src/ml_sumcheck/mod.rs @@ -62,6 +62,9 @@ impl MLSumcheck { prover_msgs.push(prover_msg); verifier_msg = Some(IPForMLSumcheck::sample_round(fs_rng)); } + prover_state + .randomness + .push(verifier_msg.unwrap().randomness); Ok((prover_msgs, prover_state)) } diff --git a/src/ml_sumcheck/test.rs b/src/ml_sumcheck/test.rs index 415f5ff..ae5d63c 100644 --- a/src/ml_sumcheck/test.rs +++ b/src/ml_sumcheck/test.rs @@ -107,7 +107,7 @@ fn test_polynomial_as_subprotocol( let (poly, asserted_sum) = random_list_of_products::(nv, num_multiplicands_range, num_products, &mut rng); let poly_info = poly.info(); - let (proof, _prover_state) = + let (proof, prover_state) = MLSumcheck::prove_as_subprotocol(prover_rng, &poly).expect("fail to prove"); let subclaim = MLSumcheck::verify_as_subprotocol(verifier_rng, &poly_info, asserted_sum, &proof) @@ -116,6 +116,7 @@ fn test_polynomial_as_subprotocol( poly.evaluate(&subclaim.point) == subclaim.expected_evaluation, "wrong subclaim" ); + assert_eq!(prover_state.randomness, subclaim.point); } #[test]