diff --git a/primitives/src/vid/advz.rs b/primitives/src/vid/advz.rs index fab0bd425..0db7978a0 100644 --- a/primitives/src/vid/advz.rs +++ b/primitives/src/vid/advz.rs @@ -1024,7 +1024,7 @@ mod tests { #[test] fn sad_path_verify_share_corrupt_share() { - let (mut advz, bytes_random) = advz_init(); + let (mut advz, bytes_random, _) = advz_init(); let disperse = advz.disperse(bytes_random).unwrap(); let (shares, common, commit) = (disperse.shares, disperse.common, disperse.commit); @@ -1090,7 +1090,7 @@ mod tests { #[test] fn sad_path_verify_share_corrupt_commit() { - let (mut advz, bytes_random) = advz_init(); + let (mut advz, bytes_random, _) = advz_init(); let disperse = advz.disperse(bytes_random).unwrap(); let (shares, common, commit) = (disperse.shares, disperse.common, disperse.commit); @@ -1136,7 +1136,7 @@ mod tests { #[test] fn sad_path_verify_share_corrupt_share_and_commit() { - let (mut advz, bytes_random) = advz_init(); + let (mut advz, bytes_random, _) = advz_init(); let disperse = advz.disperse(bytes_random).unwrap(); let (mut shares, mut common, commit) = (disperse.shares, disperse.common, disperse.commit); @@ -1161,8 +1161,8 @@ mod tests { #[test] fn sad_path_recover_payload_corrupt_shares() { - let (mut advz, bytes_random) = advz_init(); - let disperse = advz.disperse(&bytes_random).unwrap(); + let (mut advz, bytes_random, _) = advz_init(); + let disperse = advz.disperse(bytes_random.clone()).unwrap(); let (shares, common) = (disperse.shares, disperse.common); { @@ -1221,13 +1221,13 @@ mod tests { /// Returns the following tuple: /// 1. An initialized [`Advz`] instance. /// 2. A `Vec` filled with random bytes. - pub(super) fn advz_init() -> (Advz, Vec) { + pub(super) fn advz_init() -> (Advz, Vec, u32) { let (recovery_threshold, num_storage_nodes) = (4, 6); let mut rng = jf_utils::test_rng(); let srs = init_srs(recovery_threshold as usize, &mut rng); let advz = Advz::new(num_storage_nodes, recovery_threshold, srs).unwrap(); let bytes_random = init_random_payload(4000, &mut rng); - (advz, bytes_random) + (advz, bytes_random, num_storage_nodes) } /// Convenience wrapper to assert [`VidError::Argument`] return value. diff --git a/primitives/src/vid/advz/precomputable.rs b/primitives/src/vid/advz/precomputable.rs index 2ae8ff3db..9c64dc54b 100644 --- a/primitives/src/vid/advz/precomputable.rs +++ b/primitives/src/vid/advz/precomputable.rs @@ -15,15 +15,18 @@ use crate::{ MaybeGPU, Pairing, PolynomialMultiplier, UnivariateKzgPCS, }, precomputable::Precomputable, - vid, VidDisperse, VidResult, + vid, VidDisperse, VidError, VidResult, }, }; +use alloc::string::ToString; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::{end_timer, start_timer, vec::Vec}; use jf_utils::canonical; use serde::{Deserialize, Serialize}; +use super::Advz; + impl Precomputable for AdvzInternal where E: Pairing, @@ -141,6 +144,26 @@ where commit, }) } + + fn is_consistent_precompute( + commit: &Self::Commit, + precompute_data: &Self::PrecomputeData, + payload_byte_len: u32, + num_storage_nodes: u32, + ) -> VidResult<()> { + if *commit + != Advz::::derive_commit( + &precompute_data.poly_commits, + payload_byte_len, + num_storage_nodes, + )? + { + return Err(VidError::Argument( + "precompute data inconsistent with commit".to_string(), + )); + } + Ok(()) + } } #[derive( @@ -222,7 +245,7 @@ mod tests { #[test] fn commit_disperse_recover_with_precomputed_data() { - let (advz, bytes_random) = advz_init(); + let (advz, bytes_random, _) = advz_init(); let (commit, data) = advz.commit_only_precompute(&bytes_random).unwrap(); let disperse = advz.disperse_precompute(&bytes_random, &data).unwrap(); let (shares, common) = (disperse.shares, disperse.common); @@ -236,4 +259,17 @@ mod tests { .expect("recover_payload should succeed"); assert_eq!(bytes_recovered, bytes_random); } + + #[test] + fn commit_and_verify_consistent_precomputed_data() { + let (advz, bytes_random, num_storage_nodes) = advz_init(); + let (commit, data) = advz.commit_only_precompute(&bytes_random).unwrap(); + assert!(Advz::is_consistent_precompute( + &commit, + &data, + bytes_random.len() as u32, + num_storage_nodes + ) + .is_ok()) + } } diff --git a/primitives/src/vid/precomputable.rs b/primitives/src/vid/precomputable.rs index 584704770..564fa0ec0 100644 --- a/primitives/src/vid/precomputable.rs +++ b/primitives/src/vid/precomputable.rs @@ -36,4 +36,13 @@ pub trait Precomputable: VidScheme { ) -> VidResult> where B: AsRef<[u8]>; + + /// Check that a [`Precomputable::PrecomputeData`] is consistent with a + /// [`VidScheme::Commit`]. + fn is_consistent_precompute( + commit: &Self::Commit, + precompute_data: &Self::PrecomputeData, + payload_byte_len: u32, + num_storage_nodes: u32, + ) -> VidResult<()>; }