From 0f5cd85a5b32df1baabffad48e71cf9e96164579 Mon Sep 17 00:00:00 2001 From: Ahmed Farghal Date: Fri, 22 Nov 2024 13:18:11 +0000 Subject: [PATCH 1/4] [TaskCenter] Stage 2 of refactoring --- Cargo.lock | 2 +- Cargo.toml | 1 + benchmarks/src/lib.rs | 2 +- crates/bifrost/benches/append_throughput.rs | 2 +- crates/bifrost/benches/util.rs | 4 +- crates/core/Cargo.toml | 2 +- crates/core/src/metadata/manager.rs | 4 +- crates/core/src/metadata/mod.rs | 24 ++ crates/core/src/task_center/extensions.rs | 171 ++++++++++++ crates/core/src/task_center/mod.rs | 262 +++++++++++------- crates/core/src/task_center/task.rs | 4 +- crates/core/src/task_center/task_kind.rs | 4 +- crates/ingress-http/Cargo.toml | 2 +- crates/metadata-store/src/local/tests.rs | 4 +- .../src/network_server/grpc_svc_handler.rs | 15 +- .../benches/basic_benchmark.rs | 6 +- crates/partition-store/src/tests/mod.rs | 4 +- crates/service-client/Cargo.toml | 2 +- crates/storage-query-datafusion/src/mocks.rs | 5 +- crates/worker/src/partition/leadership.rs | 4 +- .../src/partition_processor_manager/mod.rs | 9 +- .../persisted_lsn_watchdog.rs | 6 +- server/src/main.rs | 2 +- server/tests/common/replicated_loglet.rs | 24 +- tools/bifrost-benchpress/Cargo.toml | 2 +- tools/bifrost-benchpress/src/main.rs | 4 +- tools/bifrost-benchpress/src/write_to_read.rs | 4 +- 27 files changed, 402 insertions(+), 173 deletions(-) create mode 100644 crates/core/src/task_center/extensions.rs diff --git a/Cargo.lock b/Cargo.lock index c87af8888..f99eba173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6103,7 +6103,7 @@ dependencies = [ "once_cell", "opentelemetry", "parking_lot", - "pin-project", + "pin-project-lite", "prost", "prost-types", "rand", diff --git a/Cargo.toml b/Cargo.toml index dff64d45e..7154a9645 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,7 @@ opentelemetry_sdk = { version = "0.24.0" } parking_lot = { version = "0.12" } paste = "1.0" pin-project = "1.0" +pin-project-lite = { version = "0.2" } prost = { version = "0.13.1" } prost-build = { version = "0.13.1" } priority-queue = "2.0.3" diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 40269a6eb..d5521fd37 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -99,7 +99,7 @@ pub fn spawn_restate(config: Configuration) -> TaskCenter { restate_types::config::set_current_config(config.clone()); let updateable_config = Configuration::updateable(); - tc.block_on("benchmark", None, async { + tc.block_on(async { RocksDbManager::init(Constant::new(config.common)); tc.spawn(TaskKind::SystemBoot, "restate", None, async move { diff --git a/crates/bifrost/benches/append_throughput.rs b/crates/bifrost/benches/append_throughput.rs index 7e32818c9..6a2d5ed43 100644 --- a/crates/bifrost/benches/append_throughput.rs +++ b/crates/bifrost/benches/append_throughput.rs @@ -104,7 +104,7 @@ fn write_throughput_local_loglet(c: &mut Criterion) { provider, )); - let bifrost = tc.block_on("bifrost-init", None, async { + let bifrost = tc.block_on(async { let metadata = metadata(); let bifrost_svc = BifrostService::new(restate_core::task_center(), metadata) .enable_local_loglet(&Live::from_value(config)); diff --git a/crates/bifrost/benches/util.rs b/crates/bifrost/benches/util.rs index 871846ec9..d2af12f47 100644 --- a/crates/bifrost/benches/util.rs +++ b/crates/bifrost/benches/util.rs @@ -44,9 +44,7 @@ pub async fn spawn_environment( let metadata_writer = metadata_manager.writer(); tc.try_set_global_metadata(metadata.clone()); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(config.common)) - }); + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(config.common))); let logs = restate_types::logs::metadata::bootstrap_logs_metadata(provider, None, num_logs); diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 89e61849e..2aaab89b0 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -40,7 +40,7 @@ metrics = { workspace = true } opentelemetry = { workspace = true } once_cell = { workspace = true } parking_lot = { workspace = true } -pin-project = { workspace = true } +pin-project-lite = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } rand = { workspace = true } diff --git a/crates/core/src/metadata/manager.rs b/crates/core/src/metadata/manager.rs index 1ce54a7cc..1f5dacd00 100644 --- a/crates/core/src/metadata/manager.rs +++ b/crates/core/src/metadata/manager.rs @@ -617,7 +617,7 @@ mod tests { S: Fn(&mut T, Version), { let tc = TaskCenterBuilder::default().build()?; - tc.block_on("test", None, async move { + tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); @@ -689,7 +689,7 @@ mod tests { I: Fn(&mut T), { let tc = TaskCenterBuilder::default().build()?; - tc.block_on("test", None, async move { + tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); diff --git a/crates/core/src/metadata/mod.rs b/crates/core/src/metadata/mod.rs index 7591ebd9a..4735d292f 100644 --- a/crates/core/src/metadata/mod.rs +++ b/crates/core/src/metadata/mod.rs @@ -73,6 +73,30 @@ pub struct Metadata { } impl Metadata { + pub fn try_with_current(f: F) -> Option + where + F: Fn(&Metadata) -> R, + { + TaskCenter::with_metadata(|m| f(m)) + } + + pub fn try_current() -> Option { + TaskCenter::with_current(|tc| tc.metadata()) + } + + #[track_caller] + pub fn with_current(f: F) -> R + where + F: FnOnce(&Metadata) -> R, + { + TaskCenter::with_metadata(|m| f(m)).expect("called outside task-center scope") + } + + #[track_caller] + pub fn current() -> Metadata { + TaskCenter::with_current(|tc| tc.metadata()).expect("called outside task-center scope") + } + #[inline(always)] pub fn nodes_config_snapshot(&self) -> Arc { self.inner.nodes_config.load_full() diff --git a/crates/core/src/task_center/extensions.rs b/crates/core/src/task_center/extensions.rs new file mode 100644 index 000000000..6f66a9a61 --- /dev/null +++ b/crates/core/src/task_center/extensions.rs @@ -0,0 +1,171 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::future::Future; +use std::pin::Pin; + +use pin_project_lite::pin_project; +use tokio::task::futures::TaskLocalFuture; +use tokio_util::sync::CancellationToken; + +use crate::task_center::TaskContext; +use crate::Metadata; + +use super::{ + GlobalOverrides, TaskCenter, TaskId, TaskKind, CURRENT_TASK_CENTER, OVERRIDES, TASK_CONTEXT, +}; + +type TaskCenterFuture = + TaskLocalFuture>>; + +/// Adds the ability to override task-center for a future and all its children +pub trait TaskCenterFutureExt: Sized { + /// Ensures that a future will run within a task-center context. This will inherit the current + /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + fn in_tc(self, task_center: &TaskCenter) -> WithTaskCenter; + + /// Lets task-center treat this future as a psuedo-task. It gets its own TaskId and an + /// independent cancellation token. However, task-center will not spawn this as a task nor + /// manage its lifecycle. + fn in_tc_as_task( + self, + task_center: &TaskCenter, + kind: TaskKind, + name: &'static str, + ) -> WithTaskCenter; + + /// Ensures that a future will run within the task-center in current scope. This will inherit the current + /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + /// + /// This is useful when running dispatching a future as a task on an external runtime/thread, + /// or when running a future on tokio's JoinSet without representing those tokio tasks as + /// task-center tasks. However, in the latter case, it's preferred to use + /// [`Self::in_current_ts_as_task`] instead. + fn in_current_tc(self) -> WithTaskCenter; + + /// Attaches current task-center and lets it treat the future as a psuedo-task. It gets its own TaskId and an + /// independent cancellation token. However, task-center will not spawn this as a task nor + /// manage its lifecycle. + fn in_current_tc_as_task(self, kind: TaskKind, name: &'static str) -> WithTaskCenter; +} + +pin_project! { + pub struct WithTaskCenter { + #[pin] + inner_fut: TaskCenterFuture, + } +} + +impl TaskCenterFutureExt for F +where + F: Future, +{ + fn in_tc(self, task_center: &TaskCenter) -> WithTaskCenter { + let ctx = task_center.with_task_context(Clone::clone); + + let inner = CURRENT_TASK_CENTER.scope( + task_center.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(ctx, self), + ), + ); + WithTaskCenter { inner_fut: inner } + } + + fn in_tc_as_task( + self, + task_center: &TaskCenter, + kind: TaskKind, + name: &'static str, + ) -> WithTaskCenter { + let ctx = task_center.with_task_context(move |parent| TaskContext { + id: TaskId::default(), + name, + kind, + cancellation_token: CancellationToken::new(), + partition_id: parent.partition_id, + }); + + let inner = CURRENT_TASK_CENTER.scope( + task_center.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(ctx, self), + ), + ); + WithTaskCenter { inner_fut: inner } + } + + /// Ensures that a future will run within a task-center context. This will inherit the current + /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + fn in_current_tc(self) -> WithTaskCenter { + TaskCenter::with_current(|tc| self.in_tc(tc)) + } + + fn in_current_tc_as_task(self, kind: TaskKind, name: &'static str) -> WithTaskCenter { + TaskCenter::with_current(|tc| self.in_tc_as_task(tc, kind, name)) + } +} + +impl Future for WithTaskCenter { + type Output = T::Output; + + fn poll( + self: Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.inner_fut.poll(ctx) + } +} + +/// Adds the ability to override Metadata for a future and all its children +pub trait MetadataFutureExt: Sized { + /// Attaches restate's Metadata as an override on a future and all children futures or + /// task-center tasks spawned from it. + fn with_metadata(self, metadata: &Metadata) -> WithMetadata; +} + +pin_project! { + pub struct WithMetadata { + #[pin] + inner_fut: TaskLocalFuture, + } +} + +impl MetadataFutureExt for F +where + F: Future, +{ + fn with_metadata(self, metadata: &Metadata) -> WithMetadata { + let current_overrides = OVERRIDES.try_with(Clone::clone).unwrap_or_default(); + // temporary mute until overrides include more fields + #[allow(clippy::needless_update)] + let overrides = GlobalOverrides { + metadata: Some(metadata.clone()), + ..current_overrides + }; + let inner = OVERRIDES.scope(overrides, self); + WithMetadata { inner_fut: inner } + } +} + +impl Future for WithMetadata { + type Output = T::Output; + + fn poll( + self: Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.inner_fut.poll(ctx) + } +} diff --git a/crates/core/src/task_center/mod.rs b/crates/core/src/task_center/mod.rs index 67524f1c7..4c07d1446 100644 --- a/crates/core/src/task_center/mod.rs +++ b/crates/core/src/task_center/mod.rs @@ -9,22 +9,25 @@ // by the Apache License, Version 2.0. mod builder; +mod extensions; mod runtime; mod task; mod task_kind; pub use builder::*; +pub use extensions::*; pub use runtime::*; pub use task::*; pub use task_kind::*; use std::collections::HashMap; +use std::future::Future; use std::panic::AssertUnwindSafe; use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; -use futures::{Future, FutureExt}; +use futures::FutureExt; use metrics::{counter, gauge}; use parking_lot::Mutex; use tokio::runtime::RuntimeMetrics; @@ -46,9 +49,18 @@ const EXIT_CODE_FAILURE: i32 = 1; task_local! { // Current task center - static CURRENT_TASK_CENTER: TaskCenter; + pub(self) static CURRENT_TASK_CENTER: TaskCenter; // Tasks provide access to their context - static CONTEXT: TaskContext; + static TASK_CONTEXT: TaskContext; + + /// Access to a task-level global overrides. + static OVERRIDES: GlobalOverrides; +} + +#[derive(Default, Clone)] +struct GlobalOverrides { + metadata: Option, + //config: Arc, } #[derive(Debug, thiserror::Error)] @@ -75,6 +87,13 @@ impl TaskCenter { ingress_runtime: Option, ) -> Self { metric_definitions::describe_metrics(); + let root_task_context = TaskContext { + id: TaskId::ROOT, + name: "::", + kind: TaskKind::InPlace, + cancellation_token: CancellationToken::new(), + partition_id: None, + }; Self { inner: Arc::new(TaskCenterInner { start_time: Instant::now(), @@ -88,10 +107,39 @@ impl TaskCenter { managed_tasks: Mutex::new(HashMap::new()), global_metadata: OnceLock::new(), managed_runtimes: Mutex::new(HashMap::with_capacity(64)), + root_task_context, }), } } + pub fn try_current() -> Option { + Self::try_with_current(Clone::clone) + } + + pub fn try_with_current(f: F) -> Option + where + F: FnOnce(&TaskCenter) -> R, + { + CURRENT_TASK_CENTER.try_with(|tc| f(tc)).ok() + } + + /// Get the current task center. Use this to spawn tasks on the current task center. + /// This must be called from within a task-center task. + #[track_caller] + pub fn current() -> TaskCenter { + Self::with_current(Clone::clone) + } + + #[track_caller] + pub fn with_current(f: F) -> R + where + F: FnOnce(&TaskCenter) -> R, + { + CURRENT_TASK_CENTER + .try_with(|tc| f(tc)) + .expect("called outside task-center task") + } + pub fn default_runtime_metrics(&self) -> RuntimeMetrics { self.inner.default_runtime_handle.metrics() } @@ -137,10 +185,6 @@ impl TaskCenter { self.inner.current_exit_code.load(Ordering::Relaxed) } - pub fn metadata(&self) -> Option<&Metadata> { - self.inner.global_metadata.get() - } - fn submit_runtime_metrics(runtime: &'static str, stats: RuntimeMetrics) { gauge!("restate.tokio.num_workers", "runtime" => runtime).set(stats.num_workers() as f64); gauge!("restate.tokio.blocking_threads", "runtime" => runtime) @@ -173,15 +217,40 @@ impl TaskCenter { } } - /// Clone the currently set METADATA (and if Some()). Otherwise falls back to global metadata. + pub(crate) fn metadata(&self) -> Option { + match OVERRIDES.try_with(|overrides| overrides.metadata.clone()) { + Ok(Some(o)) => Some(o), + // No metadata override, use task-center-level metadata + _ => self.inner.global_metadata.get().cloned(), + } + } + #[track_caller] - fn clone_metadata(&self) -> Option { - CONTEXT - .try_with(|m| m.metadata.clone()) + /// Attempt to access task-level overridden metadata first, if we don't have an override, + /// fallback to task-center's level metadata. + pub(crate) fn with_metadata(f: F) -> Option + where + F: FnOnce(&Metadata) -> R, + { + OVERRIDES + .try_with(|overrides| match &overrides.metadata { + Some(m) => Some(f(m)), + // No metadata override, use task-center-level metadata + None => CURRENT_TASK_CENTER.with(|tc| tc.inner.global_metadata.get().map(f)), + }) .ok() .flatten() - .or_else(|| self.inner.global_metadata.get().cloned()) } + + fn with_task_context(&self, f: F) -> R + where + F: Fn(&TaskContext) -> R, + { + TASK_CONTEXT + .try_with(|ctx| f(ctx)) + .unwrap_or_else(|_| f(&self.inner.root_task_context)) + } + /// Triggers a shutdown of the system. All running tasks will be asked gracefully /// to cancel but we will only wait for tasks with a TaskKind that has the property /// "OnCancel" set to "wait". @@ -231,14 +300,12 @@ impl TaskCenter { { let inner = self.inner.clone(); let id = TaskId::default(); - let metadata = self.clone_metadata(); let context = TaskContext { id, name, kind, partition_id, cancellation_token: cancel.clone(), - metadata, }; let task = Arc::new(Task { context: context.clone(), @@ -323,14 +390,12 @@ impl TaskCenter { let cancel = CancellationToken::new(); let id = TaskId::default(); - let metadata = self.clone_metadata(); let context = TaskContext { id, name, kind, partition_id, cancellation_token: cancel.clone(), - metadata, }; let fut = unmanaged_wrapper(self.clone(), context, future); @@ -374,8 +439,7 @@ impl TaskCenter { kind, cancellation_token: cancellation_token.clone(), // We must be within task-context already. let's get inherit partition_id - partition_id: CONTEXT.with(|c| c.partition_id), - metadata: Some(metadata()), + partition_id: self.with_task_context(|c| c.partition_id), }; let task = Arc::new(Task { @@ -458,7 +522,6 @@ impl TaskCenter { kind: root_task_kind, cancellation_token: cancel.clone(), partition_id, - metadata: Some(metadata()), }; let (result_tx, result_rx) = oneshot::channel(); @@ -513,17 +576,15 @@ impl TaskCenter { return Err(ShutdownError); } - let parent_id = - current_task_id().expect("spawn_child called outside of a task-center task"); - // From this point onwards, we unwrap() directly with the assumption that we are in task-center - // context and that the previous (expect) guards against reaching this point if we are - // outside task-center. - let parent_kind = current_task_kind().unwrap(); - let parent_name = CONTEXT.try_with(|ctx| ctx.name).unwrap(); + let (parent_id, parent_name, parent_kind, cancel) = self.with_task_context(|ctx| { + ( + ctx.id, + ctx.name, + ctx.kind, + ctx.cancellation_token.child_token(), + ) + }); - let cancel = CONTEXT - .try_with(|ctx| ctx.cancellation_token.child_token()) - .unwrap(); let result = self.spawn_inner(kind, name, partition_id, cancel, future); trace!( @@ -540,17 +601,17 @@ impl TaskCenter { pub fn spawn_blocking_unmanaged( &self, name: &'static str, - partition_id: Option, future: F, ) -> tokio::task::JoinHandle where F: Future + Send + 'static, O: Send + 'static, { - let tc = self.clone(); + let rt_handle = self.inner.default_runtime_handle.clone(); + let future = future.in_tc_as_task(self, TaskKind::InPlace, name); self.inner .default_runtime_handle - .spawn_blocking(move || tc.block_on(name, partition_id, future)) + .spawn_blocking(move || rt_handle.block_on(future)) } /// Cancelling the child will not cancel the parent. Note that parent task will not @@ -566,9 +627,7 @@ impl TaskCenter { where F: Future> + Send + 'static, { - let cancel = CONTEXT - .try_with(|ctx| ctx.cancellation_token.child_token()) - .expect("spawning inside task-center context"); + let cancel = self.with_task_context(|ctx| ctx.cancellation_token.child_token()); self.spawn_inner(kind, name, partition_id, cancel, future) } @@ -644,18 +703,13 @@ impl TaskCenter { /// Sets the current task_center but doesn't create a task. Use this when you need to run a /// future within task_center scope. - pub fn block_on( - &self, - name: &'static str, - partition_id: Option, - future: F, - ) -> O + pub fn block_on(&self, future: F) -> O where F: Future, { self.inner .default_runtime_handle - .block_on(self.run_in_scope(name, partition_id, future)) + .block_on(future.in_tc(self)) } /// Sets the current task_center but doesn't create a task. Use this when you need to run a @@ -669,46 +723,38 @@ impl TaskCenter { where F: Future, { - let cancel = CancellationToken::new(); + let cancellation_token = CancellationToken::new(); let id = TaskId::default(); - let metadata = self.clone_metadata(); - let context = TaskContext { + let ctx = TaskContext { id, name, kind: TaskKind::InPlace, - cancellation_token: cancel.clone(), + cancellation_token: cancellation_token.clone(), partition_id, - metadata, }; CURRENT_TASK_CENTER - .scope(self.clone(), CONTEXT.scope(context, future)) + .scope( + self.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(ctx, future), + ), + ) .await } /// Sets the current task_center but doesn't create a task. Use this when you need to run a /// closure within task_center scope. - pub fn run_in_scope_sync( - &self, - name: &'static str, - partition_id: Option, - f: F, - ) -> O + pub fn run_in_scope_sync(&self, f: F) -> O where F: FnOnce() -> O, { - let cancel = CancellationToken::new(); - let id = TaskId::default(); - let metadata = self.clone_metadata(); - let context = TaskContext { - id, - name, - kind: TaskKind::InPlace, - partition_id, - cancellation_token: cancel.clone(), - metadata, - }; - CURRENT_TASK_CENTER.sync_scope(self.clone(), || CONTEXT.sync_scope(context, f)) + CURRENT_TASK_CENTER.sync_scope(self.clone(), || { + OVERRIDES.sync_scope(OVERRIDES.try_with(Clone::clone).unwrap_or_default(), || { + TASK_CONTEXT.sync_scope(self.with_task_context(Clone::clone), f) + }) + }) } /// Take control over the running task from task-center. This returns None if the task was not @@ -835,10 +881,11 @@ struct TaskCenterInner { current_exit_code: AtomicI32, managed_tasks: Mutex>>, global_metadata: OnceLock, + root_task_context: TaskContext, } /// This wrapper function runs in a newly-spawned task. It initializes the -/// task-local variables and calls the payload function. +/// task-local variables and wraps the inner future. async fn wrapper(task_center: TaskCenter, context: TaskContext, future: F) where F: Future> + 'static, @@ -849,12 +896,16 @@ where let result = CURRENT_TASK_CENTER .scope( task_center.clone(), - CONTEXT.scope(context, { - // We use AssertUnwindSafe here so that the wrapped function - // doesn't need to be UnwindSafe. We should not do anything after - // unwinding that'd risk us being in unwind-unsafe behavior. - AssertUnwindSafe(future).catch_unwind() - }), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope( + context, + // We use AssertUnwindSafe here so that the wrapped function + // doesn't need to be UnwindSafe. We should not do anything after + // unwinding that'd risk us being in unwind-unsafe behavior. + AssertUnwindSafe(future).catch_unwind(), + ), + ), ) .await; task_center.on_finish(id, result).await; @@ -868,29 +919,21 @@ where trace!(kind = ?context.kind, name = ?context.name, "Starting task {}", context.id); CURRENT_TASK_CENTER - .scope(task_center.clone(), CONTEXT.scope(context, future)) + .scope( + task_center.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(context, future), + ), + ) .await } -/// The current task-center task kind. This returns None if we are not in the scope -/// of a task-center task. -pub fn current_task_kind() -> Option { - CONTEXT.try_with(|ctx| ctx.kind).ok() -} - -/// The current task-center task Id. This returns None if we are not in the scope -/// of a task-center task. -pub fn current_task_id() -> Option { - CONTEXT.try_with(|ctx| ctx.id).ok() -} - /// Access to global metadata handle. This is available in task-center tasks only! #[track_caller] pub fn metadata() -> Metadata { - CONTEXT - .try_with(|ctx| ctx.metadata.clone()) - .expect("metadata() called outside task-center scope") - .expect("metadata() called before global metadata was set") + // todo: migrate call-sites + Metadata::current() } #[track_caller] @@ -898,32 +941,41 @@ pub fn with_metadata(f: F) -> R where F: FnOnce(&Metadata) -> R, { - CURRENT_TASK_CENTER.with(|tc| { - f(tc.metadata() - .expect("metadata must be set. Is global metadata set?")) - }) + Metadata::with_current(f) } /// Access to this node id. This is available in task-center tasks only! #[track_caller] pub fn my_node_id() -> GenerationalNodeId { - CONTEXT - .try_with(|ctx| ctx.metadata.as_ref().map(|m| m.my_node_id())) - .expect("my_node_id() called outside task-center scope") - .expect("my_node_id() called before global metadata was set") + // todo: migrate call-sites + Metadata::with_current(|m| m.my_node_id()) +} + +/// The current task-center task Id. This returns None if we are not in the scope +/// of a task-center task. +pub fn current_task_id() -> Option { + TASK_CONTEXT + .try_with(|ctx| Some(ctx.id)) + .unwrap_or(TaskCenter::try_with_current(|tc| { + tc.inner.root_task_context.id + })) } /// The current partition Id associated to the running task-center task. pub fn current_task_partition_id() -> Option { - CONTEXT.try_with(|ctx| ctx.partition_id).ok().flatten() + TASK_CONTEXT + .try_with(|ctx| Some(ctx.partition_id)) + .unwrap_or(TaskCenter::try_with_current(|tc| { + tc.inner.root_task_context.partition_id + })) + .flatten() } /// Get the current task center. Use this to spawn tasks on the current task center. /// This must be called from within a task-center task. pub fn task_center() -> TaskCenter { - CURRENT_TASK_CENTER - .try_with(|t| t.clone()) - .expect("task_center() called in a task-center task") + // migrate call-sites + TaskCenter::current() } /// A Future that can be used to check if the current task has been requested to @@ -939,7 +991,7 @@ pub async fn cancellation_watcher() { /// cancel_task() call, or if it's a child and the parent is being cancelled by a /// cancel_task() call, this cancellation token will be set to cancelled. pub fn cancellation_token() -> CancellationToken { - let res = CONTEXT.try_with(|ctx| ctx.cancellation_token.clone()); + let res = TASK_CONTEXT.try_with(|ctx| ctx.cancellation_token.clone()); if cfg!(any(test, feature = "test-util")) { // allow in tests to call from non-task-center tasks. @@ -951,7 +1003,7 @@ pub fn cancellation_token() -> CancellationToken { /// Has the current task been requested to cancel? pub fn is_cancellation_requested() -> bool { - CONTEXT + TASK_CONTEXT .try_with(|ctx| ctx.cancellation_token.is_cancelled()) .unwrap_or_else(|_| { if cfg!(any(test, feature = "test-util")) { @@ -966,7 +1018,6 @@ mod tests { use super::*; use googletest::prelude::*; - use restate_test_util::assert_eq; use restate_types::config::CommonOptionsBuilder; use tracing_test::traced_test; @@ -986,7 +1037,6 @@ mod tests { tc.spawn(TaskKind::RoleRunner, "worker-role", None, async { info!("Hello async"); tokio::time::sleep(Duration::from_secs(10)).await; - assert_eq!(TaskKind::RoleRunner, current_task_kind().unwrap()); info!("Bye async"); Ok(()) }) diff --git a/crates/core/src/task_center/task.rs b/crates/core/src/task_center/task.rs index c595e5674..beb1b08f8 100644 --- a/crates/core/src/task_center/task.rs +++ b/crates/core/src/task_center/task.rs @@ -18,7 +18,7 @@ use tokio_util::sync::CancellationToken; use restate_types::identifiers::PartitionId; use super::{TaskId, TaskKind}; -use crate::{Metadata, ShutdownError}; +use crate::ShutdownError; #[derive(Clone)] pub(super) struct TaskContext { @@ -31,8 +31,6 @@ pub(super) struct TaskContext { /// Tasks associated with a specific partition ID will have this set. This allows /// for cancellation of tasks associated with that partition. pub(super) partition_id: Option, - /// Access to a locally-cached metadata view. - pub(super) metadata: Option, } pub(super) struct Task { diff --git a/crates/core/src/task_center/task_kind.rs b/crates/core/src/task_center/task_kind.rs index 39008c4dc..b18ab226a 100644 --- a/crates/core/src/task_center/task_kind.rs +++ b/crates/core/src/task_center/task_kind.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use strum::EnumProperty; -static NEXT_TASK_ID: AtomicU64 = const { AtomicU64::new(0) }; +static NEXT_TASK_ID: AtomicU64 = const { AtomicU64::new(1) }; #[derive( Clone, @@ -36,6 +36,8 @@ impl Default for TaskId { } impl TaskId { + pub const ROOT: TaskId = TaskId(0); + pub fn new() -> Self { Default::default() } diff --git a/crates/ingress-http/Cargo.toml b/crates/ingress-http/Cargo.toml index fcc77e4fd..8b47165b4 100644 --- a/crates/ingress-http/Cargo.toml +++ b/crates/ingress-http/Cargo.toml @@ -37,7 +37,7 @@ hyper-util = { workspace = true, features = ["http1", "http2", "server", "tokio" metrics = { workspace = true } opentelemetry = { workspace = true } opentelemetry_sdk = { workspace = true } -pin-project-lite = "0.2.13" +pin-project-lite = { workspace = true } schemars = { workspace = true, optional = true } serde = { workspace = true } serde_with = { workspace = true } diff --git a/crates/metadata-store/src/local/tests.rs b/crates/metadata-store/src/local/tests.rs index bca1aa32c..8111f3aaa 100644 --- a/crates/metadata-store/src/local/tests.rs +++ b/crates/metadata-store/src/local/tests.rs @@ -331,9 +331,7 @@ async fn create_test_environment( let task_center = &env.tc; - task_center.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(config.clone().map(|c| &c.common)) - }); + task_center.run_in_scope_sync(|| RocksDbManager::init(config.clone().map(|c| &c.common))); let client = start_metadata_store( config.pinned().common.metadata_store_client.clone(), diff --git a/crates/node/src/network_server/grpc_svc_handler.rs b/crates/node/src/network_server/grpc_svc_handler.rs index cbdb101ae..5ff24a22d 100644 --- a/crates/node/src/network_server/grpc_svc_handler.rs +++ b/crates/node/src/network_server/grpc_svc_handler.rs @@ -21,7 +21,9 @@ use restate_core::network::protobuf::node_svc::{ }; use restate_core::network::ConnectionManager; use restate_core::network::{ProtocolError, TransportConnect}; -use restate_core::{metadata, MetadataKind, TargetVersion, TaskCenter}; +use restate_core::{ + metadata, MetadataKind, TargetVersion, TaskCenter, TaskCenterFutureExt, TaskKind, +}; use restate_types::health::Health; use restate_types::nodes_config::Role; use restate_types::protobuf::node::Message; @@ -61,7 +63,7 @@ impl NodeSvc for NodeSvcHandler { let metadata_server_status = self.health.current_metadata_server_status(); let log_server_status = self.health.current_log_server_status(); let age_s = self.task_center.age().as_secs(); - self.task_center.run_in_scope_sync("get_ident", None, || { + self.task_center.run_in_scope_sync(|| { let metadata = metadata(); Ok(Response::new(IdentResponse { status: node_status.into(), @@ -98,12 +100,9 @@ impl NodeSvc for NodeSvcHandler { let incoming = request.into_inner(); let transformed = incoming.map(|x| x.map_err(ProtocolError::from)); let output_stream = self - .task_center - .run_in_scope( - "accept-connection", - None, - self.connections.accept_incoming_connection(transformed), - ) + .connections + .accept_incoming_connection(transformed) + .in_current_tc_as_task(TaskKind::InPlace, "accept-connection") .await?; // For uniformity with outbound connections, we map all responses to Ok, we never rely on diff --git a/crates/partition-store/benches/basic_benchmark.rs b/crates/partition-store/benches/basic_benchmark.rs index 6802c017b..135de7b39 100644 --- a/crates/partition-store/benches/basic_benchmark.rs +++ b/crates/partition-store/benches/basic_benchmark.rs @@ -47,10 +47,8 @@ fn basic_writing_reading_benchmark(c: &mut Criterion) { .expect("task_center builds"); let worker_options = WorkerOptions::default(); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); - let rocksdb = tc.block_on("test-setup", None, async { + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + let rocksdb = tc.block_on(async { // // setup // diff --git a/crates/partition-store/src/tests/mod.rs b/crates/partition-store/src/tests/mod.rs index 0328d9d7e..b986a3672 100644 --- a/crates/partition-store/src/tests/mod.rs +++ b/crates/partition-store/src/tests/mod.rs @@ -52,9 +52,7 @@ async fn storage_test_environment_with_manager() -> (PartitionStoreManager, Part .ingress_runtime_handle(tokio::runtime::Handle::current()) .build() .expect("task_center builds"); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let worker_options = Live::from_value(WorkerOptions::default()); let manager = PartitionStoreManager::create( worker_options.clone().map(|c| &c.storage), diff --git a/crates/service-client/Cargo.toml b/crates/service-client/Cargo.toml index c3835836b..ca85eb334 100644 --- a/crates/service-client/Cargo.toml +++ b/crates/service-client/Cargo.toml @@ -51,7 +51,7 @@ aws-smithy-async = {version = "1.2.1", default-features = false} aws-smithy-runtime = {version = "1.6.2", default-features = false} aws-smithy-runtime-api = {version = "1.7.1", default-features = false} aws-smithy-types = { version = "1.2.0", default-features = false} -pin-project-lite = "0.2.13" +pin-project-lite = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/storage-query-datafusion/src/mocks.rs b/crates/storage-query-datafusion/src/mocks.rs index 5d5bddefa..46428b066 100644 --- a/crates/storage-query-datafusion/src/mocks.rs +++ b/crates/storage-query-datafusion/src/mocks.rs @@ -161,9 +161,8 @@ impl MockQueryEngine { + 'static, ) -> Self { // Prepare Rocksdb - task_center().run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + task_center() + .run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let worker_options = Live::from_value(WorkerOptions::default()); let manager = PartitionStoreManager::create( worker_options.clone().map(|c| &c.storage), diff --git a/crates/worker/src/partition/leadership.rs b/crates/worker/src/partition/leadership.rs index a0dbacf2e..b613078de 100644 --- a/crates/worker/src/partition/leadership.rs +++ b/crates/worker/src/partition/leadership.rs @@ -1136,9 +1136,7 @@ mod tests { let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let bifrost = tc .run_in_scope( diff --git a/crates/worker/src/partition_processor_manager/mod.rs b/crates/worker/src/partition_processor_manager/mod.rs index c9f06e805..ade4c7574 100644 --- a/crates/worker/src/partition_processor_manager/mod.rs +++ b/crates/worker/src/partition_processor_manager/mod.rs @@ -651,7 +651,6 @@ impl PartitionProcessorManager { // where doing otherwise appears to starve the Tokio event loop, causing very slow startup. let handle = self.task_center.spawn_blocking_unmanaged( "starting-partition-processor", - Some(partition_id), starting_task.run(), ); @@ -955,11 +954,9 @@ mod tests { TestCoreEnvBuilder::with_incoming_only_connector().set_nodes_config(nodes_config); let health_status = HealthStatus::default(); - env_builder - .tc - .run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())); - }); + env_builder.tc.run_in_scope_sync(|| { + RocksDbManager::init(Constant::new(CommonOptions::default())); + }); let bifrost_svc = BifrostService::new(env_builder.tc.clone(), env_builder.metadata.clone()) .with_factory(memory_loglet::Factory::default()); diff --git a/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs b/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs index 10dab0651..99b39c523 100644 --- a/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs +++ b/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs @@ -178,9 +178,9 @@ mod tests { let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); - node_env.tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + node_env + .tc + .run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let all_partition_keys = RangeInclusive::new(0, PartitionKey::MAX); let partition_store_manager = PartitionStoreManager::create( diff --git a/server/src/main.rs b/server/src/main.rs index cafcf36eb..98004b374 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -161,7 +161,7 @@ fn main() { .options(Configuration::pinned().common.clone()) .build() .expect("task_center builds"); - tc.block_on("main", None, { + tc.block_on({ let tc = tc.clone(); async move { // Apply tracing config globally diff --git a/server/tests/common/replicated_loglet.rs b/server/tests/common/replicated_loglet.rs index 46845df26..d326998ef 100644 --- a/server/tests/common/replicated_loglet.rs +++ b/server/tests/common/replicated_loglet.rs @@ -17,6 +17,7 @@ use googletest::IntoTestResult; use restate_bifrost::{loglet::Loglet, Bifrost, BifrostAdmin}; use restate_core::metadata_store::Precondition; +use restate_core::TaskCenterFutureExt; use restate_core::{metadata_store::MetadataStoreClient, MetadataWriter, TaskCenterBuilder}; use restate_local_cluster_runner::{ cluster::{Cluster, MaybeTempDir, StartedCluster}, @@ -130,7 +131,7 @@ where // this will still respect LOCAL_CLUSTER_RUNNER_RETAIN_TEMPDIR=true let base_dir: MaybeTempDir = tempfile::tempdir()?.into(); - tc.run_in_scope("test", None, async { + async { RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); let cluster = Cluster::builder() @@ -177,19 +178,16 @@ where .await?; // global metadata should now be set, running in scope sets it in the task center context - tc.run_in_scope( - "test-fn", - None, - future(TestEnv { - bifrost, - loglet, - cluster, - metadata_writer, - metadata_store_client, - }), - ) + future(TestEnv { + bifrost, + loglet, + cluster, + metadata_writer, + metadata_store_client, + }) .await - }) + } + .in_tc(&tc) .await?; tc.shutdown_node("test completed", 0).await; diff --git a/tools/bifrost-benchpress/Cargo.toml b/tools/bifrost-benchpress/Cargo.toml index 7563ace25..433a8642f 100644 --- a/tools/bifrost-benchpress/Cargo.toml +++ b/tools/bifrost-benchpress/Cargo.toml @@ -26,7 +26,7 @@ restate-metadata-store = { workspace = true } restate-rocksdb = { workspace = true } restate-test-util = { workspace = true } restate-tracing-instrumentation = { workspace = true, features = ["rt-tokio"] } -restate-types = { workspace = true, features = ["test-util"] } +restate-types = { workspace = true, features = ["test-util", "clap"] } anyhow = { workspace = true } bytes = { workspace = true } diff --git a/tools/bifrost-benchpress/src/main.rs b/tools/bifrost-benchpress/src/main.rs index 3070f2e41..abe9d6a4a 100644 --- a/tools/bifrost-benchpress/src/main.rs +++ b/tools/bifrost-benchpress/src/main.rs @@ -92,7 +92,7 @@ fn main() -> anyhow::Result<()> { let (tc, bifrost) = spawn_environment(Configuration::updateable(), 1); let task_center = tc.clone(); let args = cli_args.clone(); - tc.block_on("benchpress", None, async move { + tc.block_on(async move { let tracing_guard = init_tracing_and_logging(&config.common, "Bifrost benchpress") .expect("failed to configure logging and tracing!"); @@ -149,7 +149,7 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, .expect("task_center builds"); let task_center = tc.clone(); - let bifrost = tc.block_on("spawn", None, async move { + let bifrost = tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); diff --git a/tools/bifrost-benchpress/src/write_to_read.rs b/tools/bifrost-benchpress/src/write_to_read.rs index ee8e6571f..8b391c543 100644 --- a/tools/bifrost-benchpress/src/write_to_read.rs +++ b/tools/bifrost-benchpress/src/write_to_read.rs @@ -17,7 +17,7 @@ use hdrhistogram::Histogram; use tracing::info; use restate_bifrost::Bifrost; -use restate_core::{TaskCenter, TaskHandle, TaskKind}; +use restate_core::{Metadata, TaskCenter, TaskHandle, TaskKind}; use restate_types::logs::{KeyFilter, LogId, Lsn, SequenceNumber, WithKeys}; use crate::util::{print_latencies, DummyPayload}; @@ -140,7 +140,7 @@ pub async fn run( println!( "Log Chain: {:#?}", - tc.metadata().unwrap().logs_ref().chain(&LOG_ID).unwrap() + Metadata::current().logs_ref().chain(&LOG_ID).unwrap() ); println!("Payload size per record: {} bytes", args.payload_size); println!(); From edf7f314045223e418d27b73708ee8e05fcdc8c0 Mon Sep 17 00:00:00 2001 From: Ahmed Farghal Date: Fri, 22 Nov 2024 17:21:43 +0000 Subject: [PATCH 2/4] [TaskCenter] Stage 3 --- .../cluster_state_refresher.rs | 23 +- .../admin/src/cluster_controller/scheduler.rs | 200 +- .../admin/src/cluster_controller/service.rs | 19 +- .../src/cluster_controller/service/state.rs | 1 - crates/bifrost/benches/util.rs | 36 +- .../local_loglet/log_store_writer.rs | 5 +- .../src/providers/replicated_loglet/loglet.rs | 2 - .../providers/replicated_loglet/network.rs | 32 +- .../replicated_loglet/tasks/find_tail.rs | 1 - .../providers/replicated_loglet/tasks/seal.rs | 28 +- crates/core/src/metadata/manager.rs | 26 +- crates/core/src/metadata/mod.rs | 7 +- crates/core/src/network/connection_manager.rs | 7 +- crates/core/src/network/net_util.rs | 27 +- crates/core/src/task_center/mod.rs | 45 +- crates/core/src/task_center/task_kind.rs | 2 + crates/core/src/test_env.rs | 147 +- crates/ingress-kafka/src/consumer_task.rs | 2 +- .../src/subscription_controller.rs | 10 +- crates/log-server/src/loglet_worker.rs | 1658 +++++++++-------- crates/log-server/src/network.rs | 25 +- .../src/rocksdb_logstore/builder.rs | 5 +- .../log-server/src/rocksdb_logstore/store.rs | 32 +- .../log-server/src/rocksdb_logstore/writer.rs | 5 +- crates/log-server/src/service.rs | 16 +- crates/metadata-store/src/local/service.rs | 29 +- crates/node/src/lib.rs | 8 +- crates/node/src/roles/admin.rs | 11 +- crates/node/src/roles/base.rs | 6 +- crates/node/src/roles/worker.rs | 9 +- crates/worker/src/lib.rs | 15 +- crates/worker/src/partition/leadership.rs | 16 +- crates/worker/src/partition/shuffle.rs | 129 +- .../message_handler.rs | 84 +- .../src/partition_processor_manager/mod.rs | 7 +- .../spawn_processor_task.rs | 23 +- tools/bifrost-benchpress/src/main.rs | 4 +- tools/restatectl/src/commands/log/dump_log.rs | 14 +- tools/restatectl/src/commands/metadata/get.rs | 45 +- .../restatectl/src/commands/metadata/patch.rs | 45 +- .../restatectl/src/environment/task_center.rs | 8 +- 41 files changed, 1361 insertions(+), 1453 deletions(-) diff --git a/crates/admin/src/cluster_controller/cluster_state_refresher.rs b/crates/admin/src/cluster_controller/cluster_state_refresher.rs index 54ff22a63..a8fd1968a 100644 --- a/crates/admin/src/cluster_controller/cluster_state_refresher.rs +++ b/crates/admin/src/cluster_controller/cluster_state_refresher.rs @@ -19,7 +19,9 @@ use restate_core::network::rpc_router::RpcRouter; use restate_core::network::{ MessageRouterBuilder, NetworkError, Networking, Outgoing, TransportConnect, }; -use restate_core::{Metadata, ShutdownError, TaskCenter, TaskHandle}; +use restate_core::{ + Metadata, ShutdownError, TaskCenter, TaskCenterFutureExt, TaskHandle, TaskKind, +}; use restate_types::cluster::cluster_state::{ AliveNode, ClusterState, DeadNode, NodeState, SuspectNode, }; @@ -28,7 +30,6 @@ use restate_types::time::MillisSinceEpoch; use restate_types::Version; pub struct ClusterStateRefresher { - task_center: TaskCenter, metadata: Metadata, network_sender: Networking, get_state_router: RpcRouter, @@ -39,7 +40,6 @@ pub struct ClusterStateRefresher { impl ClusterStateRefresher { pub fn new( - task_center: TaskCenter, metadata: Metadata, network_sender: Networking, router_builder: &mut MessageRouterBuilder, @@ -57,7 +57,6 @@ impl ClusterStateRefresher { watch::channel(Arc::from(initial_state)); Self { - task_center, metadata, network_sender, get_state_router, @@ -97,7 +96,6 @@ impl ClusterStateRefresher { } self.in_flight_refresh = Self::start_refresh_task( - self.task_center.clone(), self.get_state_router.clone(), self.network_sender.clone(), Arc::clone(&self.cluster_state_update_tx), @@ -108,13 +106,11 @@ impl ClusterStateRefresher { } fn start_refresh_task( - tc: TaskCenter, get_state_router: RpcRouter, network_sender: Networking, cluster_state_tx: Arc>>, metadata: Metadata, ) -> Result>>, ShutdownError> { - let task_center = tc.clone(); let refresh = async move { let last_state = Arc::clone(&cluster_state_tx.borrow()); // make sure we have a partition table that equals or newer than last refresh @@ -137,13 +133,12 @@ impl ClusterStateRefresher { for (_, node_config) in nodes_config.iter() { let node_id = node_config.current_generation; let rpc_router = get_state_router.clone(); - let tc = tc.clone(); let network_sender = network_sender.clone(); join_set .build_task() .name("get-nodes-state") - .spawn(async move { - tc.run_in_scope("get-node-state", None, async move { + .spawn( + async move { match network_sender.node_connection(node_id).await { Ok(connection) => { let outgoing = Outgoing::new(node_id, GetNodeState::default()) @@ -161,9 +156,9 @@ impl ClusterStateRefresher { } Err(network_error) => (node_id, Err(network_error)), } - }) - .await - }) + } + .in_current_tc_as_task(TaskKind::InPlace, "get-nodes-state"), + ) .expect("to spawn task"); } while let Some(Ok((node_id, result))) = join_set.join_next().await { @@ -233,7 +228,7 @@ impl ClusterStateRefresher { Ok(()) }; - let handle = task_center.spawn_unmanaged( + let handle = TaskCenter::current().spawn_unmanaged( restate_core::TaskKind::Disposable, "cluster-state-refresh", None, diff --git a/crates/admin/src/cluster_controller/scheduler.rs b/crates/admin/src/cluster_controller/scheduler.rs index 63c80930e..ca40af8c5 100644 --- a/crates/admin/src/cluster_controller/scheduler.rs +++ b/crates/admin/src/cluster_controller/scheduler.rs @@ -91,8 +91,6 @@ impl PartitionProcessorPlacementHints for & pub struct Scheduler { scheduling_plan: SchedulingPlan, last_updated_scheduling_plan: Instant, - - task_center: TaskCenter, metadata_store_client: MetadataStoreClient, networking: Networking, } @@ -104,7 +102,6 @@ pub struct Scheduler { impl Scheduler { pub async fn init( configuration: &Configuration, - task_center: TaskCenter, metadata_store_client: MetadataStoreClient, networking: Networking, ) -> Result { @@ -120,7 +117,6 @@ impl Scheduler { Ok(Self { scheduling_plan, last_updated_scheduling_plan: Instant::now(), - task_center, metadata_store_client, networking, }) @@ -478,10 +474,9 @@ impl Scheduler { commands, }; - self.task_center.spawn_child( + TaskCenter::spawn_child( TaskKind::Disposable, "send-control-processors-to-node", - None, { let networking = self.networking.clone(); async move { @@ -584,7 +579,9 @@ mod tests { HashSet, PartitionProcessorPlacementHints, Scheduler, }; use restate_core::network::{ForwardingHandler, Incoming, MessageCollectorMockConnector}; - use restate_core::{metadata, TaskCenterBuilder, TestCoreEnv, TestCoreEnvBuilder}; + use restate_core::{ + metadata, TaskCenterBuilder, TaskCenterFutureExt, TestCoreEnv, TestCoreEnvBuilder, + }; use restate_types::cluster::cluster_state::{ AliveNode, ClusterState, DeadNode, NodeState, PartitionProcessorStatus, RunMode, }; @@ -618,44 +615,41 @@ mod tests { #[test(tokio::test)] async fn empty_leadership_changes_dont_modify_plan() -> googletest::Result<()> { let test_env = TestCoreEnv::create_with_single_node(0, 0).await; - let tc = test_env.tc.clone(); let metadata_store_client = test_env.metadata_store_client.clone(); let networking = test_env.networking.clone(); - test_env - .tc - .run_in_scope("test", None, async { - let initial_scheduling_plan = metadata_store_client - .get::(SCHEDULING_PLAN_KEY.clone()) - .await - .expect("scheduling plan"); - let mut scheduler = Scheduler::init( - Configuration::pinned().as_ref(), - tc, - metadata_store_client.clone(), - networking, + async { + let initial_scheduling_plan = metadata_store_client + .get::(SCHEDULING_PLAN_KEY.clone()) + .await + .expect("scheduling plan"); + let mut scheduler = Scheduler::init( + Configuration::pinned().as_ref(), + metadata_store_client.clone(), + networking, + ) + .await?; + let observed_cluster_state = ObservedClusterState::default(); + + scheduler + .on_observed_cluster_state( + &observed_cluster_state, + &metadata().nodes_config_ref(), + NoPlacementHints, ) .await?; - let observed_cluster_state = ObservedClusterState::default(); - - scheduler - .on_observed_cluster_state( - &observed_cluster_state, - &metadata().nodes_config_ref(), - NoPlacementHints, - ) - .await?; - let scheduling_plan = metadata_store_client - .get::(SCHEDULING_PLAN_KEY.clone()) - .await - .expect("scheduling plan"); + let scheduling_plan = metadata_store_client + .get::(SCHEDULING_PLAN_KEY.clone()) + .await + .expect("scheduling plan"); - assert_eq!(initial_scheduling_plan, scheduling_plan); + assert_eq!(initial_scheduling_plan, scheduling_plan); - Ok(()) - }) - .await + Ok(()) + } + .in_tc(&test_env.tc) + .await } #[test(tokio::test(start_paused = true))] @@ -742,78 +736,76 @@ mod tests { .set_scheduling_plan(initial_scheduling_plan) .build() .await; - let tc = env.tc.clone(); - env.tc - .run_in_scope("test", None, async move { - let mut scheduler = Scheduler::init( - Configuration::pinned().as_ref(), - tc, - metadata_store_client.clone(), - networking, - ) - .await?; - let mut observed_cluster_state = ObservedClusterState::default(); + async move { + let mut scheduler = Scheduler::init( + Configuration::pinned().as_ref(), + metadata_store_client.clone(), + networking, + ) + .await?; + let mut observed_cluster_state = ObservedClusterState::default(); - for _ in 0..num_scheduling_rounds { - let cluster_state = random_cluster_state(&node_ids, num_partitions); + for _ in 0..num_scheduling_rounds { + let cluster_state = random_cluster_state(&node_ids, num_partitions); - observed_cluster_state.update(&cluster_state); - scheduler - .on_observed_cluster_state( - &observed_cluster_state, - &metadata().nodes_config_ref(), - NoPlacementHints, - ) - .await?; - // collect all control messages from the network to build up the effective scheduling plan - let control_messages = control_recv - .as_mut() - .take_until(tokio::time::sleep(Duration::from_secs(10))) - .collect::>() - .await; - - let observed_cluster_state = - derive_observed_cluster_state(&cluster_state, control_messages); - let target_scheduling_plan = metadata_store_client - .get::(SCHEDULING_PLAN_KEY.clone()) - .await? - .expect("the scheduler should have created a scheduling plan"); - - // assert that the effective scheduling plan aligns with the target scheduling plan - assert_that!( - observed_cluster_state, - matches_scheduling_plan(&target_scheduling_plan) - ); - - let alive_nodes: HashSet<_> = cluster_state - .alive_nodes() - .map(|node| node.generational_node_id.as_plain()) - .collect(); - - for (_, target_state) in target_scheduling_plan.iter() { - // assert that every partition has a leader which is part of the alive nodes set - assert!(target_state - .leader - .is_some_and(|leader| alive_nodes.contains(&leader))); - - // assert that the replication strategy was respected - match replication_strategy { - ReplicationStrategy::OnAllNodes => { - assert_eq!(target_state.node_set, alive_nodes) - } - ReplicationStrategy::Factor(replication_factor) => assert_eq!( - target_state.node_set.len(), - alive_nodes.len().min( - usize::try_from(replication_factor.get()) - .expect("u32 fits into usize") - ) - ), + observed_cluster_state.update(&cluster_state); + scheduler + .on_observed_cluster_state( + &observed_cluster_state, + &metadata().nodes_config_ref(), + NoPlacementHints, + ) + .await?; + // collect all control messages from the network to build up the effective scheduling plan + let control_messages = control_recv + .as_mut() + .take_until(tokio::time::sleep(Duration::from_secs(10))) + .collect::>() + .await; + + let observed_cluster_state = + derive_observed_cluster_state(&cluster_state, control_messages); + let target_scheduling_plan = metadata_store_client + .get::(SCHEDULING_PLAN_KEY.clone()) + .await? + .expect("the scheduler should have created a scheduling plan"); + + // assert that the effective scheduling plan aligns with the target scheduling plan + assert_that!( + observed_cluster_state, + matches_scheduling_plan(&target_scheduling_plan) + ); + + let alive_nodes: HashSet<_> = cluster_state + .alive_nodes() + .map(|node| node.generational_node_id.as_plain()) + .collect(); + + for (_, target_state) in target_scheduling_plan.iter() { + // assert that every partition has a leader which is part of the alive nodes set + assert!(target_state + .leader + .is_some_and(|leader| alive_nodes.contains(&leader))); + + // assert that the replication strategy was respected + match replication_strategy { + ReplicationStrategy::OnAllNodes => { + assert_eq!(target_state.node_set, alive_nodes) } + ReplicationStrategy::Factor(replication_factor) => assert_eq!( + target_state.node_set.len(), + alive_nodes.len().min( + usize::try_from(replication_factor.get()) + .expect("u32 fits into usize") + ) + ), } } - googletest::Result::Ok(()) - }) - .await?; + } + googletest::Result::Ok(()) + } + .in_tc(&env.tc) + .await?; Ok(()) } diff --git a/crates/admin/src/cluster_controller/service.rs b/crates/admin/src/cluster_controller/service.rs index 1c07b4ca8..be9f47afe 100644 --- a/crates/admin/src/cluster_controller/service.rs +++ b/crates/admin/src/cluster_controller/service.rs @@ -58,7 +58,6 @@ pub enum Error { } pub struct Service { - task_center: TaskCenter, metadata: Metadata, networking: Networking, bifrost: Bifrost, @@ -84,7 +83,6 @@ where mut configuration: Live, health_status: HealthStatus, bifrost: Bifrost, - task_center: TaskCenter, metadata: Metadata, networking: Networking, router_builder: &mut MessageRouterBuilder, @@ -94,12 +92,8 @@ where ) -> Self { let (command_tx, command_rx) = mpsc::channel(2); - let cluster_state_refresher = ClusterStateRefresher::new( - task_center.clone(), - metadata.clone(), - networking.clone(), - router_builder, - ); + let cluster_state_refresher = + ClusterStateRefresher::new(metadata.clone(), networking.clone(), router_builder); let processor_manager_client = PartitionProcessorManagerClient::new(networking.clone(), router_builder); @@ -125,7 +119,6 @@ where Service { configuration, health_status, - task_center, metadata, networking, bifrost, @@ -230,10 +223,9 @@ impl Service { let mut config_watcher = Configuration::watcher(); let mut cluster_state_watcher = self.cluster_state_refresher.cluster_state_watcher(); - self.task_center.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, "cluster-controller-metadata-sync", - None, sync_cluster_controller_metadata(self.metadata.clone()), )?; @@ -346,10 +338,9 @@ impl Service { ); let mut node_rpc_client = self.processor_manager_client.clone(); - let _ = self.task_center.spawn_child( + let _ = TaskCenter::spawn_child( TaskKind::Disposable, "create-snapshot-response", - Some(partition_id), async move { let _ = response_tx.send( node_rpc_client @@ -521,7 +512,6 @@ mod tests { Live::from_value(Configuration::default()), HealthStatus::default(), bifrost.clone(), - builder.tc.clone(), builder.metadata.clone(), builder.networking.clone(), &mut builder.router_builder, @@ -841,7 +831,6 @@ mod tests { Live::from_value(config), HealthStatus::default(), bifrost.clone(), - builder.tc.clone(), builder.metadata.clone(), builder.networking.clone(), &mut builder.router_builder, diff --git a/crates/admin/src/cluster_controller/service/state.rs b/crates/admin/src/cluster_controller/service/state.rs index b0647fecb..d8e0a5948 100644 --- a/crates/admin/src/cluster_controller/service/state.rs +++ b/crates/admin/src/cluster_controller/service/state.rs @@ -162,7 +162,6 @@ where let scheduler = Scheduler::init( &configuration, - service.task_center.clone(), service.metadata_store_client.clone(), service.networking.clone(), ) diff --git a/crates/bifrost/benches/util.rs b/crates/bifrost/benches/util.rs index d2af12f47..30a5c8ea1 100644 --- a/crates/bifrost/benches/util.rs +++ b/crates/bifrost/benches/util.rs @@ -13,6 +13,7 @@ use tracing::warn; use restate_core::{ spawn_metadata_manager, MetadataBuilder, MetadataManager, TaskCenter, TaskCenterBuilder, + TaskCenterFutureExt, }; use restate_metadata_store::{MetadataStoreClient, Precondition}; use restate_rocksdb::RocksDbManager; @@ -34,25 +35,30 @@ pub async fn spawn_environment( .build() .expect("task_center builds"); - restate_types::config::set_current_config(config.clone()); - let metadata_builder = MetadataBuilder::default(); + async { + restate_types::config::set_current_config(config.clone()); + let metadata_builder = MetadataBuilder::default(); - let metadata_store_client = MetadataStoreClient::new_in_memory(); - let metadata = metadata_builder.to_metadata(); - let metadata_manager = MetadataManager::new(metadata_builder, metadata_store_client.clone()); + let metadata_store_client = MetadataStoreClient::new_in_memory(); + let metadata = metadata_builder.to_metadata(); + let metadata_manager = + MetadataManager::new(metadata_builder, metadata_store_client.clone()); - let metadata_writer = metadata_manager.writer(); - tc.try_set_global_metadata(metadata.clone()); + let metadata_writer = metadata_manager.writer(); + TaskCenter::try_set_global_metadata(metadata.clone()); - tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(config.common))); + RocksDbManager::init(Constant::new(config.common)); - let logs = restate_types::logs::metadata::bootstrap_logs_metadata(provider, None, num_logs); + let logs = restate_types::logs::metadata::bootstrap_logs_metadata(provider, None, num_logs); - metadata_store_client - .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) - .await - .expect("to store bifrost config in metadata store"); - metadata_writer.submit(Arc::new(logs)); - spawn_metadata_manager(&tc, metadata_manager).expect("metadata manager starts"); + metadata_store_client + .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) + .await + .expect("to store bifrost config in metadata store"); + metadata_writer.submit(Arc::new(logs)); + spawn_metadata_manager(metadata_manager).expect("metadata manager starts"); + } + .in_tc(&tc) + .await; tc } diff --git a/crates/bifrost/src/providers/local_loglet/log_store_writer.rs b/crates/bifrost/src/providers/local_loglet/log_store_writer.rs index d4a7460ea..6d97ede08 100644 --- a/crates/bifrost/src/providers/local_loglet/log_store_writer.rs +++ b/crates/bifrost/src/providers/local_loglet/log_store_writer.rs @@ -20,7 +20,7 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt as TokioStreamExt; use tracing::{debug, error, trace, warn}; -use restate_core::{cancellation_watcher, task_center, ShutdownError, TaskKind}; +use restate_core::{cancellation_watcher, ShutdownError, TaskCenter, TaskKind}; use restate_rocksdb::{IoMode, Priority, RocksDb}; use restate_types::config::LocalLogletOptions; use restate_types::live::BoxedLiveLoad; @@ -85,10 +85,9 @@ impl LogStoreWriter { // the backlog while we process this one. let (sender, receiver) = mpsc::channel(batch_size * 2); - task_center().spawn_child( + TaskCenter::spawn_child( TaskKind::LogletProvider, "local-loglet-writer", - None, async move { debug!("Start running LogStoreWriter"); let opts = updateable.live_load(); diff --git a/crates/bifrost/src/providers/replicated_loglet/loglet.rs b/crates/bifrost/src/providers/replicated_loglet/loglet.rs index 956483202..552470693 100644 --- a/crates/bifrost/src/providers/replicated_loglet/loglet.rs +++ b/crates/bifrost/src/providers/replicated_loglet/loglet.rs @@ -317,7 +317,6 @@ impl Loglet for ReplicatedLoglet { async fn seal(&self) -> Result<(), OperationError> { let _ = SealTask::new( - task_center(), self.my_params.clone(), self.logservers_rpc.seal.clone(), self.known_global_tail.clone(), @@ -388,7 +387,6 @@ mod tests { let log_server = LogServerService::create( HealthStatus::default(), config.clone(), - node_env.tc.clone(), node_env.metadata.clone(), node_env.metadata_store_client.clone(), record_cache.clone(), diff --git a/crates/bifrost/src/providers/replicated_loglet/network.rs b/crates/bifrost/src/providers/replicated_loglet/network.rs index c99174f3b..259c9f1c0 100644 --- a/crates/bifrost/src/providers/replicated_loglet/network.rs +++ b/crates/bifrost/src/providers/replicated_loglet/network.rs @@ -15,7 +15,6 @@ use std::sync::Arc; use std::time::Duration; use futures::StreamExt; -use restate_types::errors::MaybeRetryableError; use tracing::{instrument, trace}; use restate_core::network::{ @@ -23,9 +22,10 @@ use restate_core::network::{ TransportConnect, }; use restate_core::{ - cancellation_watcher, task_center, Metadata, MetadataKind, SyncError, TargetVersion, TaskKind, + cancellation_watcher, Metadata, MetadataKind, SyncError, TargetVersion, TaskCenter, TaskKind, }; use restate_types::config::ReplicatedLogletOptions; +use restate_types::errors::MaybeRetryableError; use restate_types::logs::{LogletOffset, SequenceNumber}; use restate_types::net::replicated_loglet::{ Append, Appended, CommonRequestHeader, CommonResponseHeader, GetSequencerState, SequencerState, @@ -49,15 +49,10 @@ macro_rules! return_error_status { }, }; - let _ = task_center().spawn_child( - TaskKind::Disposable, - "append-return-error", - None, - async move { - $reciprocal.prepare(msg).send().await?; - Ok(()) - }, - ); + let _ = TaskCenter::spawn_child(TaskKind::Disposable, "append-return-error", async move { + $reciprocal.prepare(msg).send().await?; + Ok(()) + }); return; }}; @@ -71,15 +66,10 @@ macro_rules! return_error_status { }, }; - let _ = task_center().spawn_child( - TaskKind::Disposable, - "append-return-error", - None, - async move { - $reciprocal.prepare(msg).send().await?; - Ok(()) - }, - ); + let _ = TaskCenter::spawn_child(TaskKind::Disposable, "append-return-error", async move { + $reciprocal.prepare(msg).send().await?; + Ok(()) + }); return; }}; @@ -230,7 +220,7 @@ impl RequestPump { global_tail: global_tail.clone(), }; - let _ = task_center().spawn_child(TaskKind::Disposable, "wait-appended", None, task.run()); + let _ = TaskCenter::spawn_child(TaskKind::Disposable, "wait-appended", task.run()); } async fn get_loglet( diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs index 9b271270e..d8f504763 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/find_tail.rs @@ -350,7 +350,6 @@ impl FindTailTask { if nodeset_checker.any(NodeTailStatus::is_known_sealed) { // run seal task then retry the find-tail check. let seal_task = SealTask::new( - self.task_center.clone(), self.my_params.clone(), self.logservers_rpc.seal.clone(), self.known_global_tail.clone(), diff --git a/crates/bifrost/src/providers/replicated_loglet/tasks/seal.rs b/crates/bifrost/src/providers/replicated_loglet/tasks/seal.rs index bf958a2c7..eafaca25e 100644 --- a/crates/bifrost/src/providers/replicated_loglet/tasks/seal.rs +++ b/crates/bifrost/src/providers/replicated_loglet/tasks/seal.rs @@ -35,7 +35,6 @@ use crate::providers::replicated_loglet::replication::NodeSetChecker; /// The seal operation is idempotent. It's safe to seal a loglet if it's already partially or fully /// sealed. Note that the seal task ignores the "seal" state in the input known_global_tail watch. pub struct SealTask { - task_center: TaskCenter, my_params: ReplicatedLogletParams, seal_router: RpcRouter, known_global_tail: TailOffsetWatch, @@ -43,13 +42,11 @@ pub struct SealTask { impl SealTask { pub fn new( - task_center: TaskCenter, my_params: ReplicatedLogletParams, seal_router: RpcRouter, known_global_tail: TailOffsetWatch, ) -> Self { Self { - task_center, my_params, seal_router, known_global_tail, @@ -89,20 +86,19 @@ impl SealTask { networking: networking.clone(), known_global_tail: self.known_global_tail.clone(), }; - self.task_center - .spawn_child(TaskKind::Disposable, "send-seal-request", None, { - let retry_policy = retry_policy.clone(); - let tx = tx.clone(); - async move { - if let Err(e) = task.run(tx, retry_policy).await { - // We only want to trace-log if an individual seal request fails. - // If we leave the task to fail, task-center will log a scary error-level log - // which can be misleading to users. - trace!("Seal: {e}"); - } - Ok(()) + TaskCenter::spawn_child(TaskKind::Disposable, "send-seal-request", { + let retry_policy = retry_policy.clone(); + let tx = tx.clone(); + async move { + if let Err(e) = task.run(tx, retry_policy).await { + // We only want to trace-log if an individual seal request fails. + // If we leave the task to fail, task-center will log a scary error-level log + // which can be misleading to users. + trace!("Seal: {e}"); } - })?; + Ok(()) + } + })?; } drop(tx); diff --git a/crates/core/src/metadata/manager.rs b/crates/core/src/metadata/manager.rs index 1f5dacd00..5629a2e61 100644 --- a/crates/core/src/metadata/manager.rs +++ b/crates/core/src/metadata/manager.rs @@ -39,7 +39,8 @@ use crate::network::Outgoing; use crate::network::Reciprocal; use crate::network::WeakConnection; use crate::network::{MessageHandler, MessageRouterBuilder, NetworkError}; -use crate::task_center; +use crate::TaskCenter; +use crate::TaskKind; pub(super) type CommandSender = mpsc::UnboundedSender; pub(super) type CommandReceiver = mpsc::UnboundedReceiver; @@ -162,17 +163,11 @@ impl MetadataMessageHandler { container: MetadataContainer::from(metadata), })); - let _ = task_center().spawn_child( - crate::TaskKind::Disposable, - "send-metadata-to-peer", - None, - { - async move { - outgoing.send().await?; - Ok(()) - } - }, - ); + let _ = + TaskCenter::spawn_child(TaskKind::Disposable, "send-metadata-to-peer", async move { + outgoing.send().await?; + Ok(()) + }); } } @@ -630,9 +625,8 @@ mod tests { // updates happening before metadata manager start should not get lost. metadata_writer.submit(Arc::new(value.clone())); - let tc = task_center(); // start metadata manager - spawn_metadata_manager(&tc, metadata_manager)?; + spawn_metadata_manager(metadata_manager)?; let version = metadata.wait_for_version(kind, Version::MIN).await.unwrap(); assert_eq!(Version::MIN, version); @@ -651,7 +645,7 @@ mod tests { let _ = metadata.wait_for_version(kind, Version::from(3)).await; - tc.cancel_tasks(None, None).await; + TaskCenter::current().cancel_tasks(None, None).await; Ok(()) }) } @@ -702,7 +696,7 @@ mod tests { assert_eq!(Version::MIN, value.version()); // start metadata manager - spawn_metadata_manager(&task_center(), metadata_manager)?; + spawn_metadata_manager(metadata_manager)?; let mut watcher1 = metadata.watch(kind); assert_eq!(Version::INVALID, *watcher1.borrow()); diff --git a/crates/core/src/metadata/mod.rs b/crates/core/src/metadata/mod.rs index 4735d292f..e2dd6809a 100644 --- a/crates/core/src/metadata/mod.rs +++ b/crates/core/src/metadata/mod.rs @@ -386,11 +386,8 @@ impl Default for VersionWatch { } } -pub fn spawn_metadata_manager( - tc: &TaskCenter, - metadata_manager: MetadataManager, -) -> Result { - tc.spawn( +pub fn spawn_metadata_manager(metadata_manager: MetadataManager) -> Result { + TaskCenter::current().spawn( TaskKind::MetadataBackgroundSync, "metadata-manager", None, diff --git a/crates/core/src/network/connection_manager.rs b/crates/core/src/network/connection_manager.rs index edb43cffd..dd9ce033d 100644 --- a/crates/core/src/network/connection_manager.rs +++ b/crates/core/src/network/connection_manager.rs @@ -41,8 +41,8 @@ use super::{Handler, MessageRouter}; use crate::metadata::Urgency; use crate::network::handshake::{negotiate_protocol_version, wait_for_hello}; use crate::network::{Incoming, PeerMetadataVersion}; -use crate::Metadata; -use crate::{cancellation_watcher, current_task_id, task_center, TaskId, TaskKind}; +use crate::{cancellation_watcher, current_task_id, TaskId, TaskKind}; +use crate::{Metadata, TaskCenter}; struct ConnectionManagerInner { router: MessageRouter, @@ -451,10 +451,9 @@ impl ConnectionManager { ); let router = guard.router.clone(); - let task_id = task_center().spawn_child( + let task_id = TaskCenter::spawn_child( TaskKind::ConnectionReactor, "network-connection-reactor", - None, run_reactor( self.inner.clone(), connection.clone(), diff --git a/crates/core/src/network/net_util.rs b/crates/core/src/network/net_util.rs index 58290a769..0f184411f 100644 --- a/crates/core/src/network/net_util.rs +++ b/crates/core/src/network/net_util.rs @@ -14,7 +14,6 @@ use std::net::SocketAddr; use std::path::PathBuf; use std::time::Duration; -use crate::{cancellation_watcher, task_center, ShutdownError, TaskCenter, TaskKind}; use http::Uri; use hyper::body::{Body, Incoming}; use hyper::rt::{Read, Write}; @@ -29,6 +28,8 @@ use restate_types::config::{MetadataStoreClientOptions, NetworkingOptions}; use restate_types::errors::GenericError; use restate_types::net::{AdvertisedAddress, BindAddress}; +use crate::{cancellation_watcher, ShutdownError, TaskCenter, TaskKind}; + pub fn create_tonic_channel_from_advertised_address( address: AdvertisedAddress, options: &T, @@ -161,8 +162,6 @@ where B::Error: Into>, { let mut shutdown = std::pin::pin!(cancellation_watcher()); - let tc = task_center(); - let executor = TaskCenterExecutor::new(tc.clone(), server_name); loop { tokio::select! { biased; @@ -174,10 +173,10 @@ where let io = TokioIo::new(stream); debug!(?remote_addr, "Accepting incoming connection"); - tc.spawn_child(TaskKind::RpcConnection, server_name, None, handle_connection( + TaskCenter::spawn_child(TaskKind::RpcConnection, server_name, handle_connection( + server_name, io, service.clone(), - executor.clone(), remote_addr, ))?; } @@ -188,9 +187,9 @@ where } async fn handle_connection( + server_name: &'static str, io: I, service: S, - executor: TaskCenterExecutor, remote_addr: A, ) -> anyhow::Result<()> where @@ -206,6 +205,8 @@ where I: Read + Write + Unpin + 'static, A: Send + Debug, { + // todo: asoli + let executor = TaskCenterExecutor::new(TaskCenter::current(), server_name); let builder = hyper_util::server::conn::auto::Builder::new(executor); let connection = builder.serve_connection(io, service); @@ -246,13 +247,13 @@ where { fn execute(&self, fut: F) { // ignore shutdown error - let _ = - self.task_center - .spawn_child(TaskKind::RpcConnection, self.name, None, async move { - // ignore the future output - let _ = fut.await; - Ok(()) - }); + self.task_center.run_in_scope_sync(|| { + let _ = TaskCenter::spawn_child(TaskKind::RpcConnection, self.name, async move { + // ignore the future output + let _ = fut.await; + Ok(()) + }); + }); } } diff --git a/crates/core/src/task_center/mod.rs b/crates/core/src/task_center/mod.rs index 4c07d1446..1c48897d0 100644 --- a/crates/core/src/task_center/mod.rs +++ b/crates/core/src/task_center/mod.rs @@ -282,7 +282,13 @@ impl TaskCenter { /// Attempt to set the global metadata handle. This should be called once /// at the startup of the node. - pub fn try_set_global_metadata(&self, metadata: Metadata) -> bool { + pub fn try_set_global_metadata(metadata: Metadata) -> bool { + Self::with_current(|tc| tc.try_set_global_metadata_inner(metadata)) + } + + /// Attempt to set the global metadata handle. This should be called once + /// at the startup of the node. + pub(crate) fn try_set_global_metadata_inner(&self, metadata: Metadata) -> bool { self.inner.global_metadata.set(metadata).is_ok() } @@ -563,10 +569,23 @@ impl TaskCenter { /// finish before completion, but this might change in the future if the need for that arises. #[track_caller] pub fn spawn_child( + kind: TaskKind, + name: &'static str, + future: F, + ) -> Result + where + F: Future> + Send + 'static, + { + Self::with_current(|tc| tc.spawn_child_inner(kind, name, future)) + } + + /// Spawn a new task that is a child of the current task. The child task will be cancelled if the parent + /// task is cancelled. At the moment, the parent task will not automatically wait for children tasks to + /// finish before completion, but this might change in the future if the need for that arises. + fn spawn_child_inner( &self, kind: TaskKind, name: &'static str, - partition_id: Option, future: F, ) -> Result where @@ -576,16 +595,18 @@ impl TaskCenter { return Err(ShutdownError); } - let (parent_id, parent_name, parent_kind, cancel) = self.with_task_context(|ctx| { - ( - ctx.id, - ctx.name, - ctx.kind, - ctx.cancellation_token.child_token(), - ) - }); - - let result = self.spawn_inner(kind, name, partition_id, cancel, future); + let (parent_id, parent_name, parent_kind, parent_partition, cancel) = self + .with_task_context(|ctx| { + ( + ctx.id, + ctx.name, + ctx.kind, + ctx.partition_id, + ctx.cancellation_token.child_token(), + ) + }); + + let result = self.spawn_inner(kind, name, parent_partition, cancel, future); trace!( kind = ?parent_kind, diff --git a/crates/core/src/task_center/task_kind.rs b/crates/core/src/task_center/task_kind.rs index b18ab226a..a8ddeda9d 100644 --- a/crates/core/src/task_center/task_kind.rs +++ b/crates/core/src/task_center/task_kind.rs @@ -95,6 +95,8 @@ pub enum TaskKind { SystemService, #[strum(props(OnCancel = "abort", runtime = "ingress"))] Ingress, + /// Kafka ingestion related task + Kafka, PartitionProcessor, /// Longer-running, low-priority tasks that is responsible for the export, and potentially /// upload to remote storage, of partition store snapshots. diff --git a/crates/core/src/test_env.rs b/crates/core/src/test_env.rs index 2faccb5b8..5ebdd702b 100644 --- a/crates/core/src/test_env.rs +++ b/crates/core/src/test_env.rs @@ -33,7 +33,7 @@ use crate::network::{ ConnectionManager, FailingConnector, Incoming, MessageHandler, MessageRouterBuilder, NetworkError, Networking, ProtocolError, TransportConnect, }; -use crate::{spawn_metadata_manager, MetadataBuilder, TaskId}; +use crate::{spawn_metadata_manager, MetadataBuilder, TaskCenterFutureExt, TaskId}; use crate::{Metadata, MetadataManager, MetadataWriter}; use crate::{TaskCenter, TaskCenterBuilder}; @@ -103,7 +103,7 @@ impl TestCoreEnvBuilder { let partition_table = PartitionTable::with_equally_sized_partitions(Version::MIN, 10); let scheduling_plan = SchedulingPlan::from(&partition_table, ReplicationStrategy::OnAllNodes); - tc.try_set_global_metadata(metadata.clone()); + tc.try_set_global_metadata_inner(metadata.clone()); // Use memory-loglet as a default if in test-mode #[cfg(any(test, feature = "test-util"))] @@ -167,78 +167,79 @@ impl TestCoreEnvBuilder { } pub async fn build(mut self) -> TestCoreEnv { - self.metadata_manager - .register_in_message_router(&mut self.router_builder); - self.networking - .connection_manager() - .set_message_router(self.router_builder.build()); - - let metadata_manager_task = spawn_metadata_manager(&self.tc, self.metadata_manager) - .expect("metadata manager should start"); - - self.metadata_store_client - .put( - NODES_CONFIG_KEY.clone(), - &self.nodes_config, - Precondition::None, - ) - .await - .expect("to store nodes config in metadata store"); - self.metadata_writer - .submit(Arc::new(self.nodes_config.clone())); - - let logs = bootstrap_logs_metadata( - self.provider_kind, - None, - self.partition_table.num_partitions(), - ); - self.metadata_store_client - .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) - .await - .expect("to store bifrost config in metadata store"); - self.metadata_writer.submit(Arc::new(logs)); - - self.metadata_store_client - .put( - PARTITION_TABLE_KEY.clone(), - &self.partition_table, - Precondition::None, - ) - .await - .expect("to store partition table in metadata store"); - self.metadata_writer.submit(Arc::new(self.partition_table)); - - self.metadata_store_client - .put( - SCHEDULING_PLAN_KEY.clone(), - &self.scheduling_plan, - Precondition::None, - ) - .await - .expect("sot store scheduling plan in metadata store"); - - self.tc - .run_in_scope("test-env", None, async { - let _ = self - .metadata - .wait_for_version( - MetadataKind::NodesConfiguration, - self.nodes_config.version(), - ) - .await - .unwrap(); - }) - .await; - self.metadata_writer.set_my_node_id(self.my_node_id); - - TestCoreEnv { - tc: self.tc, - metadata: self.metadata, - metadata_manager_task, - metadata_writer: self.metadata_writer, - networking: self.networking, - metadata_store_client: self.metadata_store_client, + let tc = self.tc; + async { + self.metadata_manager + .register_in_message_router(&mut self.router_builder); + self.networking + .connection_manager() + .set_message_router(self.router_builder.build()); + + let metadata_manager_task = spawn_metadata_manager(self.metadata_manager) + .expect("metadata manager should start"); + + self.metadata_store_client + .put( + NODES_CONFIG_KEY.clone(), + &self.nodes_config, + Precondition::None, + ) + .await + .expect("to store nodes config in metadata store"); + self.metadata_writer + .submit(Arc::new(self.nodes_config.clone())); + + let logs = bootstrap_logs_metadata( + self.provider_kind, + None, + self.partition_table.num_partitions(), + ); + self.metadata_store_client + .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) + .await + .expect("to store bifrost config in metadata store"); + self.metadata_writer.submit(Arc::new(logs)); + + self.metadata_store_client + .put( + PARTITION_TABLE_KEY.clone(), + &self.partition_table, + Precondition::None, + ) + .await + .expect("to store partition table in metadata store"); + self.metadata_writer.submit(Arc::new(self.partition_table)); + + self.metadata_store_client + .put( + SCHEDULING_PLAN_KEY.clone(), + &self.scheduling_plan, + Precondition::None, + ) + .await + .expect("sot store scheduling plan in metadata store"); + + let _ = self + .metadata + .wait_for_version( + MetadataKind::NodesConfiguration, + self.nodes_config.version(), + ) + .await + .unwrap(); + self.metadata_writer.set_my_node_id(self.my_node_id); + + TestCoreEnv { + tc: TaskCenter::current(), + metadata: self.metadata, + metadata_manager_task, + metadata_writer: self.metadata_writer, + networking: self.networking, + metadata_store_client: self.metadata_store_client, + } } + .in_tc(&tc) + .await } } diff --git a/crates/ingress-kafka/src/consumer_task.rs b/crates/ingress-kafka/src/consumer_task.rs index e6124e481..71619b3fb 100644 --- a/crates/ingress-kafka/src/consumer_task.rs +++ b/crates/ingress-kafka/src/consumer_task.rs @@ -287,7 +287,7 @@ impl ConsumerTask { consumer_group_id.clone() ); - if let Ok(task_id) = self.task_center.spawn_child(TaskKind::Ingress, "partition-queue", None, task) { + if let Ok(task_id) = TaskCenter::spawn_child(TaskKind::Ingress, "partition-queue", task) { e.insert(task_id); } else { break Ok(()); diff --git a/crates/ingress-kafka/src/subscription_controller.rs b/crates/ingress-kafka/src/subscription_controller.rs index ec95f1df1..42619a7c4 100644 --- a/crates/ingress-kafka/src/subscription_controller.rs +++ b/crates/ingress-kafka/src/subscription_controller.rs @@ -191,7 +191,7 @@ impl Service { mod task_orchestrator { use crate::consumer_task; - use restate_core::task_center; + use restate_core::{TaskCenterFutureExt, TaskKind}; use restate_timer_queue::TimerQueue; use restate_types::identifiers::SubscriptionId; use restate_types::retries::{RetryIter, RetryPolicy}; @@ -344,12 +344,10 @@ mod task_orchestrator { let task_id = self .tasks .spawn({ - let tc = task_center(); let consumer_task_clone = consumer_task_clone.clone(); - async move { - tc.run_in_scope("kafka-consumer-task", None, consumer_task_clone.run(rx)) - .await - } + consumer_task_clone + .run(rx) + .in_current_tc_as_task(TaskKind::Kafka, "kafka-consumer-task") }) .id(); diff --git a/crates/log-server/src/loglet_worker.rs b/crates/log-server/src/loglet_worker.rs index 64106cf4a..c83c86b85 100644 --- a/crates/log-server/src/loglet_worker.rs +++ b/crates/log-server/src/loglet_worker.rs @@ -93,7 +93,6 @@ impl LogletWorkerHandle { } pub struct LogletWorker { - task_center: TaskCenter, loglet_id: ReplicatedLogletId, log_store: S, loglet_state: LogletState, @@ -101,13 +100,11 @@ pub struct LogletWorker { impl LogletWorker { pub fn start( - task_center: TaskCenter, loglet_id: ReplicatedLogletId, log_store: S, loglet_state: LogletState, ) -> Result { let writer = Self { - task_center: task_center.clone(), loglet_id, log_store, loglet_state, @@ -121,7 +118,8 @@ impl LogletWorker { let (trim_tx, trim_rx) = mpsc::unbounded_channel(); let (wait_for_tail_tx, wait_for_tail_rx) = mpsc::unbounded_channel(); let (get_digest_tx, get_digest_rx) = mpsc::unbounded_channel(); - let tc_handle = task_center.spawn_unmanaged( + // todo + let tc_handle = TaskCenter::current().spawn_unmanaged( TaskKind::LogletWriter, "loglet-worker", None, @@ -423,7 +421,7 @@ impl LogletWorker { fn process_wait_for_tail(&mut self, msg: Incoming) { let loglet_state = self.loglet_state.clone(); // fails on shutdown, in this case, we ignore the request - let _ = self.task_center.spawn( + let _ = TaskCenter::current().spawn( TaskKind::Disposable, "logserver-tail-monitor", None, @@ -465,7 +463,7 @@ impl LogletWorker { let log_store = self.log_store.clone(); let loglet_state = self.loglet_state.clone(); // fails on shutdown, in this case, we ignore the request - let _ = self.task_center.spawn( + let _ = TaskCenter::current().spawn( TaskKind::Disposable, "logserver-get-records", None, @@ -500,7 +498,7 @@ impl LogletWorker { let log_store = self.log_store.clone(); let loglet_state = self.loglet_state.clone(); // fails on shutdown, in this case, we ignore the request - let _ = self.task_center.spawn( + let _ = TaskCenter::current().spawn( TaskKind::Disposable, "logserver-get-digest", None, @@ -538,8 +536,8 @@ impl LogletWorker { // fails on shutdown, in this case, we ignore the request let mut loglet_state = self.loglet_state.clone(); let log_store = self.log_store.clone(); - let _ = self - .task_center + let _ = + TaskCenter::current() .spawn(TaskKind::Disposable, "logserver-trim", None, async move { let loglet_id = msg.body().header.loglet_id; let new_trim_point = msg.body().trim_point; @@ -616,7 +614,7 @@ mod tests { use test_log::test; use restate_core::network::OwnedConnection; - use restate_core::{MetadataBuilder, TaskCenter, TaskCenterBuilder}; + use restate_core::{MetadataBuilder, TaskCenter, TaskCenterBuilder, TaskCenterFutureExt}; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -636,872 +634,884 @@ mod tests { let tc = TaskCenterBuilder::default_for_tests().build()?; let config = Live::from_value(Configuration::default()); let common_rocks_opts = config.clone().map(|c| &c.common); - let log_store = tc - .run_in_scope("test-setup", None, async { - RocksDbManager::init(common_rocks_opts); - let metadata_builder = MetadataBuilder::default(); - assert!(tc.try_set_global_metadata(metadata_builder.to_metadata())); - // create logstore. - let builder = RocksDbLogStoreBuilder::create( - config.clone().map(|c| &c.log_server).boxed(), - config.map(|c| &c.log_server.rocksdb).boxed(), - RecordCache::new(1_000_000), - ) - .await?; - let log_store = builder.start(&tc, Default::default()).await?; - Result::Ok(log_store) - }) + let log_store = async { + RocksDbManager::init(common_rocks_opts); + let metadata_builder = MetadataBuilder::default(); + assert!(TaskCenter::try_set_global_metadata( + metadata_builder.to_metadata() + )); + // create logstore. + let builder = RocksDbLogStoreBuilder::create( + config.clone().map(|c| &c.log_server).boxed(), + config.map(|c| &c.log_server.rocksdb).boxed(), + RecordCache::new(1_000_000), + ) .await?; + let log_store = builder.start(Default::default()).await?; + Result::Ok(log_store) + } + .in_tc(&tc) + .await?; Ok((tc, log_store)) } #[test(tokio::test(start_paused = true))] async fn test_simple_store_flow() -> Result<()> { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let (tc, log_store) = setup().await?; - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(tc.clone(), LOGLET, log_store, loglet_state)?; - - let payloads: Arc<[Record]> = vec![ - Record::from("a sample record"), - Record::from("another record"), - ] - .into(); - - // offsets 1, 2 - let msg1 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::OLDEST, - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; + async { + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + let payloads: Arc<[Record]> = vec![ + Record::from("a sample record"), + Record::from("another record"), + ] + .into(); + + // offsets 1, 2 + let msg1 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::OLDEST, + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + + // offsets 3, 4 + let msg2 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(3), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); + let msg1_id = msg1.msg_id(); + let msg2_id = msg2.msg_id(); + + // pipelined writes + worker.enqueue_store(msg1).unwrap(); + worker.enqueue_store(msg2).unwrap(); + // wait for response (in test-env, it's safe to assume that responses will arrive in order) + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg1_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - // offsets 3, 4 - let msg2 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(3), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; + // response 2 + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg2_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(5))); + + tc.shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; - let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); - let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); - let msg1_id = msg1.msg_id(); - let msg2_id = msg2.msg_id(); - - // pipelined writes - worker.enqueue_store(msg1).unwrap(); - worker.enqueue_store(msg2).unwrap(); - // wait for response (in test-env, it's safe to assume that responses will arrive in order) - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg1_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - - // response 2 - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg2_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(5))); - - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - - Ok(()) + Ok(()) + } + .in_tc(&tc) + .await } #[test(tokio::test(start_paused = true))] async fn test_store_and_seal() -> Result<()> { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let (tc, log_store) = setup().await?; - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(tc.clone(), LOGLET, log_store, loglet_state)?; - - let payloads: Arc<[Record]> = vec![ - Record::from("a sample record"), - Record::from("another record"), - ] - .into(); - - // offsets 1, 2 - let msg1 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::OLDEST, - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - - let seal1 = Seal { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - sequencer: SEQUENCER, - }; - - let seal2 = Seal { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - sequencer: SEQUENCER, - }; - - // offsets 3, 4 - let msg2 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(3), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - - let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); - let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); - let seal2 = Incoming::for_testing(connection.downgrade(), seal2, None); - let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); - let msg1_id = msg1.msg_id(); - let seal1_id = seal1.msg_id(); - let seal2_id = seal2.msg_id(); - let msg2_id = msg2.msg_id(); - - worker.enqueue_store(msg1).unwrap(); - // first store is successful - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg1_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - worker.enqueue_seal(seal1).unwrap(); - // should latch onto existing seal - worker.enqueue_seal(seal2).unwrap(); - // seal takes precedence, but it gets processed in the background. This store is likely to - // observe Status::Sealing - worker.enqueue_store(msg2).unwrap(); - // sealing - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg2_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Sealing)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - // seal responses can come at any order, but we'll consume waiters queue before we process - // store messages. - // sealed - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); - let sealed: Sealed = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(sealed.status, eq(Status::Ok)); - assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); - - // sealed2 - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); - let sealed: Sealed = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(sealed.status, eq(Status::Ok)); - assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); - - // try another store - let msg3 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(3)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(3), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - let msg3 = Incoming::for_testing(connection.downgrade(), msg3, None); - let msg3_id = msg3.msg_id(); - worker.enqueue_store(msg3).unwrap(); - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg3_id)); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Sealed)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - - // GetLogletInfo - // offsets 3, 4 - let msg = GetLogletInfo { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - }; - let msg = Incoming::for_testing(connection.downgrade(), msg, None); - let msg_id = msg.msg_id(); - worker.enqueue_get_loglet_info(msg).unwrap(); - - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg_id)); - let info: LogletInfo = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(info.status, eq(Status::Ok)); - assert_that!(info.local_tail, eq(LogletOffset::new(3))); - assert_that!(info.trim_point, eq(LogletOffset::INVALID)); - assert_that!(info.sealed, eq(true)); - - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - - Ok(()) + async { + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + let payloads: Arc<[Record]> = vec![ + Record::from("a sample record"), + Record::from("another record"), + ] + .into(); + + // offsets 1, 2 + let msg1 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::OLDEST, + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + + let seal1 = Seal { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + sequencer: SEQUENCER, + }; + + let seal2 = Seal { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + sequencer: SEQUENCER, + }; + + // offsets 3, 4 + let msg2 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(3), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); + let seal2 = Incoming::for_testing(connection.downgrade(), seal2, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); + let msg1_id = msg1.msg_id(); + let seal1_id = seal1.msg_id(); + let seal2_id = seal2.msg_id(); + let msg2_id = msg2.msg_id(); + + worker.enqueue_store(msg1).unwrap(); + // first store is successful + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg1_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + worker.enqueue_seal(seal1).unwrap(); + // should latch onto existing seal + worker.enqueue_seal(seal2).unwrap(); + // seal takes precedence, but it gets processed in the background. This store is likely to + // observe Status::Sealing + worker.enqueue_store(msg2).unwrap(); + // sealing + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg2_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Sealing)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + // seal responses can come at any order, but we'll consume waiters queue before we process + // store messages. + // sealed + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); + let sealed: Sealed = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(sealed.status, eq(Status::Ok)); + assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); + + // sealed2 + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), any!(eq(seal1_id), eq(seal2_id))); + let sealed: Sealed = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(sealed.status, eq(Status::Ok)); + assert_that!(sealed.local_tail, eq(LogletOffset::new(3))); + + // try another store + let msg3 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(3)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(3), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + let msg3 = Incoming::for_testing(connection.downgrade(), msg3, None); + let msg3_id = msg3.msg_id(); + worker.enqueue_store(msg3).unwrap(); + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg3_id)); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Sealed)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); + + // GetLogletInfo + // offsets 3, 4 + let msg = GetLogletInfo { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + }; + let msg = Incoming::for_testing(connection.downgrade(), msg, None); + let msg_id = msg.msg_id(); + worker.enqueue_get_loglet_info(msg).unwrap(); + + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg_id)); + let info: LogletInfo = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(info.status, eq(Status::Ok)); + assert_that!(info.local_tail, eq(LogletOffset::new(3))); + assert_that!(info.trim_point, eq(LogletOffset::INVALID)); + assert_that!(info.sealed, eq(true)); + + tc.shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) + } + .in_tc(&tc) + .await } #[test(tokio::test(start_paused = true))] async fn test_repair_store() -> Result<()> { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const PEER: GenerationalNodeId = GenerationalNodeId::new(2, 2); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let (tc, log_store) = setup().await?; - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let (peer_net_tx, mut peer_net_rx) = mpsc::channel(10); - let repair_connection = - OwnedConnection::new_fake(PEER, CURRENT_PROTOCOL_VERSION, peer_net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(tc.clone(), LOGLET, log_store, loglet_state)?; - - let payloads: Arc<[Record]> = vec![ - Record::from("a sample record"), - Record::from("another record"), - ] - .into(); - - // offsets 1, 2 - let msg1 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::OLDEST, - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; - - // offsets 10, 11 - let msg2 = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(10), - flags: StoreFlags::empty(), - payloads: payloads.clone(), - }; + async { + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const PEER: GenerationalNodeId = GenerationalNodeId::new(2, 2); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let (peer_net_tx, mut peer_net_rx) = mpsc::channel(10); + let repair_connection = + OwnedConnection::new_fake(PEER, CURRENT_PROTOCOL_VERSION, peer_net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + let payloads: Arc<[Record]> = vec![ + Record::from("a sample record"), + Record::from("another record"), + ] + .into(); + + // offsets 1, 2 + let msg1 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::OLDEST, + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + + // offsets 10, 11 + let msg2 = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(10), + flags: StoreFlags::empty(), + payloads: payloads.clone(), + }; + + let seal1 = Seal { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + sequencer: SEQUENCER, + }; + + // 5, 6 + let repair_message_before_local_tail = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(5), + flags: StoreFlags::IgnoreSeal, + payloads: payloads.clone(), + }; + + // 16, 17 + let repair_message_after_local_tail = Store { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(16)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(16), + flags: StoreFlags::IgnoreSeal, + payloads: payloads.clone(), + }; + + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); + let repair1 = Incoming::for_testing( + repair_connection.downgrade(), + repair_message_before_local_tail, + None, + ); + let repair2 = Incoming::for_testing( + repair_connection.downgrade(), + repair_message_after_local_tail, + None, + ); + let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); - let seal1 = Seal { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - sequencer: SEQUENCER, - }; + worker.enqueue_store(msg1).unwrap(); + worker.enqueue_store(msg2).unwrap(); + // first store is successful + let response = net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.sealed, eq(false)); + assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - // 5, 6 - let repair_message_before_local_tail = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(5), - flags: StoreFlags::IgnoreSeal, - payloads: payloads.clone(), - }; + // 10, 11 + let response = net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.sealed, eq(false)); + assert_that!(stored.local_tail, eq(LogletOffset::new(12))); + + worker.enqueue_seal(seal1).unwrap(); + // seal responses can come at any order, but we'll consume waiters queue before we process + // store messages. + // sealed + let response = net_rx.recv().await.unwrap(); + let sealed: Sealed = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(sealed.status, eq(Status::Ok)); + assert_that!(sealed.local_tail, eq(LogletOffset::new(12))); - // 16, 17 - let repair_message_after_local_tail = Store { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(16)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(16), - flags: StoreFlags::IgnoreSeal, - payloads: payloads.clone(), - }; + // repair store (before local tail, local tail won't move) + worker.enqueue_store(repair1).unwrap(); + let response = peer_net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(12))); - let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); - let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); - let repair1 = Incoming::for_testing( - repair_connection.downgrade(), - repair_message_before_local_tail, - None, - ); - let repair2 = Incoming::for_testing( - repair_connection.downgrade(), - repair_message_after_local_tail, - None, - ); - let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); - - worker.enqueue_store(msg1).unwrap(); - worker.enqueue_store(msg2).unwrap(); - // first store is successful - let response = net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.sealed, eq(false)); - assert_that!(stored.local_tail, eq(LogletOffset::new(3))); - - // 10, 11 - let response = net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.sealed, eq(false)); - assert_that!(stored.local_tail, eq(LogletOffset::new(12))); - - worker.enqueue_seal(seal1).unwrap(); - // seal responses can come at any order, but we'll consume waiters queue before we process - // store messages. - // sealed - let response = net_rx.recv().await.unwrap(); - let sealed: Sealed = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(sealed.status, eq(Status::Ok)); - assert_that!(sealed.local_tail, eq(LogletOffset::new(12))); - - // repair store (before local tail, local tail won't move) - worker.enqueue_store(repair1).unwrap(); - let response = peer_net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(12))); - - worker.enqueue_store(repair2).unwrap(); - let response = peer_net_rx.recv().await.unwrap(); - let stored: Stored = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(18))); - - // GetLogletInfo - // offsets 3, 4 - let msg = GetLogletInfo { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - }; - let msg = Incoming::for_testing(connection.downgrade(), msg, None); - let msg_id = msg.msg_id(); - worker.enqueue_get_loglet_info(msg).unwrap(); - - let response = net_rx.recv().await.unwrap(); - let header = response.header.unwrap(); - assert_that!(header.in_response_to(), eq(msg_id)); - let info: LogletInfo = response - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(info.status, eq(Status::Ok)); - assert_that!(info.local_tail, eq(LogletOffset::new(18))); - assert_that!(info.trim_point, eq(LogletOffset::INVALID)); - assert_that!(info.sealed, eq(true)); - - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - - Ok(()) + worker.enqueue_store(repair2).unwrap(); + let response = peer_net_rx.recv().await.unwrap(); + let stored: Stored = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(18))); + + // GetLogletInfo + // offsets 3, 4 + let msg = GetLogletInfo { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + }; + let msg = Incoming::for_testing(connection.downgrade(), msg, None); + let msg_id = msg.msg_id(); + worker.enqueue_get_loglet_info(msg).unwrap(); + + let response = net_rx.recv().await.unwrap(); + let header = response.header.unwrap(); + assert_that!(header.in_response_to(), eq(msg_id)); + let info: LogletInfo = response + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(info.status, eq(Status::Ok)); + assert_that!(info.local_tail, eq(LogletOffset::new(18))); + assert_that!(info.trim_point, eq(LogletOffset::INVALID)); + assert_that!(info.sealed, eq(true)); + tc.shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) + } + .in_tc(&tc) + .await } #[test(tokio::test(start_paused = true))] async fn test_simple_get_records_flow() -> Result<()> { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let (tc, log_store) = setup().await?; - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = LogletWorker::start(tc.clone(), LOGLET, log_store, loglet_state)?; - - // Populate the log-store with some records (..,2,..,5,..,10, 11) - // Note: dots mean we don't have records at those globally committed offsets. - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=1 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(2)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(2), - flags: StoreFlags::empty(), - payloads: vec![Record::from("record2")].into(), - }, - None, - )) - .unwrap(); - - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=4 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(5)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(5), - flags: StoreFlags::empty(), - payloads: vec![Record::from(("record5", Keys::Single(11)))].into(), - }, - None, - )) - .unwrap(); - - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=9 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(10), - flags: StoreFlags::empty(), - payloads: vec![Record::from("record10"), Record::from("record11")].into(), - }, - None, - )) - .unwrap(); + async { + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store, loglet_state)?; + + // Populate the log-store with some records (..,2,..,5,..,10, 11) + // Note: dots mean we don't have records at those globally committed offsets. + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=1 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(2)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(2), + flags: StoreFlags::empty(), + payloads: vec![Record::from("record2")].into(), + }, + None, + )) + .unwrap(); + + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=4 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(5)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(5), + flags: StoreFlags::empty(), + payloads: vec![Record::from(("record5", Keys::Single(11)))].into(), + }, + None, + )) + .unwrap(); + + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=9 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(10), + flags: StoreFlags::empty(), + payloads: vec![Record::from("record10"), Record::from("record11")].into(), + }, + None, + )) + .unwrap(); + + // Wait for stores to complete. + for _ in 0..3 { + let stored: Stored = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + } - // Wait for stores to complete. - for _ in 0..3 { - let stored: Stored = net_rx + // We expect to see [2, 5]. No trim gaps, no filtered gaps. + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + // faking that offset=9 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + filter: KeyFilter::Any, + // no memory limits + total_limit_in_bytes: None, + from_offset: LogletOffset::new(1), + to_offset: LogletOffset::new(7), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx .recv() .await .unwrap() .body .unwrap() .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - } - - // We expect to see [2, 5]. No trim gaps, no filtered gaps. - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - // faking that offset=9 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - filter: KeyFilter::Any, - // no memory limits - total_limit_in_bytes: None, - from_offset: LogletOffset::new(1), - to_offset: LogletOffset::new(7), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(12))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.next_offset, eq(LogletOffset::new(8))); - assert_that!(records.records.len(), eq(2)); - // pop in reverse order - for i in [5, 2] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - assert_that!(record.is_data(), eq(true)); - let data = record.try_unwrap_data().unwrap(); - let original: String = data.decode().unwrap(); - assert_that!(original, eq(format!("record{}", i))); - } - - // We expect to see [2, FILTERED(5), 10, 11]. No trim gaps. - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - // INVALID can be used when we don't have a reasonable value to pass in. - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - // no memory limits - total_limit_in_bytes: None, - filter: KeyFilter::Within(0..=5), - from_offset: LogletOffset::new(1), - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(12))); - assert_that!(records.next_offset, eq(LogletOffset::new(12))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(4)); - // pop() returns records in reverse order - for i in [11, 10, 5, 2] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - if i == 5 { - // this one is filtered - assert_that!(record.is_filtered_gap(), eq(true)); - let gap = record.try_unwrap_filtered_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(5))); - } else { + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(12))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.next_offset, eq(LogletOffset::new(8))); + assert_that!(records.records.len(), eq(2)); + // pop in reverse order + for i in [5, 2] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); assert_that!(record.is_data(), eq(true)); let data = record.try_unwrap_data().unwrap(); let original: String = data.decode().unwrap(); assert_that!(original, eq(format!("record{}", i))); } - } - // Apply memory limits (2 bytes) should always see the first real record. - // We expect to see [FILTERED(5), 10]. (11 is not returend due to budget) - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - // INVALID can be used when we don't have a reasonable value to pass in. - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - // no memory limits - total_limit_in_bytes: Some(2), - filter: KeyFilter::Within(0..=5), - from_offset: LogletOffset::new(4), - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); + // We expect to see [2, FILTERED(5), 10, 11]. No trim gaps. + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + // INVALID can be used when we don't have a reasonable value to pass in. + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + // no memory limits + total_limit_in_bytes: None, + filter: KeyFilter::Within(0..=5), + from_offset: LogletOffset::new(1), + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(12))); + assert_that!(records.next_offset, eq(LogletOffset::new(12))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(4)); + // pop() returns records in reverse order + for i in [11, 10, 5, 2] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); + if i == 5 { + // this one is filtered + assert_that!(record.is_filtered_gap(), eq(true)); + let gap = record.try_unwrap_filtered_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(5))); + } else { + assert_that!(record.is_data(), eq(true)); + let data = record.try_unwrap_data().unwrap(); + let original: String = data.decode().unwrap(); + assert_that!(original, eq(format!("record{}", i))); + } + } - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(12))); - assert_that!(records.next_offset, eq(LogletOffset::new(11))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(2)); - // pop() returns records in reverse order - for i in [10, 5] { - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - if i == 5 { - // this one is filtered - assert_that!(record.is_filtered_gap(), eq(true)); - let gap = record.try_unwrap_filtered_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(5))); - } else { - assert_that!(record.is_data(), eq(true)); - let data = record.try_unwrap_data().unwrap(); - let original: String = data.decode().unwrap(); - assert_that!(original, eq(format!("record{}", i))); + // Apply memory limits (2 bytes) should always see the first real record. + // We expect to see [FILTERED(5), 10]. (11 is not returend due to budget) + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + // INVALID can be used when we don't have a reasonable value to pass in. + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + // no memory limits + total_limit_in_bytes: Some(2), + filter: KeyFilter::Within(0..=5), + from_offset: LogletOffset::new(4), + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(12))); + assert_that!(records.next_offset, eq(LogletOffset::new(11))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(2)); + // pop() returns records in reverse order + for i in [10, 5] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); + if i == 5 { + // this one is filtered + assert_that!(record.is_filtered_gap(), eq(true)); + let gap = record.try_unwrap_filtered_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(5))); + } else { + assert_that!(record.is_data(), eq(true)); + let data = record.try_unwrap_data().unwrap(); + let original: String = data.decode().unwrap(); + assert_that!(original, eq(format!("record{}", i))); + } } - } - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; + tc.shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; - Ok(()) + Ok(()) + } + .in_tc(&tc) + .await } #[test(tokio::test(start_paused = true))] async fn test_trim_basics() -> Result<()> { - const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); - const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); - let (tc, log_store) = setup().await?; - let loglet_state_map = LogletStateMap::default(); - let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); - - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - let worker = - LogletWorker::start(tc.clone(), LOGLET, log_store.clone(), loglet_state.clone())?; - - assert_that!(loglet_state.trim_point(), eq(LogletOffset::INVALID)); - assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::OLDEST)); - // The loglet has no knowledge of global commits, it shouldn't accept trims. - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::OLDEST), - trim_point: LogletOffset::OLDEST, - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Malformed)); - assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); - assert_that!(trimmed.sealed, eq(false)); - - // The loglet has knowledge of global tail of 10, it should accept trims up to 9 but it - // won't move trim point beyond its local tail. - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - trim_point: LogletOffset::new(9), - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Ok)); - assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); - assert_that!(trimmed.sealed, eq(false)); - - // let's store some records at offsets (5, 6) - worker - .enqueue_store(Incoming::for_testing( - connection.downgrade(), - Store { - // faking that offset=9 is released - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - timeout_at: None, - sequencer: SEQUENCER, - known_archived: LogletOffset::INVALID, - first_offset: LogletOffset::new(5), - flags: StoreFlags::empty(), - payloads: vec![Record::from("record5"), Record::from("record6")].into(), - }, - None, - )) - .unwrap(); - let stored: Stored = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(stored.status, eq(Status::Ok)); - assert_that!(stored.local_tail, eq(LogletOffset::new(7))); - - // trim to 5 - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - trim_point: LogletOffset::new(5), - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Ok)); - assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); - assert_that!(trimmed.sealed, eq(false)); - - // Attempt to read. We expect to see a trim gap (1->5, 6 (data-record)) - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - total_limit_in_bytes: None, - filter: KeyFilter::Any, - from_offset: LogletOffset::OLDEST, - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); + async { + const SEQUENCER: GenerationalNodeId = GenerationalNodeId::new(1, 1); + const LOGLET: ReplicatedLogletId = ReplicatedLogletId::new_unchecked(1); + let loglet_state_map = LogletStateMap::default(); + let (net_tx, mut net_rx) = mpsc::channel(10); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + let worker = LogletWorker::start(LOGLET, log_store.clone(), loglet_state.clone())?; + + assert_that!(loglet_state.trim_point(), eq(LogletOffset::INVALID)); + assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::OLDEST)); + // The loglet has no knowledge of global commits, it shouldn't accept trims. + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::OLDEST), + trim_point: LogletOffset::OLDEST, + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Malformed)); + assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); + assert_that!(trimmed.sealed, eq(false)); + + // The loglet has knowledge of global tail of 10, it should accept trims up to 9 but it + // won't move trim point beyond its local tail. + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + trim_point: LogletOffset::new(9), + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Ok)); + assert_that!(trimmed.local_tail, eq(LogletOffset::OLDEST)); + assert_that!(trimmed.sealed, eq(false)); + + // let's store some records at offsets (5, 6) + worker + .enqueue_store(Incoming::for_testing( + connection.downgrade(), + Store { + // faking that offset=9 is released + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + timeout_at: None, + sequencer: SEQUENCER, + known_archived: LogletOffset::INVALID, + first_offset: LogletOffset::new(5), + flags: StoreFlags::empty(), + payloads: vec![Record::from("record5"), Record::from("record6")].into(), + }, + None, + )) + .unwrap(); + let stored: Stored = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(stored.status, eq(Status::Ok)); + assert_that!(stored.local_tail, eq(LogletOffset::new(7))); + + // trim to 5 + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + trim_point: LogletOffset::new(5), + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Ok)); + assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); + assert_that!(trimmed.sealed, eq(false)); + + // Attempt to read. We expect to see a trim gap (1->5, 6 (data-record)) + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + total_limit_in_bytes: None, + filter: KeyFilter::Any, + from_offset: LogletOffset::OLDEST, + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(7))); + assert_that!(records.next_offset, eq(LogletOffset::new(7))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(2)); + // pop() returns records in reverse order + for i in [6, 1] { + let (offset, record) = records.records.pop().unwrap(); + assert_that!(offset, eq(LogletOffset::from(i))); + if i == 1 { + // this one is a trim gap + assert_that!(record.is_trim_gap(), eq(true)); + let gap = record.try_unwrap_trim_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(5))); + } else { + assert_that!(record.is_data(), eq(true)); + let data = record.try_unwrap_data().unwrap(); + let original: String = data.decode().unwrap(); + assert_that!(original, eq(format!("record{}", i))); + } + } - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(7))); - assert_that!(records.next_offset, eq(LogletOffset::new(7))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(2)); - // pop() returns records in reverse order - for i in [6, 1] { + // trim everything + worker + .enqueue_trim(Incoming::for_testing( + connection.downgrade(), + Trim { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), + trim_point: LogletOffset::new(9), + }, + None, + )) + .unwrap(); + + let trimmed: Trimmed = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(trimmed.status, eq(Status::Ok)); + assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); + assert_that!(trimmed.sealed, eq(false)); + + // Attempt to read again. We expect to see a trim gap (1->6) + worker + .enqueue_get_records(Incoming::for_testing( + connection.downgrade(), + GetRecords { + header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), + total_limit_in_bytes: None, + filter: KeyFilter::Any, + from_offset: LogletOffset::OLDEST, + // to a point beyond local tail + to_offset: LogletOffset::new(100), + }, + None, + )) + .unwrap(); + + let mut records: Records = net_rx + .recv() + .await + .unwrap() + .body + .unwrap() + .try_decode(connection.protocol_version())?; + assert_that!(records.status, eq(Status::Ok)); + assert_that!(records.local_tail, eq(LogletOffset::new(7))); + assert_that!(records.next_offset, eq(LogletOffset::new(7))); + assert_that!(records.sealed, eq(false)); + assert_that!(records.records.len(), eq(1)); let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(i))); - if i == 1 { - // this one is a trim gap - assert_that!(record.is_trim_gap(), eq(true)); - let gap = record.try_unwrap_trim_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(5))); - } else { - assert_that!(record.is_data(), eq(true)); - let data = record.try_unwrap_data().unwrap(); - let original: String = data.decode().unwrap(); - assert_that!(original, eq(format!("record{}", i))); - } + assert_that!(offset, eq(LogletOffset::from(1))); + assert_that!(record.is_trim_gap(), eq(true)); + let gap = record.try_unwrap_trim_gap().unwrap(); + assert_that!(gap.to, eq(LogletOffset::new(6))); + + // Make sure that we can load the local-tail correctly when loading the loglet_state + let loglet_state_map = LogletStateMap::default(); + let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; + assert_that!(loglet_state.trim_point(), eq(LogletOffset::new(6))); + assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::new(7))); + + tc.shutdown_node("test completed", 0).await; + RocksDbManager::get().shutdown().await; + Ok(()) } - - // trim everything - worker - .enqueue_trim(Incoming::for_testing( - connection.downgrade(), - Trim { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::new(10)), - trim_point: LogletOffset::new(9), - }, - None, - )) - .unwrap(); - - let trimmed: Trimmed = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(trimmed.status, eq(Status::Ok)); - assert_that!(trimmed.local_tail, eq(LogletOffset::new(7))); - assert_that!(trimmed.sealed, eq(false)); - - // Attempt to read again. We expect to see a trim gap (1->6) - worker - .enqueue_get_records(Incoming::for_testing( - connection.downgrade(), - GetRecords { - header: LogServerRequestHeader::new(LOGLET, LogletOffset::INVALID), - total_limit_in_bytes: None, - filter: KeyFilter::Any, - from_offset: LogletOffset::OLDEST, - // to a point beyond local tail - to_offset: LogletOffset::new(100), - }, - None, - )) - .unwrap(); - - let mut records: Records = net_rx - .recv() - .await - .unwrap() - .body - .unwrap() - .try_decode(connection.protocol_version())?; - assert_that!(records.status, eq(Status::Ok)); - assert_that!(records.local_tail, eq(LogletOffset::new(7))); - assert_that!(records.next_offset, eq(LogletOffset::new(7))); - assert_that!(records.sealed, eq(false)); - assert_that!(records.records.len(), eq(1)); - let (offset, record) = records.records.pop().unwrap(); - assert_that!(offset, eq(LogletOffset::from(1))); - assert_that!(record.is_trim_gap(), eq(true)); - let gap = record.try_unwrap_trim_gap().unwrap(); - assert_that!(gap.to, eq(LogletOffset::new(6))); - - // Make sure that we can load the local-tail correctly when loading the loglet_state - let loglet_state_map = LogletStateMap::default(); - let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; - assert_that!(loglet_state.trim_point(), eq(LogletOffset::new(6))); - assert_that!(loglet_state.local_tail().offset(), eq(LogletOffset::new(7))); - - tc.shutdown_node("test completed", 0).await; - RocksDbManager::get().shutdown().await; - - Ok(()) + .in_tc(&tc) + .await } } diff --git a/crates/log-server/src/network.rs b/crates/log-server/src/network.rs index 419c4e12f..d8ebad0d0 100644 --- a/crates/log-server/src/network.rs +++ b/crates/log-server/src/network.rs @@ -20,8 +20,8 @@ use tokio_stream::StreamExt as TokioStreamExt; use tracing::{debug, trace}; use xxhash_rust::xxh3::Xxh3Builder; +use restate_core::cancellation_watcher; use restate_core::network::{Incoming, MessageRouterBuilder, MessageStream}; -use restate_core::{cancellation_watcher, Metadata, TaskCenter}; use restate_types::config::Configuration; use restate_types::health::HealthStatus; use restate_types::live::Live; @@ -39,8 +39,6 @@ const DEFAULT_WRITERS_CAPACITY: usize = 128; type LogletWorkerMap = HashMap; pub struct RequestPump { - task_center: TaskCenter, - _metadata: Metadata, _configuration: Live, store_stream: MessageStream, release_stream: MessageStream, @@ -54,8 +52,6 @@ pub struct RequestPump { impl RequestPump { pub fn new( - task_center: TaskCenter, - metadata: Metadata, mut configuration: Live, router_builder: &mut MessageRouterBuilder, ) -> Self { @@ -74,8 +70,6 @@ impl RequestPump { let wait_for_tail_stream = router_builder.subscribe_to_stream(queue_length); let get_digest_stream = router_builder.subscribe_to_stream(queue_length); Self { - task_center, - _metadata: metadata, _configuration: configuration, store_stream, release_stream, @@ -100,7 +94,6 @@ impl RequestPump { S: LogStore + Clone + Sync + Send + 'static, { let RequestPump { - task_center, mut store_stream, mut release_stream, mut seal_stream, @@ -148,7 +141,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( get_digest.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -160,7 +152,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( wait_for_tail.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -172,7 +163,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( release.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -184,7 +174,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( seal.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -196,7 +185,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( get_loglet_info.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -208,7 +196,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( get_records.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -220,7 +207,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( trim.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -232,7 +218,6 @@ impl RequestPump { let worker = Self::find_or_create_worker( store.body().header.loglet_id, &log_store, - &task_center, &state_map, &mut loglet_workers, ).await?; @@ -329,7 +314,6 @@ impl RequestPump { async fn find_or_create_worker<'a, S: LogStore>( loglet_id: ReplicatedLogletId, log_store: &S, - task_center: &TaskCenter, state_map: &LogletStateMap, loglet_workers: &'a mut LogletWorkerMap, ) -> anyhow::Result<&'a LogletWorkerHandle> { @@ -338,12 +322,7 @@ impl RequestPump { .get_or_load(loglet_id, log_store) .await .context("cannot load loglet state map from logstore")?; - let handle = LogletWorker::start( - task_center.clone(), - loglet_id, - log_store.clone(), - state.clone(), - )?; + let handle = LogletWorker::start(loglet_id, log_store.clone(), state.clone())?; e.insert(handle); } diff --git a/crates/log-server/src/rocksdb_logstore/builder.rs b/crates/log-server/src/rocksdb_logstore/builder.rs index d442a9d4f..9a529603c 100644 --- a/crates/log-server/src/rocksdb_logstore/builder.rs +++ b/crates/log-server/src/rocksdb_logstore/builder.rs @@ -16,7 +16,7 @@ use restate_types::protobuf::common::LogServerStatus; use rocksdb::{DBCompressionType, SliceTransform}; use static_assertions::const_assert; -use restate_core::{ShutdownError, TaskCenter}; +use restate_core::ShutdownError; use restate_rocksdb::{CfExactPattern, CfName, DbName, DbSpecBuilder, RocksDb, RocksDbManager}; use restate_types::config::{LogServerOptions, RocksDbOptions}; use restate_types::live::BoxedLiveLoad; @@ -81,7 +81,6 @@ impl RocksDbLogStoreBuilder { pub async fn start( self, - task_center: &TaskCenter, health_status: HealthStatus, ) -> Result { let RocksDbLogStoreBuilder { @@ -96,7 +95,7 @@ impl RocksDbLogStoreBuilder { record_cache, health_status.clone(), ) - .start(task_center)?; + .start()?; Ok(RocksDbLogStore { health_status, diff --git a/crates/log-server/src/rocksdb_logstore/store.rs b/crates/log-server/src/rocksdb_logstore/store.rs index a54b74a16..6c0a43364 100644 --- a/crates/log-server/src/rocksdb_logstore/store.rs +++ b/crates/log-server/src/rocksdb_logstore/store.rs @@ -497,7 +497,7 @@ mod tests { use googletest::prelude::*; use test_log::test; - use restate_core::{TaskCenter, TaskCenterBuilder}; + use restate_core::{TaskCenter, TaskCenterBuilder, TaskCenterFutureExt}; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -517,22 +517,22 @@ mod tests { async fn setup() -> Result<(TaskCenter, RocksDbLogStore)> { setup_panic_handler(); let tc = TaskCenterBuilder::default_for_tests().build()?; - let config = Live::from_value(Configuration::default()); - let common_rocks_opts = config.clone().map(|c| &c.common); - let log_store = tc - .run_in_scope("test-setup", None, async { - RocksDbManager::init(common_rocks_opts); - // create logstore. - let builder = RocksDbLogStoreBuilder::create( - config.clone().map(|c| &c.log_server).boxed(), - config.map(|c| &c.log_server.rocksdb).boxed(), - RecordCache::new(1_000_000), - ) - .await?; - let log_store = builder.start(&tc, Default::default()).await?; - Result::Ok(log_store) - }) + let log_store = async { + let config = Live::from_value(Configuration::default()); + let common_rocks_opts = config.clone().map(|c| &c.common); + RocksDbManager::init(common_rocks_opts); + // create logstore. + let builder = RocksDbLogStoreBuilder::create( + config.clone().map(|c| &c.log_server).boxed(), + config.map(|c| &c.log_server.rocksdb).boxed(), + RecordCache::new(1_000_000), + ) .await?; + let log_store = builder.start(Default::default()).await?; + Result::Ok(log_store) + } + .in_tc(&tc) + .await?; Ok((tc, log_store)) } diff --git a/crates/log-server/src/rocksdb_logstore/writer.rs b/crates/log-server/src/rocksdb_logstore/writer.rs index 48c760b0d..08498c49a 100644 --- a/crates/log-server/src/rocksdb_logstore/writer.rs +++ b/crates/log-server/src/rocksdb_logstore/writer.rs @@ -95,7 +95,7 @@ impl LogStoreWriter { } /// Must be called from task_center context - pub fn start(mut self, tc: &TaskCenter) -> Result { + pub fn start(mut self) -> Result { // big enough to allow a second full batch to queue up while the existing one is being processed let batch_size = std::cmp::max( 1, @@ -107,10 +107,9 @@ impl LogStoreWriter { // the backlog while we process this one. let (sender, receiver) = mpsc::channel(batch_size * 2); - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, "log-server-rocksdb-writer", - None, async move { debug!("Start running LogStoreWriter"); let mut opts = self.updateable_options.clone(); diff --git a/crates/log-server/src/service.rs b/crates/log-server/src/service.rs index 9199df121..32587d73c 100644 --- a/crates/log-server/src/service.rs +++ b/crates/log-server/src/service.rs @@ -38,7 +38,6 @@ use crate::rocksdb_logstore::RocksDbLogStoreBuilder; pub struct LogServerService { health_status: HealthStatus, updateable_config: Live, - task_center: TaskCenter, metadata: Metadata, request_processor: RequestPump, metadata_store_client: MetadataStoreClient, @@ -49,7 +48,6 @@ impl LogServerService { pub async fn create( health_status: HealthStatus, updateable_config: Live, - task_center: TaskCenter, metadata: Metadata, metadata_store_client: MetadataStoreClient, record_cache: RecordCache, @@ -58,17 +56,11 @@ impl LogServerService { describe_metrics(); health_status.update(LogServerStatus::StartingUp); - let request_processor = RequestPump::new( - task_center.clone(), - metadata.clone(), - updateable_config.clone(), - router_builder, - ); + let request_processor = RequestPump::new(updateable_config.clone(), router_builder); Ok(Self { health_status, updateable_config, - task_center, metadata, request_processor, metadata_store_client, @@ -84,7 +76,6 @@ impl LogServerService { let LogServerService { health_status, updateable_config, - task_center, metadata, request_processor: request_pump, mut metadata_store_client, @@ -102,7 +93,7 @@ impl LogServerService { // 2. Fire up the log store. let mut log_store = log_store_builder - .start(&task_center, health_status.clone()) + .start(health_status.clone()) .await .context("Couldn't start log-server's log store")?; @@ -132,10 +123,9 @@ impl LogServerService { crate::protobuf::FILE_DESCRIPTOR_SET, ); - let _ = task_center.spawn_child( + let _ = TaskCenter::spawn_child( TaskKind::SystemService, "log-server", - None, request_pump.run(health_status, log_store, state_map, storage_state), )?; diff --git a/crates/metadata-store/src/local/service.rs b/crates/metadata-store/src/local/service.rs index d8d634b69..17cf85d4a 100644 --- a/crates/metadata-store/src/local/service.rs +++ b/crates/metadata-store/src/local/service.rs @@ -18,7 +18,7 @@ use tower::ServiceExt; use tower_http::classify::{GrpcCode, GrpcErrorsAsFailures, SharedClassifier}; use restate_core::network::net_util; -use restate_core::{task_center, ShutdownError, TaskKind}; +use restate_core::{ShutdownError, TaskCenter, TaskKind}; use restate_rocksdb::RocksError; use restate_types::config::{MetadataStoreOptions, RocksDbOptions}; use restate_types::live::BoxedLiveLoad; @@ -106,22 +106,17 @@ impl LocalMetadataStoreService { .map_request(|req: Request| req.map(boxed)), ); - task_center().spawn_child( - TaskKind::RpcServer, - "metadata-store-grpc", - None, - async move { - net_util::run_hyper_server( - &bind_address, - service, - "metadata-store-grpc", - || health_status.update(MetadataServerStatus::Ready), - || health_status.update(MetadataServerStatus::Unknown), - ) - .await?; - Ok(()) - }, - )?; + TaskCenter::spawn_child(TaskKind::RpcServer, "metadata-store-grpc", async move { + net_util::run_hyper_server( + &bind_address, + service, + "metadata-store-grpc", + || health_status.update(MetadataServerStatus::Ready), + || health_status.update(MetadataServerStatus::Unknown), + ) + .await?; + Ok(()) + })?; store.run().await; diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index 5759b1562..d4db04716 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -25,6 +25,7 @@ use restate_core::network::{ use restate_core::partitions::{spawn_partition_routing_refresher, PartitionRoutingRefresher}; use restate_core::{ spawn_metadata_manager, MetadataBuilder, MetadataKind, MetadataManager, TargetVersion, + TaskCenter, }; use restate_core::{task_center, TaskKind}; #[cfg(feature = "replicated-loglet")] @@ -197,7 +198,6 @@ impl Node { LogServerService::create( health.log_server_status(), updateable_config.clone(), - tc.clone(), metadata.clone(), metadata_store_client.clone(), record_cache, @@ -329,11 +329,11 @@ impl Node { let metadata_writer = self.metadata_manager.writer(); let metadata = self.metadata_manager.metadata().clone(); - let is_set = tc.try_set_global_metadata(metadata.clone()); + let is_set = TaskCenter::try_set_global_metadata(metadata.clone()); debug_assert!(is_set, "Global metadata was already set"); // Start metadata manager - spawn_metadata_manager(&tc, self.metadata_manager)?; + spawn_metadata_manager(self.metadata_manager)?; // Start partition routing information refresher spawn_partition_routing_refresher(&tc, self.partition_routing_refresher)?; @@ -441,7 +441,7 @@ impl Node { } if let Some(ingress_role) = self.ingress_role { - tc.spawn_child(TaskKind::Ingress, "ingress-http", None, ingress_role.run())?; + TaskCenter::spawn_child(TaskKind::Ingress, "ingress-http", ingress_role.run())?; } tc.spawn(TaskKind::RpcServer, "node-rpc-server", None, { diff --git a/crates/node/src/roles/admin.rs b/crates/node/src/roles/admin.rs index ac832358a..d2fb4a7eb 100644 --- a/crates/node/src/roles/admin.rs +++ b/crates/node/src/roles/admin.rs @@ -20,7 +20,7 @@ use restate_core::network::NetworkServerBuilder; use restate_core::network::Networking; use restate_core::network::TransportConnect; use restate_core::partitions::PartitionRouting; -use restate_core::{task_center, Metadata, MetadataWriter, TaskCenter, TaskKind}; +use restate_core::{Metadata, MetadataWriter, TaskCenter, TaskKind}; use restate_service_client::{AssumeRoleCacheMode, ServiceClient}; use restate_service_protocol::discovery::ServiceDiscovery; use restate_storage_query_datafusion::context::{QueryContext, SelectPartitionsFromMetadata}; @@ -122,7 +122,6 @@ impl AdminRole { updateable_config.clone(), health_status, bifrost, - task_center, metadata, networking, router_builder, @@ -142,21 +141,17 @@ impl AdminRole { } pub async fn start(self) -> Result<(), anyhow::Error> { - let tc = task_center(); - if let Some(cluster_controller) = self.controller { - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, "cluster-controller-service", - None, cluster_controller.run(), )?; } - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::RpcServer, "admin-rpc-server", - None, self.admin.run(self.updateable_config.map(|c| &c.admin)), )?; diff --git a/crates/node/src/roles/base.rs b/crates/node/src/roles/base.rs index 84c8b6645..e6140dea4 100644 --- a/crates/node/src/roles/base.rs +++ b/crates/node/src/roles/base.rs @@ -14,9 +14,8 @@ use futures::StreamExt; use restate_core::{ cancellation_watcher, network::{Incoming, MessageRouterBuilder, MessageStream, NetworkError}, - task_center, worker_api::ProcessorsManagerHandle, - ShutdownError, TaskKind, + ShutdownError, TaskCenter, TaskKind, }; use restate_types::net::node::{GetNodeState, NodeStateResponse}; @@ -39,8 +38,7 @@ impl BaseRole { } pub fn start(self) -> anyhow::Result<()> { - let tc = task_center(); - tc.spawn_child(TaskKind::RoleRunner, "base-role-service", None, async { + TaskCenter::spawn_child(TaskKind::RoleRunner, "base-role-service", async { let cancelled = cancellation_watcher(); tokio::select! { diff --git a/crates/node/src/roles/worker.rs b/crates/node/src/roles/worker.rs index 66bda22ce..23851da8a 100644 --- a/crates/node/src/roles/worker.rs +++ b/crates/node/src/roles/worker.rs @@ -16,7 +16,8 @@ use restate_core::network::Networking; use restate_core::network::TransportConnect; use restate_core::partitions::PartitionRouting; use restate_core::worker_api::ProcessorsManagerHandle; -use restate_core::{cancellation_watcher, task_center, Metadata, MetadataKind}; +use restate_core::TaskCenter; +use restate_core::{cancellation_watcher, Metadata, MetadataKind}; use restate_core::{ShutdownError, TaskKind}; use restate_metadata_store::MetadataStoreClient; use restate_storage_query_datafusion::context::QueryContext; @@ -103,16 +104,14 @@ impl WorkerRole { } pub async fn start(self) -> anyhow::Result<()> { - let tc = task_center(); // todo: only run subscriptions on node 0 once being distributed - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::MetadataBackgroundSync, "subscription_controller", - None, Self::watch_subscriptions(self.metadata, self.worker.subscription_controller_handle()), )?; - tc.spawn_child(TaskKind::RoleRunner, "worker-service", None, async { + TaskCenter::spawn_child(TaskKind::RoleRunner, "worker-service", async { self.worker.run().await })?; diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index fbc9f4ad4..416fb4784 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -20,6 +20,7 @@ mod subscription_controller; mod subscription_integration; use codederror::CodedError; +use restate_core::TaskCenter; use std::time::Duration; use restate_bifrost::Bifrost; @@ -199,37 +200,31 @@ impl Worker { } pub async fn run(self) -> anyhow::Result<()> { - let tc = task_center(); - // Postgres external server - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::RpcServer, "postgres-query-server", - None, self.storage_query_postgres.run(), )?; // Datafusion remote scanner - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, "datafusion-scan-server", - None, self.datafusion_remote_scanner.run(), )?; // Kafka Ingress - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, "kafka-ingress", - None, self.ingress_kafka .run(self.updateable_config.clone().map(|c| &c.ingress)), )?; - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, "partition-processor-manager", - None, self.partition_processor_manager.run(), )?; diff --git a/crates/worker/src/partition/leadership.rs b/crates/worker/src/partition/leadership.rs index b613078de..f18d7da82 100644 --- a/crates/worker/src/partition/leadership.rs +++ b/crates/worker/src/partition/leadership.rs @@ -395,12 +395,8 @@ where let shuffle_hint_tx = shuffle.create_hint_sender(); - let shuffle_task_id = task_center().spawn_child( - TaskKind::Shuffle, - "shuffle", - Some(self.partition_processor_metadata.partition_id), - shuffle.run(), - )?; + let shuffle_task_id = + TaskCenter::spawn_child(TaskKind::Shuffle, "shuffle", shuffle.run())?; let self_proposer = SelfProposer::new( self.partition_processor_metadata.partition_id, @@ -421,12 +417,8 @@ where self.cleanup_interval, ); - let cleaner_task_id = task_center().spawn_child( - TaskKind::Cleaner, - "cleaner", - Some(self.partition_processor_metadata.partition_id), - cleaner.run(), - )?; + let cleaner_task_id = + TaskCenter::spawn_child(TaskKind::Cleaner, "cleaner", cleaner.run())?; self.state = State::Leader(LeaderState { leader_epoch, diff --git a/crates/worker/src/partition/shuffle.rs b/crates/worker/src/partition/shuffle.rs index b34a0af09..1ca1d1e2b 100644 --- a/crates/worker/src/partition/shuffle.rs +++ b/crates/worker/src/partition/shuffle.rs @@ -460,7 +460,9 @@ mod tests { use restate_bifrost::{Bifrost, LogEntry}; use restate_core::network::FailingConnector; - use restate_core::{TaskKind, TestCoreEnv, TestCoreEnvBuilder}; + use restate_core::{ + TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnv, TestCoreEnvBuilder, + }; use restate_storage_api::outbox_table::OutboxMessage; use restate_storage_api::StorageError; use restate_types::identifiers::{InvocationId, LeaderEpoch, PartitionId}; @@ -686,14 +688,9 @@ mod tests { let shuffle_env = create_shuffle_env(outbox_reader).await; let tc = shuffle_env.env.tc.clone(); - tc.run_in_scope("test", None, async { + async { let partition_id = shuffle_env.shuffle.metadata.partition_id; - tc.spawn_child( - TaskKind::Shuffle, - "shuffle", - None, - shuffle_env.shuffle.run(), - )?; + TaskCenter::spawn_child(TaskKind::Shuffle, "shuffle", shuffle_env.shuffle.run())?; let reader = shuffle_env.bifrost.create_reader( LogId::from(partition_id), KeyFilter::Any, @@ -706,7 +703,8 @@ mod tests { assert_received_invoke_commands(messages, expected_messages); Ok::<(), anyhow::Error>(()) - }) + } + .in_tc(&tc) .await } @@ -732,14 +730,9 @@ mod tests { let shuffle_env = create_shuffle_env(outbox_reader).await; let tc = shuffle_env.env.tc.clone(); - tc.run_in_scope("test", None, async { + async { let partition_id = shuffle_env.shuffle.metadata.partition_id; - tc.spawn_child( - TaskKind::Shuffle, - "shuffle", - None, - shuffle_env.shuffle.run(), - )?; + TaskCenter::spawn_child(TaskKind::Shuffle, "shuffle", shuffle_env.shuffle.run())?; let reader = shuffle_env.bifrost.create_reader( LogId::from(partition_id), KeyFilter::Any, @@ -752,7 +745,8 @@ mod tests { assert_received_invoke_commands(messages, expected_messages); Ok::<(), anyhow::Error>(()) - }) + } + .in_tc(&tc) .await } @@ -775,64 +769,63 @@ mod tests { let tc = shuffle_env.env.tc.clone(); let total_restarts = Arc::new(AtomicUsize::new(0)); - let shuffle_task_id = tc - .run_in_scope("test", None, async { - let partition_id = shuffle_env.shuffle.metadata.partition_id; - let reader = shuffle_env.bifrost.create_reader( - LogId::from(partition_id), - KeyFilter::Any, - Lsn::INVALID, - Lsn::MAX, - )?; - let total_restarts = Arc::clone(&total_restarts); - - let shuffle_task = - tc.spawn_child(TaskKind::Shuffle, "shuffle", None, async move { - let mut shuffle = shuffle_env.shuffle; - let metadata = shuffle.metadata; - let truncation_tx = shuffle.truncation_tx.clone(); - let mut processed_range = 0; - let mut num_restarts = 0; - - // restart shuffle on failures and update failing outbox reader - while shuffle.run().await.is_err() { - num_restarts += 1; - // update the failing outbox reader to make a bit more progress and delete some of the delivered records - { - let outbox_reader = Arc::get_mut(&mut outbox_reader) - .expect("only one reference should exist"); - - // leave the first entry to generate some holes - for idx in (processed_range + 1)..outbox_reader.fail_index { - outbox_reader.records[usize::try_from(idx) - .expect("index should fit in usize")] = None; - } + let shuffle_task_id = async { + let partition_id = shuffle_env.shuffle.metadata.partition_id; + let reader = shuffle_env.bifrost.create_reader( + LogId::from(partition_id), + KeyFilter::Any, + Lsn::INVALID, + Lsn::MAX, + )?; + let total_restarts = Arc::clone(&total_restarts); + + let shuffle_task = TaskCenter::spawn_child(TaskKind::Shuffle, "shuffle", async move { + let mut shuffle = shuffle_env.shuffle; + let metadata = shuffle.metadata; + let truncation_tx = shuffle.truncation_tx.clone(); + let mut processed_range = 0; + let mut num_restarts = 0; + + // restart shuffle on failures and update failing outbox reader + while shuffle.run().await.is_err() { + num_restarts += 1; + // update the failing outbox reader to make a bit more progress and delete some of the delivered records + { + let outbox_reader = Arc::get_mut(&mut outbox_reader) + .expect("only one reference should exist"); + + // leave the first entry to generate some holes + for idx in (processed_range + 1)..outbox_reader.fail_index { + outbox_reader.records + [usize::try_from(idx).expect("index should fit in usize")] = None; + } - processed_range = outbox_reader.fail_index; - outbox_reader.fail_index += 10; - } + processed_range = outbox_reader.fail_index; + outbox_reader.fail_index += 10; + } - shuffle = Shuffle::new( - metadata, - Arc::clone(&outbox_reader), - truncation_tx.clone(), - 1, - shuffle_env.bifrost.clone(), - ); - } + shuffle = Shuffle::new( + metadata, + Arc::clone(&outbox_reader), + truncation_tx.clone(), + 1, + shuffle_env.bifrost.clone(), + ); + } - total_restarts.store(num_restarts, Ordering::Relaxed); + total_restarts.store(num_restarts, Ordering::Relaxed); - Ok(()) - })?; + Ok(()) + })?; - let messages = collect_invoke_commands_until(reader, last_invocation_id).await?; + let messages = collect_invoke_commands_until(reader, last_invocation_id).await?; - assert_received_invoke_commands(messages, expected_messages); + assert_received_invoke_commands(messages, expected_messages); - Ok::<_, anyhow::Error>(shuffle_task) - }) - .await?; + Ok::<_, anyhow::Error>(shuffle_task) + } + .in_tc(&tc) + .await?; let shuffle_task = tc.cancel_task(shuffle_task_id).expect("should exist"); shuffle_task.await?; diff --git a/crates/worker/src/partition_processor_manager/message_handler.rs b/crates/worker/src/partition_processor_manager/message_handler.rs index ff5901b4e..394135b29 100644 --- a/crates/worker/src/partition_processor_manager/message_handler.rs +++ b/crates/worker/src/partition_processor_manager/message_handler.rs @@ -10,7 +10,7 @@ use restate_core::network::{Incoming, MessageHandler}; use restate_core::worker_api::ProcessorsManagerHandle; -use restate_core::{task_center, TaskKind}; +use restate_core::{TaskCenter, TaskKind}; use restate_types::net::partition_processor_manager::{ CreateSnapshotRequest, CreateSnapshotResponse, SnapshotError, }; @@ -36,49 +36,47 @@ impl MessageHandler for PartitionProcessorManagerMessageHandler { async fn on_message(&self, msg: Incoming) { let processors_manager_handle = self.processors_manager_handle.clone(); - task_center() - .spawn_child( - TaskKind::Disposable, - "create-snapshot-request-rpc", - None, - async move { - let create_snapshot_result = processors_manager_handle - .create_snapshot(msg.body().partition_id) - .await; + TaskCenter::spawn_child( + TaskKind::Disposable, + "create-snapshot-request-rpc", + async move { + let create_snapshot_result = processors_manager_handle + .create_snapshot(msg.body().partition_id) + .await; - match create_snapshot_result.as_ref() { - Ok(snapshot) => { - debug!( - partition_id = %msg.body().partition_id, - %snapshot, - "Create snapshot successfully completed", - ); - msg.to_rpc_response(CreateSnapshotResponse { - result: Ok(snapshot.snapshot_id), - }) - } - Err(err) => { - warn!( - partition_id = %msg.body().partition_id, - "Create snapshot failed: {}", - err - ); - msg.to_rpc_response(CreateSnapshotResponse { - result: Err(SnapshotError::SnapshotCreationFailed(err.to_string())), - }) - } + match create_snapshot_result.as_ref() { + Ok(snapshot) => { + debug!( + partition_id = %msg.body().partition_id, + %snapshot, + "Create snapshot successfully completed", + ); + msg.to_rpc_response(CreateSnapshotResponse { + result: Ok(snapshot.snapshot_id), + }) } - .send() - .await - .map_err(|e| { - warn!(result = ?create_snapshot_result, "Failed to send response: {}", e); - anyhow::anyhow!("Failed to send response to create snapshot request: {}", e) - }) - }, - ) - .map_err(|e| { - warn!("Failed to spawn request handler: {}", e); - }) - .ok(); + Err(err) => { + warn!( + partition_id = %msg.body().partition_id, + "Create snapshot failed: {}", + err + ); + msg.to_rpc_response(CreateSnapshotResponse { + result: Err(SnapshotError::SnapshotCreationFailed(err.to_string())), + }) + } + } + .send() + .await + .map_err(|e| { + warn!(result = ?create_snapshot_result, "Failed to send response: {}", e); + anyhow::anyhow!("Failed to send response to create snapshot request: {}", e) + }) + }, + ) + .map_err(|e| { + warn!("Failed to spawn request handler: {}", e); + }) + .ok(); } } diff --git a/crates/worker/src/partition_processor_manager/mod.rs b/crates/worker/src/partition_processor_manager/mod.rs index ade4c7574..f528de7e6 100644 --- a/crates/worker/src/partition_processor_manager/mod.rs +++ b/crates/worker/src/partition_processor_manager/mod.rs @@ -229,12 +229,7 @@ impl PartitionProcessorManager { self.partition_store_manager.clone(), persisted_lsns_tx, ); - self.task_center.spawn_child( - TaskKind::Watchdog, - "persisted-lsn-watchdog", - None, - watchdog.run(), - )?; + TaskCenter::spawn_child(TaskKind::Watchdog, "persisted-lsn-watchdog", watchdog.run())?; let mut logs_version_watcher = self.metadata.watch(MetadataKind::Logs); let mut partition_table_version_watcher = self.metadata.watch(MetadataKind::PartitionTable); diff --git a/crates/worker/src/partition_processor_manager/spawn_processor_task.rs b/crates/worker/src/partition_processor_manager/spawn_processor_task.rs index a6fb8c413..13a58d0a0 100644 --- a/crates/worker/src/partition_processor_manager/spawn_processor_task.rs +++ b/crates/worker/src/partition_processor_manager/spawn_processor_task.rs @@ -13,12 +13,8 @@ use std::ops::RangeInclusive; use tokio::sync::{mpsc, watch}; use tracing::instrument; -use crate::invoker_integration::EntryEnricher; -use crate::partition::invoker_storage_reader::InvokerStorageReader; -use crate::partition_processor_manager::processor_state::StartedProcessor; -use crate::PartitionProcessorBuilder; use restate_bifrost::Bifrost; -use restate_core::{task_center, Metadata, RuntimeRootTaskHandle, TaskKind}; +use restate_core::{Metadata, RuntimeRootTaskHandle, TaskCenter, TaskKind}; use restate_invoker_impl::Service as InvokerService; use restate_partition_store::{OpenMode, PartitionStore, PartitionStoreManager}; use restate_service_protocol::codec::ProtobufRawEntryCodec; @@ -29,6 +25,11 @@ use restate_types::live::Live; use restate_types::schema::Schema; use restate_types::GenerationalNodeId; +use crate::invoker_integration::EntryEnricher; +use crate::partition::invoker_storage_reader::InvokerStorageReader; +use crate::partition_processor_manager::processor_state::StartedProcessor; +use crate::PartitionProcessorBuilder; + pub struct SpawnPartitionProcessorTask { task_name: &'static str, node_id: GenerationalNodeId, @@ -121,8 +122,7 @@ impl SpawnPartitionProcessorTask { let invoker_name = Box::leak(Box::new(format!("invoker-{}", partition_id))); let invoker_config = configuration.clone().map(|c| &c.worker.invoker); - let tc = task_center(); - let root_task_handle = tc.clone().start_runtime( + let root_task_handle = TaskCenter::current().start_runtime( TaskKind::PartitionProcessor, task_name, Some(pp_builder.partition_id), @@ -139,15 +139,18 @@ impl SpawnPartitionProcessorTask { .await?; move || async move { - tc.spawn_child( + TaskCenter::spawn_child( TaskKind::SystemService, invoker_name, - Some(pp_builder.partition_id), invoker.run(invoker_config), )?; pp_builder - .build::(tc, bifrost, partition_store) + .build::( + TaskCenter::current(), + bifrost, + partition_store, + ) .await? .run() .await diff --git a/tools/bifrost-benchpress/src/main.rs b/tools/bifrost-benchpress/src/main.rs index abe9d6a4a..0bdf283d5 100644 --- a/tools/bifrost-benchpress/src/main.rs +++ b/tools/bifrost-benchpress/src/main.rs @@ -157,7 +157,7 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, MetadataManager::new(metadata_builder, metadata_store_client.clone()); let metadata_writer = metadata_manager.writer(); - task_center.try_set_global_metadata(metadata.clone()); + TaskCenter::try_set_global_metadata(metadata.clone()); RocksDbManager::init(config.clone().map(|c| &c.common)); @@ -172,7 +172,7 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, .await .expect("to store bifrost config in metadata store"); metadata_writer.submit(Arc::new(logs)); - spawn_metadata_manager(&task_center, metadata_manager).expect("metadata manager starts"); + spawn_metadata_manager(metadata_manager).expect("metadata manager starts"); let bifrost_svc = BifrostService::new(task_center, metadata) .enable_in_memory_loglet() diff --git a/tools/restatectl/src/commands/log/dump_log.rs b/tools/restatectl/src/commands/log/dump_log.rs index 3ab6ef6e9..d99e749f5 100644 --- a/tools/restatectl/src/commands/log/dump_log.rs +++ b/tools/restatectl/src/commands/log/dump_log.rs @@ -18,7 +18,7 @@ use tracing::{debug, info}; use restate_bifrost::BifrostService; use restate_core::network::MessageRouterBuilder; -use restate_core::{MetadataBuilder, MetadataManager, TaskKind}; +use restate_core::{MetadataBuilder, MetadataManager, TaskCenter, TaskKind}; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -59,7 +59,7 @@ struct DecodedLogRecord { } async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { - run_in_task_center(opts.config_file.as_ref(), |config, tc| async move { + run_in_task_center(opts.config_file.as_ref(), |config| async move { if !config.bifrost.local.data_dir().exists() { bail!( "The specified path '{}' does not contain a local-loglet directory.", @@ -72,7 +72,7 @@ async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { let metadata_builder = MetadataBuilder::default(); let metadata = metadata_builder.to_metadata(); - tc.try_set_global_metadata(metadata.clone()); + TaskCenter::try_set_global_metadata(metadata.clone()); let metadata_store_client = metadata_store::start_metadata_store( config.common.metadata_store_client.clone(), @@ -80,7 +80,7 @@ async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { Live::from_value(config.metadata_store.clone()) .map(|c| &c.rocksdb) .boxed(), - &tc, + &TaskCenter::current(), ) .await?; debug!("Metadata store client created"); @@ -90,19 +90,19 @@ async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { let mut router_builder = MessageRouterBuilder::default(); metadata_manager.register_in_message_router(&mut router_builder); - tc.spawn( + TaskCenter::current().spawn( TaskKind::SystemService, "metadata-manager", None, metadata_manager.run(), )?; - let bifrost_svc = BifrostService::new(tc.clone(), metadata.clone()) + let bifrost_svc = BifrostService::new(TaskCenter::current(), metadata.clone()) .enable_local_loglet(&Configuration::updateable()); let bifrost = bifrost_svc.handle(); // Ensures bifrost has initial metadata synced up before starting the worker. - // Need to run start in new tc scope to have access to metadata() + // Need to run start in tc scope to have access to metadata() bifrost_svc.start().await?; let log_id = LogId::from(opts.log_id); diff --git a/tools/restatectl/src/commands/metadata/get.rs b/tools/restatectl/src/commands/metadata/get.rs index 7c04861be..6a1533dc9 100644 --- a/tools/restatectl/src/commands/metadata/get.rs +++ b/tools/restatectl/src/commands/metadata/get.rs @@ -13,6 +13,7 @@ use clap::Parser; use cling::{Collect, Run}; use tracing::debug; +use restate_core::TaskCenter; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -57,32 +58,28 @@ async fn get_value_remote(opts: &GetValueOpts) -> anyhow::Result anyhow::Result> { - run_in_task_center( - opts.metadata.config_file.as_ref(), - |config, task_center| async move { - let rocksdb_manager = - RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); - debug!("RocksDB Initialized"); + run_in_task_center(opts.metadata.config_file.as_ref(), |config| async move { + let rocksdb_manager = RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); + debug!("RocksDB Initialized"); - let metadata_store_client = metadata_store::start_metadata_store( - config.common.metadata_store_client.clone(), - Live::from_value(config.metadata_store.clone()).boxed(), - Live::from_value(config.metadata_store.clone()) - .map(|c| &c.rocksdb) - .boxed(), - &task_center, - ) - .await?; - debug!("Metadata store client created"); + let metadata_store_client = metadata_store::start_metadata_store( + config.common.metadata_store_client.clone(), + Live::from_value(config.metadata_store.clone()).boxed(), + Live::from_value(config.metadata_store.clone()) + .map(|c| &c.rocksdb) + .boxed(), + &TaskCenter::current(), + ) + .await?; + debug!("Metadata store client created"); - let value: Option = metadata_store_client - .get(ByteString::from(opts.key.as_str())) - .await - .map_err(|e| anyhow::anyhow!("Failed to get value: {}", e))?; + let value: Option = metadata_store_client + .get(ByteString::from(opts.key.as_str())) + .await + .map_err(|e| anyhow::anyhow!("Failed to get value: {}", e))?; - rocksdb_manager.shutdown().await; - anyhow::Ok(value) - }, - ) + rocksdb_manager.shutdown().await; + anyhow::Ok(value) + }) .await } diff --git a/tools/restatectl/src/commands/metadata/patch.rs b/tools/restatectl/src/commands/metadata/patch.rs index 55ced61a8..bceeb3acf 100644 --- a/tools/restatectl/src/commands/metadata/patch.rs +++ b/tools/restatectl/src/commands/metadata/patch.rs @@ -15,6 +15,7 @@ use json_patch::Patch; use tracing::debug; use restate_core::metadata_store::{MetadataStoreClient, Precondition}; +use restate_core::TaskCenter; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -79,30 +80,26 @@ async fn patch_value_direct( opts: &PatchValueOpts, patch: Patch, ) -> anyhow::Result> { - let value = run_in_task_center( - opts.metadata.config_file.as_ref(), - |config, task_center| async move { - let rocksdb_manager = - RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); - debug!("RocksDB Initialized"); - - let metadata_store_client = start_metadata_store( - config.common.metadata_store_client.clone(), - Live::from_value(config.metadata_store.clone()).boxed(), - Live::from_value(config.metadata_store.clone()) - .map(|c| &c.rocksdb) - .boxed(), - &task_center, - ) - .await?; - debug!("Metadata store client created"); - - let result = patch_value_inner(opts, &patch, &metadata_store_client).await; - - rocksdb_manager.shutdown().await; - result - }, - ) + let value = run_in_task_center(opts.metadata.config_file.as_ref(), |config| async move { + let rocksdb_manager = RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); + debug!("RocksDB Initialized"); + + let metadata_store_client = start_metadata_store( + config.common.metadata_store_client.clone(), + Live::from_value(config.metadata_store.clone()).boxed(), + Live::from_value(config.metadata_store.clone()) + .map(|c| &c.rocksdb) + .boxed(), + &TaskCenter::current(), + ) + .await?; + debug!("Metadata store client created"); + + let result = patch_value_inner(opts, &patch, &metadata_store_client).await; + + rocksdb_manager.shutdown().await; + result + }) .await?; Ok(Some(value)) diff --git a/tools/restatectl/src/environment/task_center.rs b/tools/restatectl/src/environment/task_center.rs index 4e7a5e7a4..5e286aa4b 100644 --- a/tools/restatectl/src/environment/task_center.rs +++ b/tools/restatectl/src/environment/task_center.rs @@ -13,7 +13,7 @@ use std::path::PathBuf; use tracing::warn; -use restate_core::{TaskCenter, TaskCenterBuilder}; +use restate_core::TaskCenterBuilder; use restate_types::config::Configuration; use restate_types::config_loader::ConfigLoaderBuilder; use restate_types::live::Pinned; @@ -21,7 +21,7 @@ use restate_types::live::Pinned; /// Loads configuration, creates a task center, executes the supplied function body in scope of TC, and shuts down. pub async fn run_in_task_center(config_file: Option<&PathBuf>, fn_body: F) -> O::Output where - F: FnOnce(Pinned, TaskCenter) -> O, + F: FnOnce(Pinned) -> O, O: Future, { let config_path = config_file @@ -57,9 +57,7 @@ where .build() .expect("task_center builds"); - let result = task_center - .run_in_scope("main", None, fn_body(config, task_center.clone())) - .await; + let result = task_center.run_in_scope_sync(|| fn_body(config)).await; task_center.shutdown_node("finished", 0).await; result From 1df125250b2bec6ed6f840fea3923b87a18b2d50 Mon Sep 17 00:00:00 2001 From: Ahmed Farghal Date: Fri, 22 Nov 2024 17:24:42 +0000 Subject: [PATCH 3/4] [TaskCenter] Stage 4 --- .../admin/src/cluster_controller/service.rs | 399 +++++++++--------- crates/bifrost/benches/append_throughput.rs | 13 +- crates/bifrost/src/appender.rs | 6 +- crates/bifrost/src/bifrost.rs | 204 +++++---- crates/bifrost/src/bifrost_admin.rs | 17 +- crates/bifrost/src/read_stream.rs | 253 +++++------ crates/bifrost/src/service.rs | 38 +- crates/bifrost/src/watchdog.rs | 7 +- crates/node/src/lib.rs | 3 +- crates/worker/src/partition/cleaner.rs | 168 ++++---- crates/worker/src/partition/leadership.rs | 24 +- crates/worker/src/partition/shuffle.rs | 9 +- .../src/partition_processor_manager/mod.rs | 72 ++-- tools/bifrost-benchpress/src/main.rs | 3 +- tools/restatectl/src/commands/log/dump_log.rs | 3 +- tools/xtask/src/main.rs | 11 +- 16 files changed, 587 insertions(+), 643 deletions(-) diff --git a/crates/admin/src/cluster_controller/service.rs b/crates/admin/src/cluster_controller/service.rs index be9f47afe..4a07f5e35 100644 --- a/crates/admin/src/cluster_controller/service.rs +++ b/crates/admin/src/cluster_controller/service.rs @@ -484,7 +484,10 @@ mod tests { use restate_core::network::{ FailingConnector, Incoming, MessageHandler, MockPeerConnection, NetworkServerBuilder, }; - use restate_core::{NoOpMessageHandler, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; + use restate_core::{ + NoOpMessageHandler, TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnv, + TestCoreEnvBuilder, + }; use restate_types::cluster::cluster_state::PartitionProcessorStatus; use restate_types::config::{AdminOptions, Configuration}; use restate_types::health::HealthStatus; @@ -501,53 +504,49 @@ mod tests { async fn manual_log_trim() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); let mut builder = TestCoreEnvBuilder::with_incoming_only_connector(); + let tc = builder.tc.clone(); + async { + let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); + let bifrost = bifrost_svc.handle(); + + let svc = Service::new( + Live::from_value(Configuration::default()), + HealthStatus::default(), + bifrost.clone(), + builder.metadata.clone(), + builder.networking.clone(), + &mut builder.router_builder, + &mut NetworkServerBuilder::default(), + builder.metadata_writer.clone(), + builder.metadata_store_client.clone(), + ); + let svc_handle = svc.handle(); + + let _ = builder.build().await; + bifrost_svc.start().await?; + + let mut appender = bifrost.create_appender(LOG_ID)?; + + TaskCenter::current().spawn( + TaskKind::SystemService, + "cluster-controller", + None, + svc.run(), + )?; + + for _ in 1..=5 { + appender.append("").await?; + } - let metadata = builder.metadata.clone(); - - let bifrost_svc = BifrostService::new(builder.tc.clone(), metadata.clone()) - .with_factory(memory_loglet::Factory::default()); - let bifrost = bifrost_svc.handle(); - - let svc = Service::new( - Live::from_value(Configuration::default()), - HealthStatus::default(), - bifrost.clone(), - builder.metadata.clone(), - builder.networking.clone(), - &mut builder.router_builder, - &mut NetworkServerBuilder::default(), - builder.metadata_writer.clone(), - builder.metadata_store_client.clone(), - ); - let svc_handle = svc.handle(); - - let node_env = builder.build().await; - bifrost_svc.start().await?; - - let mut appender = bifrost.create_appender(LOG_ID)?; - - node_env.tc.spawn( - TaskKind::SystemService, - "cluster-controller", - None, - svc.run(), - )?; - - node_env - .tc - .run_in_scope("test", None, async move { - for _ in 1..=5 { - appender.append("").await?; - } - - svc_handle.trim_log(LOG_ID, Lsn::from(3)).await??; + svc_handle.trim_log(LOG_ID, Lsn::from(3)).await??; - let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::OLDEST)); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(3)))); - Ok::<(), anyhow::Error>(()) - }) - .await?; + let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::OLDEST)); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(3)))); + Ok::<(), anyhow::Error>(()) + } + .in_tc(&tc) + .await?; Ok(()) } @@ -613,63 +612,60 @@ mod tests { .add_message_handler(NoOpMessageHandler::::default()) }) .await?; + let tc = node_env.tc.clone(); + + async move { + // simulate a connection from node 2 so we can have a connection between the two + // nodes + let node_2 = MockPeerConnection::connect( + GenerationalNodeId::new(2, 2), + node_env.metadata.nodes_config_version(), + node_env + .metadata + .nodes_config_ref() + .cluster_name() + .to_owned(), + node_env.networking.connection_manager(), + 10, + ) + .await?; + // let node2 receive messages and use the same message handler as node1 + let (_node_2, _node2_reactor) = node_2 + .process_with_message_handler(&TaskCenter::current(), get_node_state_handler)?; + + let mut appender = bifrost.create_appender(LOG_ID)?; + for i in 1..=20 { + let lsn = appender.append("").await?; + assert_eq!(Lsn::from(i), lsn); + } - node_env - .tc - .clone() - .run_in_scope("test", None, async move { - // simulate a connection from node 2 so we can have a connection between the two - // nodes - let node_2 = MockPeerConnection::connect( - GenerationalNodeId::new(2, 2), - node_env.metadata.nodes_config_version(), - node_env - .metadata - .nodes_config_ref() - .cluster_name() - .to_owned(), - node_env.networking.connection_manager(), - 10, - ) - .await?; - // let node2 receive messages and use the same message handler as node1 - let (_node_2, _node2_reactor) = - node_2.process_with_message_handler(&node_env.tc, get_node_state_handler)?; - - let mut appender = bifrost.create_appender(LOG_ID)?; - for i in 1..=20 { - let lsn = appender.append("").await?; - assert_eq!(Lsn::from(i), lsn); - } - - tokio::time::sleep(interval_duration * 10).await; - - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + tokio::time::sleep(interval_duration * 10).await; - // report persisted lsn back to cluster controller - persisted_lsn.store(6, Ordering::Relaxed); + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - tokio::time::sleep(interval_duration * 10).await; - // we delete 1-6. - assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); + // report persisted lsn back to cluster controller + persisted_lsn.store(6, Ordering::Relaxed); - // increase by 4 more, this should not overcome the threshold - persisted_lsn.store(10, Ordering::Relaxed); + tokio::time::sleep(interval_duration * 10).await; + // we delete 1-6. + assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); + // increase by 4 more, this should not overcome the threshold + persisted_lsn.store(10, Ordering::Relaxed); - // now we have reached the min threshold wrt to the last trim point - persisted_lsn.store(11, Ordering::Relaxed); + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::from(6), bifrost.get_trim_point(LOG_ID).await?); - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::from(11), bifrost.get_trim_point(LOG_ID).await?); + // now we have reached the min threshold wrt to the last trim point + persisted_lsn.store(11, Ordering::Relaxed); - Ok::<(), anyhow::Error>(()) - }) - .await?; + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::from(11), bifrost.get_trim_point(LOG_ID).await?); - Ok(()) + Ok::<(), anyhow::Error>(()) + } + .in_tc(&tc) + .await } #[test(tokio::test(start_paused = true))] @@ -699,58 +695,55 @@ mod tests { }) .await?; - node_env - .tc - .clone() - .run_in_scope("test", None, async move { - // simulate a connection from node 2 so we can have a connection between the two - // nodes - let node_2 = MockPeerConnection::connect( - GenerationalNodeId::new(2, 2), - node_env.metadata.nodes_config_version(), - node_env - .metadata - .nodes_config_ref() - .cluster_name() - .to_owned(), - node_env.networking.connection_manager(), - 10, - ) - .await?; - // let node2 receive messages and use the same message handler as node1 - let (_node_2, _node2_reactor) = - node_2.process_with_message_handler(&node_env.tc, get_node_state_handler)?; - - let mut appender = bifrost.create_appender(LOG_ID)?; - for i in 1..=20 { - let lsn = appender.append(format!("record{}", i)).await?; - assert_eq!(Lsn::from(i), lsn); - } - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - - // report persisted lsn back to cluster controller - persisted_lsn.store(3, Ordering::Relaxed); + let tc = node_env.tc.clone(); + async move { + // simulate a connection from node 2 so we can have a connection between the two + // nodes + let node_2 = MockPeerConnection::connect( + GenerationalNodeId::new(2, 2), + node_env.metadata.nodes_config_version(), + node_env + .metadata + .nodes_config_ref() + .cluster_name() + .to_owned(), + node_env.networking.connection_manager(), + 10, + ) + .await?; + // let node2 receive messages and use the same message handler as node1 + let (_node_2, _node2_reactor) = + node_2.process_with_message_handler(&node_env.tc, get_node_state_handler)?; + + let mut appender = bifrost.create_appender(LOG_ID)?; + for i in 1..=20 { + let lsn = appender.append(format!("record{}", i)).await?; + assert_eq!(Lsn::from(i), lsn); + } + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - tokio::time::sleep(interval_duration * 10).await; - // everything before the persisted_lsn. - assert_eq!(bifrost.get_trim_point(LOG_ID).await?, Lsn::from(3)); - // we should be able to after the last persisted lsn - let v = bifrost.read(LOG_ID, Lsn::from(4)).await?.unwrap(); - assert_that!(v.sequence_number(), eq(Lsn::new(4))); - assert!(v.is_data_record()); - assert_that!(v.decode_unchecked::(), eq("record4".to_owned())); + // report persisted lsn back to cluster controller + persisted_lsn.store(3, Ordering::Relaxed); - persisted_lsn.store(20, Ordering::Relaxed); + tokio::time::sleep(interval_duration * 10).await; + // everything before the persisted_lsn. + assert_eq!(bifrost.get_trim_point(LOG_ID).await?, Lsn::from(3)); + // we should be able to after the last persisted lsn + let v = bifrost.read(LOG_ID, Lsn::from(4)).await?.unwrap(); + assert_that!(v.sequence_number(), eq(Lsn::new(4))); + assert!(v.is_data_record()); + assert_that!(v.decode_unchecked::(), eq("record4".to_owned())); - tokio::time::sleep(interval_duration * 10).await; - assert_eq!(Lsn::from(20), bifrost.get_trim_point(LOG_ID).await?); + persisted_lsn.store(20, Ordering::Relaxed); - Ok::<(), anyhow::Error>(()) - }) - .await?; + tokio::time::sleep(interval_duration * 10).await; + assert_eq!(Lsn::from(20), bifrost.get_trim_point(LOG_ID).await?); - Ok(()) + Ok::<(), anyhow::Error>(()) + } + .in_tc(&tc) + .await } #[test(tokio::test(start_paused = true))] @@ -788,25 +781,24 @@ mod tests { }) .await?; - node_env - .tc - .run_in_scope("test", None, async move { - let mut appender = bifrost.create_appender(LOG_ID)?; - for i in 1..=5 { - let lsn = appender.append(format!("record{}", i)).await?; - assert_eq!(Lsn::from(i), lsn); - } + async move { + let mut appender = bifrost.create_appender(LOG_ID)?; + for i in 1..=5 { + let lsn = appender.append(format!("record{}", i)).await?; + assert_eq!(Lsn::from(i), lsn); + } - // report persisted lsn back to cluster controller for a subset of the nodes - persisted_lsn.store(5, Ordering::Relaxed); + // report persisted lsn back to cluster controller for a subset of the nodes + persisted_lsn.store(5, Ordering::Relaxed); - tokio::time::sleep(interval_duration * 10).await; - // no trimming should have happened because one node did not report the persisted lsn - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + tokio::time::sleep(interval_duration * 10).await; + // no trimming should have happened because one node did not report the persisted lsn + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - Ok::<(), anyhow::Error>(()) - }) - .await?; + Ok::<(), anyhow::Error>(()) + } + .in_tc(&node_env.tc) + .await?; Ok(()) } @@ -819,53 +811,54 @@ mod tests { F: FnMut(TestCoreEnvBuilder) -> TestCoreEnvBuilder, { let mut builder = TestCoreEnvBuilder::with_incoming_only_connector(); - let metadata = builder.metadata.clone(); - - let bifrost_svc = BifrostService::new(builder.tc.clone(), metadata.clone()) - .with_factory(memory_loglet::Factory::default()); - let bifrost = bifrost_svc.handle(); - - let mut server_builder = NetworkServerBuilder::default(); - - let svc = Service::new( - Live::from_value(config), - HealthStatus::default(), - bifrost.clone(), - builder.metadata.clone(), - builder.networking.clone(), - &mut builder.router_builder, - &mut server_builder, - builder.metadata_writer.clone(), - builder.metadata_store_client.clone(), - ); + let tc = builder.tc.clone(); + async { + let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); + let bifrost = bifrost_svc.handle(); - let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); - nodes_config.upsert_node(NodeConfig::new( - "node-1".to_owned(), - GenerationalNodeId::new(1, 1), - AdvertisedAddress::Uds("foobar".into()), - Role::Worker.into(), - LogServerConfig::default(), - )); - nodes_config.upsert_node(NodeConfig::new( - "node-2".to_owned(), - GenerationalNodeId::new(2, 2), - AdvertisedAddress::Uds("bar".into()), - Role::Worker.into(), - LogServerConfig::default(), - )); - let builder = modify_builder(builder.set_nodes_config(nodes_config)); - - let node_env = builder.build().await; - bifrost_svc.start().await?; - - node_env.tc.spawn( - TaskKind::SystemService, - "cluster-controller", - None, - svc.run(), - )?; + let mut server_builder = NetworkServerBuilder::default(); - Ok((node_env, bifrost)) + let svc = Service::new( + Live::from_value(config), + HealthStatus::default(), + bifrost.clone(), + builder.metadata.clone(), + builder.networking.clone(), + &mut builder.router_builder, + &mut server_builder, + builder.metadata_writer.clone(), + builder.metadata_store_client.clone(), + ); + + let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); + nodes_config.upsert_node(NodeConfig::new( + "node-1".to_owned(), + GenerationalNodeId::new(1, 1), + AdvertisedAddress::Uds("foobar".into()), + Role::Worker.into(), + LogServerConfig::default(), + )); + nodes_config.upsert_node(NodeConfig::new( + "node-2".to_owned(), + GenerationalNodeId::new(2, 2), + AdvertisedAddress::Uds("bar".into()), + Role::Worker.into(), + LogServerConfig::default(), + )); + let builder = modify_builder(builder.set_nodes_config(nodes_config)); + + let node_env = builder.build().await; + bifrost_svc.start().await?; + + node_env.tc.spawn( + TaskKind::SystemService, + "cluster-controller", + None, + svc.run(), + )?; + Ok((node_env, bifrost)) + } + .in_tc(&tc) + .await } } diff --git a/crates/bifrost/benches/append_throughput.rs b/crates/bifrost/benches/append_throughput.rs index 6a2d5ed43..679c3f9c4 100644 --- a/crates/bifrost/benches/append_throughput.rs +++ b/crates/bifrost/benches/append_throughput.rs @@ -8,22 +8,23 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +mod util; + use std::ops::Range; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use futures::stream::{FuturesOrdered, FuturesUnordered}; use futures::StreamExt; +use tracing::info; +use tracing_subscriber::EnvFilter; + use restate_bifrost::{Bifrost, BifrostService}; -use restate_core::metadata; use restate_rocksdb::{DbName, RocksDbManager}; use restate_types::config::{ BifrostOptionsBuilder, CommonOptionsBuilder, ConfigurationBuilder, LocalLogletOptionsBuilder, }; use restate_types::live::Live; use restate_types::logs::LogId; -use tracing::info; -use tracing_subscriber::EnvFilter; -mod util; async fn append_records_multi_log(bifrost: Bifrost, log_id_range: Range, count_per_log: u64) { let mut appends = FuturesUnordered::new(); @@ -105,9 +106,7 @@ fn write_throughput_local_loglet(c: &mut Criterion) { )); let bifrost = tc.block_on(async { - let metadata = metadata(); - let bifrost_svc = BifrostService::new(restate_core::task_center(), metadata) - .enable_local_loglet(&Live::from_value(config)); + let bifrost_svc = BifrostService::new().enable_local_loglet(&Live::from_value(config)); let bifrost = bifrost_svc.handle(); // start bifrost service in the background diff --git a/crates/bifrost/src/appender.rs b/crates/bifrost/src/appender.rs index 4d4ca98a0..65feaa37b 100644 --- a/crates/bifrost/src/appender.rs +++ b/crates/bifrost/src/appender.rs @@ -11,14 +11,15 @@ use std::sync::Arc; use std::time::Instant; -use restate_types::storage::StorageEncode; use tracing::{debug, info, instrument, warn}; +use restate_core::Metadata; use restate_types::config::Configuration; use restate_types::live::Live; use restate_types::logs::metadata::SegmentIndex; use restate_types::logs::{LogId, Lsn, Record}; use restate_types::retries::RetryIter; +use restate_types::storage::StorageEncode; use crate::bifrost::BifrostInner; use crate::loglet::AppendError; @@ -167,8 +168,9 @@ impl Appender { ); return Ok(loglet); } else { + let log_version = Metadata::with_current(|m| m.logs_version()); debug!( - log_version = %bifrost_inner.metadata.logs_version(), + log_version = %log_version, "Still waiting for sealing to complete. Elapsed={:?}", start.elapsed(), ); diff --git a/crates/bifrost/src/bifrost.rs b/crates/bifrost/src/bifrost.rs index a810776cf..beeffa185 100644 --- a/crates/bifrost/src/bifrost.rs +++ b/crates/bifrost/src/bifrost.rs @@ -13,13 +13,12 @@ use std::sync::Arc; use std::sync::OnceLock; use enum_map::EnumMap; +use tracing::instrument; use restate_core::{Metadata, MetadataKind, TargetVersion}; use restate_types::logs::metadata::{MaybeSegment, ProviderKind, Segment}; use restate_types::logs::{KeyFilter, LogId, Lsn, SequenceNumber, TailState}; use restate_types::storage::StorageEncode; -use restate_types::Version; -use tracing::instrument; use crate::appender::Appender; use crate::background_appender::BackgroundAppender; @@ -44,21 +43,20 @@ impl Bifrost { } #[cfg(any(test, feature = "test-util"))] - pub async fn init_in_memory(metadata: Metadata) -> Self { + pub async fn init_in_memory() -> Self { use crate::providers::memory_loglet; - Self::init_with_factory(metadata, memory_loglet::Factory::default()).await + Self::init_with_factory(memory_loglet::Factory::default()).await } #[cfg(any(test, feature = "test-util"))] - pub async fn init_local(metadata: Metadata) -> Self { + pub async fn init_local() -> Self { use restate_types::config::Configuration; use crate::BifrostService; let config = Configuration::updateable(); - let bifrost_svc = - BifrostService::new(restate_core::task_center(), metadata).enable_local_loglet(&config); + let bifrost_svc = BifrostService::new().enable_local_loglet(&config); let bifrost = bifrost_svc.handle(); // start bifrost service in the background @@ -70,14 +68,10 @@ impl Bifrost { } #[cfg(any(test, feature = "test-util"))] - pub async fn init_with_factory( - metadata: Metadata, - factory: impl crate::loglet::LogletProviderFactory, - ) -> Self { + pub async fn init_with_factory(factory: impl crate::loglet::LogletProviderFactory) -> Self { use crate::BifrostService; - let bifrost_svc = - BifrostService::new(restate_core::task_center(), metadata).with_factory(factory); + let bifrost_svc = BifrostService::new().with_factory(factory); let bifrost = bifrost_svc.handle(); // start bifrost service in the background @@ -228,11 +222,6 @@ impl Bifrost { self.inner.get_trim_point(log_id).await } - /// The version of the currently loaded logs metadata - pub fn version(&self) -> Version { - self.inner.metadata.logs_version() - } - /// Read a full log with the given id. To be used only in tests!!! #[cfg(any(test, feature = "test-util"))] pub async fn read_all(&self, log_id: LogId) -> Result> { @@ -262,7 +251,6 @@ static_assertions::assert_impl_all!(Bifrost: Send, Sync, Clone); // Locks in this data-structure are held for very short time and should never be // held across an async boundary. pub struct BifrostInner { - pub(crate) metadata: Metadata, #[allow(unused)] watchdog: WatchdogSender, // Initialized after BifrostService::start completes. @@ -271,9 +259,8 @@ pub struct BifrostInner { } impl BifrostInner { - pub fn new(metadata: Metadata, watchdog: WatchdogSender) -> Self { + pub fn new(watchdog: WatchdogSender) -> Self { Self { - metadata, watchdog, providers: Default::default(), shutting_down: AtomicBool::new(false), @@ -338,7 +325,7 @@ impl BifrostInner { } async fn get_trim_point(&self, log_id: LogId) -> Result { - let log_metadata = self.metadata.logs_ref(); + let log_metadata = Metadata::with_current(|m| m.logs_ref()); let log_chain = log_metadata .chain(&log_id) @@ -366,7 +353,7 @@ impl BifrostInner { } pub async fn trim(&self, log_id: LogId, trim_point: Lsn) -> Result<(), Error> { - let log_metadata = self.metadata.logs_ref(); + let log_metadata = Metadata::with_current(|m| m.logs_ref()); let log_chain = log_metadata .chain(&log_id) @@ -396,7 +383,7 @@ impl BifrostInner { /// Immediately fetch new metadata from metadata store. pub async fn sync_metadata(&self) -> Result<()> { - self.metadata + Metadata::current() .sync(MetadataKind::Logs, TargetVersion::Latest) .await?; Ok(()) @@ -419,7 +406,7 @@ impl BifrostInner { /// Checks if the log_id exists and that the provider is not disabled (can be created). pub(crate) fn check_log_id(&self, log_id: LogId) -> Result<(), Error> { - let logs = self.metadata.logs_ref(); + let logs = Metadata::with_current(|metadata| metadata.logs_ref()); let chain = logs.chain(&log_id).ok_or(Error::UnknownLogId(log_id))?; let kind = chain.tail().config.kind; @@ -429,7 +416,7 @@ impl BifrostInner { } pub async fn writeable_loglet(&self, log_id: LogId) -> Result { - let log_metadata = self.metadata.logs_ref(); + let log_metadata = Metadata::with_current(|metadata| metadata.logs_ref()); let tail_segment = log_metadata .chain(&log_id) .ok_or(Error::UnknownLogId(log_id))? @@ -438,7 +425,7 @@ impl BifrostInner { } pub async fn find_loglet_for_lsn(&self, log_id: LogId, lsn: Lsn) -> Result { - let log_metadata = self.metadata.logs_ref(); + let log_metadata = Metadata::with_current(|metadata| metadata.logs_ref()); let maybe_segment = log_metadata .chain(&log_id) .ok_or(Error::UnknownLogId(log_id))? @@ -509,8 +496,8 @@ mod tests { use tracing::info; use tracing_test::traced_test; - use restate_core::{metadata, TaskKind, TestCoreEnv}; - use restate_core::{task_center, TestCoreEnvBuilder}; + use restate_core::TestCoreEnvBuilder; + use restate_core::{TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnv}; use restate_rocksdb::RocksDbManager; use restate_types::config::CommonOptions; use restate_types::live::Constant; @@ -518,7 +505,7 @@ mod tests { use restate_types::logs::SequenceNumber; use restate_types::metadata_store::keys::BIFROST_CONFIG_KEY; use restate_types::partition_table::PartitionTable; - use restate_types::Versioned; + use restate_types::{Version, Versioned}; use crate::providers::memory_loglet::{self}; use crate::BifrostAdmin; @@ -534,9 +521,8 @@ mod tests { )) .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { - let bifrost = Bifrost::init_in_memory(metadata()).await; + async { + let bifrost = Bifrost::init_in_memory().await; let clean_bifrost_clone = bifrost.clone(); @@ -588,7 +574,7 @@ mod tests { assert_eq!(max_lsn.next(), tail.offset()); // Initiate shutdown - task_center().shutdown_node("completed", 0).await; + TaskCenter::current().shutdown_node("completed", 0).await; // appends cannot succeed after shutdown let res = appender_0.append("").await; assert!(matches!(res, Err(Error::Shutdown(_)))); @@ -596,20 +582,20 @@ mod tests { assert!(logs_contain("Shutting down in-memory loglet provider")); assert!(logs_contain("Bifrost watchdog shutdown complete")); Ok(()) - }) + } + .in_tc(&node_env.tc) .await } #[tokio::test(start_paused = true)] async fn test_lazy_initialization() -> googletest::Result<()> { let node_env = TestCoreEnv::create_with_single_node(1, 1).await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { + async { let delay = Duration::from_secs(5); // This memory provider adds a delay to its loglet initialization, we want // to ensure that appends do not fail while waiting for the loglet; let factory = memory_loglet::Factory::with_init_delay(delay); - let bifrost = Bifrost::init_with_factory(metadata(), factory).await; + let bifrost = Bifrost::init_with_factory(factory).await; let start = tokio::time::Instant::now(); let lsn = bifrost.create_appender(LogId::new(0))?.append("").await?; @@ -617,7 +603,8 @@ mod tests { // The append was properly delayed assert_eq!(delay, start.elapsed()); Ok(()) - }) + } + .in_tc(&node_env.tc) .await } @@ -628,74 +615,73 @@ mod tests { .set_provider_kind(ProviderKind::Local) .build() .await; - node_env - .tc - .run_in_scope("test", None, async { - RocksDbManager::init(Constant::new(CommonOptions::default())); - - let bifrost = Bifrost::init_local(metadata()).await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); + async { + RocksDbManager::init(Constant::new(CommonOptions::default())); - assert_eq!(Lsn::OLDEST, bifrost.find_tail(LOG_ID).await?.offset()); + let bifrost = Bifrost::init_local().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + assert_eq!(Lsn::OLDEST, bifrost.find_tail(LOG_ID).await?.offset()); - let mut appender = bifrost.create_appender(LOG_ID)?; - // append 10 records - for _ in 1..=10 { - appender.append("").await?; - } + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; + let mut appender = bifrost.create_appender(LOG_ID)?; + // append 10 records + for _ in 1..=10 { + appender.append("").await?; + } - let tail = bifrost.find_tail(LOG_ID).await?; - assert_eq!(tail.offset(), Lsn::from(11)); - assert!(!tail.is_sealed()); - assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); + bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; - // 5 itself is trimmed - for lsn in 1..=5 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + let tail = bifrost.find_tail(LOG_ID).await?; + assert_eq!(tail.offset(), Lsn::from(11)); + assert!(!tail.is_sealed()); + assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); - } + // 5 itself is trimmed + for lsn in 1..=5 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - for lsn in 6..=10 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); + } - // trimming beyond the release point will fall back to the release point - bifrost_admin.trim(LOG_ID, Lsn::MAX).await?; + for lsn in 6..=10 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); + } - assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); - let new_trim_point = bifrost.get_trim_point(LOG_ID).await?; - assert_eq!(Lsn::from(10), new_trim_point); + // trimming beyond the release point will fall back to the release point + bifrost_admin.trim(LOG_ID, Lsn::MAX).await?; - let record = bifrost.read(LOG_ID, Lsn::from(10)).await?.unwrap(); - assert!(record.is_trim_gap()); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); + assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); + let new_trim_point = bifrost.get_trim_point(LOG_ID).await?; + assert_eq!(Lsn::from(10), new_trim_point); - // Add 10 more records - for _ in 0..10 { - appender.append("").await?; - } + let record = bifrost.read(LOG_ID, Lsn::from(10)).await?.unwrap(); + assert!(record.is_trim_gap()); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); - for lsn in 11..20 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } + // Add 10 more records + for _ in 0..10 { + appender.append("").await?; + } - Ok(()) - }) - .await + for lsn in 11..20 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); + } + + Ok(()) + } + .in_tc(&node_env.tc) + .await } #[tokio::test(start_paused = true)] @@ -708,9 +694,8 @@ mod tests { )) .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { - let bifrost = Bifrost::init_in_memory(metadata()).await; + async { + let bifrost = Bifrost::init_in_memory().await; let bifrost_admin = BifrostAdmin::new( &bifrost, &node_env.metadata_writer, @@ -759,9 +744,10 @@ mod tests { .await .is_err()); - let old_version = bifrost.inner.metadata.logs_version(); + let metadata = Metadata::current(); + let old_version = metadata.logs_version(); - let mut builder = bifrost.inner.metadata.logs_ref().clone().into_builder(); + let mut builder = metadata.logs_ref().clone().into_builder(); let mut chain_builder = builder.chain(LOG_ID).unwrap(); assert_eq!(1, chain_builder.num_segments()); let new_segment_params = new_single_node_loglet_params(ProviderKind::InMemory); @@ -785,14 +771,14 @@ mod tests { .await?; // make sure we have updated metadata. - metadata() + metadata .sync(MetadataKind::Logs, TargetVersion::Latest) .await?; - assert_eq!(new_version, bifrost.inner.metadata.logs_version()); + assert_eq!(new_version, metadata.logs_version()); { // validate that the stored metadata matches our expectations. - let new_metadata = bifrost.inner.metadata.logs_ref().clone(); + let new_metadata = metadata.logs_ref().clone(); let chain_builder = new_metadata.chain(&LOG_ID).unwrap(); assert_eq!(2, chain_builder.num_segments()); } @@ -887,7 +873,8 @@ mod tests { assert!(bifrost.read(LOG_ID, Lsn::new(8)).await?.is_none()); Ok(()) - }) + } + .in_tc(&node_env.tc) .await } @@ -903,10 +890,9 @@ mod tests { .set_provider_kind(ProviderKind::Local) .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { + async { RocksDbManager::init(Constant::new(CommonOptions::default())); - let bifrost = Bifrost::init_local(metadata()).await; + let bifrost = Bifrost::init_local().await; let bifrost_admin = BifrostAdmin::new( &bifrost, &node_env.metadata_writer, @@ -916,7 +902,7 @@ mod tests { // create an appender let stop_signal = Arc::new(AtomicBool::default()); let append_counter = Arc::new(AtomicUsize::new(0)); - let _ = tc.spawn(TaskKind::TestRunner, "append-records", None, { + let _ = TaskCenter::current().spawn(TaskKind::TestRunner, "append-records", None, { let append_counter = append_counter.clone(); let stop_signal = stop_signal.clone(); let bifrost = bifrost.clone(); @@ -1011,9 +997,11 @@ mod tests { } googletest::Result::Ok(()) - }) + } + .in_tc(&node_env.tc) .await?; - tc.shutdown_node("test completed", 0).await; + + node_env.tc.shutdown_node("test completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } diff --git a/crates/bifrost/src/bifrost_admin.rs b/crates/bifrost/src/bifrost_admin.rs index 8e67626fd..dd239de6d 100644 --- a/crates/bifrost/src/bifrost_admin.rs +++ b/crates/bifrost/src/bifrost_admin.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use tracing::{info, instrument}; -use restate_core::{MetadataKind, MetadataWriter}; +use restate_core::{Metadata, MetadataKind, MetadataWriter}; use restate_metadata_store::MetadataStoreClient; use restate_types::config::Configuration; use restate_types::logs::builder::BuilderError; @@ -93,22 +93,13 @@ impl<'a> BifrostAdmin<'a> { params: LogletParams, ) -> Result { self.bifrost.inner.fail_if_shutting_down()?; - let _ = self - .bifrost - .inner - .metadata + let metadata = Metadata::current(); + let _ = metadata .wait_for_version(MetadataKind::Logs, min_version) .await?; let segment_index = segment_index - .or_else(|| { - self.bifrost - .inner - .metadata - .logs_ref() - .chain(&log_id) - .map(|c| c.tail_index()) - }) + .or_else(|| metadata.logs_ref().chain(&log_id).map(|c| c.tail_index())) .ok_or(Error::UnknownLogId(log_id))?; let sealed_segment = self.seal(log_id, segment_index).await?; diff --git a/crates/bifrost/src/read_stream.rs b/crates/bifrost/src/read_stream.rs index 3ec2967ed..1c7ce3eef 100644 --- a/crates/bifrost/src/read_stream.rs +++ b/crates/bifrost/src/read_stream.rs @@ -21,6 +21,7 @@ use futures::Stream; use futures::StreamExt; use pin_project::pin_project; +use restate_core::Metadata; use restate_core::MetadataKind; use restate_core::ShutdownError; use restate_types::logs::metadata::MaybeSegment; @@ -348,7 +349,7 @@ impl Stream for LogReadStream { panic!("substream must be set at this point"); }; - let log_metadata = bifrost_inner.metadata.logs_ref(); + let log_metadata = Metadata::with_current(|metadata| metadata.logs_ref()); // The log is gone! let Some(chain) = log_metadata.chain(this.log_id) else { @@ -400,11 +401,11 @@ impl Stream for LogReadStream { let metadata_version = log_metadata.version(); // No hope at this metadata version, wait for the next update. - let metadata_watch_fut = Box::pin( - bifrost_inner - .metadata - .wait_for_version(MetadataKind::Logs, metadata_version.next()), - ); + let metadata_watch_fut = Box::pin(async move { + Metadata::current() + .wait_for_version(MetadataKind::Logs, metadata_version.next()) + .await + }); log_metadata_watch_fut.set(Some(metadata_watch_fut)); continue; } @@ -445,7 +446,7 @@ mod tests { use tracing_test::traced_test; use restate_core::{ - metadata, task_center, MetadataKind, TargetVersion, TaskKind, TestCoreEnvBuilder, + MetadataKind, TargetVersion, TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnvBuilder, }; use restate_rocksdb::RocksDbManager; use restate_types::config::{CommonOptions, Configuration}; @@ -468,14 +469,13 @@ mod tests { .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { + async { let read_from = Lsn::from(6); let config = Live::from_value(Configuration::default()); RocksDbManager::init(Constant::new(CommonOptions::default())); - let svc = BifrostService::new(task_center(), metadata()).enable_local_loglet(&config); + let svc = BifrostService::new().enable_local_loglet(&config); let bifrost = svc.handle(); svc.start().await.expect("loglet must start"); @@ -494,22 +494,28 @@ mod tests { let read_counter = Arc::new(AtomicUsize::new(0)); // spawn a reader that reads 5 records and exits. let counter_clone = read_counter.clone(); - let id = tc.spawn(TaskKind::TestRunner, "read-records", None, async move { - for i in 6..=10 { - let record = reader.next().await.expect("to never terminate")?; - let expected_lsn = Lsn::from(i); - counter_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - assert_that!(record.sequence_number(), eq(expected_lsn)); - assert_that!(reader.read_pointer(), ge(record.sequence_number())); - assert_that!( - record.decode_unchecked::(), - eq(format!("record{}", expected_lsn)) - ); - } - Ok(()) - })?; + let id = TaskCenter::current().spawn( + TaskKind::TestRunner, + "read-records", + None, + async move { + for i in 6..=10 { + let record = reader.next().await.expect("to never terminate")?; + let expected_lsn = Lsn::from(i); + counter_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + assert_that!(record.sequence_number(), eq(expected_lsn)); + assert_that!(reader.read_pointer(), ge(record.sequence_number())); + assert_that!( + record.decode_unchecked::(), + eq(format!("record{}", expected_lsn)) + ); + } + Ok(()) + }, + )?; - let reader_bg_handle = tc.take_task(id).expect("read-records task to exist"); + let reader_bg_handle = TaskCenter::with_current(|tc| tc.take_task(id)) + .expect("read-records task to exist"); tokio::task::yield_now().await; // Not finished, we still didn't append records @@ -538,7 +544,8 @@ mod tests { assert_eq!(5, read_counter.load(std::sync::atomic::Ordering::Relaxed)); anyhow::Ok(()) - }) + } + .in_tc(&node_env.tc) .await?; Ok(()) } @@ -553,94 +560,92 @@ mod tests { .set_provider_kind(ProviderKind::Local) .build() .await; - node_env - .tc - .run_in_scope("test", None, async { - let config = Live::from_value(Configuration::default()); - RocksDbManager::init(Constant::new(CommonOptions::default())); - - let svc = - BifrostService::new(task_center(), metadata()).enable_local_loglet(&config); - let bifrost = svc.handle(); - - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); - svc.start().await.expect("loglet must start"); + async { + let config = Live::from_value(Configuration::default()); + RocksDbManager::init(Constant::new(CommonOptions::default())); + + let svc = BifrostService::new().enable_local_loglet(&config); + let bifrost = svc.handle(); - let mut appender = bifrost.create_appender(LOG_ID)?; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); + svc.start().await.expect("loglet must start"); - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + let mut appender = bifrost.create_appender(LOG_ID)?; - // append 10 records [1..10] - for i in 1..=10 { - let lsn = appender.append("").await?; - assert_eq!(Lsn::from(i), lsn); - } + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + + // append 10 records [1..10] + for i in 1..=10 { + let lsn = appender.append("").await?; + assert_eq!(Lsn::from(i), lsn); + } + + // [1..5] trimmed. trim_point = 5 + bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; - // [1..5] trimmed. trim_point = 5 - bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; + assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); + assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); - assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); - assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); + let mut read_stream = + bifrost.create_reader(LOG_ID, KeyFilter::Any, Lsn::OLDEST, Lsn::MAX)?; - let mut read_stream = - bifrost.create_reader(LOG_ID, KeyFilter::Any, Lsn::OLDEST, Lsn::MAX)?; + let record = read_stream.next().await.unwrap()?; + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); + for lsn in 6..=7 { let record = read_stream.next().await.unwrap()?; - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); + } + assert!(!read_stream.is_terminated()); + assert_eq!(Lsn::from(8), read_stream.read_pointer()); + + let tail = bifrost.find_tail(LOG_ID).await?.offset(); + // trimming beyond the release point will fall back to the release point + bifrost_admin.trim(LOG_ID, Lsn::from(u64::MAX)).await?; + let trim_point = bifrost.get_trim_point(LOG_ID).await?; + assert_eq!(Lsn::from(10), bifrost.get_trim_point(LOG_ID).await?); + // trim point becomes the point before the next slot available for writes (aka. the + // tail) + assert_eq!(tail.prev(), trim_point); + + // append lsns [11..20] + for i in 11..=20 { + let lsn = appender.append(format!("record{}", i)).await?; + assert_eq!(Lsn::from(i), lsn); + } - for lsn in 6..=7 { - let record = read_stream.next().await.unwrap()?; - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } - assert!(!read_stream.is_terminated()); - assert_eq!(Lsn::from(8), read_stream.read_pointer()); - - let tail = bifrost.find_tail(LOG_ID).await?.offset(); - // trimming beyond the release point will fall back to the release point - bifrost_admin.trim(LOG_ID, Lsn::from(u64::MAX)).await?; - let trim_point = bifrost.get_trim_point(LOG_ID).await?; - assert_eq!(Lsn::from(10), bifrost.get_trim_point(LOG_ID).await?); - // trim point becomes the point before the next slot available for writes (aka. the - // tail) - assert_eq!(tail.prev(), trim_point); - - // append lsns [11..20] - for i in 11..=20 { - let lsn = appender.append(format!("record{}", i)).await?; - assert_eq!(Lsn::from(i), lsn); - } + // read stream should send a gap from 8->10 + let record = read_stream.next().await.unwrap()?; + assert_that!(record.sequence_number(), eq(Lsn::new(8))); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); + + // read pointer is at 11 + assert_eq!(Lsn::from(11), read_stream.read_pointer()); - // read stream should send a gap from 8->10 + // read the rest of the records + for lsn in 11..=20 { let record = read_stream.next().await.unwrap()?; - assert_that!(record.sequence_number(), eq(Lsn::new(8))); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); - - // read pointer is at 11 - assert_eq!(Lsn::from(11), read_stream.read_pointer()); - - // read the rest of the records - for lsn in 11..=20 { - let record = read_stream.next().await.unwrap()?; - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq(format!("record{}", lsn)) - ); - } - // we are at tail. polling should return pending. - let pinned = std::pin::pin!(read_stream.next()); - let next_is_pending = futures::poll!(pinned); - assert!(matches!(next_is_pending, Poll::Pending)); - - Ok(()) - }) - .await + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq(format!("record{}", lsn)) + ); + } + // we are at tail. polling should return pending. + let pinned = std::pin::pin!(read_stream.next()); + let next_is_pending = futures::poll!(pinned); + assert!(matches!(next_is_pending, Poll::Pending)); + + Ok(()) + } + .in_tc(&node_env.tc) + .await } // Note: This test doesn't validate read stream behaviour with zombie records at seal boundary. @@ -654,13 +659,12 @@ mod tests { .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { + async { let config = Live::from_value(Configuration::default()); RocksDbManager::init(Constant::new(CommonOptions::default())); // enable both in-memory and local loglet types - let svc = BifrostService::new(task_center(), metadata()) + let svc = BifrostService::new() .enable_local_loglet(&config) .enable_in_memory_loglet(); let bifrost = svc.handle(); @@ -742,10 +746,11 @@ mod tests { assert!(tail.is_sealed()); assert_eq!(Lsn::from(11), tail.offset()); + let metadata = Metadata::current(); // perform manual reconfiguration (can be replaced with bifrost reconfiguration API // when it's implemented) - let old_version = bifrost.inner.metadata.logs_version(); - let mut builder = bifrost.inner.metadata.logs_ref().clone().into_builder(); + let old_version = metadata.logs_version(); + let mut builder = metadata.logs_ref().clone().into_builder(); let mut chain_builder = builder.chain(LOG_ID).unwrap(); assert_eq!(1, chain_builder.num_segments()); let new_segment_params = new_single_node_loglet_params(ProviderKind::InMemory); @@ -768,9 +773,7 @@ mod tests { .await?; // make sure we have updated metadata. - bifrost - .inner - .metadata + metadata .sync(MetadataKind::Logs, TargetVersion::Latest) .await?; @@ -807,7 +810,8 @@ mod tests { assert_that!(record.decode_unchecked::(), eq("segment-2-1000")); anyhow::Ok(()) - }) + } + .in_tc(&node_env.tc) .await?; Ok(()) } @@ -822,13 +826,12 @@ mod tests { .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { + async { let config = Live::from_value(Configuration::default()); RocksDbManager::init(Constant::new(CommonOptions::default())); // enable both in-memory and local loglet types - let svc = BifrostService::new(task_center(), metadata()) + let svc = BifrostService::new() .enable_local_loglet(&config) .enable_in_memory_loglet(); let bifrost = svc.handle(); @@ -932,7 +935,8 @@ mod tests { ); anyhow::Ok(()) - }) + } + .in_tc(&node_env.tc) .await?; Ok(()) } @@ -947,22 +951,22 @@ mod tests { .build() .await; - let tc = node_env.tc; - tc.run_in_scope("test", None, async { + async { let config = Live::from_value(Configuration::default()); RocksDbManager::init(Constant::new(CommonOptions::default())); // enable both in-memory and local loglet types - let svc = BifrostService::new(task_center(), metadata()) + let svc = BifrostService::new() .enable_local_loglet(&config) .enable_in_memory_loglet(); let bifrost = svc.handle(); svc.start().await.expect("loglet must start"); let mut appender = bifrost.create_appender(LOG_ID)?; + let metadata = Metadata::current(); // prepare a chain that starts from Lsn 10 (we expect trim from OLDEST -> 9) - let old_version = bifrost.inner.metadata.logs_version(); - let mut builder = bifrost.inner.metadata.logs_ref().clone().into_builder(); + let old_version = metadata.logs_version(); + let mut builder = metadata.logs_ref().clone().into_builder(); let mut chain_builder = builder.chain(LOG_ID).unwrap(); assert_eq!(1, chain_builder.num_segments()); let new_segment_params = new_single_node_loglet_params(ProviderKind::Local); @@ -984,9 +988,7 @@ mod tests { .await?; // make sure we have updated metadata. - bifrost - .inner - .metadata + metadata .sync(MetadataKind::Logs, TargetVersion::Latest) .await?; @@ -1016,7 +1018,8 @@ mod tests { } anyhow::Ok(()) - }) + } + .in_tc(&node_env.tc) .await?; Ok(()) } diff --git a/crates/bifrost/src/service.rs b/crates/bifrost/src/service.rs index adafde732..0a1ed2eef 100644 --- a/crates/bifrost/src/service.rs +++ b/crates/bifrost/src/service.rs @@ -13,11 +13,11 @@ use std::sync::Arc; use anyhow::Context; use enum_map::EnumMap; -use restate_types::config::Configuration; -use restate_types::live::Live; use tracing::{debug, error, trace}; -use restate_core::{cancellation_watcher, Metadata, TaskCenter, TaskKind}; +use restate_core::{cancellation_watcher, TaskCenter, TaskCenterFutureExt, TaskKind}; +use restate_types::config::Configuration; +use restate_types::live::Live; use restate_types::logs::metadata::ProviderKind; use crate::bifrost::BifrostInner; @@ -28,26 +28,25 @@ use crate::watchdog::{Watchdog, WatchdogCommand}; use crate::{loglet::LogletProviderFactory, Bifrost}; pub struct BifrostService { - task_center: TaskCenter, inner: Arc, bifrost: Bifrost, watchdog: Watchdog, factories: HashMap>, } +impl Default for BifrostService { + fn default() -> Self { + Self::new() + } +} + impl BifrostService { - pub fn new(task_center: TaskCenter, metadata: Metadata) -> Self { + pub fn new() -> Self { let (watchdog_sender, watchdog_receiver) = tokio::sync::mpsc::unbounded_channel(); - let inner = Arc::new(BifrostInner::new(metadata.clone(), watchdog_sender.clone())); + let inner = Arc::new(BifrostInner::new(watchdog_sender.clone())); let bifrost = Bifrost::new(inner.clone()); - let watchdog = Watchdog::new( - task_center.clone(), - inner.clone(), - watchdog_sender, - watchdog_receiver, - ); + let watchdog = Watchdog::new(inner.clone(), watchdog_sender, watchdog_receiver); Self { - task_center, inner, bifrost, watchdog, @@ -101,10 +100,9 @@ impl BifrostService { let mut tasks = tokio::task::JoinSet::new(); // Start all enabled providers. for (kind, factory) in self.factories { - let tc = self.task_center.clone(); let watchdog = self.watchdog.sender(); - tasks.spawn(async move { - tc.run_in_scope("loglet-provider-start", None, async move { + tasks.spawn( + async move { trace!("Starting loglet provider {}", kind); match factory.create().await { Err(e) => { @@ -124,9 +122,9 @@ impl BifrostService { Ok((kind, provider)) } } - }) - .await - }); + } + .in_current_tc_as_task(TaskKind::LogletProvider, "loglet-provider-start"), + ); } let mut shutdown = std::pin::pin!(cancellation_watcher()); @@ -160,7 +158,7 @@ impl BifrostService { .map_err(|_| anyhow::anyhow!("bifrost must be initialized only once"))?; // We spawn the watchdog as a background long-running task - self.task_center.spawn( + TaskCenter::current().spawn( TaskKind::BifrostBackgroundHighPriority, "bifrost-watchdog", None, diff --git a/crates/bifrost/src/watchdog.rs b/crates/bifrost/src/watchdog.rs index 916ae243d..c6fd1783c 100644 --- a/crates/bifrost/src/watchdog.rs +++ b/crates/bifrost/src/watchdog.rs @@ -30,7 +30,6 @@ type WatchdogReceiver = tokio::sync::mpsc::UnboundedReceiver; /// Tasks are expected to check for the cancellation token when appropriate and finalize their /// work before termination. pub struct Watchdog { - task_center: TaskCenter, inner: Arc, sender: WatchdogSender, inbound: WatchdogReceiver, @@ -39,13 +38,11 @@ pub struct Watchdog { impl Watchdog { pub fn new( - task_center: TaskCenter, inner: Arc, sender: WatchdogSender, inbound: WatchdogReceiver, ) -> Self { Self { - task_center, inner, sender, inbound, @@ -57,7 +54,7 @@ impl Watchdog { match cmd { WatchdogCommand::ScheduleMetadataSync => { let bifrost = self.inner.clone(); - let _ = self.task_center.spawn( + let _ = TaskCenter::current().spawn( TaskKind::MetadataBackgroundSync, "bifrost-metadata-sync", None, @@ -73,7 +70,7 @@ impl Watchdog { WatchdogCommand::WatchProvider(provider) => { self.live_providers.push(provider.clone()); // TODO: Convert to a managed background task - let _ = self.task_center.spawn( + let _ = TaskCenter::current().spawn( TaskKind::BifrostBackgroundHighPriority, "bifrost-provider-on-start", None, diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index d4db04716..29ab5e91a 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -178,8 +178,7 @@ impl Node { record_cache.clone(), &mut router_builder, ); - let bifrost_svc = BifrostService::new(tc.clone(), metadata.clone()) - .enable_local_loglet(&updateable_config); + let bifrost_svc = BifrostService::new().enable_local_loglet(&updateable_config); #[cfg(feature = "replicated-loglet")] let bifrost_svc = bifrost_svc.with_factory(replicated_loglet_factory); diff --git a/crates/worker/src/partition/cleaner.rs b/crates/worker/src/partition/cleaner.rs index 1121d12ab..0d4485df8 100644 --- a/crates/worker/src/partition/cleaner.rs +++ b/crates/worker/src/partition/cleaner.rs @@ -8,8 +8,15 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::ops::RangeInclusive; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; + use anyhow::Context; use futures::StreamExt; +use tokio::time::MissedTickBehavior; +use tracing::{debug, instrument, warn}; + use restate_bifrost::Bifrost; use restate_core::cancellation_watcher; use restate_storage_api::invocation_status_table::{ @@ -22,11 +29,6 @@ use restate_types::GenerationalNodeId; use restate_wal_protocol::{ append_envelope_to_bifrost, Command, Destination, Envelope, Header, Source, }; -use std::ops::RangeInclusive; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; -use tokio::time::MissedTickBehavior; -use tracing::{debug, instrument, warn}; pub(super) struct Cleaner { partition_id: PartitionId, @@ -170,7 +172,7 @@ mod tests { use futures::{stream, Stream}; use googletest::prelude::*; - use restate_core::{TaskKind, TestCoreEnvBuilder}; + use restate_core::{TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnvBuilder}; use restate_storage_api::invocation_status_table::{ CompletedInvocation, InFlightInvocationMetadata, InvocationStatus, }; @@ -222,90 +224,88 @@ mod tests { )) .build() .await; - let tc = &env.tc; - let bifrost = tc - .run_in_scope( - "init bifrost", - None, - Bifrost::init_in_memory(env.metadata.clone()), - ) - .await; + async { + let bifrost = Bifrost::init_in_memory().await; - let expired_invocation = - InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); - let not_expired_invocation_1 = - InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); - let not_expired_invocation_2 = - InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); - let not_completed_invocation = - InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); + let expired_invocation = + InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); + let not_expired_invocation_1 = + InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); + let not_expired_invocation_2 = + InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); + let not_completed_invocation = + InvocationId::from_parts(PartitionKey::MIN, InvocationUuid::mock_random()); - let mock_storage = MockInvocationStatusReader(vec![ - ( - expired_invocation, - InvocationStatus::Completed(CompletedInvocation { - completion_retention_duration: Duration::ZERO, - ..CompletedInvocation::mock_neo() - }), - ), - ( - not_expired_invocation_1, - InvocationStatus::Completed(CompletedInvocation { - completion_retention_duration: Duration::MAX, - ..CompletedInvocation::mock_neo() - }), - ), - ( - not_expired_invocation_2, - // Old status invocations are still processed with the cleanup timer in the PP - InvocationStatus::Completed(CompletedInvocation::mock_old()), - ), - ( - not_completed_invocation, - InvocationStatus::Invoked(InFlightInvocationMetadata::mock()), - ), - ]); + let mock_storage = MockInvocationStatusReader(vec![ + ( + expired_invocation, + InvocationStatus::Completed(CompletedInvocation { + completion_retention_duration: Duration::ZERO, + ..CompletedInvocation::mock_neo() + }), + ), + ( + not_expired_invocation_1, + InvocationStatus::Completed(CompletedInvocation { + completion_retention_duration: Duration::MAX, + ..CompletedInvocation::mock_neo() + }), + ), + ( + not_expired_invocation_2, + // Old status invocations are still processed with the cleanup timer in the PP + InvocationStatus::Completed(CompletedInvocation::mock_old()), + ), + ( + not_completed_invocation, + InvocationStatus::Invoked(InFlightInvocationMetadata::mock()), + ), + ]); - tc.spawn( - TaskKind::Cleaner, - "cleaner", - Some(PartitionId::MIN), - Cleaner::new( - PartitionId::MIN, - LeaderEpoch::INITIAL, - GenerationalNodeId::new(1, 1), - mock_storage, - bifrost.clone(), - RangeInclusive::new(PartitionKey::MIN, PartitionKey::MAX), - Duration::from_secs(1), - ) - .run(), - ) - .unwrap(); + TaskCenter::current() + .spawn( + TaskKind::Cleaner, + "cleaner", + Some(PartitionId::MIN), + Cleaner::new( + PartitionId::MIN, + LeaderEpoch::INITIAL, + GenerationalNodeId::new(1, 1), + mock_storage, + bifrost.clone(), + RangeInclusive::new(PartitionKey::MIN, PartitionKey::MAX), + Duration::from_secs(1), + ) + .run(), + ) + .unwrap(); - // By yielding once we let the cleaner task run, and perform the cleanup - tokio::task::yield_now().await; + // By yielding once we let the cleaner task run, and perform the cleanup + tokio::task::yield_now().await; - // All the invocation ids were created with same partition keys, hence same partition id. - let partition_id = env - .metadata - .partition_table_snapshot() - .find_partition_id(expired_invocation.partition_key()) - .unwrap(); + // All the invocation ids were created with same partition keys, hence same partition id. + let partition_id = env + .metadata + .partition_table_snapshot() + .find_partition_id(expired_invocation.partition_key()) + .unwrap(); - let mut log_entries = bifrost.read_all(partition_id.into()).await.unwrap(); - let bifrost_message = log_entries - .remove(0) - .try_decode::() - .unwrap() - .unwrap(); + let mut log_entries = bifrost.read_all(partition_id.into()).await.unwrap(); + let bifrost_message = log_entries + .remove(0) + .try_decode::() + .unwrap() + .unwrap(); - assert_that!( - bifrost_message.command, - pat!(Command::PurgeInvocation(pat!(PurgeInvocationRequest { - invocation_id: eq(expired_invocation) - }))) - ); - assert_that!(log_entries, empty()); + assert_that!( + bifrost_message.command, + pat!(Command::PurgeInvocation(pat!(PurgeInvocationRequest { + invocation_id: eq(expired_invocation) + }))) + ); + assert_that!(log_entries, empty()); + } + .in_tc(&env.tc) + .await; } } diff --git a/crates/worker/src/partition/leadership.rs b/crates/worker/src/partition/leadership.rs index f18d7da82..b84f8e950 100644 --- a/crates/worker/src/partition/leadership.rs +++ b/crates/worker/src/partition/leadership.rs @@ -1099,7 +1099,7 @@ mod tests { use crate::partition::leadership::{LeadershipState, PartitionProcessorMetadata, State}; use assert2::let_assert; use restate_bifrost::Bifrost; - use restate_core::{task_center, TestCoreEnv}; + use restate_core::{task_center, TaskCenterFutureExt, TestCoreEnv}; use restate_invoker_api::test_util::MockInvokerHandle; use restate_partition_store::{OpenMode, PartitionStoreManager}; use restate_rocksdb::RocksDbManager; @@ -1124,21 +1124,14 @@ mod tests { #[test(tokio::test)] async fn become_leader_then_step_down() -> googletest::Result<()> { let env = TestCoreEnv::create_with_single_node(0, 0).await; - let tc = env.tc.clone(); - let storage_options = StorageOptions::default(); - let rocksdb_options = RocksDbOptions::default(); + async { + let storage_options = StorageOptions::default(); + let rocksdb_options = RocksDbOptions::default(); - tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + RocksDbManager::init(Constant::new(CommonOptions::default())); - let bifrost = tc - .run_in_scope( - "init bifrost", - None, - Bifrost::init_in_memory(env.metadata.clone()), - ) - .await; + let bifrost = Bifrost::init_in_memory().await; - tc.run_in_scope("test", None, async { let partition_store_manager = PartitionStoreManager::create( Constant::new(storage_options.clone()).boxed(), Constant::new(rocksdb_options.clone()).boxed(), @@ -1203,10 +1196,11 @@ mod tests { assert!(matches!(state.state, State::Follower)); googletest::Result::Ok(()) - }) + } + .in_tc(&env.tc) .await?; - tc.shutdown_node("test_completed", 0).await; + env.tc.shutdown_node("test_completed", 0).await; RocksDbManager::get().shutdown().await; Ok(()) } diff --git a/crates/worker/src/partition/shuffle.rs b/crates/worker/src/partition/shuffle.rs index 1ca1d1e2b..08d2283c2 100644 --- a/crates/worker/src/partition/shuffle.rs +++ b/crates/worker/src/partition/shuffle.rs @@ -645,7 +645,6 @@ mod tests { )) .build() .await; - let tc = &env.tc; let metadata = ShuffleMetadata::new( PartitionId::from(0), LeaderEpoch::from(0), @@ -654,13 +653,7 @@ mod tests { let (truncation_tx, _truncation_rx) = mpsc::channel(1); - let bifrost = tc - .run_in_scope( - "init bifrost", - None, - Bifrost::init_in_memory(env.metadata.clone()), - ) - .await; + let bifrost = Bifrost::init_in_memory().in_tc(&env.tc).await; let shuffle = Shuffle::new(metadata, outbox_reader, truncation_tx, 1, bifrost.clone()); ShuffleEnv { diff --git a/crates/worker/src/partition_processor_manager/mod.rs b/crates/worker/src/partition_processor_manager/mod.rs index f528de7e6..0720d7af2 100644 --- a/crates/worker/src/partition_processor_manager/mod.rs +++ b/crates/worker/src/partition_processor_manager/mod.rs @@ -912,7 +912,7 @@ mod tests { use restate_bifrost::providers::memory_loglet; use restate_bifrost::BifrostService; use restate_core::network::MockPeerConnection; - use restate_core::{TaskKind, TestCoreEnvBuilder}; + use restate_core::{TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnvBuilder}; use restate_partition_store::PartitionStoreManager; use restate_rocksdb::RocksDbManager; use restate_types::config::{CommonOptions, Configuration, RocksDbOptions, StorageOptions}; @@ -947,49 +947,44 @@ mod tests { let mut env_builder = TestCoreEnvBuilder::with_incoming_only_connector().set_nodes_config(nodes_config); - let health_status = HealthStatus::default(); + let tc = env_builder.tc.clone(); + async { + let health_status = HealthStatus::default(); - env_builder.tc.run_in_scope_sync(|| { RocksDbManager::init(Constant::new(CommonOptions::default())); - }); - let bifrost_svc = BifrostService::new(env_builder.tc.clone(), env_builder.metadata.clone()) - .with_factory(memory_loglet::Factory::default()); - let bifrost = bifrost_svc.handle(); + let bifrost_svc = BifrostService::new().with_factory(memory_loglet::Factory::default()); + let bifrost = bifrost_svc.handle(); - let partition_store_manager = PartitionStoreManager::create( - Constant::new(StorageOptions::default()), - Constant::new(RocksDbOptions::default()).boxed(), - &[(PartitionId::MIN, 0..=PartitionKey::MAX)], - ) - .await?; + let partition_store_manager = PartitionStoreManager::create( + Constant::new(StorageOptions::default()), + Constant::new(RocksDbOptions::default()).boxed(), + &[(PartitionId::MIN, 0..=PartitionKey::MAX)], + ) + .await?; - let partition_processor_manager = PartitionProcessorManager::new( - env_builder.tc.clone(), - health_status, - Live::from_value(Configuration::default()), - env_builder.metadata.clone(), - env_builder.metadata_store_client.clone(), - partition_store_manager, - &mut env_builder.router_builder, - bifrost, - ); + let partition_processor_manager = PartitionProcessorManager::new( + env_builder.tc.clone(), + health_status, + Live::from_value(Configuration::default()), + env_builder.metadata.clone(), + env_builder.metadata_store_client.clone(), + partition_store_manager, + &mut env_builder.router_builder, + bifrost, + ); - let env = env_builder.build().await; - let processors_manager_handle = partition_processor_manager.handle(); + let env = env_builder.build().await; + let processors_manager_handle = partition_processor_manager.handle(); + + bifrost_svc.start().await.into_test_result()?; + TaskCenter::current().spawn( + TaskKind::SystemService, + "partition-processor-manager", + None, + partition_processor_manager.run(), + )?; - env.tc - .run_in_scope("init-bifrost", None, bifrost_svc.start()) - .await - .into_test_result()?; - env.tc.spawn( - TaskKind::SystemService, - "partition-processor-manager", - None, - partition_processor_manager.run(), - )?; - let tc = env.tc.clone(); - tc.run_in_scope("test", None, async move { let connection = MockPeerConnection::connect( node_id, env.metadata.nodes_config_version(), @@ -1049,7 +1044,8 @@ mod tests { } googletest::Result::Ok(()) - }) + } + .in_tc(&tc) .await?; tc.shutdown_node("test completed", 0).await; diff --git a/tools/bifrost-benchpress/src/main.rs b/tools/bifrost-benchpress/src/main.rs index 0bdf283d5..902e4b68e 100644 --- a/tools/bifrost-benchpress/src/main.rs +++ b/tools/bifrost-benchpress/src/main.rs @@ -148,7 +148,6 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, .build() .expect("task_center builds"); - let task_center = tc.clone(); let bifrost = tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); @@ -174,7 +173,7 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, metadata_writer.submit(Arc::new(logs)); spawn_metadata_manager(metadata_manager).expect("metadata manager starts"); - let bifrost_svc = BifrostService::new(task_center, metadata) + let bifrost_svc = BifrostService::new() .enable_in_memory_loglet() .enable_local_loglet(&config); let bifrost = bifrost_svc.handle(); diff --git a/tools/restatectl/src/commands/log/dump_log.rs b/tools/restatectl/src/commands/log/dump_log.rs index d99e749f5..30ee43d5d 100644 --- a/tools/restatectl/src/commands/log/dump_log.rs +++ b/tools/restatectl/src/commands/log/dump_log.rs @@ -97,8 +97,7 @@ async fn dump_log(opts: &DumpLogOpts) -> anyhow::Result<()> { metadata_manager.run(), )?; - let bifrost_svc = BifrostService::new(TaskCenter::current(), metadata.clone()) - .enable_local_loglet(&Configuration::updateable()); + let bifrost_svc = BifrostService::new().enable_local_loglet(&Configuration::updateable()); let bifrost = bifrost_svc.handle(); // Ensures bifrost has initial metadata synced up before starting the worker. diff --git a/tools/xtask/src/main.rs b/tools/xtask/src/main.rs index 2292b61bf..621975b36 100644 --- a/tools/xtask/src/main.rs +++ b/tools/xtask/src/main.rs @@ -18,8 +18,8 @@ use schemars::gen::SchemaSettings; use restate_admin::service::AdminService; use restate_bifrost::Bifrost; -use restate_core::TaskKind; use restate_core::TestCoreEnv; +use restate_core::{TaskCenterFutureExt, TaskKind}; use restate_service_client::{AssumeRoleCacheMode, ServiceClient}; use restate_service_protocol::discovery::ServiceDiscovery; use restate_storage_query_datafusion::table_docs; @@ -103,14 +103,7 @@ async fn generate_rest_api_doc() -> anyhow::Result<()> { // We start the Meta service, then download the openapi schema generated let node_env = TestCoreEnv::create_with_single_node(1, 1).await; - let bifrost = node_env - .tc - .run_in_scope( - "bifrost init", - None, - Bifrost::init_in_memory(node_env.metadata.clone()), - ) - .await; + let bifrost = Bifrost::init_in_memory().in_tc(&node_env.tc).await; let admin_service = AdminService::new( node_env.metadata_writer.clone(), From 976c909f4161194e23bd75c64ba690dced64700a Mon Sep 17 00:00:00 2001 From: Ahmed Farghal Date: Fri, 22 Nov 2024 18:46:55 +0000 Subject: [PATCH 4/4] [TaskCenter] Stage 5 - test macro --- Cargo.lock | 10 + Cargo.toml | 3 + crates/bifrost/src/bifrost.rs | 794 ++++++++++++------------- crates/core/Cargo.toml | 4 +- crates/core/derive/Cargo.toml | 16 + crates/core/derive/src/lib.rs | 39 ++ crates/core/derive/src/tc_test.rs | 681 +++++++++++++++++++++ crates/core/src/lib.rs | 10 + crates/core/src/task_center/builder.rs | 12 +- crates/core/src/task_center/mod.rs | 18 +- crates/core/src/test_env.rs | 3 +- crates/core/src/test_env2.rs | 303 ++++++++++ 12 files changed, 1469 insertions(+), 424 deletions(-) create mode 100644 crates/core/derive/Cargo.toml create mode 100644 crates/core/derive/src/lib.rs create mode 100644 crates/core/derive/src/tc_test.rs create mode 100644 crates/core/src/test_env2.rs diff --git a/Cargo.lock b/Cargo.lock index f99eba173..95a2dea10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6107,6 +6107,7 @@ dependencies = [ "prost", "prost-types", "rand", + "restate-core-derive", "restate-test-util", "restate-types", "schemars", @@ -6131,6 +6132,15 @@ dependencies = [ "xxhash-rust", ] +[[package]] +name = "restate-core-derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.85", +] + [[package]] name = "restate-errors" version = "1.1.4" diff --git a/Cargo.toml b/Cargo.toml index 7154a9645..1a5cc1e2f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "cli", "crates/*", + "crates/core/derive", "crates/codederror/derive", "server", "benchmarks", @@ -14,6 +15,7 @@ members = [ default-members = [ "cli", "crates/*", + "crates/core/derive", "crates/codederror/derive", "server", "tools/restatectl", @@ -38,6 +40,7 @@ restate-base64-util = { path = "crates/base64-util" } restate-bifrost = { path = "crates/bifrost" } restate-cli-util = { path = "crates/cli-util" } restate-core = { path = "crates/core" } +restate-core-derive = { path = "crates/core/derive" } restate-errors = { path = "crates/errors" } restate-fs-util = { path = "crates/fs-util" } restate-futures-util = { path = "crates/futures-util" } diff --git a/crates/bifrost/src/bifrost.rs b/crates/bifrost/src/bifrost.rs index beeffa185..f04cf4aca 100644 --- a/crates/bifrost/src/bifrost.rs +++ b/crates/bifrost/src/bifrost.rs @@ -496,8 +496,8 @@ mod tests { use tracing::info; use tracing_test::traced_test; - use restate_core::TestCoreEnvBuilder; - use restate_core::{TaskCenter, TaskCenterFutureExt, TaskKind, TestCoreEnv}; + use restate_core::TestCoreEnvBuilder2; + use restate_core::{TaskCenter, TaskKind, TestCoreEnv2}; use restate_rocksdb::RocksDbManager; use restate_types::config::CommonOptions; use restate_types::live::Constant; @@ -510,379 +510,359 @@ mod tests { use crate::providers::memory_loglet::{self}; use crate::BifrostAdmin; - #[tokio::test] + #[restate_core::test] #[traced_test] async fn test_append_smoke() -> googletest::Result<()> { let num_partitions = 5; - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let _ = TestCoreEnvBuilder2::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, num_partitions, )) .build() .await; - async { - let bifrost = Bifrost::init_in_memory().await; - - let clean_bifrost_clone = bifrost.clone(); - - let mut appender_0 = bifrost.create_appender(LogId::new(0))?; - let mut appender_3 = bifrost.create_appender(LogId::new(3))?; - let mut max_lsn = Lsn::INVALID; - for i in 1..=5 { - // Append a record to memory - let lsn = appender_0.append("").await?; - info!(%lsn, "Appended record to log"); - assert_eq!(Lsn::from(i), lsn); - max_lsn = lsn; - } - // Append to a log that doesn't exist. - let invalid_log = LogId::from(num_partitions + 1); - let resp = bifrost.create_appender(invalid_log); - - assert_that!(resp, pat!(Err(pat!(Error::UnknownLogId(eq(invalid_log)))))); - - // use a cloned bifrost. - let cloned_bifrost = bifrost.clone(); - let mut second_appender_0 = cloned_bifrost.create_appender(LogId::new(0))?; - for _ in 1..=5 { - // Append a record to memory - let lsn = second_appender_0.append("").await?; - info!(%lsn, "Appended record to log"); - assert_eq!(max_lsn + Lsn::from(1), lsn); - max_lsn = lsn; - } + let bifrost = Bifrost::init_in_memory().await; - // Ensure original clone writes to the same underlying loglet. - let lsn = clean_bifrost_clone - .create_appender(LogId::new(0))? - .append("") - .await?; - assert_eq!(max_lsn + Lsn::from(1), lsn); + let clean_bifrost_clone = bifrost.clone(); + + let mut appender_0 = bifrost.create_appender(LogId::new(0))?; + let mut appender_3 = bifrost.create_appender(LogId::new(3))?; + let mut max_lsn = Lsn::INVALID; + for i in 1..=5 { + // Append a record to memory + let lsn = appender_0.append("").await?; + info!(%lsn, "Appended record to log"); + assert_eq!(Lsn::from(i), lsn); max_lsn = lsn; + } - // Writes to another log don't impact original log - let lsn = appender_3.append("").await?; - assert_eq!(Lsn::from(1), lsn); + // Append to a log that doesn't exist. + let invalid_log = LogId::from(num_partitions + 1); + let resp = bifrost.create_appender(invalid_log); - let lsn = appender_0.append("").await?; + assert_that!(resp, pat!(Err(pat!(Error::UnknownLogId(eq(invalid_log)))))); + + // use a cloned bifrost. + let cloned_bifrost = bifrost.clone(); + let mut second_appender_0 = cloned_bifrost.create_appender(LogId::new(0))?; + for _ in 1..=5 { + // Append a record to memory + let lsn = second_appender_0.append("").await?; + info!(%lsn, "Appended record to log"); assert_eq!(max_lsn + Lsn::from(1), lsn); max_lsn = lsn; - - let tail = bifrost.find_tail(LogId::new(0)).await?; - assert_eq!(max_lsn.next(), tail.offset()); - - // Initiate shutdown - TaskCenter::current().shutdown_node("completed", 0).await; - // appends cannot succeed after shutdown - let res = appender_0.append("").await; - assert!(matches!(res, Err(Error::Shutdown(_)))); - // Validate the watchdog has called the provider::start() function. - assert!(logs_contain("Shutting down in-memory loglet provider")); - assert!(logs_contain("Bifrost watchdog shutdown complete")); - Ok(()) } - .in_tc(&node_env.tc) - .await + + // Ensure original clone writes to the same underlying loglet. + let lsn = clean_bifrost_clone + .create_appender(LogId::new(0))? + .append("") + .await?; + assert_eq!(max_lsn + Lsn::from(1), lsn); + max_lsn = lsn; + + // Writes to another log don't impact original log + let lsn = appender_3.append("").await?; + assert_eq!(Lsn::from(1), lsn); + + let lsn = appender_0.append("").await?; + assert_eq!(max_lsn + Lsn::from(1), lsn); + max_lsn = lsn; + + let tail = bifrost.find_tail(LogId::new(0)).await?; + assert_eq!(max_lsn.next(), tail.offset()); + + // Initiate shutdown + TaskCenter::current().shutdown_node("completed", 0).await; + // appends cannot succeed after shutdown + let res = appender_0.append("").await; + assert!(matches!(res, Err(Error::Shutdown(_)))); + // Validate the watchdog has called the provider::start() function. + assert!(logs_contain("Shutting down in-memory loglet provider")); + assert!(logs_contain("Bifrost watchdog shutdown complete")); + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn test_lazy_initialization() -> googletest::Result<()> { - let node_env = TestCoreEnv::create_with_single_node(1, 1).await; - async { - let delay = Duration::from_secs(5); - // This memory provider adds a delay to its loglet initialization, we want - // to ensure that appends do not fail while waiting for the loglet; - let factory = memory_loglet::Factory::with_init_delay(delay); - let bifrost = Bifrost::init_with_factory(factory).await; - - let start = tokio::time::Instant::now(); - let lsn = bifrost.create_appender(LogId::new(0))?.append("").await?; - assert_eq!(Lsn::from(1), lsn); - // The append was properly delayed - assert_eq!(delay, start.elapsed()); - Ok(()) - } - .in_tc(&node_env.tc) - .await + let _ = TestCoreEnv2::create_with_single_node(1, 1).await; + let delay = Duration::from_secs(5); + // This memory provider adds a delay to its loglet initialization, we want + // to ensure that appends do not fail while waiting for the loglet; + let factory = memory_loglet::Factory::with_init_delay(delay); + let bifrost = Bifrost::init_with_factory(factory).await; + + let start = tokio::time::Instant::now(); + let lsn = bifrost.create_appender(LogId::new(0))?.append("").await?; + assert_eq!(Lsn::from(1), lsn); + // The append was properly delayed + assert_eq!(delay, start.elapsed()); + Ok(()) } - #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + #[test(restate_core::test(flavor = "multi_thread", worker_threads = 2))] async fn trim_log_smoke_test() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; - async { - RocksDbManager::init(Constant::new(CommonOptions::default())); - - let bifrost = Bifrost::init_local().await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); + RocksDbManager::init(Constant::new(CommonOptions::default())); - assert_eq!(Lsn::OLDEST, bifrost.find_tail(LOG_ID).await?.offset()); + let bifrost = Bifrost::init_local().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); - assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); + assert_eq!(Lsn::OLDEST, bifrost.find_tail(LOG_ID).await?.offset()); - let mut appender = bifrost.create_appender(LOG_ID)?; - // append 10 records - for _ in 1..=10 { - appender.append("").await?; - } + assert_eq!(Lsn::INVALID, bifrost.get_trim_point(LOG_ID).await?); - bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; + let mut appender = bifrost.create_appender(LOG_ID)?; + // append 10 records + for _ in 1..=10 { + appender.append("").await?; + } - let tail = bifrost.find_tail(LOG_ID).await?; - assert_eq!(tail.offset(), Lsn::from(11)); - assert!(!tail.is_sealed()); - assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); + bifrost_admin.trim(LOG_ID, Lsn::from(5)).await?; - // 5 itself is trimmed - for lsn in 1..=5 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + let tail = bifrost.find_tail(LOG_ID).await?; + assert_eq!(tail.offset(), Lsn::from(11)); + assert!(!tail.is_sealed()); + assert_eq!(Lsn::from(5), bifrost.get_trim_point(LOG_ID).await?); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); - } + // 5 itself is trimmed + for lsn in 1..=5 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - for lsn in 6..=10 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(5)))); + } - // trimming beyond the release point will fall back to the release point - bifrost_admin.trim(LOG_ID, Lsn::MAX).await?; + for lsn in 6..=10 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); + } - assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); - let new_trim_point = bifrost.get_trim_point(LOG_ID).await?; - assert_eq!(Lsn::from(10), new_trim_point); + // trimming beyond the release point will fall back to the release point + bifrost_admin.trim(LOG_ID, Lsn::MAX).await?; - let record = bifrost.read(LOG_ID, Lsn::from(10)).await?.unwrap(); - assert!(record.is_trim_gap()); - assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); + assert_eq!(Lsn::from(11), bifrost.find_tail(LOG_ID).await?.offset()); + let new_trim_point = bifrost.get_trim_point(LOG_ID).await?; + assert_eq!(Lsn::from(10), new_trim_point); - // Add 10 more records - for _ in 0..10 { - appender.append("").await?; - } + let record = bifrost.read(LOG_ID, Lsn::from(10)).await?.unwrap(); + assert!(record.is_trim_gap()); + assert_that!(record.trim_gap_to_sequence_number(), eq(Some(Lsn::new(10)))); - for lsn in 11..20 { - let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); - assert!(record.is_data_record()); - } + // Add 10 more records + for _ in 0..10 { + appender.append("").await?; + } - Ok(()) + for lsn in 11..20 { + let record = bifrost.read(LOG_ID, Lsn::from(lsn)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(lsn))); + assert!(record.is_data_record()); } - .in_tc(&node_env.tc) - .await + + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] async fn test_read_across_segments() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )) .build() .await; - async { - let bifrost = Bifrost::init_in_memory().await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); - - let mut appender = bifrost.create_appender(LOG_ID)?; - // Lsns [1..5] - for i in 1..=5 { - // Append a record to memory - let lsn = appender.append(format!("segment-1-{i}")).await?; - assert_eq!(Lsn::from(i), lsn); - } - - // not sealed, tail is what we expect - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Open(eq(Lsn::new(6)))) - ); - - let segment_1 = bifrost - .inner - .find_loglet_for_lsn(LOG_ID, Lsn::OLDEST) - .await? - .unwrap(); - - // seal the segment - bifrost_admin - .seal(LOG_ID, segment_1.segment_index()) - .await?; + let bifrost = Bifrost::init_in_memory().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); + + let mut appender = bifrost.create_appender(LOG_ID)?; + // Lsns [1..5] + for i in 1..=5 { + // Append a record to memory + let lsn = appender.append(format!("segment-1-{i}")).await?; + assert_eq!(Lsn::from(i), lsn); + } - // sealed, tail is what we expect - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Sealed(eq(Lsn::new(6)))) - ); + // not sealed, tail is what we expect + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Open(eq(Lsn::new(6)))) + ); + + let segment_1 = bifrost + .inner + .find_loglet_for_lsn(LOG_ID, Lsn::OLDEST) + .await? + .unwrap(); + + // seal the segment + bifrost_admin + .seal(LOG_ID, segment_1.segment_index()) + .await?; - println!("attempting to read during reconfiguration"); - // attempting to read from bifrost will result in a timeout since metadata sees this as an open - // segment but the segment itself is sealed. This means reconfiguration is in-progress - // and we can't confidently read records. - assert!(tokio::time::timeout( - Duration::from_secs(5), - bifrost.read(LOG_ID, Lsn::new(2)) + // sealed, tail is what we expect + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Sealed(eq(Lsn::new(6)))) + ); + + println!("attempting to read during reconfiguration"); + // attempting to read from bifrost will result in a timeout since metadata sees this as an open + // segment but the segment itself is sealed. This means reconfiguration is in-progress + // and we can't confidently read records. + assert!( + tokio::time::timeout(Duration::from_secs(5), bifrost.read(LOG_ID, Lsn::new(2))) + .await + .is_err() + ); + + let metadata = Metadata::current(); + let old_version = metadata.logs_version(); + + let mut builder = metadata.logs_ref().clone().into_builder(); + let mut chain_builder = builder.chain(LOG_ID).unwrap(); + assert_eq!(1, chain_builder.num_segments()); + let new_segment_params = new_single_node_loglet_params(ProviderKind::InMemory); + // deliberately skips Lsn::from(6) to create a zombie record in segment 1. Segment 1 now has 4 records. + chain_builder.append_segment(Lsn::new(5), ProviderKind::InMemory, new_segment_params)?; + + let new_metadata = builder.build(); + let new_version = new_metadata.version(); + assert_eq!(new_version, old_version.next()); + node_env + .metadata_store_client + .put( + BIFROST_CONFIG_KEY.clone(), + &new_metadata, + restate_metadata_store::Precondition::MatchesVersion(old_version), ) - .await - .is_err()); - - let metadata = Metadata::current(); - let old_version = metadata.logs_version(); - - let mut builder = metadata.logs_ref().clone().into_builder(); - let mut chain_builder = builder.chain(LOG_ID).unwrap(); - assert_eq!(1, chain_builder.num_segments()); - let new_segment_params = new_single_node_loglet_params(ProviderKind::InMemory); - // deliberately skips Lsn::from(6) to create a zombie record in segment 1. Segment 1 now has 4 records. - chain_builder.append_segment( - Lsn::new(5), - ProviderKind::InMemory, - new_segment_params, - )?; - - let new_metadata = builder.build(); - let new_version = new_metadata.version(); - assert_eq!(new_version, old_version.next()); - node_env - .metadata_store_client - .put( - BIFROST_CONFIG_KEY.clone(), - &new_metadata, - restate_metadata_store::Precondition::MatchesVersion(old_version), - ) - .await?; - - // make sure we have updated metadata. - metadata - .sync(MetadataKind::Logs, TargetVersion::Latest) - .await?; - assert_eq!(new_version, metadata.logs_version()); - - { - // validate that the stored metadata matches our expectations. - let new_metadata = metadata.logs_ref().clone(); - let chain_builder = new_metadata.chain(&LOG_ID).unwrap(); - assert_eq!(2, chain_builder.num_segments()); - } - - // find_tail() on the underlying loglet returns (6) but for bifrost it should be (5) after - // the new segment was created at tail of the chain with base_lsn=5 - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Open(eq(Lsn::new(5)))) - ); - - // appends should go to the new segment - let mut appender = bifrost.create_appender(LOG_ID)?; - // Lsns [5..7] - for i in 5..=7 { - // Append a record to memory - let lsn = appender.append(format!("segment-2-{i}")).await?; - assert_eq!(Lsn::from(i), lsn); - } - - // tail is now 8 and open. - assert_that!( - bifrost.find_tail(LOG_ID).await?, - pat!(TailState::Open(eq(Lsn::new(8)))) - ); - - // validating that segment 1 is still sealed and has its own tail at Lsn (6) - assert_that!( - segment_1.find_tail().await?, - pat!(TailState::Sealed(eq(Lsn::new(6)))) - ); - - let segment_2 = bifrost - .inner - .find_loglet_for_lsn(LOG_ID, Lsn::new(5)) - .await? - .unwrap(); - - assert_ne!(segment_1, segment_2); - - // segment 2 is open and at 8 as previously validated through bifrost interface - assert_that!( - segment_2.find_tail().await?, - pat!(TailState::Open(eq(Lsn::new(8)))) - ); - - // Reading the log. (OLDEST) - let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(1))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-1-1".to_owned()) - ); - - let record = bifrost.read(LOG_ID, Lsn::new(2)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(2))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-1-2".to_owned()) - ); + .await?; - // border of segment 1 - let record = bifrost.read(LOG_ID, Lsn::new(4)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(4))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-1-4".to_owned()) - ); + // make sure we have updated metadata. + metadata + .sync(MetadataKind::Logs, TargetVersion::Latest) + .await?; + assert_eq!(new_version, metadata.logs_version()); - // start of segment 2 - let record = bifrost.read(LOG_ID, Lsn::new(5)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(5))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-2-5".to_owned()) - ); + { + // validate that the stored metadata matches our expectations. + let new_metadata = metadata.logs_ref().clone(); + let chain_builder = new_metadata.chain(&LOG_ID).unwrap(); + assert_eq!(2, chain_builder.num_segments()); + } - // last record - let record = bifrost.read(LOG_ID, Lsn::new(7)).await?.unwrap(); - assert_that!(record.sequence_number(), eq(Lsn::new(7))); - assert!(record.is_data_record()); - assert_that!( - record.decode_unchecked::(), - eq("segment-2-7".to_owned()) - ); + // find_tail() on the underlying loglet returns (6) but for bifrost it should be (5) after + // the new segment was created at tail of the chain with base_lsn=5 + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Open(eq(Lsn::new(5)))) + ); + + // appends should go to the new segment + let mut appender = bifrost.create_appender(LOG_ID)?; + // Lsns [5..7] + for i in 5..=7 { + // Append a record to memory + let lsn = appender.append(format!("segment-2-{i}")).await?; + assert_eq!(Lsn::from(i), lsn); + } - // 8 doesn't exist yet. - assert!(bifrost.read(LOG_ID, Lsn::new(8)).await?.is_none()); + // tail is now 8 and open. + assert_that!( + bifrost.find_tail(LOG_ID).await?, + pat!(TailState::Open(eq(Lsn::new(8)))) + ); + + // validating that segment 1 is still sealed and has its own tail at Lsn (6) + assert_that!( + segment_1.find_tail().await?, + pat!(TailState::Sealed(eq(Lsn::new(6)))) + ); + + let segment_2 = bifrost + .inner + .find_loglet_for_lsn(LOG_ID, Lsn::new(5)) + .await? + .unwrap(); + + assert_ne!(segment_1, segment_2); + + // segment 2 is open and at 8 as previously validated through bifrost interface + assert_that!( + segment_2.find_tail().await?, + pat!(TailState::Open(eq(Lsn::new(8)))) + ); + + // Reading the log. (OLDEST) + let record = bifrost.read(LOG_ID, Lsn::OLDEST).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(1))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-1-1".to_owned()) + ); + + let record = bifrost.read(LOG_ID, Lsn::new(2)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(2))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-1-2".to_owned()) + ); + + // border of segment 1 + let record = bifrost.read(LOG_ID, Lsn::new(4)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(4))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-1-4".to_owned()) + ); + + // start of segment 2 + let record = bifrost.read(LOG_ID, Lsn::new(5)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(5))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-2-5".to_owned()) + ); + + // last record + let record = bifrost.read(LOG_ID, Lsn::new(7)).await?.unwrap(); + assert_that!(record.sequence_number(), eq(Lsn::new(7))); + assert!(record.is_data_record()); + assert_that!( + record.decode_unchecked::(), + eq("segment-2-7".to_owned()) + ); + + // 8 doesn't exist yet. + assert!(bifrost.read(LOG_ID, Lsn::new(8)).await?.is_none()); - Ok(()) - } - .in_tc(&node_env.tc) - .await + Ok(()) } - #[tokio::test(start_paused = true)] + #[restate_core::test(start_paused = true)] #[traced_test] async fn test_appends_correctly_handle_reconfiguration() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + let node_env = TestCoreEnvBuilder2::with_incoming_only_connector() .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, @@ -890,118 +870,112 @@ mod tests { .set_provider_kind(ProviderKind::Local) .build() .await; - async { - RocksDbManager::init(Constant::new(CommonOptions::default())); - let bifrost = Bifrost::init_local().await; - let bifrost_admin = BifrostAdmin::new( - &bifrost, - &node_env.metadata_writer, - &node_env.metadata_store_client, - ); - - // create an appender - let stop_signal = Arc::new(AtomicBool::default()); - let append_counter = Arc::new(AtomicUsize::new(0)); - let _ = TaskCenter::current().spawn(TaskKind::TestRunner, "append-records", None, { - let append_counter = append_counter.clone(); - let stop_signal = stop_signal.clone(); - let bifrost = bifrost.clone(); - let mut appender = bifrost.create_appender(LOG_ID)?; - async move { - let mut i = 0; - while !stop_signal.load(Ordering::Relaxed) { - i += 1; - if i % 2 == 0 { - // append individual record - let lsn = appender.append(format!("record{}", i)).await?; - println!("Appended {}", lsn); - } else { - // append batch - let mut payloads = Vec::with_capacity(10); - for j in 1..=10 { - payloads.push(format!("record-in-batch{}-{}", i, j)); - } - let lsn = appender.append_batch(payloads).await?; - println!("Appended batch {}", lsn); + RocksDbManager::init(Constant::new(CommonOptions::default())); + let bifrost = Bifrost::init_local().await; + let bifrost_admin = BifrostAdmin::new( + &bifrost, + &node_env.metadata_writer, + &node_env.metadata_store_client, + ); + + // create an appender + let stop_signal = Arc::new(AtomicBool::default()); + let append_counter = Arc::new(AtomicUsize::new(0)); + let _ = TaskCenter::current().spawn(TaskKind::TestRunner, "append-records", None, { + let append_counter = append_counter.clone(); + let stop_signal = stop_signal.clone(); + let bifrost = bifrost.clone(); + let mut appender = bifrost.create_appender(LOG_ID)?; + async move { + let mut i = 0; + while !stop_signal.load(Ordering::Relaxed) { + i += 1; + if i % 2 == 0 { + // append individual record + let lsn = appender.append(format!("record{}", i)).await?; + println!("Appended {}", lsn); + } else { + // append batch + let mut payloads = Vec::with_capacity(10); + for j in 1..=10 { + payloads.push(format!("record-in-batch{}-{}", i, j)); } - append_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - tokio::time::sleep(Duration::from_millis(1)).await; + let lsn = appender.append_batch(payloads).await?; + println!("Appended batch {}", lsn); } - println!("Appender terminated"); - Ok(()) - } - })?; - - let mut append_counter_before_seal; - loop { - append_counter_before_seal = append_counter.load(Ordering::Relaxed); - if append_counter_before_seal < 100 { - tokio::time::sleep(Duration::from_millis(10)).await; - } else { - break; + append_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(1)).await; } + println!("Appender terminated"); + Ok(()) } - - // seal and don't extend the chain. - let _ = bifrost_admin.seal(LOG_ID, SegmentIndex::from(0)).await?; - - // appends should stall! - tokio::time::sleep(Duration::from_millis(100)).await; - let append_counter_during_seal = append_counter.load(Ordering::Relaxed); - for _ in 0..5 { - tokio::time::sleep(Duration::from_millis(500)).await; - let counter_now = append_counter.load(Ordering::Relaxed); - assert_that!(counter_now, eq(append_counter_during_seal)); - println!("Appends are stalling, counter={}", counter_now); + })?; + + let mut append_counter_before_seal; + loop { + append_counter_before_seal = append_counter.load(Ordering::Relaxed); + if append_counter_before_seal < 100 { + tokio::time::sleep(Duration::from_millis(10)).await; + } else { + break; } + } + + // seal and don't extend the chain. + let _ = bifrost_admin.seal(LOG_ID, SegmentIndex::from(0)).await?; + + // appends should stall! + tokio::time::sleep(Duration::from_millis(100)).await; + let append_counter_during_seal = append_counter.load(Ordering::Relaxed); + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(500)).await; + let counter_now = append_counter.load(Ordering::Relaxed); + assert_that!(counter_now, eq(append_counter_during_seal)); + println!("Appends are stalling, counter={}", counter_now); + } - for i in 1..=5 { - let last_segment = bifrost + for i in 1..=5 { + let last_segment = bifrost + .inner + .writeable_loglet(LOG_ID) + .await? + .segment_index(); + // allow appender to run a little. + tokio::time::sleep(Duration::from_millis(500)).await; + // seal the loglet and extend with an in-memory one + let new_segment_params = new_single_node_loglet_params(ProviderKind::Local); + bifrost_admin + .seal_and_extend_chain( + LOG_ID, + None, + Version::MIN, + ProviderKind::Local, + new_segment_params, + ) + .await?; + println!("Seal {}", i); + assert_that!( + bifrost .inner .writeable_loglet(LOG_ID) .await? - .segment_index(); - // allow appender to run a little. - tokio::time::sleep(Duration::from_millis(500)).await; - // seal the loglet and extend with an in-memory one - let new_segment_params = new_single_node_loglet_params(ProviderKind::Local); - bifrost_admin - .seal_and_extend_chain( - LOG_ID, - None, - Version::MIN, - ProviderKind::Local, - new_segment_params, - ) - .await?; - println!("Seal {}", i); - assert_that!( - bifrost - .inner - .writeable_loglet(LOG_ID) - .await? - .segment_index(), - gt(last_segment) - ); - } - - // make sure that appends are still happening. - let mut append_counter_after_seal = append_counter.load(Ordering::Relaxed); - tokio::time::sleep(Duration::from_millis(100)).await; - assert_that!(append_counter_after_seal, gt(append_counter_before_seal)); - for _ in 0..5 { - tokio::time::sleep(Duration::from_millis(50)).await; - let counter_now = append_counter.load(Ordering::Relaxed); - assert_that!(counter_now, gt(append_counter_after_seal)); - append_counter_after_seal = counter_now; - } + .segment_index(), + gt(last_segment) + ); + } - googletest::Result::Ok(()) + // make sure that appends are still happening. + let mut append_counter_after_seal = append_counter.load(Ordering::Relaxed); + tokio::time::sleep(Duration::from_millis(100)).await; + assert_that!(append_counter_after_seal, gt(append_counter_before_seal)); + for _ in 0..5 { + tokio::time::sleep(Duration::from_millis(50)).await; + let counter_now = append_counter.load(Ordering::Relaxed); + assert_that!(counter_now, gt(append_counter_after_seal)); + append_counter_after_seal = counter_now; } - .in_tc(&node_env.tc) - .await?; - node_env.tc.shutdown_node("test completed", 0).await; + // questionable. RocksDbManager::get().shutdown().await; Ok(()) } diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 2aaab89b0..b406faee4 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -9,11 +9,12 @@ publish = false [features] default = [] -test-util = ["tokio/test-util"] +test-util = ["tokio/test-util", "restate-core-derive"] options_schema = ["dep:schemars"] [dependencies] restate-types = { workspace = true } +restate-core-derive = { workspace = true, optional = true } anyhow = { workspace = true } axum = { workspace = true, default-features = false } @@ -67,6 +68,7 @@ tonic-build = { workspace = true } [dev-dependencies] restate-test-util = { workspace = true } restate-types = { workspace = true, features = ["test-util"] } +restate-core-derive = { workspace = true } googletest = { workspace = true } test-log = { workspace = true } diff --git a/crates/core/derive/Cargo.toml b/crates/core/derive/Cargo.toml new file mode 100644 index 000000000..355f8beb6 --- /dev/null +++ b/crates/core/derive/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "restate-core-derive" +version = "0.1.0" +authors.workspace = true +edition.workspace = true +rust-version.workspace = true +license.workspace = true +publish = false + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0" +quote = "1" +syn = { version = "2.0", features = ["full"] } diff --git a/crates/core/derive/src/lib.rs b/crates/core/derive/src/lib.rs new file mode 100644 index 000000000..b4cac8b07 --- /dev/null +++ b/crates/core/derive/src/lib.rs @@ -0,0 +1,39 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +extern crate proc_macro; + +mod tc_test; + +use proc_macro::TokenStream; + +/// Run tests within task-center +/// +/// +/// You can configure the underlying runtime(s) as you would do with tokio +/// ```no_run +/// #[restate_core::test(_args_of_tokio_test)] +/// async fn test_name() { +/// TaskCenter::current(); +/// } +/// ``` +/// +/// A generalised example is +/// ```no_run +/// #[restate_core::test(start_paused = true)]` +/// async fn test_name() { +/// TaskCenter::current(); +/// } +/// ``` +/// +#[proc_macro_attribute] +pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { + tc_test::test(args.into(), item.into(), true).into() +} diff --git a/crates/core/derive/src/tc_test.rs b/crates/core/derive/src/tc_test.rs new file mode 100644 index 000000000..d7d317a84 --- /dev/null +++ b/crates/core/derive/src/tc_test.rs @@ -0,0 +1,681 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +//! Some parts of this codebase were taken from https://github.com/tokio-rs/tokio/blob/master/tokio-macros/src/entry.rs +//! MIT License + +use proc_macro2::{Span, TokenStream, TokenTree}; +use quote::{quote, quote_spanned, ToTokens}; +use syn::parse::{Parse, ParseStream, Parser}; +use syn::{braced, Attribute, Ident, Path, Signature, Visibility}; + +// syn::AttributeArgs does not implement syn::Parse +type AttributeArgs = syn::punctuated::Punctuated; + +#[derive(Clone, Copy, PartialEq)] +enum RuntimeFlavor { + CurrentThread, + Threaded, +} + +impl RuntimeFlavor { + fn from_str(s: &str) -> Result { + match s { + "current_thread" => Ok(RuntimeFlavor::CurrentThread), + "multi_thread" => Ok(RuntimeFlavor::Threaded), + "single_thread" => Err("The single threaded runtime flavor is called `current_thread`.".to_string()), + "basic_scheduler" => Err("The `basic_scheduler` runtime flavor has been renamed to `current_thread`.".to_string()), + "threaded_scheduler" => Err("The `threaded_scheduler` runtime flavor has been renamed to `multi_thread`.".to_string()), + _ => Err(format!("No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.", s)), + } + } +} + +#[derive(Clone, Copy, PartialEq)] +enum UnhandledPanic { + Ignore, + ShutdownRuntime, +} + +impl UnhandledPanic { + fn from_str(s: &str) -> Result { + match s { + "ignore" => Ok(UnhandledPanic::Ignore), + "shutdown_runtime" => Ok(UnhandledPanic::ShutdownRuntime), + _ => Err(format!("No such unhandled panic behavior `{}`. The unhandled panic behaviors are `ignore` and `shutdown_runtime`.", s)), + } + } + + fn into_tokens(self, crate_path: &TokenStream) -> TokenStream { + match self { + UnhandledPanic::Ignore => quote! { #crate_path::runtime::UnhandledPanic::Ignore }, + UnhandledPanic::ShutdownRuntime => { + quote! { #crate_path::runtime::UnhandledPanic::ShutdownRuntime } + } + } + } +} + +struct FinalConfig { + flavor: RuntimeFlavor, + worker_threads: Option, + start_paused: Option, + crate_name: Option, + unhandled_panic: Option, +} + +/// Config used in case of the attribute not being able to build a valid config +const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { + flavor: RuntimeFlavor::CurrentThread, + worker_threads: None, + start_paused: None, + crate_name: None, + unhandled_panic: None, +}; + +struct Configuration { + rt_multi_thread_available: bool, + default_flavor: RuntimeFlavor, + flavor: Option, + worker_threads: Option<(usize, Span)>, + start_paused: Option<(bool, Span)>, + crate_name: Option, + unhandled_panic: Option<(UnhandledPanic, Span)>, +} + +impl Configuration { + fn new(rt_multi_thread: bool) -> Self { + Configuration { + rt_multi_thread_available: rt_multi_thread, + default_flavor: RuntimeFlavor::CurrentThread, + flavor: None, + worker_threads: None, + start_paused: None, + crate_name: None, + unhandled_panic: None, + } + } + + fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.flavor.is_some() { + return Err(syn::Error::new(span, "`flavor` set multiple times.")); + } + + let runtime_str = parse_string(runtime, span, "flavor")?; + let runtime = + RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?; + self.flavor = Some(runtime); + Ok(()) + } + + fn set_worker_threads( + &mut self, + worker_threads: syn::Lit, + span: Span, + ) -> Result<(), syn::Error> { + if self.worker_threads.is_some() { + return Err(syn::Error::new( + span, + "`worker_threads` set multiple times.", + )); + } + + let worker_threads = parse_int(worker_threads, span, "worker_threads")?; + if worker_threads == 0 { + return Err(syn::Error::new(span, "`worker_threads` may not be 0.")); + } + self.worker_threads = Some((worker_threads, span)); + Ok(()) + } + + fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.start_paused.is_some() { + return Err(syn::Error::new(span, "`start_paused` set multiple times.")); + } + + let start_paused = parse_bool(start_paused, span, "start_paused")?; + self.start_paused = Some((start_paused, span)); + Ok(()) + } + + fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.crate_name.is_some() { + return Err(syn::Error::new(span, "`crate` set multiple times.")); + } + let name_path = parse_path(name, span, "crate")?; + self.crate_name = Some(name_path); + Ok(()) + } + + fn set_unhandled_panic( + &mut self, + unhandled_panic: syn::Lit, + span: Span, + ) -> Result<(), syn::Error> { + if self.unhandled_panic.is_some() { + return Err(syn::Error::new( + span, + "`unhandled_panic` set multiple times.", + )); + } + + let unhandled_panic = parse_string(unhandled_panic, span, "unhandled_panic")?; + let unhandled_panic = + UnhandledPanic::from_str(&unhandled_panic).map_err(|err| syn::Error::new(span, err))?; + self.unhandled_panic = Some((unhandled_panic, span)); + Ok(()) + } + + fn macro_name(&self) -> &'static str { + "restate_core::test" + } + + fn build(&self) -> Result { + use RuntimeFlavor as F; + + let flavor = self.flavor.unwrap_or(self.default_flavor); + let worker_threads = match (flavor, self.worker_threads) { + (F::CurrentThread, Some((_, worker_threads_span))) => { + let msg = format!( + "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(worker_threads_span, msg)); + } + (F::CurrentThread, None) => None, + (F::Threaded, worker_threads) if self.rt_multi_thread_available => { + worker_threads.map(|(val, _span)| val) + } + (F::Threaded, _) => { + let msg = if self.flavor.is_none() { + "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled." + } else { + "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature." + }; + return Err(syn::Error::new(Span::call_site(), msg)); + } + }; + + let start_paused = match (flavor, self.start_paused) { + (F::Threaded, Some((_, start_paused_span))) => { + let msg = format!( + "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(start_paused_span, msg)); + } + (F::CurrentThread, Some((start_paused, _))) => Some(start_paused), + (_, None) => None, + }; + + let unhandled_panic = match (flavor, self.unhandled_panic) { + (F::Threaded, Some((_, unhandled_panic_span))) => { + let msg = format!( + "The `unhandled_panic` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`", + self.macro_name(), + ); + return Err(syn::Error::new(unhandled_panic_span, msg)); + } + (F::CurrentThread, Some((unhandled_panic, _))) => Some(unhandled_panic), + (_, None) => None, + }; + + Ok(FinalConfig { + crate_name: self.crate_name.clone(), + flavor, + worker_threads, + start_paused, + unhandled_panic, + }) + } +} + +fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result { + match int { + syn::Lit::Int(lit) => match lit.base10_parse::() { + Ok(value) => Ok(value), + Err(e) => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as integer: {}", field, e), + )), + }, + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as integer.", field), + )), + } +} + +fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result { + match int { + syn::Lit::Str(s) => Ok(s.value()), + syn::Lit::Verbatim(s) => Ok(s.to_string()), + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as string.", field), + )), + } +} + +fn parse_path(lit: syn::Lit, span: Span, field: &str) -> Result { + match lit { + syn::Lit::Str(s) => { + let err = syn::Error::new( + span, + format!( + "Failed to parse value of `{}` as path: \"{}\"", + field, + s.value() + ), + ); + s.parse::().map_err(|_| err.clone()) + } + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as path.", field), + )), + } +} + +fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result { + match bool { + syn::Lit::Bool(b) => Ok(b.value), + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as bool.", field), + )), + } +} + +fn build_config( + input: &ItemFn, + args: AttributeArgs, + rt_multi_thread: bool, +) -> Result { + if input.sig.asyncness.is_none() { + let msg = "the `async` keyword is missing from the function declaration"; + return Err(syn::Error::new_spanned(input.sig.fn_token, msg)); + } + + let mut config = Configuration::new(rt_multi_thread); + let macro_name = config.macro_name(); + + for arg in args { + match arg { + syn::Meta::NameValue(namevalue) => { + let ident = namevalue + .path + .get_ident() + .ok_or_else(|| { + syn::Error::new_spanned(&namevalue, "Must have specified ident") + })? + .to_string() + .to_lowercase(); + let lit = match &namevalue.value { + syn::Expr::Lit(syn::ExprLit { lit, .. }) => lit, + expr => return Err(syn::Error::new_spanned(expr, "Must be a literal")), + }; + match ident.as_str() { + "worker_threads" => { + config.set_worker_threads(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "flavor" => { + config.set_flavor(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "start_paused" => { + config.set_start_paused(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "core_threads" => { + let msg = "Attribute `core_threads` is renamed to `worker_threads`"; + return Err(syn::Error::new_spanned(namevalue, msg)); + } + "crate" => { + config.set_crate_name(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + "unhandled_panic" => { + config + .set_unhandled_panic(lit.clone(), syn::spanned::Spanned::span(lit))?; + } + name => { + let msg = format!( + "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`", + name, + ); + return Err(syn::Error::new_spanned(namevalue, msg)); + } + } + } + syn::Meta::Path(path) => { + let name = path + .get_ident() + .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))? + .to_string() + .to_lowercase(); + let msg = match name.as_str() { + "threaded_scheduler" | "multi_thread" => { + format!( + "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].", + macro_name + ) + } + "basic_scheduler" | "current_thread" | "single_threaded" => { + format!( + "Set the runtime flavor with #[{}(flavor = \"current_thread\")].", + macro_name + ) + } + "flavor" | "worker_threads" | "start_paused" | "crate" | "unhandled_panic" => { + format!("The `{}` attribute requires an argument.", name) + } + name => { + format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`, `unhandled_panic`.", name) + } + }; + return Err(syn::Error::new_spanned(path, msg)); + } + other => { + return Err(syn::Error::new_spanned( + other, + "Unknown attribute inside the macro", + )); + } + } + } + + config.build() +} + +fn parse_knobs(mut input: ItemFn, config: FinalConfig) -> TokenStream { + input.sig.asyncness = None; + + // If type mismatch occurs, the current rustc points to the last statement. + let (last_stmt_start_span, last_stmt_end_span) = { + let mut last_stmt = input.stmts.last().cloned().unwrap_or_default().into_iter(); + + // `Span` on stable Rust has a limitation that only points to the first + // token, not the whole tokens. We can work around this limitation by + // using the first/last span of the tokens like + // `syn::Error::new_spanned` does. + let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span()); + let end = last_stmt.last().map_or(start, |t| t.span()); + (start, end) + }; + + let crate_path = config + .crate_name + .map(ToTokens::into_token_stream) + .unwrap_or_else(|| Ident::new("restate_core", last_stmt_start_span).into_token_stream()); + + let mut tc_builder = quote_spanned! {last_stmt_start_span=> + #crate_path::TaskCenterBuilder::default() + .ingress_runtime_handle(rt.handle().clone()) + .default_runtime_handle(rt.handle().clone()) + }; + let mut rt = match config.flavor { + RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> + ::tokio::runtime::Builder::new_current_thread() + }, + RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> + ::tokio::runtime::Builder::new_multi_thread() + }, + }; + if let Some(v) = config.worker_threads { + rt = quote_spanned! {last_stmt_start_span=> #rt.worker_threads(#v) }; + } + if let Some(v) = config.start_paused { + rt = quote_spanned! {last_stmt_start_span=> #rt.start_paused(#v) }; + tc_builder = quote_spanned! {last_stmt_start_span=> #tc_builder.pause_time(#v) }; + } + if let Some(v) = config.unhandled_panic { + let unhandled_panic = v.into_tokens(&crate_path); + rt = quote_spanned! {last_stmt_start_span=> #rt.unhandled_panic(#unhandled_panic) }; + } + + let generated_attrs = quote! { + #[::core::prelude::v1::test] + }; + + let body_ident = quote! { body }; + let last_block = quote_spanned! {last_stmt_end_span=> + #[allow(clippy::expect_used, clippy::diverging_sub_expression)] + { + use restate_core::TaskCenterFutureExt as _; + // Make sure that panics exits the process. + let orig_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |panic_info| { + // invoke the default handler and exit the process + orig_hook(panic_info); + std::process::exit(1); + })); + let rt = #rt + .enable_all() + .build() + .expect("Failed building the Runtime"); + + let task_center = #tc_builder + .build() + .expect("Failed building task-center"); + + let ret = rt.block_on(#body_ident.in_tc(&task_center)); + rt.block_on(task_center.shutdown_node("completed", 0)); + ret + } + }; + + let body = input.body(); + + // For test functions pin the body to the stack and use `Pin<&mut dyn + // Future>` to reduce the amount of `Runtime::block_on` (and related + // functions) copies we generate during compilation due to the generic + // parameter `F` (the future to block on). This could have an impact on + // performance, but because it's only for testing it's unlikely to be very + // large. + // + // We don't do this for the main function as it should only be used once so + // there will be no benefit. + let body = { + let output_type = match &input.sig.output { + // For functions with no return value syn doesn't print anything, + // but that doesn't work as `Output` for our boxed `Future`, so + // default to `()` (the same type as the function output). + syn::ReturnType::Default => quote! { () }, + syn::ReturnType::Type(_, ret_type) => quote! { #ret_type }, + }; + quote! { + let body = async #body; + ::tokio::pin!(body); + let body: ::core::pin::Pin<&mut dyn ::core::future::Future> = body; + } + }; + + input.into_tokens(generated_attrs, body, last_block) +} + +fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(error.into_compile_error()); + tokens +} + +pub(crate) fn test(args: TokenStream, item: TokenStream, rt_multi_thread: bool) -> TokenStream { + // If any of the steps for this macro fail, we still want to expand to an item that is as close + // to the expected output as possible. This helps out IDEs such that completions and other + // related features keep working. + let input: ItemFn = match syn::parse2(item.clone()) { + Ok(it) => it, + Err(e) => return token_stream_with_error(item, e), + }; + let config = if let Some(attr) = input.attrs().find(|attr| is_test_attribute(attr)) { + let msg = "second test attribute is supplied, consider removing or changing the order of your test attributes"; + Err(syn::Error::new_spanned(attr, msg)) + } else { + AttributeArgs::parse_terminated + .parse2(args) + .and_then(|args| build_config(&input, args, rt_multi_thread)) + }; + + match config { + Ok(config) => parse_knobs(input, config), + Err(e) => token_stream_with_error(parse_knobs(input, DEFAULT_ERROR_CONFIG), e), + } +} + +// Check whether given attribute is a test attribute of forms: +// * `#[test]` +// * `#[core::prelude::*::test]` or `#[::core::prelude::*::test]` +// * `#[std::prelude::*::test]` or `#[::std::prelude::*::test]` +fn is_test_attribute(attr: &Attribute) -> bool { + let path = match &attr.meta { + syn::Meta::Path(path) => path, + _ => return false, + }; + let candidates = [ + ["core", "prelude", "*", "test"], + ["std", "prelude", "*", "test"], + ]; + if path.leading_colon.is_none() + && path.segments.len() == 1 + && path.segments[0].arguments.is_none() + && path.segments[0].ident == "test" + { + return true; + } else if path.segments.len() != candidates[0].len() { + return false; + } + candidates.into_iter().any(|segments| { + path.segments.iter().zip(segments).all(|(segment, path)| { + segment.arguments.is_none() && (path == "*" || segment.ident == path) + }) + }) +} + +struct ItemFn { + outer_attrs: Vec, + vis: Visibility, + sig: Signature, + brace_token: syn::token::Brace, + inner_attrs: Vec, + stmts: Vec, +} + +impl ItemFn { + /// Access all attributes of the function item. + fn attrs(&self) -> impl Iterator { + self.outer_attrs.iter().chain(self.inner_attrs.iter()) + } + + /// Get the body of the function item in a manner so that it can be + /// conveniently used with the `quote!` macro. + fn body(&self) -> Body<'_> { + Body { + brace_token: self.brace_token, + stmts: &self.stmts, + } + } + + /// Convert our local function item into a token stream. + fn into_tokens( + self, + generated_attrs: proc_macro2::TokenStream, + body: proc_macro2::TokenStream, + last_block: proc_macro2::TokenStream, + ) -> TokenStream { + let mut tokens = proc_macro2::TokenStream::new(); + // Outer attributes are simply streamed as-is. + for attr in self.outer_attrs { + attr.to_tokens(&mut tokens); + } + + // Inner attributes require extra care, since they're not supported on + // blocks (which is what we're expanded into) we instead lift them + // outside of the function. This matches the behavior of `syn`. + for mut attr in self.inner_attrs { + attr.style = syn::AttrStyle::Outer; + attr.to_tokens(&mut tokens); + } + + // Add generated macros at the end, so macros processed later are aware of them. + generated_attrs.to_tokens(&mut tokens); + + self.vis.to_tokens(&mut tokens); + self.sig.to_tokens(&mut tokens); + + self.brace_token.surround(&mut tokens, |tokens| { + body.to_tokens(tokens); + last_block.to_tokens(tokens); + }); + + tokens + } +} + +impl Parse for ItemFn { + #[inline] + fn parse(input: ParseStream<'_>) -> syn::Result { + // This parse implementation has been largely lifted from `syn`, with + // the exception of: + // * We don't have access to the plumbing necessary to parse inner + // attributes in-place. + // * We do our own statements parsing to avoid recursively parsing + // entire statements and only look for the parts we're interested in. + + let outer_attrs = input.call(Attribute::parse_outer)?; + let vis: Visibility = input.parse()?; + let sig: Signature = input.parse()?; + + let content; + let brace_token = braced!(content in input); + let inner_attrs = Attribute::parse_inner(&content)?; + + let mut buf = proc_macro2::TokenStream::new(); + let mut stmts = Vec::new(); + + while !content.is_empty() { + if let Some(semi) = content.parse::>()? { + semi.to_tokens(&mut buf); + stmts.push(buf); + buf = proc_macro2::TokenStream::new(); + continue; + } + + // Parse a single token tree and extend our current buffer with it. + // This avoids parsing the entire content of the sub-tree. + buf.extend([content.parse::()?]); + } + + if !buf.is_empty() { + stmts.push(buf); + } + + Ok(Self { + outer_attrs, + vis, + sig, + brace_token, + inner_attrs, + stmts, + }) + } +} + +struct Body<'a> { + brace_token: syn::token::Brace, + // Statements, with terminating `;`. + stmts: &'a [TokenStream], +} + +impl ToTokens for Body<'_> { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + self.brace_token.surround(tokens, |tokens| { + for stmt in self.stmts { + stmt.to_tokens(tokens); + } + }); + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 93856f624..6c85c1a32 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -18,6 +18,10 @@ mod task_center; pub mod worker_api; pub use error::*; +#[cfg(any(test, feature = "test-util"))] +#[doc(inline)] +pub use restate_core_derive::test; + pub use metadata::{ spawn_metadata_manager, Metadata, MetadataBuilder, MetadataKind, MetadataManager, MetadataWriter, SyncError, TargetVersion, @@ -27,5 +31,11 @@ pub use task_center::*; #[cfg(any(test, feature = "test-util"))] mod test_env; +#[cfg(any(test, feature = "test-util"))] +mod test_env2; + #[cfg(any(test, feature = "test-util"))] pub use test_env::{create_mock_nodes_config, NoOpMessageHandler, TestCoreEnv, TestCoreEnvBuilder}; + +#[cfg(any(test, feature = "test-util"))] +pub use test_env2::{TestCoreEnv2, TestCoreEnvBuilder2}; diff --git a/crates/core/src/task_center/builder.rs b/crates/core/src/task_center/builder.rs index dfb3787f6..7325bbfa9 100644 --- a/crates/core/src/task_center/builder.rs +++ b/crates/core/src/task_center/builder.rs @@ -33,7 +33,6 @@ pub struct TaskCenterBuilder { ingress_runtime_handle: Option, ingress_runtime: Option, options: Option, - #[cfg(any(test, feature = "test-util"))] pause_time: bool, } @@ -67,13 +66,11 @@ impl TaskCenterBuilder { self } - #[cfg(any(test, feature = "test-util"))] pub fn pause_time(mut self, pause_time: bool) -> Self { self.pause_time = pause_time; self } - #[cfg(any(test, feature = "test-util"))] pub fn default_for_tests() -> Self { Self::default() .ingress_runtime_handle(tokio::runtime::Handle::current()) @@ -85,10 +82,6 @@ impl TaskCenterBuilder { let options = self.options.unwrap_or_default(); if self.default_runtime_handle.is_none() { let mut default_runtime_builder = tokio_builder("worker", &options); - #[cfg(any(test, feature = "test-util"))] - if self.pause_time { - default_runtime_builder.start_paused(self.pause_time); - } let default_runtime = default_runtime_builder.build()?; self.default_runtime_handle = Some(default_runtime.handle().clone()); self.default_runtime = Some(default_runtime); @@ -96,10 +89,6 @@ impl TaskCenterBuilder { if self.ingress_runtime_handle.is_none() { let mut ingress_runtime_builder = tokio_builder("ingress", &options); - #[cfg(any(test, feature = "test-util"))] - if self.pause_time { - ingress_runtime_builder.start_paused(self.pause_time); - } let ingress_runtime = ingress_runtime_builder.build()?; self.ingress_runtime_handle = Some(ingress_runtime.handle().clone()); self.ingress_runtime = Some(ingress_runtime); @@ -113,6 +102,7 @@ impl TaskCenterBuilder { self.ingress_runtime_handle.unwrap(), self.default_runtime, self.ingress_runtime, + self.pause_time, )) } } diff --git a/crates/core/src/task_center/mod.rs b/crates/core/src/task_center/mod.rs index 1c48897d0..dd40530d6 100644 --- a/crates/core/src/task_center/mod.rs +++ b/crates/core/src/task_center/mod.rs @@ -72,7 +72,8 @@ pub enum RuntimeError { } /// Task center is used to manage long-running and background tasks and their lifecycle. -#[derive(Clone)] +#[derive(Clone, derive_more::Debug)] +#[debug("TaskCenter({})", inner.id)] pub struct TaskCenter { inner: Arc, } @@ -85,6 +86,9 @@ impl TaskCenter { ingress_runtime_handle: tokio::runtime::Handle, default_runtime: Option, ingress_runtime: Option, + // used in tests to start all runtimes with clock paused. Note that this only impacts + // partition processor runtimes + pause_time: bool, ) -> Self { metric_definitions::describe_metrics(); let root_task_context = TaskContext { @@ -96,6 +100,7 @@ impl TaskCenter { }; Self { inner: Arc::new(TaskCenterInner { + id: rand::random(), start_time: Instant::now(), default_runtime_handle, default_runtime, @@ -108,6 +113,7 @@ impl TaskCenter { global_metadata: OnceLock::new(), managed_runtimes: Mutex::new(HashMap::with_capacity(64)), root_task_context, + pause_time, }), } } @@ -508,6 +514,10 @@ impl TaskCenter { // todo: configure the runtime according to a new runtime kind perhaps? let thread_builder = std::thread::Builder::new().name(format!("rt:{}", runtime_name)); let mut builder = tokio::runtime::Builder::new_current_thread(); + + #[cfg(any(test, feature = "test-util"))] + builder.start_paused(self.inner.pause_time); + let rt = builder .enable_all() .build() @@ -886,6 +896,12 @@ impl TaskCenter { } struct TaskCenterInner { + #[allow(dead_code)] + /// used in Debug impl to distinguish between multiple task-centers + id: u16, + /// Should we start new runtimes with paused clock? + #[allow(dead_code)] + pause_time: bool, default_runtime_handle: tokio::runtime::Handle, ingress_runtime_handle: tokio::runtime::Handle, managed_runtimes: Mutex>>, diff --git a/crates/core/src/test_env.rs b/crates/core/src/test_env.rs index 5ebdd702b..1d7143aab 100644 --- a/crates/core/src/test_env.rs +++ b/crates/core/src/test_env.rs @@ -217,7 +217,7 @@ impl TestCoreEnvBuilder { Precondition::None, ) .await - .expect("sot store scheduling plan in metadata store"); + .expect("to store scheduling plan in metadata store"); let _ = self .metadata @@ -227,6 +227,7 @@ impl TestCoreEnvBuilder { ) .await .unwrap(); + self.metadata_writer.set_my_node_id(self.my_node_id); TestCoreEnv { diff --git a/crates/core/src/test_env2.rs b/crates/core/src/test_env2.rs new file mode 100644 index 000000000..17f966446 --- /dev/null +++ b/crates/core/src/test_env2.rs @@ -0,0 +1,303 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::marker::PhantomData; +use std::str::FromStr; +use std::sync::Arc; + +use futures::Stream; + +use restate_types::cluster_controller::{ReplicationStrategy, SchedulingPlan}; +use restate_types::config::NetworkingOptions; +use restate_types::logs::metadata::{bootstrap_logs_metadata, ProviderKind}; +use restate_types::metadata_store::keys::{ + BIFROST_CONFIG_KEY, NODES_CONFIG_KEY, PARTITION_TABLE_KEY, SCHEDULING_PLAN_KEY, +}; +use restate_types::net::codec::{Targeted, WireDecode}; +use restate_types::net::metadata::MetadataKind; +use restate_types::net::AdvertisedAddress; +use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; +use restate_types::partition_table::PartitionTable; +use restate_types::protobuf::node::Message; +use restate_types::{GenerationalNodeId, Version}; + +use crate::metadata_store::{MetadataStoreClient, Precondition}; +use crate::network::{ + ConnectionManager, FailingConnector, Incoming, MessageHandler, MessageRouterBuilder, + NetworkError, Networking, ProtocolError, TransportConnect, +}; +use crate::TaskCenter; +use crate::{spawn_metadata_manager, MetadataBuilder, TaskId}; +use crate::{Metadata, MetadataManager, MetadataWriter}; + +pub struct TestCoreEnvBuilder2 { + pub my_node_id: GenerationalNodeId, + pub metadata_manager: MetadataManager, + pub metadata_writer: MetadataWriter, + pub metadata: Metadata, + pub networking: Networking, + pub nodes_config: NodesConfiguration, + pub provider_kind: ProviderKind, + pub router_builder: MessageRouterBuilder, + pub partition_table: PartitionTable, + pub scheduling_plan: SchedulingPlan, + pub metadata_store_client: MetadataStoreClient, +} + +impl TestCoreEnvBuilder2 { + pub fn with_incoming_only_connector() -> Self { + let metadata_builder = MetadataBuilder::default(); + let net_opts = NetworkingOptions::default(); + let connection_manager = + ConnectionManager::new_incoming_only(metadata_builder.to_metadata()); + let networking = Networking::with_connection_manager( + metadata_builder.to_metadata(), + net_opts, + connection_manager, + ); + + TestCoreEnvBuilder2::with_networking(networking, metadata_builder) + } +} +impl TestCoreEnvBuilder2 { + pub fn with_transport_connector(connector: Arc) -> TestCoreEnvBuilder2 { + let metadata_builder = MetadataBuilder::default(); + let net_opts = NetworkingOptions::default(); + let connection_manager = + ConnectionManager::new(metadata_builder.to_metadata(), connector, net_opts.clone()); + let networking = Networking::with_connection_manager( + metadata_builder.to_metadata(), + net_opts, + connection_manager, + ); + + TestCoreEnvBuilder2::with_networking(networking, metadata_builder) + } + + pub fn with_networking(networking: Networking, metadata_builder: MetadataBuilder) -> Self { + let my_node_id = GenerationalNodeId::new(1, 1); + let metadata_store_client = MetadataStoreClient::new_in_memory(); + let metadata = metadata_builder.to_metadata(); + let metadata_manager = + MetadataManager::new(metadata_builder, metadata_store_client.clone()); + let metadata_writer = metadata_manager.writer(); + let router_builder = MessageRouterBuilder::default(); + let nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); + let partition_table = PartitionTable::with_equally_sized_partitions(Version::MIN, 10); + let scheduling_plan = + SchedulingPlan::from(&partition_table, ReplicationStrategy::OnAllNodes); + TaskCenter::try_set_global_metadata(metadata.clone()); + + // Use memory-loglet as a default if in test-mode + #[cfg(any(test, feature = "test-util"))] + let provider_kind = ProviderKind::InMemory; + #[cfg(not(any(test, feature = "test-util")))] + let provider_kind = ProviderKind::Local; + + TestCoreEnvBuilder2 { + my_node_id, + metadata_manager, + metadata_writer, + metadata, + networking, + nodes_config, + router_builder, + partition_table, + scheduling_plan, + metadata_store_client, + provider_kind, + } + } + + pub fn set_nodes_config(mut self, nodes_config: NodesConfiguration) -> Self { + self.nodes_config = nodes_config; + self + } + + pub fn set_partition_table(mut self, partition_table: PartitionTable) -> Self { + self.partition_table = partition_table; + self + } + + pub fn set_scheduling_plan(mut self, scheduling_plan: SchedulingPlan) -> Self { + self.scheduling_plan = scheduling_plan; + self + } + + pub fn set_my_node_id(mut self, my_node_id: GenerationalNodeId) -> Self { + self.my_node_id = my_node_id; + self + } + + pub fn set_provider_kind(mut self, provider_kind: ProviderKind) -> Self { + self.provider_kind = provider_kind; + self + } + + pub fn add_mock_nodes_config(mut self) -> Self { + self.nodes_config = + create_mock_nodes_config(self.my_node_id.raw_id(), self.my_node_id.raw_generation()); + self + } + + pub fn add_message_handler(mut self, handler: H) -> Self + where + H: MessageHandler + Send + Sync + 'static, + { + self.router_builder.add_message_handler(handler); + self + } + + pub async fn build(mut self) -> TestCoreEnv2 { + self.metadata_manager + .register_in_message_router(&mut self.router_builder); + self.networking + .connection_manager() + .set_message_router(self.router_builder.build()); + + let metadata_manager_task = + spawn_metadata_manager(self.metadata_manager).expect("metadata manager should start"); + + self.metadata_store_client + .put( + NODES_CONFIG_KEY.clone(), + &self.nodes_config, + Precondition::None, + ) + .await + .expect("to store nodes config in metadata store"); + self.metadata_writer + .submit(Arc::new(self.nodes_config.clone())); + + let logs = bootstrap_logs_metadata( + self.provider_kind, + None, + self.partition_table.num_partitions(), + ); + self.metadata_store_client + .put(BIFROST_CONFIG_KEY.clone(), &logs, Precondition::None) + .await + .expect("to store bifrost config in metadata store"); + self.metadata_writer.submit(Arc::new(logs)); + + self.metadata_store_client + .put( + PARTITION_TABLE_KEY.clone(), + &self.partition_table, + Precondition::None, + ) + .await + .expect("to store partition table in metadata store"); + self.metadata_writer.submit(Arc::new(self.partition_table)); + + self.metadata_store_client + .put( + SCHEDULING_PLAN_KEY.clone(), + &self.scheduling_plan, + Precondition::None, + ) + .await + .expect("to store scheduling plan in metadata store"); + + let _ = self + .metadata + .wait_for_version( + MetadataKind::NodesConfiguration, + self.nodes_config.version(), + ) + .await + .unwrap(); + + self.metadata_writer.set_my_node_id(self.my_node_id); + + TestCoreEnv2 { + metadata: self.metadata, + metadata_manager_task, + metadata_writer: self.metadata_writer, + networking: self.networking, + metadata_store_client: self.metadata_store_client, + } + } +} + +// This might need to be moved to a better place in the future. +pub struct TestCoreEnv2 { + pub metadata: Metadata, + pub metadata_writer: MetadataWriter, + pub networking: Networking, + pub metadata_manager_task: TaskId, + pub metadata_store_client: MetadataStoreClient, +} + +impl TestCoreEnv2 { + pub async fn create_with_single_node(node_id: u32, generation: u32) -> Self { + TestCoreEnvBuilder2::with_incoming_only_connector() + .set_my_node_id(GenerationalNodeId::new(node_id, generation)) + .add_mock_nodes_config() + .build() + .await + } +} + +impl TestCoreEnv2 { + pub async fn accept_incoming_connection( + &self, + incoming: S, + ) -> Result + Unpin + Send + 'static, NetworkError> + where + S: Stream> + Unpin + Send + 'static, + { + self.networking + .connection_manager() + .accept_incoming_connection(incoming) + .await + } +} + +pub fn create_mock_nodes_config(node_id: u32, generation: u32) -> NodesConfiguration { + let mut nodes_config = NodesConfiguration::new(Version::MIN, "test-cluster".to_owned()); + let address = AdvertisedAddress::from_str("http://127.0.0.1:5122/").unwrap(); + let node_id = GenerationalNodeId::new(node_id, generation); + let roles = Role::Admin | Role::Worker; + let my_node = NodeConfig::new( + format!("MyNode-{}", node_id), + node_id, + address, + roles, + LogServerConfig::default(), + ); + nodes_config.upsert_node(my_node); + nodes_config +} + +/// No-op message handler which simply drops the received messages. Useful if you don't want to +/// react to network messages. +pub struct NoOpMessageHandler { + phantom_data: PhantomData, +} + +impl Default for NoOpMessageHandler { + fn default() -> Self { + NoOpMessageHandler { + phantom_data: PhantomData, + } + } +} + +impl MessageHandler for NoOpMessageHandler +where + M: WireDecode + Targeted + Send + Sync, +{ + type MessageType = M; + + async fn on_message(&self, _msg: Incoming) { + // no-op + } +}