diff --git a/mithril-client/src/certificate_client/verify.rs b/mithril-client/src/certificate_client/verify.rs index 04fbeb3de8..a1b95d3e4b 100644 --- a/mithril-client/src/certificate_client/verify.rs +++ b/mithril-client/src/certificate_client/verify.rs @@ -104,7 +104,7 @@ impl MithrilCertificateVerifier { Ok(None) } - async fn verify_one( + async fn verify_with_cache( &self, certificate_chain_validation_id: &str, certificate: CertificateToVerify, @@ -123,23 +123,25 @@ impl MithrilCertificateVerifier { hash: previous_hash, })) } else { - self.verify_not_cached_certificate(certificate_chain_validation_id, certificate) - .await + let certificate = match certificate { + CertificateToVerify::Downloaded { certificate } => certificate, + CertificateToVerify::ToDownload { hash } => { + self.retriever.get_certificate_details(&hash).await? + } + }; + + let previous_certificate = self + .verify_without_cache(certificate_chain_validation_id, certificate) + .await?; + Ok(previous_certificate.map(Into::into)) } } - async fn verify_not_cached_certificate( + async fn verify_without_cache( &self, certificate_chain_validation_id: &str, - certificate: CertificateToVerify, - ) -> MithrilResult> { - let certificate = match certificate { - CertificateToVerify::Downloaded { certificate } => certificate, - CertificateToVerify::ToDownload { hash } => { - self.retriever.get_certificate_details(&hash).await? - } - }; - + certificate: Certificate, + ) -> MithrilResult> { let previous_certificate = self .internal_verifier .verify_certificate(&certificate, &self.genesis_verification_key) @@ -162,7 +164,7 @@ impl MithrilCertificateVerifier { }) .await; - Ok(previous_certificate.map(|cert| CertificateToVerify::Downloaded { certificate: cert })) + Ok(previous_certificate) } } @@ -182,6 +184,12 @@ impl CertificateToVerify { } } +impl From for CertificateToVerify { + fn from(value: Certificate) -> Self { + Self::Downloaded { certificate: value } + } +} + #[cfg_attr(target_family = "wasm", async_trait(?Send))] #[cfg_attr(not(target_family = "wasm"), async_trait)] impl CertificateVerifier for MithrilCertificateVerifier { @@ -196,22 +204,36 @@ impl CertificateVerifier for MithrilCertificateVerifier { }) .await; - // always validate the given certificate even if it was cached - let mut current_certificate = self - .verify_not_cached_certificate( - &certificate_chain_validation_id, - CertificateToVerify::Downloaded { - certificate: certificate.clone().try_into()?, - }, - ) - .await?; + // Validate certificates without cache until we cross an epoch boundary + // This is necessary to ensure that the AVK of the current epoch was included + // in a certificate in the previous epoch. + let start_epoch = certificate.epoch; + let mut current_certificate: Option = Some(certificate.clone().try_into()?); + loop { + match current_certificate { + None => break, + Some(next) => { + current_certificate = self + .verify_without_cache(&certificate_chain_validation_id, next) + .await?; + + let has_crossed_epoch_boundary = current_certificate + .as_ref() + .is_some_and(|c| c.epoch != start_epoch); + if has_crossed_epoch_boundary { + break; + } + } + } + } + let mut current_certificate = current_certificate.map(Into::into); loop { match current_certificate { None => break, Some(next) => { current_certificate = self - .verify_one(&certificate_chain_validation_id, next) + .verify_with_cache(&certificate_chain_validation_id, next) .await? } } @@ -304,6 +326,7 @@ mod tests { #[cfg(feature = "unstable")] mod cache { use chrono::TimeDelta; + use mithril_common::test_utils::CertificateChainingMethod; use mockall::predicate::eq; use crate::aggregator_client::MockAggregatorHTTPClient; @@ -349,7 +372,7 @@ mod tests { ); verifier - .verify_one( + .verify_with_cache( "certificate_chain_validation_id", CertificateToVerify::Downloaded { certificate: genesis_certificate.clone(), @@ -385,7 +408,7 @@ mod tests { ); verifier - .verify_one( + .verify_with_cache( "certificate_chain_validation_id", CertificateToVerify::Downloaded { certificate: certificate.clone(), @@ -429,28 +452,64 @@ mod tests { } #[tokio::test] - async fn verification_of_first_certificate_of_a_chain_should_not_use_cache() { + async fn verification_of_certificates_should_not_use_cache_until_crossing_an_epoch_boundary( + ) { + // Chain produced: + // Cert epoch 3 n°6 - parent cert n°5 + // Cert epoch 3 n°5 - parent cert n°4 + // Cert epoch 2 n°4 - parent cert n°3 + // Cert epoch 2 n°3 - parent cert n°2 + // Cert epoch 2 n°2 - parent cert n°1 + // Cert epoch 1 n°1 - genesis let (chain, verifier) = CertificateChainBuilder::new() - .with_total_certificates(2) - .with_certificates_per_epoch(1) + .with_total_certificates(6) + .with_certificates_per_epoch(3) + .with_certificate_chaining_method(CertificateChainingMethod::Sequential) .build(); + let first_certificate = chain.first().unwrap(); let genesis_certificate = chain.last().unwrap(); assert!(genesis_certificate.is_genesis()); + // The two certificates on the last epoch plus the last of the second epoch must be + // fetched from the network as we need to cross an epoch boundary + let certificates_that_must_not_be_fetched_from_cache = chain[..3].to_vec(); + let certificate_that_can_be_fetched_from_cache = chain[2..5].to_vec(); + let mut cache = MockCertificateVerifierCache::new(); - cache.expect_get_previous_hash().returning(|_| Ok(None)); + + for certificate in &certificates_that_must_not_be_fetched_from_cache { + cache + .expect_get_previous_hash() + .with(eq(certificate.hash.clone())) + .never(); + } + for certificate in certificate_that_can_be_fetched_from_cache { + let previous_hash = certificate.previous_hash.clone(); + cache + .expect_get_previous_hash() + .with(eq(certificate.hash.clone())) + .return_once(|_| Ok(Some(previous_hash))) + .once(); + } cache .expect_get_previous_hash() - .with(eq(first_certificate.hash.clone())) - .never(); + .with(eq(genesis_certificate.hash.clone())) + .returning(|_| Ok(None)); + cache .expect_store_validated_certificate() .returning(|_, _| Ok(())); let certificate_client = CertificateClientTestBuilder::default() .config_aggregator_client_mock(|mock| { - mock.expect_certificate_chain(chain.clone()); + mock.expect_certificate_chain( + [ + certificates_that_must_not_be_fetched_from_cache, + vec![genesis_certificate.clone()], + ] + .concat(), + ); }) .with_genesis_verification_key(verifier.to_verification_key()) .with_verifier_cache(Arc::new(cache)) @@ -470,14 +529,14 @@ mod tests { .build(); let last_certificate_hash = chain.first().unwrap().hash.clone(); - // All certificates are cached except the last and the genesis (we always fetch the both) + // All certificates are cached except the last two (to cross an epoch boundary) and the genesis (we always fetch the both) let cache = MemoryCertificateVerifierCache::new(TimeDelta::hours(3)) - .with_items_from_chain(&chain[1..4]); + .with_items_from_chain(&chain[2..4]); let certificate_client = CertificateClientTestBuilder::default() .config_aggregator_client_mock(|mock| { mock.expect_certificate_chain( - [chain[0..2].to_vec(), vec![chain.last().unwrap().clone()]].concat(), + [chain[0..3].to_vec(), vec![chain.last().unwrap().clone()]].concat(), ) }) .with_genesis_verification_key(verifier.to_verification_key())