diff --git a/Cargo.lock b/Cargo.lock index 4a1c8d090..c31c683d9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2739,6 +2739,7 @@ dependencies = [ "postgres-types", "prio 0.17.0-alpha.0", "prometheus", + "querystring", "quickcheck", "quickcheck_macros", "rand", diff --git a/aggregator/Cargo.toml b/aggregator/Cargo.toml index c0ed1a11d..535fa0a77 100644 --- a/aggregator/Cargo.toml +++ b/aggregator/Cargo.toml @@ -29,14 +29,13 @@ test-util = [ "janus_aggregator_core/test-util", "janus_core/test-util", "janus_messages/test-util", - "dep:assert_matches", "dep:testcontainers", "dep:trillium-testing", ] [dependencies] anyhow.workspace = true -assert_matches = { workspace = true, optional = true } +assert_matches = { workspace = true } async-trait = { workspace = true } aws-lc-rs = { workspace = true } backoff = { workspace = true, features = ["tokio"] } @@ -70,6 +69,7 @@ postgres-protocol = { workspace = true } postgres-types = { workspace = true, features = ["derive", "array-impls"] } prio.workspace = true prometheus = { workspace = true, optional = true } +querystring = { workspace = true } rand = { workspace = true, features = ["min_const_gen"] } rayon.workspace = true regex = { workspace = true } diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index fce9da1c1..cfba66a70 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -4,9 +4,10 @@ pub use crate::aggregator::error::Error; use crate::{ aggregator::{ aggregate_share::compute_aggregate_share, + aggregation_job_init::compute_helper_aggregate_init, aggregation_job_writer::{ AggregationJobWriter, AggregationJobWriterMetrics, InitialWrite, - ReportAggregationUpdate as _, WritableReportAggregation, + ReportAggregationUpdate as _, UpdateWrite, WritableReportAggregation, }, batch_mode::{CollectableBatchMode, UploadableBatchMode}, error::{ @@ -26,6 +27,7 @@ use crate::{ report_aggregation_success_counter, }, }; +use aggregation_job_continue::compute_helper_aggregate_continue; use aws_lc_rs::{ digest::{digest, SHA256}, rand::SystemRandom, @@ -52,7 +54,7 @@ use janus_aggregator_core::{ }, Datastore, Error as DatastoreError, Transaction, }, - task::{self, AggregatorTask, BatchMode, VerifyKey}, + task::{self, AggregationMode, AggregatorTask, BatchMode}, taskprov::PeerAggregator, }; #[cfg(feature = "fpvec_bounded_l2")] @@ -63,8 +65,8 @@ use janus_core::{ retries::{retry_http_request_notify, HttpResponse}, time::{Clock, DurationExt, IntervalExt, TimeExt}, vdaf::{ - new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128, vdaf_application_context, - Prio3SumVecField64MultiproofHmacSha256Aes128, VdafInstance, VERIFY_KEY_LENGTH, + new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128, + Prio3SumVecField64MultiproofHmacSha256Aes128, VdafInstance, }, Runtime, }; @@ -73,10 +75,9 @@ use janus_messages::{ taskprov::TaskConfig, AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobStep, - BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp, Duration, ExtensionType, - HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, - PrepareResp, PrepareStepResult, Report, ReportError, ReportIdChecksum, ReportShare, Role, - TaskId, + BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp, Duration, HpkeConfig, + HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare, + PrepareResp, Report, ReportError, ReportIdChecksum, Role, TaskId, }; use opentelemetry::{ metrics::{Counter, Histogram, Meter}, @@ -91,18 +92,16 @@ use prio::{ dp::DifferentialPrivacyStrategy, field::Field64, flp::gadgets::{Mul, ParallelSum}, - topology::ping_pong::{PingPongState, PingPongTopology}, vdaf::{ self, prio3::{Prio3, Prio3Count, Prio3Histogram, Prio3Sum, Prio3SumVec}, }, }; use rand::{thread_rng, Rng}; -use rayon::iter::{IndexedParallelIterator as _, IntoParallelRefIterator as _, ParallelIterator}; use reqwest::Client; use std::{ borrow::Cow, - collections::{HashMap, HashSet}, + collections::HashSet, fmt::Debug, hash::Hash, panic, @@ -110,16 +109,15 @@ use std::{ sync::{Arc, Mutex as SyncMutex}, time::{Duration as StdDuration, Instant}, }; -use tokio::{sync::mpsc, try_join}; -use tracing::{debug, error, info, info_span, trace_span, warn, Level, Span}; +use tokio::try_join; +use tracing::{debug, error, info, warn, Level}; use url::Url; -#[cfg(test)] -mod aggregate_init_tests; pub mod aggregate_share; pub mod aggregation_job_continue; pub mod aggregation_job_creator; pub mod aggregation_job_driver; +pub mod aggregation_job_init; pub mod aggregation_job_writer; pub mod batch_creator; pub mod batch_mode; @@ -162,7 +160,7 @@ pub struct Aggregator { } #[derive(Clone)] -struct AggregatorMetrics { +pub struct AggregatorMetrics { /// Counter tracking the number of failed decryptions while handling the /// `tasks/{task-id}/reports` endpoint. upload_decrypt_failure_counter: Counter, @@ -279,9 +277,9 @@ impl Aggregator { cfg.max_upload_batch_size, cfg.max_upload_batch_write_delay, ), - // If we're in taskprov mode, we can never cache None entries for tasks, since aggregators - // could insert tasks at any time and expect them to be available across all aggregator - // replicas. + // If we're in taskprov mode, we can never cache None entries for tasks, since + // aggregators could insert tasks at any time and expect them to be available across all + // aggregator replicas. !cfg.taskprov_config.enabled, cfg.task_cache_capacity, cfg.task_cache_ttl, @@ -429,7 +427,6 @@ impl Aggregator { task_aggregator .handle_aggregate_init( Arc::clone(&self.datastore), - &self.clock, Arc::clone(&self.hpke_keypairs), &self.metrics, self.cfg.batch_aggregation_shard_count, @@ -492,6 +489,45 @@ impl Aggregator { .await } + async fn handle_aggregate_get( + &self, + task_id: &TaskId, + aggregation_job_id: &AggregationJobId, + auth_token: Option, + taskprov_task_config: Option<&TaskConfig>, + step: AggregationJobStep, + ) -> Result { + let task_aggregator = self + .task_aggregators + .get(task_id) + .await? + .ok_or(Error::UnrecognizedTask(*task_id))?; + if task_aggregator.task.role() != &Role::Helper + || task_aggregator.task.aggregation_mode() != Some(&AggregationMode::Asynchronous) + { + return Err(Error::UnrecognizedTask(*task_id)); + } + + if self.cfg.taskprov_config.enabled && taskprov_task_config.is_some() { + self.taskprov_authorize_request( + &Role::Leader, + task_id, + taskprov_task_config.unwrap(), + auth_token.as_ref(), + ) + .await?; + } else if !task_aggregator + .task + .check_aggregator_auth_token(auth_token.as_ref()) + { + return Err(Error::UnauthorizedRequest(*task_id)); + } + + task_aggregator + .handle_aggregate_get(Arc::clone(&self.datastore), aggregation_job_id, step) + .await + } + async fn handle_aggregate_delete( &self, task_id: &TaskId, @@ -712,7 +748,15 @@ impl Aggregator { u64::from(*task_config.min_batch_size()), *task_config.time_precision(), *peer_aggregator.tolerable_clock_skew(), - task::AggregatorTaskParameters::TaskprovHelper, + task::AggregatorTaskParameters::TaskprovHelper { + aggregation_mode: peer_aggregator.aggregation_mode().copied().ok_or_else( + || { + Error::Internal( + "peer aggregator has no aggregation mode specified".to_string(), + ) + }, + )?, + }, ) .map_err(|err| Error::InvalidTask(*task_id, OptOutReason::TaskParameters(err)))? .with_taskprov_task_info(task_config.task_info().to_vec()), @@ -824,14 +868,12 @@ impl TaskAggregator { let vdaf_ops = match task.vdaf() { VdafInstance::Prio3Count => { let vdaf = Prio3::new_count(2)?; - let verify_key = task.vdaf_verify_key()?; - VdafOps::Prio3Count(Arc::new(vdaf), verify_key) + VdafOps::Prio3Count(Arc::new(vdaf)) } VdafInstance::Prio3Sum { max_measurement } => { let vdaf = Prio3::new_sum(2, *max_measurement)?; - let verify_key = task.vdaf_verify_key()?; - VdafOps::Prio3Sum(Arc::new(vdaf), verify_key) + VdafOps::Prio3Sum(Arc::new(vdaf)) } VdafInstance::Prio3SumVec { @@ -841,10 +883,8 @@ impl TaskAggregator { dp_strategy, } => { let vdaf = Prio3::new_sum_vec(2, *bits, *length, *chunk_length)?; - let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3SumVec( Arc::new(vdaf), - verify_key, vdaf_ops_strategies::Prio3SumVec::from_vdaf_dp_strategy(dp_strategy.clone()), ) } @@ -859,10 +899,8 @@ impl TaskAggregator { let vdaf = new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128::< ParallelSum>, >(*proofs, *bits, *length, *chunk_length)?; - let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3SumVecField64MultiproofHmacSha256Aes128( Arc::new(vdaf), - verify_key, vdaf_ops_strategies::Prio3SumVec::from_vdaf_dp_strategy(dp_strategy.clone()), ) } @@ -873,10 +911,8 @@ impl TaskAggregator { dp_strategy, } => { let vdaf = Prio3::new_histogram(2, *length, *chunk_length)?; - let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3Histogram( Arc::new(vdaf), - verify_key, vdaf_ops_strategies::Prio3Histogram::from_vdaf_dp_strategy(dp_strategy.clone()), ) } @@ -890,10 +926,8 @@ impl TaskAggregator { Prio3FixedPointBoundedL2VecSumBitSize::BitSize16 => { let vdaf: Prio3FixedPointBoundedL2VecSum> = Prio3::new_fixedpoint_boundedl2_vec_sum(2, *length)?; - let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3FixedPoint16BitBoundedL2VecSum( Arc::new(vdaf), - verify_key, vdaf_ops_strategies::Prio3FixedPointBoundedL2VecSum::from_vdaf_dp_strategy( dp_strategy.clone(), ), @@ -902,10 +936,8 @@ impl TaskAggregator { Prio3FixedPointBoundedL2VecSumBitSize::BitSize32 => { let vdaf: Prio3FixedPointBoundedL2VecSum> = Prio3::new_fixedpoint_boundedl2_vec_sum(2, *length)?; - let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3FixedPoint32BitBoundedL2VecSum( Arc::new(vdaf), - verify_key, vdaf_ops_strategies::Prio3FixedPointBoundedL2VecSum::from_vdaf_dp_strategy( dp_strategy.clone(), ), @@ -968,27 +1000,25 @@ impl TaskAggregator { async fn handle_aggregate_init( &self, datastore: Arc>, - clock: &C, hpke_keypairs: Arc, metrics: &AggregatorMetrics, batch_aggregation_shard_count: u64, task_counter_shard_count: u64, aggregation_job_id: &AggregationJobId, - require_taskprov_extension: bool, + require_taskbind_extension: bool, log_forbidden_mutations: Option, req_bytes: &[u8], ) -> Result { self.vdaf_ops .handle_aggregate_init( datastore, - clock, hpke_keypairs, metrics, Arc::clone(&self.task), batch_aggregation_shard_count, task_counter_shard_count, aggregation_job_id, - require_taskprov_extension, + require_taskbind_extension, log_forbidden_mutations, req_bytes, ) @@ -1013,12 +1043,23 @@ impl TaskAggregator { batch_aggregation_shard_count, task_counter_shard_count, aggregation_job_id, - Arc::new(req), + req, request_hash, ) .await } + async fn handle_aggregate_get( + &self, + datastore: Arc>, + aggregation_job_id: &AggregationJobId, + step: AggregationJobStep, + ) -> Result { + self.vdaf_ops + .handle_aggregate_get(datastore, Arc::clone(&self.task), aggregation_job_id, step) + .await + } + async fn handle_aggregate_delete( &self, datastore: &Datastore, @@ -1160,33 +1201,22 @@ mod vdaf_ops_strategies { #[allow(clippy::enum_variant_names)] #[derive(Debug)] enum VdafOps { - Prio3Count(Arc, VerifyKey), - Prio3Sum(Arc, VerifyKey), - Prio3SumVec( - Arc, - VerifyKey, - vdaf_ops_strategies::Prio3SumVec, - ), + Prio3Count(Arc), + Prio3Sum(Arc), + Prio3SumVec(Arc, vdaf_ops_strategies::Prio3SumVec), Prio3SumVecField64MultiproofHmacSha256Aes128( Arc>>>, - VerifyKey<32>, vdaf_ops_strategies::Prio3SumVec, ), - Prio3Histogram( - Arc, - VerifyKey, - vdaf_ops_strategies::Prio3Histogram, - ), + Prio3Histogram(Arc, vdaf_ops_strategies::Prio3Histogram), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint16BitBoundedL2VecSum( Arc>>, - VerifyKey, vdaf_ops_strategies::Prio3FixedPointBoundedL2VecSum, ), #[cfg(feature = "fpvec_bounded_l2")] Prio3FixedPoint32BitBoundedL2VecSum( Arc>>, - VerifyKey, vdaf_ops_strategies::Prio3FixedPointBoundedL2VecSum, ), #[cfg(feature = "test-util")] @@ -1199,35 +1229,32 @@ enum VdafOps { /// specify the VDAF's type, and the name of a const that will be set to the VDAF's verify key /// length, also for explicitly specifying type parameters. macro_rules! vdaf_ops_dispatch { - ($vdaf_ops:expr, ($vdaf:pat_param, $verify_key:pat_param, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident, $dp_strategy:ident, $DpStrategy:ident) => $body:tt) => { + ($vdaf_ops:expr, ($vdaf:pat_param, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident, $dp_strategy:ident, $DpStrategy:ident) => $body:tt) => { match $vdaf_ops { - crate::aggregator::VdafOps::Prio3Count(vdaf, verify_key) => { + crate::aggregator::VdafOps::Prio3Count(vdaf) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Count; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH_PRIO3; type $DpStrategy = janus_core::dp::NoDifferentialPrivacy; let $dp_strategy = &Arc::new(janus_core::dp::NoDifferentialPrivacy); let body = $body; body } - crate::aggregator::VdafOps::Prio3Sum(vdaf, verify_key) => { + crate::aggregator::VdafOps::Prio3Sum(vdaf) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH_PRIO3; type $DpStrategy = janus_core::dp::NoDifferentialPrivacy; let $dp_strategy = &Arc::new(janus_core::dp::NoDifferentialPrivacy); let body = $body; body } - crate::aggregator::VdafOps::Prio3SumVec(vdaf, verify_key, _dp_strategy) => { + crate::aggregator::VdafOps::Prio3SumVec(vdaf, _dp_strategy) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3SumVec; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH_PRIO3; match _dp_strategy { vdaf_ops_strategies::Prio3SumVec::NoDifferentialPrivacy => { type $DpStrategy = janus_core::dp::NoDifferentialPrivacy; @@ -1244,9 +1271,8 @@ macro_rules! vdaf_ops_dispatch { } } - crate::aggregator::VdafOps::Prio3SumVecField64MultiproofHmacSha256Aes128(vdaf, verify_key, _dp_strategy) => { + crate::aggregator::VdafOps::Prio3SumVecField64MultiproofHmacSha256Aes128(vdaf, _dp_strategy) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::janus_core::vdaf::Prio3SumVecField64MultiproofHmacSha256Aes128< ::prio::flp::gadgets::ParallelSum< ::prio::field::Field64, @@ -1270,11 +1296,10 @@ macro_rules! vdaf_ops_dispatch { } } - crate::aggregator::VdafOps::Prio3Histogram(vdaf, verify_key, _dp_strategy) => { + crate::aggregator::VdafOps::Prio3Histogram(vdaf, _dp_strategy) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3Histogram; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH_PRIO3; match _dp_strategy { vdaf_ops_strategies::Prio3Histogram::NoDifferentialPrivacy => { type $DpStrategy = janus_core::dp::NoDifferentialPrivacy; @@ -1295,11 +1320,10 @@ macro_rules! vdaf_ops_dispatch { // Note that the variable `_dp_strategy` is used if `$dp_strategy` // and `$DpStrategy` are given. The underscore suppresses warnings // which occur when `vdaf_ops!` is called without these parameters. - crate::aggregator::VdafOps::Prio3FixedPoint16BitBoundedL2VecSum(vdaf, verify_key, _dp_strategy) => { + crate::aggregator::VdafOps::Prio3FixedPoint16BitBoundedL2VecSum(vdaf, _dp_strategy) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSum>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH_PRIO3; match _dp_strategy { vdaf_ops_strategies::Prio3FixedPointBoundedL2VecSum::ZCdpDiscreteGaussian(_strategy) => { @@ -1321,11 +1345,10 @@ macro_rules! vdaf_ops_dispatch { // Note that the variable `_dp_strategy` is used if `$dp_strategy` // and `$DpStrategy` are given. The underscore suppresses warnings // which occur when `vdaf_ops!` is called without these parameters. - crate::aggregator::VdafOps::Prio3FixedPoint32BitBoundedL2VecSum(vdaf, verify_key, _dp_strategy) => { + crate::aggregator::VdafOps::Prio3FixedPoint32BitBoundedL2VecSum(vdaf, _dp_strategy) => { let $vdaf = vdaf; - let $verify_key = verify_key; type $Vdaf = ::prio::vdaf::prio3::Prio3FixedPointBoundedL2VecSum>; - const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; + const $VERIFY_KEY_LENGTH: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH_PRIO3; match _dp_strategy { vdaf_ops_strategies::Prio3FixedPointBoundedL2VecSum::ZCdpDiscreteGaussian(_strategy) => { @@ -1346,7 +1369,6 @@ macro_rules! vdaf_ops_dispatch { #[cfg(feature = "test-util")] crate::aggregator::VdafOps::Fake(vdaf) => { let $vdaf = vdaf; - let $verify_key = &VerifyKey::new([]); type $Vdaf = ::prio::vdaf::dummy::Vdaf; const $VERIFY_KEY_LENGTH: usize = 0; type $DpStrategy = janus_core::dp::NoDifferentialPrivacy; @@ -1357,8 +1379,8 @@ macro_rules! vdaf_ops_dispatch { } }; - ($vdaf_ops:expr, ($vdaf:pat_param, $verify_key:pat_param, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { - vdaf_ops_dispatch!($vdaf_ops, ($vdaf, $verify_key, $Vdaf, $VERIFY_KEY_LENGTH, _unused, _Unused) => $body)}; + ($vdaf_ops:expr, ($vdaf:pat_param, $Vdaf:ident, $VERIFY_KEY_LENGTH:ident) => $body:tt) => { + vdaf_ops_dispatch!($vdaf_ops, ($vdaf, $Vdaf, $VERIFY_KEY_LENGTH, _unused, _Unused) => $body)}; } impl VdafOps { @@ -1374,7 +1396,7 @@ impl VdafOps { ) -> Result<(), Arc> { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_upload_generic::( Arc::clone(vdaf), clock, @@ -1388,7 +1410,7 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_upload_generic::( Arc::clone(vdaf), clock, @@ -1415,23 +1437,21 @@ impl VdafOps { async fn handle_aggregate_init( &self, datastore: Arc>, - clock: &C, hpke_keypairs: Arc, metrics: &AggregatorMetrics, task: Arc, batch_aggregation_shard_count: u64, task_counter_shard_count: u64, aggregation_job_id: &AggregationJobId, - require_taskprov_extension: bool, + require_taskbind_extension: bool, log_forbidden_mutations: Option, req_bytes: &[u8], ) -> Result { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_init_generic::( datastore, - clock, hpke_keypairs, Arc::clone(vdaf), metrics, @@ -1439,8 +1459,7 @@ impl VdafOps { batch_aggregation_shard_count, task_counter_shard_count, aggregation_job_id, - verify_key, - require_taskprov_extension, + require_taskbind_extension, log_forbidden_mutations, req_bytes, ) @@ -1448,10 +1467,9 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, verify_key, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_init_generic::( datastore, - clock, hpke_keypairs, Arc::clone(vdaf), metrics, @@ -1459,8 +1477,7 @@ impl VdafOps { batch_aggregation_shard_count, task_counter_shard_count, aggregation_job_id, - verify_key, - require_taskprov_extension, + require_taskbind_extension, log_forbidden_mutations, req_bytes, ) @@ -1483,12 +1500,12 @@ impl VdafOps { batch_aggregation_shard_count: u64, task_counter_shard_count: u64, aggregation_job_id: &AggregationJobId, - req: Arc, + req: AggregationJobContinueReq, request_hash: [u8; 32], ) -> Result { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_continue_generic::( datastore, Arc::clone(vdaf), @@ -1504,7 +1521,7 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_continue_generic::( datastore, Arc::clone(vdaf), @@ -1522,6 +1539,41 @@ impl VdafOps { } } + #[tracing::instrument(skip(self, datastore), fields(task_id = ?task.id()), err(level = Level::DEBUG))] + async fn handle_aggregate_get( + &self, + datastore: Arc>, + task: Arc, + aggregation_job_id: &AggregationJobId, + step: AggregationJobStep, + ) -> Result { + match task.batch_mode() { + task::BatchMode::TimeInterval => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { + Self::handle_aggregate_get_generic::( + datastore, + Arc::clone(vdaf), + task, + aggregation_job_id, + step, + ).await + }) + } + + task::BatchMode::LeaderSelected { .. } => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { + Self::handle_aggregate_get_generic::( + datastore, + Arc::clone(vdaf), + task, + aggregation_job_id, + step, + ).await + }) + } + } + } + #[tracing::instrument(skip(self, datastore), fields(task_id = ?task.id()), err(level = Level::DEBUG))] async fn handle_aggregate_delete( &self, @@ -1531,7 +1583,7 @@ impl VdafOps { ) -> Result<(), Error> { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (_, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (_, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_delete_generic::( datastore, task, @@ -1540,7 +1592,7 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (_, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (_, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_aggregate_delete_generic::( datastore, task, @@ -1725,27 +1777,16 @@ impl VdafOps { } } -/// Used by the aggregation job initialization handler to represent initialization of a report -/// share. -#[derive(Clone)] -struct ReportShareData -where - A: vdaf::Aggregator, -{ - report_share: ReportShare, - report_aggregation: WritableReportAggregation, -} - impl VdafOps { async fn check_aggregate_init_idempotency( tx: &Transaction<'_, C>, vdaf: &A, task_id: &TaskId, - aggregation_job_id: &AggregationJobId, - req: &AggregationJobInitializeReq, request_hash: [u8; 32], + mutating_aggregation_job: &AggregationJob, + mutating_report_aggregations: impl IntoIterator>, log_forbidden_mutations: Option, - ) -> Result, datastore::Error> + ) -> Result>, datastore::Error> where B: AccumulableBatchMode, A: vdaf::Aggregator + 'static + Send + Sync, @@ -1753,7 +1794,7 @@ impl VdafOps { for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>, { let existing_aggregation_job = match tx - .get_aggregation_job::(task_id, aggregation_job_id) + .get_aggregation_job::(task_id, mutating_aggregation_job.id()) .await? { Some(existing_aggregation_job) => existing_aggregation_job, @@ -1762,7 +1803,7 @@ impl VdafOps { if existing_aggregation_job.state() == &AggregationJobState::Deleted { return Err(datastore::Error::User( - Error::DeletedAggregationJob(*task_id, *aggregation_job_id).into(), + Error::DeletedAggregationJob(*task_id, *mutating_aggregation_job.id()).into(), )); } @@ -1773,21 +1814,20 @@ impl VdafOps { vdaf, &Role::Helper, task_id, - aggregation_job_id, + mutating_aggregation_job.id(), existing_aggregation_job.aggregation_parameter(), ) .await? .iter() .map(|ra| *ra.report_id()) .collect(); - let mutating_request_report_ids: Vec<_> = req - .prepare_inits() - .iter() - .map(|pi| *pi.report_share().metadata().id()) + let mutating_request_report_ids: Vec<_> = mutating_report_aggregations + .into_iter() + .map(|ra| *ra.report_id()) .collect(); let event = AggregationJobInitForbiddenMutationEvent { task_id: *task_id, - aggregation_job_id: *aggregation_job_id, + aggregation_job_id: *mutating_aggregation_job.id(), original_request_hash: existing_aggregation_job.last_request_hash(), original_report_ids, original_batch_id: format!( @@ -1802,9 +1842,12 @@ impl VdafOps { mutating_request_report_ids, mutating_request_batch_id: format!( "{:?}", - req.batch_selector().batch_identifier() + mutating_aggregation_job.partial_batch_identifier() ), - mutating_request_aggregation_parameter: req.aggregation_parameter().to_vec(), + mutating_request_aggregation_parameter: mutating_aggregation_job + .aggregation_parameter() + .get_encoded() + .map_err(|e| datastore::Error::User(e.into()))?, }; let event_id = crate::diagnostic::write_event( log_forbidden_mutations, @@ -1827,31 +1870,31 @@ impl VdafOps { "request hash mismatch on retried aggregation job request", ); } + return Err(datastore::Error::User( Error::ForbiddenMutation { resource_type: "aggregation job", - identifier: aggregation_job_id.to_string(), + identifier: mutating_aggregation_job.id().to_string(), } .into(), )); } - // This is a repeated request. Send the same response we computed last time. - return Ok(Some(AggregationJobResp::Finished { - prepare_resps: tx - .get_report_aggregations_for_aggregation_job( - vdaf, - &Role::Helper, - task_id, - aggregation_job_id, - existing_aggregation_job.aggregation_parameter(), - ) - .await? - .iter() - .filter_map(ReportAggregation::last_prep_resp) - .cloned() - .collect(), - })); + // This is a repeated request. Send the preparation responses we computed last time. + return Ok(Some( + tx.get_report_aggregations_for_aggregation_job( + vdaf, + &Role::Helper, + task_id, + existing_aggregation_job.id(), + existing_aggregation_job.aggregation_parameter(), + ) + .await? + .iter() + .filter_map(ReportAggregation::last_prep_resp) + .cloned() + .collect(), + )); } /// Implements [helper aggregate initialization][1]. @@ -1859,7 +1902,6 @@ impl VdafOps { /// [1]: https://www.ietf.org/archive/id/draft-ietf-ppm-dap-07.html#name-helper-initialization async fn handle_aggregate_init_generic( datastore: Arc>, - clock: &C, hpke_keypairs: Arc, vdaf: Arc, metrics: &AggregatorMetrics, @@ -1867,8 +1909,7 @@ impl VdafOps { batch_aggregation_shard_count: u64, task_counter_shard_count: u64, aggregation_job_id: &AggregationJobId, - verify_key: &VerifyKey, - require_taskprov_extension: bool, + require_taskbind_extension: bool, log_forbidden_mutations: Option, req_bytes: &[u8], ) -> Result @@ -1886,50 +1927,10 @@ impl VdafOps { A::PublicShare: Send + Sync, A::OutputShare: Send + Sync + PartialEq, { - // unwrap safety: SHA-256 computed by ring should always be 32 bytes + // Unwrap safety: SHA-256 computed by ring should always be 32 bytes. let request_hash = digest(&SHA256, req_bytes).as_ref().try_into().unwrap(); - let req = Arc::new( - AggregationJobInitializeReq::::get_decoded(req_bytes) - .map_err(Error::MessageDecode)?, - ); - - // Check if this is a repeated request, and if it is the same as before, send - // the same response as last time. - if let Some(response) = datastore - .run_tx("aggregate_init_idempotecy_check", |tx| { - let vdaf = vdaf.clone(); - let task = Arc::clone(&task); - let aggregation_job_id = *aggregation_job_id; - let req = Arc::clone(&req); - let log_forbidden_mutations = log_forbidden_mutations.clone(); - - Box::pin(async move { - Self::check_aggregate_init_idempotency( - tx, - vdaf.as_ref(), - task.id(), - &aggregation_job_id, - &req, - request_hash, - log_forbidden_mutations, - ) - .await - }) - }) - .await? - { - return Ok(response); - } - - let agg_param = Arc::new( - A::AggregationParam::get_decoded(req.aggregation_parameter()) - .map_err(Error::MessageDecode)?, - ); - - let report_deadline = clock - .now() - .add(task.tolerable_clock_skew()) - .map_err(Error::from)?; + let req = AggregationJobInitializeReq::::get_decoded(req_bytes) + .map_err(Error::MessageDecode)?; // If two ReportShare messages have the same report ID, then the helper MUST abort with // error "invalidMessage". (§4.5.1.2) @@ -1943,328 +1944,7 @@ impl VdafOps { } } - // Compute the next aggregation step. - // - // We validate that each prepare_init can be represented by a `u64` ord value here, so that - // inside the parallel iterator we can unwrap. A conversion failure here will fail the - // entire aggregation. However, this is desirable: this can only happen if we receive too - // many report shares in an aggregation job for us to store, which is a whole-aggregation - // problem rather than a per-report problem. (separately, this would require more than - // u64::MAX report shares in a single aggregation job, which is practically impossible.) - u64::try_from(req.prepare_inits().len())?; - - // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. - // This will cause any attempts to send on `sender` to return a `SendError`, which will be - // returned from the function passed to `try_for_each_with`; `try_for_each_with` will - // terminate early on receiving an error. - let (sender, mut receiver) = mpsc::unbounded_channel(); - let producer_task = tokio::task::spawn_blocking({ - let parent_span = Span::current(); - let hpke_keypairs = Arc::clone(&hpke_keypairs); - let vdaf = Arc::clone(&vdaf); - let task = Arc::clone(&task); - let metrics = metrics.clone(); - let req = Arc::clone(&req); - let aggregation_job_id = *aggregation_job_id; - let verify_key = *verify_key; - let agg_param = Arc::clone(&agg_param); - - move || { - let span = info_span!(parent: parent_span, "handle_aggregate_init_generic threadpool task"); - let ctx = vdaf_application_context(task.id()); - - req - .prepare_inits() - .par_iter() - .enumerate() - .try_for_each(|(ord, prepare_init)| { - let _entered = span.enter(); - - // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) - let hpke_keypair = hpke_keypairs.keypair( - prepare_init - .report_share() - .encrypted_input_share() - .config_id(), - ).ok_or_else(|| { - debug!( - config_id = %prepare_init.report_share().encrypted_input_share().config_id(), - "Helper encrypted input share references unknown HPKE config ID" - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "unknown_hpke_config_id")]); - ReportError::HpkeUnknownConfigId - }); - - let plaintext = hpke_keypair.and_then(|hpke_keypair| { - let input_share_aad = InputShareAad::new( - *task.id(), - prepare_init.report_share().metadata().clone(), - prepare_init.report_share().public_share().to_vec(), - ) - .get_encoded() - .map_err(|err| { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - ?err, - "Couldn't encode input share AAD" - ); - metrics.aggregate_step_failure_counter.add( - 1, - &[KeyValue::new("type", "input_share_aad_encode_failure")], - ); - // HpkeDecryptError isn't strictly accurate, but given that this - // fallible encoding is part of the HPKE decryption process, I think - // this is as close as we can get to a meaningful error signal. - ReportError::HpkeDecryptError - })?; - - hpke::open( - &hpke_keypair, - &HpkeApplicationInfo::new( - &Label::InputShare, - &Role::Client, - &Role::Helper, - ), - prepare_init.report_share().encrypted_input_share(), - &input_share_aad, - ) - .map_err(|error| { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - ?error, - "Couldn't decrypt helper's report share" - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "decrypt_failure")]); - ReportError::HpkeDecryptError - }) - }); - - let plaintext_input_share = plaintext.and_then(|plaintext| { - let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext) - .map_err(|error| { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - ?error, "Couldn't decode helper's plaintext input share", - ); - metrics.aggregate_step_failure_counter.add( - 1, - &[KeyValue::new( - "type", - "plaintext_input_share_decode_failure", - )], - ); - ReportError::InvalidMessage - })?; - - // Build map of extension type to extension data, checking for duplicates. - let mut extensions = HashMap::new(); - if !plaintext_input_share.private_extensions().iter().chain(prepare_init.report_share().metadata().public_extensions()).all(|extension| { - extensions - .insert(*extension.extension_type(), extension.extension_data()) - .is_none() - }) { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - "Received report share with duplicate extensions", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "duplicate_extension")]); - return Err(ReportError::InvalidMessage); - } - - if require_taskprov_extension { - let valid_taskprov_extension_present = extensions - .get(&ExtensionType::Taskbind) - .map(|data| data.is_empty()) - .unwrap_or(false); - if !valid_taskprov_extension_present { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - "Taskprov task received report with missing or malformed \ - taskprov extension", - ); - metrics.aggregate_step_failure_counter.add( - 1, - &[KeyValue::new( - "type", - "missing_or_malformed_taskprov_extension", - )], - ); - return Err(ReportError::InvalidMessage); - } - } else if extensions.contains_key(&ExtensionType::Taskbind) { - // taskprov not enabled, but the taskprov extension is present. - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - "Non-taskprov task received report with unexpected taskprov \ - extension", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "unexpected_taskprov_extension")]); - return Err(ReportError::InvalidMessage); - } - - Ok(plaintext_input_share) - }); - - let input_share = plaintext_input_share.and_then(|plaintext_input_share| { - A::InputShare::get_decoded_with_param( - &(&vdaf, Role::Helper.index().unwrap()), - plaintext_input_share.payload(), - ) - .map_err(|error| { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - ?error, "Couldn't decode helper's input share", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "input_share_decode_failure")]); - ReportError::InvalidMessage - }) - }); - - let public_share = A::PublicShare::get_decoded_with_param( - &vdaf, - prepare_init.report_share().public_share(), - ) - .map_err(|error| { - debug!( - task_id = %task.id(), - report_id = ?prepare_init.report_share().metadata().id(), - ?error, "Couldn't decode public share", - ); - metrics - .aggregate_step_failure_counter - .add(1, &[KeyValue::new("type", "public_share_decode_failure")]); - ReportError::InvalidMessage - }); - - let shares = - input_share.and_then(|input_share| Ok((public_share?, input_share))); - - // Reject reports from too far in the future. - let shares = shares.and_then(|shares| { - if prepare_init - .report_share() - .metadata() - .time() - .is_after(&report_deadline) - { - return Err(ReportError::ReportTooEarly); - } - Ok(shares) - }); - - // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF - // associated with the task and computes the first state transition. [...] If either - // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) - let init_rslt = shares.and_then(|(public_share, input_share)| { - trace_span!("VDAF preparation (helper initialization)").in_scope(|| { - vdaf.helper_initialized( - verify_key.as_bytes(), - &ctx, - &agg_param, - /* report ID is used as VDAF nonce */ - prepare_init.report_share().metadata().id().as_ref(), - &public_share, - &input_share, - prepare_init.message(), - ) - .and_then(|transition| transition.evaluate(&ctx, &vdaf)) - .map_err(|error| { - handle_ping_pong_error( - task.id(), - Role::Helper, - prepare_init.report_share().metadata().id(), - error, - &metrics.aggregate_step_failure_counter, - ) - }) - }) - }); - - let (report_aggregation_state, prepare_step_result, output_share) = - match init_rslt { - Ok((PingPongState::Continued(prepare_state), outgoing_message)) => { - // Helper is not finished. Await the next message from the Leader to advance to - // the next step. - ( - ReportAggregationState::HelperContinue { prepare_state }, - PrepareStepResult::Continue { - message: outgoing_message, - }, - None, - ) - } - Ok((PingPongState::Finished(output_share), outgoing_message)) => ( - ReportAggregationState::Finished, - PrepareStepResult::Continue { - message: outgoing_message, - }, - Some(output_share), - ), - Err(report_error) => ( - ReportAggregationState::Failed { report_error }, - PrepareStepResult::Reject(report_error), - None, - ), - }; - - sender.send(ReportShareData { - report_share: prepare_init.report_share().clone(), - report_aggregation: WritableReportAggregation::new( - ReportAggregation::::new( - *task.id(), - aggregation_job_id, - *prepare_init.report_share().metadata().id(), - *prepare_init.report_share().metadata().time(), - // Unwrap safety: we checked that all ordinal values are representable - // as a u64 before entering the parallel iterator. - ord.try_into().unwrap(), - Some(PrepareResp::new( - *prepare_init.report_share().metadata().id(), - prepare_step_result, - )), - report_aggregation_state, - ), - output_share, - ), - }) - }) - } - }); - - let mut report_share_data = Vec::with_capacity(req.prepare_inits().len()); - while receiver.recv_many(&mut report_share_data, 10).await > 0 {} - let report_share_data = Arc::new(report_share_data); - - // Await the producer task to resume any panics that may have occurred, and to ensure we can - // unwrap the aggregation parameter's Arc in a few lines. The only other errors that can - // occur are: a `JoinError` indicating cancellation, which is impossible because we do not - // cancel the task; and a `SendError`, which can only happen if this future is cancelled (in - // which case we will not run this code at all). - let _ = producer_task.await.map_err(|join_error| { - if let Ok(reason) = join_error.try_into_panic() { - panic::resume_unwind(reason); - } - }); - assert_eq!(report_share_data.len(), req.prepare_inits().len()); - - // Store data to datastore. + // Build initial aggregation job & report aggregations. let min_client_timestamp = req .prepare_inits() .iter() @@ -2283,61 +1963,286 @@ impl VdafOps { .difference(&min_client_timestamp)? .add(&Duration::from_seconds(1))?, )?; - let aggregation_job = Arc::new( - AggregationJob::::new( - *task.id(), - *aggregation_job_id, - Arc::unwrap_or_clone(agg_param), - req.batch_selector().batch_identifier().clone(), - client_timestamp_interval, - // For one-round VDAFs, the aggregation job will actually be finished, but the - // aggregation job writer handles updating its state. - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ) - .with_last_request_hash(request_hash), - ); + let aggregation_job = AggregationJob::::new( + *task.id(), + *aggregation_job_id, + A::AggregationParam::get_decoded(req.aggregation_parameter()) + .map_err(Error::MessageDecode)?, + req.batch_selector().batch_identifier().clone(), + client_timestamp_interval, + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + ) + .with_last_request_hash(request_hash); - let (response, counters) = datastore - .run_tx("aggregate_init", |tx| { + let report_aggregations = req + .prepare_inits() + .iter() + .enumerate() + .map(|(ord, prepare_init)| { + Ok(ReportAggregation::::new( + *task.id(), + *aggregation_job_id, + *prepare_init.report_share().metadata().id(), + *prepare_init.report_share().metadata().time(), + u64::try_from(ord)?, + None, + ReportAggregationState::HelperInitProcessing { + prepare_init: prepare_init.clone(), + require_taskbind_extension, + }, + )) + }) + .collect::, Error>>()?; + + match task.aggregation_mode() { + Some(AggregationMode::Synchronous) => { + Self::handle_aggregate_init_generic_sync( + datastore, + hpke_keypairs, + vdaf, + metrics, + task, + batch_aggregation_shard_count, + task_counter_shard_count, + log_forbidden_mutations, + request_hash, + aggregation_job, + report_aggregations, + ) + .await + } + + Some(AggregationMode::Asynchronous) => { + Self::handle_aggregate_init_generic_async( + datastore, + vdaf, + metrics, + task, + batch_aggregation_shard_count, + task_counter_shard_count, + log_forbidden_mutations, + request_hash, + aggregation_job, + report_aggregations, + ) + .await + } + + None => Err(Error::Internal("task has no aggregation mode".to_string())), + } + } + + // All report aggregations must be in the HelperInitProcessing state. + async fn handle_aggregate_init_generic_sync( + datastore: Arc>, + hpke_keypairs: Arc, + vdaf: Arc, + metrics: &AggregatorMetrics, + task: Arc, + batch_aggregation_shard_count: u64, + task_counter_shard_count: u64, + log_forbidden_mutations: Option, + request_hash: [u8; 32], + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result + where + B: AccumulableBatchMode, + A: vdaf::Aggregator + 'static + Send + Sync, + C: Clock, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::PrepareMessage: Send + Sync + PartialEq, + A::PrepareShare: Send + Sync + PartialEq, + for<'a> A::PrepareState: + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)> + PartialEq, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync + PartialEq, + { + // Check if this is a repeated request, and if it is the same as before, send + // the same response as last time. + let aggregation_job = Arc::new(aggregation_job); + let report_aggregations = Arc::new(report_aggregations); + if let Some(prepare_resps) = datastore + .run_tx("aggregate_init_idempotecy", |tx| { let vdaf = Arc::clone(&vdaf); let task = Arc::clone(&task); - let aggregation_job_writer_metrics = metrics.for_aggregation_job_writer(); let aggregation_job = Arc::clone(&aggregation_job); - let report_share_data = Arc::clone(&report_share_data); - let req = Arc::clone(&req); + let report_aggregations = Arc::clone(&report_aggregations); let log_forbidden_mutations = log_forbidden_mutations.clone(); Box::pin(async move { - // Check if this is a repeated request, and if it is the same as before, send - // the same response as last time. We check again to avoid the possibility of - // races. - if let Some(response) = Self::check_aggregate_init_idempotency( + Self::check_aggregate_init_idempotency( tx, vdaf.as_ref(), task.id(), - aggregation_job.id(), - &req, request_hash, + &aggregation_job, + report_aggregations.iter(), log_forbidden_mutations, ) - .await? - { - return Ok((response, TaskAggregationCounter::default())); - } + .await + }) + }) + .await? + { + return Ok(AggregationJobResp::Finished { prepare_resps }); + } - // Write report shares, and ensure this isn't a repeated report aggregation. - let report_aggregations = try_join_all(report_share_data.iter().map(|rsd| { - let task = Arc::clone(&task); + // Compute the next aggregation step. + let report_aggregations = compute_helper_aggregate_init( + datastore.clock(), + hpke_keypairs, + Arc::clone(&vdaf), + metrics.clone().into(), + Arc::clone(&task), + Arc::clone(&aggregation_job), + Arc::unwrap_or_clone(report_aggregations), + ) + .await?; - async move { - let mut report_aggregation = Cow::Borrowed(&rsd.report_aggregation); - match tx.put_scrubbed_report(task.id(), &rsd.report_share).await { - Ok(()) => (), - Err(datastore::Error::MutationTargetAlreadyExists) => { - report_aggregation = Cow::Owned( - report_aggregation - .into_owned() + // Store data to datastore. + let prepare_resps = Self::handle_aggregate_init_generic_write( + datastore, + vdaf, + metrics, + task, + batch_aggregation_shard_count, + task_counter_shard_count, + log_forbidden_mutations, + request_hash, + aggregation_job, + Arc::new(report_aggregations), + ) + .await?; + + Ok(AggregationJobResp::Finished { prepare_resps }) + } + + // All report aggregations must be in the HelperInitProcessing state. + async fn handle_aggregate_init_generic_async( + datastore: Arc>, + vdaf: Arc, + metrics: &AggregatorMetrics, + task: Arc, + batch_aggregation_shard_count: u64, + task_counter_shard_count: u64, + log_forbidden_mutations: Option, + request_hash: [u8; 32], + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result + where + B: AccumulableBatchMode, + A: vdaf::Aggregator + 'static + Send + Sync, + C: Clock, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::PrepareMessage: Send + Sync + PartialEq, + A::PrepareShare: Send + Sync + PartialEq, + for<'a> A::PrepareState: + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)> + PartialEq, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync + PartialEq, + { + Self::handle_aggregate_init_generic_write( + datastore, + vdaf, + metrics, + task, + batch_aggregation_shard_count, + task_counter_shard_count, + log_forbidden_mutations, + request_hash, + Arc::new(aggregation_job.with_state(AggregationJobState::Active)), + Arc::new( + report_aggregations + .into_iter() + .map(|ra| WritableReportAggregation::new(ra, None)) + .collect(), + ), + ) + .await?; + + Ok(AggregationJobResp::Processing) + } + + async fn handle_aggregate_init_generic_write( + datastore: Arc>, + vdaf: Arc, + metrics: &AggregatorMetrics, + task: Arc, + batch_aggregation_shard_count: u64, + task_counter_shard_count: u64, + log_forbidden_mutations: Option, + request_hash: [u8; 32], + aggregation_job: Arc>, + report_aggregations: Arc>>, + ) -> Result, Error> + where + B: AccumulableBatchMode, + A: vdaf::Aggregator + 'static + Send + Sync, + C: Clock, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::PrepareMessage: Send + Sync + PartialEq, + A::PrepareShare: Send + Sync + PartialEq, + for<'a> A::PrepareState: + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)> + PartialEq, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync + PartialEq, + { + let (prepare_resps, counters) = datastore + .run_tx("aggregate_init_aggregator_write", |tx| { + let vdaf = Arc::clone(&vdaf); + let task = Arc::clone(&task); + let aggregation_job_writer_metrics = metrics.for_aggregation_job_writer(); + let aggregation_job = Arc::clone(&aggregation_job); + let report_aggregations = Arc::clone(&report_aggregations); + let log_forbidden_mutations = log_forbidden_mutations.clone(); + + Box::pin(async move { + // Check if this is a repeated request, and if it is the same as before, send + // the same response as last time. We check during the write transaction, even + // if we have checked before, to avoid the possibility of races for concurrent + // requests. + if let Some(prepare_resps) = Self::check_aggregate_init_idempotency( + tx, + vdaf.as_ref(), + task.id(), + request_hash, + &aggregation_job, + report_aggregations.iter().map(|ra| ra.report_aggregation()), + log_forbidden_mutations, + ) + .await? + { + return Ok((prepare_resps, TaskAggregationCounter::default())); + } + + // Write report shares, and ensure this isn't a repeated report aggregation. + let report_aggregations = try_join_all(report_aggregations.iter().map(|ra| { + let task = Arc::clone(&task); + + async move { + let mut report_aggregation = Cow::Borrowed(ra); + match tx + .put_scrubbed_report( + task.id(), + ra.report_aggregation().report_id(), + ra.report_aggregation().time(), + ) + .await + { + Ok(()) => (), + Err(datastore::Error::MutationTargetAlreadyExists) => { + report_aggregation = Cow::Owned( + report_aggregation + .into_owned() .with_failure(ReportError::ReportReplayed), ) } @@ -2360,11 +2265,9 @@ impl VdafOps { let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?; Ok(( - AggregationJobResp::Finished { - prepare_resps: prep_resps_by_agg_job - .remove(aggregation_job.id()) - .unwrap_or_default(), - }, + prep_resps_by_agg_job + .remove(aggregation_job.id()) + .unwrap_or_default(), counters, )) }) @@ -2373,7 +2276,7 @@ impl VdafOps { write_task_aggregation_counter(datastore, task_counter_shard_count, *task.id(), counters); - Ok(response) + Ok(prepare_resps) } async fn handle_aggregate_continue_generic< @@ -2389,7 +2292,7 @@ impl VdafOps { batch_aggregation_shard_count: u64, task_counter_shard_count: u64, aggregation_job_id: &AggregationJobId, - req: Arc, + req: AggregationJobContinueReq, request_hash: [u8; 32], ) -> Result where @@ -2410,8 +2313,7 @@ impl VdafOps { )); } - // TODO(#224): don't hold DB transaction open while computing VDAF updates? - // TODO(#1035): don't do O(n) network round-trips (where n is the number of prepare steps) + let req = Arc::new(req); let (response, counters) = datastore .run_tx("aggregate_continue", |tx| { let vdaf = Arc::clone(&vdaf); @@ -2448,6 +2350,8 @@ impl VdafOps { )); } + // Check for a duplicate request, and treat it idempotently. + // // If the leader's request is on the same step as our stored aggregation job, // then we probably have already received this message and computed this step, // but the leader never got our response and so retried stepping the job. @@ -2476,19 +2380,34 @@ impl VdafOps { } } } - return Ok(( - AggregationJobResp::Finished { - prepare_resps: report_aggregations - .iter() - .filter_map(ReportAggregation::last_prep_resp) - .cloned() - .collect(), - }, - TaskAggregationCounter::default(), - )); - } else if aggregation_job.step().increment() != req.step() { - // If this is not a replay, the leader should be advancing our state to the next - // step and no further. + + let resp = match task.aggregation_mode() { + Some(AggregationMode::Synchronous) => { + AggregationJobResp::Finished { + prepare_resps: report_aggregations + .iter() + .filter_map(ReportAggregation::last_prep_resp) + .cloned() + .collect(), + } + } + Some(AggregationMode::Asynchronous) => { + AggregationJobResp::Processing + } + None => { + return Err(datastore::Error::User( + Error::Internal("task has no aggregation mode".to_string()) + .into(), + )) + } + }; + + return Ok((resp, TaskAggregationCounter::default())); + } + + if aggregation_job.step().increment() != req.step() { + // If this is not a replay, the leader should be advancing our state to the + // next step and no further. return Err(datastore::Error::User( Error::StepMismatch { task_id: *task.id(), @@ -2500,20 +2419,113 @@ impl VdafOps { )); } - // The leader is advancing us to the next step. Step the aggregation job to - // compute the next round of prepare messages and state. - Self::step_aggregation_job( - tx, - task, - vdaf, - batch_aggregation_shard_count, - aggregation_job, - report_aggregations, - req, - request_hash, - &metrics, - ) - .await + // Pair incoming preparation continuation messages with existing report + // aggregations. + let mut report_aggregations_to_write = Vec::with_capacity(report_aggregations.len()); + let mut report_aggregations_iter = report_aggregations.into_iter(); + let mut report_aggregations = Vec::with_capacity(req.prepare_continues().len()); + for prepare_continue in req.prepare_continues() { + let report_aggregation = loop { + let report_aggregation = report_aggregations_iter.next().ok_or_else(|| { + datastore::Error::User( + Error::InvalidMessage( + Some(*task.id()), + "leader sent unexpected, duplicate, or out-of-order prepare steps", + ) + .into(), + ) + })?; + if report_aggregation.report_id() != prepare_continue.report_id() { + // This report was omitted by the leader because of a prior failure. + // Note that the report was dropped (if it's not already in an error + // state) and continue. + if matches!( + report_aggregation.state(), + ReportAggregationState::HelperContinue { .. } + ) { + report_aggregations_to_write.push(WritableReportAggregation::new( + report_aggregation + .with_state(ReportAggregationState::Failed { + report_error: ReportError::ReportDropped, + }) + .with_last_prep_resp(None), + None, + )); + } + continue; + } + break report_aggregation; + }; + + let prepare_state = if let ReportAggregationState::HelperContinue{ prepare_state } = report_aggregation.state() { + prepare_state.clone() + } else { + return Err(datastore::Error::User( + Error::InvalidMessage( + Some(*task.id()), + "leader sent prepare step for non-CONTINUE report aggregation", + ) + .into(), + )) + }; + + report_aggregations.push(report_aggregation + .with_state(ReportAggregationState::HelperContinueProcessing { + prepare_state, + prepare_continue: prepare_continue.clone(), + }) + .with_last_prep_resp(None) + ); + } + + for report_aggregation in report_aggregations_iter { + // This report was omitted by the leader because of a prior failure. Note + // that the report was dropped (if it's not already in an error state) and + // continue. + if matches!( + report_aggregation.state(), + ReportAggregationState::HelperContinue { .. } + ) { + report_aggregations_to_write.push(WritableReportAggregation::new( + report_aggregation + .with_state(ReportAggregationState::Failed { + report_error: ReportError::ReportDropped, + }) + .with_last_prep_resp(None), + None, + )); + } + } + + let aggregation_job = aggregation_job + .with_step(req.step()) // Advance the job to the leader's step + .with_last_request_hash(request_hash); + + match task.aggregation_mode() { + Some(AggregationMode::Synchronous) => Self::handle_aggregate_continue_generic_sync( + tx, + task, + vdaf, + batch_aggregation_shard_count, + &metrics, + report_aggregations_to_write, + aggregation_job, + report_aggregations, + ).await, + + Some(AggregationMode::Asynchronous) => Self::handle_aggregate_continue_generic_async( + tx, + task, + vdaf, + batch_aggregation_shard_count, + &metrics, + report_aggregations_to_write, + aggregation_job, + report_aggregations, + ).await, + + None => Err(Error::Internal("task has no aggregation mode".to_string())), + }.map_err(|err| datastore::Error::User(err.into())) }) }) .await?; @@ -2523,6 +2535,251 @@ impl VdafOps { Ok(response) } + // All report aggregations must be in the HelperContinueProcessing state. + async fn handle_aggregate_continue_generic_sync< + const SEED_SIZE: usize, + B: AccumulableBatchMode, + A, + C: Clock, + >( + tx: &Transaction<'_, C>, + task: Arc, + vdaf: Arc, + batch_aggregation_shard_count: u64, + metrics: &AggregatorMetrics, + mut report_aggregations_to_write: Vec>, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(AggregationJobResp, TaskAggregationCounter), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync, + { + // Compute the next aggregation step. + // TODO(#224): don't hold DB transaction open while computing VDAF updates? + let aggregation_job = Arc::new(aggregation_job); + report_aggregations_to_write.extend( + compute_helper_aggregate_continue( + Arc::clone(&vdaf), + metrics.clone().into(), + Arc::clone(&task), + Arc::clone(&aggregation_job), + report_aggregations, + ) + .await, + ); + + // Store data to datastore. + let (prepare_resps, counters) = Self::handle_aggregate_continue_generic_write( + tx, + task, + vdaf, + batch_aggregation_shard_count, + metrics, + Arc::unwrap_or_clone(aggregation_job), + report_aggregations_to_write, + ) + .await?; + Ok((AggregationJobResp::Finished { prepare_resps }, counters)) + } + + // All report aggregations must be in the HelperContinueProcessing state. + async fn handle_aggregate_continue_generic_async< + const SEED_SIZE: usize, + B: AccumulableBatchMode, + A, + C: Clock, + >( + tx: &Transaction<'_, C>, + task: Arc, + vdaf: Arc, + batch_aggregation_shard_count: u64, + metrics: &AggregatorMetrics, + mut report_aggregations_to_write: Vec>, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(AggregationJobResp, TaskAggregationCounter), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync, + { + report_aggregations_to_write.extend( + report_aggregations + .into_iter() + .map(|ra| WritableReportAggregation::new(ra, None)), + ); + + let (_, counters) = Self::handle_aggregate_continue_generic_write( + tx, + task, + vdaf, + batch_aggregation_shard_count, + metrics, + aggregation_job.with_state(AggregationJobState::Active), + report_aggregations_to_write, + ) + .await?; + + Ok((AggregationJobResp::Processing, counters)) + } + + async fn handle_aggregate_continue_generic_write< + const SEED_SIZE: usize, + B: AccumulableBatchMode, + A, + C: Clock, + >( + tx: &Transaction<'_, C>, + task: Arc, + vdaf: Arc, + batch_aggregation_shard_count: u64, + metrics: &AggregatorMetrics, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(Vec, TaskAggregationCounter), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync, + { + // Sanity-check that we have the correct number of report aggregations. + assert_eq!(report_aggregations.len(), report_aggregations.capacity()); + + // Write accumulated aggregation values back to the datastore; this will mark any reports + // that can't be aggregated because the batch is collected with error BatchCollected. + let aggregation_job_id = *aggregation_job.id(); + let mut aggregation_job_writer = + AggregationJobWriter::::new( + task, + batch_aggregation_shard_count, + Some(metrics.for_aggregation_job_writer()), + ); + aggregation_job_writer.put(aggregation_job, report_aggregations)?; + let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?; + Ok(( + prep_resps_by_agg_job + .remove(&aggregation_job_id) + .unwrap_or_default(), + counters, + )) + } + + /// Handle requests to the helper to get an aggregation job. + async fn handle_aggregate_get_generic< + const SEED_SIZE: usize, + B: AccumulableBatchMode, + A, + C: Clock, + >( + datastore: Arc>, + vdaf: Arc, + task: Arc, + aggregation_job_id: &AggregationJobId, + step: AggregationJobStep, + ) -> Result + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync, + { + Ok(datastore + .run_tx("get_aggregation_job", |tx| { + let vdaf = Arc::clone(&vdaf); + let task = Arc::clone(&task); + let aggregation_job_id = *aggregation_job_id; + + Box::pin(async move { + // Read aggregation job & report aggregations. + let aggregation_job = tx + .get_aggregation_job::(task.id(), &aggregation_job_id) + .await? + .ok_or_else(|| { + datastore::Error::User( + Error::UnrecognizedAggregationJob(*task.id(), aggregation_job_id) + .into(), + ) + })?; + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Helper, + task.id(), + &aggregation_job_id, + aggregation_job.aggregation_parameter(), + ) + .await?; + + // Validate that the request is for the expected step. + if aggregation_job.step() != step { + return Err(datastore::Error::User( + Error::StepMismatch { + task_id: *task.id(), + aggregation_job_id, + expected_step: aggregation_job.step(), + got_step: step, + } + .into(), + )); + } + + // Return a value based on the report aggregations. + Ok(match aggregation_job.state() { + AggregationJobState::Active => AggregationJobResp::Processing, + + AggregationJobState::AwaitingRequest | AggregationJobState::Finished => { + AggregationJobResp::Finished { + prepare_resps: report_aggregations + .into_iter() + .filter_map(|ra| ra.last_prep_resp().cloned()) + .collect(), + } + } + + AggregationJobState::Abandoned => { + return Err(datastore::Error::User( + Error::AbandonedAggregationJob(*task.id(), *aggregation_job.id()) + .into(), + )) + } + + AggregationJobState::Deleted => { + return Err(datastore::Error::User( + Error::DeletedAggregationJob(*task.id(), *aggregation_job.id()) + .into(), + )) + } + }) + }) + }) + .await?) + } + /// Handle requests to the helper to delete an aggregation job. async fn handle_aggregate_delete_generic< const SEED_SIZE: usize, @@ -2581,7 +2838,7 @@ impl VdafOps { ) -> Result, Error> { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_create_collection_job_generic::< VERIFY_KEY_LENGTH, TimeInterval, @@ -2592,7 +2849,7 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_create_collection_job_generic::< VERIFY_KEY_LENGTH, LeaderSelected, @@ -2738,7 +2995,7 @@ impl VdafOps { ) -> Result, Error> { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_get_collection_job_generic::< VERIFY_KEY_LENGTH, TimeInterval, @@ -2749,7 +3006,7 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_get_collection_job_generic::< VERIFY_KEY_LENGTH, LeaderSelected, @@ -2883,7 +3140,7 @@ impl VdafOps { ) -> Result<(), Error> { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_delete_collection_job_generic::< VERIFY_KEY_LENGTH, TimeInterval, @@ -2894,7 +3151,7 @@ impl VdafOps { }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH) => { Self::handle_delete_collection_job_generic::< VERIFY_KEY_LENGTH, LeaderSelected, @@ -2970,25 +3227,43 @@ impl VdafOps { ) -> Result { match task.batch_mode() { task::BatchMode::TimeInterval => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH, dp_strategy, DpStrategyType) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH, dp_strategy, DpStrategyType) => { Self::handle_aggregate_share_generic::< VERIFY_KEY_LENGTH, TimeInterval, DpStrategyType, VdafType, _, - >(datastore, clock, task, Arc::clone(vdaf), req_bytes, batch_aggregation_shard_count, collector_hpke_config, Arc::clone(dp_strategy)).await + >( + datastore, + clock, + task, + Arc::clone(vdaf), + req_bytes, + batch_aggregation_shard_count, + collector_hpke_config, + Arc::clone(dp_strategy) + ).await }) } task::BatchMode::LeaderSelected { .. } => { - vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH, dp_strategy, DpStrategyType) => { + vdaf_ops_dispatch!(self, (vdaf, VdafType, VERIFY_KEY_LENGTH, dp_strategy, DpStrategyType) => { Self::handle_aggregate_share_generic::< VERIFY_KEY_LENGTH, LeaderSelected, DpStrategyType, VdafType, _, - >(datastore, clock, task, Arc::clone(vdaf), req_bytes, batch_aggregation_shard_count, collector_hpke_config, Arc::clone(dp_strategy)).await + >( + datastore, + clock, + task, + Arc::clone(vdaf), + req_bytes, + batch_aggregation_shard_count, + collector_hpke_config, + Arc::clone(dp_strategy) + ).await }) } } diff --git a/aggregator/src/aggregator/aggregate_init_tests.rs b/aggregator/src/aggregator/aggregate_init_tests.rs deleted file mode 100644 index 85b44b414..000000000 --- a/aggregator/src/aggregator/aggregate_init_tests.rs +++ /dev/null @@ -1,679 +0,0 @@ -use crate::aggregator::{ - http_handlers::{ - test_util::{decode_response_body, take_problem_details}, - AggregatorHandlerBuilder, - }, - test_util::generate_helper_report_share, - Config, -}; -use assert_matches::assert_matches; -use http::StatusCode; -use janus_aggregator_core::{ - datastore::test_util::{ephemeral_datastore, EphemeralDatastore}, - task::{ - test_util::{Task, TaskBuilder}, - AggregatorTask, BatchMode, - }, - test_util::noop_meter, -}; -use janus_core::{ - auth_tokens::{AuthenticationToken, DAP_AUTH_HEADER}, - test_util::{install_test_trace_subscriber, run_vdaf, runtime::TestRuntime, VdafTranscript}, - time::{Clock, MockClock, TimeExt as _}, - vdaf::VdafInstance, -}; -use janus_messages::{ - batch_mode::{self, TimeInterval}, - AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, Duration, Extension, - ExtensionType, HpkeConfig, PartialBatchSelector, PrepareInit, PrepareResp, PrepareStepResult, - ReportError, ReportMetadata, ReportShare, -}; -use prio::{ - codec::Encode, - vdaf::{self, dummy}, -}; -use rand::random; -use serde_json::json; -use std::sync::Arc; -use trillium::{Handler, KnownHeaderName, Status}; -use trillium_testing::{prelude::put, TestConn}; - -#[derive(Clone)] -pub(super) struct PrepareInitGenerator -where - V: vdaf::Vdaf, -{ - clock: MockClock, - task: AggregatorTask, - vdaf: V, - aggregation_param: V::AggregationParam, - hpke_config: HpkeConfig, - private_extensions: Vec, -} - -impl PrepareInitGenerator -where - V: vdaf::Vdaf + vdaf::Aggregator + vdaf::Client<16>, -{ - pub(super) fn new( - clock: MockClock, - task: AggregatorTask, - hpke_config: HpkeConfig, - vdaf: V, - aggregation_param: V::AggregationParam, - ) -> Self { - Self { - clock, - task, - vdaf, - aggregation_param, - hpke_config, - private_extensions: Vec::new(), - } - } - - pub(super) fn with_private_extensions(mut self, extensions: Vec) -> Self { - self.private_extensions = extensions; - self - } - - pub(super) fn next( - &self, - measurement: &V::Measurement, - ) -> (PrepareInit, VdafTranscript) { - self.next_with_metadata( - ReportMetadata::new( - random(), - self.clock - .now() - .to_batch_interval_start(self.task.time_precision()) - .unwrap(), - Vec::new(), - ), - measurement, - ) - } - - pub(super) fn next_with_metadata( - &self, - report_metadata: ReportMetadata, - measurement: &V::Measurement, - ) -> (PrepareInit, VdafTranscript) { - let (report_share, transcript) = - self.next_report_share_with_metadata(report_metadata, measurement); - ( - PrepareInit::new( - report_share, - transcript.leader_prepare_transitions[0].message.clone(), - ), - transcript, - ) - } - - pub(super) fn next_report_share( - &self, - measurement: &V::Measurement, - ) -> (ReportShare, VdafTranscript) { - self.next_report_share_with_metadata( - ReportMetadata::new( - random(), - self.clock - .now() - .to_batch_interval_start(self.task.time_precision()) - .unwrap(), - Vec::new(), - ), - measurement, - ) - } - - pub(super) fn next_report_share_with_metadata( - &self, - report_metadata: ReportMetadata, - measurement: &V::Measurement, - ) -> (ReportShare, VdafTranscript) { - let transcript = run_vdaf( - &self.vdaf, - self.task.id(), - self.task.vdaf_verify_key().unwrap().as_bytes(), - &self.aggregation_param, - report_metadata.id(), - measurement, - ); - let report_share = generate_helper_report_share::( - *self.task.id(), - report_metadata, - &self.hpke_config, - &transcript.public_share, - self.private_extensions.clone(), - &transcript.helper_input_share, - ); - (report_share, transcript) - } -} - -pub(super) struct AggregationJobInitTestCase< - const VERIFY_KEY_SIZE: usize, - V: vdaf::Aggregator, -> { - pub(super) clock: MockClock, - pub(super) task: Task, - pub(super) prepare_init_generator: PrepareInitGenerator, - pub(super) aggregation_job_id: AggregationJobId, - pub(super) aggregation_job_init_req: AggregationJobInitializeReq, - aggregation_job_init_resp: Option, - pub(super) aggregation_param: V::AggregationParam, - pub(super) handler: Box, - _ephemeral_datastore: EphemeralDatastore, -} - -pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy::Vdaf> { - setup_aggregate_init_test_for_vdaf( - dummy::Vdaf::new(1), - VdafInstance::Fake { rounds: 1 }, - dummy::AggregationParam(0), - 0, - ) - .await -} - -async fn setup_multi_step_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy::Vdaf> { - setup_aggregate_init_test_for_vdaf( - dummy::Vdaf::new(2), - VdafInstance::Fake { rounds: 2 }, - dummy::AggregationParam(7), - 13, - ) - .await -} - -async fn setup_aggregate_init_test_for_vdaf< - const VERIFY_KEY_SIZE: usize, - V: vdaf::Aggregator + vdaf::Client<16>, ->( - vdaf: V, - vdaf_instance: VdafInstance, - aggregation_param: V::AggregationParam, - measurement: V::Measurement, -) -> AggregationJobInitTestCase { - let mut test_case = setup_aggregate_init_test_without_sending_request( - vdaf, - vdaf_instance, - aggregation_param, - measurement, - AuthenticationToken::Bearer(random()), - ) - .await; - - let mut response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &test_case.aggregation_job_init_req, - &test_case.handler, - ) - .await; - assert_eq!(response.status(), Some(Status::Created)); - - let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; - let prepare_resps = assert_matches!( - &aggregation_job_resp, - AggregationJobResp::Finished { prepare_resps } => prepare_resps - ); - assert_eq!( - prepare_resps.len(), - test_case.aggregation_job_init_req.prepare_inits().len(), - ); - assert_matches!( - prepare_resps[0].result(), - &PrepareStepResult::Continue { .. } - ); - - test_case.aggregation_job_init_resp = Some(aggregation_job_resp); - test_case -} - -async fn setup_aggregate_init_test_without_sending_request< - const VERIFY_KEY_SIZE: usize, - V: vdaf::Aggregator + vdaf::Client<16>, ->( - vdaf: V, - vdaf_instance: VdafInstance, - aggregation_param: V::AggregationParam, - measurement: V::Measurement, - auth_token: AuthenticationToken, -) -> AggregationJobInitTestCase { - install_test_trace_subscriber(); - - let task = TaskBuilder::new(BatchMode::TimeInterval, vdaf_instance) - .with_aggregator_auth_token(auth_token) - .build(); - let helper_task = task.helper_view().unwrap(); - let clock = MockClock::default(); - let ephemeral_datastore = ephemeral_datastore().await; - let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - - datastore.put_aggregator_task(&helper_task).await.unwrap(); - let keypair = datastore.put_hpke_key().await.unwrap(); - - let handler = AggregatorHandlerBuilder::new( - Arc::clone(&datastore), - clock.clone(), - TestRuntime::default(), - &noop_meter(), - Config::default(), - ) - .await - .unwrap() - .build() - .unwrap(); - - let prepare_init_generator = PrepareInitGenerator::new( - clock.clone(), - helper_task.clone(), - keypair.config().clone(), - vdaf, - aggregation_param.clone(), - ); - - let prepare_inits = Vec::from([ - prepare_init_generator.next(&measurement).0, - prepare_init_generator.next(&measurement).0, - ]); - - let aggregation_job_id = random(); - let aggregation_job_init_req = AggregationJobInitializeReq::new( - aggregation_param.get_encoded().unwrap(), - PartialBatchSelector::new_time_interval(), - prepare_inits.clone(), - ); - - AggregationJobInitTestCase { - clock, - task, - prepare_init_generator, - aggregation_job_id, - aggregation_job_init_req, - aggregation_job_init_resp: None, - aggregation_param, - handler: Box::new(handler), - _ephemeral_datastore: ephemeral_datastore, - } -} - -pub(crate) async fn put_aggregation_job( - task: &Task, - aggregation_job_id: &AggregationJobId, - aggregation_job: &AggregationJobInitializeReq, - handler: &impl Handler, -) -> TestConn { - let (header, value) = task.aggregator_auth_token().request_authentication(); - - put(task.aggregation_job_uri(aggregation_job_id).unwrap().path()) - .with_request_header(header, value) - .with_request_header( - KnownHeaderName::ContentType, - AggregationJobInitializeReq::::MEDIA_TYPE, - ) - .with_request_body(aggregation_job.get_encoded().unwrap()) - .run_async(handler) - .await -} - -#[tokio::test] -async fn aggregation_job_init_authorization_dap_auth_token() { - let test_case = setup_aggregate_init_test_without_sending_request( - dummy::Vdaf::new(1), - VdafInstance::Fake { rounds: 1 }, - dummy::AggregationParam(0), - 0, - AuthenticationToken::DapAuth(random()), - ) - .await; - - let (auth_header, auth_value) = test_case - .task - .aggregator_auth_token() - .request_authentication(); - - let response = put(test_case - .task - .aggregation_job_uri(&test_case.aggregation_job_id) - .unwrap() - .path()) - .with_request_header(auth_header, auth_value) - .with_request_header( - KnownHeaderName::ContentType, - AggregationJobInitializeReq::::MEDIA_TYPE, - ) - .with_request_body(test_case.aggregation_job_init_req.get_encoded().unwrap()) - .run_async(&test_case.handler) - .await; - - assert_eq!(response.status(), Some(Status::Created)); -} - -#[rstest::rstest] -#[case::not_bearer_token("wrong kind of token")] -#[case::not_base64("Bearer: ")] -#[tokio::test] -async fn aggregation_job_init_malformed_authorization_header(#[case] header_value: &'static str) { - let test_case = setup_aggregate_init_test_without_sending_request( - dummy::Vdaf::new(1), - VdafInstance::Fake { rounds: 1 }, - dummy::AggregationParam(0), - 0, - AuthenticationToken::Bearer(random()), - ) - .await; - - let response = put(test_case - .task - .aggregation_job_uri(&test_case.aggregation_job_id) - .unwrap() - .path()) - // Authenticate using a malformed "Authorization: Bearer " header and a `DAP-Auth-Token` - // header. The presence of the former should cause an error despite the latter being present and - // well formed. - .with_request_header(KnownHeaderName::Authorization, header_value.to_string()) - .with_request_header( - DAP_AUTH_HEADER, - test_case.task.aggregator_auth_token().as_ref().to_owned(), - ) - .with_request_header( - KnownHeaderName::ContentType, - AggregationJobInitializeReq::::MEDIA_TYPE, - ) - .with_request_body(test_case.aggregation_job_init_req.get_encoded().unwrap()) - .run_async(&test_case.handler) - .await; - - assert_eq!(response.status(), Some(Status::Forbidden)); -} - -#[tokio::test] -async fn aggregation_job_init_unexpected_taskprov_extension() { - let test_case = setup_aggregate_init_test_without_sending_request( - dummy::Vdaf::new(1), - VdafInstance::Fake { rounds: 1 }, - dummy::AggregationParam(0), - 0, - random(), - ) - .await; - - let prepare_init = test_case - .prepare_init_generator - .clone() - .with_private_extensions(Vec::from([Extension::new( - ExtensionType::Taskbind, - Vec::new(), - )])) - .next(&0) - .0; - let report_id = *prepare_init.report_share().metadata().id(); - let aggregation_job_init_req = AggregationJobInitializeReq::new( - dummy::AggregationParam(1).get_encoded().unwrap(), - PartialBatchSelector::new_time_interval(), - Vec::from([prepare_init]), - ); - - let mut response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &aggregation_job_init_req, - &test_case.handler, - ) - .await; - assert_eq!(response.status(), Some(Status::Created)); - - let want_aggregation_job_resp = AggregationJobResp::Finished { - prepare_resps: Vec::from([PrepareResp::new( - report_id, - PrepareStepResult::Reject(ReportError::InvalidMessage), - )]), - }; - let got_aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; - assert_eq!(want_aggregation_job_resp, got_aggregation_job_resp); -} - -#[tokio::test] -async fn aggregation_job_mutation_aggregation_job() { - let test_case = setup_aggregate_init_test().await; - - // Put the aggregation job again, but with a different aggregation parameter. - let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( - dummy::AggregationParam(1).get_encoded().unwrap(), - PartialBatchSelector::new_time_interval(), - test_case.aggregation_job_init_req.prepare_inits().to_vec(), - ); - - let response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &mutated_aggregation_job_init_req, - &test_case.handler, - ) - .await; - assert_eq!(response.status(), Some(Status::Conflict)); -} - -#[tokio::test] -async fn aggregation_job_mutation_report_shares() { - let test_case = setup_aggregate_init_test().await; - - let prepare_inits = test_case.aggregation_job_init_req.prepare_inits(); - - // Put the aggregation job again, mutating the associated report shares' metadata such that - // uniqueness constraints on client_reports are violated - for mutated_prepare_inits in [ - // Omit a report share that was included previously - Vec::from(&prepare_inits[0..prepare_inits.len() - 1]), - // Include a different report share than was included previously - [ - &prepare_inits[0..prepare_inits.len() - 1], - &[test_case.prepare_init_generator.next(&0).0], - ] - .concat(), - // Include an extra report share than was included previously - [ - prepare_inits, - &[test_case.prepare_init_generator.next(&0).0], - ] - .concat(), - // Reverse the order of the reports - prepare_inits.iter().rev().cloned().collect(), - ] { - let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( - test_case.aggregation_param.get_encoded().unwrap(), - PartialBatchSelector::new_time_interval(), - mutated_prepare_inits, - ); - let response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &mutated_aggregation_job_init_req, - &test_case.handler, - ) - .await; - assert_eq!(response.status(), Some(Status::Conflict)); - } -} - -#[tokio::test] -async fn aggregation_job_mutation_report_aggregations() { - // We set up a multi-step VDAF in this test so that the aggregation job won't finish on the - // first step. - let test_case = setup_multi_step_aggregate_init_test().await; - - // Generate some new reports using the existing reports' metadata, but varying the measurement - // values such that the prepare state computed during aggregation initializaton won't match the - // first aggregation job. - let mutated_prepare_inits = test_case - .aggregation_job_init_req - .prepare_inits() - .iter() - .map(|s| { - test_case - .prepare_init_generator - .next_with_metadata(s.report_share().metadata().clone(), &13) - .0 - }) - .collect(); - - let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( - test_case.aggregation_param.get_encoded().unwrap(), - PartialBatchSelector::new_time_interval(), - mutated_prepare_inits, - ); - - let response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &mutated_aggregation_job_init_req, - &test_case.handler, - ) - .await; - assert_eq!(response.status(), Some(Status::Conflict)); -} - -#[tokio::test] -async fn aggregation_job_intolerable_clock_skew() { - let mut test_case = setup_aggregate_init_test_without_sending_request( - dummy::Vdaf::new(1), - VdafInstance::Fake { rounds: 1 }, - dummy::AggregationParam(0), - 0, - AuthenticationToken::Bearer(random()), - ) - .await; - - test_case.aggregation_job_init_req = AggregationJobInitializeReq::new( - test_case.aggregation_param.get_encoded().unwrap(), - PartialBatchSelector::new_time_interval(), - Vec::from([ - // Barely tolerable. - test_case - .prepare_init_generator - .next_with_metadata( - ReportMetadata::new( - random(), - test_case - .clock - .now() - .add(test_case.task.tolerable_clock_skew()) - .unwrap(), - Vec::new(), - ), - &0, - ) - .0, - // Barely intolerable. - test_case - .prepare_init_generator - .next_with_metadata( - ReportMetadata::new( - random(), - test_case - .clock - .now() - .add(test_case.task.tolerable_clock_skew()) - .unwrap() - .add(&Duration::from_seconds(1)) - .unwrap(), - Vec::new(), - ), - &0, - ) - .0, - ]), - ); - - let mut response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &test_case.aggregation_job_init_req, - &test_case.handler, - ) - .await; - assert_eq!(response.status(), Some(Status::Created)); - - let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; - let prepare_resps = assert_matches!( - aggregation_job_resp, - AggregationJobResp::Finished { prepare_resps } => prepare_resps - ); - assert_eq!( - prepare_resps.len(), - test_case.aggregation_job_init_req.prepare_inits().len(), - ); - assert_matches!( - prepare_resps[0].result(), - &PrepareStepResult::Continue { .. } - ); - assert_matches!( - prepare_resps[1].result(), - &PrepareStepResult::Reject(ReportError::ReportTooEarly) - ); -} - -#[tokio::test] -async fn aggregation_job_init_two_step_vdaf_idempotence() { - // We set up a multi-step VDAF in this test so that the aggregation job won't finish on the - // first step. - let test_case = setup_multi_step_aggregate_init_test().await; - - // Send the aggregation job init request again. We should get an identical response back. - let mut response = put_aggregation_job( - &test_case.task, - &test_case.aggregation_job_id, - &test_case.aggregation_job_init_req, - &test_case.handler, - ) - .await; - - let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; - assert_eq!( - aggregation_job_resp, - test_case.aggregation_job_init_resp.unwrap(), - ); -} - -#[tokio::test] -async fn aggregation_job_init_wrong_query() { - let test_case = setup_aggregate_init_test().await; - - // setup_aggregate_init_test sets up a task with a time interval query. We send a - // leader-selected query which should yield an error. - let wrong_query = AggregationJobInitializeReq::new( - test_case.aggregation_param.get_encoded().unwrap(), - PartialBatchSelector::new_leader_selected(random()), - test_case.aggregation_job_init_req.prepare_inits().to_vec(), - ); - - let (header, value) = test_case - .task - .aggregator_auth_token() - .request_authentication(); - - let mut response = put(test_case - .task - .aggregation_job_uri(&random()) - .unwrap() - .path()) - .with_request_header(header, value) - .with_request_header( - KnownHeaderName::ContentType, - AggregationJobInitializeReq::::MEDIA_TYPE, - ) - .with_request_body(wrong_query.get_encoded().unwrap()) - .run_async(&test_case.handler) - .await; - assert_eq!( - take_problem_details(&mut response).await, - json!({ - "status": StatusCode::BAD_REQUEST.as_u16(), - "type": "urn:ietf:params:ppm:dap:error:invalidMessage", - "title": "The message type for a response was incorrect or the payload was malformed.", - }), - ); -} diff --git a/aggregator/src/aggregator/aggregation_job_continue.rs b/aggregator/src/aggregator/aggregation_job_continue.rs index 498dc385c..0c5b6b85b 100644 --- a/aggregator/src/aggregator/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/aggregation_job_continue.rs @@ -1,29 +1,20 @@ -//! Implements portions of aggregation job continuation for the helper. +//! Implements portions of aggregation job continuation for the Helper. use crate::aggregator::{ - aggregation_job_writer::{AggregationJobWriter, UpdateWrite, WritableReportAggregation}, - error::handle_ping_pong_error, - AggregatorMetrics, Error, VdafOps, + aggregation_job_writer::WritableReportAggregation, handle_ping_pong_error, AggregatorMetrics, }; +use assert_matches::assert_matches; use janus_aggregator_core::{ batch_mode::AccumulableBatchMode, - datastore::{ - self, - models::{ - AggregationJob, ReportAggregation, ReportAggregationState, TaskAggregationCounter, - }, - Transaction, - }, + datastore::models::{AggregationJob, ReportAggregation, ReportAggregationState}, task::AggregatorTask, }; -use janus_core::{time::Clock, vdaf::vdaf_application_context}; -use janus_messages::{ - AggregationJobContinueReq, AggregationJobResp, PrepareResp, PrepareStepResult, ReportError, - Role, -}; +use janus_core::vdaf::vdaf_application_context; +use janus_messages::{PrepareResp, PrepareStepResult, Role}; +use opentelemetry::metrics::Counter; use prio::{ codec::{Encode, ParameterizedDecode}, - topology::ping_pong::{PingPongContinuedValue, PingPongState, PingPongTopology}, + topology::ping_pong::{PingPongContinuedValue, PingPongState, PingPongTopology as _}, vdaf, }; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; @@ -31,269 +22,192 @@ use std::{panic, sync::Arc}; use tokio::sync::mpsc; use tracing::{info_span, trace_span, Span}; -impl VdafOps { - /// Step the helper's aggregation job to the next step using the step `n` ping pong state in - /// `report_aggregations` with the step `n+1` ping pong messages in `leader_aggregation_job`. - pub(super) async fn step_aggregation_job( - tx: &Transaction<'_, C>, - task: Arc, - vdaf: Arc, - batch_aggregation_shard_count: u64, - aggregation_job: AggregationJob, - report_aggregations: Vec>, - req: Arc, - request_hash: [u8; 32], - metrics: &AggregatorMetrics, - ) -> Result<(AggregationJobResp, TaskAggregationCounter), datastore::Error> - where - C: Clock, - B: AccumulableBatchMode, - A: vdaf::Aggregator + 'static + Send + Sync, - A::AggregationParam: Send + Sync + PartialEq + Eq, - A::InputShare: Send + Sync, - A::OutputShare: Send + Sync, - A::PrepareMessage: Send + Sync, - A::PublicShare: Send + Sync, - for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, - { - let request_step = req.step(); - - // Match preparation step received from leader to stored report aggregation, and extract - // the stored preparation step. - let report_aggregation_count = report_aggregations.len(); - let mut report_aggregations_iter = report_aggregations.into_iter(); - - let mut prep_steps_and_ras = Vec::with_capacity(req.prepare_steps().len()); // matched to prep_steps - let mut report_aggregations_to_write = Vec::with_capacity(report_aggregation_count); - for prep_step in req.prepare_steps() { - let report_aggregation = loop { - let report_agg = report_aggregations_iter.next().ok_or_else(|| { - datastore::Error::User( - Error::InvalidMessage( - Some(*task.id()), - "leader sent unexpected, duplicate, or out-of-order prepare steps", - ) - .into(), - ) - })?; - if report_agg.report_id() != prep_step.report_id() { - // This report was omitted by the leader because of a prior failure. Note that - // the report was dropped (if it's not already in an error state) and continue. - if matches!( - report_agg.state(), - ReportAggregationState::HelperContinue { .. } - ) { - report_aggregations_to_write.push(WritableReportAggregation::new( - report_agg - .clone() - .with_state(ReportAggregationState::Failed { - report_error: ReportError::ReportDropped, - }) - .with_last_prep_resp(None), - None, - )); - } - continue; - } - break report_agg; - }; - - let prep_state = match report_aggregation.state() { - ReportAggregationState::HelperContinue { prepare_state } => prepare_state.clone(), - ReportAggregationState::LeaderContinue { .. } => { - return Err(datastore::Error::User( - Error::Internal( - "helper encountered unexpected ReportAggregationState::LeaderContinue" - .to_string(), - ) - .into(), - )) - } - _ => { - return Err(datastore::Error::User( - Error::InvalidMessage( - Some(*task.id()), - "leader sent prepare step for non-WAITING report aggregation", - ) - .into(), - )); - } - }; +#[derive(Clone)] +pub struct AggregateContinueMetrics { + /// Counters tracking the number of failures to step client reports through the aggregation + /// process. + aggregate_step_failure_counter: Counter, +} - prep_steps_and_ras.push((prep_step.clone(), report_aggregation, prep_state)); +impl AggregateContinueMetrics { + pub fn new(aggregate_step_failure_counter: Counter) -> Self { + Self { + aggregate_step_failure_counter, } + } +} - for report_aggregation in report_aggregations_iter { - // This report was omitted by the leader because of a prior failure. Note that - // the report was dropped (if it's not already in an error state) and continue. - if matches!( - report_aggregation.state(), - ReportAggregationState::HelperContinue { .. } - ) { - report_aggregations_to_write.push(WritableReportAggregation::new( - report_aggregation - .clone() - .with_state(ReportAggregationState::Failed { - report_error: ReportError::ReportDropped, - }) - .with_last_prep_resp(None), - None, - )); - } +impl From for AggregateContinueMetrics { + fn from(metrics: AggregatorMetrics) -> Self { + Self { + aggregate_step_failure_counter: metrics.aggregate_step_failure_counter.clone(), } + } +} - // Compute the next aggregation step. - // - // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. - // This will cause any attempts to send on `sender` to return a `SendError`, which will be - // returned from the function passed to `try_for_each_with`; `try_for_each_with` will - // terminate early on receiving an error. - let (sender, mut receiver) = mpsc::unbounded_channel(); - let aggregation_job = Arc::new(aggregation_job); - let producer_task = tokio::task::spawn_blocking({ - let parent_span = Span::current(); - let metrics = metrics.clone(); - let task = Arc::clone(&task); - let vdaf = Arc::clone(&vdaf); - let aggregation_job = Arc::clone(&aggregation_job); - - move || { - let span = info_span!(parent: parent_span, "step_aggregation_job threadpool task"); - let ctx = vdaf_application_context(task.id()); - - prep_steps_and_ras.into_par_iter().try_for_each( - |(prep_step, report_aggregation, prep_state)| { - let _entered = span.enter(); - - let (report_aggregation_state, prepare_step_result, output_share) = - trace_span!("VDAF preparation (helper continuation)") - .in_scope(|| { - // Continue with the incoming message. - vdaf.helper_continued( - &ctx, - PingPongState::Continued(prep_state.clone()), - aggregation_job.aggregation_parameter(), - prep_step.message(), - ) - .and_then( - |continued_value| { - match continued_value { - PingPongContinuedValue::WithMessage { - transition, - } => { - let (new_state, message) = - transition.evaluate(&ctx, vdaf.as_ref())?; - let (report_aggregation_state, output_share) = - match new_state { - // Helper did not finish. Store the new - // state and await the next message from - // the Leader to advance preparation. - PingPongState::Continued(prepare_state) => ( - ReportAggregationState::HelperContinue { - prepare_state, - }, - None, - ), - // Helper finished. Commit the output - // share. - PingPongState::Finished(output_share) => ( - ReportAggregationState::Finished, - Some(output_share), - ), - }; - - Ok(( - report_aggregation_state, - // Helper has an outgoing message for Leader - PrepareStepResult::Continue { message }, - output_share, - )) - } - - PingPongContinuedValue::FinishedNoMessage { - output_share, - } => Ok(( - ReportAggregationState::Finished, - PrepareStepResult::Finished, - Some(output_share), - )), - } - }, - ) - }) - .map_err(|error| { - handle_ping_pong_error( - task.id(), - Role::Leader, - prep_step.report_id(), - error, - &metrics.aggregate_step_failure_counter, - ) +/// Given report aggregations in the `HelperContinueProcessing` state, this function computes the +/// next step of the aggregation; the returned [`WritableReportAggregation`]s correspond to the +/// provided report aggregations and will be in the `HelperContinue`, `Finished`, or `Failed` +/// states. +/// +/// Only report aggregations in the `HelperContinueProcessing` state can be provided. The caller +/// must filter report aggregations which are in other states (e.g. `Failed`) prior to calling this +/// function. +/// +/// ### Panics +/// +/// Panics if a provided report aggregation is in a state other than `HelperContinueProcessing`. +pub async fn compute_helper_aggregate_continue( + vdaf: Arc, + metrics: AggregateContinueMetrics, + task: Arc, + aggregation_job: Arc>, + report_aggregations: Vec>, +) -> Vec> +where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareShare: Send + Sync, + A::PrepareMessage: Send + Sync, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync, + B: AccumulableBatchMode, +{ + let report_aggregation_count = report_aggregations.len(); + + // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. + // This will cause any attempts to send on `sender` to return a `SendError`, which will be + // returned from the function passed to `try_for_each_with`; `try_for_each_with` will + // terminate early on receiving an error. + let (sender, mut receiver) = mpsc::unbounded_channel(); + let producer_task = tokio::task::spawn_blocking({ + let parent_span = Span::current(); + let metrics = metrics.clone(); + let task = Arc::clone(&task); + let vdaf = Arc::clone(&vdaf); + let aggregation_job = Arc::clone(&aggregation_job); + + move || { + let span = info_span!(parent: parent_span, "step_aggregation_job threadpool task"); + let ctx = vdaf_application_context(task.id()); + + report_aggregations + .into_par_iter() + .try_for_each(|report_aggregation| { + let _entered = span.enter(); + + // Assert safety: this function should only be called with report + // aggregations in the HelperContinueProcessing state. + let (prepare_state, prepare_continue) = assert_matches!( + report_aggregation.state(), + ReportAggregationState::HelperContinueProcessing{ + prepare_state, + prepare_continue + } => (prepare_state, prepare_continue) + ); + + let (report_aggregation_state, prepare_step_result, output_share) = + trace_span!("VDAF preparation (helper continuation)") + .in_scope(|| { + // Continue with the incoming message. + vdaf.helper_continued( + &ctx, + PingPongState::Continued(prepare_state.clone()), + aggregation_job.aggregation_parameter(), + prepare_continue.message(), + ) + .and_then(|continued_value| { + match continued_value { + PingPongContinuedValue::WithMessage { transition } => { + let (new_state, message) = + transition.evaluate(&ctx, vdaf.as_ref())?; + let (report_aggregation_state, output_share) = + match new_state { + // Helper did not finish. Store the new + // state and await the next message from + // the Leader to advance preparation. + PingPongState::Continued(prepare_state) => ( + ReportAggregationState::HelperContinue { + prepare_state, + }, + None, + ), + // Helper finished. Commit the output + // share. + PingPongState::Finished(output_share) => ( + ReportAggregationState::Finished, + Some(output_share), + ), + }; + + Ok(( + report_aggregation_state, + // Helper has an outgoing message for Leader + PrepareStepResult::Continue { message }, + output_share, + )) + } + + PingPongContinuedValue::FinishedNoMessage { + output_share, + } => Ok(( + ReportAggregationState::Finished, + PrepareStepResult::Finished, + Some(output_share), + )), + } }) - .unwrap_or_else(|report_error| { - ( - ReportAggregationState::Failed { report_error }, - PrepareStepResult::Reject(report_error), - None, - ) - }); - - sender.send(WritableReportAggregation::new( - report_aggregation - .clone() - .with_state(report_aggregation_state) - .with_last_prep_resp(Some(PrepareResp::new( - *prep_step.report_id(), - prepare_step_result, - ))), - output_share, - )) - }, - ) - } - }); + }) + .map_err(|error| { + handle_ping_pong_error( + task.id(), + Role::Leader, + prepare_continue.report_id(), + error, + &metrics.aggregate_step_failure_counter, + ) + }) + .unwrap_or_else(|report_error| { + ( + ReportAggregationState::Failed { report_error }, + PrepareStepResult::Reject(report_error), + None, + ) + }); + + let report_id = *report_aggregation.report_id(); + sender.send(WritableReportAggregation::new( + report_aggregation + .with_state(report_aggregation_state) + .with_last_prep_resp(Some(PrepareResp::new( + report_id, + prepare_step_result, + ))), + output_share, + )) + }) + } + }); + + let mut report_aggregations = Vec::with_capacity(report_aggregation_count); + while receiver.recv_many(&mut report_aggregations, 10).await > 0 {} + + // Await the producer task to resume any panics that may have occurred, and to ensure we can + // unwrap the aggregation job's Arc in a few lines. The only other errors that can occur + // are: a `JoinError` indicating cancellation, which is impossible because we do not cancel + // the task; and a `SendError`, which can only happen if this future is cancelled (in which + // case we will not run this code at all). + let _ = producer_task.await.map_err(|join_error| { + if let Ok(reason) = join_error.try_into_panic() { + panic::resume_unwind(reason); + } + }); + assert_eq!(report_aggregations.len(), report_aggregation_count); - while receiver - .recv_many(&mut report_aggregations_to_write, 10) - .await - > 0 - {} - - // Await the producer task to resume any panics that may have occurred, and to ensure we can - // unwrap the aggregation job's Arc in a few lines. The only other errors that can occur - // are: a `JoinError` indicating cancellation, which is impossible because we do not cancel - // the task; and a `SendError`, which can only happen if this future is cancelled (in which - // case we will not run this code at all). - let _ = producer_task.await.map_err(|join_error| { - if let Ok(reason) = join_error.try_into_panic() { - panic::resume_unwind(reason); - } - }); - assert_eq!(report_aggregations_to_write.len(), report_aggregation_count); - - // Write accumulated aggregation values back to the datastore; this will mark any reports - // that can't be aggregated because the batch is collected with error BatchCollected. - let aggregation_job_id = *aggregation_job.id(); - let aggregation_job = Arc::unwrap_or_clone(aggregation_job) - .with_step(request_step) // Advance the job to the leader's step - .with_last_request_hash(request_hash); - let mut aggregation_job_writer = - AggregationJobWriter::::new( - task, - batch_aggregation_shard_count, - Some(metrics.for_aggregation_job_writer()), - ); - aggregation_job_writer.put(aggregation_job, report_aggregations_to_write)?; - let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?; - Ok(( - AggregationJobResp::Finished { - prepare_resps: prep_resps_by_agg_job - .remove(&aggregation_job_id) - .unwrap_or_default(), - }, - counters, - )) - } + report_aggregations } #[cfg(feature = "test-util")] @@ -315,15 +229,19 @@ pub mod test_util { handler: &impl Handler, ) -> TestConn { let (header, value) = task.aggregator_auth_token().request_authentication(); - post(task.aggregation_job_uri(aggregation_job_id).unwrap().path()) - .with_request_header(header, value) - .with_request_header( - KnownHeaderName::ContentType, - AggregationJobContinueReq::MEDIA_TYPE, - ) - .with_request_body(request.get_encoded().unwrap()) - .run_async(handler) - .await + post( + task.aggregation_job_uri(aggregation_job_id, None) + .unwrap() + .path(), + ) + .with_request_header(header, value) + .with_request_header( + KnownHeaderName::ContentType, + AggregationJobContinueReq::MEDIA_TYPE, + ) + .with_request_body(request.get_encoded().unwrap()) + .run_async(handler) + .await } pub async fn post_aggregation_job_and_decode( @@ -395,11 +313,11 @@ pub mod test_util { #[cfg(test)] mod tests { use crate::aggregator::{ - aggregate_init_tests::{put_aggregation_job, PrepareInitGenerator}, aggregation_job_continue::test_util::{ post_aggregation_job_and_decode, post_aggregation_job_expecting_error, post_aggregation_job_expecting_status, }, + aggregation_job_init::test_util::{put_aggregation_job, PrepareInitGenerator}, http_handlers::{ test_util::{take_problem_details, HttpHandlerTest}, AggregatorHandlerBuilder, @@ -416,7 +334,7 @@ mod tests { }, task::{ test_util::{Task, TaskBuilder}, - BatchMode, + AggregationMode, BatchMode, }, test_util::noop_meter, }; @@ -463,8 +381,12 @@ mod tests { install_test_trace_subscriber(); let aggregation_job_id = random(); - let task = - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; @@ -485,33 +407,34 @@ mod tests { datastore .run_unnamed_tx(|tx| { - let (task, aggregation_param, prepare_init, transcript) = ( - helper_task.clone(), - aggregation_parameter, - prepare_init.clone(), - transcript.clone(), - ); + let helper_task = helper_task.clone(); + let prepare_init = prepare_init.clone(); + let transcript = transcript.clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); - tx.put_scrubbed_report(task.id(), prepare_init.report_share()) - .await - .unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + prepare_init.report_share().metadata().id(), + prepare_init.report_share().metadata().time(), + ) + .await + .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, - aggregation_param, + aggregation_parameter, (), Interval::from_time(prepare_init.report_share().metadata().time()).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *prepare_init.report_share().metadata().id(), *prepare_init.report_share().metadata().time(), @@ -585,7 +508,7 @@ mod tests { AggregationJobResp::Finished { prepare_resps: test_case .first_continue_request - .prepare_steps() + .prepare_continues() .iter() .map(|step| PrepareResp::new(*step.report_id(), PrepareStepResult::Finished)) .collect() @@ -605,7 +528,12 @@ mod tests { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); datastore .put_aggregator_task(&task.leader_view().unwrap()) .await @@ -636,7 +564,12 @@ mod tests { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); datastore .put_aggregator_task(&task.leader_view().unwrap()) .await @@ -646,7 +579,7 @@ mod tests { let (header, value) = task.aggregator_auth_token().request_authentication(); let mut test_conn = delete( - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -673,7 +606,10 @@ mod tests { // to advance to step 0. Should be rejected because that is an illegal transition. let step_zero_request = AggregationJobContinueReq::new( AggregationJobStep::from(0), - test_case.first_continue_request.prepare_steps().to_vec(), + test_case + .first_continue_request + .prepare_continues() + .to_vec(), ); post_aggregation_job_expecting_error( @@ -719,15 +655,18 @@ mod tests { let (before_aggregation_job, before_report_aggregations) = test_case .datastore .run_unnamed_tx(|tx| { - let (task_id, unrelated_prepare_init, aggregation_job_id) = ( - *test_case.task.id(), - unrelated_prepare_init.clone(), - test_case.aggregation_job_id, - ); + let task_id = *test_case.task.id(); + let unrelated_prepare_init = unrelated_prepare_init.clone(); + let aggregation_job_id = test_case.aggregation_job_id; + Box::pin(async move { - tx.put_scrubbed_report(&task_id, unrelated_prepare_init.report_share()) - .await - .unwrap(); + tx.put_scrubbed_report( + &task_id, + unrelated_prepare_init.report_share().metadata().id(), + unrelated_prepare_init.report_share().metadata().time(), + ) + .await + .unwrap(); let aggregation_job = tx .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( @@ -842,7 +781,10 @@ mod tests { // Send another request for a step that the helper is past. Should fail. let past_step_request = AggregationJobContinueReq::new( AggregationJobStep::from(1), - test_case.first_continue_request.prepare_steps().to_vec(), + test_case + .first_continue_request + .prepare_continues() + .to_vec(), ); post_aggregation_job_expecting_error( @@ -866,7 +808,10 @@ mod tests { // helper isn't on that step. let future_step_request = AggregationJobContinueReq::new( AggregationJobStep::from(17), - test_case.first_continue_request.prepare_steps().to_vec(), + test_case + .first_continue_request + .prepare_continues() + .to_vec(), ); post_aggregation_job_expecting_error( @@ -895,7 +840,7 @@ mod tests { let test_conn = delete( test_case .task - .aggregation_job_uri(&test_case.aggregation_job_id) + .aggregation_job_uri(&test_case.aggregation_job_id, None) .unwrap() .path(), ) diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index 16c6d7aff..a8a83a9ea 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -29,8 +29,8 @@ use janus_core::{ time::{Clock, DurationExt as _, TimeExt as _}, vdaf::{ new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128, - Prio3SumVecField64MultiproofHmacSha256Aes128, VdafInstance, VERIFY_KEY_LENGTH, - VERIFY_KEY_LENGTH_HMACSHA256_AES128, + Prio3SumVecField64MultiproofHmacSha256Aes128, VdafInstance, VERIFY_KEY_LENGTH_PRIO3, + VERIFY_KEY_LENGTH_PRIO3_HMACSHA256_AES128, }, }; use janus_messages::{ @@ -310,13 +310,13 @@ impl AggregationJobCreator { match (task.batch_mode(), task.vdaf()) { (task::BatchMode::TimeInterval, VdafInstance::Prio3Count) => { let vdaf = Arc::new(Prio3::new_count(2)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } (task::BatchMode::TimeInterval, VdafInstance::Prio3Sum { max_measurement }) => { let vdaf = Arc::new(Prio3::new_sum(2, *max_measurement)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -330,7 +330,7 @@ impl AggregationJobCreator { }, ) => { let vdaf = Arc::new(Prio3::new_sum_vec(2, *bits, *length, *chunk_length)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -348,7 +348,7 @@ impl AggregationJobCreator { ParallelSum>, >(*proofs, *bits, *length, *chunk_length)?); self.create_aggregation_jobs_for_time_interval_task_no_param::< - VERIFY_KEY_LENGTH_HMACSHA256_AES128, + VERIFY_KEY_LENGTH_PRIO3_HMACSHA256_AES128, Prio3SumVecField64MultiproofHmacSha256Aes128<_>, >(task, vdaf).await } @@ -362,7 +362,7 @@ impl AggregationJobCreator { }, ) => { let vdaf = Arc::new(Prio3::new_histogram(2, *length, *chunk_length)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -378,13 +378,13 @@ impl AggregationJobCreator { Prio3FixedPointBoundedL2VecSumBitSize::BitSize16 => { let vdaf: Arc>> = Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum(2, *length)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) .await } Prio3FixedPointBoundedL2VecSumBitSize::BitSize32 => { let vdaf: Arc>> = Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum(2, *length)?); - self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) + self.create_aggregation_jobs_for_time_interval_task_no_param::>>(task, vdaf) .await } }, @@ -408,12 +408,12 @@ impl AggregationJobCreator { Prio3< prio::flp::types::Count, vdaf::xof::XofTurboShake128, - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, >, > = Arc::new(Prio3::new_count(2)?); let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, Prio3Count, >(task, vdaf, batch_time_window_size).await } @@ -427,7 +427,7 @@ impl AggregationJobCreator { let vdaf = Arc::new(Prio3::new_sum(2, *max_measurement)?); let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, Prio3Sum, >(task, vdaf, batch_time_window_size).await } @@ -446,7 +446,7 @@ impl AggregationJobCreator { let vdaf = Arc::new(Prio3::new_sum_vec(2, *bits, *length, *chunk_length)?); let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, Prio3SumVec, >(task, vdaf, batch_time_window_size).await } @@ -468,7 +468,7 @@ impl AggregationJobCreator { >(*proofs, *bits, *length, *chunk_length)?); let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH_HMACSHA256_AES128, + VERIFY_KEY_LENGTH_PRIO3_HMACSHA256_AES128, Prio3SumVecField64MultiproofHmacSha256Aes128<_>, >(task, vdaf, batch_time_window_size).await } @@ -486,7 +486,7 @@ impl AggregationJobCreator { let vdaf = Arc::new(Prio3::new_histogram(2, *length, *chunk_length)?); let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, Prio3Histogram, >(task, vdaf, batch_time_window_size).await } @@ -509,7 +509,7 @@ impl AggregationJobCreator { let vdaf: Arc>> = Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum(2, *length)?); self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, Prio3FixedPointBoundedL2VecSum>, >(task, vdaf, batch_time_window_size).await } @@ -517,7 +517,7 @@ impl AggregationJobCreator { let vdaf: Arc>> = Arc::new(Prio3::new_fixedpoint_boundedl2_vec_sum(2, *length)?); self.create_aggregation_jobs_for_leader_selected_task_no_param::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, Prio3FixedPointBoundedL2VecSum>, >(task, vdaf, batch_time_window_size).await } @@ -663,7 +663,7 @@ impl AggregationJobCreator { (), (), client_timestamp_interval, - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); @@ -806,7 +806,7 @@ impl AggregationJobCreator { aggregation_param.clone(), (), client_timestamp_interval, - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let report_aggregations: Vec<_> = agg_job_reports @@ -915,14 +915,14 @@ mod tests { test_util::ephemeral_datastore, Transaction, }, - task::{test_util::TaskBuilder, BatchMode as TaskBatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode as TaskBatchMode}, test_util::noop_meter, }; use janus_core::{ hpke::HpkeKeypair, test_util::{install_test_trace_subscriber, run_vdaf}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, - vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + vdaf::{VdafInstance, VERIFY_KEY_LENGTH_PRIO3}, }; use janus_messages::{ batch_mode::{LeaderSelected, TimeInterval}, @@ -966,10 +966,14 @@ mod tests { let report_time = Time::from_seconds_since_epoch(0); let leader_task = Arc::new( - TaskBuilder::new(TaskBatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), + TaskBuilder::new( + TaskBatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(), ); let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &report_time).unwrap(); @@ -993,10 +997,14 @@ mod tests { )); let helper_task = Arc::new( - TaskBuilder::new(TaskBatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .helper_view() - .unwrap(), + TaskBuilder::new( + TaskBatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .helper_view() + .unwrap(), ); let helper_report = Arc::new(LeaderStoredReport::new_dummy( *helper_task.id(), @@ -1072,7 +1080,7 @@ mod tests { Box::pin(async move { let (leader_aggregations, leader_batch_aggregations) = read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, _, _, @@ -1145,10 +1153,14 @@ mod tests { const MAX_AGGREGATION_JOB_SIZE: usize = 60; let task = Arc::new( - TaskBuilder::new(TaskBatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), + TaskBuilder::new( + TaskBatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(), ); // In one batch, create enough reports to fill 2 max-size aggregation jobs, a min-size @@ -1245,7 +1257,7 @@ mod tests { Box::pin(async move { Ok(read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, _, _, @@ -1334,10 +1346,14 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = ephemeral_datastore.datastore(clock.clone()).await; let task = Arc::new( - TaskBuilder::new(TaskBatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), + TaskBuilder::new( + TaskBatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(), ); let report_time = clock.now(); @@ -1432,7 +1448,7 @@ mod tests { Box::pin(async move { Ok(read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, _, _, @@ -1489,7 +1505,7 @@ mod tests { Box::pin(async move { Ok(read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, _, _, @@ -1546,10 +1562,14 @@ mod tests { const MAX_AGGREGATION_JOB_SIZE: usize = 60; let task = Arc::new( - TaskBuilder::new(TaskBatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), + TaskBuilder::new( + TaskBatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(), ); // Create a min-size batch. @@ -1594,7 +1614,7 @@ mod tests { tx.put_client_report(report).await.unwrap(); } tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -1658,7 +1678,7 @@ mod tests { Box::pin(async move { Ok(read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, _, _, @@ -1729,6 +1749,7 @@ mod tests { TaskBatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Prio3Count, ) .with_min_batch_size(MIN_BATCH_SIZE as u64) @@ -1829,7 +1850,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -1927,6 +1948,7 @@ mod tests { TaskBatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Prio3Count, ) .with_min_batch_size(MIN_BATCH_SIZE as u64) @@ -2022,7 +2044,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -2089,6 +2111,7 @@ mod tests { TaskBatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Prio3Count, ) .with_min_batch_size(MIN_BATCH_SIZE as u64) @@ -2189,7 +2212,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -2285,7 +2308,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -2351,6 +2374,7 @@ mod tests { TaskBatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Prio3Count, ) .with_min_batch_size(MIN_BATCH_SIZE as u64) @@ -2451,7 +2475,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -2555,7 +2579,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -2622,6 +2646,7 @@ mod tests { TaskBatchMode::LeaderSelected { batch_time_window_size: Some(batch_time_window_size), }, + AggregationMode::Synchronous, VdafInstance::Prio3Count, ) .with_min_batch_size(MIN_BATCH_SIZE as u64) @@ -2758,7 +2783,7 @@ mod tests { .await .unwrap(), read_and_verify_aggregate_info_for_task::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, _, _, @@ -2860,6 +2885,7 @@ mod tests { let task = Arc::new( TaskBuilder::new( TaskBatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() diff --git a/aggregator/src/aggregator/aggregation_job_driver.rs b/aggregator/src/aggregator/aggregation_job_driver.rs index 67357692a..4cd754551 100644 --- a/aggregator/src/aggregator/aggregation_job_driver.rs +++ b/aggregator/src/aggregator/aggregation_job_driver.rs @@ -1,6 +1,8 @@ use crate::{ aggregator::{ aggregate_step_failure_counter, + aggregation_job_continue::{compute_helper_aggregate_continue, AggregateContinueMetrics}, + aggregation_job_init::{compute_helper_aggregate_init, AggregateInitMetrics}, aggregation_job_writer::{ AggregationJobWriter, AggregationJobWriterMetrics, UpdateWrite, WritableReportAggregation, @@ -11,6 +13,7 @@ use crate::{ report_aggregation_success_counter, send_request_to_helper, write_task_aggregation_counter, Error, RequestBody, }, + cache::HpkeKeypairCache, metrics::aggregated_report_share_dimension_histogram, }; use anyhow::{anyhow, Result}; @@ -28,7 +31,7 @@ use janus_aggregator_core::{ }, Datastore, }, - task::{self, AggregatorTask, VerifyKey}, + task::{self, AggregatorTask}, TIME_HISTOGRAM_BOUNDARIES, }; use janus_core::{ @@ -56,12 +59,17 @@ use rayon::iter::{IndexedParallelIterator as _, IntoParallelIterator as _, Paral use reqwest::Method; use retry_after::RetryAfter; use std::{ + borrow::Cow, collections::HashSet, panic, sync::Arc, time::{Duration, UNIX_EPOCH}, }; -use tokio::{join, sync::mpsc, try_join}; +use tokio::{ + join, + sync::{mpsc, Mutex}, + try_join, +}; use tracing::{debug, error, info, info_span, trace_span, warn, Span}; #[cfg(test)] @@ -73,6 +81,8 @@ pub struct AggregationJobDriver { // Configuration. batch_aggregation_shard_count: u64, task_counter_shard_count: u64, + hpke_configs_refresh_interval: Duration, + default_async_poll_interval: Duration, // Dependencies. http_client: reqwest::Client, @@ -102,6 +112,8 @@ where meter: &Meter, batch_aggregation_shard_count: u64, task_counter_shard_count: u64, + hpke_configs_refresh_interval: Duration, + default_async_poll_interval: Duration, ) -> Self { let aggregation_success_counter = report_aggregation_success_counter(meter); let aggregate_step_failure_counter = aggregate_step_failure_counter(meter); @@ -134,6 +146,8 @@ where Self { batch_aggregation_shard_count, task_counter_shard_count, + hpke_configs_refresh_interval, + default_async_poll_interval, http_client, backoff, aggregation_success_counter, @@ -148,17 +162,28 @@ where async fn step_aggregation_job( &self, datastore: Arc>, + hpke_keypairs: Arc, lease: Arc>, ) -> Result<(), Error> { match lease.leased().batch_mode() { task::BatchMode::TimeInterval => { vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await + self.step_aggregation_job_generic::( + datastore, + hpke_keypairs, + Arc::new(vdaf), + lease + ).await }) } task::BatchMode::LeaderSelected { .. } => { vdaf_dispatch!(lease.leased().vdaf(), (vdaf, VdafType, VERIFY_KEY_LENGTH) => { - self.step_aggregation_job_generic::(datastore, Arc::new(vdaf), lease).await + self.step_aggregation_job_generic::( + datastore, + hpke_keypairs, + Arc::new(vdaf), + lease + ).await }) } } @@ -172,6 +197,7 @@ where >( &self, datastore: Arc>, + hpke_keypairs: Arc, vdaf: Arc, lease: Arc>, ) -> Result<(), Error> @@ -188,7 +214,7 @@ where A::PublicShare: PartialEq + Send + Sync, { // Read all information about the aggregation job. - let (task, aggregation_job, report_aggregations, verify_key) = datastore + let (task, aggregation_job, report_aggregations) = datastore .run_tx("step_aggregation_job_generic", |tx| { let (lease, vdaf) = (Arc::clone(&lease), Arc::clone(&vdaf)); Box::pin(async move { @@ -198,18 +224,13 @@ where lease.leased().aggregation_job_id(), ); - let (task, aggregation_job) = try_join!(task_future, aggregation_job_future,)?; + let (task, aggregation_job) = try_join!(task_future, aggregation_job_future)?; let task = task.ok_or_else(|| { datastore::Error::User( anyhow!("couldn't find task {}", lease.leased().task_id()).into(), ) })?; - let verify_key = task.vdaf_verify_key().map_err(|_| { - datastore::Error::User( - anyhow!("VDAF verification key has wrong length").into(), - ) - })?; let aggregation_job = aggregation_job.ok_or_else(|| { datastore::Error::User( anyhow!( @@ -231,31 +252,97 @@ where ) .await?; - Ok(( - Arc::new(task), - aggregation_job, - report_aggregations, - verify_key, - )) + Ok((task, aggregation_job, report_aggregations)) }) }) .await?; - // Figure out the next step based on the non-error report aggregation states, and dispatch accordingly. - let (mut saw_init, mut saw_continue, mut saw_poll, mut saw_finished) = - (false, false, false, false); + match task.role() { + Role::Leader => { + self.step_aggregation_job_leader( + datastore, + vdaf, + lease, + task, + aggregation_job, + report_aggregations, + ) + .await + } + + Role::Helper => { + self.step_aggregation_job_helper( + datastore, + hpke_keypairs, + vdaf, + lease, + task, + aggregation_job, + report_aggregations, + ) + .await + } + + _ => Err(Error::Internal(format!("unexpected role {}", task.role()))), + } + } + + async fn step_aggregation_job_leader< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A, + >( + &self, + datastore: Arc>, + vdaf: Arc, + lease: Arc>, + task: AggregatorTask, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::OutputShare: PartialEq + Eq + Send + Sync, + for<'a> A::PrepareState: + PartialEq + Eq + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareMessage: PartialEq + Eq + Send + Sync, + A::PrepareShare: PartialEq + Eq + Send + Sync, + A::InputShare: PartialEq + Send + Sync, + A::PublicShare: PartialEq + Send + Sync, + { + // Figure out the next step based on the non-error report aggregation states, and dispatch + // accordingly. + let mut saw_init = false; + let mut saw_continue = false; + let mut saw_poll = false; + let mut saw_finished = false; for report_aggregation in &report_aggregations { match report_aggregation.state() { ReportAggregationState::LeaderInit { .. } => saw_init = true, ReportAggregationState::LeaderContinue { .. } => saw_continue = true, ReportAggregationState::LeaderPoll { .. } => saw_poll = true, + ReportAggregationState::HelperInitProcessing { .. } => { + return Err(Error::Internal( + "Leader encountered unexpected ReportAggregationState::HelperInitProcessing" + .to_string() + )); + } ReportAggregationState::HelperContinue { .. } => { return Err(Error::Internal( "Leader encountered unexpected ReportAggregationState::HelperContinue" .to_string(), )); } + ReportAggregationState::HelperContinueProcessing { .. } => { + return Err(Error::Internal( + "Leader encountered unexpected ReportAggregationState::HelperContinueProcessing" + .to_string() + )); + } ReportAggregationState::Finished => saw_finished = true, ReportAggregationState::Failed { .. } => (), // ignore failed aggregations @@ -264,21 +351,20 @@ where match (saw_init, saw_continue, saw_poll, saw_finished) { // Only saw report aggregations in state "init" (or failed). (true, false, false, false) => { - self.step_aggregation_job_aggregate_init( + self.step_aggregation_job_leader_init( datastore, vdaf, lease, task, aggregation_job, report_aggregations, - verify_key, ) .await } // Only saw report aggregations in state "continue" (or failed). (false, true, false, false) => { - self.step_aggregation_job_aggregate_continue( + self.step_aggregation_job_leader_continue( datastore, vdaf, lease, @@ -291,7 +377,7 @@ where // Only saw report aggregations in state "poll" (or failed). (false, false, true, false) => { - self.step_aggregation_job_aggregate_poll( + self.step_aggregation_job_leader_poll( datastore, vdaf, lease, @@ -311,7 +397,7 @@ where } #[allow(clippy::too_many_arguments)] - async fn step_aggregation_job_aggregate_init< + async fn step_aggregation_job_leader_init< const SEED_SIZE: usize, C: Clock, B: CollectableBatchMode, @@ -321,10 +407,9 @@ where datastore: Arc>, vdaf: Arc, lease: Arc>, - task: Arc, + task: AggregatorTask, aggregation_job: AggregationJob, report_aggregations: Vec>, - verify_key: VerifyKey, ) -> Result<(), Error> where A::AggregationParam: Send + Sync + PartialEq + Eq, @@ -336,8 +421,6 @@ where A::PrepareMessage: PartialEq + Eq + Send + Sync, A::PublicShare: PartialEq + Send + Sync, { - let aggregation_job = Arc::new(aggregation_job); - // Only process non-failed report aggregations. let report_aggregations: Vec<_> = report_aggregations .into_iter() @@ -358,6 +441,10 @@ where // on receiving an error. let (ra_sender, mut ra_receiver) = mpsc::unbounded_channel(); let (pi_and_sa_sender, mut pi_and_sa_receiver) = mpsc::unbounded_channel(); + let aggregation_job = Arc::new(aggregation_job); + let verify_key = task + .vdaf_verify_key() + .map_err(|_| Error::Internal("VDAF verification key has wrong length".to_string()))?; let producer_task = tokio::task::spawn_blocking({ let parent_span = Span::current(); let vdaf = Arc::clone(&vdaf); @@ -589,7 +676,7 @@ where let aggregation_job: AggregationJob = Arc::unwrap_or_clone(aggregation_job); - self.process_response_from_helper( + self.step_aggregation_job_leader_process_response( datastore, vdaf, lease, @@ -603,7 +690,7 @@ where .await } - async fn step_aggregation_job_aggregate_continue< + async fn step_aggregation_job_leader_continue< const SEED_SIZE: usize, C: Clock, B: CollectableBatchMode, @@ -613,7 +700,7 @@ where datastore: Arc>, vdaf: Arc, lease: Arc>, - task: Arc, + task: AggregatorTask, aggregation_job: AggregationJob, report_aggregations: Vec>, ) -> Result<(), Error> @@ -780,7 +867,7 @@ where let resp = AggregationJobResp::get_decoded(http_response.body()).map_err(Error::MessageDecode)?; - self.process_response_from_helper( + self.step_aggregation_job_leader_process_response( datastore, vdaf, lease, @@ -794,7 +881,7 @@ where .await } - async fn step_aggregation_job_aggregate_poll< + async fn step_aggregation_job_leader_poll< const SEED_SIZE: usize, C: Clock, B: CollectableBatchMode, @@ -804,7 +891,7 @@ where datastore: Arc>, vdaf: Arc, lease: Arc>, - task: Arc, + task: AggregatorTask, aggregation_job: AggregationJob, report_aggregations: Vec>, ) -> Result<(), Error> @@ -864,7 +951,7 @@ where let resp = AggregationJobResp::get_decoded(http_response.body()).map_err(Error::MessageDecode)?; - self.process_response_from_helper( + self.step_aggregation_job_leader_process_response( datastore, vdaf, lease, @@ -879,7 +966,7 @@ where } #[allow(clippy::too_many_arguments)] - async fn process_response_from_helper< + async fn step_aggregation_job_leader_process_response< const SEED_SIZE: usize, C: Clock, B: CollectableBatchMode, @@ -889,7 +976,7 @@ where datastore: Arc>, vdaf: Arc, lease: Arc>, - task: Arc, + task: AggregatorTask, aggregation_job: AggregationJob, stepped_aggregations: Vec>, report_aggregations_to_write: Vec>, @@ -908,7 +995,7 @@ where { match helper_resp { AggregationJobResp::Processing => { - self.process_response_from_helper_pending( + self.step_aggregation_job_leader_process_response_processing( datastore, vdaf, lease, @@ -922,7 +1009,7 @@ where } AggregationJobResp::Finished { prepare_resps } => { - self.process_response_from_helper_finished( + self.step_aggregation_job_leader_process_response_finished( datastore, vdaf, lease, @@ -937,7 +1024,7 @@ where } } - async fn process_response_from_helper_pending< + async fn step_aggregation_job_leader_process_response_processing< const SEED_SIZE: usize, C: Clock, B: CollectableBatchMode, @@ -947,7 +1034,7 @@ where datastore: Arc>, vdaf: Arc, lease: Arc>, - task: Arc, + task: AggregatorTask, aggregation_job: AggregationJob, stepped_aggregations: Vec>, mut report_aggregations_to_write: Vec>, @@ -984,9 +1071,10 @@ where )); // Write everything back to storage. + let task_id = *task.id(); let mut aggregation_job_writer = AggregationJobWriter::::new( - Arc::clone(&task), + Arc::new(task), self.batch_aggregation_shard_count, Some(AggregationJobWriterMetrics { report_aggregation_success_counter: self.aggregation_success_counter.clone(), @@ -1002,9 +1090,9 @@ where let retry_after = retry_after .map(|ra| retry_after_to_duration(datastore.clock(), ra)) .transpose()? - .or_else(|| Some(Duration::from_secs(60))); + .unwrap_or(self.default_async_poll_interval); let counters = datastore - .run_tx("process_response_from_helper_pending", |tx| { + .run_tx("process_response_from_helper_processing", |tx| { let vdaf = Arc::clone(&vdaf); let aggregation_job_writer = Arc::clone(&aggregation_job_writer); let lease = Arc::clone(&lease); @@ -1012,24 +1100,19 @@ where Box::pin(async move { let ((_, counters), _) = try_join!( aggregation_job_writer.write(tx, Arc::clone(&vdaf)), - tx.release_aggregation_job(&lease, retry_after.as_ref()), + tx.release_aggregation_job(&lease, Some(&retry_after)), )?; Ok(counters) }) }) .await?; - write_task_aggregation_counter( - datastore, - self.task_counter_shard_count, - *task.id(), - counters, - ); + write_task_aggregation_counter(datastore, self.task_counter_shard_count, task_id, counters); Ok(()) } - async fn process_response_from_helper_finished< + async fn step_aggregation_job_leader_process_response_finished< const SEED_SIZE: usize, C: Clock, B: CollectableBatchMode, @@ -1039,7 +1122,7 @@ where datastore: Arc>, vdaf: Arc, lease: Arc>, - task: Arc, + task: AggregatorTask, aggregation_job: AggregationJob, stepped_aggregations: Vec>, mut report_aggregations_to_write: Vec>, @@ -1091,7 +1174,7 @@ where move || { let span = info_span!( parent: parent_span, - "process_Response_from_helper threadpool task" + "process_response_from_helper threadpool task" ); let ctx = vdaf_application_context(&task_id); @@ -1214,9 +1297,10 @@ where // Write everything back to storage. let aggregation_job = Arc::unwrap_or_clone(aggregation_job); + let task_id = *task.id(); let mut aggregation_job_writer = AggregationJobWriter::::new( - Arc::clone(&task), + Arc::new(task), self.batch_aggregation_shard_count, Some(AggregationJobWriterMetrics { report_aggregation_success_counter: self.aggregation_success_counter.clone(), @@ -1249,6 +1333,312 @@ where }) .await?; + write_task_aggregation_counter(datastore, self.task_counter_shard_count, task_id, counters); + + Ok(()) + } + + async fn step_aggregation_job_helper< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A, + >( + &self, + datastore: Arc>, + hpke_keypairs: Arc, + vdaf: Arc, + lease: Arc>, + task: AggregatorTask, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::OutputShare: PartialEq + Eq + Send + Sync, + for<'a> A::PrepareState: + PartialEq + Eq + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareMessage: PartialEq + Eq + Send + Sync, + A::PrepareShare: PartialEq + Eq + Send + Sync, + A::InputShare: PartialEq + Send + Sync, + A::PublicShare: PartialEq + Send + Sync, + { + // Figure out the next step based on the non-error report aggregation states, and dispatch + // accordingly. + let mut saw_init = false; + let mut saw_continue = false; + let mut saw_finished = false; + for report_aggregation in &report_aggregations { + match report_aggregation.state() { + ReportAggregationState::LeaderInit { .. } => { + return Err(Error::Internal( + "Helper encountered unexpected ReportAggregationState::LeaderInit" + .to_string(), + )); + } + ReportAggregationState::LeaderContinue { .. } => { + return Err(Error::Internal( + "Helper encountered unexpected ReportAggregationState::LeaderContinue" + .to_string(), + )); + } + ReportAggregationState::LeaderPoll { .. } => { + return Err(Error::Internal( + "Leader encountered unexpected ReportAggregationState::LeaderPoll" + .to_string(), + )); + } + + ReportAggregationState::HelperInitProcessing { .. } => saw_init = true, + ReportAggregationState::HelperContinue { .. } => { + return Err(Error::Internal( + "Helper encountered unexpected ReportAggregationState::HelperContinue" + .to_string(), + )); + } + ReportAggregationState::HelperContinueProcessing { .. } => saw_continue = true, + + ReportAggregationState::Finished => saw_finished = true, + ReportAggregationState::Failed { .. } => (), // ignore failed aggregations + } + } + + match (saw_init, saw_continue, saw_finished) { + // Only saw report aggregations in state "init processing" (or failed). + (true, false, false) => { + self.step_aggregation_job_helper_init( + datastore, + hpke_keypairs, + vdaf, + lease, + task, + aggregation_job, + report_aggregations, + ) + .await + } + + // Only saw report aggregations in state "continue processing" (or failed). + (false, true, false) => { + self.step_aggregation_job_helper_continue( + datastore, + vdaf, + lease, + task, + aggregation_job, + report_aggregations, + ) + .await + } + + _ => Err(Error::Internal(format!( + "unexpected combination of report aggregation states (saw_init = {saw_init}, \ + saw_continue = {saw_continue}, saw_finished = {saw_finished})", + ))), + } + } + + async fn step_aggregation_job_helper_init< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A, + >( + &self, + datastore: Arc>, + hpke_keypairs: Arc, + vdaf: Arc, + lease: Arc>, + task: AggregatorTask, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::OutputShare: PartialEq + Eq + Send + Sync, + for<'a> A::PrepareState: + PartialEq + Eq + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareMessage: PartialEq + Eq + Send + Sync, + A::PrepareShare: PartialEq + Eq + Send + Sync, + A::InputShare: PartialEq + Send + Sync, + A::PublicShare: PartialEq + Send + Sync, + { + // Only process report aggregations in the HelperInitProcessing state. + let report_aggregations = report_aggregations + .into_iter() + .filter(|ra| { + matches!( + ra.state(), + ReportAggregationState::HelperInitProcessing { .. } + ) + }) + .collect(); + + // Compute the next aggregation step. + let task = Arc::new(task); + let aggregation_job = + Arc::new(aggregation_job.with_state(AggregationJobState::AwaitingRequest)); + let report_aggregations = Arc::new( + compute_helper_aggregate_init( + datastore.clock(), + hpke_keypairs, + Arc::clone(&vdaf), + AggregateInitMetrics::new(self.aggregate_step_failure_counter.clone()), + Arc::clone(&task), + Arc::clone(&aggregation_job), + report_aggregations, + ) + .await?, + ); + + // Write results back to datastore. + let metrics = AggregationJobWriterMetrics { + report_aggregation_success_counter: self.aggregation_success_counter.clone(), + aggregate_step_failure_counter: self.aggregate_step_failure_counter.clone(), + aggregated_report_share_dimension_histogram: self + .aggregated_report_share_dimension_histogram + .clone(), + }; + + let counters = datastore + .run_tx("aggregate_init_driver_write", |tx| { + let vdaf = Arc::clone(&vdaf); + let lease = Arc::clone(&lease); + let task = Arc::clone(&task); + let metrics = metrics.clone(); + let aggregation_job = Arc::clone(&aggregation_job); + let report_aggregations = Arc::clone(&report_aggregations); + let batch_aggregation_shard_count = self.batch_aggregation_shard_count; + + Box::pin(async move { + // Write aggregation job, report aggregations, and batch aggregations. + let report_aggregations = + report_aggregations.iter().map(Cow::Borrowed).collect(); + + let mut aggregation_job_writer = + AggregationJobWriter::::new( + task, + batch_aggregation_shard_count, + Some(metrics), + ); + aggregation_job_writer + .put(aggregation_job.as_ref().clone(), report_aggregations)?; + let ((_, counters), _) = try_join!( + aggregation_job_writer.write(tx, vdaf), + tx.release_aggregation_job(&lease, None), + )?; + Ok(counters) + }) + }) + .await?; + + write_task_aggregation_counter( + datastore, + self.task_counter_shard_count, + *task.id(), + counters, + ); + + Ok(()) + } + + async fn step_aggregation_job_helper_continue< + const SEED_SIZE: usize, + C: Clock, + B: CollectableBatchMode, + A, + >( + &self, + datastore: Arc>, + vdaf: Arc, + lease: Arc>, + task: AggregatorTask, + aggregation_job: AggregationJob, + report_aggregations: Vec>, + ) -> Result<(), Error> + where + A: vdaf::Aggregator + Send + Sync + 'static, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::OutputShare: PartialEq + Eq + Send + Sync, + for<'a> A::PrepareState: + PartialEq + Eq + Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)>, + A::PrepareMessage: PartialEq + Eq + Send + Sync, + A::PrepareShare: PartialEq + Eq + Send + Sync, + A::InputShare: PartialEq + Send + Sync, + A::PublicShare: PartialEq + Send + Sync, + { + // Only process report aggregations in the HelperContinueProcessing state. + let report_aggregations = report_aggregations + .into_iter() + .filter(|ra| { + matches!( + ra.state(), + ReportAggregationState::HelperContinueProcessing { .. } + ) + }) + .collect(); + + // Compute the next aggregation step. + let task = Arc::new(task); + let aggregation_job = + Arc::new(aggregation_job.with_state(AggregationJobState::AwaitingRequest)); + let report_aggregations = Arc::new( + compute_helper_aggregate_continue( + Arc::clone(&vdaf), + AggregateContinueMetrics::new(self.aggregate_step_failure_counter.clone()), + Arc::clone(&task), + Arc::clone(&aggregation_job), + report_aggregations, + ) + .await, + ); + + // Write results back to datastore. + let metrics = AggregationJobWriterMetrics { + report_aggregation_success_counter: self.aggregation_success_counter.clone(), + aggregate_step_failure_counter: self.aggregate_step_failure_counter.clone(), + aggregated_report_share_dimension_histogram: self + .aggregated_report_share_dimension_histogram + .clone(), + }; + + let counters = datastore + .run_tx("aggregate_continue_driver_write", |tx| { + let vdaf = Arc::clone(&vdaf); + let lease = Arc::clone(&lease); + let task = Arc::clone(&task); + let metrics = metrics.clone(); + let aggregation_job = Arc::clone(&aggregation_job); + let report_aggregations = Arc::clone(&report_aggregations); + + let batch_aggregation_shard_count = self.batch_aggregation_shard_count; + + Box::pin(async move { + let report_aggregations = + report_aggregations.iter().map(Cow::Borrowed).collect(); + let mut aggregation_job_writer = + AggregationJobWriter::::new( + task, + batch_aggregation_shard_count, + Some(metrics), + ); + aggregation_job_writer + .put(aggregation_job.as_ref().clone(), report_aggregations)?; + + let ((_, counters), _) = try_join!( + aggregation_job_writer.write(tx, vdaf), + tx.release_aggregation_job(&lease, None), + )?; + Ok(counters) + }) + }) + .await?; + write_task_aggregation_counter( datastore, self.task_counter_shard_count, @@ -1417,6 +1807,7 @@ where { move |max_acquire_count: usize| { let datastore = Arc::clone(&datastore); + Box::pin(async move { datastore .run_tx("acquire_aggregation_jobs", |tx| { @@ -1439,9 +1830,14 @@ where datastore: Arc>, maximum_attempts_before_failure: usize, ) -> impl Fn(Lease) -> BoxFuture<'static, Result<(), Error>> { + let hpke_keypairs = Arc::new(Mutex::new(None)); + move |lease| { - let (this, datastore) = (Arc::clone(&self), Arc::clone(&datastore)); + let this = Arc::clone(&self); + let datastore = Arc::clone(&datastore); + let hpke_keypairs = Arc::clone(&hpke_keypairs); let lease = Arc::new(lease); + Box::pin(async move { let attempts = lease.lease_attempts(); if attempts > maximum_attempts_before_failure { @@ -1458,8 +1854,30 @@ where this.job_retry_counter.add(1, &[]); } + let hpke_keypairs = { + let mut hpke_keypairs = hpke_keypairs.lock().await; + match hpke_keypairs.as_ref() { + Some(hpke_keypairs) => Arc::clone(hpke_keypairs), + None => { + let hk = Arc::new( + HpkeKeypairCache::new( + Arc::clone(&datastore), + this.hpke_configs_refresh_interval, + ) + .await?, + ); + *hpke_keypairs = Some(Arc::clone(&hk)); + hk + } + } + }; + match this - .step_aggregation_job(Arc::clone(&datastore), Arc::clone(&lease)) + .step_aggregation_job( + Arc::clone(&datastore), + Arc::clone(&hpke_keypairs), + Arc::clone(&lease), + ) .await { Ok(_) => Ok(()), diff --git a/aggregator/src/aggregator/aggregation_job_driver/tests.rs b/aggregator/src/aggregator/aggregation_job_driver/tests.rs index c9b12127c..a5b4fccab 100644 --- a/aggregator/src/aggregator/aggregation_job_driver/tests.rs +++ b/aggregator/src/aggregator/aggregation_job_driver/tests.rs @@ -1,11 +1,16 @@ +#![allow(clippy::unit_arg)] // allow reference to dummy::Vdaf's public share, which has the unit type + use crate::{ aggregator::{ aggregation_job_driver::AggregationJobDriver, - test_util::assert_task_aggregation_counter, - test_util::{BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT}, + test_util::{ + assert_task_aggregation_counter, generate_helper_report_share, + BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + }, Error, }, binary_utils::job_driver::JobDriver, + cache::HpkeKeypairCache, }; use assert_matches::assert_matches; use futures::future::join_all; @@ -21,7 +26,7 @@ use janus_aggregator_core::{ test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, }, - task::{test_util::TaskBuilder, AggregatorTask, BatchMode, VerifyKey}, + task::{test_util::TaskBuilder, AggregationMode, AggregatorTask, BatchMode, VerifyKey}, test_util::noop_meter, }; use janus_core::{ @@ -30,7 +35,7 @@ use janus_core::{ retries::test_util::LimitedRetryer, test_util::{install_test_trace_subscriber, run_vdaf, runtime::TestRuntimeManager}, time::{Clock, IntervalExt, MockClock, TimeExt}, - vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + vdaf::{VdafInstance, VERIFY_KEY_LENGTH_PRIO3}, Runtime, }; use janus_messages::{ @@ -55,6 +60,8 @@ use std::{sync::Arc, time::Duration as StdDuration}; use tokio::time::timeout; use trillium_tokio::Stopper; +const DEFAULT_ASYNC_POLL_INTERVAL: StdDuration = StdDuration::from_secs(1); + #[tokio::test] async fn aggregation_job_driver() { // This is a minimal test that AggregationJobDriver::run() will successfully find @@ -70,10 +77,15 @@ async fn aggregation_job_driver() { let mut runtime_manager = TestRuntimeManager::new(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -128,7 +140,7 @@ async fn aggregation_job_driver() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -199,7 +211,7 @@ async fn aggregation_job_driver() { server .mock( req_method, - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -218,6 +230,8 @@ async fn aggregation_job_driver() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, )); let stopper = Stopper::new(); @@ -347,18 +361,23 @@ async fn aggregation_job_driver() { } #[tokio::test] -async fn sync_time_interval_aggregation_job_init_single_step() { +async fn leader_sync_time_interval_aggregation_job_init_single_step() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -368,7 +387,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { .unwrap(); let batch_identifier = TimeInterval::to_batch_identifier(&leader_task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time, Vec::new()); - let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -450,7 +469,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { } tx.put_aggregation_job(&AggregationJob::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -460,7 +479,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -483,7 +502,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { } tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -542,7 +561,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { let mocked_aggregate_failure = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -555,7 +574,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { let mocked_aggregate_success = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -578,9 +597,19 @@ async fn sync_time_interval_aggregation_job_init_single_step() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -588,17 +617,18 @@ async fn sync_time_interval_aggregation_job_init_single_step() { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::Finished, - AggregationJobStep::from(1), - ); + let want_aggregation_job = + AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + ); - let want_report_aggregation = ReportAggregation::::new( + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -608,7 +638,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { ReportAggregationState::Finished, ); let want_repeated_public_extension_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *repeated_public_extension_report.metadata().id(), @@ -620,7 +650,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { }, ); let want_repeated_private_extension_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *repeated_private_extension_report.metadata().id(), @@ -632,7 +662,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { }, ); let want_repeated_public_private_extension_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *repeated_public_private_extension_report.metadata().id(), @@ -645,7 +675,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { ); let want_batch_aggregations = Vec::from([BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -681,7 +711,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -737,7 +767,7 @@ async fn sync_time_interval_aggregation_job_init_single_step() { .unwrap() .unwrap(); let batch_aggregations = merge_batch_aggregations_by_batch( - tx.get_batch_aggregations_for_task::(&vdaf, task.id()) + tx.get_batch_aggregations_for_task::(&vdaf, task.id()) .await .unwrap(), ); @@ -776,18 +806,23 @@ async fn sync_time_interval_aggregation_job_init_single_step() { } #[tokio::test] -async fn sync_time_interval_aggregation_job_init_two_steps() { +async fn leader_sync_time_interval_aggregation_job_init_two_steps() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -839,7 +874,7 @@ async fn sync_time_interval_aggregation_job_init_two_steps() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -908,7 +943,7 @@ async fn sync_time_interval_aggregation_job_init_two_steps() { let mocked_aggregate_success = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -931,9 +966,19 @@ async fn sync_time_interval_aggregation_job_init_two_steps() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -946,7 +991,7 @@ async fn sync_time_interval_aggregation_job_init_two_steps() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), ); let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( @@ -1027,7 +1072,7 @@ async fn sync_time_interval_aggregation_job_init_two_steps() { } #[tokio::test] -async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { +async fn leader_sync_time_interval_aggregation_job_init_partially_garbage_collected() { // This is a regression test for https://github.com/divviup/janus/issues/2464. const OLDEST_ALLOWED_REPORT_TIMESTAMP: Time = Time::from_seconds_since_epoch(1000); @@ -1040,13 +1085,18 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { let clock = MockClock::new(OLDEST_ALLOWED_REPORT_TIMESTAMP); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .with_time_precision(TIME_PRECISION) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .with_time_precision(TIME_PRECISION) + .build(); let leader_task = task.leader_view().unwrap(); @@ -1069,7 +1119,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { let gc_ineligible_report_metadata = ReportMetadata::new(random(), gc_ineligible_time, Vec::new()); - let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); let gc_eligible_transcript = run_vdaf( vdaf.as_ref(), @@ -1131,7 +1181,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { .unwrap(); tx.put_aggregation_job(&AggregationJob::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -1144,7 +1194,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { gc_ineligible_time.difference(&gc_eligible_time).unwrap(), ) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -1161,7 +1211,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -1181,7 +1231,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { .await .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -1267,7 +1317,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { let mocked_aggregate_init = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -1290,31 +1340,42 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); // Verify. mocked_aggregate_init.assert_async().await; - let want_aggregation_job = AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new( - gc_eligible_time, - gc_ineligible_time.difference(&gc_eligible_time).unwrap(), - ) - .unwrap(), - AggregationJobState::Finished, - AggregationJobStep::from(1), - ); + let want_aggregation_job = + AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new( + gc_eligible_time, + gc_ineligible_time.difference(&gc_eligible_time).unwrap(), + ) + .unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + ); let want_gc_eligible_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *gc_eligible_report.metadata().id(), @@ -1324,7 +1385,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { ReportAggregationState::Finished, ); let want_ineligible_report_aggregation = - ReportAggregation::::new( + ReportAggregation::::new( *task.id(), aggregation_job_id, *gc_ineligible_report.metadata().id(), @@ -1338,7 +1399,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { want_ineligible_report_aggregation, ]); let want_batch_aggregations = Vec::from([BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -1362,7 +1423,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { let task = task.clone(); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1380,7 +1441,7 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { .await .unwrap(); let batch_aggregations = merge_batch_aggregations_by_batch( - tx.get_batch_aggregations_for_task::(&vdaf, task.id()) + tx.get_batch_aggregations_for_task::(&vdaf, task.id()) .await .unwrap(), ); @@ -1399,19 +1460,21 @@ async fn sync_time_interval_aggregation_job_init_partially_garbage_collected() { } #[tokio::test] -async fn sync_leader_selected_aggregation_job_init_single_step() { +async fn leader_sync_leader_selected_aggregation_job_init_single_step() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(Prio3::new_count(2).unwrap()); let task = TaskBuilder::new( BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Prio3Count, ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) @@ -1427,7 +1490,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { .unwrap(), Vec::new(), ); - let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); let transcript = run_vdaf( vdaf.as_ref(), @@ -1461,7 +1524,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { .unwrap(); tx.put_aggregation_job(&AggregationJob::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, Prio3Count, >::new( @@ -1471,7 +1534,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -1484,7 +1547,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, Prio3Count, >::new( @@ -1543,7 +1606,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { let mocked_aggregate_failure = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -1556,7 +1619,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { let mocked_aggregate_success = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -1579,9 +1642,19 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); let error = aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease.clone())) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease.clone()), + ) .await .unwrap_err(); assert_matches!( @@ -1592,7 +1665,15 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { } ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -1600,16 +1681,17 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { mocked_aggregate_failure.assert_async().await; mocked_aggregate_success.assert_async().await; - let want_aggregation_job = AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - batch_id, - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::Finished, - AggregationJobStep::from(1), - ); - let want_report_aggregation = ReportAggregation::::new( + let want_aggregation_job = + AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + batch_id, + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + ); + let want_report_aggregation = ReportAggregation::::new( *task.id(), aggregation_job_id, *report.metadata().id(), @@ -1619,7 +1701,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { ReportAggregationState::Finished, ); let want_batch_aggregations = Vec::from([BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, LeaderSelected, Prio3Count, >::new( @@ -1643,7 +1725,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { (Arc::clone(&vdaf), task.clone(), *report.metadata().id()); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), &aggregation_job_id, ) @@ -1663,7 +1745,7 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { .unwrap() .unwrap(); let batch_aggregations = merge_batch_aggregations_by_batch( - tx.get_batch_aggregations_for_task::( + tx.get_batch_aggregations_for_task::( &vdaf, task.id(), ) @@ -1689,19 +1771,21 @@ async fn sync_leader_selected_aggregation_job_init_single_step() { } #[tokio::test] -async fn sync_leader_selected_aggregation_job_init_two_steps() { +async fn leader_sync_leader_selected_aggregation_job_init_two_steps() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); let task = TaskBuilder::new( BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 2 }, ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) @@ -1760,7 +1844,7 @@ async fn sync_leader_selected_aggregation_job_init_two_steps() { batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -1829,7 +1913,7 @@ async fn sync_leader_selected_aggregation_job_init_two_steps() { let mocked_aggregate_success = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -1852,9 +1936,19 @@ async fn sync_leader_selected_aggregation_job_init_two_steps() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -1867,7 +1961,7 @@ async fn sync_leader_selected_aggregation_job_init_two_steps() { aggregation_param, batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), ); let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( @@ -1948,7 +2042,7 @@ async fn sync_leader_selected_aggregation_job_init_two_steps() { } #[tokio::test] -async fn sync_time_interval_aggregation_job_continue() { +async fn leader_sync_time_interval_aggregation_job_continue() { // Setup: insert a client report and add it to an aggregation job whose state has already // been stepped once. install_test_trace_subscriber(); @@ -1956,11 +2050,16 @@ async fn sync_time_interval_aggregation_job_continue() { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); let time = clock .now() @@ -2025,7 +2124,7 @@ async fn sync_time_interval_aggregation_job_continue() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), )) .await @@ -2115,7 +2214,7 @@ async fn sync_time_interval_aggregation_job_continue() { let mocked_aggregate_failure = server .mock( "POST", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -2128,7 +2227,7 @@ async fn sync_time_interval_aggregation_job_continue() { let mocked_aggregate_success = server .mock( "POST", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -2148,9 +2247,19 @@ async fn sync_time_interval_aggregation_job_continue() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); let error = aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease.clone())) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease.clone()), + ) .await .unwrap_err(); assert_matches!( @@ -2161,7 +2270,15 @@ async fn sync_time_interval_aggregation_job_continue() { } ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -2270,7 +2387,7 @@ async fn sync_time_interval_aggregation_job_continue() { } #[tokio::test] -async fn sync_leader_selected_aggregation_job_continue() { +async fn leader_sync_leader_selected_aggregation_job_continue() { // Setup: insert a client report and add it to an aggregation job whose state has already // been stepped once. install_test_trace_subscriber(); @@ -2278,12 +2395,14 @@ async fn sync_leader_selected_aggregation_job_continue() { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); let task = TaskBuilder::new( BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 2 }, ) .with_helper_aggregator_endpoint(server.url().parse().unwrap()) @@ -2347,7 +2466,7 @@ async fn sync_leader_selected_aggregation_job_continue() { batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), )) .await @@ -2421,7 +2540,7 @@ async fn sync_leader_selected_aggregation_job_continue() { let mocked_aggregate_failure = server .mock( "POST", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -2434,7 +2553,7 @@ async fn sync_leader_selected_aggregation_job_continue() { let mocked_aggregate_success = server .mock( "POST", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -2454,9 +2573,19 @@ async fn sync_leader_selected_aggregation_job_continue() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -2558,18 +2687,23 @@ async fn sync_leader_selected_aggregation_job_continue() { } #[tokio::test] -async fn async_aggregation_job_init_to_pending() { +async fn leader_async_aggregation_job_init_to_pending() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(1)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -2622,7 +2756,7 @@ async fn async_aggregation_job_init_to_pending() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -2681,7 +2815,7 @@ async fn async_aggregation_job_init_to_pending() { let mocked_aggregate_request = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -2704,9 +2838,19 @@ async fn async_aggregation_job_init_to_pending() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -2719,7 +2863,7 @@ async fn async_aggregation_job_init_to_pending() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); @@ -2802,18 +2946,23 @@ async fn async_aggregation_job_init_to_pending() { } #[tokio::test] -async fn async_aggregation_job_init_to_pending_two_step() { +async fn leader_async_aggregation_job_init_to_pending_two_step() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -2866,7 +3015,7 @@ async fn async_aggregation_job_init_to_pending_two_step() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -2925,7 +3074,7 @@ async fn async_aggregation_job_init_to_pending_two_step() { let mocked_aggregate_request = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -2948,9 +3097,19 @@ async fn async_aggregation_job_init_to_pending_two_step() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -2963,7 +3122,7 @@ async fn async_aggregation_job_init_to_pending_two_step() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); @@ -3046,18 +3205,23 @@ async fn async_aggregation_job_init_to_pending_two_step() { } #[tokio::test] -async fn async_aggregation_job_continue_to_pending() { +async fn leader_async_aggregation_job_continue_to_pending() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -3114,7 +3278,7 @@ async fn async_aggregation_job_continue_to_pending() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), )) .await @@ -3174,7 +3338,7 @@ async fn async_aggregation_job_continue_to_pending() { let mocked_aggregate_request = server .mock( "POST", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -3194,9 +3358,19 @@ async fn async_aggregation_job_continue_to_pending() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -3209,7 +3383,7 @@ async fn async_aggregation_job_continue_to_pending() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), ); @@ -3292,18 +3466,23 @@ async fn async_aggregation_job_continue_to_pending() { } #[tokio::test] -async fn async_aggregation_job_init_poll_to_pending() { +async fn leader_async_aggregation_job_init_poll_to_pending() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(1)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -3357,7 +3536,7 @@ async fn async_aggregation_job_init_poll_to_pending() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -3411,7 +3590,7 @@ async fn async_aggregation_job_init_poll_to_pending() { let mocked_aggregate_request = server .mock( "GET", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -3430,9 +3609,19 @@ async fn async_aggregation_job_init_poll_to_pending() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -3445,7 +3634,7 @@ async fn async_aggregation_job_init_poll_to_pending() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); @@ -3528,18 +3717,23 @@ async fn async_aggregation_job_init_poll_to_pending() { } #[tokio::test] -async fn async_aggregation_job_init_poll_to_pending_two_step() { +async fn leader_async_aggregation_job_init_poll_to_pending_two_step() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -3593,7 +3787,7 @@ async fn async_aggregation_job_init_poll_to_pending_two_step() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -3647,7 +3841,7 @@ async fn async_aggregation_job_init_poll_to_pending_two_step() { let mocked_aggregate_request = server .mock( "GET", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -3666,9 +3860,19 @@ async fn async_aggregation_job_init_poll_to_pending_two_step() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -3681,7 +3885,7 @@ async fn async_aggregation_job_init_poll_to_pending_two_step() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); @@ -3764,18 +3968,23 @@ async fn async_aggregation_job_init_poll_to_pending_two_step() { } #[tokio::test] -async fn async_aggregation_job_init_poll_to_finished() { +async fn leader_async_aggregation_job_init_poll_to_finished() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(1)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -3829,7 +4038,7 @@ async fn async_aggregation_job_init_poll_to_finished() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -3890,7 +4099,7 @@ async fn async_aggregation_job_init_poll_to_finished() { let mocked_aggregate_request = server .mock( "GET", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -3909,9 +4118,19 @@ async fn async_aggregation_job_init_poll_to_finished() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -4005,18 +4224,23 @@ async fn async_aggregation_job_init_poll_to_finished() { } #[tokio::test] -async fn async_aggregation_job_init_poll_to_continue() { +async fn leader_async_aggregation_job_init_poll_to_continue() { // Setup: insert a client report and add it to a new aggregation job. install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); @@ -4070,7 +4294,7 @@ async fn async_aggregation_job_init_poll_to_continue() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -4131,7 +4355,7 @@ async fn async_aggregation_job_init_poll_to_continue() { let mocked_aggregate_request = server .mock( "GET", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -4150,9 +4374,19 @@ async fn async_aggregation_job_init_poll_to_continue() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -4165,7 +4399,7 @@ async fn async_aggregation_job_init_poll_to_continue() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), ); @@ -4251,7 +4485,7 @@ async fn async_aggregation_job_init_poll_to_continue() { } #[tokio::test] -async fn async_aggregation_job_continue_poll_to_pending() { +async fn leader_async_aggregation_job_continue_poll_to_pending() { // Setup: insert a client report and add it to an aggregation job whose state has already // been stepped once. install_test_trace_subscriber(); @@ -4259,11 +4493,16 @@ async fn async_aggregation_job_continue_poll_to_pending() { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); let time = clock .now() @@ -4318,7 +4557,7 @@ async fn async_aggregation_job_continue_poll_to_pending() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), )) .await @@ -4373,7 +4612,7 @@ async fn async_aggregation_job_continue_poll_to_pending() { let mocked_aggregate_success = server .mock( "GET", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -4392,9 +4631,19 @@ async fn async_aggregation_job_continue_poll_to_pending() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -4407,7 +4656,7 @@ async fn async_aggregation_job_continue_poll_to_pending() { aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), ); let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( @@ -4489,7 +4738,7 @@ async fn async_aggregation_job_continue_poll_to_pending() { } #[tokio::test] -async fn async_aggregation_job_continue_poll_to_finished() { +async fn leader_async_aggregation_job_continue_poll_to_finished() { // Setup: insert a client report and add it to an aggregation job whose state has already // been stepped once. install_test_trace_subscriber(); @@ -4497,11 +4746,16 @@ async fn async_aggregation_job_continue_poll_to_finished() { let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); let time = clock .now() @@ -4556,7 +4810,7 @@ async fn async_aggregation_job_continue_poll_to_finished() { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(1), )) .await @@ -4616,7 +4870,7 @@ async fn async_aggregation_job_continue_poll_to_finished() { let mocked_aggregate_success = server .mock( "GET", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -4635,9 +4889,19 @@ async fn async_aggregation_job_continue_poll_to_finished() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver - .step_aggregation_job(ds.clone(), Arc::new(lease)) + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) .await .unwrap(); @@ -4729,97 +4993,805 @@ async fn async_aggregation_job_continue_poll_to_finished() { .await; } -struct CancelAggregationJobTestCase { - task: AggregatorTask, - vdaf: Arc, - aggregation_job: AggregationJob, - batch_identifier: Interval, - report_aggregation: ReportAggregation, - _ephemeral_datastore: EphemeralDatastore, - datastore: Arc>, - lease: Lease, - mock_helper: ServerGuard, -} - -async fn setup_cancel_aggregation_job_test() -> CancelAggregationJobTestCase { - // Setup: insert a client report and add it to a new aggregation job. +#[tokio::test] +async fn helper_async_init_processing_to_finished() { + // Setup: insert an aggregation job with a report aggregation in state HelperInitProcessing. install_test_trace_subscriber(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; - let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - let vdaf = Arc::new(Prio3::new_count(2).unwrap()); - let mock_helper = mockito::Server::new_async().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let hpke_keypair = ds.put_hpke_key().await.unwrap(); + let vdaf = Arc::new(dummy::Vdaf::new(1)); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_helper_aggregator_endpoint(mock_helper.url().parse().unwrap()) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); let time = clock .now() .to_batch_interval_start(task.time_precision()) .unwrap(); - let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); + let active_batch_identifier = + TimeInterval::to_batch_identifier(&helper_task, &(), &time).unwrap(); let report_metadata = ReportMetadata::new(random(), time, Vec::new()); - let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let aggregation_param = dummy::AggregationParam(7); let transcript = run_vdaf( vdaf.as_ref(), task.id(), verify_key.as_bytes(), - &(), + &aggregation_param, report_metadata.id(), - &false, + &13, ); - let helper_hpke_keypair = HpkeKeypair::test(); - let report = LeaderStoredReport::generate( + let report_share = generate_helper_report_share::( *task.id(), report_metadata, - helper_hpke_keypair.config(), + hpke_keypair.config(), + &transcript.public_share, Vec::new(), - &transcript, + &transcript.helper_input_share, ); let aggregation_job_id = random(); - let aggregation_job = AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ); - let report_aggregation = report.as_leader_init_report_aggregation(aggregation_job_id, 0); - - let lease = datastore + let lease = ds .run_unnamed_tx(|tx| { - let (task, report, aggregation_job, report_aggregation) = ( - task.clone(), - report.clone(), - aggregation_job.clone(), - report_aggregation.clone(), - ); + let helper_task = helper_task.clone(); + let report_share = report_share.clone(); + let message = transcript.leader_prepare_transitions[0].message.clone(); + Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); - tx.put_client_report(&report).await.unwrap(); - tx.scrub_client_report(report.task_id(), report.metadata().id()) - .await - .unwrap(); - tx.put_aggregation_job(&aggregation_job).await.unwrap(); - tx.put_report_aggregation(&report_aggregation) + let report_id = *report_share.metadata().id(); + let report_timestamp = *report_share.metadata().time(); + + tx.put_aggregator_task(&helper_task).await.unwrap(); + tx.put_scrubbed_report(helper_task.id(), &report_id, &report_timestamp) .await .unwrap(); - tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, - TimeInterval, - Prio3Count, - >::new( - *task.id(), - batch_identifier, - (), + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + report_id, + report_timestamp, + 0, + None, + ReportAggregationState::HelperInitProcessing { + prepare_init: PrepareInit::new(report_share, message), + require_taskbind_extension: false, + }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(&report_timestamp).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + let lease = tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0); + + Ok(lease) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Run. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, + ); + aggregation_job_driver + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) + .await + .unwrap(); + + // Verify. + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(0), + ); + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report_share.metadata().id(), + *report_share.metadata().time(), + 0, + Some(PrepareResp::new( + *report_share.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )), + ReportAggregationState::Finished, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report_share.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: Some(transcript.helper_aggregate_share), + report_count: 1, + checksum: ReportIdChecksum::for_report_id(report_share.metadata().id()), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 1, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let helper_task = helper_task.clone(); + let report_metadata = report_share.metadata().clone(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + helper_task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Helper, + helper_task.id(), + &aggregation_job_id, + report_metadata.id(), + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + helper_task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(1)) + .await; +} + +#[tokio::test] +async fn helper_async_init_processing_to_continue() { + // Setup: insert an aggregation job with a report aggregation in state HelperInitProcessing. + install_test_trace_subscriber(); + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let hpke_keypair = ds.put_hpke_key().await.unwrap(); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let active_batch_identifier = + TimeInterval::to_batch_identifier(&helper_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + + let aggregation_param = dummy::AggregationParam(7); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &13, + ); + + let report_share = generate_helper_report_share::( + *task.id(), + report_metadata, + hpke_keypair.config(), + &transcript.public_share, + Vec::new(), + &transcript.helper_input_share, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_share = report_share.clone(); + let message = transcript.leader_prepare_transitions[0].message.clone(); + + Box::pin(async move { + let report_id = *report_share.metadata().id(); + let report_timestamp = *report_share.metadata().time(); + + tx.put_aggregator_task(&helper_task).await.unwrap(); + tx.put_scrubbed_report(helper_task.id(), &report_id, &report_timestamp) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + report_id, + report_timestamp, + 0, + None, + ReportAggregationState::HelperInitProcessing { + prepare_init: PrepareInit::new(report_share, message), + require_taskbind_extension: false, + }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(&report_timestamp).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + let lease = tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0); + + Ok(lease) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Run. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, + ); + aggregation_job_driver + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) + .await + .unwrap(); + + // Verify. + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + ); + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report_share.metadata().id(), + *report_share.metadata().time(), + 0, + Some(PrepareResp::new( + *report_share.metadata().id(), + PrepareStepResult::Continue { + message: transcript.helper_prepare_transitions[0].message.clone(), + }, + )), + ReportAggregationState::HelperContinue { + prepare_state: *transcript.helper_prepare_transitions[0].prepare_state(), + }, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report_share.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let helper_task = helper_task.clone(); + let report_metadata = report_share.metadata().clone(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + helper_task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Helper, + helper_task.id(), + &aggregation_job_id, + report_metadata.id(), + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + helper_task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(0)) + .await; +} + +#[tokio::test] +async fn helper_async_continue_processing_to_finished() { + // Setup: insert an aggregation job with a report aggregation in state HelperInitProcessing. + install_test_trace_subscriber(); + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + let hpke_keypair = ds.put_hpke_key().await.unwrap(); + let vdaf = Arc::new(dummy::Vdaf::new(2)); + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let active_batch_identifier = + TimeInterval::to_batch_identifier(&helper_task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + + let aggregation_param = dummy::AggregationParam(7); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &13, + ); + + let report_share = generate_helper_report_share::( + *task.id(), + report_metadata, + hpke_keypair.config(), + &transcript.public_share, + Vec::new(), + &transcript.helper_input_share, + ); + let aggregation_job_id = random(); + + let lease = ds + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_share = report_share.clone(); + let prepare_state = *transcript.helper_prepare_transitions[0].prepare_state(); + let message = transcript.leader_prepare_transitions[1].message.clone(); + + Box::pin(async move { + let report_id = *report_share.metadata().id(); + let report_timestamp = *report_share.metadata().time(); + + tx.put_aggregator_task(&helper_task).await.unwrap(); + tx.put_scrubbed_report(helper_task.id(), &report_id, &report_timestamp) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(1), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + report_id, + report_timestamp, + 0, + None, + ReportAggregationState::HelperContinueProcessing { + prepare_state, + prepare_continue: PrepareContinue::new(report_id, message), + }, + )) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(&report_timestamp).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: None, + report_count: 0, + checksum: ReportIdChecksum::default(), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 0, + }, + )) + .await + .unwrap(); + + let lease = tx + .acquire_incomplete_aggregation_jobs(&StdDuration::from_secs(60), 1) + .await + .unwrap() + .remove(0); + + Ok(lease) + }) + }) + .await + .unwrap(); + assert_eq!(lease.leased().task_id(), task.id()); + assert_eq!(lease.leased().aggregation_job_id(), &aggregation_job_id); + + // Run. + let aggregation_job_driver = AggregationJobDriver::new( + reqwest::Client::builder().build().unwrap(), + LimitedRetryer::new(0), + &noop_meter(), + BATCH_AGGREGATION_SHARD_COUNT, + TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, + ); + aggregation_job_driver + .step_aggregation_job( + ds.clone(), + Arc::new( + HpkeKeypairCache::new(Arc::clone(&ds), HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL) + .await + .unwrap(), + ), + Arc::new(lease), + ) + .await + .unwrap(); + + // Verify. + let want_aggregation_job = AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + ); + let want_report_aggregation = ReportAggregation::<0, dummy::Vdaf>::new( + *task.id(), + aggregation_job_id, + *report_share.metadata().id(), + *report_share.metadata().time(), + 0, + Some(PrepareResp::new( + *report_share.metadata().id(), + PrepareStepResult::Finished, + )), + ReportAggregationState::Finished, + ); + + let want_batch_aggregations = + Vec::from([BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( + *task.id(), + active_batch_identifier, + aggregation_param, + 0, + Interval::from_time(report_share.metadata().time()).unwrap(), + BatchAggregationState::Aggregating { + aggregate_share: Some(transcript.helper_aggregate_share), + report_count: 1, + checksum: ReportIdChecksum::for_report_id(report_share.metadata().id()), + aggregation_jobs_created: 1, + aggregation_jobs_terminated: 1, + }, + )]); + + let (got_aggregation_job, got_report_aggregation, got_batch_aggregations) = ds + .run_unnamed_tx(|tx| { + let vdaf = Arc::clone(&vdaf); + let helper_task = helper_task.clone(); + let report_metadata = report_share.metadata().clone(); + + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + helper_task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregation = tx + .get_report_aggregation_by_report_id( + vdaf.as_ref(), + &Role::Helper, + helper_task.id(), + &aggregation_job_id, + report_metadata.id(), + &aggregation_param, + ) + .await + .unwrap() + .unwrap(); + let batch_aggregations = merge_batch_aggregations_by_batch( + tx.get_batch_aggregations_for_task::<0, TimeInterval, dummy::Vdaf>( + &vdaf, + helper_task.id(), + ) + .await + .unwrap(), + ); + + Ok((aggregation_job, report_aggregation, batch_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!(want_aggregation_job, got_aggregation_job); + assert_eq!(want_report_aggregation, got_report_aggregation); + assert_eq!(want_batch_aggregations, got_batch_aggregations); + + assert_task_aggregation_counter(&ds, *task.id(), TaskAggregationCounter::new_with_values(1)) + .await; +} + +struct CancelAggregationJobTestCase { + task: AggregatorTask, + vdaf: Arc, + aggregation_job: AggregationJob, + batch_identifier: Interval, + report_aggregation: ReportAggregation, + _ephemeral_datastore: EphemeralDatastore, + datastore: Arc>, + lease: Lease, + mock_helper: ServerGuard, +} + +async fn setup_cancel_aggregation_job_test() -> CancelAggregationJobTestCase { + // Setup: insert a client report and add it to a new aggregation job. + install_test_trace_subscriber(); + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + datastore.put_hpke_key().await.unwrap(); + let vdaf = Arc::new(Prio3::new_count(2).unwrap()); + let mock_helper = mockito::Server::new_async().await; + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_helper_aggregator_endpoint(mock_helper.url().parse().unwrap()) + .build() + .leader_view() + .unwrap(); + let time = clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(); + let batch_identifier = TimeInterval::to_batch_identifier(&task, &(), &time).unwrap(); + let report_metadata = ReportMetadata::new(random(), time, Vec::new()); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &(), + report_metadata.id(), + &false, + ); + + let helper_hpke_keypair = HpkeKeypair::test(); + let report = LeaderStoredReport::generate( + *task.id(), + report_metadata, + helper_hpke_keypair.config(), + Vec::new(), + &transcript, + ); + let aggregation_job_id = random(); + + let aggregation_job = AggregationJob::::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + ); + let report_aggregation = report.as_leader_init_report_aggregation(aggregation_job_id, 0); + + let lease = datastore + .run_unnamed_tx(|tx| { + let (task, report, aggregation_job, report_aggregation) = ( + task.clone(), + report.clone(), + aggregation_job.clone(), + report_aggregation.clone(), + ); + Box::pin(async move { + tx.put_aggregator_task(&task).await.unwrap(); + tx.put_client_report(&report).await.unwrap(); + tx.scrub_client_report(report.task_id(), report.metadata().id()) + .await + .unwrap(); + tx.put_aggregation_job(&aggregation_job).await.unwrap(); + tx.put_report_aggregation(&report_aggregation) + .await + .unwrap(); + + tx.put_batch_aggregation(&BatchAggregation::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *task.id(), + batch_identifier, + (), 0, Interval::from_time(&time).unwrap(), BatchAggregationState::Aggregating { @@ -4886,6 +5858,8 @@ async fn cancel_aggregation_job() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver .abandon_aggregation_job(Arc::clone(&test_case.datastore), Arc::new(test_case.lease)) @@ -4927,7 +5901,7 @@ async fn cancel_aggregation_job() { ); Box::pin(async move { let aggregation_job = tx - .get_aggregation_job::( + .get_aggregation_job::( task.id(), aggregation_job.id(), ) @@ -4947,7 +5921,7 @@ async fn cancel_aggregation_job() { .unwrap() .unwrap(); let batch_aggregations = merge_batch_aggregations_by_batch( - tx.get_batch_aggregations_for_task::(&vdaf, task.id()) + tx.get_batch_aggregations_for_task::(&vdaf, task.id()) .await .unwrap(), ); @@ -4996,6 +5970,8 @@ async fn cancel_aggregation_job_helper_aggregation_job_deletion_fails() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, ); aggregation_job_driver .abandon_aggregation_job(Arc::clone(&test_case.datastore), Arc::new(test_case.lease)) @@ -5013,16 +5989,21 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { let mut runtime_manager = TestRuntimeManager::new(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let stopper = Stopper::new(); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); let agg_auth_token = task.aggregator_auth_token(); let aggregation_job_id = random(); - let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); let helper_hpke_keypair = HpkeKeypair::test(); @@ -5061,18 +6042,20 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { .await .unwrap(); - tx.put_aggregation_job( - &AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ), - ) + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) .await .unwrap(); @@ -5083,7 +6066,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -5116,6 +6099,8 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, )); let job_driver = Arc::new( JobDriver::new( @@ -5141,7 +6126,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { let failure_mock = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -5160,7 +6145,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { let no_more_requests_mock = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -5207,7 +6192,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { .unwrap() .unwrap(); let got_batch_aggregations = merge_batch_aggregations_by_batch( - tx.get_batch_aggregations_for_task::(&vdaf, task.id()) + tx.get_batch_aggregations_for_task::(&vdaf, task.id()) .await .unwrap(), ); @@ -5219,7 +6204,7 @@ async fn abandon_failing_aggregation_job_with_retryable_error() { .unwrap(); assert_eq!( got_aggregation_job, - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), @@ -5256,16 +6241,21 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { let mut runtime_manager = TestRuntimeManager::new(); let ephemeral_datastore = ephemeral_datastore().await; let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + ds.put_hpke_key().await.unwrap(); let stopper = Stopper::new(); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .build(); let leader_task = task.leader_view().unwrap(); let agg_auth_token = task.aggregator_auth_token(); let aggregation_job_id = random(); - let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); + let verify_key: VerifyKey = task.vdaf_verify_key().unwrap(); let helper_hpke_keypair = HpkeKeypair::test(); @@ -5304,18 +6294,20 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { .await .unwrap(); - tx.put_aggregation_job( - &AggregationJob::::new( - *task.id(), - aggregation_job_id, - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ), - ) + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *task.id(), + aggregation_job_id, + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) .await .unwrap(); @@ -5326,7 +6318,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { .unwrap(); tx.put_batch_aggregation(&BatchAggregation::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -5359,6 +6351,8 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { &noop_meter(), BATCH_AGGREGATION_SHARD_COUNT, TASK_AGGREGATION_COUNTER_SHARD_COUNT, + HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + DEFAULT_ASYNC_POLL_INTERVAL, )); let job_driver = Arc::new( JobDriver::new( @@ -5384,7 +6378,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { let failure_mock = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -5403,7 +6397,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { let no_more_requests_mock = server .mock( "PUT", - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -5444,7 +6438,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { .unwrap() .unwrap(); let got_batch_aggregations = merge_batch_aggregations_by_batch( - tx.get_batch_aggregations_for_task::(&vdaf, task.id()) + tx.get_batch_aggregations_for_task::(&vdaf, task.id()) .await .unwrap(), ); @@ -5456,7 +6450,7 @@ async fn abandon_failing_aggregation_job_with_fatal_error() { .unwrap(); assert_eq!( got_aggregation_job, - AggregationJob::::new( + AggregationJob::::new( *task.id(), aggregation_job_id, (), diff --git a/aggregator/src/aggregator/aggregation_job_init.rs b/aggregator/src/aggregator/aggregation_job_init.rs new file mode 100644 index 000000000..6d3595863 --- /dev/null +++ b/aggregator/src/aggregator/aggregation_job_init.rs @@ -0,0 +1,1124 @@ +//! Implements portions of aggregation job initialization for the Helper. + +use crate::{ + aggregator::{ + aggregation_job_writer::WritableReportAggregation, error::handle_ping_pong_error, + AggregatorMetrics, Error, + }, + cache::HpkeKeypairCache, +}; +use assert_matches::assert_matches; +use janus_aggregator_core::{ + batch_mode::AccumulableBatchMode, + datastore::models::{AggregationJob, ReportAggregation, ReportAggregationState}, + task::AggregatorTask, +}; +use janus_core::{ + hpke::{self, HpkeApplicationInfo, Label}, + time::{Clock, TimeExt as _}, + vdaf::vdaf_application_context, +}; +use janus_messages::{ + ExtensionType, InputShareAad, PlaintextInputShare, PrepareResp, PrepareStepResult, ReportError, + Role, +}; +use opentelemetry::{metrics::Counter, KeyValue}; +use prio::{ + codec::{Decode as _, Encode, ParameterizedDecode}, + topology::ping_pong::{PingPongState, PingPongTopology as _}, + vdaf, +}; +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use std::{collections::HashMap, panic, sync::Arc}; +use tokio::sync::mpsc; +use tracing::{debug, info_span, trace_span, Span}; + +#[derive(Clone)] +pub struct AggregateInitMetrics { + /// Counters tracking the number of failures to step client reports through the aggregation + /// process. + aggregate_step_failure_counter: Counter, +} + +impl AggregateInitMetrics { + pub fn new(aggregate_step_failure_counter: Counter) -> Self { + Self { + aggregate_step_failure_counter, + } + } +} + +impl From for AggregateInitMetrics { + fn from(metrics: AggregatorMetrics) -> Self { + Self { + aggregate_step_failure_counter: metrics.aggregate_step_failure_counter, + } + } +} + +/// Given report aggregations in the `HelperInitProcessing` state, this function computes the next +/// step of the aggregation; the returned [`WritableReportAggregation`]s correspond to the provided +/// report aggregations and will be in the `HelperContinue`, `Finished`, or `Failed` states. +/// +/// Only report aggregations in the `HelperInitProcessing` state can be provided. The caller must +/// filter report aggregations which are in other states (e.g. `Failed`) prior to calling this +/// function. +/// +/// ### Panics +/// +/// Panics if a provided report aggregation is in a state other than `HelperInitProcessing`. +pub async fn compute_helper_aggregate_init( + clock: &C, + hpke_keypairs: Arc, + vdaf: Arc, + metrics: AggregateInitMetrics, + task: Arc, + aggregation_job: Arc>, + report_aggregations: Vec>, +) -> Result>, Error> +where + B: AccumulableBatchMode, + A: vdaf::Aggregator + 'static + Send + Sync, + C: Clock, + A::AggregationParam: Send + Sync + PartialEq + Eq, + A::AggregateShare: Send + Sync, + A::InputShare: Send + Sync, + A::PrepareMessage: Send + Sync + PartialEq, + A::PrepareShare: Send + Sync + PartialEq, + for<'a> A::PrepareState: Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)> + PartialEq, + A::PublicShare: Send + Sync, + A::OutputShare: Send + Sync + PartialEq, +{ + let verify_key = task.vdaf_verify_key()?; + let report_aggregation_count = report_aggregations.len(); + let report_deadline = clock + .now() + .add(task.tolerable_clock_skew()) + .map_err(Error::from)?; + + // Shutdown on cancellation: if this request is cancelled, the `receiver` will be dropped. This + // will cause any attempts to send on `sender` to return a `SendError`, which will be returned + // from the function passed to `try_for_each`; `try_for_each` will terminate early on receiving + // an error. + let (sender, mut receiver) = mpsc::unbounded_channel(); + let producer_task = tokio::task::spawn_blocking({ + let parent_span = Span::current(); + let hpke_keypairs = Arc::clone(&hpke_keypairs); + let vdaf = Arc::clone(&vdaf); + let task = Arc::clone(&task); + let metrics = metrics.clone(); + let aggregation_job = Arc::clone(&aggregation_job); + + move || { + let span = + info_span!(parent: parent_span, "compute_helper_aggregate_init threadpool task"); + let ctx = vdaf_application_context(task.id()); + + report_aggregations + .into_par_iter() + .try_for_each(|report_aggregation| { + let _entered = span.enter(); + + // Assert safety: this function should only be called with report + // aggregations in the HelperInitProcessing state. + let (prepare_init, require_taskbind_extension) = assert_matches!( + report_aggregation.state(), + ReportAggregationState::HelperInitProcessing { + prepare_init, + require_taskbind_extension, + } => (prepare_init, *require_taskbind_extension) + ); + + // If decryption fails, then the aggregator MUST fail with error `hpke-decrypt-error`. (§4.4.2.2) + let hpke_keypair = hpke_keypairs.keypair( + prepare_init + .report_share() + .encrypted_input_share() + .config_id(), + ).ok_or_else(|| { + debug!( + config_id = %prepare_init.report_share().encrypted_input_share().config_id(), + "Helper encrypted input share references unknown HPKE config ID" + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "unknown_hpke_config_id")]); + ReportError::HpkeUnknownConfigId + }); + + let plaintext = hpke_keypair.and_then(|hpke_keypair| { + let input_share_aad = InputShareAad::new( + *task.id(), + prepare_init.report_share().metadata().clone(), + prepare_init.report_share().public_share().to_vec(), + ) + .get_encoded() + .map_err(|err| { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + ?err, + "Couldn't encode input share AAD" + ); + metrics.aggregate_step_failure_counter.add( + 1, + &[KeyValue::new("type", "input_share_aad_encode_failure")], + ); + // HpkeDecryptError isn't strictly accurate, but given that this + // fallible encoding is part of the HPKE decryption process, I think + // this is as close as we can get to a meaningful error signal. + ReportError::HpkeDecryptError + })?; + + hpke::open( + &hpke_keypair, + &HpkeApplicationInfo::new( + &Label::InputShare, + &Role::Client, + &Role::Helper, + ), + prepare_init.report_share().encrypted_input_share(), + &input_share_aad, + ) + .map_err(|error| { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + ?error, + "Couldn't decrypt helper's report share" + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "decrypt_failure")]); + ReportError::HpkeDecryptError + }) + }); + + let plaintext_input_share = plaintext.and_then(|plaintext| { + let plaintext_input_share = PlaintextInputShare::get_decoded(&plaintext) + .map_err(|error| { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + ?error, "Couldn't decode helper's plaintext input share", + ); + metrics.aggregate_step_failure_counter.add( + 1, + &[KeyValue::new( + "type", + "plaintext_input_share_decode_failure", + )], + ); + ReportError::InvalidMessage + })?; + + // Build map of extension type to extension data, checking for duplicates. + let mut extensions = HashMap::new(); + if !plaintext_input_share.private_extensions().iter().chain(prepare_init.report_share().metadata().public_extensions()).all(|extension| { + extensions + .insert(*extension.extension_type(), extension.extension_data()) + .is_none() + }) { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + "Received report share with duplicate extensions", + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "duplicate_extension")]); + return Err(ReportError::InvalidMessage); + } + + if require_taskbind_extension { + let valid_taskbind_extension_present = extensions + .get(&ExtensionType::Taskbind) + .map(|data| data.is_empty()) + .unwrap_or(false); + if !valid_taskbind_extension_present { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + "Taskprov task received report with missing or malformed \ + taskbind extension", + ); + metrics.aggregate_step_failure_counter.add( + 1, + &[KeyValue::new( + "type", + "missing_or_malformed_taskbind_extension", + )], + ); + return Err(ReportError::InvalidMessage); + } + } else if extensions.contains_key(&ExtensionType::Taskbind) { + // taskprov not enabled, but the taskbind extension is present. + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + "Non-taskprov task received report with unexpected taskbind \ + extension", + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "unexpected_taskbind_extension")]); + return Err(ReportError::InvalidMessage); + } + + Ok(plaintext_input_share) + }); + + let input_share = plaintext_input_share.and_then(|plaintext_input_share| { + A::InputShare::get_decoded_with_param( + &(&vdaf, Role::Helper.index().unwrap()), + plaintext_input_share.payload(), + ) + .map_err(|error| { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + ?error, "Couldn't decode helper's input share", + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "input_share_decode_failure")]); + ReportError::InvalidMessage + }) + }); + + let public_share = A::PublicShare::get_decoded_with_param( + &vdaf, + prepare_init.report_share().public_share(), + ) + .map_err(|error| { + debug!( + task_id = %task.id(), + report_id = ?prepare_init.report_share().metadata().id(), + ?error, "Couldn't decode public share", + ); + metrics + .aggregate_step_failure_counter + .add(1, &[KeyValue::new("type", "public_share_decode_failure")]); + ReportError::InvalidMessage + }); + + let shares = + input_share.and_then(|input_share| Ok((public_share?, input_share))); + + // Reject reports from too far in the future. + let shares = shares.and_then(|shares| { + if prepare_init + .report_share() + .metadata() + .time() + .is_after(&report_deadline) + { + return Err(ReportError::ReportTooEarly); + } + Ok(shares) + }); + + // Next, the aggregator runs the preparation-state initialization algorithm for the VDAF + // associated with the task and computes the first state transition. [...] If either + // step fails, then the aggregator MUST fail with error `vdaf-prep-error`. (§4.4.2.2) + let init_rslt = shares.and_then(|(public_share, input_share)| { + trace_span!("VDAF preparation (helper initialization)").in_scope(|| { + vdaf.helper_initialized( + verify_key.as_bytes(), + &ctx, + aggregation_job.aggregation_parameter(), + /* report ID is used as VDAF nonce */ + prepare_init.report_share().metadata().id().as_ref(), + &public_share, + &input_share, + prepare_init.message(), + ) + .and_then(|transition| transition.evaluate(&ctx, &vdaf)) + .map_err(|error| { + handle_ping_pong_error( + task.id(), + Role::Helper, + prepare_init.report_share().metadata().id(), + error, + &metrics.aggregate_step_failure_counter, + ) + }) + }) + }); + + let (report_aggregation_state, prepare_step_result, output_share) = + match init_rslt { + Ok((PingPongState::Continued(prepare_state), outgoing_message)) => { + // Helper is not finished. Await the next message from the Leader to advance to + // the next step. + ( + ReportAggregationState::HelperContinue { prepare_state }, + PrepareStepResult::Continue { + message: outgoing_message, + }, + None, + ) + } + Ok((PingPongState::Finished(output_share), outgoing_message)) => ( + ReportAggregationState::Finished, + PrepareStepResult::Continue { + message: outgoing_message, + }, + Some(output_share), + ), + Err(report_error) => ( + ReportAggregationState::Failed { report_error }, + PrepareStepResult::Reject(report_error), + None, + ), + }; + + let report_id = *prepare_init.report_share().metadata().id(); + sender.send(WritableReportAggregation::new( + report_aggregation + .with_last_prep_resp( + Some(PrepareResp::new( + report_id, + prepare_step_result, + )) + ) + .with_state(report_aggregation_state), + output_share + ) + ) + }) + } + }); + + let mut report_aggregations = Vec::with_capacity(report_aggregation_count); + while receiver.recv_many(&mut report_aggregations, 10).await > 0 {} + + // Await the producer task to resume any panics that may have occurred, and to ensure the + // producer task is completely done (e.g. all of its memory is released). The only other errors + // that can occur are: a `JoinError` indicating cancellation, which is impossible because we do + // not cancel the task; and a `SendError`, which can only happen if this future is cancelled (in + // which case we will not run this code at all). + let _ = producer_task.await.map_err(|join_error| { + if let Ok(reason) = join_error.try_into_panic() { + panic::resume_unwind(reason); + } + }); + assert_eq!(report_aggregations.len(), report_aggregation_count); + + Ok(report_aggregations) +} + +#[cfg(feature = "test-util")] +#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))] +pub mod test_util { + use crate::aggregator::test_util::generate_helper_report_share; + use janus_aggregator_core::task::{test_util::Task, AggregatorTask}; + use janus_core::{ + test_util::{run_vdaf, VdafTranscript}, + time::{Clock, MockClock, TimeExt as _}, + }; + use janus_messages::{ + batch_mode::{self}, + AggregationJobId, AggregationJobInitializeReq, Extension, HpkeConfig, PrepareInit, + ReportMetadata, ReportShare, + }; + use prio::{ + codec::Encode, + vdaf::{self}, + }; + use rand::random; + use trillium::{Handler, KnownHeaderName}; + use trillium_testing::{prelude::put, TestConn}; + + #[derive(Clone)] + pub struct PrepareInitGenerator + where + V: vdaf::Vdaf, + { + clock: MockClock, + task: AggregatorTask, + vdaf: V, + aggregation_param: V::AggregationParam, + hpke_config: HpkeConfig, + private_extensions: Vec, + } + + impl PrepareInitGenerator + where + V: vdaf::Vdaf + vdaf::Aggregator + vdaf::Client<16>, + { + pub fn new( + clock: MockClock, + task: AggregatorTask, + hpke_config: HpkeConfig, + vdaf: V, + aggregation_param: V::AggregationParam, + ) -> Self { + Self { + clock, + task, + vdaf, + aggregation_param, + hpke_config, + private_extensions: Vec::new(), + } + } + + pub fn with_private_extensions(mut self, extensions: Vec) -> Self { + self.private_extensions = extensions; + self + } + + pub fn next( + &self, + measurement: &V::Measurement, + ) -> (PrepareInit, VdafTranscript) { + self.next_with_metadata( + ReportMetadata::new( + random(), + self.clock + .now() + .to_batch_interval_start(self.task.time_precision()) + .unwrap(), + Vec::new(), + ), + measurement, + ) + } + + pub fn next_with_metadata( + &self, + report_metadata: ReportMetadata, + measurement: &V::Measurement, + ) -> (PrepareInit, VdafTranscript) { + let (report_share, transcript) = + self.next_report_share_with_metadata(report_metadata, measurement); + ( + PrepareInit::new( + report_share, + transcript.leader_prepare_transitions[0].message.clone(), + ), + transcript, + ) + } + + pub fn next_report_share( + &self, + measurement: &V::Measurement, + ) -> (ReportShare, VdafTranscript) { + self.next_report_share_with_metadata( + ReportMetadata::new( + random(), + self.clock + .now() + .to_batch_interval_start(self.task.time_precision()) + .unwrap(), + Vec::new(), + ), + measurement, + ) + } + + pub fn next_report_share_with_metadata( + &self, + report_metadata: ReportMetadata, + measurement: &V::Measurement, + ) -> (ReportShare, VdafTranscript) { + let transcript = run_vdaf( + &self.vdaf, + self.task.id(), + self.task.vdaf_verify_key().unwrap().as_bytes(), + &self.aggregation_param, + report_metadata.id(), + measurement, + ); + let report_share = generate_helper_report_share::( + *self.task.id(), + report_metadata, + &self.hpke_config, + &transcript.public_share, + self.private_extensions.clone(), + &transcript.helper_input_share, + ); + (report_share, transcript) + } + } + + pub async fn put_aggregation_job( + task: &Task, + aggregation_job_id: &AggregationJobId, + aggregation_job: &AggregationJobInitializeReq, + handler: &impl Handler, + ) -> TestConn { + let (header, value) = task.aggregator_auth_token().request_authentication(); + + put(task + .aggregation_job_uri(aggregation_job_id, None) + .unwrap() + .path()) + .with_request_header(header, value) + .with_request_header( + KnownHeaderName::ContentType, + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .with_request_body(aggregation_job.get_encoded().unwrap()) + .run_async(handler) + .await + } +} + +#[cfg(test)] +mod tests { + use crate::aggregator::{ + aggregation_job_init::test_util::{put_aggregation_job, PrepareInitGenerator}, + http_handlers::{ + test_util::{decode_response_body, take_problem_details}, + AggregatorHandlerBuilder, + }, + Config, + }; + use assert_matches::assert_matches; + use http::StatusCode; + use janus_aggregator_core::{ + datastore::test_util::{ephemeral_datastore, EphemeralDatastore}, + task::{ + test_util::{Task, TaskBuilder}, + AggregationMode, BatchMode, + }, + test_util::noop_meter, + }; + use janus_core::{ + auth_tokens::{AuthenticationToken, DAP_AUTH_HEADER}, + test_util::{install_test_trace_subscriber, runtime::TestRuntime}, + time::{Clock, MockClock, TimeExt as _}, + vdaf::VdafInstance, + }; + use janus_messages::{ + batch_mode::TimeInterval, AggregationJobId, AggregationJobInitializeReq, + AggregationJobResp, Duration, Extension, ExtensionType, PartialBatchSelector, PrepareResp, + PrepareStepResult, ReportError, ReportMetadata, + }; + use prio::{ + codec::Encode, + vdaf::{self, dummy}, + }; + use rand::random; + use serde_json::json; + use std::sync::Arc; + use trillium::{Handler, KnownHeaderName, Status}; + use trillium_testing::prelude::put; + + pub(super) struct AggregationJobInitTestCase< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator, + > { + pub(super) clock: MockClock, + pub(super) task: Task, + pub(super) prepare_init_generator: PrepareInitGenerator, + pub(super) aggregation_job_id: AggregationJobId, + pub(super) aggregation_job_init_req: AggregationJobInitializeReq, + aggregation_job_init_resp: Option, + pub(super) aggregation_param: V::AggregationParam, + pub(super) handler: Box, + _ephemeral_datastore: EphemeralDatastore, + } + + pub(super) async fn setup_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy::Vdaf> { + setup_aggregate_init_test_for_vdaf( + dummy::Vdaf::new(1), + VdafInstance::Fake { rounds: 1 }, + dummy::AggregationParam(0), + 0, + ) + .await + } + + async fn setup_multi_step_aggregate_init_test() -> AggregationJobInitTestCase<0, dummy::Vdaf> { + setup_aggregate_init_test_for_vdaf( + dummy::Vdaf::new(2), + VdafInstance::Fake { rounds: 2 }, + dummy::AggregationParam(7), + 13, + ) + .await + } + + async fn setup_aggregate_init_test_for_vdaf< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator + vdaf::Client<16>, + >( + vdaf: V, + vdaf_instance: VdafInstance, + aggregation_param: V::AggregationParam, + measurement: V::Measurement, + ) -> AggregationJobInitTestCase { + let mut test_case = setup_aggregate_init_test_without_sending_request( + vdaf, + vdaf_instance, + aggregation_param, + measurement, + AuthenticationToken::Bearer(random()), + ) + .await; + + let mut response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &test_case.aggregation_job_init_req, + &test_case.handler, + ) + .await; + assert_eq!(response.status(), Some(Status::Created)); + + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + let prepare_resps = assert_matches!( + &aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); + assert_eq!( + prepare_resps.len(), + test_case.aggregation_job_init_req.prepare_inits().len(), + ); + assert_matches!( + prepare_resps[0].result(), + &PrepareStepResult::Continue { .. } + ); + + test_case.aggregation_job_init_resp = Some(aggregation_job_resp); + test_case + } + + async fn setup_aggregate_init_test_without_sending_request< + const VERIFY_KEY_SIZE: usize, + V: vdaf::Aggregator + vdaf::Client<16>, + >( + vdaf: V, + vdaf_instance: VdafInstance, + aggregation_param: V::AggregationParam, + measurement: V::Measurement, + auth_token: AuthenticationToken, + ) -> AggregationJobInitTestCase { + install_test_trace_subscriber(); + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + vdaf_instance, + ) + .with_aggregator_auth_token(auth_token) + .build(); + let helper_task = task.helper_view().unwrap(); + let clock = MockClock::default(); + let ephemeral_datastore = ephemeral_datastore().await; + let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); + + datastore.put_aggregator_task(&helper_task).await.unwrap(); + let keypair = datastore.put_hpke_key().await.unwrap(); + + let handler = AggregatorHandlerBuilder::new( + Arc::clone(&datastore), + clock.clone(), + TestRuntime::default(), + &noop_meter(), + Config::default(), + ) + .await + .unwrap() + .build() + .unwrap(); + + let prepare_init_generator = PrepareInitGenerator::new( + clock.clone(), + helper_task.clone(), + keypair.config().clone(), + vdaf, + aggregation_param.clone(), + ); + + let prepare_inits = Vec::from([ + prepare_init_generator.next(&measurement).0, + prepare_init_generator.next(&measurement).0, + ]); + + let aggregation_job_id = random(); + let aggregation_job_init_req = AggregationJobInitializeReq::new( + aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + prepare_inits.clone(), + ); + + AggregationJobInitTestCase { + clock, + task, + prepare_init_generator, + aggregation_job_id, + aggregation_job_init_req, + aggregation_job_init_resp: None, + aggregation_param, + handler: Box::new(handler), + _ephemeral_datastore: ephemeral_datastore, + } + } + + #[tokio::test] + async fn aggregation_job_init_authorization_dap_auth_token() { + let test_case = setup_aggregate_init_test_without_sending_request( + dummy::Vdaf::new(1), + VdafInstance::Fake { rounds: 1 }, + dummy::AggregationParam(0), + 0, + AuthenticationToken::DapAuth(random()), + ) + .await; + + let (auth_header, auth_value) = test_case + .task + .aggregator_auth_token() + .request_authentication(); + + let response = put(test_case + .task + .aggregation_job_uri(&test_case.aggregation_job_id, None) + .unwrap() + .path()) + .with_request_header(auth_header, auth_value) + .with_request_header( + KnownHeaderName::ContentType, + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .with_request_body(test_case.aggregation_job_init_req.get_encoded().unwrap()) + .run_async(&test_case.handler) + .await; + + assert_eq!(response.status(), Some(Status::Created)); + } + + #[rstest::rstest] + #[case::not_bearer_token("wrong kind of token")] + #[case::not_base64("Bearer: ")] + #[tokio::test] + async fn aggregation_job_init_malformed_authorization_header( + #[case] header_value: &'static str, + ) { + let test_case = setup_aggregate_init_test_without_sending_request( + dummy::Vdaf::new(1), + VdafInstance::Fake { rounds: 1 }, + dummy::AggregationParam(0), + 0, + AuthenticationToken::Bearer(random()), + ) + .await; + + let response = put(test_case + .task + .aggregation_job_uri(&test_case.aggregation_job_id, None) + .unwrap() + .path()) + // Authenticate using a malformed "Authorization: Bearer " header and a `DAP-Auth-Token` + // header. The presence of the former should cause an error despite the latter being present and + // well formed. + .with_request_header(KnownHeaderName::Authorization, header_value.to_string()) + .with_request_header( + DAP_AUTH_HEADER, + test_case.task.aggregator_auth_token().as_ref().to_owned(), + ) + .with_request_header( + KnownHeaderName::ContentType, + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .with_request_body(test_case.aggregation_job_init_req.get_encoded().unwrap()) + .run_async(&test_case.handler) + .await; + + assert_eq!(response.status(), Some(Status::Forbidden)); + } + + #[tokio::test] + async fn aggregation_job_init_unexpected_taskbind_extension() { + let test_case = setup_aggregate_init_test_without_sending_request( + dummy::Vdaf::new(1), + VdafInstance::Fake { rounds: 1 }, + dummy::AggregationParam(0), + 0, + random(), + ) + .await; + + let prepare_init = test_case + .prepare_init_generator + .clone() + .with_private_extensions(Vec::from([Extension::new( + ExtensionType::Taskbind, + Vec::new(), + )])) + .next(&0) + .0; + let report_id = *prepare_init.report_share().metadata().id(); + let aggregation_job_init_req = AggregationJobInitializeReq::new( + dummy::AggregationParam(1).get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([prepare_init]), + ); + + let mut response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &aggregation_job_init_req, + &test_case.handler, + ) + .await; + assert_eq!(response.status(), Some(Status::Created)); + + let want_aggregation_job_resp = AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + report_id, + PrepareStepResult::Reject(ReportError::InvalidMessage), + )]), + }; + let got_aggregation_job_resp: AggregationJobResp = + decode_response_body(&mut response).await; + assert_eq!(want_aggregation_job_resp, got_aggregation_job_resp); + } + + #[tokio::test] + async fn aggregation_job_mutation_aggregation_job() { + let test_case = setup_aggregate_init_test().await; + + // Put the aggregation job again, but with a different aggregation parameter. + let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( + dummy::AggregationParam(1).get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + test_case.aggregation_job_init_req.prepare_inits().to_vec(), + ); + + let response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &mutated_aggregation_job_init_req, + &test_case.handler, + ) + .await; + assert_eq!(response.status(), Some(Status::Conflict)); + } + + #[tokio::test] + async fn aggregation_job_mutation_report_shares() { + let test_case = setup_aggregate_init_test().await; + + let prepare_inits = test_case.aggregation_job_init_req.prepare_inits(); + + // Put the aggregation job again, mutating the associated report shares' metadata such that + // uniqueness constraints on client_reports are violated + for mutated_prepare_inits in [ + // Omit a report share that was included previously + Vec::from(&prepare_inits[0..prepare_inits.len() - 1]), + // Include a different report share than was included previously + [ + &prepare_inits[0..prepare_inits.len() - 1], + &[test_case.prepare_init_generator.next(&0).0], + ] + .concat(), + // Include an extra report share than was included previously + [ + prepare_inits, + &[test_case.prepare_init_generator.next(&0).0], + ] + .concat(), + // Reverse the order of the reports + prepare_inits.iter().rev().cloned().collect(), + ] { + let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( + test_case.aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + mutated_prepare_inits, + ); + let response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &mutated_aggregation_job_init_req, + &test_case.handler, + ) + .await; + assert_eq!(response.status(), Some(Status::Conflict)); + } + } + + #[tokio::test] + async fn aggregation_job_mutation_report_aggregations() { + // We set up a multi-step VDAF in this test so that the aggregation job won't finish on the + // first step. + let test_case = setup_multi_step_aggregate_init_test().await; + + // Generate some new reports using the existing reports' metadata, but varying the measurement + // values such that the prepare state computed during aggregation initializaton won't match the + // first aggregation job. + let mutated_prepare_inits = test_case + .aggregation_job_init_req + .prepare_inits() + .iter() + .map(|s| { + test_case + .prepare_init_generator + .next_with_metadata(s.report_share().metadata().clone(), &13) + .0 + }) + .collect(); + + let mutated_aggregation_job_init_req = AggregationJobInitializeReq::new( + test_case.aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + mutated_prepare_inits, + ); + + let response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &mutated_aggregation_job_init_req, + &test_case.handler, + ) + .await; + assert_eq!(response.status(), Some(Status::Conflict)); + } + + #[tokio::test] + async fn aggregation_job_intolerable_clock_skew() { + let mut test_case = setup_aggregate_init_test_without_sending_request( + dummy::Vdaf::new(1), + VdafInstance::Fake { rounds: 1 }, + dummy::AggregationParam(0), + 0, + AuthenticationToken::Bearer(random()), + ) + .await; + + test_case.aggregation_job_init_req = AggregationJobInitializeReq::new( + test_case.aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([ + // Barely tolerable. + test_case + .prepare_init_generator + .next_with_metadata( + ReportMetadata::new( + random(), + test_case + .clock + .now() + .add(test_case.task.tolerable_clock_skew()) + .unwrap(), + Vec::new(), + ), + &0, + ) + .0, + // Barely intolerable. + test_case + .prepare_init_generator + .next_with_metadata( + ReportMetadata::new( + random(), + test_case + .clock + .now() + .add(test_case.task.tolerable_clock_skew()) + .unwrap() + .add(&Duration::from_seconds(1)) + .unwrap(), + Vec::new(), + ), + &0, + ) + .0, + ]), + ); + + let mut response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &test_case.aggregation_job_init_req, + &test_case.handler, + ) + .await; + assert_eq!(response.status(), Some(Status::Created)); + + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + let prepare_resps = assert_matches!( + aggregation_job_resp, + AggregationJobResp::Finished { prepare_resps } => prepare_resps + ); + assert_eq!( + prepare_resps.len(), + test_case.aggregation_job_init_req.prepare_inits().len(), + ); + assert_matches!( + prepare_resps[0].result(), + &PrepareStepResult::Continue { .. } + ); + assert_matches!( + prepare_resps[1].result(), + &PrepareStepResult::Reject(ReportError::ReportTooEarly) + ); + } + + #[tokio::test] + async fn aggregation_job_init_two_step_vdaf_idempotence() { + // We set up a multi-step VDAF in this test so that the aggregation job won't finish on the + // first step. + let test_case = setup_multi_step_aggregate_init_test().await; + + // Send the aggregation job init request again. We should get an identical response back. + let mut response = put_aggregation_job( + &test_case.task, + &test_case.aggregation_job_id, + &test_case.aggregation_job_init_req, + &test_case.handler, + ) + .await; + + let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await; + assert_eq!( + aggregation_job_resp, + test_case.aggregation_job_init_resp.unwrap(), + ); + } + + #[tokio::test] + async fn aggregation_job_init_wrong_query() { + let test_case = setup_aggregate_init_test().await; + + // setup_aggregate_init_test sets up a task with a time interval query. We send a + // leader-selected query which should yield an error. + let wrong_query = AggregationJobInitializeReq::new( + test_case.aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_leader_selected(random()), + test_case.aggregation_job_init_req.prepare_inits().to_vec(), + ); + + let (header, value) = test_case + .task + .aggregator_auth_token() + .request_authentication(); + + let mut response = put(test_case + .task + .aggregation_job_uri(&random(), None) + .unwrap() + .path()) + .with_request_header(header, value) + .with_request_header( + KnownHeaderName::ContentType, + AggregationJobInitializeReq::::MEDIA_TYPE, + ) + .with_request_body(wrong_query.get_encoded().unwrap()) + .run_async(&test_case.handler) + .await; + assert_eq!( + take_problem_details(&mut response).await, + json!({ + "status": StatusCode::BAD_REQUEST.as_u16(), + "type": "urn:ietf:params:ppm:dap:error:invalidMessage", + "title": "The message type for a response was incorrect or the payload was malformed.", + }), + ); + } +} diff --git a/aggregator/src/aggregator/aggregation_job_writer.rs b/aggregator/src/aggregator/aggregation_job_writer.rs index 0c0201542..18a6eb94e 100644 --- a/aggregator/src/aggregator/aggregation_job_writer.rs +++ b/aggregator/src/aggregator/aggregation_job_writer.rs @@ -344,7 +344,9 @@ impl WriteType for InitialWrite { { // For new writes (inserts) of aggregation jobs in a non-terminal state, increment // aggregation_jobs_created. - if aggregation_job.state() == &AggregationJobState::InProgress { + if aggregation_job.state() == &AggregationJobState::Active + || aggregation_job.state() == &AggregationJobState::AwaitingRequest + { *batch_aggregation = BatchAggregation::new( *batch_aggregation.task_id(), batch_aggregation.batch_identifier().clone(), @@ -404,7 +406,9 @@ impl WriteType for UpdateWrite { // For updates of aggregation jobs into a terminal state, increment // aggregation_jobs_terminated. (This is safe to do since we will not process a terminal // aggregation job again.) - if aggregation_job.state() != &AggregationJobState::InProgress { + if aggregation_job.state() != &AggregationJobState::Active + && aggregation_job.state() != &AggregationJobState::AwaitingRequest + { *batch_aggregation = BatchAggregation::new( *batch_aggregation.task_id(), batch_aggregation.batch_identifier().clone(), @@ -946,6 +950,11 @@ impl> output_share, } } + + /// Retrieves the report aggregation to be written. + pub fn report_aggregation(&self) -> &ReportAggregation { + &self.report_aggregation + } } /// Abstracts over multiple representations of a report aggregation. diff --git a/aggregator/src/aggregator/batch_creator.rs b/aggregator/src/aggregator/batch_creator.rs index 3fda9eb26..25ea76151 100644 --- a/aggregator/src/aggregator/batch_creator.rs +++ b/aggregator/src/aggregator/batch_creator.rs @@ -369,7 +369,7 @@ where (), batch_id, client_timestamp_interval, - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); aggregation_job_writer.put(aggregation_job, report_aggregations)?; diff --git a/aggregator/src/aggregator/collection_job_driver.rs b/aggregator/src/aggregator/collection_job_driver.rs index a23d5fd1c..dca7e5099 100644 --- a/aggregator/src/aggregator/collection_job_driver.rs +++ b/aggregator/src/aggregator/collection_job_driver.rs @@ -810,7 +810,7 @@ mod tests { }, task::{ test_util::{Task, TaskBuilder}, - BatchMode, + AggregationMode, BatchMode, }, test_util::noop_meter, }; @@ -845,11 +845,15 @@ mod tests { CollectionJob<0, TimeInterval, dummy::Vdaf>, ) { let time_precision = Duration::from_seconds(500); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .with_time_precision(time_precision) - .with_min_batch_size(10) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .with_time_precision(time_precision) + .with_min_batch_size(10) + .build(); let leader_task = task.leader_view().unwrap(); let batch_interval = Interval::new(clock.now(), Duration::from_seconds(2000)).unwrap(); @@ -985,11 +989,15 @@ mod tests { let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); let time_precision = Duration::from_seconds(500); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_helper_aggregator_endpoint(server.url().parse().unwrap()) - .with_time_precision(time_precision) - .with_min_batch_size(10) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_helper_aggregator_endpoint(server.url().parse().unwrap()) + .with_time_precision(time_precision) + .with_min_batch_size(10) + .build(); let leader_task = task.leader_view().unwrap(); let agg_auth_token = task.aggregator_auth_token(); diff --git a/aggregator/src/aggregator/collection_job_tests.rs b/aggregator/src/aggregator/collection_job_tests.rs index 4dd49520a..862c7052c 100644 --- a/aggregator/src/aggregator/collection_job_tests.rs +++ b/aggregator/src/aggregator/collection_job_tests.rs @@ -16,7 +16,7 @@ use janus_aggregator_core::{ }, task::{ test_util::{Task, TaskBuilder}, - BatchMode, + AggregationMode, BatchMode, }, test_util::noop_meter, }; @@ -259,7 +259,12 @@ pub(crate) async fn setup_collection_job_test_case( ) -> CollectionJobTestCase { install_test_trace_subscriber(); - let task = TaskBuilder::new(batch_mode, VdafInstance::Fake { rounds: 1 }).build(); + let task = TaskBuilder::new( + batch_mode, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); let role_task = task.view_for_role(role).unwrap(); let clock = MockClock::default(); let ephemeral_datastore = ephemeral_datastore().await; diff --git a/aggregator/src/aggregator/error.rs b/aggregator/src/aggregator/error.rs index bb73a7fe0..f004995e2 100644 --- a/aggregator/src/aggregator/error.rs +++ b/aggregator/src/aggregator/error.rs @@ -60,6 +60,9 @@ pub enum Error { /// An attempt was made to act on an unknown aggregation job. #[error("task {0}: unrecognized aggregation job: {1}")] UnrecognizedAggregationJob(TaskId, AggregationJobId), + /// An attempt was made to act on a known but abandoned aggregation job. + #[error("task {0}: abandoned aggregation job: {1}")] + AbandonedAggregationJob(TaskId, AggregationJobId), /// An attempt was made to act on a known but deleted aggregation job. #[error("task {0}: deleted aggregation job: {1}")] DeletedAggregationJob(TaskId, AggregationJobId), @@ -291,6 +294,7 @@ impl Error { Error::UnrecognizedTask(_) => "unrecognized_task", Error::MissingTaskId => "missing_task_id", Error::UnrecognizedAggregationJob(_, _) => "unrecognized_aggregation_job", + Error::AbandonedAggregationJob(_, _) => "abandoned_aggregation_job", Error::DeletedAggregationJob(_, _) => "deleted_aggregation_job", Error::DeletedCollectionJob(_, _) => "deleted_collection_job", Error::AbandonedCollectionJob(_, _) => "abandoned_collection_job", diff --git a/aggregator/src/aggregator/garbage_collector.rs b/aggregator/src/aggregator/garbage_collector.rs index 4343ac455..4a3cc247b 100644 --- a/aggregator/src/aggregator/garbage_collector.rs +++ b/aggregator/src/aggregator/garbage_collector.rs @@ -182,7 +182,7 @@ mod tests { }, test_util::ephemeral_datastore, }, - task::{self, test_util::TaskBuilder}, + task::{self, test_util::TaskBuilder, AggregationMode}, test_util::noop_meter, }; use janus_core::{ @@ -219,6 +219,7 @@ mod tests { Box::pin(async move { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -240,7 +241,7 @@ mod tests { aggregation_param, (), Interval::from_time(&client_timestamp).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -369,9 +370,11 @@ mod tests { let task = ds .run_unnamed_tx(|tx| { let clock = clock.clone(); + Box::pin(async move { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -391,9 +394,13 @@ mod tests { Vec::from("payload_0"), ), ); - tx.put_scrubbed_report(task.id(), &report_share) - .await - .unwrap(); + tx.put_scrubbed_report( + task.id(), + report_share.metadata().id(), + report_share.metadata().time(), + ) + .await + .unwrap(); // Aggregation artifacts. let aggregation_job_id = random(); @@ -403,7 +410,7 @@ mod tests { aggregation_param, (), Interval::from_time(&client_timestamp).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -548,6 +555,7 @@ mod tests { task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -574,7 +582,7 @@ mod tests { aggregation_param, batch_id, Interval::from_time(&client_timestamp).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); tx.put_aggregation_job(&aggregation_job).await.unwrap(); @@ -713,11 +721,13 @@ mod tests { let task = ds .run_unnamed_tx(|tx| { let clock = clock.clone(); + Box::pin(async move { let task = TaskBuilder::new( task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -742,9 +752,13 @@ mod tests { Vec::from("payload_0"), ), ); - tx.put_scrubbed_report(task.id(), &report_share) - .await - .unwrap(); + tx.put_scrubbed_report( + task.id(), + report_share.metadata().id(), + report_share.metadata().time(), + ) + .await + .unwrap(); // Aggregation artifacts. let batch_id = random(); @@ -754,7 +768,7 @@ mod tests { aggregation_param, batch_id, Interval::from_time(&client_timestamp).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); tx.put_aggregation_job(&aggregation_job).await.unwrap(); diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index 1a25df3c2..37c79d687 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -20,14 +20,15 @@ use janus_core::{ use janus_messages::{ batch_mode::TimeInterval, codec::Decode, problem_type::DapProblemType, taskprov::TaskConfig, AggregateShare, AggregateShareReq, AggregationJobContinueReq, AggregationJobId, - AggregationJobInitializeReq, AggregationJobResp, CollectionJobId, CollectionJobReq, - CollectionJobResp, HpkeConfigList, Report, TaskId, + AggregationJobInitializeReq, AggregationJobResp, AggregationJobStep, CollectionJobId, + CollectionJobReq, CollectionJobResp, HpkeConfigList, Report, TaskId, }; use opentelemetry::{ metrics::{Counter, Meter}, KeyValue, }; use prio::codec::Encode; +use querystring::querify; use serde::{Deserialize, Serialize}; use std::sync::Arc; use std::{borrow::Cow, time::Duration as StdDuration}; @@ -90,6 +91,15 @@ async fn run_error_handler(error: &Error, mut conn: Conn) -> Conn { &ProblemDocument::new_dap(DapProblemType::UnrecognizedAggregationJob) .with_task_id(task_id), ), + Error::AbandonedAggregationJob(task_id, aggregation_job_id) => conn.with_problem_document( + &ProblemDocument::new( + "https://docs.divviup.org/references/janus-errors#aggregation-job-abandoned", + "The aggregation job has been abandoned.", + Status::Gone, + ) + .with_task_id(task_id) + .with_aggregation_job_id(aggregation_job_id), + ), Error::DeletedAggregationJob(task_id, aggregation_job_id) => conn.with_problem_document( &ProblemDocument::new( "https://docs.divviup.org/references/janus-errors#aggregation-job-deleted", @@ -402,6 +412,10 @@ where Box::new(api(aggregation_jobs_post::)) as Box }), ) + .get( + AGGREGATION_JOB_ROUTE, + instrumented(api(aggregation_jobs_get::)), + ) .delete( AGGREGATION_JOB_ROUTE, instrumented(api(aggregation_jobs_delete::)), @@ -592,6 +606,32 @@ async fn aggregation_jobs_post( Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE).with_status(Status::Accepted)) } +/// API handler for the "/tasks/.../aggregation_jobs/..." GET endpoint. +async fn aggregation_jobs_get( + conn: &mut Conn, + State(aggregator): State>>, +) -> Result, Error> { + let task_id = parse_task_id(conn)?; + let aggregation_job_id = parse_aggregation_job_id(conn)?; + let auth_token = parse_auth_token(&task_id, conn)?; + let taskprov_task_config = parse_taskprov_header(&aggregator, &task_id, conn)?; + let step = parse_step(conn)? + .ok_or_else(|| Error::BadRequest("missing step query parameter".to_string()))?; + + let response = conn + .cancel_on_disconnect(aggregator.handle_aggregate_get( + &task_id, + &aggregation_job_id, + auth_token, + taskprov_task_config.as_ref(), + step, + )) + .await + .ok_or(Error::ClientDisconnected)??; + + Ok(EncodedBody::new(response, AggregationJobResp::MEDIA_TYPE).with_status(Status::Ok)) +} + /// API handler for the "/tasks/.../aggregation_jobs/..." DELETE endpoint. async fn aggregation_jobs_delete( conn: &mut Conn, @@ -811,6 +851,17 @@ fn parse_taskprov_header( )) } +/// Gets the [`AggregationJobStep`] from the request's query string. +fn parse_step(conn: &Conn) -> Result, Error> { + const STEP_KEY: &str = "step"; + querify(conn.querystring()) + .into_iter() + .find(|(key, _)| *key == STEP_KEY) + .map(|(_, val)| val.parse::().map(AggregationJobStep::from)) + .transpose() + .map_err(|err| Error::BadRequest(format!("couldn't parse step: {err}"))) +} + struct BodyBytes(Vec); #[async_trait] diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregate_share.rs b/aggregator/src/aggregator/http_handlers/tests/aggregate_share.rs index 633efce55..7abbcd507 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregate_share.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregate_share.rs @@ -9,7 +9,7 @@ use janus_aggregator_core::{ datastore::models::{BatchAggregation, BatchAggregationState}, task::{ test_util::{Task, TaskBuilder}, - BatchMode, + AggregationMode, BatchMode, }, }; use janus_core::{ @@ -58,7 +58,12 @@ async fn aggregate_share_request_to_leader() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); let leader_task = task.leader_view().unwrap(); datastore.put_aggregator_task(&leader_task).await.unwrap(); @@ -97,9 +102,13 @@ async fn aggregate_share_request_invalid_batch_interval() { // Prepare parameters. const REPORT_EXPIRY_AGE: Duration = Duration::from_seconds(3600); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build(); let helper_task = task.helper_view().unwrap(); datastore.put_aggregator_task(&helper_task).await.unwrap(); @@ -159,10 +168,14 @@ async fn aggregate_share_request() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_time_precision(Duration::from_seconds(500)) - .with_min_batch_size(10) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_time_precision(Duration::from_seconds(500)) + .with_min_batch_size(10) + .build(); let helper_task = task.helper_view().unwrap(); datastore.put_aggregator_task(&helper_task).await.unwrap(); diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs index c26b18e7e..29a12b198 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_continue.rs @@ -17,7 +17,7 @@ use janus_aggregator_core::{ merge_batch_aggregations_by_batch, AggregationJob, AggregationJobState, BatchAggregation, BatchAggregationState, ReportAggregation, ReportAggregationState, TaskAggregationCounter, }, - task::{test_util::TaskBuilder, BatchMode, VerifyKey}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode, VerifyKey}, }; use janus_core::{ report_id::ReportIdChecksumExt, @@ -27,9 +27,8 @@ use janus_core::{ }; use janus_messages::{ batch_mode::TimeInterval, AggregationJobContinueReq, AggregationJobResp, AggregationJobStep, - Duration, HpkeCiphertext, HpkeConfigId, Interval, PrepareContinue, PrepareResp, - PrepareStepResult, ReportError, ReportId, ReportIdChecksum, ReportMetadata, ReportShare, Role, - Time, + Duration, Interval, PrepareContinue, PrepareResp, PrepareStepResult, ReportError, ReportId, + ReportIdChecksum, ReportMetadata, Role, Time, }; use prio::{ topology::ping_pong::PingPongMessage, @@ -40,7 +39,7 @@ use std::sync::Arc; use trillium::Status; #[tokio::test] -async fn aggregate_continue() { +async fn aggregate_continue_sync() { let HttpHandlerTest { clock, ephemeral_datastore: _ephemeral_datastore, @@ -51,7 +50,12 @@ async fn aggregate_continue() { } = HttpHandlerTest::new().await; let aggregation_job_id = random(); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let vdaf = Arc::new(dummy::Vdaf::new(2)); @@ -148,51 +152,57 @@ async fn aggregate_continue() { datastore .run_unnamed_tx(|tx| { - let task = helper_task.clone(); - let (report_share_0, report_share_1, report_share_2) = ( - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - ); - let (helper_prep_state_0, helper_prep_state_1, helper_prep_state_2) = ( - *helper_prep_state_0, - *helper_prep_state_1, - *helper_prep_state_2, - ); - let (report_metadata_0, report_metadata_1, report_metadata_2) = ( - report_metadata_0.clone(), - report_metadata_1.clone(), - report_metadata_2.clone(), - ); + let helper_task = helper_task.clone(); + let report_share_0 = report_share_0.clone(); + let report_share_1 = report_share_1.clone(); + let report_share_2 = report_share_2.clone(); + let helper_prep_state_0 = *helper_prep_state_0; + let helper_prep_state_1 = *helper_prep_state_1; + let helper_prep_state_2 = *helper_prep_state_2; + let report_metadata_0 = report_metadata_0.clone(); + let report_metadata_1 = report_metadata_1.clone(); + let report_metadata_2 = report_metadata_2.clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_0) - .await - .unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_1) - .await - .unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_2) - .await - .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_0.metadata().id(), + report_share_0.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_1.metadata().id(), + report_share_1.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_2.metadata().id(), + report_share_2.metadata().time(), + ) + .await + .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::AwaitingRequest, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata_0.id(), *report_metadata_0.time(), @@ -205,7 +215,7 @@ async fn aggregate_continue() { .await .unwrap(); tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata_1.id(), *report_metadata_1.time(), @@ -218,7 +228,7 @@ async fn aggregate_continue() { .await .unwrap(); tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata_2.id(), *report_metadata_2.time(), @@ -235,10 +245,13 @@ async fn aggregate_continue() { // into, which will cause it to fail to prepare. try_join_all( empty_batch_aggregations::<0, TimeInterval, dummy::Vdaf>( - &task, + &helper_task, BATCH_AGGREGATION_SHARD_COUNT, - &Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()) - .unwrap(), + &Interval::new( + Time::from_seconds_since_epoch(0), + *helper_task.time_precision(), + ) + .unwrap(), &aggregation_param, &[], ) @@ -373,6 +386,258 @@ async fn aggregate_continue() { .await; } +#[tokio::test] +async fn aggregate_continue_async() { + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + hpke_keypair: hpke_key, + .. + } = HttpHandlerTest::new().await; + + let aggregation_job_id = random(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + + let vdaf = Arc::new(dummy::Vdaf::new(2)); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let measurement = 13; + let aggregation_param = dummy::AggregationParam(7); + + // report_share_0 is a "happy path" report. + let report_metadata_0 = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript_0 = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata_0.id(), + &measurement, + ); + let helper_prep_state_0 = transcript_0.helper_prepare_transitions[0].prepare_state(); + let leader_prep_message_0 = &transcript_0.leader_prepare_transitions[1].message; + let report_share_0 = generate_helper_report_share::( + *task.id(), + report_metadata_0.clone(), + hpke_key.config(), + &transcript_0.public_share, + Vec::new(), + &transcript_0.helper_input_share, + ); + + // report_share_1 is omitted by the leader's request. + let report_metadata_1 = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript_1 = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata_1.id(), + &measurement, + ); + + let helper_prep_state_1 = transcript_1.helper_prepare_transitions[0].prepare_state(); + let report_share_1 = generate_helper_report_share::( + *task.id(), + report_metadata_1.clone(), + hpke_key.config(), + &transcript_1.public_share, + Vec::new(), + &transcript_1.helper_input_share, + ); + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_share_0 = report_share_0.clone(); + let report_share_1 = report_share_1.clone(); + let helper_prep_state_0 = *helper_prep_state_0; + let helper_prep_state_1 = *helper_prep_state_1; + let report_metadata_0 = report_metadata_0.clone(); + let report_metadata_1 = report_metadata_1.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + tx.put_scrubbed_report( + helper_task.id(), + report_share_0.metadata().id(), + report_share_0.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_1.metadata().id(), + report_share_1.metadata().time(), + ) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata_0.id(), + *report_metadata_0.time(), + 0, + None, + ReportAggregationState::HelperContinue { + prepare_state: helper_prep_state_0, + }, + )) + .await + .unwrap(); + tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata_1.id(), + *report_metadata_1.time(), + 1, + None, + ReportAggregationState::HelperContinue { + prepare_state: helper_prep_state_1, + }, + )) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + let request = AggregationJobContinueReq::new( + AggregationJobStep::from(1), + Vec::from([PrepareContinue::new( + *report_metadata_0.id(), + leader_prep_message_0.clone(), + )]), + ); + + // Send request, and parse response. + let aggregate_resp = + post_aggregation_job_and_decode(&task, &aggregation_job_id, &request, &handler).await; + + // Validate response. + assert_eq!(aggregate_resp, AggregationJobResp::Processing); + + // Validate datastore. + let (aggregation_job, report_aggregations) = datastore + .run_unnamed_tx(|tx| { + let (vdaf, task) = (Arc::clone(&vdaf), task.clone()); + Box::pin(async move { + let aggregation_job = tx + .get_aggregation_job::<0, TimeInterval, dummy::Vdaf>( + task.id(), + &aggregation_job_id, + ) + .await + .unwrap() + .unwrap(); + let report_aggregations = tx + .get_report_aggregations_for_aggregation_job( + vdaf.as_ref(), + &Role::Helper, + task.id(), + &aggregation_job_id, + &aggregation_param, + ) + .await + .unwrap(); + Ok((aggregation_job, report_aggregations)) + }) + }) + .await + .unwrap(); + + assert_eq!( + aggregation_job, + AggregationJob::new( + *task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(1), + ) + .with_last_request_hash(aggregation_job.last_request_hash().unwrap()) + ); + assert_eq!( + report_aggregations, + Vec::from([ + ReportAggregation::new( + *task.id(), + aggregation_job_id, + *report_metadata_0.id(), + *report_metadata_0.time(), + 0, + None, + ReportAggregationState::HelperContinueProcessing { + prepare_state: *helper_prep_state_0, + prepare_continue: PrepareContinue::new( + *report_metadata_0.id(), + leader_prep_message_0.clone() + ), + }, + ), + ReportAggregation::new( + *task.id(), + aggregation_job_id, + *report_metadata_1.id(), + *report_metadata_1.time(), + 1, + None, + ReportAggregationState::Failed { + report_error: ReportError::ReportDropped + }, + ), + ]) + ); + + assert_task_aggregation_counter( + &datastore, + *task.id(), + TaskAggregationCounter::new_with_values(0), + ) + .await; +} + #[tokio::test] async fn aggregate_continue_accumulate_batch_aggregation() { let HttpHandlerTest { @@ -383,7 +648,12 @@ async fn aggregate_continue_accumulate_batch_aggregation() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let aggregation_job_id_0 = random(); let aggregation_job_id_1 = random(); @@ -525,52 +795,58 @@ async fn aggregate_continue_accumulate_batch_aggregation() { datastore .run_unnamed_tx(|tx| { - let task = helper_task.clone(); - let (report_share_0, report_share_1, report_share_2) = ( - report_share_0.clone(), - report_share_1.clone(), - report_share_2.clone(), - ); - let (helper_prep_state_0, helper_prep_state_1, helper_prep_state_2) = ( - *helper_prep_state_0, - *helper_prep_state_1, - *helper_prep_state_2, - ); - let (report_metadata_0, report_metadata_1, report_metadata_2) = ( - report_metadata_0.clone(), - report_metadata_1.clone(), - report_metadata_2.clone(), - ); + let helper_task = helper_task.clone(); + let report_share_0 = report_share_0.clone(); + let report_share_1 = report_share_1.clone(); + let report_share_2 = report_share_2.clone(); + let helper_prep_state_0 = *helper_prep_state_0; + let helper_prep_state_1 = *helper_prep_state_1; + let helper_prep_state_2 = *helper_prep_state_2; + let report_metadata_0 = report_metadata_0.clone(); + let report_metadata_1 = report_metadata_1.clone(); + let report_metadata_2 = report_metadata_2.clone(); let second_batch_want_batch_aggregations = second_batch_want_batch_aggregations.clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_0) - .await - .unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_1) - .await - .unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_2) - .await - .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_0.metadata().id(), + report_share_0.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_1.metadata().id(), + report_share_1.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_2.metadata().id(), + report_share_2.metadata().time(), + ) + .await + .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_0, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_0, *report_metadata_0.id(), *report_metadata_0.time(), @@ -583,7 +859,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_0, *report_metadata_1.id(), *report_metadata_1.time(), @@ -596,7 +872,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_0, *report_metadata_2.id(), *report_metadata_2.time(), @@ -610,7 +886,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { .unwrap(); tx.put_batch_aggregation(&BatchAggregation::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), first_batch_identifier, aggregation_param, 0, @@ -836,49 +1112,55 @@ async fn aggregate_continue_accumulate_batch_aggregation() { datastore .run_unnamed_tx(|tx| { - let task = helper_task.clone(); - let (report_share_3, report_share_4, report_share_5) = ( - report_share_3.clone(), - report_share_4.clone(), - report_share_5.clone(), - ); - let (helper_prep_state_3, helper_prep_state_4, helper_prep_state_5) = ( - *helper_prep_state_3, - *helper_prep_state_4, - *helper_prep_state_5, - ); - let (report_metadata_3, report_metadata_4, report_metadata_5) = ( - report_metadata_3.clone(), - report_metadata_4.clone(), - report_metadata_5.clone(), - ); + let helper_task = helper_task.clone(); + let report_share_3 = report_share_3.clone(); + let report_share_4 = report_share_4.clone(); + let report_share_5 = report_share_5.clone(); + let helper_prep_state_3 = *helper_prep_state_3; + let helper_prep_state_4 = *helper_prep_state_4; + let helper_prep_state_5 = *helper_prep_state_5; + let report_metadata_3 = report_metadata_3.clone(); + let report_metadata_4 = report_metadata_4.clone(); + let report_metadata_5 = report_metadata_5.clone(); Box::pin(async move { - tx.put_scrubbed_report(task.id(), &report_share_3) - .await - .unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_4) - .await - .unwrap(); - tx.put_scrubbed_report(task.id(), &report_share_5) - .await - .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_3.metadata().id(), + report_share_3.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_4.metadata().id(), + report_share_4.metadata().time(), + ) + .await + .unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + report_share_5.metadata().id(), + report_share_5.metadata().time(), + ) + .await + .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_1, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_1, *report_metadata_3.id(), *report_metadata_3.time(), @@ -891,7 +1173,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_1, *report_metadata_4.id(), *report_metadata_4.time(), @@ -904,7 +1186,7 @@ async fn aggregate_continue_accumulate_batch_aggregation() { .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id_1, *report_metadata_5.id(), *report_metadata_5.time(), @@ -1049,7 +1331,12 @@ async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let report_id = random(); let aggregation_param = dummy::AggregationParam(7); @@ -1071,43 +1358,34 @@ async fn aggregate_continue_leader_sends_non_continue_or_finish_transition() { // Setup datastore. datastore .run_unnamed_tx(|tx| { - let (task, aggregation_param, report_metadata, transcript) = ( - helper_task.clone(), - aggregation_param, - report_metadata.clone(), - transcript.clone(), - ); + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let transcript = transcript.clone(); + Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - report_metadata.clone(), - Vec::from("public share"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), + helper_task.id(), + report_metadata.id(), + report_metadata.time(), ) .await .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), @@ -1167,7 +1445,12 @@ async fn aggregate_continue_prep_step_fails() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let vdaf = dummy::Vdaf::new(2); let report_id = random(); @@ -1195,33 +1478,34 @@ async fn aggregate_continue_prep_step_fails() { // Setup datastore. datastore .run_unnamed_tx(|tx| { - let (task, aggregation_param, report_metadata, transcript, helper_report_share) = ( - helper_task.clone(), - aggregation_param, - report_metadata.clone(), - transcript.clone(), - helper_report_share.clone(), - ); + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let transcript = transcript.clone(); + let helper_report_share = helper_report_share.clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); - tx.put_scrubbed_report(task.id(), &helper_report_share) - .await - .unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); + tx.put_scrubbed_report( + helper_task.id(), + helper_report_share.metadata().id(), + helper_report_share.metadata().time(), + ) + .await + .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), @@ -1342,7 +1626,12 @@ async fn aggregate_continue_unexpected_transition() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let report_id = random(); let aggregation_param = dummy::AggregationParam(7); @@ -1361,43 +1650,33 @@ async fn aggregate_continue_unexpected_transition() { // Setup datastore. datastore .run_unnamed_tx(|tx| { - let (task, aggregation_param, report_metadata, transcript) = ( - helper_task.clone(), - aggregation_param, - report_metadata.clone(), - transcript.clone(), - ); + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let transcript = transcript.clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - report_metadata.clone(), - Vec::from("PUBLIC"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), + helper_task.id(), + report_metadata.id(), + report_metadata.time(), ) .await .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), @@ -1457,7 +1736,12 @@ async fn aggregate_continue_out_of_order_transition() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 2 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 2 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let report_id_0 = random(); let aggregation_param = dummy::AggregationParam(7); @@ -1493,69 +1777,45 @@ async fn aggregate_continue_out_of_order_transition() { // Setup datastore. datastore .run_unnamed_tx(|tx| { - let ( - task, - aggregation_param, - report_metadata_0, - report_metadata_1, - transcript_0, - transcript_1, - ) = ( - helper_task.clone(), - aggregation_param, - report_metadata_0.clone(), - report_metadata_1.clone(), - transcript_0.clone(), - transcript_1.clone(), - ); + let helper_task = helper_task.clone(); + let report_metadata_0 = report_metadata_0.clone(); + let report_metadata_1 = report_metadata_1.clone(); + let transcript_0 = transcript_0.clone(); + let transcript_1 = transcript_1.clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - report_metadata_0.clone(), - Vec::from("public"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), + helper_task.id(), + report_metadata_0.id(), + report_metadata_0.time(), ) .await .unwrap(); tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - report_metadata_1.clone(), - Vec::from("public"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), + helper_task.id(), + report_metadata_1.id(), + report_metadata_1.time(), ) .await .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, aggregation_param, (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata_0.id(), *report_metadata_0.time(), @@ -1568,7 +1828,7 @@ async fn aggregate_continue_out_of_order_transition() { .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata_1.id(), *report_metadata_1.time(), @@ -1635,7 +1895,12 @@ async fn aggregate_continue_for_non_waiting_aggregation() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let aggregation_job_id = random(); let report_metadata = ReportMetadata::new( @@ -1647,37 +1912,32 @@ async fn aggregate_continue_for_non_waiting_aggregation() { // Setup datastore. datastore .run_unnamed_tx(|tx| { - let (task, report_metadata) = (helper_task.clone(), report_metadata.clone()); + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - report_metadata.clone(), - Vec::from("public share"), - HpkeCiphertext::new( - HpkeConfigId::from(42), - Vec::from("012345"), - Vec::from("543210"), - ), - ), + helper_task.id(), + report_metadata.id(), + report_metadata.time(), ) .await .unwrap(); tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, dummy::AggregationParam(0), (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( - *task.id(), + *helper_task.id(), aggregation_job_id, *report_metadata.id(), *report_metadata.time(), diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_get.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_get.rs new file mode 100644 index 000000000..28c8ee72b --- /dev/null +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_get.rs @@ -0,0 +1,609 @@ +use crate::aggregator::{ + http_handlers::test_util::{decode_response_body, HttpHandlerTest}, + test_util::generate_helper_report_share, +}; +use janus_aggregator_core::{ + datastore::models::{ + AggregationJob, AggregationJobState, ReportAggregation, ReportAggregationState, + }, + task::{ + test_util::{Task, TaskBuilder}, + AggregationMode, BatchMode, VerifyKey, + }, +}; +use janus_core::{ + test_util::run_vdaf, + time::{Clock as _, TimeExt as _}, + vdaf::VdafInstance, +}; +use janus_messages::{ + batch_mode::TimeInterval, AggregationJobId, AggregationJobResp, AggregationJobStep, Duration, + Interval, PrepareInit, PrepareResp, PrepareStepResult, ReportMetadata, +}; +use prio::vdaf::dummy; +use rand::random; +use std::sync::Arc; +use trillium::{Handler, Status}; +use trillium_testing::{assert_headers, prelude::get, TestConn}; + +#[tokio::test] +async fn aggregation_job_get_ready() { + // Prepare state. + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + .. + } = HttpHandlerTest::new().await; + + let aggregation_job_id = random(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + + let vdaf = Arc::new(dummy::Vdaf::new(1)); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let measurement = 13; + let aggregation_param = dummy::AggregationParam(7); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + let helper_message = &transcript.helper_prepare_transitions[0].message; + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let helper_message = helper_message.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + tx.put_scrubbed_report( + helper_task.id(), + report_metadata.id(), + report_metadata.time(), + ) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(*report_metadata.time(), Duration::from_seconds(1)).unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata.id(), + *report_metadata.time(), + 0, + Some(PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Continue { + message: helper_message, + }, + )), + ReportAggregationState::Finished, + )) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + // Send request. + let aggregate_resp = get_aggregation_job_and_decode( + &task, + &aggregation_job_id, + Some(AggregationJobStep::from(0)), + &handler, + ) + .await; + + // Validate result. + assert_eq!( + aggregate_resp, + AggregationJobResp::Finished { + prepare_resps: Vec::from([PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Continue { + message: helper_message.clone(), + } + )]) + } + ); +} + +#[tokio::test] +async fn aggregation_job_get_unready() { + // Prepare state. + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + hpke_keypair, + } = HttpHandlerTest::new().await; + + let aggregation_job_id = random(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + + let vdaf = Arc::new(dummy::Vdaf::new(1)); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let measurement = 13; + let aggregation_param = dummy::AggregationParam(7); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + let leader_message = &transcript.leader_prepare_transitions[0].message; + let report_share = generate_helper_report_share::( + *task.id(), + report_metadata.clone(), + hpke_keypair.config(), + &transcript.public_share, + Vec::new(), + &transcript.helper_input_share, + ); + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let report_share = report_share.clone(); + let leader_message = leader_message.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + tx.put_scrubbed_report( + helper_task.id(), + report_metadata.id(), + report_metadata.time(), + ) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(*report_metadata.time(), Duration::from_seconds(1)).unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata.id(), + *report_metadata.time(), + 0, + None, + ReportAggregationState::HelperInitProcessing { + prepare_init: PrepareInit::new(report_share, leader_message), + require_taskbind_extension: false, + }, + )) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + // Send request. + let aggregate_resp = get_aggregation_job_and_decode( + &task, + &aggregation_job_id, + Some(AggregationJobStep::from(0)), + &handler, + ) + .await; + + // Validate result. + assert_eq!(aggregate_resp, AggregationJobResp::Processing); +} + +#[tokio::test] +async fn aggregation_job_get_wrong_step() { + // Prepare state. + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + .. + } = HttpHandlerTest::new().await; + + let aggregation_job_id = random(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + + let vdaf = Arc::new(dummy::Vdaf::new(1)); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let measurement = 13; + let aggregation_param = dummy::AggregationParam(7); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + let helper_message = &transcript.helper_prepare_transitions[0].message; + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let helper_message = helper_message.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + tx.put_scrubbed_report( + helper_task.id(), + report_metadata.id(), + report_metadata.time(), + ) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(*report_metadata.time(), Duration::from_seconds(1)).unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata.id(), + *report_metadata.time(), + 0, + Some(PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Continue { + message: helper_message, + }, + )), + ReportAggregationState::Finished, + )) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + // Send request. + let test_conn = get_aggregation_job( + &task, + &aggregation_job_id, + Some(AggregationJobStep::from(1)), + &handler, + ) + .await; + + // Validate result. + assert_eq!(test_conn.status(), Some(Status::BadRequest)); +} + +#[tokio::test] +async fn aggregation_job_get_missing_step() { + // Prepare state. + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + .. + } = HttpHandlerTest::new().await; + + let aggregation_job_id = random(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + + let vdaf = Arc::new(dummy::Vdaf::new(1)); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let measurement = 13; + let aggregation_param = dummy::AggregationParam(7); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + let helper_message = &transcript.helper_prepare_transitions[0].message; + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let helper_message = helper_message.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + tx.put_scrubbed_report( + helper_task.id(), + report_metadata.id(), + report_metadata.time(), + ) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(*report_metadata.time(), Duration::from_seconds(1)).unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata.id(), + *report_metadata.time(), + 0, + Some(PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Continue { + message: helper_message, + }, + )), + ReportAggregationState::Finished, + )) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + // Send request. + let test_conn = get_aggregation_job(&task, &aggregation_job_id, None, &handler).await; + + // Validate result. + assert_eq!(test_conn.status(), Some(Status::BadRequest)); +} + +#[tokio::test] +async fn aggregation_job_get_sync() { + // Prepare state. + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + .. + } = HttpHandlerTest::new().await; + + let aggregation_job_id = random(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + let helper_task = task.helper_view().unwrap(); + + let vdaf = Arc::new(dummy::Vdaf::new(1)); + let verify_key: VerifyKey<0> = task.vdaf_verify_key().unwrap(); + let measurement = 13; + let aggregation_param = dummy::AggregationParam(7); + + let report_metadata = ReportMetadata::new( + random(), + clock + .now() + .to_batch_interval_start(task.time_precision()) + .unwrap(), + Vec::new(), + ); + let transcript = run_vdaf( + vdaf.as_ref(), + task.id(), + verify_key.as_bytes(), + &aggregation_param, + report_metadata.id(), + &measurement, + ); + let helper_message = &transcript.helper_prepare_transitions[0].message; + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_metadata = report_metadata.clone(); + let helper_message = helper_message.clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + tx.put_scrubbed_report( + helper_task.id(), + report_metadata.id(), + report_metadata.time(), + ) + .await + .unwrap(); + + tx.put_aggregation_job(&AggregationJob::<0, TimeInterval, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + aggregation_param, + (), + Interval::new(*report_metadata.time(), Duration::from_seconds(1)).unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + )) + .await + .unwrap(); + + tx.put_report_aggregation(&ReportAggregation::<0, dummy::Vdaf>::new( + *helper_task.id(), + aggregation_job_id, + *report_metadata.id(), + *report_metadata.time(), + 0, + Some(PrepareResp::new( + *report_metadata.id(), + PrepareStepResult::Continue { + message: helper_message, + }, + )), + ReportAggregationState::Finished, + )) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + // Send request. + let test_conn = get_aggregation_job( + &task, + &aggregation_job_id, + Some(AggregationJobStep::from(0)), + &handler, + ) + .await; + + // Validate result. + assert_eq!(test_conn.status(), Some(Status::BadRequest)); +} + +async fn get_aggregation_job( + task: &Task, + aggregation_job_id: &AggregationJobId, + step: Option, + handler: &impl Handler, +) -> TestConn { + let uri = task.aggregation_job_uri(aggregation_job_id, step).unwrap(); + let uri = match uri.query() { + Some(query) => format!("{}?{}", uri.path(), query), + None => uri.path().to_string(), + }; + + let (header, value) = task.aggregator_auth_token().request_authentication(); + get(uri) + .with_request_header(header, value) + .run_async(handler) + .await +} + +async fn get_aggregation_job_and_decode( + task: &Task, + aggregation_job_id: &AggregationJobId, + step: Option, + handler: &impl Handler, +) -> AggregationJobResp { + let mut test_conn = get_aggregation_job(task, aggregation_job_id, step, handler).await; + assert_eq!(test_conn.status(), Some(Status::Ok)); + assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE)); + decode_response_body::(&mut test_conn).await +} diff --git a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs index 05ddc0074..2a7ae3746 100644 --- a/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs +++ b/aggregator/src/aggregator/http_handlers/tests/aggregation_job_init.rs @@ -1,5 +1,7 @@ +#![allow(clippy::unit_arg)] // allow reference to dummy::Vdaf's public share, which has the unit type + use crate::aggregator::{ - aggregate_init_tests::{put_aggregation_job, PrepareInitGenerator}, + aggregation_job_init::test_util::{put_aggregation_job, PrepareInitGenerator}, empty_batch_aggregations, http_handlers::test_util::{decode_response_body, take_problem_details, HttpHandlerTest}, test_util::{ @@ -11,10 +13,10 @@ use assert_matches::assert_matches; use futures::future::try_join_all; use janus_aggregator_core::{ datastore::models::{ - AggregationJob, AggregationJobState, BatchAggregation, BatchAggregationState, - ReportAggregation, ReportAggregationState, TaskAggregationCounter, + AggregationJobState, BatchAggregation, BatchAggregationState, ReportAggregationState, + TaskAggregationCounter, }, - task::{test_util::TaskBuilder, BatchMode, VerifyKey}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode, VerifyKey}, }; use janus_core::{ auth_tokens::AuthenticationToken, @@ -26,10 +28,10 @@ use janus_core::{ }; use janus_messages::{ batch_mode::{LeaderSelected, TimeInterval}, - AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobStep, - Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, InputShareAad, Interval, - PartialBatchSelector, PrepareInit, PrepareStepResult, ReportError, ReportIdChecksum, - ReportMetadata, ReportShare, Time, + AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, Duration, Extension, + ExtensionType, HpkeCiphertext, HpkeConfigId, InputShareAad, Interval, PartialBatchSelector, + PrepareInit, PrepareStepResult, ReportError, ReportIdChecksum, ReportMetadata, ReportShare, + Role, Time, }; use prio::{codec::Encode, vdaf::dummy}; use rand::random; @@ -46,7 +48,12 @@ async fn aggregate_leader() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); datastore .put_aggregator_task(&task.leader_view().unwrap()) .await @@ -93,7 +100,7 @@ async fn aggregate_leader() { let test_conn = TestConn::build( trillium::Method::Options, - task.aggregation_job_uri(&aggregation_job_id) + task.aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), (), @@ -116,9 +123,13 @@ async fn aggregate_wrong_agg_auth_token() { let dap_auth_token = AuthenticationToken::DapAuth(random()); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_aggregator_auth_token(dap_auth_token.clone()) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_aggregator_auth_token(dap_auth_token.clone()) + .build(); datastore .put_aggregator_task(&task.helper_view().unwrap()) @@ -146,7 +157,7 @@ async fn aggregate_wrong_agg_auth_token() { Vec::from([dap_auth_token, wrong_token_value]), ] { let mut test_conn = put(task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header( @@ -177,10 +188,7 @@ async fn aggregate_wrong_agg_auth_token() { } #[tokio::test] -// Silence the unit_arg lint so that we can work with dummy::Vdaf::{InputShare, -// Measurement} values (whose type is ()). -#[allow(clippy::unit_arg, clippy::let_unit_value)] -async fn aggregate_init() { +async fn aggregate_init_sync() { let HttpHandlerTest { clock, ephemeral_datastore: _ephemeral_datastore, @@ -190,7 +198,12 @@ async fn aggregate_init() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); @@ -443,58 +456,35 @@ async fn aggregate_init() { let mut batch_aggregations_results = Vec::new(); let mut aggregation_jobs_results = Vec::new(); - let conflicting_aggregation_job = datastore + datastore .run_unnamed_tx(|tx| { - let task = helper_task.clone(); + let helper_task = helper_task.clone(); let report_share_4 = prepare_init_4.report_share().clone(); Box::pin(async move { - tx.put_aggregator_task(&task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); // report_share_4 is already in the datastore as it was referenced by an existing // aggregation job. - tx.put_scrubbed_report(task.id(), &report_share_4) - .await - .unwrap(); - - // Put in an aggregation job and report aggregation for report_share_4. It uses - // the same aggregation parameter as the aggregation job this test will later - // add and so should cause report_share_4 to fail to prepare. - let conflicting_aggregation_job = AggregationJob::new( - *task.id(), - random(), - dummy::AggregationParam(0), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ); - tx.put_aggregation_job::<0, TimeInterval, dummy::Vdaf>( - &conflicting_aggregation_job, + tx.put_scrubbed_report( + helper_task.id(), + report_share_4.metadata().id(), + report_share_4.metadata().time(), ) .await .unwrap(); - tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new( - *task.id(), - *conflicting_aggregation_job.id(), - *report_share_4.metadata().id(), - *report_share_4.metadata().time(), - 0, - None, - ReportAggregationState::Finished, - )) - .await - .unwrap(); // Write collected batch aggregations for the interval that report_share_5 falls // into, which will cause it to fail to prepare. try_join_all( empty_batch_aggregations::<0, TimeInterval, dummy::Vdaf>( - &task, + &helper_task, BATCH_AGGREGATION_SHARD_COUNT, - &Interval::new(Time::from_seconds_since_epoch(0), *task.time_precision()) - .unwrap(), + &Interval::new( + Time::from_seconds_since_epoch(0), + *helper_task.time_precision(), + ) + .unwrap(), &dummy::AggregationParam(0), &[], ) @@ -504,7 +494,7 @@ async fn aggregate_init() { .await .unwrap(); - Ok(conflicting_aggregation_job) + Ok(()) }) }) .await @@ -665,15 +655,11 @@ async fn aggregate_init() { .await .unwrap(); - assert_eq!(aggregation_jobs.len(), 2); + assert_eq!(aggregation_jobs.len(), 1); - let mut saw_conflicting_aggregation_job = false; let mut saw_new_aggregation_job = false; - for aggregation_job in &aggregation_jobs { - if aggregation_job.eq(&conflicting_aggregation_job) { - saw_conflicting_aggregation_job = true; - } else if aggregation_job.task_id().eq(task.id()) + if aggregation_job.task_id().eq(task.id()) && aggregation_job.id().eq(&aggregation_job_id) && aggregation_job.partial_batch_identifier().eq(&()) && aggregation_job.state().eq(&AggregationJobState::Finished) @@ -681,8 +667,6 @@ async fn aggregate_init() { saw_new_aggregation_job = true; } } - - assert!(saw_conflicting_aggregation_job); assert!(saw_new_aggregation_job); aggregation_jobs_results.push(aggregation_jobs); @@ -700,6 +684,167 @@ async fn aggregate_init() { .await; } +#[tokio::test] +async fn aggregate_init_async() { + let HttpHandlerTest { + clock, + ephemeral_datastore: _ephemeral_datastore, + datastore, + handler, + hpke_keypair, + .. + } = HttpHandlerTest::new().await; + + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); + + let helper_task = task.helper_view().unwrap(); + + let vdaf = dummy::Vdaf::new(1); + let measurement = 0; + let prep_init_generator = PrepareInitGenerator::new( + clock.clone(), + helper_task.clone(), + hpke_keypair.config().clone(), + vdaf.clone(), + dummy::AggregationParam(0), + ); + + // prepare_init_0 is a "happy path" report. + let (prepare_init_0, _) = prep_init_generator.next(&measurement); + + // prepare_init_1 has already been aggregated in another aggregation job, with the same + // aggregation parameter. + let (prepare_init_1, _) = prep_init_generator.next(&measurement); + + datastore + .run_unnamed_tx(|tx| { + let helper_task = helper_task.clone(); + let report_share_1 = prepare_init_1.report_share().clone(); + + Box::pin(async move { + tx.put_aggregator_task(&helper_task).await.unwrap(); + + // report_share_1 is already in the datastore as it was referenced by an existing + // aggregation job. + tx.put_scrubbed_report( + helper_task.id(), + report_share_1.metadata().id(), + report_share_1.metadata().time(), + ) + .await + .unwrap(); + + Ok(()) + }) + }) + .await + .unwrap(); + + let aggregation_param = dummy::AggregationParam(0); + let request = AggregationJobInitializeReq::new( + aggregation_param.get_encoded().unwrap(), + PartialBatchSelector::new_time_interval(), + Vec::from([prepare_init_0.clone(), prepare_init_1.clone()]), + ); + + // Send request, parse response. Do this twice to prove that the request is idempotent. + let aggregation_job_id: AggregationJobId = random(); + let mut aggregation_jobs_results = Vec::new(); + let mut report_aggregations_results = Vec::new(); + let mut batch_aggregations_results = Vec::new(); + for _ in 0..2 { + let mut test_conn = + put_aggregation_job(&task, &aggregation_job_id, &request, &handler).await; + assert_eq!(test_conn.status(), Some(Status::Created)); + assert_headers!( + &test_conn, + "content-type" => (AggregationJobResp::MEDIA_TYPE) + ); + let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await; + assert_matches!(aggregate_resp, AggregationJobResp::Processing); + + // Check aggregation job in datastore. + let (aggregation_jobs, report_aggregations, batch_aggregations) = datastore + .run_unnamed_tx(|tx| { + let task = task.clone(); + let vdaf = vdaf.clone(); + Box::pin(async move { + Ok(( + tx.get_aggregation_jobs_for_task::<0, TimeInterval, dummy::Vdaf>(task.id()) + .await + .unwrap(), + tx.get_report_aggregations_for_aggregation_job::<0, dummy::Vdaf>( + &vdaf, + &Role::Helper, + task.id(), + &aggregation_job_id, + &aggregation_param, + ) + .await + .unwrap(), + tx.get_batch_aggregations_for_task::<0, TimeInterval, _>(&vdaf, task.id()) + .await + .unwrap(), + )) + }) + }) + .await + .unwrap(); + + assert_eq!(aggregation_jobs.len(), 1); + + assert_eq!(aggregation_jobs[0].task_id(), task.id()); + assert_eq!(aggregation_jobs[0].id(), &aggregation_job_id); + assert_eq!(aggregation_jobs[0].partial_batch_identifier(), &()); + assert_eq!(aggregation_jobs[0].state(), &AggregationJobState::Active); + + assert_eq!(report_aggregations.len(), 2); + + assert_eq!( + report_aggregations[0].report_id(), + prepare_init_0.report_share().metadata().id() + ); + assert_eq!( + report_aggregations[0].state(), + &ReportAggregationState::HelperInitProcessing { + prepare_init: prepare_init_0.clone(), + require_taskbind_extension: false + } + ); + + assert_eq!( + report_aggregations[1].report_id(), + prepare_init_1.report_share().metadata().id() + ); + assert_eq!( + report_aggregations[1].state(), + &ReportAggregationState::Failed { + report_error: ReportError::ReportReplayed + } + ); + + aggregation_jobs_results.push(aggregation_jobs); + report_aggregations_results.push(report_aggregations); + batch_aggregations_results.push(batch_aggregations); + } + + assert!(aggregation_jobs_results.windows(2).all(|v| v[0] == v[1])); + assert!(report_aggregations_results.windows(2).all(|v| v[0] == v[1])); + assert!(batch_aggregations_results.windows(2).all(|v| v[0] == v[1])); + + assert_task_aggregation_counter( + &datastore, + *task.id(), + TaskAggregationCounter::new_with_values(0), + ) + .await; +} + #[tokio::test] async fn aggregate_init_batch_already_collected() { let HttpHandlerTest { @@ -715,6 +860,7 @@ async fn aggregate_init_batch_already_collected() { BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build(); @@ -779,7 +925,7 @@ async fn aggregate_init_batch_already_collected() { let aggregation_job_id: AggregationJobId = random(); let (header, value) = task.aggregator_auth_token().request_authentication(); let mut test_conn = put(task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(header, value) @@ -827,7 +973,12 @@ async fn aggregate_init_prep_init_failed() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::FakeFailsPrepInit).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::FakeFailsPrepInit, + ) + .build(); let helper_task = task.helper_view().unwrap(); let prep_init_generator = PrepareInitGenerator::new( clock.clone(), @@ -892,7 +1043,12 @@ async fn aggregate_init_prep_step_failed() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::FakeFailsPrepStep).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::FakeFailsPrepStep, + ) + .build(); let helper_task = task.helper_view().unwrap(); let prep_init_generator = PrepareInitGenerator::new( clock.clone(), @@ -956,7 +1112,12 @@ async fn aggregate_init_duplicated_report_id() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build(); let helper_task = task.helper_view().unwrap(); let prep_init_generator = PrepareInitGenerator::new( diff --git a/aggregator/src/aggregator/http_handlers/tests/collection_job.rs b/aggregator/src/aggregator/http_handlers/tests/collection_job.rs index d7c7971be..80376a48f 100644 --- a/aggregator/src/aggregator/http_handlers/tests/collection_job.rs +++ b/aggregator/src/aggregator/http_handlers/tests/collection_job.rs @@ -6,7 +6,7 @@ use assert_matches::assert_matches; use janus_aggregator_core::{ batch_mode::AccumulableBatchMode, datastore::models::{CollectionJob, CollectionJobState}, - task::{test_util::TaskBuilder, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode}, }; use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, @@ -146,9 +146,13 @@ async fn collection_job_put_request_invalid_batch_size() { } = HttpHandlerTest::new().await; // Prepare parameters. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_min_batch_size(1) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_min_batch_size(1) + .build(); let leader_task = task.leader_view().unwrap(); datastore.put_aggregator_task(&leader_task).await.unwrap(); diff --git a/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs b/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs index 1e5799e30..0c1e9798a 100644 --- a/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs +++ b/aggregator/src/aggregator/http_handlers/tests/helper_e2e.rs @@ -1,5 +1,5 @@ use assert_matches::assert_matches; -use janus_aggregator_core::task::{test_util::TaskBuilder, BatchMode}; +use janus_aggregator_core::task::{test_util::TaskBuilder, AggregationMode, BatchMode}; use janus_core::{report_id::ReportIdChecksumExt, vdaf::VdafInstance}; use janus_messages::{ batch_mode::LeaderSelected, AggregateShareReq, AggregationJobInitializeReq, AggregationJobResp, @@ -14,7 +14,7 @@ use trillium::Status; use trillium_testing::assert_status; use crate::aggregator::{ - aggregate_init_tests::{put_aggregation_job, PrepareInitGenerator}, + aggregation_job_init::test_util::{put_aggregation_job, PrepareInitGenerator}, http_handlers::{ test_util::{take_response_body, HttpHandlerTest}, tests::aggregate_share::post_aggregate_share_request, @@ -38,6 +38,7 @@ async fn helper_aggregation_report_share_replay() { BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_min_batch_size(1) diff --git a/aggregator/src/aggregator/http_handlers/tests/hpke_config.rs b/aggregator/src/aggregator/http_handlers/tests/hpke_config.rs index 5e3a031f0..586e18426 100644 --- a/aggregator/src/aggregator/http_handlers/tests/hpke_config.rs +++ b/aggregator/src/aggregator/http_handlers/tests/hpke_config.rs @@ -12,7 +12,7 @@ use crate::{ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _}; use janus_aggregator_core::{ datastore::models::HpkeKeyState, - task::{test_util::TaskBuilder, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode}, test_util::noop_meter, }; use janus_core::{ @@ -156,7 +156,12 @@ async fn hpke_config_with_taskprov() { } = HttpHandlerTest::new().await; // Insert a taskprov task. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); let taskprov_helper_task = task.taskprov_helper_view().unwrap(); datastore .put_aggregator_task(&taskprov_helper_task) @@ -226,10 +231,14 @@ async fn hpke_config_cors_headers() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); datastore.put_aggregator_task(&task).await.unwrap(); // Check for appropriate CORS headers in response to a preflight request. diff --git a/aggregator/src/aggregator/http_handlers/tests/mod.rs b/aggregator/src/aggregator/http_handlers/tests/mod.rs index ea6e64358..546f232ce 100644 --- a/aggregator/src/aggregator/http_handlers/tests/mod.rs +++ b/aggregator/src/aggregator/http_handlers/tests/mod.rs @@ -1,5 +1,6 @@ mod aggregate_share; mod aggregation_job_continue; +mod aggregation_job_get; mod aggregation_job_init; mod collection_job; mod helper_e2e; diff --git a/aggregator/src/aggregator/http_handlers/tests/report.rs b/aggregator/src/aggregator/http_handlers/tests/report.rs index 754c59667..a8bc1569d 100644 --- a/aggregator/src/aggregator/http_handlers/tests/report.rs +++ b/aggregator/src/aggregator/http_handlers/tests/report.rs @@ -11,7 +11,7 @@ use crate::{ }; use janus_aggregator_core::{ datastore::test_util::{ephemeral_datastore, EphemeralDatastoreBuilder}, - task::{test_util::TaskBuilder, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode}, test_util::noop_meter, }; use janus_core::{ @@ -75,9 +75,13 @@ async fn upload_handler() { } = HttpHandlerTest::new().await; const REPORT_EXPIRY_AGE: u64 = 1_000_000; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_report_expiry_age(Some(Duration::from_seconds(REPORT_EXPIRY_AGE))) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_report_expiry_age(Some(Duration::from_seconds(REPORT_EXPIRY_AGE))) + .build(); let leader_task = task.leader_view().unwrap(); datastore.put_aggregator_task(&leader_task).await.unwrap(); @@ -208,9 +212,13 @@ async fn upload_handler() { .await; // Reports with timestamps past the task's end time should be rejected. - let task_end_soon = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_task_end(Some(clock.now().add(&Duration::from_seconds(60)).unwrap())) - .build(); + let task_end_soon = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_task_end(Some(clock.now().add(&Duration::from_seconds(60)).unwrap())) + .build(); let leader_task_end_soon = task_end_soon.leader_view().unwrap(); datastore .put_aggregator_task(&leader_task_end_soon) @@ -373,7 +381,12 @@ async fn upload_handler_helper() { .. } = HttpHandlerTest::new().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); let helper_task = task.helper_view().unwrap(); datastore.put_aggregator_task(&helper_task).await.unwrap(); let report = create_report(&helper_task, &hpke_keypair, clock.now()); @@ -432,9 +445,13 @@ async fn upload_handler_error_fanout() { .unwrap(); const REPORT_EXPIRY_AGE: u64 = 1_000_000; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_report_expiry_age(Some(Duration::from_seconds(REPORT_EXPIRY_AGE))) - .build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_report_expiry_age(Some(Duration::from_seconds(REPORT_EXPIRY_AGE))) + .build(); let leader_task = task.leader_view().unwrap(); datastore.put_aggregator_task(&leader_task).await.unwrap(); @@ -544,7 +561,12 @@ async fn upload_client_early_disconnect() { .build() .unwrap(); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); let task_id = *task.id(); let leader_task = task.leader_view().unwrap(); datastore.put_aggregator_task(&leader_task).await.unwrap(); diff --git a/aggregator/src/aggregator/taskprov_tests.rs b/aggregator/src/aggregator/taskprov_tests.rs index 1c2aa2aee..ce471c97d 100644 --- a/aggregator/src/aggregator/taskprov_tests.rs +++ b/aggregator/src/aggregator/taskprov_tests.rs @@ -1,6 +1,6 @@ use crate::{ aggregator::{ - aggregate_init_tests::PrepareInitGenerator, + aggregation_job_init::test_util::PrepareInitGenerator, http_handlers::test_util::{decode_response_body, take_problem_details}, Config, }, @@ -19,7 +19,7 @@ use janus_aggregator_core::{ }, task::{ test_util::{Task, TaskBuilder}, - BatchMode, + AggregationMode, BatchMode, }, taskprov::{taskprov_task_id, test_util::PeerAggregatorBuilder, PeerAggregator}, test_util::noop_meter, @@ -104,7 +104,7 @@ where let collector_hpke_keypair = HpkeKeypair::test(); let peer_aggregator = PeerAggregatorBuilder::new() .with_endpoint(url::Url::parse("https://leader.example.com/").unwrap()) - .with_role(Role::Leader) + .with_peer_role(Role::Leader) .with_collector_hpke_config(collector_hpke_keypair.config().clone()) .build(); @@ -169,6 +169,7 @@ where BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, vdaf_instance, ) .with_id(task_id) @@ -276,7 +277,7 @@ async fn taskprov_aggregate_init() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth.0, "Bearer invalid_token") @@ -305,7 +306,7 @@ async fn taskprov_aggregate_init() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth.0, auth.1) @@ -371,7 +372,7 @@ async fn taskprov_aggregate_init() { .eq(&batch_id_1) && aggregation_jobs[0] .state() - .eq(&AggregationJobState::InProgress) + .eq(&AggregationJobState::AwaitingRequest) ); assert!( aggregation_jobs[1].task_id().eq(&test.task_id) @@ -381,7 +382,7 @@ async fn taskprov_aggregate_init() { .eq(&batch_id_2) && aggregation_jobs[1] .state() - .eq(&AggregationJobState::InProgress) + .eq(&AggregationJobState::AwaitingRequest) ); let got_task = got_task.unwrap(); assert_eq!(test.task.taskprov_helper_view().unwrap(), got_task); @@ -412,7 +413,7 @@ async fn taskprov_aggregate_init_missing_extension() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth.0, auth.1) @@ -499,7 +500,7 @@ async fn taskprov_aggregate_init_malformed_extension() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth.0, auth.1) @@ -590,7 +591,7 @@ async fn taskprov_opt_out_task_ended() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth.0, auth.1) @@ -656,7 +657,7 @@ async fn taskprov_opt_out_mismatched_task_id() { let mut test_conn = put(test // Use the test case task's ID. .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth.0, auth.1) @@ -835,7 +836,12 @@ async fn taskprov_aggregate_continue() { tx.put_aggregator_task(&task.taskprov_helper_view().unwrap()) .await?; - tx.put_scrubbed_report(task.id(), &report_share).await?; + tx.put_scrubbed_report( + task.id(), + report_share.metadata().id(), + report_share.metadata().time(), + ) + .await?; tx.put_aggregation_job(&AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( *task.id(), @@ -844,7 +850,7 @@ async fn taskprov_aggregate_continue() { batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await?; @@ -894,7 +900,7 @@ async fn taskprov_aggregate_continue() { // Attempt using the wrong credentials, should reject. let mut test_conn = post( test.task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -923,7 +929,7 @@ async fn taskprov_aggregate_continue() { let mut test_conn = post( test.task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -1095,7 +1101,7 @@ async fn end_to_end() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth_header_name, auth_header_value.clone()) @@ -1138,7 +1144,7 @@ async fn end_to_end() { let mut test_conn = post( test.task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path(), ) @@ -1245,7 +1251,7 @@ async fn end_to_end_sumvec_hmac() { let mut test_conn = put(test .task - .aggregation_job_uri(&aggregation_job_id) + .aggregation_job_uri(&aggregation_job_id, None) .unwrap() .path()) .with_request_header(auth_header_name, auth_header_value.clone()) diff --git a/aggregator/src/aggregator/upload_tests.rs b/aggregator/src/aggregator/upload_tests.rs index f432f711d..7f4107ad7 100644 --- a/aggregator/src/aggregator/upload_tests.rs +++ b/aggregator/src/aggregator/upload_tests.rs @@ -13,7 +13,7 @@ use janus_aggregator_core::{ }, task::{ test_util::{Task, TaskBuilder}, - BatchMode, + AggregationMode, BatchMode, }, test_util::noop_meter, }; @@ -24,7 +24,7 @@ use janus_core::{ runtime::{TestRuntime, TestRuntimeManager}, }, time::{Clock, MockClock, TimeExt}, - vdaf::{VdafInstance, VERIFY_KEY_LENGTH}, + vdaf::{VdafInstance, VERIFY_KEY_LENGTH_PRIO3}, Runtime, }; use janus_messages::{ @@ -58,7 +58,12 @@ impl UploadTest { let clock = MockClock::default(); let vdaf = Prio3Count::new_count(2).unwrap(); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count).build(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build(); let leader_task = task.leader_view().unwrap(); @@ -446,16 +451,18 @@ async fn upload_report_for_collected_batch() { .run_unnamed_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.put_collection_job( - &CollectionJob::::new( - *task.id(), - random(), - Query::new_time_interval(batch_interval), - (), - batch_interval, - CollectionJobState::Start, - ), - ) + tx.put_collection_job(&CollectionJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *task.id(), + random(), + Query::new_time_interval(batch_interval), + (), + batch_interval, + CollectionJobState::Start, + )) .await }) }) @@ -514,13 +521,17 @@ async fn upload_report_task_not_started() { .await; // Set the task start time to the future, and generate & upload a report from before that time. - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_task_start(Some( - clock.now().add(&Duration::from_seconds(3600)).unwrap(), - )) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_task_start(Some( + clock.now().add(&Duration::from_seconds(3600)).unwrap(), + )) + .build() + .leader_view() + .unwrap(); datastore.put_aggregator_task(&task).await.unwrap(); let report = create_report(&task, &hpke_keypair, clock.now()); @@ -576,11 +587,15 @@ async fn upload_report_task_ended() { ) .await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_task_end(Some(clock.now())) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_task_end(Some(clock.now())) + .build() + .leader_view() + .unwrap(); datastore.put_aggregator_task(&task).await.unwrap(); // Advance the clock to end the task. @@ -638,11 +653,15 @@ async fn upload_report_report_expired() { ) .await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_report_expiry_age(Some(Duration::from_seconds(60))) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_report_expiry_age(Some(Duration::from_seconds(60))) + .build() + .leader_view() + .unwrap(); datastore.put_aggregator_task(&task).await.unwrap(); let report = create_report(&task, &hpke_keypair, clock.now()); diff --git a/aggregator/src/binaries/aggregation_job_driver.rs b/aggregator/src/binaries/aggregation_job_driver.rs index 510d7032b..09044a812 100644 --- a/aggregator/src/binaries/aggregation_job_driver.rs +++ b/aggregator/src/binaries/aggregation_job_driver.rs @@ -1,6 +1,7 @@ use crate::{ aggregator::aggregation_job_driver::AggregationJobDriver, binary_utils::{job_driver::JobDriver, BinaryContext, BinaryOptions, CommonBinaryOptions}, + cache::HpkeKeypairCache, config::{BinaryConfig, CommonConfig, JobDriverConfig, TaskprovConfig}, }; use anyhow::{Context, Result}; @@ -18,6 +19,11 @@ pub async fn main_callback(ctx: BinaryContext) -> Re "/aggregation_job_driver", ); + let hpke_configs_refresh_interval = match ctx.config.hpke_configs_refresh_interval { + Some(duration) => Duration::from_millis(duration), + None => HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL, + }; + let datastore = Arc::new(ctx.datastore); let aggregation_job_driver = Arc::new(AggregationJobDriver::new( reqwest::Client::builder() @@ -36,6 +42,8 @@ pub async fn main_callback(ctx: BinaryContext) -> Re &ctx.meter, ctx.config.batch_aggregation_shard_count, ctx.config.task_counter_shard_count, + hpke_configs_refresh_interval, + Duration::from_millis(ctx.config.default_async_poll_interval), )); let lease_duration = Duration::from_secs(ctx.config.job_driver_config.worker_lease_duration_s); @@ -132,6 +140,19 @@ pub struct Config { /// aggregations, while increasing the cost of getting task metrics. #[serde(default = "default_task_counter_shard_count")] pub task_counter_shard_count: u64, + + /// Defines how often to refresh the HPKE configs cache in milliseconds. This affects how often + /// an aggregator becomes aware of HPKE key state changes. If unspecified, default is defined by + /// [`HpkeKeypairCache::DEFAULT_REFRESH_INTERVAL`]. You shouldn't normally have to specify this. + #[serde(default)] + pub hpke_configs_refresh_interval: Option, + + /// Defines how frequently outstanding asynchronous aggregation jobs where this aggregator is + /// the Leader might be polled if the Helper does not send a Retry-After header, in + /// milliseconds. (If the Helper does send a Retry-After header, it will be respected.) If + /// unspecified, the default is one minute. + #[serde(default = "default_default_async_poll_interval")] + pub default_async_poll_interval: u64, } impl BinaryConfig for Config { @@ -148,6 +169,10 @@ fn default_task_counter_shard_count() -> u64 { 32 } +fn default_default_async_poll_interval() -> u64 { + 60_000 +} + #[cfg(test)] mod tests { use super::{Config, Options}; @@ -190,6 +215,8 @@ mod tests { }, batch_aggregation_shard_count: 32, task_counter_shard_count: 64, + hpke_configs_refresh_interval: Some(180000), + default_async_poll_interval: 5_000, taskprov_config: TaskprovConfig::default(), }) } diff --git a/aggregator/src/binaries/janus_cli.rs b/aggregator/src/binaries/janus_cli.rs index d4bccc9fd..801c0675c 100644 --- a/aggregator/src/binaries/janus_cli.rs +++ b/aggregator/src/binaries/janus_cli.rs @@ -11,7 +11,7 @@ use clap::Parser; use janus_aggregator_api::git_revision; use janus_aggregator_core::{ datastore::{self, models::HpkeKeyState, Datastore}, - task::{AggregatorTask, SerializedAggregatorTask}, + task::{AggregationMode, AggregatorTask, SerializedAggregatorTask}, taskprov::{PeerAggregator, VerifyKeyInit}, }; use janus_core::{ @@ -126,9 +126,13 @@ enum Command { #[arg(long)] peer_endpoint: Url, - /// This aggregator's role. + /// The peer aggregator's role. #[arg(long)] - role: Role, + peer_role: Role, + + /// The aggregation mode to use. Specified only if this aggregator is the Helper. + #[arg(long)] + aggregation_mode: Option, /// The taskprov verify_key_init value, in unpadded base64url. #[arg(long, env = "VERIFY_KEY_INIT", hide_env_values = true)] @@ -247,7 +251,8 @@ impl Command { Command::AddTaskprovPeerAggregator { kubernetes_secret_options, peer_endpoint, - role, + peer_role, + aggregation_mode, verify_key_init, collector_hpke_config_file, report_expiry_age_secs, @@ -271,7 +276,8 @@ impl Command { &datastore, command_line_options.dry_run, peer_endpoint, - *role, + *peer_role, + *aggregation_mode, *verify_key_init, collector_hpke_config_file, report_expiry_age, @@ -397,6 +403,7 @@ async fn add_taskprov_peer_aggregator( dry_run: bool, peer_endpoint: &Url, role: Role, + aggregation_mode: Option, verify_key_init: VerifyKeyInit, collector_hpke_config_file: &Path, report_expiry_age: Option, @@ -415,6 +422,7 @@ async fn add_taskprov_peer_aggregator( let peer_aggregator = Arc::new(PeerAggregator::new( peer_endpoint.clone(), role, + aggregation_mode, verify_key_init, collector_hpke_config, report_expiry_age, @@ -755,7 +763,7 @@ mod tests { use clap::CommandFactory; use janus_aggregator_core::{ datastore::{models::HpkeKeyState, test_util::ephemeral_datastore, Datastore}, - task::{test_util::TaskBuilder, AggregatorTask, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, AggregatorTask, BatchMode}, taskprov::{PeerAggregator, VerifyKeyInit}, }; use janus_core::{ @@ -1041,7 +1049,8 @@ mod tests { ds: &Datastore, dry_run: bool, peer_endpoint: &Url, - role: Role, + peer_role: Role, + aggregation_mode: Option, verify_key_init: VerifyKeyInit, collector_hpke_config: &HpkeConfig, report_expiry_age: Option, @@ -1059,7 +1068,8 @@ mod tests { ds, dry_run, peer_endpoint, - role, + peer_role, + aggregation_mode, verify_key_init, &collector_hpke_config_file, report_expiry_age, @@ -1077,7 +1087,8 @@ mod tests { let ds = ephemeral_datastore.datastore(RealClock::default()).await; let peer_endpoint = "https://example.com".try_into().unwrap(); - let role = Role::Leader; + let peer_role = Role::Leader; + let aggregation_mode = Some(AggregationMode::Synchronous); let verify_key_init = random(); let collector_hpke_config = HpkeKeypair::generate( HpkeConfigId::from(96), @@ -1097,7 +1108,8 @@ mod tests { &ds, /* dry_run */ false, &peer_endpoint, - role, + peer_role, + aggregation_mode, verify_key_init, &collector_hpke_config, report_expiry_age, @@ -1109,7 +1121,8 @@ mod tests { let want_peer_aggregator = PeerAggregator::new( peer_endpoint.clone(), - role, + peer_role, + aggregation_mode, verify_key_init, collector_hpke_config, report_expiry_age, @@ -1124,7 +1137,7 @@ mod tests { Box::pin(async move { Ok(tx - .get_taskprov_peer_aggregator(&peer_endpoint, &role) + .get_taskprov_peer_aggregator(&peer_endpoint, &peer_role) .await .unwrap() .unwrap()) @@ -1146,6 +1159,7 @@ mod tests { /* dry_run */ true, &"https://example.com".try_into().unwrap(), Role::Leader, + Some(AggregationMode::Synchronous), random(), &HpkeKeypair::generate( HpkeConfigId::from(96), @@ -1197,12 +1211,17 @@ mod tests { let ds = ephemeral_datastore.datastore(RealClock::default()).await; let tasks = Vec::from([ - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), TaskBuilder::new( BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(), + TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Prio3Sum { max_measurement: 4096, }, @@ -1231,13 +1250,14 @@ mod tests { let ephemeral_datastore = ephemeral_datastore().await; let ds = ephemeral_datastore.datastore(RealClock::default()).await; - let tasks = - Vec::from([ - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), - ]); + let tasks = Vec::from([TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap()]); let written_tasks = run_provision_tasks_testcase(&ds, &tasks, true).await; @@ -1255,12 +1275,17 @@ mod tests { #[tokio::test] async fn replace_task() { let tasks = Vec::from([ - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(), TaskBuilder::new( BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(), + TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Prio3Sum { max_measurement: 4096, }, @@ -1287,6 +1312,7 @@ mod tests { BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Prio3SumVec { bits: 1, length: 4, @@ -1359,6 +1385,7 @@ mod tests { hpke_keys: [] - peer_aggregator_endpoint: https://leader batch_mode: TimeInterval + aggregation_mode: Asynchronous vdaf: !Prio3Sum max_measurement: 4096 role: Helper diff --git a/aggregator/src/cache.rs b/aggregator/src/cache.rs index fbf8d757e..efaf9deeb 100644 --- a/aggregator/src/cache.rs +++ b/aggregator/src/cache.rs @@ -198,7 +198,7 @@ impl PeerAggregatorCache { // so a linear search should be fine. self.peers .iter() - .find(|peer| peer.endpoint() == endpoint && peer.role() == role) + .find(|peer| peer.endpoint() == endpoint && peer.peer_role() == role) } } @@ -276,7 +276,7 @@ mod tests { use janus_aggregator_core::{ datastore::{models::HpkeKeyState, test_util::ephemeral_datastore}, - task::{test_util::TaskBuilder, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode}, }; use janus_core::{ hpke::HpkeKeypair, @@ -352,10 +352,14 @@ mod tests { ttl, ); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); assert!(task_aggregators.get(task.id()).await.unwrap().is_none()); // We shouldn't have cached that last call. @@ -417,10 +421,14 @@ mod tests { ttl, ); - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); assert!(task_aggregators.get(task.id()).await.unwrap().is_none()); diff --git a/aggregator/src/metrics.rs b/aggregator/src/metrics.rs index 944661d33..f0e9b17e1 100644 --- a/aggregator/src/metrics.rs +++ b/aggregator/src/metrics.rs @@ -329,8 +329,8 @@ pub(crate) fn aggregate_step_failure_counter(meter: &Meter) -> Counter { "duplicate_extension", "missing_client_report", "missing_prepare_message", - "missing_or_malformed_taskprov_extension", - "unexpected_taskprov_extension", + "missing_or_malformed_taskbind_extension", + "unexpected_taskbind_extension", ] { aggregate_step_failure_counter.add(0, &[KeyValue::new("type", failure_type)]); } diff --git a/aggregator/tests/integration/graceful_shutdown.rs b/aggregator/tests/integration/graceful_shutdown.rs index 567e4d684..edb551999 100644 --- a/aggregator/tests/integration/graceful_shutdown.rs +++ b/aggregator/tests/integration/graceful_shutdown.rs @@ -24,7 +24,7 @@ use janus_aggregator::{ }; use janus_aggregator_core::{ datastore::test_util::ephemeral_datastore, - task::{test_util::TaskBuilder, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode}, }; use janus_core::{ hpke::HpkeCiphersuite, test_util::install_test_trace_subscriber, time::RealClock, @@ -140,10 +140,14 @@ async fn graceful_shutdown(binary_name: &str, mut c common_config.database.connection_pool_timeouts_s = 60; common_config.health_check_listen_address = health_check_listen_address; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); datastore.put_aggregator_task(&task).await.unwrap(); // Save the above configuration to a temporary file, so that we can pass @@ -415,6 +419,8 @@ async fn aggregation_job_driver_shutdown() { taskprov_config: TaskprovConfig::default(), batch_aggregation_shard_count: 32, task_counter_shard_count: 32, + hpke_configs_refresh_interval: None, + default_async_poll_interval: 1000, }; graceful_shutdown("aggregation_job_driver", config).await; diff --git a/aggregator_api/src/models.rs b/aggregator_api/src/models.rs index 4d391fc2c..f2a17004c 100644 --- a/aggregator_api/src/models.rs +++ b/aggregator_api/src/models.rs @@ -2,7 +2,7 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use educe::Educe; use janus_aggregator_core::{ datastore::models::{HpkeKeyState, HpkeKeypair, TaskAggregationCounter, TaskUploadCounter}, - task::{AggregatorTask, BatchMode}, + task::{AggregationMode, AggregatorTask, BatchMode}, taskprov::{PeerAggregator, VerifyKeyInit}, }; use janus_core::{ @@ -65,6 +65,9 @@ pub(crate) struct PostTaskReq { pub(crate) peer_aggregator_endpoint: Url, /// DAP batch mode for this task. pub(crate) batch_mode: BatchMode, + /// Aggregation mode (e.g. synchronous vs asynchronous) for this task. Populated if and only if + /// this is a Helper task. + pub(crate) aggregation_mode: Option, /// The VDAF being run by this task. pub(crate) vdaf: VdafInstance, /// The role that this aggregator will play in this task. @@ -202,7 +205,7 @@ pub(crate) struct PatchHpkeConfigReq { pub(crate) struct TaskprovPeerAggregatorResp { #[educe(Debug(method(std::fmt::Display::fmt)))] pub(crate) endpoint: Url, - pub(crate) role: Role, + pub(crate) peer_role: Role, pub(crate) collector_hpke_config: HpkeConfig, pub(crate) report_expiry_age: Option, pub(crate) tolerable_clock_skew: Duration, @@ -213,7 +216,7 @@ impl From for TaskprovPeerAggregatorResp { // Exclude sensitive values. Self { endpoint: value.endpoint().clone(), - role: *value.role(), + peer_role: *value.peer_role(), collector_hpke_config: value.collector_hpke_config().clone(), report_expiry_age: value.report_expiry_age().cloned(), tolerable_clock_skew: *value.tolerable_clock_skew(), @@ -224,7 +227,8 @@ impl From for TaskprovPeerAggregatorResp { #[derive(Serialize, Deserialize)] pub(crate) struct PostTaskprovPeerAggregatorReq { pub(crate) endpoint: Url, - pub(crate) role: Role, + pub(crate) peer_role: Role, + pub(crate) aggregation_mode: Option, pub(crate) collector_hpke_config: HpkeConfig, pub(crate) verify_key_init: VerifyKeyInit, pub(crate) report_expiry_age: Option, @@ -236,7 +240,7 @@ pub(crate) struct PostTaskprovPeerAggregatorReq { #[derive(Clone, Serialize, Deserialize)] pub(crate) struct DeleteTaskprovPeerAggregatorReq { pub(crate) endpoint: Url, - pub(crate) role: Role, + pub(crate) peer_role: Role, } // Any value that is present is considered Some value, including null. See diff --git a/aggregator_api/src/routes.rs b/aggregator_api/src/routes.rs index 3c8456a20..5069b1613 100644 --- a/aggregator_api/src/routes.rs +++ b/aggregator_api/src/routes.rs @@ -157,6 +157,12 @@ pub(super) async fn post_task( AggregatorTaskParameters::Helper { aggregator_auth_token_hash, collector_hpke_config: req.collector_hpke_config, + aggregation_mode: req.aggregation_mode.ok_or_else(|| { + Error::BadRequest( + "aggregator acting in helper role must be provided an aggregation mode" + .to_string(), + ) + })?, }, ) } @@ -442,7 +448,8 @@ pub(super) async fn post_taskprov_peer_aggregator( ) -> Result<(Status, Json), Error> { let to_insert = PeerAggregator::new( req.endpoint, - req.role, + req.peer_role, + req.aggregation_mode, req.verify_key_init, req.collector_hpke_config, req.report_expiry_age, @@ -456,7 +463,7 @@ pub(super) async fn post_taskprov_peer_aggregator( let to_insert = to_insert.clone(); Box::pin(async move { tx.put_taskprov_peer_aggregator(&to_insert).await?; - tx.get_taskprov_peer_aggregator(to_insert.endpoint(), to_insert.role()) + tx.get_taskprov_peer_aggregator(to_insert.endpoint(), to_insert.peer_role()) .await }) }) @@ -478,7 +485,7 @@ pub(super) async fn delete_taskprov_peer_aggregator( .run_tx("delete_taskprov_peer_aggregator", |tx| { let req = req.clone(); Box::pin(async move { - tx.delete_taskprov_peer_aggregator(&req.endpoint, &req.role) + tx.delete_taskprov_peer_aggregator(&req.endpoint, &req.peer_role) .await }) }) diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index 081f301e6..b2888a226 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -16,7 +16,10 @@ use janus_aggregator_core::{ test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, }, - task::{test_util::TaskBuilder, AggregatorTask, AggregatorTaskParameters, BatchMode}, + task::{ + test_util::TaskBuilder, AggregationMode, AggregatorTask, AggregatorTaskParameters, + BatchMode, + }, taskprov::test_util::PeerAggregatorBuilder, test_util::noop_meter, SecretBytes, @@ -26,7 +29,7 @@ use janus_core::{ hpke::HpkeKeypair, test_util::install_test_trace_subscriber, time::MockClock, - vdaf::{vdaf_dp_strategies, VdafInstance, VERIFY_KEY_LENGTH}, + vdaf::{vdaf_dp_strategies, VdafInstance, VERIFY_KEY_LENGTH_PRIO3}, }; use janus_messages::{ Duration, HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeKdfId, HpkeKemId, HpkePublicKey, Role, @@ -92,10 +95,14 @@ async fn get_task_ids() { .run_unnamed_tx(|tx| { Box::pin(async move { let tasks: Vec<_> = iter::repeat_with(|| { - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .build() - .leader_view() - .unwrap() + TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build() + .leader_view() + .unwrap() }) .take(10) .collect(); @@ -186,7 +193,7 @@ async fn post_task_bad_role() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); let aggregator_auth_token = AuthenticationToken::DapAuth(random()); @@ -194,6 +201,7 @@ async fn post_task_bad_role() { let req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: None, vdaf: VdafInstance::Prio3Count, role: Role::Collector, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -225,7 +233,7 @@ async fn post_task_unauthorized() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); let aggregator_auth_token = AuthenticationToken::DapAuth(random()); @@ -233,6 +241,7 @@ async fn post_task_unauthorized() { let req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3Count, role: Role::Helper, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -265,7 +274,7 @@ async fn post_task_helper_no_optional_fields() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); @@ -273,6 +282,7 @@ async fn post_task_helper_no_optional_fields() { let req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3Count, role: Role::Helper, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -350,7 +360,7 @@ async fn post_task_helper_with_aggregator_auth_token() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); let aggregator_auth_token = AuthenticationToken::DapAuth(random()); @@ -359,6 +369,7 @@ async fn post_task_helper_with_aggregator_auth_token() { let req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3Count, role: Role::Helper, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -391,7 +402,7 @@ async fn post_task_idempotence() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); let aggregator_auth_token = AuthenticationToken::DapAuth(random()); @@ -400,6 +411,7 @@ async fn post_task_idempotence() { let mut req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3Count, role: Role::Leader, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -470,7 +482,7 @@ async fn post_task_leader_all_optional_fields() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); let aggregator_auth_token = AuthenticationToken::DapAuth(random()); @@ -479,6 +491,7 @@ async fn post_task_leader_all_optional_fields() { let req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3Count, role: Role::Leader, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -555,7 +568,7 @@ async fn post_task_leader_no_aggregator_auth_token() { let vdaf_verify_key = SecretBytes::new( thread_rng() .sample_iter(Standard) - .take(VERIFY_KEY_LENGTH) + .take(VERIFY_KEY_LENGTH_PRIO3) .collect(), ); @@ -563,6 +576,7 @@ async fn post_task_leader_no_aggregator_auth_token() { let req = PostTaskReq { peer_aggregator_endpoint: "http://aggregator.endpoint".try_into().unwrap(), batch_mode: BatchMode::TimeInterval, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3Count, role: Role::Leader, vdaf_verify_key: URL_SAFE_NO_PAD.encode(&vdaf_verify_key), @@ -595,10 +609,14 @@ async fn get_task(#[case] role: Role) { // Setup: write a task to the datastore. let (handler, _ephemeral_datastore, ds) = setup_api_test().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .build() - .view_for_role(role) - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build() + .view_for_role(role) + .unwrap(); ds.put_aggregator_task(&task).await.unwrap(); @@ -651,11 +669,14 @@ async fn delete_task() { let task_id = ds .run_unnamed_tx(|tx| { Box::pin(async move { - let task = - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build() + .leader_view() + .unwrap(); tx.put_aggregator_task(&task).await?; @@ -726,11 +747,15 @@ async fn patch_task(#[case] role: Role) { // Setup: write a task to the datastore. let (handler, _ephemeral_datastore, ds) = setup_api_test().await; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_task_end(Some(Time::from_seconds_since_epoch(1000))) - .build() - .view_for_role(role) - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_task_end(Some(Time::from_seconds_since_epoch(1000))) + .build() + .view_for_role(role) + .unwrap(); ds.put_aggregator_task(&task).await.unwrap(); let task_id = *task.id(); @@ -842,11 +867,14 @@ async fn get_task_upload_metrics() { let task_id = ds .run_unnamed_tx(|tx| { Box::pin(async move { - let task = - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build() + .leader_view() + .unwrap(); let task_id = *task.id(); tx.put_aggregator_task(&task).await.unwrap(); @@ -921,11 +949,14 @@ async fn get_task_aggregation_metrics() { let task_id = ds .run_unnamed_tx(|tx| { Box::pin(async move { - let task = - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .build() + .leader_view() + .unwrap(); let task_id = *task.id(); tx.put_aggregator_task(&task).await.unwrap(); @@ -1423,11 +1454,11 @@ async fn get_taskprov_peer_aggregator() { let leader = PeerAggregatorBuilder::new() .with_endpoint(Url::parse("https://leader.example.com/").unwrap()) - .with_role(Role::Leader) + .with_peer_role(Role::Leader) .build(); let helper = PeerAggregatorBuilder::new() .with_endpoint(Url::parse("https://helper.example.com/").unwrap()) - .with_role(Role::Helper) + .with_peer_role(Role::Helper) .build(); ds.run_unnamed_tx(|tx| { @@ -1464,14 +1495,14 @@ async fn get_taskprov_peer_aggregator() { let mut expected = vec![ TaskprovPeerAggregatorResp { endpoint: leader.endpoint().clone(), - role: *leader.role(), + peer_role: *leader.peer_role(), collector_hpke_config: leader.collector_hpke_config().clone(), report_expiry_age: leader.report_expiry_age().cloned(), tolerable_clock_skew: *leader.tolerable_clock_skew(), }, TaskprovPeerAggregatorResp { endpoint: helper.endpoint().clone(), - role: *helper.role(), + peer_role: *helper.peer_role(), collector_hpke_config: helper.collector_hpke_config().clone(), report_expiry_age: helper.report_expiry_age().cloned(), tolerable_clock_skew: *helper.tolerable_clock_skew(), @@ -1499,12 +1530,13 @@ async fn post_taskprov_peer_aggregator() { let endpoint = Url::parse("https://leader.example.com/").unwrap(); let leader = PeerAggregatorBuilder::new() .with_endpoint(endpoint.clone()) - .with_role(Role::Leader) + .with_peer_role(Role::Leader) .build(); let req = PostTaskprovPeerAggregatorReq { endpoint, - role: Role::Leader, + peer_role: Role::Leader, + aggregation_mode: Some(AggregationMode::Synchronous), collector_hpke_config: leader.collector_hpke_config().clone(), verify_key_init: *leader.verify_key_init(), report_expiry_age: leader.report_expiry_age().cloned(), @@ -1574,7 +1606,7 @@ async fn delete_taskprov_peer_aggregator() { let endpoint = Url::parse("https://leader.example.com/").unwrap(); let leader = PeerAggregatorBuilder::new() .with_endpoint(endpoint.clone()) - .with_role(Role::Leader) + .with_peer_role(Role::Leader) .build(); ds.run_unnamed_tx(|tx| { @@ -1586,7 +1618,7 @@ async fn delete_taskprov_peer_aggregator() { let req = DeleteTaskprovPeerAggregatorReq { endpoint, - role: Role::Leader, + peer_role: Role::Leader, }; // Delete target. @@ -1616,7 +1648,7 @@ async fn delete_taskprov_peer_aggregator() { .with_request_body( serde_json::to_vec(&DeleteTaskprovPeerAggregatorReq { endpoint: Url::parse("https://doesnt-exist.example.com/").unwrap(), - role: Role::Leader, + peer_role: Role::Leader, }) .unwrap() ) @@ -1689,6 +1721,7 @@ fn post_task_req_serialization() { batch_mode: BatchMode::LeaderSelected { batch_time_window_size: None, }, + aggregation_mode: Some(AggregationMode::Synchronous), vdaf: VdafInstance::Prio3SumVec { bits: 1, length: 5, @@ -1714,7 +1747,7 @@ fn post_task_req_serialization() { &[ Token::Struct { name: "PostTaskReq", - len: 12, + len: 13, }, Token::Str("peer_aggregator_endpoint"), Token::Str("https://example.com/"), @@ -1727,6 +1760,12 @@ fn post_task_req_serialization() { Token::Str("batch_time_window_size"), Token::None, Token::StructVariantEnd, + Token::Str("aggregation_mode"), + Token::Some, + Token::UnitVariant { + name: "AggregationMode", + variant: "Synchronous", + }, Token::Str("vdaf"), Token::StructVariant { name: "VdafInstance", @@ -1807,6 +1846,7 @@ fn post_task_req_serialization() { batch_mode: BatchMode::LeaderSelected { batch_time_window_size: None, }, + aggregation_mode: None, vdaf: VdafInstance::Prio3SumVec { bits: 1, length: 5, @@ -1836,7 +1876,7 @@ fn post_task_req_serialization() { &[ Token::Struct { name: "PostTaskReq", - len: 12, + len: 13, }, Token::Str("peer_aggregator_endpoint"), Token::Str("https://example.com/"), @@ -1849,6 +1889,8 @@ fn post_task_req_serialization() { Token::Str("batch_time_window_size"), Token::None, Token::StructVariantEnd, + Token::Str("aggregation_mode"), + Token::None, Token::Str("vdaf"), Token::StructVariant { name: "VdafInstance", diff --git a/aggregator_core/src/batch_mode.rs b/aggregator_core/src/batch_mode.rs index dd5022f66..98e4e3a11 100644 --- a/aggregator_core/src/batch_mode.rs +++ b/aggregator_core/src/batch_mode.rs @@ -429,7 +429,7 @@ impl CollectableBatchMode for LeaderSelected { mod tests { use crate::{ batch_mode::CollectableBatchMode, - task::{test_util::TaskBuilder, BatchMode}, + task::{test_util::TaskBuilder, AggregationMode, BatchMode}, }; use janus_core::vdaf::VdafInstance; use janus_messages::{batch_mode::TimeInterval, Duration, Interval, Time}; @@ -437,11 +437,15 @@ mod tests { #[test] fn validate_collect_identifier() { let time_precision_secs = 3600; - let task = TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Fake { rounds: 1 }) - .with_time_precision(Duration::from_seconds(time_precision_secs)) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_time_precision(Duration::from_seconds(time_precision_secs)) + .build() + .leader_view() + .unwrap(); struct TestCase { name: &'static str, diff --git a/aggregator_core/src/datastore.rs b/aggregator_core/src/datastore.rs index 521de03c0..e1e2de80b 100644 --- a/aggregator_core/src/datastore.rs +++ b/aggregator_core/src/datastore.rs @@ -13,7 +13,7 @@ use self::models::{ use crate::VdafHasAggregationParameter; use crate::{ batch_mode::{AccumulableBatchMode, CollectableBatchMode}, - task::{self, AggregatorTask, AggregatorTaskParameters}, + task::{self, AggregationMode, AggregatorTask, AggregatorTaskParameters}, taskprov::PeerAggregator, SecretBytes, TIME_HISTOGRAM_BOUNDARIES, }; @@ -29,8 +29,8 @@ use janus_core::{ use janus_messages::{ batch_mode::{BatchMode, LeaderSelected, TimeInterval}, AggregationJobId, BatchId, CollectionJobId, Duration, Extension, HpkeCiphertext, HpkeConfig, - HpkeConfigId, Interval, PrepareResp, Query, ReportId, ReportIdChecksum, ReportMetadata, - ReportShare, Role, TaskId, Time, + HpkeConfigId, Interval, PrepareContinue, PrepareInit, PrepareResp, Query, ReportId, + ReportIdChecksum, ReportMetadata, Role, TaskId, Time, }; use models::UnaggregatedReport; use opentelemetry::{ @@ -658,15 +658,16 @@ WHERE success = TRUE ORDER BY version DESC LIMIT(1)", .prepare_cached( "-- put_aggregator_task() INSERT INTO tasks ( - task_id, aggregator_role, peer_aggregator_endpoint, batch_mode, vdaf, - task_start, task_end, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, - taskprov_task_info, aggregator_auth_token_type, aggregator_auth_token, - aggregator_auth_token_hash, collector_auth_token_type, - collector_auth_token_hash, created_at, updated_at, updated_by) + task_id, aggregator_role, aggregation_mode, peer_aggregator_endpoint, + batch_mode, vdaf, task_start, task_end, report_expiry_age, min_batch_size, + time_precision, tolerable_clock_skew, collector_hpke_config, + vdaf_verify_key, taskprov_task_info, aggregator_auth_token_type, + aggregator_auth_token, aggregator_auth_token_hash, + collector_auth_token_type, collector_auth_token_hash, created_at, + updated_at, updated_by) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, - $19, $20, $21, $22 + $19, $20, $21, $22, $23 ) ON CONFLICT DO NOTHING", ) @@ -677,6 +678,7 @@ ON CONFLICT DO NOTHING", &[ /* task_id */ &task.id().as_ref(), /* aggregator_role */ &AggregatorRole::from_role(*task.role())?, + /* aggregation_mode */ &task.aggregation_mode().copied(), /* peer_aggregator_endpoint */ &task.peer_aggregator_endpoint().as_str(), /* batch_mode */ &Json(task.batch_mode()), @@ -814,11 +816,12 @@ UPDATE tasks SET task_end = $1, updated_at = $2, updated_by = $3 .prepare_cached( "-- get_aggregator_task() SELECT - aggregator_role, peer_aggregator_endpoint, batch_mode, vdaf, task_start, - task_end, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, - taskprov_task_info, aggregator_auth_token_type, aggregator_auth_token, - aggregator_auth_token_hash, collector_auth_token_type, collector_auth_token_hash + aggregator_role, aggregation_mode, peer_aggregator_endpoint, batch_mode, + vdaf, task_start, task_end, report_expiry_age, min_batch_size, + time_precision, tolerable_clock_skew, collector_hpke_config, + vdaf_verify_key, taskprov_task_info, aggregator_auth_token_type, + aggregator_auth_token, aggregator_auth_token_hash, + collector_auth_token_type, collector_auth_token_hash FROM tasks WHERE task_id = $1", ) .await?; @@ -836,11 +839,12 @@ FROM tasks WHERE task_id = $1", .prepare_cached( "-- get_aggregator_tasks() SELECT - task_id, aggregator_role, peer_aggregator_endpoint, batch_mode, vdaf, task_start, - task_end, report_expiry_age, min_batch_size, time_precision, - tolerable_clock_skew, collector_hpke_config, vdaf_verify_key, - taskprov_task_info, aggregator_auth_token_type, aggregator_auth_token, - aggregator_auth_token_hash, collector_auth_token_type, collector_auth_token_hash + task_id, aggregator_role, aggregation_mode, peer_aggregator_endpoint, + batch_mode, vdaf, task_start, task_end, report_expiry_age, min_batch_size, + time_precision, tolerable_clock_skew, collector_hpke_config, + vdaf_verify_key, taskprov_task_info, aggregator_auth_token_type, + aggregator_auth_token, aggregator_auth_token_hash, + collector_auth_token_type, collector_auth_token_hash FROM tasks", ) .await?; @@ -858,6 +862,7 @@ FROM tasks", let aggregator_role: AggregatorRole = row.get("aggregator_role"); let peer_aggregator_endpoint = row.get::<_, String>("peer_aggregator_endpoint").parse()?; let batch_mode = row.try_get::<_, Json>("batch_mode")?.0; + let aggregation_mode: Option = row.get("aggregation_mode"); let vdaf = row.try_get::<_, Json>("vdaf")?.0; let task_start = row .get::<_, Option>("task_start") @@ -920,6 +925,7 @@ FROM tasks", let aggregator_parameters = match ( aggregator_role, + aggregation_mode, aggregator_auth_token, aggregator_auth_token_hash, collector_auth_token_hash, @@ -927,6 +933,7 @@ FROM tasks", ) { ( AggregatorRole::Leader, + None, Some(aggregator_auth_token), None, Some(collector_auth_token_hash), @@ -938,6 +945,7 @@ FROM tasks", }, ( AggregatorRole::Helper, + Some(aggregation_mode), None, Some(aggregator_auth_token_hash), None, @@ -945,9 +953,10 @@ FROM tasks", ) => AggregatorTaskParameters::Helper { aggregator_auth_token_hash, collector_hpke_config, + aggregation_mode, }, - (AggregatorRole::Helper, None, None, None, None) => { - AggregatorTaskParameters::TaskprovHelper + (AggregatorRole::Helper, Some(aggregation_mode), None, None, None, None) => { + AggregatorTaskParameters::TaskprovHelper { aggregation_mode } } values => { return Err(Error::DbState(format!( @@ -1671,12 +1680,13 @@ WHERE task_id = $1 ); } - /// put_scrubbed_report stores a scrubbed report, given its associated task ID & report share. + /// put_scrubbed_report stores a scrubbed report, given its associated task ID & identifiers. #[tracing::instrument(skip(self), err(level = Level::DEBUG))] pub async fn put_scrubbed_report( &self, task_id: &TaskId, - report_share: &ReportShare, + report_id: &ReportId, + client_timestamp: &Time, ) -> Result<(), Error> { let task_info = match self.task_info_for(task_id).await? { Some(task_info) => task_info, @@ -1713,9 +1723,8 @@ WHERE client_reports.client_timestamp < $7", &stmt, &[ /* task_id */ &task_info.pkey, - /* report_id */ &report_share.metadata().id().as_ref(), - /* client_timestamp */ - &report_share.metadata().time().as_naive_date_time()?, + /* report_id */ &report_id.as_ref(), + /* client_timestamp */ &client_timestamp.as_naive_date_time()?, /* created_at */ &now, /* updated_at */ &now, /* updated_by */ &self.name, @@ -1873,8 +1882,7 @@ WHERE aggregation_jobs.task_id = $1 WITH incomplete_jobs AS ( SELECT aggregation_jobs.id FROM aggregation_jobs JOIN tasks ON tasks.id = aggregation_jobs.task_id - WHERE tasks.aggregator_role = 'LEADER' - AND aggregation_jobs.state = 'IN_PROGRESS' + WHERE aggregation_jobs.state = 'ACTIVE' AND aggregation_jobs.lease_expiry <= $2 AND UPPER(aggregation_jobs.client_timestamp_interval) >= COALESCE($2::TIMESTAMP - tasks.report_expiry_age * '1 second'::INTERVAL, @@ -1915,6 +1923,7 @@ RETURNING tasks.task_id, tasks.batch_mode, tasks.vdaf, let vdaf = row.try_get::<_, Json>("vdaf")?.0; let lease_token = row.get_bytea_and_convert::("lease_token")?; let lease_attempts = row.get_bigint_and_convert("lease_attempts")?; + Ok(Lease::new( AcquiredAggregationJob::new(task_id, aggregation_job_id, batch_mode, vdaf), lease_expiry_time, @@ -2122,7 +2131,8 @@ SELECT report_aggregations.state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, leader_prep_state, - leader_output_share, helper_prep_state, error_code + leader_output_share, prepare_init, require_taskbind_extension, + helper_prep_state, prepare_continue, error_code FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id WHERE report_aggregations.task_id = $1 @@ -2186,7 +2196,8 @@ SELECT ord, client_timestamp, last_prep_resp, report_aggregations.state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, - leader_prep_state, leader_output_share, helper_prep_state, error_code + leader_prep_state, leader_output_share, prepare_init, + require_taskbind_extension, helper_prep_state, prepare_continue, error_code FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id @@ -2251,7 +2262,8 @@ SELECT client_timestamp, last_prep_resp, report_aggregations.state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, - leader_prep_state, leader_output_share, helper_prep_state, error_code + leader_prep_state, leader_output_share, prepare_init, + require_taskbind_extension, helper_prep_state, prepare_continue, error_code FROM report_aggregations JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id WHERE report_aggregations.task_id = $1 @@ -2370,6 +2382,31 @@ WHERE report_aggregations.task_id = $1 } } + ReportAggregationStateCode::InitProcessing => { + let prepare_init_bytes = + row.get::<_, Option>>("prepare_init") + .ok_or_else(|| { + Error::DbState( + "report aggregation in state INIT_PROCESSING but prepare_init is NULL" + .to_string(), + ) + })?; + let require_taskbind_extension = row.get::<_, Option>("require_taskbind_extension") + .ok_or_else(|| { + Error::DbState( + "report aggregation in state INIT_PROCESSING but require_taskbind_extension is NULL" + .to_string(), + ) + })?; + + let prepare_init = PrepareInit::get_decoded(&prepare_init_bytes)?; + + ReportAggregationState::HelperInitProcessing { + prepare_init, + require_taskbind_extension, + } + } + ReportAggregationStateCode::Continue => { match role { Role::Leader => { @@ -2377,7 +2414,7 @@ WHERE report_aggregations.task_id = $1 .get::<_, Option>>("leader_prep_transition") .ok_or_else(|| { Error::DbState( - "report aggregation in state WAITING but leader_prep_transition is NULL" + "report aggregation in state CONTINUE but leader_prep_transition is NULL" .to_string(), ) })?; @@ -2395,7 +2432,7 @@ WHERE report_aggregations.task_id = $1 .get::<_, Option>>("helper_prep_state") .ok_or_else(|| { Error::DbState( - "report aggregation in state WAITING but helper_prep_state is NULL" + "report aggregation in state CONTINUE but helper_prep_state is NULL" .to_string(), ) })?; @@ -2410,6 +2447,36 @@ WHERE report_aggregations.task_id = $1 } } + ReportAggregationStateCode::ContinueProcessing => { + let helper_prep_state_bytes = row + .get::<_, Option>>("helper_prep_state") + .ok_or_else(|| { + Error::DbState( + "report aggregation in state CONTINUE_PROCESSING but helper_prep_state is NULL" + .to_string(), + ) + })?; + let prepare_continue_bytes = row + .get::<_, Option>>("prepare_continue") + .ok_or_else(|| { + Error::DbState( + "report aggregation in state CONTINUE_PROCESSING but message is NULL" + .to_string(), + ) + })?; + + let prepare_state = A::PrepareState::get_decoded_with_param( + &(vdaf, 1 /* helper */), + &helper_prep_state_bytes, + )?; + let prepare_continue = PrepareContinue::get_decoded(&prepare_continue_bytes)?; + + ReportAggregationState::HelperContinueProcessing { + prepare_state, + prepare_continue, + } + } + ReportAggregationStateCode::Poll => { let leader_prep_state_bytes = row.get::<_, Option>>("leader_prep_state"); let leader_output_share_bytes = @@ -2507,10 +2574,12 @@ INSERT INTO report_aggregations last_prep_resp, state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, leader_prep_state, - leader_output_share, helper_prep_state, error_code, created_at, updated_at, updated_by) + leader_output_share, prepare_init, require_taskbind_extension, + helper_prep_state, prepare_continue, error_code, created_at, updated_at, + updated_by) SELECT $1, aggregation_jobs.id, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, - $15, $16, $17, $18, $19, $20 + $15, $16, $17, $18, $19, $20, $21, $22, $23 FROM aggregation_jobs WHERE task_id = $1 AND aggregation_job_id = $2 @@ -2520,20 +2589,22 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, leader_prep_transition, leader_prep_state, leader_output_share, - helper_prep_state, error_code, created_at, updated_at, updated_by + prepare_init, require_taskbind_extension, helper_prep_state, + prepare_continue, error_code, created_at, updated_at, updated_by ) = ( excluded.client_report_id, excluded.client_timestamp, excluded.last_prep_resp, excluded.state, excluded.public_extensions, excluded.public_share, excluded.leader_private_extensions, excluded.leader_input_share, excluded.helper_encrypted_input_share, excluded.leader_prep_transition, excluded.leader_prep_state, - excluded.leader_output_share, excluded.helper_prep_state, - excluded.error_code, excluded.created_at, excluded.updated_at, - excluded.updated_by + excluded.leader_output_share, excluded.prepare_init, + excluded.require_taskbind_extension, excluded.helper_prep_state, + excluded.prepare_continue, excluded.error_code, excluded.created_at, + excluded.updated_at, excluded.updated_by ) WHERE (SELECT UPPER(client_timestamp_interval) FROM aggregation_jobs - WHERE id = report_aggregations.aggregation_job_id) >= $21", + WHERE id = report_aggregations.aggregation_job_id) >= $24", ) .await?; check_insert( @@ -2559,7 +2630,11 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE &encoded_state_values.leader_prep_transition, /* leader_prep_state */ &encoded_state_values.leader_prep_state, /* leader_output_share */ &encoded_state_values.leader_output_share, + /* prepare_init */ &encoded_state_values.prepare_init, + /* require_taskbind_extension */ + &encoded_state_values.require_taskbind_extension, /* helper_prep_state */ &encoded_state_values.helper_prep_state, + /* prepare_continue */ &encoded_state_values.prepare_continue, /* error_code */ &encoded_state_values.report_error, /* created_at */ &now, /* updated_at */ &now, @@ -2615,14 +2690,17 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE client_report_id, client_timestamp, last_prep_resp, state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, - leader_prep_transition, helper_prep_state, error_code, created_at, - updated_at, updated_by + leader_prep_transition, leader_prep_state, leader_output_share, + prepare_init, helper_prep_state, prepare_continue, error_code, + created_at, updated_at, updated_by ) = ( excluded.client_report_id, excluded.client_timestamp, excluded.last_prep_resp, excluded.state, excluded.public_extensions, excluded.public_share, excluded.leader_private_extensions, excluded.leader_input_share, excluded.helper_encrypted_input_share, - excluded.leader_prep_transition, excluded.helper_prep_state, + excluded.leader_prep_transition, excluded.leader_prep_state, + excluded.leader_output_share, excluded.prepare_init, + excluded.helper_prep_state, excluded.prepare_continue, excluded.error_code, excluded.created_at, excluded.updated_at, excluded.updated_by ) @@ -2674,14 +2752,17 @@ ON CONFLICT(task_id, aggregation_job_id, ord) DO UPDATE client_report_id, client_timestamp, last_prep_resp, state, public_extensions, public_share, leader_private_extensions, leader_input_share, helper_encrypted_input_share, - leader_prep_transition, helper_prep_state, error_code, created_at, - updated_at, updated_by + leader_prep_transition, leader_prep_state, leader_output_share, + prepare_init, helper_prep_state, prepare_continue, error_code, + created_at, updated_at, updated_by ) = ( excluded.client_report_id, excluded.client_timestamp, excluded.last_prep_resp, excluded.state, excluded.public_extensions, excluded.public_share, excluded.leader_private_extensions, excluded.leader_input_share, excluded.helper_encrypted_input_share, - excluded.leader_prep_transition, excluded.helper_prep_state, + excluded.leader_prep_transition, excluded.leader_prep_state, + excluded.leader_output_share, excluded.prepare_init, + excluded.helper_prep_state, excluded.prepare_continue, excluded.error_code, excluded.created_at, excluded.updated_at, excluded.updated_by ) @@ -2747,18 +2828,19 @@ SET last_prep_resp = $1, state = $2, public_extensions = $3, public_share = $4, leader_private_extensions = $5, leader_input_share = $6, helper_encrypted_input_share = $7, leader_prep_transition = $8, - leader_prep_state = $9, leader_output_share = $10, - helper_prep_state = $11, error_code = $12, updated_at = $13, - updated_by = $14 + leader_prep_state = $9, leader_output_share = $10, prepare_init = $11, + require_taskbind_extension = $12, helper_prep_state = $13, + prepare_continue = $14, error_code = $15, updated_at = $16, + updated_by = $17 FROM aggregation_jobs WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id - AND aggregation_jobs.aggregation_job_id = $15 - AND aggregation_jobs.task_id = $16 - AND report_aggregations.task_id = $16 - AND report_aggregations.client_report_id = $17 - AND report_aggregations.client_timestamp = $18 - AND report_aggregations.ord = $19 - AND UPPER(aggregation_jobs.client_timestamp_interval) >= $20", + AND aggregation_jobs.aggregation_job_id = $18 + AND aggregation_jobs.task_id = $19 + AND report_aggregations.task_id = $19 + AND report_aggregations.client_report_id = $20 + AND report_aggregations.client_timestamp = $21 + AND report_aggregations.ord = $22 + AND UPPER(aggregation_jobs.client_timestamp_interval) >= $23", ) .await?; check_single_row_mutation( @@ -2778,7 +2860,11 @@ WHERE report_aggregations.aggregation_job_id = aggregation_jobs.id &encoded_state_values.leader_prep_transition, /* leader_prep_state */ &encoded_state_values.leader_prep_state, /* leader_output_share */ &encoded_state_values.leader_output_share, + /* prepare_init */ &encoded_state_values.prepare_init, + /* require_taskbind_extension */ + &encoded_state_values.require_taskbind_extension, /* helper_prep_state */ &encoded_state_values.helper_prep_state, + /* prepare_continue */ &encoded_state_values.prepare_continue, /* error_code */ &encoded_state_values.report_error, /* updated_at */ &now, /* updated_by */ &self.name, @@ -5038,7 +5124,7 @@ INSERT INTO hpke_keys let stmt = self .prepare_cached( "-- get_taskprov_peer_aggregators() -SELECT id, endpoint, role, verify_key_init, collector_hpke_config, +SELECT id, endpoint, peer_role, aggregation_mode, verify_key_init, collector_hpke_config, report_expiry_age, tolerable_clock_skew FROM taskprov_peer_aggregators", ) @@ -5119,9 +5205,9 @@ ord, type, token FROM taskprov_collector_auth_tokens AS a let stmt = self .prepare_cached( "-- get_taskprov_peer_aggregator() -SELECT endpoint, role, verify_key_init, collector_hpke_config, +SELECT endpoint, peer_role, aggregation_mode, verify_key_init, collector_hpke_config, report_expiry_age, tolerable_clock_skew - FROM taskprov_peer_aggregators WHERE endpoint = $1 AND role = $2", + FROM taskprov_peer_aggregators WHERE endpoint = $1 AND peer_role = $2", ) .await?; let peer_aggregator_row = self.query_opt(&stmt, params); @@ -5131,7 +5217,7 @@ SELECT endpoint, role, verify_key_init, collector_hpke_config, "-- get_taskprov_peer_aggregator() SELECT ord, type, token FROM taskprov_aggregator_auth_tokens WHERE peer_aggregator_id = (SELECT id FROM taskprov_peer_aggregators - WHERE endpoint = $1 AND role = $2) + WHERE endpoint = $1 AND peer_role = $2) ORDER BY ord ASC", ) .await?; @@ -5142,7 +5228,7 @@ SELECT ord, type, token FROM taskprov_aggregator_auth_tokens "-- get_taskprov_peer_aggregator() SELECT ord, type, token FROM taskprov_collector_auth_tokens WHERE peer_aggregator_id = (SELECT id FROM taskprov_peer_aggregators - WHERE endpoint = $1 AND role = $2) + WHERE endpoint = $1 AND peer_role = $2) ORDER BY ord ASC", ) .await?; @@ -5172,7 +5258,8 @@ SELECT ord, type, token FROM taskprov_collector_auth_tokens ) -> Result { let endpoint = Url::parse(peer_aggregator_row.get::<_, &str>("endpoint"))?; let endpoint_bytes = endpoint.as_str().as_ref(); - let role: AggregatorRole = peer_aggregator_row.get("role"); + let peer_role: AggregatorRole = peer_aggregator_row.get("peer_role"); + let aggregation_mode: Option = peer_aggregator_row.get("aggregation_mode"); let report_expiry_age = peer_aggregator_row .get_nullable_bigint_and_convert("report_expiry_age")? .map(Duration::from_seconds); @@ -5203,7 +5290,7 @@ SELECT ord, type, token FROM taskprov_collector_auth_tokens let mut row_id = Vec::new(); row_id.extend_from_slice(endpoint_bytes); - row_id.extend_from_slice(&role.as_role().get_encoded()?); + row_id.extend_from_slice(&peer_role.as_role().get_encoded()?); row_id.extend_from_slice(&ord.to_be_bytes()); auth_token_type.as_authentication(&self.crypter.decrypt( @@ -5225,7 +5312,8 @@ SELECT ord, type, token FROM taskprov_collector_auth_tokens Ok(PeerAggregator::new( endpoint, - role.as_role(), + peer_role.as_role(), + aggregation_mode, verify_key_init, collector_hpke_config, report_expiry_age, @@ -5241,7 +5329,7 @@ SELECT ord, type, token FROM taskprov_collector_auth_tokens peer_aggregator: &PeerAggregator, ) -> Result<(), Error> { let endpoint = peer_aggregator.endpoint().as_str(); - let role = &AggregatorRole::from_role(*peer_aggregator.role())?; + let peer_role = &AggregatorRole::from_role(*peer_aggregator.peer_role())?; let encrypted_verify_key_init = self.crypter.encrypt( "taskprov_peer_aggregator", endpoint.as_ref(), @@ -5253,9 +5341,9 @@ SELECT ord, type, token FROM taskprov_collector_auth_tokens .prepare_cached( "-- put_taskprov_peer_aggregator() INSERT INTO taskprov_peer_aggregators ( - endpoint, role, verify_key_init, tolerable_clock_skew, report_expiry_age, + endpoint, peer_role, aggregation_mode, verify_key_init, tolerable_clock_skew, report_expiry_age, collector_hpke_config, created_at, updated_by -) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT DO NOTHING", ) .await?; @@ -5264,7 +5352,8 @@ ON CONFLICT DO NOTHING", &stmt, &[ /* endpoint */ &endpoint, - /* role */ role, + /* peer_role */ peer_role, + /* aggregation_mode */ &peer_aggregator.aggregation_mode(), /* verify_key_init */ &encrypted_verify_key_init, /* tolerable_clock_skew */ &i64::try_from(peer_aggregator.tolerable_clock_skew().as_seconds())?, @@ -5293,7 +5382,7 @@ ON CONFLICT DO NOTHING", let mut row_id = Vec::new(); row_id.extend_from_slice(endpoint.as_ref()); - row_id.extend_from_slice(&role.as_role().get_encoded()?); + row_id.extend_from_slice(&peer_role.as_role().get_encoded()?); row_id.extend_from_slice(&ord.to_be_bytes()); let encrypted_auth_token = @@ -5319,13 +5408,13 @@ INSERT INTO taskprov_aggregator_auth_tokens ( peer_aggregator_id, created_at, updated_by, ord, type, token ) SELECT - (SELECT id FROM taskprov_peer_aggregators WHERE endpoint = $1 AND role = $2), + (SELECT id FROM taskprov_peer_aggregators WHERE endpoint = $1 AND peer_role = $2), $3, $4, * FROM UNNEST($5::BIGINT[], $6::AUTH_TOKEN_TYPE[], $7::BYTEA[])", ) .await?; let aggregator_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ /* endpoint */ &endpoint, - /* role */ role, + /* peer_role */ peer_role, /* created_at */ &self.clock.now().as_naive_date_time()?, /* updated_by */ &self.name, /* ords */ &aggregator_auth_token_ords, @@ -5346,13 +5435,13 @@ INSERT INTO taskprov_collector_auth_tokens ( peer_aggregator_id, created_at, updated_by, ord, type, token ) SELECT - (SELECT id FROM taskprov_peer_aggregators WHERE endpoint = $1 AND role = $2), + (SELECT id FROM taskprov_peer_aggregators WHERE endpoint = $1 AND peer_role = $2), $3, $4, * FROM UNNEST($5::BIGINT[], $6::AUTH_TOKEN_TYPE[], $7::BYTEA[])", ) .await?; let collector_auth_tokens_params: &[&(dyn ToSql + Sync)] = &[ /* endpoint */ &endpoint, - /* role */ role, + /* peer_role */ peer_role, /* created_at */ &self.clock.now().as_naive_date_time()?, /* updated_by */ &self.name, /* ords */ &collector_auth_token_ords, @@ -5369,19 +5458,19 @@ SELECT pub async fn delete_taskprov_peer_aggregator( &self, aggregator_url: &Url, - role: &Role, + peer_role: &Role, ) -> Result<(), Error> { let aggregator_url = aggregator_url.as_str(); - let role = AggregatorRole::from_role(*role)?; + let peer_role = AggregatorRole::from_role(*peer_role)?; // Deletion of other data implemented via ON DELETE CASCADE. let stmt = self .prepare_cached( "-- delete_taskprov_peer_aggregator() -DELETE FROM taskprov_peer_aggregators WHERE endpoint = $1 AND role = $2", +DELETE FROM taskprov_peer_aggregators WHERE endpoint = $1 AND peer_role = $2", ) .await?; - check_single_row_mutation(self.execute(&stmt, &[&aggregator_url, &role]).await?) + check_single_row_mutation(self.execute(&stmt, &[&aggregator_url, &peer_role]).await?) } /// Get the [`TaskUploadCounter`] for a task. This is aggregated across all shards. Returns diff --git a/aggregator_core/src/datastore/models.rs b/aggregator_core/src/datastore/models.rs index cc5364841..50a317f01 100644 --- a/aggregator_core/src/datastore/models.rs +++ b/aggregator_core/src/datastore/models.rs @@ -15,8 +15,8 @@ use janus_core::{ use janus_messages::{ batch_mode::{BatchMode, LeaderSelected, TimeInterval}, AggregationJobId, AggregationJobStep, BatchId, CollectionJobId, Duration, Extension, - HpkeCiphertext, HpkeConfigId, Interval, PrepareResp, Query, ReportError, ReportId, - ReportIdChecksum, ReportMetadata, Role, TaskId, Time, + HpkeCiphertext, HpkeConfigId, Interval, PrepareContinue, PrepareInit, PrepareResp, Query, + ReportError, ReportId, ReportIdChecksum, ReportMetadata, Role, TaskId, Time, }; use postgres_protocol::types::{ range_from_sql, range_to_sql, timestamp_from_sql, timestamp_to_sql, Range, RangeBound, @@ -538,8 +538,10 @@ where #[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, ToSql, FromSql)] #[postgres(name = "aggregation_job_state")] pub enum AggregationJobState { - #[postgres(name = "IN_PROGRESS")] - InProgress, + #[postgres(name = "ACTIVE")] + Active, + #[postgres(name = "AWAITING_REQUEST")] + AwaitingRequest, #[postgres(name = "FINISHED")] Finished, #[postgres(name = "ABANDONED")] @@ -952,18 +954,37 @@ pub enum ReportAggregationState, }, // // Helper-only states. // + /// The Helper has received an aggregation initialization request from the Leader, and is + /// processing it asynchronously. + HelperInitProcessing { + /// The initialization message received for this report aggregation. + prepare_init: PrepareInit, + /// Does this report aggregation require the taskprov extension? + require_taskbind_extension: bool, + }, /// The Helper is ready to receive an aggregation continuation request from the Leader. HelperContinue { /// Helper's current preparation state #[educe(Debug(ignore))] prepare_state: A::PrepareState, }, + /// The Helper has received an aggregation continuation request from the Leader, and is + /// processing it asynchronously. + HelperContinueProcessing { + /// Helper's current preparation state. + #[educe(Debug(ignore))] + prepare_state: A::PrepareState, + /// The message from the Leader for this report aggregation. + #[educe(Debug(ignore))] + prepare_continue: PrepareContinue, + }, // // Common states. @@ -980,9 +1001,17 @@ impl> pub(super) fn state_code(&self) -> ReportAggregationStateCode { match self { ReportAggregationState::LeaderInit { .. } => ReportAggregationStateCode::Init, - ReportAggregationState::LeaderContinue { .. } - | ReportAggregationState::HelperContinue { .. } => ReportAggregationStateCode::Continue, + ReportAggregationState::LeaderContinue { .. } => ReportAggregationStateCode::Continue, ReportAggregationState::LeaderPoll { .. } => ReportAggregationStateCode::Poll, + + ReportAggregationState::HelperInitProcessing { .. } => { + ReportAggregationStateCode::InitProcessing + } + ReportAggregationState::HelperContinue { .. } => ReportAggregationStateCode::Continue, + ReportAggregationState::HelperContinueProcessing { .. } => { + ReportAggregationStateCode::ContinueProcessing + } + ReportAggregationState::Finished => ReportAggregationStateCode::Finished, ReportAggregationState::Failed { .. } => ReportAggregationStateCode::Failed, } @@ -1046,6 +1075,15 @@ impl> } } + ReportAggregationState::HelperInitProcessing { + prepare_init, + require_taskbind_extension, + } => EncodedReportAggregationStateValues { + prepare_init: Some(prepare_init.get_encoded()?), + require_taskbind_extension: Some(*require_taskbind_extension), + ..Default::default() + }, + ReportAggregationState::HelperContinue { prepare_state } => { EncodedReportAggregationStateValues { helper_prep_state: Some(prepare_state.get_encoded()?), @@ -1053,7 +1091,17 @@ impl> } } + ReportAggregationState::HelperContinueProcessing { + prepare_state, + prepare_continue, + } => EncodedReportAggregationStateValues { + helper_prep_state: Some(prepare_state.get_encoded()?), + prepare_continue: Some(prepare_continue.get_encoded()?), + ..Default::default() + }, + ReportAggregationState::Finished => EncodedReportAggregationStateValues::default(), + ReportAggregationState::Failed { report_error } => { EncodedReportAggregationStateValues { report_error: Some(*report_error as i16), @@ -1080,9 +1128,16 @@ pub(super) struct EncodedReportAggregationStateValues { pub(super) leader_prep_state: Option>, pub(super) leader_output_share: Option>, - // State for HelperContinue. + // State for HelperInitProcessing. + pub(super) prepare_init: Option>, + pub(super) require_taskbind_extension: Option, + + // State for HelperContinue & HelperContinueProcessing. pub(super) helper_prep_state: Option>, + // State for HelperContinueProcessing. + pub(super) prepare_continue: Option>, + // State for Failed. pub(super) report_error: Option, } @@ -1095,8 +1150,12 @@ pub(super) struct EncodedReportAggregationStateValues { pub(super) enum ReportAggregationStateCode { #[postgres(name = "INIT")] Init, + #[postgres(name = "INIT_PROCESSING")] + InitProcessing, #[postgres(name = "CONTINUE")] Continue, + #[postgres(name = "CONTINUE_PROCESSING")] + ContinueProcessing, #[postgres(name = "POLL")] Poll, #[postgres(name = "FINISHED")] @@ -1171,6 +1230,20 @@ where _ => false, }, + ( + Self::HelperInitProcessing { + prepare_init: lhs_prepare_init, + require_taskbind_extension: lhs_require_taskbind_extension, + }, + Self::HelperInitProcessing { + prepare_init: rhs_prepare_init, + require_taskbind_extension: rhs_require_taskbind_extension, + }, + ) => { + lhs_prepare_init == rhs_prepare_init + && lhs_require_taskbind_extension == rhs_require_taskbind_extension + } + ( Self::HelperContinue { prepare_state: lhs_state, @@ -1180,6 +1253,17 @@ where }, ) => lhs_state == rhs_state, + ( + Self::HelperContinueProcessing { + prepare_state: lhs_state, + prepare_continue: lhs_prepare_continue, + }, + Self::HelperContinueProcessing { + prepare_state: rhs_state, + prepare_continue: rhs_prepare_continue, + }, + ) => lhs_state == rhs_state && lhs_prepare_continue == rhs_prepare_continue, + ( Self::Failed { report_error: lhs_report_error, diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index 6879cd576..7159855e6 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -1,3 +1,5 @@ +#![allow(clippy::unit_arg)] // allow reference to dummy::Vdaf's public share, which has the unit type + use crate::{ batch_mode::CollectableBatchMode, datastore::{ @@ -16,7 +18,7 @@ use crate::{ }, Crypter, Datastore, Error, RowExt, Transaction, SUPPORTED_SCHEMA_VERSIONS, }, - task::{self, test_util::TaskBuilder, AggregatorTask}, + task::{self, test_util::TaskBuilder, AggregationMode, AggregatorTask}, taskprov::test_util::PeerAggregatorBuilder, test_util::noop_meter, }; @@ -28,14 +30,14 @@ use janus_core::{ hpke::{self, HpkeApplicationInfo, Label}, test_util::{install_test_trace_subscriber, run_vdaf}, time::{Clock, DurationExt, IntervalExt, MockClock, TimeExt}, - vdaf::{vdaf_dp_strategies, VdafInstance, VERIFY_KEY_LENGTH}, + vdaf::{vdaf_dp_strategies, VdafInstance, VERIFY_KEY_LENGTH_PRIO3}, }; use janus_messages::{ batch_mode::{BatchMode, LeaderSelected, TimeInterval}, AggregateShareAad, AggregationJobId, AggregationJobStep, BatchId, BatchSelector, CollectionJobId, Duration, Extension, ExtensionType, HpkeCiphertext, HpkeConfigId, Interval, - PrepareResp, PrepareStepResult, Query, ReportError, ReportId, ReportIdChecksum, ReportMetadata, - ReportShare, Role, TaskId, Time, + PrepareInit, PrepareResp, PrepareStepResult, Query, ReportError, ReportId, ReportIdChecksum, + ReportMetadata, ReportShare, Role, TaskId, Time, }; use prio::{ codec::{Decode, Encode}, @@ -205,13 +207,17 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { Role::Leader, ), ] { - let task = TaskBuilder::new(task::BatchMode::TimeInterval, vdaf) - .with_task_start(Some(Time::from_seconds_since_epoch(1000))) - .with_task_end(Some(Time::from_seconds_since_epoch(4000))) - .with_report_expiry_age(Some(Duration::from_seconds(3600))) - .build() - .view_for_role(role) - .unwrap(); + let task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Synchronous, + vdaf, + ) + .with_task_start(Some(Time::from_seconds_since_epoch(1000))) + .with_task_end(Some(Time::from_seconds_since_epoch(4000))) + .with_report_expiry_age(Some(Duration::from_seconds(3600))) + .build() + .view_for_role(role) + .unwrap(); want_tasks.insert(*task.id(), task.clone()); let err = ds @@ -305,11 +311,15 @@ async fn update_task_end(ephemeral_datastore: EphemeralDatastore) { install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new(task::BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_task_end(Some(Time::from_seconds_since_epoch(1000))) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_task_end(Some(Time::from_seconds_since_epoch(1000))) + .build() + .leader_view() + .unwrap(); ds.put_aggregator_task(&task).await.unwrap(); ds.run_unnamed_tx(|tx| { @@ -354,10 +364,14 @@ async fn put_task_invalid_aggregator_auth_tokens(ephemeral_datastore: EphemeralD install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new(task::BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); ds.put_aggregator_task(&task).await.unwrap(); @@ -394,10 +408,14 @@ async fn put_task_invalid_collector_auth_tokens(ephemeral_datastore: EphemeralDa install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new(task::BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); + let task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); ds.put_aggregator_task(&task).await.unwrap(); @@ -440,6 +458,7 @@ async fn get_task_ids(ephemeral_datastore: EphemeralDatastore) { let tasks: Vec<_> = iter::repeat_with(|| { TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -484,6 +503,7 @@ async fn roundtrip_report(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(report_expiry_age)) @@ -674,6 +694,7 @@ async fn get_unaggregated_client_reports_for_task(ephemeral_datastore: Ephemeral .unwrap(); let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -682,6 +703,7 @@ async fn get_unaggregated_client_reports_for_task(ephemeral_datastore: Ephemeral .unwrap(); let unrelated_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -895,6 +917,7 @@ async fn get_unaggregated_client_report_ids_with_agg_param_for_task( let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -902,6 +925,7 @@ async fn get_unaggregated_client_report_ids_with_agg_param_for_task( .unwrap(); let unrelated_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -1059,7 +1083,7 @@ async fn get_unaggregated_client_report_ids_with_agg_param_for_task( (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await?; @@ -1197,6 +1221,7 @@ async fn count_client_reports_for_interval(ephemeral_datastore: EphemeralDatasto let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -1205,6 +1230,7 @@ async fn count_client_reports_for_interval(ephemeral_datastore: EphemeralDatasto .unwrap(); let unrelated_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -1212,6 +1238,7 @@ async fn count_client_reports_for_interval(ephemeral_datastore: EphemeralDatasto .unwrap(); let no_reports_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -1333,6 +1360,7 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -1343,6 +1371,7 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -1388,7 +1417,7 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto Duration::from_seconds(1), ) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let expired_report_aggregation = expired_report @@ -1401,7 +1430,7 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto batch_id, Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(2)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let aggregation_job_0_report_aggregation_0 = @@ -1416,7 +1445,7 @@ async fn count_client_reports_for_batch_id(ephemeral_datastore: EphemeralDatasto batch_id, Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let aggregation_job_1_report_aggregation_0 = @@ -1479,32 +1508,24 @@ async fn roundtrip_scrubbed_report(ephemeral_datastore: EphemeralDatastore) { install_test_trace_subscriber(); let ds = ephemeral_datastore.datastore(MockClock::default()).await; - let task = TaskBuilder::new(task::BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .leader_view() - .unwrap(); - let report_share = ReportShare::new( - ReportMetadata::new( - ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]), - Time::from_seconds_since_epoch(12345), - Vec::from([Extension::new( - ExtensionType::Tbd, - "public_extension_tbd".into(), - )]), - ), - Vec::from("public_share"), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("encapsulated_context_0"), - Vec::from("payload_0"), - ), - ); + let task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .build() + .leader_view() + .unwrap(); + + let report_id = ReportId::from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); + let client_timestamp = Time::from_seconds_since_epoch(12345); ds.run_tx("test-put-report-share", |tx| { - let (task, report_share) = (task.clone(), report_share.clone()); + let task = task.clone(); + Box::pin(async move { tx.put_aggregator_task(&task).await.unwrap(); - tx.put_scrubbed_report(task.id(), &report_share) + tx.put_scrubbed_report(task.id(), &report_id, &client_timestamp) .await .unwrap(); @@ -1527,7 +1548,6 @@ async fn roundtrip_scrubbed_report(ephemeral_datastore: EphemeralDatastore) { ) = ds .run_unnamed_tx(|tx| { let task_id = *task.id(); - let report_id = *report_share.metadata().id(); Box::pin(async move { // Verify that attempting to read the report share as a report receives the expected @@ -1609,6 +1629,7 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) { task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -1622,7 +1643,7 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) { dummy::AggregationParam(23), batch_id, Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let helper_aggregation_job = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( @@ -1631,7 +1652,7 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) { dummy::AggregationParam(23), random(), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); @@ -1772,7 +1793,7 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) { Duration::from_seconds(6789), ) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); ds.run_unnamed_tx(|tx| { @@ -1833,28 +1854,47 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); const AGGREGATION_JOB_COUNT: usize = 10; - let task = TaskBuilder::new(task::BatchMode::TimeInterval, VdafInstance::Prio3Count) - .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) - .build() - .leader_view() - .unwrap(); - let mut aggregation_job_ids: Vec<_> = thread_rng() - .sample_iter(Standard) + let leader_task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Synchronous, + VdafInstance::Prio3Count, + ) + .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) + .build() + .leader_view() + .unwrap(); + let helper_task = TaskBuilder::new( + task::BatchMode::TimeInterval, + AggregationMode::Asynchronous, + VdafInstance::Prio3Count, + ) + .build() + .helper_view() + .unwrap(); + + let mut task_and_aggregation_job_ids: Vec<_> = [*leader_task.id(), *helper_task.id()] + .into_iter() + .cycle() + .zip(thread_rng().sample_iter(Standard)) .take(AGGREGATION_JOB_COUNT) .collect(); - aggregation_job_ids.sort(); + task_and_aggregation_job_ids.sort(); ds.run_unnamed_tx(|tx| { - let (task, aggregation_job_ids) = (task.clone(), aggregation_job_ids.clone()); + let leader_task = leader_task.clone(); + let helper_task = helper_task.clone(); + let task_and_aggregation_job_ids = task_and_aggregation_job_ids.clone(); + Box::pin(async move { // Write a few aggregation jobs we expect to be able to retrieve with // acquire_incomplete_aggregation_jobs(). - tx.put_aggregator_task(&task).await.unwrap(); - try_join_all(aggregation_job_ids.into_iter().map(|aggregation_job_id| { - let task_id = *task.id(); - async move { + tx.put_aggregator_task(&leader_task).await.unwrap(); + tx.put_aggregator_task(&helper_task).await.unwrap(); + + try_join_all(task_and_aggregation_job_ids.into_iter().map( + |(task_id, aggregation_job_id)| async move { tx.put_aggregation_job(&AggregationJob::< - VERIFY_KEY_LENGTH, + VERIFY_KEY_LENGTH_PRIO3, TimeInterval, Prio3Count, >::new( @@ -1871,68 +1911,70 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore Duration::from_seconds(1), ) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await - } - })) + }, + )) .await .unwrap(); // Write an aggregation job that is finished. We don't want to retrieve this one. - tx.put_aggregation_job( - &AggregationJob::::new( - *task.id(), - random(), - (), - (), - Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::Finished, - AggregationJobStep::from(1), - ), - ) + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *leader_task.id(), + random(), + (), + (), + Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)).unwrap(), + AggregationJobState::Finished, + AggregationJobStep::from(1), + )) .await .unwrap(); // Write an expired aggregation job. We don't want to retrieve this one, either. - tx.put_aggregation_job( - &AggregationJob::::new( - *task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ), - ) + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *leader_task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::Active, + AggregationJobStep::from(0), + )) .await .unwrap(); - // Write an aggregation job for a task that we are taking on the helper role for. - // We don't want to retrieve this one, either. - let helper_task = - TaskBuilder::new(task::BatchMode::TimeInterval, VdafInstance::Prio3Count) - .build() - .helper_view() - .unwrap(); - tx.put_aggregator_task(&helper_task).await.unwrap(); - tx.put_aggregation_job( - &AggregationJob::::new( - *helper_task.id(), - random(), - (), - (), - Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) - .unwrap(), - AggregationJobState::InProgress, - AggregationJobStep::from(0), - ), - ) + // Write an aggregation job that is awaiting a request from the Leader. We don't want to + // retrieve this one, either. + tx.put_aggregation_job(&AggregationJob::< + VERIFY_KEY_LENGTH_PRIO3, + TimeInterval, + Prio3Count, + >::new( + *helper_task.id(), + random(), + (), + (), + Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) + .unwrap(), + AggregationJobState::AwaitingRequest, + AggregationJobStep::from(0), + )) .await + .unwrap(); + + Ok(()) }) }) .await @@ -1963,13 +2005,13 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore let want_expiry_time = clock.now().as_naive_date_time().unwrap() + chrono::Duration::from_std(LEASE_DURATION).unwrap(); - let want_aggregation_jobs: Vec<_> = aggregation_job_ids + let want_aggregation_jobs: Vec<_> = task_and_aggregation_job_ids .iter() - .map(|&agg_job_id| { + .map(|(task_id, aggregation_job_id)| { ( AcquiredAggregationJob::new( - *task.id(), - agg_job_id, + *task_id, + *aggregation_job_id, task::BatchMode::TimeInterval, VdafInstance::Prio3Count, ), @@ -2113,13 +2155,13 @@ async fn aggregation_job_acquire_release(ephemeral_datastore: EphemeralDatastore )); let want_expiry_time = clock.now().as_naive_date_time().unwrap() + chrono::Duration::from_std(LEASE_DURATION).unwrap(); - let want_aggregation_jobs: Vec<_> = aggregation_job_ids + let want_aggregation_jobs: Vec<_> = task_and_aggregation_job_ids .iter() - .map(|&job_id| { + .map(|(task_id, aggregation_job_id)| { ( AcquiredAggregationJob::new( - *task.id(), - job_id, + *task_id, + *aggregation_job_id, task::BatchMode::TimeInterval, VdafInstance::Prio3Count, ), @@ -2202,7 +2244,7 @@ async fn aggregation_job_not_found(ephemeral_datastore: EphemeralDatastore) { let rslt = ds .run_unnamed_tx(|tx| { Box::pin(async move { - tx.get_aggregation_job::( + tx.get_aggregation_job::( &random(), &random(), ) @@ -2216,7 +2258,7 @@ async fn aggregation_job_not_found(ephemeral_datastore: EphemeralDatastore) { let rslt = ds .run_unnamed_tx(|tx| { Box::pin(async move { - tx.update_aggregation_job::( + tx.update_aggregation_job::( &AggregationJob::new( random(), random(), @@ -2224,7 +2266,7 @@ async fn aggregation_job_not_found(ephemeral_datastore: EphemeralDatastore) { (), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ), ) @@ -2248,6 +2290,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -2259,7 +2302,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) dummy::AggregationParam(23), random(), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let second_aggregation_job = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( @@ -2268,7 +2311,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) dummy::AggregationParam(42), random(), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); let aggregation_job_with_request_hash = AggregationJob::<0, LeaderSelected, dummy::Vdaf>::new( @@ -2277,7 +2320,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) dummy::AggregationParam(42), random(), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)).unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ) .with_last_request_hash([3; 32]); @@ -2303,6 +2346,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -2316,7 +2360,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) random(), Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -2343,6 +2387,8 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore) #[rstest_reuse::apply(schema_versions_template)] #[tokio::test] async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { + use janus_messages::PrepareContinue; + install_test_trace_subscriber(); let task_id = random(); @@ -2400,12 +2446,48 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { leader_state: vdaf_transcript.leader_prepare_transitions[1].state.clone(), }, ), + ( + Role::Helper, + ReportAggregationState::HelperInitProcessing { + prepare_init: PrepareInit::new( + ReportShare::new( + ReportMetadata::new( + report_id, + Time::from_seconds_since_epoch(25000), + Vec::new(), + ), + vdaf_transcript.public_share.get_encoded().unwrap(), + HpkeCiphertext::new( + HpkeConfigId::from(13), + Vec::from("encapsulated_context"), + Vec::from("payload"), + ), + ), + vdaf_transcript.leader_prepare_transitions[0] + .message + .clone(), + ), + require_taskbind_extension: true, + }, + ), ( Role::Helper, ReportAggregationState::HelperContinue { prepare_state: *vdaf_transcript.helper_prepare_transitions[0].prepare_state(), }, ), + ( + Role::Helper, + ReportAggregationState::HelperContinueProcessing { + prepare_state: *vdaf_transcript.helper_prepare_transitions[0].prepare_state(), + prepare_continue: PrepareContinue::new( + report_id, + vdaf_transcript.leader_prepare_transitions[1] + .message + .clone(), + ), + }, + ), (Role::Leader, ReportAggregationState::Finished), (Role::Helper, ReportAggregationState::Finished), ( @@ -2429,6 +2511,7 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 2 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -2451,29 +2534,14 @@ async fn roundtrip_report_aggregation(ephemeral_datastore: EphemeralDatastore) { (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await .unwrap(); - tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - ReportMetadata::new( - report_id, - OLDEST_ALLOWED_REPORT_TIMESTAMP, - Vec::new(), - ), - Vec::from("public_share"), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("encapsulated_context_0"), - Vec::from("payload_0"), - ), - ), - ) - .await - .unwrap(); + tx.put_scrubbed_report(task.id(), &report_id, &OLDEST_ALLOWED_REPORT_TIMESTAMP) + .await + .unwrap(); let report_aggregation = ReportAggregation::new( *task.id(), @@ -2710,6 +2778,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 2 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -2732,7 +2801,7 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -2764,24 +2833,9 @@ async fn get_report_aggregations_for_aggregation_job(ephemeral_datastore: Epheme .enumerate() { let report_id = ReportId::from((ord as u128).to_be_bytes()); - tx.put_scrubbed_report( - task.id(), - &ReportShare::new( - ReportMetadata::new( - report_id, - OLDEST_ALLOWED_REPORT_TIMESTAMP, - Vec::new(), - ), - Vec::from("public_share"), - HpkeCiphertext::new( - HpkeConfigId::from(12), - Vec::from("encapsulated_context_0"), - Vec::from("payload_0"), - ), - ), - ) - .await - .unwrap(); + tx.put_scrubbed_report(task.id(), &report_id, &OLDEST_ALLOWED_REPORT_TIMESTAMP) + .await + .unwrap(); let report_aggregation = ReportAggregation::new( *task.id(), @@ -2872,6 +2926,7 @@ async fn create_report_aggregation_from_client_reports_table( let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 2 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -2894,7 +2949,7 @@ async fn create_report_aggregation_from_client_reports_table( (), Interval::new(OLDEST_ALLOWED_REPORT_TIMESTAMP, Duration::from_seconds(1)) .unwrap(), - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), )) .await @@ -3034,6 +3089,7 @@ async fn get_collection_job(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -3257,6 +3313,7 @@ async fn update_collection_jobs(ephemeral_datastore: EphemeralDatastore) { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -3471,11 +3528,15 @@ async fn setup_collection_job_acquire_test_case( Box::pin(async move { for task_id in &test_case.task_ids { tx.put_aggregator_task( - &TaskBuilder::new(test_case.batch_mode, VdafInstance::Fake { rounds: 1 }) - .with_id(*task_id) - .build() - .leader_view() - .unwrap(), + &TaskBuilder::new( + test_case.batch_mode, + AggregationMode::Synchronous, + VdafInstance::Fake { rounds: 1 }, + ) + .with_id(*task_id) + .build() + .leader_view() + .unwrap(), ) .await .unwrap(); @@ -4316,6 +4377,7 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera let time_precision = Duration::from_seconds(100); let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_time_precision(time_precision) @@ -4325,6 +4387,7 @@ async fn roundtrip_batch_aggregation_time_interval(ephemeral_datastore: Ephemera .unwrap(); let other_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -4665,6 +4728,7 @@ async fn roundtrip_batch_aggregation_leader_selected(ephemeral_datastore: Epheme task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -4682,6 +4746,7 @@ async fn roundtrip_batch_aggregation_leader_selected(ephemeral_datastore: Epheme task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -4875,6 +4940,7 @@ async fn roundtrip_aggregate_share_job_time_interval(ephemeral_datastore: Epheme Box::pin(async move { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5047,6 +5113,7 @@ async fn roundtrip_aggregate_share_job_leader_selected(ephemeral_datastore: Ephe task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5202,6 +5269,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5255,6 +5323,7 @@ async fn roundtrip_outstanding_batch(ephemeral_datastore: EphemeralDatastore) { task::BatchMode::LeaderSelected { batch_time_window_size: Some(batch_time_window_size), }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5571,6 +5640,7 @@ async fn delete_expired_client_reports(ephemeral_datastore: EphemeralDatastore) Box::pin(async move { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(report_expiry_age)) @@ -5579,6 +5649,7 @@ async fn delete_expired_client_reports(ephemeral_datastore: EphemeralDatastore) .unwrap(); let other_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -5672,6 +5743,7 @@ async fn delete_expired_client_reports_noop(ephemeral_datastore: EphemeralDatast Box::pin(async move { let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(None) @@ -5783,7 +5855,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData *aggregation_param, B::partial_batch_identifier(&batch_identifier).clone(), client_timestamp_interval, - AggregationJobState::InProgress, + AggregationJobState::Active, AggregationJobStep::from(0), ); tx.put_aggregation_job(&aggregation_job).await.unwrap(); @@ -5818,6 +5890,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData Box::pin(async move { let leader_time_interval_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5826,6 +5899,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData .unwrap(); let helper_time_interval_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5836,6 +5910,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -5846,6 +5921,7 @@ async fn delete_expired_aggregation_artifacts(ephemeral_datastore: EphemeralData task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -6373,6 +6449,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas Box::pin(async move { let leader_time_interval_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -6381,6 +6458,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas .unwrap(); let helper_time_interval_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -6391,6 +6469,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -6401,6 +6480,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -6411,6 +6491,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas task::BatchMode::LeaderSelected { batch_time_window_size: Some(Duration::from_hours(24).unwrap()), }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -6419,6 +6500,7 @@ async fn delete_expired_collection_artifacts(ephemeral_datastore: EphemeralDatas .unwrap(); let other_task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(REPORT_EXPIRY_AGE)) @@ -7423,10 +7505,11 @@ async fn roundtrip_taskprov_peer_aggregator(ephemeral_datastore: EphemeralDatast let datastore = ephemeral_datastore.datastore(MockClock::default()).await; // Basic aggregator. - let example_leader_peer_aggregator = - PeerAggregatorBuilder::new().with_role(Role::Leader).build(); + let example_leader_peer_aggregator = PeerAggregatorBuilder::new() + .with_peer_role(Role::Leader) + .build(); let example_helper_peer_aggregator = PeerAggregatorBuilder::new() - .with_role(Role::Helper) + .with_peer_role(Role::Helper) .with_aggregator_auth_tokens(Vec::from([random(), random()])) .with_collector_auth_tokens(Vec::new()) .build(); @@ -7479,16 +7562,19 @@ async fn roundtrip_taskprov_peer_aggregator(ephemeral_datastore: EphemeralDatast let another_example_leader_peer_aggregator = another_example_leader_peer_aggregator.clone(); Box::pin(async move { - for peer in [ + for peer_aggregator in [ example_leader_peer_aggregator.clone(), example_helper_peer_aggregator.clone(), another_example_leader_peer_aggregator.clone(), ] { assert_eq!( - tx.get_taskprov_peer_aggregator(peer.endpoint(), peer.role()) - .await - .unwrap(), - Some(peer.clone()), + tx.get_taskprov_peer_aggregator( + peer_aggregator.endpoint(), + peer_aggregator.peer_role() + ) + .await + .unwrap(), + Some(peer_aggregator.clone()), ); } @@ -7506,7 +7592,7 @@ async fn roundtrip_taskprov_peer_aggregator(ephemeral_datastore: EphemeralDatast example_helper_peer_aggregator.clone(), another_example_leader_peer_aggregator.clone(), ] { - tx.delete_taskprov_peer_aggregator(peer.endpoint(), peer.role()) + tx.delete_taskprov_peer_aggregator(peer.endpoint(), peer.peer_role()) .await .unwrap(); } @@ -7552,6 +7638,7 @@ async fn accept_write_expired_report(ephemeral_datastore: EphemeralDatastore) { let report_expiry_age = Duration::from_seconds(60); let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .with_report_expiry_age(Some(report_expiry_age)) @@ -7635,6 +7722,7 @@ async fn roundtrip_task_upload_counter(ephemeral_datastore: EphemeralDatastore) let task = TaskBuilder::new( task::BatchMode::TimeInterval, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() @@ -7720,6 +7808,7 @@ async fn roundtrip_task_aggregation_counter(ephemeral_datastore: EphemeralDatast task::BatchMode::LeaderSelected { batch_time_window_size: None, }, + AggregationMode::Synchronous, VdafInstance::Fake { rounds: 1 }, ) .build() diff --git a/aggregator_core/src/task.rs b/aggregator_core/src/task.rs index 77a59d0d4..36f773acc 100644 --- a/aggregator_core/src/task.rs +++ b/aggregator_core/src/task.rs @@ -1,6 +1,7 @@ //! Shared parameters for a DAP task. use crate::SecretBytes; +use anyhow::anyhow; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use educe::Educe; use janus_core::{ @@ -11,9 +12,10 @@ use janus_core::{ use janus_messages::{ batch_mode, AggregationJobId, AggregationJobStep, Duration, HpkeConfig, Role, TaskId, Time, }; +use postgres_types::{FromSql, ToSql}; use rand::{distributions::Standard, random, thread_rng, Rng}; use serde::{de::Error as _, Deserialize, Deserializer, Serialize, Serializer}; -use std::array::TryFromSliceError; +use std::{array::TryFromSliceError, str::FromStr}; use url::Url; /// Errors that methods and functions in this module may return. @@ -267,7 +269,7 @@ impl AggregatorTask { { if matches!( aggregator_parameters, - AggregatorTaskParameters::TaskprovHelper + AggregatorTaskParameters::TaskprovHelper { .. }, ) { return Err(Error::InvalidParameter( "batch_time_window_size is not supported for taskprov", @@ -297,6 +299,11 @@ impl AggregatorTask { self.aggregator_parameters.role() } + /// Retrieves the aggregation mode of the task for the Helper, or None for the Leader. + pub fn aggregation_mode(&self) -> Option<&AggregationMode> { + self.aggregator_parameters().aggregation_mode() + } + /// Retrieves the peer aggregator endpoint associated with this task. pub fn peer_aggregator_endpoint(&self) -> &Url { &self.peer_aggregator_endpoint @@ -493,6 +500,7 @@ pub enum AggregatorTaskParameters { /// HPKE configuration for the collector. collector_hpke_config: HpkeConfig, }, + /// Task parameters held exclusively by the DAP helper. Helper { /// Authentication token hash used to validate requests from the leader during the @@ -500,10 +508,39 @@ pub enum AggregatorTaskParameters { aggregator_auth_token_hash: AuthenticationTokenHash, /// HPKE configuration for the collector. collector_hpke_config: HpkeConfig, + /// The aggregation mode to use for this task. + aggregation_mode: AggregationMode, }, - /// Task parameters held exclusively by a DAP helper provisioned via taskprov. Currently there - /// are no such parameters. - TaskprovHelper, + + /// Task parameters held exclusively by a DAP helper provisioned via taskprov. + TaskprovHelper { aggregation_mode: AggregationMode }, +} + +/// Indicates an aggregation mode: synchronous or asynchronous. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ToSql, FromSql)] +#[postgres(name = "aggregation_mode")] +pub enum AggregationMode { + /// Aggregation is completed synchronously, i.e. every successful aggregation initialization or + /// continuation request will be responded to with a response in the "finished" status. + #[postgres(name = "SYNCHRONOUS")] + Synchronous, + + /// Aggregation is completed asynchronously, i.e. every successful aggregation initialization or + /// continuation request will be responded to with a response in the "processing" status. + #[postgres(name = "ASYNCHRONOUS")] + Asynchronous, +} + +impl FromStr for AggregationMode { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "synchronous" => Ok(Self::Synchronous), + "asynchronous" => Ok(Self::Asynchronous), + _ => Err(anyhow!("couldn't parse AggregationMode value: {s}")), + } + } } impl AggregatorTaskParameters { @@ -511,7 +548,20 @@ impl AggregatorTaskParameters { pub fn role(&self) -> &Role { match self { Self::Leader { .. } => &Role::Leader, - Self::Helper { .. } | Self::TaskprovHelper => &Role::Helper, + Self::Helper { .. } | Self::TaskprovHelper { .. } => &Role::Helper, + } + } + + /// Returns the [`AggregationMode`] for this task for the helper, or `None` for the leader. + fn aggregation_mode(&self) -> Option<&AggregationMode> { + match self { + Self::Leader { .. } => None, + Self::Helper { + aggregation_mode, .. + } => Some(aggregation_mode), + Self::TaskprovHelper { + aggregation_mode, .. + } => Some(aggregation_mode), } } @@ -576,6 +626,7 @@ pub struct SerializedAggregatorTask { task_id: Option, peer_aggregator_endpoint: Url, batch_mode: BatchMode, + aggregation_mode: Option, vdaf: VdafInstance, role: Role, vdaf_verify_key: Option, // in unpadded base64url @@ -633,6 +684,7 @@ impl Serialize for AggregatorTask { task_id: Some(*self.id()), peer_aggregator_endpoint: self.peer_aggregator_endpoint().clone(), batch_mode: *self.batch_mode(), + aggregation_mode: self.aggregator_parameters.aggregation_mode().copied(), vdaf: self.vdaf().clone(), role: *self.role(), vdaf_verify_key: Some(URL_SAFE_NO_PAD.encode(self.opaque_vdaf_verify_key())), @@ -690,6 +742,9 @@ impl TryFrom for AggregatorTask { Error::InvalidParameter("missing aggregator auth token hash"), )?, collector_hpke_config: serialized_task.collector_hpke_config, + aggregation_mode: serialized_task + .aggregation_mode + .ok_or(Error::InvalidParameter("missing aggregation mode"))?, }, _ => return Err(Error::InvalidParameter("unexpected role")), }; @@ -724,8 +779,8 @@ impl<'de> Deserialize<'de> for AggregatorTask { pub mod test_util { use crate::{ task::{ - AggregatorTask, AggregatorTaskParameters, BatchMode, CommonTaskParameters, Error, - VerifyKey, + AggregationMode, AggregatorTask, AggregatorTaskParameters, BatchMode, + CommonTaskParameters, Error, VerifyKey, }, SecretBytes, }; @@ -738,7 +793,8 @@ pub mod test_util { vdaf::VdafInstance, }; use janus_messages::{ - AggregationJobId, CollectionJobId, Duration, HpkeConfigId, Role, TaskId, Time, + AggregationJobId, AggregationJobStep, CollectionJobId, Duration, HpkeConfigId, Role, + TaskId, Time, }; use rand::{distributions::Standard, random, thread_rng, Rng}; use std::collections::HashMap; @@ -756,6 +812,8 @@ pub mod test_util { /// URL relative to which the leader aggregator's API endpoints are found. #[educe(Debug(method(std::fmt::Display::fmt)))] helper_aggregator_endpoint: Url, + /// The mode used for aggregation by the Helper (synchronous vs asynchronous). + helper_aggregation_mode: AggregationMode, /// HPKE configuration and private key used by the collector to decrypt aggregate shares. collector_hpke_keypair: HpkeKeypair, /// Token used to authenticate messages exchanged between the aggregators in the aggregation @@ -778,6 +836,7 @@ pub mod test_util { leader_aggregator_endpoint: Url, helper_aggregator_endpoint: Url, batch_mode: BatchMode, + helper_aggregation_mode: AggregationMode, vdaf: VdafInstance, vdaf_verify_key: SecretBytes, task_start: Option