diff --git a/Cargo.lock b/Cargo.lock index c87af8888..f99eba173 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6103,7 +6103,7 @@ dependencies = [ "once_cell", "opentelemetry", "parking_lot", - "pin-project", + "pin-project-lite", "prost", "prost-types", "rand", diff --git a/Cargo.toml b/Cargo.toml index dff64d45e..7154a9645 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -141,6 +141,7 @@ opentelemetry_sdk = { version = "0.24.0" } parking_lot = { version = "0.12" } paste = "1.0" pin-project = "1.0" +pin-project-lite = { version = "0.2" } prost = { version = "0.13.1" } prost-build = { version = "0.13.1" } priority-queue = "2.0.3" diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 7826bdc82..f14e3ad05 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -99,7 +99,7 @@ pub fn spawn_restate(config: Configuration) -> TaskCenter { restate_types::config::set_current_config(config.clone()); let updateable_config = Configuration::updateable(); - tc.block_on("benchmark", None, async { + tc.block_on(async { RocksDbManager::init(Constant::new(config.common)); tc.spawn(TaskKind::SystemBoot, "restate", None, async move { diff --git a/crates/bifrost/benches/append_throughput.rs b/crates/bifrost/benches/append_throughput.rs index f868cc3c0..b074cb1e2 100644 --- a/crates/bifrost/benches/append_throughput.rs +++ b/crates/bifrost/benches/append_throughput.rs @@ -104,7 +104,7 @@ fn write_throughput_local_loglet(c: &mut Criterion) { provider, )); - let bifrost = tc.block_on("bifrost-init", None, async { + let bifrost = tc.block_on(async { let metadata = metadata(); let bifrost_svc = BifrostService::new(restate_core::task_center(), metadata) .enable_local_loglet(&Live::from_value(config)); diff --git a/crates/bifrost/benches/util.rs b/crates/bifrost/benches/util.rs index e8ff89538..d4d22f46d 100644 --- a/crates/bifrost/benches/util.rs +++ b/crates/bifrost/benches/util.rs @@ -44,9 +44,7 @@ pub async fn spawn_environment( let metadata_writer = metadata_manager.writer(); tc.try_set_global_metadata(metadata.clone()); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(config.common)) - }); + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(config.common))); let logs = restate_types::logs::metadata::bootstrap_logs_metadata(provider, None, num_logs); diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 89e61849e..2aaab89b0 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -40,7 +40,7 @@ metrics = { workspace = true } opentelemetry = { workspace = true } once_cell = { workspace = true } parking_lot = { workspace = true } -pin-project = { workspace = true } +pin-project-lite = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } rand = { workspace = true } diff --git a/crates/core/src/metadata/manager.rs b/crates/core/src/metadata/manager.rs index 58afdee2e..e2d156525 100644 --- a/crates/core/src/metadata/manager.rs +++ b/crates/core/src/metadata/manager.rs @@ -617,7 +617,7 @@ mod tests { S: Fn(&mut T, Version), { let tc = TaskCenterBuilder::default().build()?; - tc.block_on("test", None, async move { + tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); @@ -689,7 +689,7 @@ mod tests { I: Fn(&mut T), { let tc = TaskCenterBuilder::default().build()?; - tc.block_on("test", None, async move { + tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); diff --git a/crates/core/src/metadata/mod.rs b/crates/core/src/metadata/mod.rs index 382e80501..837d616e9 100644 --- a/crates/core/src/metadata/mod.rs +++ b/crates/core/src/metadata/mod.rs @@ -73,6 +73,30 @@ pub struct Metadata { } impl Metadata { + pub fn try_with_current(f: F) -> Option + where + F: Fn(&Metadata) -> R, + { + TaskCenter::with_metadata(|m| f(m)) + } + + pub fn try_current() -> Option { + TaskCenter::with_current(|tc| tc.metadata()) + } + + #[track_caller] + pub fn with_current(f: F) -> R + where + F: FnOnce(&Metadata) -> R, + { + TaskCenter::with_metadata(|m| f(m)).expect("called outside task-center scope") + } + + #[track_caller] + pub fn current() -> Metadata { + TaskCenter::with_current(|tc| tc.metadata()).expect("called outside task-center scope") + } + #[inline(always)] pub fn nodes_config_snapshot(&self) -> Arc { self.inner.nodes_config.load_full() diff --git a/crates/core/src/task_center/extensions.rs b/crates/core/src/task_center/extensions.rs new file mode 100644 index 000000000..6f66a9a61 --- /dev/null +++ b/crates/core/src/task_center/extensions.rs @@ -0,0 +1,171 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::future::Future; +use std::pin::Pin; + +use pin_project_lite::pin_project; +use tokio::task::futures::TaskLocalFuture; +use tokio_util::sync::CancellationToken; + +use crate::task_center::TaskContext; +use crate::Metadata; + +use super::{ + GlobalOverrides, TaskCenter, TaskId, TaskKind, CURRENT_TASK_CENTER, OVERRIDES, TASK_CONTEXT, +}; + +type TaskCenterFuture = + TaskLocalFuture>>; + +/// Adds the ability to override task-center for a future and all its children +pub trait TaskCenterFutureExt: Sized { + /// Ensures that a future will run within a task-center context. This will inherit the current + /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + fn in_tc(self, task_center: &TaskCenter) -> WithTaskCenter; + + /// Lets task-center treat this future as a psuedo-task. It gets its own TaskId and an + /// independent cancellation token. However, task-center will not spawn this as a task nor + /// manage its lifecycle. + fn in_tc_as_task( + self, + task_center: &TaskCenter, + kind: TaskKind, + name: &'static str, + ) -> WithTaskCenter; + + /// Ensures that a future will run within the task-center in current scope. This will inherit the current + /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + /// + /// This is useful when running dispatching a future as a task on an external runtime/thread, + /// or when running a future on tokio's JoinSet without representing those tokio tasks as + /// task-center tasks. However, in the latter case, it's preferred to use + /// [`Self::in_current_ts_as_task`] instead. + fn in_current_tc(self) -> WithTaskCenter; + + /// Attaches current task-center and lets it treat the future as a psuedo-task. It gets its own TaskId and an + /// independent cancellation token. However, task-center will not spawn this as a task nor + /// manage its lifecycle. + fn in_current_tc_as_task(self, kind: TaskKind, name: &'static str) -> WithTaskCenter; +} + +pin_project! { + pub struct WithTaskCenter { + #[pin] + inner_fut: TaskCenterFuture, + } +} + +impl TaskCenterFutureExt for F +where + F: Future, +{ + fn in_tc(self, task_center: &TaskCenter) -> WithTaskCenter { + let ctx = task_center.with_task_context(Clone::clone); + + let inner = CURRENT_TASK_CENTER.scope( + task_center.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(ctx, self), + ), + ); + WithTaskCenter { inner_fut: inner } + } + + fn in_tc_as_task( + self, + task_center: &TaskCenter, + kind: TaskKind, + name: &'static str, + ) -> WithTaskCenter { + let ctx = task_center.with_task_context(move |parent| TaskContext { + id: TaskId::default(), + name, + kind, + cancellation_token: CancellationToken::new(), + partition_id: parent.partition_id, + }); + + let inner = CURRENT_TASK_CENTER.scope( + task_center.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(ctx, self), + ), + ); + WithTaskCenter { inner_fut: inner } + } + + /// Ensures that a future will run within a task-center context. This will inherit the current + /// task context (if there is one). Otherwise, it'll run in the context of the root task (task-id=0). + fn in_current_tc(self) -> WithTaskCenter { + TaskCenter::with_current(|tc| self.in_tc(tc)) + } + + fn in_current_tc_as_task(self, kind: TaskKind, name: &'static str) -> WithTaskCenter { + TaskCenter::with_current(|tc| self.in_tc_as_task(tc, kind, name)) + } +} + +impl Future for WithTaskCenter { + type Output = T::Output; + + fn poll( + self: Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.inner_fut.poll(ctx) + } +} + +/// Adds the ability to override Metadata for a future and all its children +pub trait MetadataFutureExt: Sized { + /// Attaches restate's Metadata as an override on a future and all children futures or + /// task-center tasks spawned from it. + fn with_metadata(self, metadata: &Metadata) -> WithMetadata; +} + +pin_project! { + pub struct WithMetadata { + #[pin] + inner_fut: TaskLocalFuture, + } +} + +impl MetadataFutureExt for F +where + F: Future, +{ + fn with_metadata(self, metadata: &Metadata) -> WithMetadata { + let current_overrides = OVERRIDES.try_with(Clone::clone).unwrap_or_default(); + // temporary mute until overrides include more fields + #[allow(clippy::needless_update)] + let overrides = GlobalOverrides { + metadata: Some(metadata.clone()), + ..current_overrides + }; + let inner = OVERRIDES.scope(overrides, self); + WithMetadata { inner_fut: inner } + } +} + +impl Future for WithMetadata { + type Output = T::Output; + + fn poll( + self: Pin<&mut Self>, + ctx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.inner_fut.poll(ctx) + } +} diff --git a/crates/core/src/task_center/mod.rs b/crates/core/src/task_center/mod.rs index abec0df8c..c0fa59b79 100644 --- a/crates/core/src/task_center/mod.rs +++ b/crates/core/src/task_center/mod.rs @@ -9,22 +9,25 @@ // by the Apache License, Version 2.0. mod builder; +mod extensions; mod runtime; mod task; mod task_kind; pub use builder::*; +pub use extensions::*; pub use runtime::*; pub use task::*; pub use task_kind::*; use std::collections::HashMap; +use std::future::Future; use std::panic::AssertUnwindSafe; use std::sync::atomic::{AtomicBool, AtomicI32, Ordering}; use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant}; -use futures::{Future, FutureExt}; +use futures::FutureExt; use metrics::{counter, gauge}; use parking_lot::Mutex; use tokio::runtime::RuntimeMetrics; @@ -46,9 +49,18 @@ const EXIT_CODE_FAILURE: i32 = 1; task_local! { // Current task center - static CURRENT_TASK_CENTER: TaskCenter; + pub(self) static CURRENT_TASK_CENTER: TaskCenter; // Tasks provide access to their context - static CONTEXT: TaskContext; + static TASK_CONTEXT: TaskContext; + + /// Access to a task-level global overrides. + static OVERRIDES: GlobalOverrides; +} + +#[derive(Default, Clone)] +struct GlobalOverrides { + metadata: Option, + //config: Arc, } #[derive(Debug, thiserror::Error)] @@ -75,6 +87,13 @@ impl TaskCenter { ingress_runtime: Option, ) -> Self { metric_definitions::describe_metrics(); + let root_task_context = TaskContext { + id: TaskId::ROOT, + name: "::", + kind: TaskKind::InPlace, + cancellation_token: CancellationToken::new(), + partition_id: None, + }; Self { inner: Arc::new(TaskCenterInner { start_time: Instant::now(), @@ -88,10 +107,39 @@ impl TaskCenter { managed_tasks: Mutex::new(HashMap::new()), global_metadata: OnceLock::new(), managed_runtimes: Mutex::new(HashMap::with_capacity(64)), + root_task_context, }), } } + pub fn try_current() -> Option { + Self::try_with_current(Clone::clone) + } + + pub fn try_with_current(f: F) -> Option + where + F: FnOnce(&TaskCenter) -> R, + { + CURRENT_TASK_CENTER.try_with(|tc| f(tc)).ok() + } + + /// Get the current task center. Use this to spawn tasks on the current task center. + /// This must be called from within a task-center task. + #[track_caller] + pub fn current() -> TaskCenter { + Self::with_current(Clone::clone) + } + + #[track_caller] + pub fn with_current(f: F) -> R + where + F: FnOnce(&TaskCenter) -> R, + { + CURRENT_TASK_CENTER + .try_with(|tc| f(tc)) + .expect("called outside task-center task") + } + pub fn default_runtime_metrics(&self) -> RuntimeMetrics { self.inner.default_runtime_handle.metrics() } @@ -137,10 +185,6 @@ impl TaskCenter { self.inner.current_exit_code.load(Ordering::Relaxed) } - pub fn metadata(&self) -> Option<&Metadata> { - self.inner.global_metadata.get() - } - fn submit_runtime_metrics(runtime: &'static str, stats: RuntimeMetrics) { gauge!("restate.tokio.num_workers", "runtime" => runtime).set(stats.num_workers() as f64); gauge!("restate.tokio.blocking_threads", "runtime" => runtime) @@ -173,15 +217,40 @@ impl TaskCenter { } } - /// Clone the currently set METADATA (and if Some()). Otherwise falls back to global metadata. + pub(crate) fn metadata(&self) -> Option { + match OVERRIDES.try_with(|overrides| overrides.metadata.clone()) { + Ok(Some(o)) => Some(o), + // No metadata override, use task-center-level metadata + _ => self.inner.global_metadata.get().cloned(), + } + } + #[track_caller] - fn clone_metadata(&self) -> Option { - CONTEXT - .try_with(|m| m.metadata.clone()) + /// Attempt to access task-level overridden metadata first, if we don't have an override, + /// fallback to task-center's level metadata. + pub(crate) fn with_metadata(f: F) -> Option + where + F: FnOnce(&Metadata) -> R, + { + OVERRIDES + .try_with(|overrides| match &overrides.metadata { + Some(m) => Some(f(m)), + // No metadata override, use task-center-level metadata + None => CURRENT_TASK_CENTER.with(|tc| tc.inner.global_metadata.get().map(f)), + }) .ok() .flatten() - .or_else(|| self.inner.global_metadata.get().cloned()) } + + fn with_task_context(&self, f: F) -> R + where + F: Fn(&TaskContext) -> R, + { + TASK_CONTEXT + .try_with(|ctx| f(ctx)) + .unwrap_or_else(|_| f(&self.inner.root_task_context)) + } + /// Triggers a shutdown of the system. All running tasks will be asked gracefully /// to cancel but we will only wait for tasks with a TaskKind that has the property /// "OnCancel" set to "wait". @@ -231,14 +300,12 @@ impl TaskCenter { { let inner = self.inner.clone(); let id = TaskId::default(); - let metadata = self.clone_metadata(); let context = TaskContext { id, name, kind, partition_id, cancellation_token: cancel.clone(), - metadata, }; let task = Arc::new(Task { context: context.clone(), @@ -323,14 +390,12 @@ impl TaskCenter { let cancel = CancellationToken::new(); let id = TaskId::default(); - let metadata = self.clone_metadata(); let context = TaskContext { id, name, kind, partition_id, cancellation_token: cancel.clone(), - metadata, }; let fut = unmanaged_wrapper(self.clone(), context, future); @@ -374,8 +439,7 @@ impl TaskCenter { kind, cancellation_token: cancellation_token.clone(), // We must be within task-context already. let's get inherit partition_id - partition_id: CONTEXT.with(|c| c.partition_id), - metadata: Some(metadata()), + partition_id: self.with_task_context(|c| c.partition_id), }; let task = Arc::new(Task { @@ -458,7 +522,6 @@ impl TaskCenter { kind: root_task_kind, cancellation_token: cancel.clone(), partition_id, - metadata: Some(metadata()), }; let (result_tx, result_rx) = oneshot::channel(); @@ -513,17 +576,15 @@ impl TaskCenter { return Err(ShutdownError); } - let parent_id = - current_task_id().expect("spawn_child called outside of a task-center task"); - // From this point onwards, we unwrap() directly with the assumption that we are in task-center - // context and that the previous (expect) guards against reaching this point if we are - // outside task-center. - let parent_kind = current_task_kind().unwrap(); - let parent_name = CONTEXT.try_with(|ctx| ctx.name).unwrap(); + let (parent_id, parent_name, parent_kind, cancel) = self.with_task_context(|ctx| { + ( + ctx.id, + ctx.name, + ctx.kind, + ctx.cancellation_token.child_token(), + ) + }); - let cancel = CONTEXT - .try_with(|ctx| ctx.cancellation_token.child_token()) - .unwrap(); let result = self.spawn_inner(kind, name, partition_id, cancel, future); trace!( @@ -540,17 +601,17 @@ impl TaskCenter { pub fn spawn_blocking_unmanaged( &self, name: &'static str, - partition_id: Option, future: F, ) -> tokio::task::JoinHandle where F: Future + Send + 'static, O: Send + 'static, { - let tc = self.clone(); + let rt_handle = self.inner.default_runtime_handle.clone(); + let future = future.in_tc_as_task(self, TaskKind::InPlace, name); self.inner .default_runtime_handle - .spawn_blocking(move || tc.block_on(name, partition_id, future)) + .spawn_blocking(move || rt_handle.block_on(future)) } /// Cancelling the child will not cancel the parent. Note that parent task will not @@ -566,9 +627,7 @@ impl TaskCenter { where F: Future> + Send + 'static, { - let cancel = CONTEXT - .try_with(|ctx| ctx.cancellation_token.child_token()) - .expect("spawning inside task-center context"); + let cancel = self.with_task_context(|ctx| ctx.cancellation_token.child_token()); self.spawn_inner(kind, name, partition_id, cancel, future) } @@ -644,18 +703,13 @@ impl TaskCenter { /// Sets the current task_center but doesn't create a task. Use this when you need to run a /// future within task_center scope. - pub fn block_on( - &self, - name: &'static str, - partition_id: Option, - future: F, - ) -> O + pub fn block_on(&self, future: F) -> O where F: Future, { self.inner .default_runtime_handle - .block_on(self.run_in_scope(name, partition_id, future)) + .block_on(future.in_tc(self)) } /// Sets the current task_center but doesn't create a task. Use this when you need to run a @@ -669,46 +723,38 @@ impl TaskCenter { where F: Future, { - let cancel = CancellationToken::new(); + let cancellation_token = CancellationToken::new(); let id = TaskId::default(); - let metadata = self.clone_metadata(); - let context = TaskContext { + let ctx = TaskContext { id, name, kind: TaskKind::InPlace, - cancellation_token: cancel.clone(), + cancellation_token: cancellation_token.clone(), partition_id, - metadata, }; CURRENT_TASK_CENTER - .scope(self.clone(), CONTEXT.scope(context, future)) + .scope( + self.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(ctx, future), + ), + ) .await } /// Sets the current task_center but doesn't create a task. Use this when you need to run a /// closure within task_center scope. - pub fn run_in_scope_sync( - &self, - name: &'static str, - partition_id: Option, - f: F, - ) -> O + pub fn run_in_scope_sync(&self, f: F) -> O where F: FnOnce() -> O, { - let cancel = CancellationToken::new(); - let id = TaskId::default(); - let metadata = self.clone_metadata(); - let context = TaskContext { - id, - name, - kind: TaskKind::InPlace, - partition_id, - cancellation_token: cancel.clone(), - metadata, - }; - CURRENT_TASK_CENTER.sync_scope(self.clone(), || CONTEXT.sync_scope(context, f)) + CURRENT_TASK_CENTER.sync_scope(self.clone(), || { + OVERRIDES.sync_scope(OVERRIDES.try_with(Clone::clone).unwrap_or_default(), || { + TASK_CONTEXT.sync_scope(self.with_task_context(Clone::clone), f) + }) + }) } /// Take control over the running task from task-center. This returns None if the task was not @@ -835,10 +881,11 @@ struct TaskCenterInner { current_exit_code: AtomicI32, managed_tasks: Mutex>>, global_metadata: OnceLock, + root_task_context: TaskContext, } /// This wrapper function runs in a newly-spawned task. It initializes the -/// task-local variables and calls the payload function. +/// task-local variables and wraps the inner future. async fn wrapper(task_center: TaskCenter, context: TaskContext, future: F) where F: Future> + 'static, @@ -849,12 +896,16 @@ where let result = CURRENT_TASK_CENTER .scope( task_center.clone(), - CONTEXT.scope(context, { - // We use AssertUnwindSafe here so that the wrapped function - // doesn't need to be UnwindSafe. We should not do anything after - // unwinding that'd risk us being in unwind-unsafe behavior. - AssertUnwindSafe(future).catch_unwind() - }), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope( + context, + // We use AssertUnwindSafe here so that the wrapped function + // doesn't need to be UnwindSafe. We should not do anything after + // unwinding that'd risk us being in unwind-unsafe behavior. + AssertUnwindSafe(future).catch_unwind(), + ), + ), ) .await; task_center.on_finish(id, result).await; @@ -868,29 +919,21 @@ where trace!(kind = ?context.kind, name = ?context.name, "Starting task {}", context.id); CURRENT_TASK_CENTER - .scope(task_center.clone(), CONTEXT.scope(context, future)) + .scope( + task_center.clone(), + OVERRIDES.scope( + OVERRIDES.try_with(Clone::clone).unwrap_or_default(), + TASK_CONTEXT.scope(context, future), + ), + ) .await } -/// The current task-center task kind. This returns None if we are not in the scope -/// of a task-center task. -pub fn current_task_kind() -> Option { - CONTEXT.try_with(|ctx| ctx.kind).ok() -} - -/// The current task-center task Id. This returns None if we are not in the scope -/// of a task-center task. -pub fn current_task_id() -> Option { - CONTEXT.try_with(|ctx| ctx.id).ok() -} - /// Access to global metadata handle. This is available in task-center tasks only! #[track_caller] pub fn metadata() -> Metadata { - CONTEXT - .try_with(|ctx| ctx.metadata.clone()) - .expect("metadata() called outside task-center scope") - .expect("metadata() called before global metadata was set") + // todo: migrate call-sites + Metadata::current() } #[track_caller] @@ -898,32 +941,41 @@ pub fn with_metadata(f: F) -> R where F: FnOnce(&Metadata) -> R, { - CURRENT_TASK_CENTER.with(|tc| { - f(tc.metadata() - .expect("metadata must be set. Is global metadata set?")) - }) + Metadata::with_current(f) } /// Access to this node id. This is available in task-center tasks only! #[track_caller] pub fn my_node_id() -> GenerationalNodeId { - CONTEXT - .try_with(|ctx| ctx.metadata.as_ref().map(|m| m.my_node_id())) - .expect("my_node_id() called outside task-center scope") - .expect("my_node_id() called before global metadata was set") + // todo: migrate call-sites + Metadata::with_current(|m| m.my_node_id()) +} + +/// The current task-center task Id. This returns None if we are not in the scope +/// of a task-center task. +pub fn current_task_id() -> Option { + TASK_CONTEXT + .try_with(|ctx| Some(ctx.id)) + .unwrap_or(TaskCenter::try_with_current(|tc| { + tc.inner.root_task_context.id + })) } /// The current partition Id associated to the running task-center task. pub fn current_task_partition_id() -> Option { - CONTEXT.try_with(|ctx| ctx.partition_id).ok().flatten() + TASK_CONTEXT + .try_with(|ctx| Some(ctx.partition_id)) + .unwrap_or(TaskCenter::try_with_current(|tc| { + tc.inner.root_task_context.partition_id + })) + .flatten() } /// Get the current task center. Use this to spawn tasks on the current task center. /// This must be called from within a task-center task. pub fn task_center() -> TaskCenter { - CURRENT_TASK_CENTER - .try_with(|t| t.clone()) - .expect("task_center() called in a task-center task") + // migrate call-sites + TaskCenter::current() } /// A Future that can be used to check if the current task has been requested to @@ -939,7 +991,7 @@ pub async fn cancellation_watcher() { /// cancel_task() call, or if it's a child and the parent is being cancelled by a /// cancel_task() call, this cancellation token will be set to cancelled. pub fn cancellation_token() -> CancellationToken { - let res = CONTEXT.try_with(|ctx| ctx.cancellation_token.clone()); + let res = TASK_CONTEXT.try_with(|ctx| ctx.cancellation_token.clone()); if cfg!(any(test, feature = "test-util")) { // allow in tests to call from non-task-center tasks. @@ -951,7 +1003,7 @@ pub fn cancellation_token() -> CancellationToken { /// Has the current task been requested to cancel? pub fn is_cancellation_requested() -> bool { - CONTEXT + TASK_CONTEXT .try_with(|ctx| ctx.cancellation_token.is_cancelled()) .unwrap_or_else(|_| { if cfg!(any(test, feature = "test-util")) { @@ -966,7 +1018,6 @@ mod tests { use super::*; use googletest::prelude::*; - use restate_test_util::assert_eq; use restate_types::config::CommonOptionsBuilder; use tracing_test::traced_test; @@ -986,7 +1037,6 @@ mod tests { tc.spawn(TaskKind::RoleRunner, "worker-role", None, async { info!("Hello async"); tokio::time::sleep(Duration::from_secs(10)).await; - assert_eq!(TaskKind::RoleRunner, current_task_kind().unwrap()); info!("Bye async"); Ok(()) }) diff --git a/crates/core/src/task_center/task.rs b/crates/core/src/task_center/task.rs index c595e5674..beb1b08f8 100644 --- a/crates/core/src/task_center/task.rs +++ b/crates/core/src/task_center/task.rs @@ -18,7 +18,7 @@ use tokio_util::sync::CancellationToken; use restate_types::identifiers::PartitionId; use super::{TaskId, TaskKind}; -use crate::{Metadata, ShutdownError}; +use crate::ShutdownError; #[derive(Clone)] pub(super) struct TaskContext { @@ -31,8 +31,6 @@ pub(super) struct TaskContext { /// Tasks associated with a specific partition ID will have this set. This allows /// for cancellation of tasks associated with that partition. pub(super) partition_id: Option, - /// Access to a locally-cached metadata view. - pub(super) metadata: Option, } pub(super) struct Task { diff --git a/crates/core/src/task_center/task_kind.rs b/crates/core/src/task_center/task_kind.rs index 463000a3c..2b486d555 100644 --- a/crates/core/src/task_center/task_kind.rs +++ b/crates/core/src/task_center/task_kind.rs @@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use strum::EnumProperty; -static NEXT_TASK_ID: AtomicU64 = const { AtomicU64::new(0) }; +static NEXT_TASK_ID: AtomicU64 = const { AtomicU64::new(1) }; #[derive( Clone, @@ -36,6 +36,8 @@ impl Default for TaskId { } impl TaskId { + pub const ROOT: TaskId = TaskId(0); + pub fn new() -> Self { Default::default() } diff --git a/crates/ingress-http/Cargo.toml b/crates/ingress-http/Cargo.toml index fcc77e4fd..8b47165b4 100644 --- a/crates/ingress-http/Cargo.toml +++ b/crates/ingress-http/Cargo.toml @@ -37,7 +37,7 @@ hyper-util = { workspace = true, features = ["http1", "http2", "server", "tokio" metrics = { workspace = true } opentelemetry = { workspace = true } opentelemetry_sdk = { workspace = true } -pin-project-lite = "0.2.13" +pin-project-lite = { workspace = true } schemars = { workspace = true, optional = true } serde = { workspace = true } serde_with = { workspace = true } diff --git a/crates/metadata-store/src/local/tests.rs b/crates/metadata-store/src/local/tests.rs index 15de671ec..9c6023155 100644 --- a/crates/metadata-store/src/local/tests.rs +++ b/crates/metadata-store/src/local/tests.rs @@ -331,9 +331,7 @@ async fn create_test_environment( let task_center = &env.tc; - task_center.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(config.clone().map(|c| &c.common)) - }); + task_center.run_in_scope_sync(|| RocksDbManager::init(config.clone().map(|c| &c.common))); let client = start_metadata_store( config.pinned().common.metadata_store_client.clone(), diff --git a/crates/node/src/network_server/grpc_svc_handler.rs b/crates/node/src/network_server/grpc_svc_handler.rs index 835192b6b..de7824450 100644 --- a/crates/node/src/network_server/grpc_svc_handler.rs +++ b/crates/node/src/network_server/grpc_svc_handler.rs @@ -21,7 +21,9 @@ use restate_core::network::protobuf::node_svc::{ }; use restate_core::network::ConnectionManager; use restate_core::network::{ProtocolError, TransportConnect}; -use restate_core::{metadata, MetadataKind, TargetVersion, TaskCenter}; +use restate_core::{ + metadata, MetadataKind, TargetVersion, TaskCenter, TaskCenterFutureExt, TaskKind, +}; use restate_types::health::Health; use restate_types::nodes_config::Role; use restate_types::protobuf::node::Message; @@ -61,7 +63,7 @@ impl NodeSvc for NodeSvcHandler { let metadata_server_status = self.health.current_metadata_server_status(); let log_server_status = self.health.current_log_server_status(); let age_s = self.task_center.age().as_secs(); - self.task_center.run_in_scope_sync("get_ident", None, || { + self.task_center.run_in_scope_sync(|| { let metadata = metadata(); Ok(Response::new(IdentResponse { status: node_status.into(), @@ -98,12 +100,9 @@ impl NodeSvc for NodeSvcHandler { let incoming = request.into_inner(); let transformed = incoming.map(|x| x.map_err(ProtocolError::from)); let output_stream = self - .task_center - .run_in_scope( - "accept-connection", - None, - self.connections.accept_incoming_connection(transformed), - ) + .connections + .accept_incoming_connection(transformed) + .in_current_tc_as_task(TaskKind::InPlace, "accept-connection") .await?; // For uniformity with outbound connections, we map all responses to Ok, we never rely on diff --git a/crates/partition-store/benches/basic_benchmark.rs b/crates/partition-store/benches/basic_benchmark.rs index 906983ebe..4d9c87b62 100644 --- a/crates/partition-store/benches/basic_benchmark.rs +++ b/crates/partition-store/benches/basic_benchmark.rs @@ -47,10 +47,8 @@ fn basic_writing_reading_benchmark(c: &mut Criterion) { .expect("task_center builds"); let worker_options = WorkerOptions::default(); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); - let rocksdb = tc.block_on("test-setup", None, async { + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); + let rocksdb = tc.block_on(async { // // setup // diff --git a/crates/partition-store/src/tests/mod.rs b/crates/partition-store/src/tests/mod.rs index 1f365387f..07102e22c 100644 --- a/crates/partition-store/src/tests/mod.rs +++ b/crates/partition-store/src/tests/mod.rs @@ -52,9 +52,7 @@ async fn storage_test_environment_with_manager() -> (PartitionStoreManager, Part .ingress_runtime_handle(tokio::runtime::Handle::current()) .build() .expect("task_center builds"); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let worker_options = Live::from_value(WorkerOptions::default()); let manager = PartitionStoreManager::create( worker_options.clone().map(|c| &c.storage), diff --git a/crates/service-client/Cargo.toml b/crates/service-client/Cargo.toml index c3835836b..ca85eb334 100644 --- a/crates/service-client/Cargo.toml +++ b/crates/service-client/Cargo.toml @@ -51,7 +51,7 @@ aws-smithy-async = {version = "1.2.1", default-features = false} aws-smithy-runtime = {version = "1.6.2", default-features = false} aws-smithy-runtime-api = {version = "1.7.1", default-features = false} aws-smithy-types = { version = "1.2.0", default-features = false} -pin-project-lite = "0.2.13" +pin-project-lite = { workspace = true } [dev-dependencies] tempfile = { workspace = true } diff --git a/crates/storage-query-datafusion/src/mocks.rs b/crates/storage-query-datafusion/src/mocks.rs index 9eb16d458..031fbe7c1 100644 --- a/crates/storage-query-datafusion/src/mocks.rs +++ b/crates/storage-query-datafusion/src/mocks.rs @@ -161,9 +161,8 @@ impl MockQueryEngine { + 'static, ) -> Self { // Prepare Rocksdb - task_center().run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + task_center() + .run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let worker_options = Live::from_value(WorkerOptions::default()); let manager = PartitionStoreManager::create( worker_options.clone().map(|c| &c.storage), diff --git a/crates/worker/src/partition/leadership.rs b/crates/worker/src/partition/leadership.rs index d8be3056b..0810ad1b7 100644 --- a/crates/worker/src/partition/leadership.rs +++ b/crates/worker/src/partition/leadership.rs @@ -1136,9 +1136,7 @@ mod tests { let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); - tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + tc.run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let bifrost = tc .run_in_scope( diff --git a/crates/worker/src/partition_processor_manager/mod.rs b/crates/worker/src/partition_processor_manager/mod.rs index 12a27bd13..f711376ce 100644 --- a/crates/worker/src/partition_processor_manager/mod.rs +++ b/crates/worker/src/partition_processor_manager/mod.rs @@ -651,7 +651,6 @@ impl PartitionProcessorManager { // where doing otherwise appears to starve the Tokio event loop, causing very slow startup. let handle = self.task_center.spawn_blocking_unmanaged( "starting-partition-processor", - Some(partition_id), starting_task.run(), ); @@ -955,11 +954,9 @@ mod tests { TestCoreEnvBuilder::with_incoming_only_connector().set_nodes_config(nodes_config); let health_status = HealthStatus::default(); - env_builder - .tc - .run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())); - }); + env_builder.tc.run_in_scope_sync(|| { + RocksDbManager::init(Constant::new(CommonOptions::default())); + }); let bifrost_svc = BifrostService::new(env_builder.tc.clone(), env_builder.metadata.clone()) .with_factory(memory_loglet::Factory::default()); diff --git a/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs b/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs index 20cfe936f..4301bd62f 100644 --- a/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs +++ b/crates/worker/src/partition_processor_manager/persisted_lsn_watchdog.rs @@ -178,9 +178,9 @@ mod tests { let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); - node_env.tc.run_in_scope_sync("db-manager-init", None, || { - RocksDbManager::init(Constant::new(CommonOptions::default())) - }); + node_env + .tc + .run_in_scope_sync(|| RocksDbManager::init(Constant::new(CommonOptions::default()))); let all_partition_keys = RangeInclusive::new(0, PartitionKey::MAX); let partition_store_manager = PartitionStoreManager::create( diff --git a/server/src/main.rs b/server/src/main.rs index 56d81e54e..fc1d67d5f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -161,7 +161,7 @@ fn main() { .options(Configuration::pinned().common.clone()) .build() .expect("task_center builds"); - tc.block_on("main", None, { + tc.block_on({ let tc = tc.clone(); async move { // Apply tracing config globally diff --git a/server/tests/common/replicated_loglet.rs b/server/tests/common/replicated_loglet.rs index 9604c1005..90fed46c2 100644 --- a/server/tests/common/replicated_loglet.rs +++ b/server/tests/common/replicated_loglet.rs @@ -7,6 +7,7 @@ use googletest::IntoTestResult; use restate_bifrost::{loglet::Loglet, Bifrost, BifrostAdmin}; use restate_core::metadata_store::Precondition; +use restate_core::TaskCenterFutureExt; use restate_core::{metadata_store::MetadataStoreClient, MetadataWriter, TaskCenterBuilder}; use restate_local_cluster_runner::{ cluster::{Cluster, MaybeTempDir, StartedCluster}, @@ -120,7 +121,7 @@ where // this will still respect LOCAL_CLUSTER_RUNNER_RETAIN_TEMPDIR=true let base_dir: MaybeTempDir = tempfile::tempdir()?.into(); - tc.run_in_scope("test", None, async { + async { RocksDbManager::init(Configuration::mapped_updateable(|c| &c.common)); let cluster = Cluster::builder() @@ -167,19 +168,16 @@ where .await?; // global metadata should now be set, running in scope sets it in the task center context - tc.run_in_scope( - "test-fn", - None, - future(TestEnv { - bifrost, - loglet, - cluster, - metadata_writer, - metadata_store_client, - }), - ) + future(TestEnv { + bifrost, + loglet, + cluster, + metadata_writer, + metadata_store_client, + }) .await - }) + } + .with_tc(&tc) .await?; tc.shutdown_node("test completed", 0).await; diff --git a/tools/bifrost-benchpress/Cargo.toml b/tools/bifrost-benchpress/Cargo.toml index 7563ace25..433a8642f 100644 --- a/tools/bifrost-benchpress/Cargo.toml +++ b/tools/bifrost-benchpress/Cargo.toml @@ -26,7 +26,7 @@ restate-metadata-store = { workspace = true } restate-rocksdb = { workspace = true } restate-test-util = { workspace = true } restate-tracing-instrumentation = { workspace = true, features = ["rt-tokio"] } -restate-types = { workspace = true, features = ["test-util"] } +restate-types = { workspace = true, features = ["test-util", "clap"] } anyhow = { workspace = true } bytes = { workspace = true } diff --git a/tools/bifrost-benchpress/src/main.rs b/tools/bifrost-benchpress/src/main.rs index 43b28a18c..5ef260e63 100644 --- a/tools/bifrost-benchpress/src/main.rs +++ b/tools/bifrost-benchpress/src/main.rs @@ -92,7 +92,7 @@ fn main() -> anyhow::Result<()> { let (tc, bifrost) = spawn_environment(Configuration::updateable(), 1); let task_center = tc.clone(); let args = cli_args.clone(); - tc.block_on("benchpress", None, async move { + tc.block_on(async move { let tracing_guard = init_tracing_and_logging(&config.common, "Bifrost benchpress") .expect("failed to configure logging and tracing!"); @@ -149,7 +149,7 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, .expect("task_center builds"); let task_center = tc.clone(); - let bifrost = tc.block_on("spawn", None, async move { + let bifrost = tc.block_on(async move { let metadata_builder = MetadataBuilder::default(); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); diff --git a/tools/bifrost-benchpress/src/write_to_read.rs b/tools/bifrost-benchpress/src/write_to_read.rs index e5a4ad65d..76759982d 100644 --- a/tools/bifrost-benchpress/src/write_to_read.rs +++ b/tools/bifrost-benchpress/src/write_to_read.rs @@ -17,7 +17,7 @@ use hdrhistogram::Histogram; use tracing::info; use restate_bifrost::Bifrost; -use restate_core::{TaskCenter, TaskHandle, TaskKind}; +use restate_core::{Metadata, TaskCenter, TaskHandle, TaskKind}; use restate_types::logs::{KeyFilter, LogId, Lsn, SequenceNumber, WithKeys}; use crate::util::{print_latencies, DummyPayload}; @@ -140,7 +140,7 @@ pub async fn run( println!( "Log Chain: {:#?}", - tc.metadata().unwrap().logs_ref().chain(&LOG_ID).unwrap() + Metadata::current().logs_ref().chain(&LOG_ID).unwrap() ); println!("Payload size per record: {} bytes", args.payload_size); println!();