diff --git a/kernel/src/executor/mod.rs b/kernel/src/executor/mod.rs index 3b7cb877..2dfa2745 100644 --- a/kernel/src/executor/mod.rs +++ b/kernel/src/executor/mod.rs @@ -78,12 +78,13 @@ mod queue; mod scheduler; mod task; +mod wake_list; mod yield_now; +use crate::executor::task::JoinHandle; use core::future::Future; use rand::RngCore; use sync::OnceLock; -pub use task::JoinHandle; static EXECUTOR: OnceLock = OnceLock::new(); @@ -91,7 +92,7 @@ pub struct Executor { /// Handle to the scheduler used by this runtime // If we ever want to support multiple runtimes, this should become an enum over the different // variants. For now, we only support the multithreaded scheduler. - scheduler: scheduler::multi_thread::Handle, + scheduler: scheduler::Handle, } /// Get a reference to the current executor. @@ -109,7 +110,7 @@ pub fn current() -> &'static Executor { pub fn init(num_cores: usize, rng: &mut impl RngCore, shutdown_on_idle: bool) -> &'static Executor { #[expect(tail_expr_drop_order, reason = "")] EXECUTOR.get_or_init(|| Executor { - scheduler: scheduler::multi_thread::Handle::new(num_cores, rng, shutdown_on_idle), + scheduler: scheduler::Handle::new(num_cores, rng, shutdown_on_idle), }) } @@ -118,7 +119,7 @@ pub fn init(num_cores: usize, rng: &mut impl RngCore, shutdown_on_idle: bool) -> /// This function will not return until the runtime is shut down. #[inline] pub fn run(rt: &'static Executor, hartid: usize, initial: impl FnOnce()) -> Result<(), ()> { - scheduler::multi_thread::worker::run(&rt.scheduler, hartid, initial) + scheduler::worker::run(&rt.scheduler, hartid, initial) } impl Executor { diff --git a/kernel/src/executor/scheduler/multi_thread/idle.rs b/kernel/src/executor/scheduler/idle.rs similarity index 100% rename from kernel/src/executor/scheduler/multi_thread/idle.rs rename to kernel/src/executor/scheduler/idle.rs diff --git a/kernel/src/executor/scheduler/mod.rs b/kernel/src/executor/scheduler/mod.rs index 72432c45..02fc68dd 100644 --- a/kernel/src/executor/scheduler/mod.rs +++ b/kernel/src/executor/scheduler/mod.rs @@ -5,4 +5,116 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -pub mod multi_thread; +mod idle; +pub mod worker; + +use crate::executor::scheduler::idle::Idle; +use crate::executor::task::{JoinHandle, OwnedTasks, TaskRef}; +use crate::executor::{queue, task}; +use crate::hart_local::HartLocal; +use crate::util::condvar::Condvar; +use crate::util::fast_rand::FastRand; +use crate::util::parking_spot::ParkingSpot; +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::future::Future; +use core::sync::atomic::{AtomicBool, Ordering}; +use core::task::Waker; +use rand::RngCore; +use sync::Mutex; +pub struct Handle { + shared: worker::Shared, +} + +impl Handle { + #[expect(tail_expr_drop_order, reason = "")] + pub fn new(num_cores: usize, rand: &mut impl RngCore, shutdown_on_idle: bool) -> Self { + let mut cores = Vec::with_capacity(num_cores); + let mut remotes = Vec::with_capacity(num_cores); + + for i in 0..num_cores { + let (steal, run_queue) = queue::new(); + + cores.push(Box::new(worker::Core { + index: i, + run_queue, + lifo_slot: None, + is_searching: false, + rand: FastRand::new(rand.next_u64()), + })); + remotes.push(worker::Remote { steal }); + } + + let (idle, idle_synced) = Idle::new(cores); + + let stub = TaskRef::new_stub(); + let run_queue = mpsc_queue::MpscQueue::new_with_stub(stub); + + Self { + shared: worker::Shared { + shutdown: AtomicBool::new(false), + remotes: remotes.into_boxed_slice(), + owned: OwnedTasks::new(), + synced: Mutex::new(worker::Synced { + assigned_cores: (0..num_cores).map(|_| None).collect(), + idle: idle_synced, + shutdown_cores: Vec::with_capacity(num_cores), + }), + run_queue, + idle, + condvars: (0..num_cores).map(|_| Condvar::new()).collect(), + parking_spot: ParkingSpot::default(), + per_hart: HartLocal::with_capacity(num_cores), + shutdown_on_idle, + }, + } + } + + pub fn spawn(&'static self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let id = task::Id::next(); + let (handle, maybe_task) = self.shared.owned.bind(future, self, id); + + if let Some(task) = maybe_task { + self.shared.schedule_task(task, false); + } + + handle + } + + pub fn shutdown(&self) { + if !self.shared.shutdown.swap(true, Ordering::AcqRel) { + let mut synced = self.shared.synced.lock(); + + // Set the shutdown flag on all available cores + self.shared.idle.shutdown(&mut synced, &self.shared); + + // Any unassigned cores need to be shutdown, but we have to first drop + // the lock + drop(synced); + self.shared.idle.shutdown_unassigned_cores(&self.shared); + } + } + + #[inline] + pub(in crate::executor) fn defer(&self, waker: &Waker) { + self.shared.per_hart.get().unwrap().defer(waker); + } +} + +impl task::Schedule for &'static Handle { + fn schedule(&self, task: TaskRef) { + self.shared.schedule_task(task, false); + } + + fn release(&self, task: &TaskRef) -> Option { + self.shared.owned.remove(task) + } + + fn yield_now(&self, task: TaskRef) { + self.shared.schedule_task(task, true); + } +} diff --git a/kernel/src/executor/scheduler/multi_thread/mod.rs b/kernel/src/executor/scheduler/multi_thread/mod.rs deleted file mode 100644 index 4379782f..00000000 --- a/kernel/src/executor/scheduler/multi_thread/mod.rs +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2025 Jonas Kruckenberg -// -// Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be -// copied, modified, or distributed except according to those terms. - -mod idle; -pub mod worker; - -use crate::executor::task::{OwnedTasks, TaskRef}; -use crate::executor::{queue, task, JoinHandle}; -use crate::hart_local::HartLocal; -use crate::util::condvar::Condvar; -use crate::util::fast_rand::FastRand; -use crate::util::parking_spot::ParkingSpot; -use alloc::boxed::Box; -use alloc::vec::Vec; -use core::future::Future; -use core::sync::atomic::{AtomicBool, Ordering}; -use core::task::Waker; -use idle::Idle; -use rand::RngCore; -use sync::Mutex; -use worker::{Core, Remote, Shared, Synced}; - -pub struct Handle { - shared: Shared, -} - -impl Handle { - #[expect(tail_expr_drop_order, reason = "")] - pub fn new(num_cores: usize, rand: &mut impl RngCore, shutdown_on_idle: bool) -> Self { - let mut cores = Vec::with_capacity(num_cores); - let mut remotes = Vec::with_capacity(num_cores); - - for i in 0..num_cores { - let (steal, run_queue) = queue::new(); - - cores.push(Box::new(Core { - index: i, - run_queue, - lifo_slot: None, - is_searching: false, - rand: FastRand::new(rand.next_u64()), - })); - remotes.push(Remote { steal }); - } - - let (idle, idle_synced) = Idle::new(cores); - - let stub = TaskRef::new_stub(); - let run_queue = mpsc_queue::MpscQueue::new_with_stub(stub); - - Self { - shared: Shared { - shutdown: AtomicBool::new(false), - remotes: remotes.into_boxed_slice(), - owned: OwnedTasks::new(), - synced: Mutex::new(Synced { - assigned_cores: (0..num_cores).map(|_| None).collect(), - idle: idle_synced, - shutdown_cores: Vec::with_capacity(num_cores), - }), - run_queue, - idle, - condvars: (0..num_cores).map(|_| Condvar::new()).collect(), - parking_spot: ParkingSpot::default(), - tls: HartLocal::with_capacity(num_cores), - shutdown_on_idle, - }, - } - } - - pub fn spawn(&'static self, future: F) -> JoinHandle - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - let id = task::Id::next(); - let (handle, maybe_task) = self.shared.owned.bind(future, self, id); - - if let Some(task) = maybe_task { - self.shared.schedule_task(task, false); - } - - handle - } - - pub fn shutdown(&self) { - if !self.shared.shutdown.swap(true, Ordering::AcqRel) { - let mut synced = self.shared.synced.lock(); - - // Set the shutdown flag on all available cores - self.shared.idle.shutdown(&mut synced, &self.shared); - - // Any unassigned cores need to be shutdown, but we have to first drop - // the lock - drop(synced); - self.shared.idle.shutdown_unassigned_cores(&self.shared); - } - } - - #[inline] - pub(in crate::executor) fn defer(&self, waker: &Waker) { - self.shared.tls.get().unwrap().defer(waker); - } -} - -impl task::Schedule for &'static Handle { - fn schedule(&self, task: TaskRef) { - self.shared.schedule_task(task, false); - } - - fn release(&self, task: &TaskRef) -> Option { - self.shared.owned.remove(task) - } - - fn yield_now(&self, task: TaskRef) { - self.shared.schedule_task(task, true); - } -} diff --git a/kernel/src/executor/scheduler/multi_thread/worker.rs b/kernel/src/executor/scheduler/worker.rs similarity index 95% rename from kernel/src/executor/scheduler/multi_thread/worker.rs rename to kernel/src/executor/scheduler/worker.rs index 49ea12c0..94578ca2 100644 --- a/kernel/src/executor/scheduler/multi_thread/worker.rs +++ b/kernel/src/executor/scheduler/worker.rs @@ -5,10 +5,12 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -//! A scheduler is initialized with a fixed number of workers. Each worker is -//! driven by a thread. Each worker has a "core" which contains data such as the -//! run queue and other state. When `block_in_place` is called, the worker's -//! "core" is handed off to a new thread allowing the scheduler to continue to +//! Scheduler worker implementation. +//! +//! A scheduler worker is a hart that is running the scheduling loop and which we therefore can +//! schedule work on. A scheduler is initialized with a fixed number of workers. Each worker has +//! a "core" which contains data such as the run queue and other state. When `block_in_place` is called, +//! the worker's "core" is handed off to a new thread allowing the scheduler to continue to //! make progress while the originating thread blocks. //! //! # Shutdown @@ -63,8 +65,8 @@ //! the global queue indefinitely. This would be a ref-count cycle and a memory //! leak. -use super::{idle, Handle}; use crate::executor::queue::Overflow; +use crate::executor::scheduler::{idle, Handle}; use crate::executor::task::{OwnedTasks, TaskRef}; use crate::executor::{queue, task}; use crate::hart_local::HartLocal; @@ -95,8 +97,8 @@ static NUM_NOTIFY_LOCAL: Counter = counter!("scheduler.num-notify-local"); /// A scheduler worker /// -/// Data is stack-allocated and never migrates threads -pub struct Worker { +/// Data is stack-allocated and never migrates harts. +struct Worker { hartid: usize, /// True if the scheduler is being shutdown is_shutdown: bool, @@ -113,9 +115,14 @@ pub struct Worker { /// Core data /// -/// Data is heap-allocated and migrates threads. +/// Data is heap-allocated and migrates harts. +/// +/// You can think of `Core`s and `Worker`s a bit like robots with pluggable batteries. Just like a +/// robot needs the battery to operate, a `Worker` needs a `Core` to operate. Workers are cooperative +/// and will give up their `Core` if they are done with their work or become blocked waiting for +/// interrupts. This allows other `Worker`s to pick up the `Core` and continue work. #[repr(align(128))] -pub(super) struct Core { +pub struct Core { /// Index holding this core's remote/shared state. pub(super) index: usize, /// The worker-local run queue. @@ -150,7 +157,7 @@ pub(super) struct Shared { /// Per-hart thread-local data. Logically this is part of the [`Worker`] struct, but placed here /// into a TLS slot instead of stack allocated so we can access it from other places (i.e. we only /// need access to the scheduler handle instead of access to the workers stack which wouldn't work). - pub(super) tls: HartLocal, + pub(super) per_hart: HartLocal, /// Signal to workers that they should be shutting down. pub(super) shutdown: AtomicBool, /// Whether to shut down the executor when all tasks are processed, used in tests. @@ -178,7 +185,10 @@ pub(super) struct Remote { pub(super) steal: queue::Steal, } -/// Thread-local context +/// Hart-local context +/// +/// Logically this is part of the [`Worker`] struct, but is kept separate to allow access from +/// other parts of the code. pub(super) struct Context { /// Handle to the current scheduler handle: &'static Handle, @@ -208,7 +218,7 @@ pub fn run(handle: &'static Handle, hartid: usize, initial: impl FnOnce()) -> Re }; #[expect(tail_expr_drop_order, reason = "")] - let cx = handle.shared.tls.get_or(|| Context { + let cx = handle.shared.per_hart.get_or(|| Context { handle, core: RefCell::new(None), lifo_enabled: Cell::new(true), @@ -518,6 +528,7 @@ impl Worker { // Safety: we're parking only for a very small amount of time, this is fine unsafe { + // TODO cleanup log::trace!("spin stalling for {:?}", Duration::from_micros(i as u64)); arch::hart_park_timeout(Duration::from_micros(i as u64)); log::trace!("after spin stall"); @@ -740,7 +751,7 @@ impl Core { impl Shared { pub(in crate::executor) fn schedule_task(&self, task: TaskRef, is_yield: bool) { - if let Some(cx) = self.tls.get() { + if let Some(cx) = self.per_hart.get() { // And the current thread still holds a core if let Some(core) = cx.core.borrow_mut().as_mut() { if is_yield { diff --git a/kernel/src/executor/task/join_handle.rs b/kernel/src/executor/task/join_handle.rs index f01946df..461dd261 100644 --- a/kernel/src/executor/task/join_handle.rs +++ b/kernel/src/executor/task/join_handle.rs @@ -5,7 +5,8 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use super::raw::{Header, TaskRef}; +use super::raw::Header; +use crate::executor::task::TaskRef; use core::fmt; use core::future::Future; use core::marker::PhantomData; diff --git a/kernel/src/executor/task/mod.rs b/kernel/src/executor/task/mod.rs index 28846cb6..6cec683b 100644 --- a/kernel/src/executor/task/mod.rs +++ b/kernel/src/executor/task/mod.rs @@ -5,191 +5,19 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -//! The task module. -//! -//! The task module contains the code that manages spawned tasks and provides a -//! safe API for the rest of the runtime to use. Each task in a runtime is -//! stored in an `OwnedTasks` or `LocalOwnedTasks` object. -//! -//! # Task reference types -//! -//! A task is usually referenced by multiple handles, and there are several -//! types of handles. -//! -//! * `OwnedTask` - tasks stored in an `OwnedTasks` or `LocalOwnedTasks` are of this -//! reference type. -//! -//! * `JoinHandle` - each task has a `JoinHandle` that allows access to the output -//! of the task. -//! -//! * `Waker` - every waker for a task has this reference type. There can be any -//! number of waker references. -//! -//! * `Notified` - tracks whether the task is notified. -//! -//! * `Unowned` - this task reference type is used for tasks not stored in any -//! runtime. Mainly used for blocking tasks, but also in tests. -//! -//! The task uses a reference count to keep track of how many active references -//! exist. The `Unowned` reference type takes up two ref-counts. All other -//! reference types take up a single ref-count. -//! -//! Besides the waker type, each task has at most one of each reference type. -//! //! # State //! -//! The task stores its state in an atomic `usize` with various bitfields for the -//! necessary information. The state has the following bitfields: -//! -//! * `RUNNING` - Tracks whether the task is currently being polled or cancelled. -//! This bit functions as a lock around the task. -//! -//! * `COMPLETE` - Is one once the future has fully completed and has been -//! dropped. Never unset once set. Never set together with RUNNING. -//! -//! * `NOTIFIED` - Tracks whether a Notified object currently exists. -//! -//! * `CANCELLED` - Is set to one for tasks that should be cancelled as soon as -//! possible. May take any value for completed tasks. -//! -//! * `JOIN_INTEREST` - Is set to one if there exists a `JoinHandle`. -//! -//! * `JOIN_WAKER` - Acts as an access control bit for the join handle waker. The -//! protocol for its usage is described below. -//! -//! The rest of the bits are used for the ref-count. -//! -//! # Fields in the task + //! //! The task has various fields. This section describes how and when it is safe //! to access a field. -//! -//! * The state field is accessed with atomic instructions. -//! -//! * The `OwnedTask` reference has exclusive access to the `owned` field. -//! -//! * The Notified reference has exclusive access to the `queue_next` field. -//! -//! * The `owner_id` field can be set as part of construction of the task, but -//! is otherwise immutable and anyone can access the field immutably without -//! synchronization. -//! -//! * If COMPLETE is one, then the `JoinHandle` has exclusive access to the -//! stage field. If COMPLETE is zero, then the RUNNING bitfield functions as -//! a lock for the stage field, and it can be accessed only by the thread -//! that set RUNNING to one. -//! -//! * The waker field may be concurrently accessed by different threads: in one -//! thread the runtime may complete a task and *read* the waker field to -//! invoke the waker, and in another thread the task's `JoinHandle` may be -//! polled, and if the task hasn't yet completed, the `JoinHandle` may *write* -//! a waker to the waker field. The `JOIN_WAKER` bit ensures safe access by -//! multiple threads to the waker field using the following rules: -//! -//! 1. `JOIN_WAKER` is initialized to zero. -//! -//! 2. If `JOIN_WAKER` is zero, then the `JoinHandle` has exclusive (mutable) -//! access to the waker field. -//! -//! 3. If `JOIN_WAKER` is one, then the `JoinHandle` has shared (read-only) -//! access to the waker field. -//! -//! 4. If `JOIN_WAKER` is one and COMPLETE is one, then the runtime has shared -//! (read-only) access to the waker field. -//! -//! 5. If the `JoinHandle` needs to write to the waker field, then the -//! `JoinHandle` needs to (i) successfully set `JOIN_WAKER` to zero if it is -//! not already zero to gain exclusive access to the waker field per rule -//! 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` to one. -//! If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped -//! to clear the waker field, only steps (i) and (ii) are relevant. -//! -//! 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e. -//! the task hasn't yet completed). The runtime can change `JOIN_WAKER` only -//! if COMPLETE is one. -//! -//! 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the runtime has -//! exclusive (mutable) access to the waker field. This might happen if the -//! `JoinHandle` gets dropped right after the task completes and the runtime -//! sets the `COMPLETE` bit. In this case the runtime needs the mutable access -//! to the waker field to drop it. -//! -//! Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a -//! race. If step (i) fails, then the attempt to write a waker is aborted. If -//! step (iii) fails because COMPLETE is set to one by another thread after -//! step (i), then the waker field is cleared. Once COMPLETE is one (i.e. -//! task has completed), the `JoinHandle` will not modify `JOIN_WAKER`. After the -//! runtime sets COMPLETE to one, it invokes the waker if there is one so in this -//! case when a task completes the `JOIN_WAKER` bit implicates to the runtime -//! whether it should invoke the waker or not. After the runtime is done with -//! using the waker during task completion, it unsets the `JOIN_WAKER` bit to give -//! the `JoinHandle` exclusive access again so that it is able to drop the waker -//! at a later point. -//! -//! All other fields are immutable and can be accessed immutably without -//! synchronization by anyone. -//! -//! # Safety -//! -//! This section goes through various situations and explains why the API is -//! safe in that situation. -//! -//! ## Polling or dropping the future -//! -//! Any mutable access to the future happens after obtaining a lock by modifying -//! the RUNNING field, so exclusive access is ensured. -//! -//! When the task completes, exclusive access to the output is transferred to -//! the `JoinHandle`. If the `JoinHandle` is already dropped when the transition to -//! complete happens, the thread performing that transition retains exclusive -//! access to the output and should immediately drop it. -//! -//! ## Non-Send futures -//! -//! If a future is not Send, then it is bound to a `LocalOwnedTasks`. The future -//! will only ever be polled or dropped given a `LocalNotified` or inside a call -//! to `LocalOwnedTasks::shutdown_all`. In either case, it is guaranteed that the -//! future is on the right thread. -//! -//! If the task is never removed from the `LocalOwnedTasks`, then it is leaked, so -//! there is no risk that the task is dropped on some other thread when the last -//! ref-count drops. -//! -//! ## Non-Send output -//! -//! When a task completes, the output is placed in the stage of the task. Then, -//! a transition that sets COMPLETE to true is performed, and the value of -//! `JOIN_INTEREST` when this transition happens is read. -//! -//! If `JOIN_INTEREST` is zero when the transition to COMPLETE happens, then the -//! output is immediately dropped. -//! -//! If `JOIN_INTEREST` is one when the transition to COMPLETE happens, then the -//! `JoinHandle` is responsible for cleaning up the output. If the output is not -//! Send, then this happens: -//! -//! 1. The output is created on the thread that the future was polled on. Since -//! only non-Send futures can have non-Send output, the future was polled on -//! the thread that the future was spawned from. -//! 2. Since `JoinHandle` is not Send if Output is not Send, the -//! `JoinHandle` is also on the thread that the future was spawned from. -//! 3. Thus, the `JoinHandle` will not move the output across threads when it -//! takes or drops the output. -//! -//! ## Recursive poll/shutdown -//! -//! Calling poll from inside a shutdown call or vice-versa is not prevented by -//! the API exposed by the task module, so this has to be safe. In either case, -//! the lock in the RUNNING bitfield makes the inner call return immediately. If -//! the inner call is a `shutdown` call, then the CANCELLED bit is set, and the -//! poll call will notice it when the poll finishes, and the task is cancelled -//! at that point. mod error; mod id; mod join_handle; mod owned_tasks; pub(crate) mod raw; +mod references; mod state; mod waker; @@ -198,17 +26,10 @@ pub use error::JoinError; pub use id::Id; pub use join_handle::JoinHandle; pub use owned_tasks::OwnedTasks; -pub use raw::TaskRef; +pub use references::TaskRef; pub type Result = core::result::Result; -pub enum PollResult { - Complete, - Notified, - Done, - Dealloc, -} - pub trait Schedule { /// Schedule the task to run. fn schedule(&self, task: TaskRef); diff --git a/kernel/src/executor/task/owned_tasks.rs b/kernel/src/executor/task/owned_tasks.rs index 5146254d..f98686ad 100644 --- a/kernel/src/executor/task/owned_tasks.rs +++ b/kernel/src/executor/task/owned_tasks.rs @@ -5,7 +5,9 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use super::{raw, Id, JoinHandle, Schedule, TaskRef}; +use super::{raw, Schedule, TaskRef}; +use crate::executor::task::id::Id; +use crate::executor::task::join_handle::JoinHandle; use core::future::Future; use core::sync::atomic::{AtomicBool, Ordering}; use sync::Mutex; diff --git a/kernel/src/executor/task/raw.rs b/kernel/src/executor/task/raw.rs index d3c28313..7462b9d5 100644 --- a/kernel/src/executor/task/raw.rs +++ b/kernel/src/executor/task/raw.rs @@ -5,49 +5,17 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use super::error::JoinError; -use super::id::Id; -use super::state::{ - Snapshot, State, TransitionToIdle, TransitionToNotifiedByRef, TransitionToNotifiedByVal, - TransitionToRunning, -}; -use super::waker::waker_ref; -use super::{PollResult, Schedule}; -use crate::panic; -use alloc::boxed::Box; -use core::alloc::Layout; -use core::any::Any; +use crate::executor::task::id::Id; +use crate::executor::task::state::State; +use crate::executor::task::TaskRef; use core::cell::UnsafeCell; use core::future::Future; -use core::marker::PhantomData; use core::mem; -use core::mem::{offset_of, ManuallyDrop}; -use core::panic::AssertUnwindSafe; +use core::mem::offset_of; use core::pin::Pin; use core::ptr::NonNull; use core::task::{Context, Poll, Waker}; -/// A type-erased, reference-counted pointer to a spawned [`Task`]. -/// -/// `TaskRef`s are reference-counted, and the task will be deallocated when the -/// last `TaskRef` pointing to it is dropped. -#[derive(Eq, PartialEq)] -pub struct TaskRef(NonNull
); - -/// A non-Send variant of Notified with the invariant that it is on a thread -/// where it is safe to poll it. -#[repr(transparent)] -pub struct LocalTaskRef { - pub(super) task: TaskRef, - pub(super) _not_send: PhantomData<*const ()>, -} - -/// A typed pointer to a spawned [`Task`]. It's roughly a lower-level version of [`TaskRef`] -/// that is not reference counted and tied to a specific tasks future type and scheduler. -struct RawTaskRef { - ptr: NonNull>, -} - /// A task. /// /// This struct holds the various parts of a task: the [future][`Future`] @@ -70,7 +38,7 @@ struct RawTaskRef { /// storage). Therefore, operations that are specific to the task's `S`-typed /// [scheduler], `F`-typed [`Future`] are performed via [dynamic dispatch]. /// -/// [scheduler]: crate::executor::scheduler::multi_thread::Handle +/// [scheduler]: crate::executor::scheduler::Handle /// [dynamic dispatch]: https://en.wikipedia.org/wiki/Dynamic_dispatch // # This struct should be cache padded to avoid false sharing. The cache padding rules are copied // from crossbeam-utils/src/cache_padded.rs @@ -154,55 +122,92 @@ struct RawTaskRef { repr(align(64)) )] #[repr(C)] -struct Task { - header: Header, - core: Core, - trailer: Trailer, +pub(super) struct Task { + pub(super) header: Header, + pub(super) core: Core, + pub(super) trailer: Trailer, } #[repr(C)] #[derive(Debug)] -pub struct Header { - /// Task state which can be atomically updated. - pub(super) state: State, - /// The task vtable for this task. +pub(crate) struct Header { + /// The task's state. /// - /// Note that this is different from the [waker vtable], which contains - /// pointers to the waker methods (and depends primarily on the task's - /// scheduler type). The task vtable instead contains methods for - /// interacting with the task's future, such as polling it and reading the - /// task's output. These depend primarily on the type of the future rather - /// than the scheduler. - /// - /// [waker vtable]: core::task::RawWakerVTable + /// This field is access with atomic instructions, so it's always safe to access it. + pub(super) state: State, pub(super) vtable: &'static Vtable, } #[repr(C)] #[derive(Debug)] -pub struct Core { +pub(super) struct Core { pub(super) scheduler: S, - /// Either the future or the output. - stage: UnsafeCell>, - /// The task's ID, used for populating `JoinError`s. + /// The future that the task is running. + /// + /// If `COMPLETE` is one, then the `JoinHandle` has exclusive access to this field + /// If COMPLETE is zero, then the RUNNING bitfield functions as + /// a lock for the stage field, and it can be accessed only by the thread + /// that set RUNNING to one. + pub(super) stage: UnsafeCell>, pub(super) task_id: Id, } -/// Either the future or the output. -#[repr(C)] // https://github.com/rust-lang/miri/issues/3780 -pub(super) enum Stage { - Running(T), - Finished(super::Result), - Consumed, -} - #[repr(C)] #[derive(Debug)] -pub struct Trailer { +pub(super) struct Trailer { /// Consumer task waiting on completion of this task. - waker: UnsafeCell>, - run_queue_links: mpsc_queue::Links
, - owned_tasks_links: linked_list::Links
, + /// + /// This field may be access by different threads: on one hart we may complete a task and *read* + /// the waker field to invoke the waker, and in another thread the task's `JoinHandle` may be + /// polled, and if the task hasn't yet completed, the `JoinHandle` may *write* a waker to the + /// waker field. The `JOIN_WAKER` bit in the headers`state` field ensures safe access by multiple + /// hart to the waker field using the following rules: + /// + /// 1. `JOIN_WAKER` is initialized to zero. + /// + /// 2. If `JOIN_WAKER` is zero, then the `JoinHandle` has exclusive (mutable) + /// access to the waker field. + /// + /// 3. If `JOIN_WAKER` is one, then the `JoinHandle` has shared (read-only) access to the waker + /// field. + /// + /// 4. If `JOIN_WAKER` is one and COMPLETE is one, then the executor has shared (read-only) access + /// to the waker field. + /// + /// 5. If the `JoinHandle` needs to write to the waker field, then the `JoinHandle` needs to + /// (i) successfully set `JOIN_WAKER` to zero if it is not already zero to gain exclusive access + /// to the waker field per rule 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` + /// to one. If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped + /// to clear the waker field, only steps (i) and (ii) are relevant. + /// + /// 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e. + /// the task hasn't yet completed). The executor can change `JOIN_WAKER` only + /// if COMPLETE is one. + /// + /// 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the executor has + /// exclusive (mutable) access to the waker field. This might happen if the + /// `JoinHandle` gets dropped right after the task completes and the executor + /// sets the `COMPLETE` bit. In this case the executor needs the mutable access + /// to the waker field to drop it. + /// + /// Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a + /// race. If step (i) fails, then the attempt to write a waker is aborted. If step (iii) fails + /// because `COMPLETE` is set to one by another thread after step (i), then the waker field is + /// cleared. Once `COMPLETE` is one (i.e. task has completed), the `JoinHandle` will not + /// modify `JOIN_WAKER`. After the runtime sets COMPLETE to one, it invokes the waker if there + /// is one so in this case when a task completes the `JOIN_WAKER` bit implicates to the runtime + /// whether it should invoke the waker or not. After the runtime is done with using the waker + /// during task completion, it unsets the `JOIN_WAKER` bit to give the `JoinHandle` exclusive + /// access again so that it is able to drop the waker at a later point. + pub(super) waker: UnsafeCell>, + /// Links to other tasks in the intrusive global run queue. + /// + /// TODO ownership + pub(super) run_queue_links: mpsc_queue::Links
, + /// Links to other tasks in the global "owned tasks" list. + /// + /// The `OwnedTask` reference has exclusive access to this field. + pub(super) owned_tasks_links: linked_list::Links
, } #[derive(Debug)] @@ -210,7 +215,7 @@ pub(super) struct Vtable { /// Polls the future. pub(super) poll: unsafe fn(NonNull
), /// Schedules the task for execution on the runtime. - schedule: unsafe fn(NonNull
), + pub(super) schedule: unsafe fn(NonNull
), /// Deallocates the memory. pub(super) dealloc: unsafe fn(NonNull
), /// Reads the task output, if complete. @@ -220,645 +225,17 @@ pub(super) struct Vtable { /// Scheduler is being shutdown. pub(super) shutdown: unsafe fn(NonNull
), /// The number of bytes that the `id` field is offset from the header. - id_offset: usize, + pub(super) id_offset: usize, /// The number of bytes that the `trailer` field is offset from the header. - trailer_offset: usize, -} - -impl TaskRef { - pub(crate) fn new_stub() -> Self { - Self(RawTaskRef::new_stub().ptr.cast()) - } - - #[expect(tail_expr_drop_order, reason = "")] - pub(crate) fn new(future: F, scheduler: S, task_id: Id) -> (Self, Self, Self) - where - F: Future, - S: Schedule + 'static, - { - let ptr = RawTaskRef::new(future, scheduler, task_id).ptr.cast(); - (Self(ptr), Self(ptr), Self(ptr)) - } - - pub(crate) fn clone_from_raw(ptr: NonNull
) -> Self { - let this = Self(ptr); - this.state().ref_inc(); - this - } - - pub(super) unsafe fn from_raw(ptr: NonNull
) -> Self { - Self(ptr) - } - - pub(in crate::executor) fn header_ptr(&self) -> NonNull
{ - self.0 - } - pub(super) fn header(&self) -> &Header { - // Safety: constructor ensures the pointer is always valid - unsafe { self.0.as_ref() } - } - /// Returns a reference to the task's state. - pub(super) fn state(&self) -> &State { - &self.header().state - } - - pub(in crate::executor) fn run(self) { - self.poll(); - mem::forget(self); - } - - pub(in crate::executor) fn poll(&self) { - let vtable = self.header().vtable; - // Safety: constructor ensures the pointer is always valid - unsafe { - (vtable.poll)(self.0); - } - } - pub(super) fn schedule(&self) { - let vtable = self.header().vtable; - // Safety: constructor ensures the pointer is always valid - unsafe { - (vtable.schedule)(self.0); - } - } - pub(super) fn dealloc(&self) { - let vtable = self.header().vtable; - // Safety: constructor ensures the pointer is always valid - unsafe { - (vtable.dealloc)(self.0); - } - } - pub(super) unsafe fn try_read_output(&self, dst: *mut (), waker: &Waker) { - let vtable = self.header().vtable; - // Safety: constructor ensures the pointer is always valid - unsafe { - (vtable.try_read_output)(self.0, dst, waker); - } - } - pub(super) fn drop_join_handle_slow(&self) { - let vtable = self.header().vtable; - // Safety: constructor ensures the pointer is always valid - unsafe { (vtable.drop_join_handle_slow)(self.0) } - } - pub(super) fn shutdown(&self) { - let vtable = self.header().vtable; - // Safety: constructor ensures the pointer is always valid - unsafe { (vtable.shutdown)(self.0) } - } - pub(super) fn drop_reference(&self) { - if self.state().ref_dec() { - self.dealloc(); - } - } - /// This call consumes a ref-count and notifies the task. This will create a - /// new Notified and submit it if necessary. - /// - /// The caller does not need to hold a ref-count besides the one that was - /// passed to this call. - pub(super) fn wake_by_val(&self) { - match self.state().transition_to_notified_by_val() { - TransitionToNotifiedByVal::Submit => { - // The caller has given us a ref-count, and the transition has - // created a new ref-count, so we now hold two. We turn the new - // ref-count Notified and pass it to the call to `schedule`. - // - // The old ref-count is retained for now to ensure that the task - // is not dropped during the call to `schedule` if the call - // drops the task it was given. - self.schedule(); - - // Now that we have completed the call to schedule, we can - // release our ref-count. - self.drop_reference(); - } - TransitionToNotifiedByVal::Dealloc => { - self.dealloc(); - } - TransitionToNotifiedByVal::DoNothing => {} - } - } - - /// This call notifies the task. It will not consume any ref-counts, but the - /// caller should hold a ref-count. This will create a new Notified and - /// submit it if necessary. - pub(super) fn wake_by_ref(&self) { - match self.state().transition_to_notified_by_ref() { - TransitionToNotifiedByRef::Submit => { - // The transition above incremented the ref-count for a new task - // and the caller also holds a ref-count. The caller's ref-count - // ensures that the task is not destroyed even if the new task - // is dropped before `schedule` returns. - self.schedule(); - } - TransitionToNotifiedByRef::DoNothing => {} - } - } - - /// Remotely aborts the task. - /// - /// The caller should hold a ref-count, but we do not consume it. - /// - /// This is similar to `shutdown` except that it asks the runtime to perform - /// the shutdown. This is necessary to avoid the shutdown happening in the - /// wrong thread for non-Send tasks. - pub(super) fn remote_abort(&self) { - if self.state().transition_to_notified_and_cancel() { - // The transition has created a new ref-count, which we turn into - // a Notified and pass to the task. - // - // Since the caller holds a ref-count, the task cannot be destroyed - // before the call to `schedule` returns even if the call drops the - // `Notified` internally. - self.schedule(); - } - } -} - -impl Clone for TaskRef { - #[inline] - #[track_caller] - fn clone(&self) -> Self { - log::trace!("TaskRef::clone {:?}", self.0); - self.state().ref_inc(); - Self(self.0) - } -} - -impl Drop for TaskRef { - #[inline] - #[track_caller] - fn drop(&mut self) { - // log::trace!("TaskRef::drop {:?}", self.0); - if self.state().ref_dec() { - self.dealloc(); - } - } -} - -// Safety: task refs are "just" atomically reference counted pointers and the state lifecycle system ensures mutual -// exclusion for mutating methods, thus this type is always Send -unsafe impl Send for TaskRef {} -// Safety: task refs are "just" atomically reference counted pointers and the state lifecycle system ensures mutual -// exclusion for mutating methods, thus this type is always Sync -unsafe impl Sync for TaskRef {} - -impl LocalTaskRef { - /// Runs the task. - pub(crate) fn poll(self) { - let raw = self.task; - raw.poll(); - } -} - -impl RawTaskRef -where - F: Future, - S: Schedule + 'static, -{ - const TASK_VTABLE: Vtable = Vtable { - poll: Self::poll, - schedule: Self::schedule, - dealloc: Self::dealloc, - try_read_output: Self::try_read_output, - drop_join_handle_slow: Self::drop_join_handle_slow, - shutdown: Self::shutdown, - id_offset: get_id_offset::(), - trailer_offset: get_trailer_offset::(), - }; - - pub fn new(future: F, scheduler: S, task_id: Id) -> Self { - let ptr = Box::into_raw(Box::new(Task { - header: Header { - state: State::new(), - vtable: &Self::TASK_VTABLE, - }, - core: Core { - scheduler, - stage: UnsafeCell::new(Stage::Running(future)), - task_id, - }, - trailer: Trailer { - waker: UnsafeCell::new(None), - run_queue_links: mpsc_queue::Links::default(), - owned_tasks_links: linked_list::Links::default(), - }, - })); - - log::trace!( - "allocated task ptr {ptr:?} with layout {:?}", - Layout::new::>() - ); - Self { - // Safety: we just allocated the pointer, it is always valid - ptr: unsafe { NonNull::new_unchecked(ptr) }, - } - } - - unsafe fn poll(ptr: NonNull
) { - // Safety: this method gets called through the vtable ensuring that the pointer is valid - // for this `RawTaskRef`'s `F` and `S` generics. - unsafe { - let this = Self::from_raw(ptr); - - // We pass our ref-count to `poll_inner`. - match Self::poll_inner(ptr) { - PollResult::Notified => { - debug_assert!(this.state().load().ref_count() >= 2); - // The `poll_inner` call has given us two ref-counts back. - // We give one of them to a new task and call `yield_now`. - this.core().scheduler.yield_now(this.get_new_task()); - - // The remaining ref-count is now dropped. We kept the extra - // ref-count until now to ensure that even if the `yield_now` - // call drops the provided task, the task isn't deallocated - // before after `yield_now` returns. - this.drop_reference(); - } - PollResult::Complete => { - this.complete(); - } - PollResult::Dealloc => { - Self::dealloc(ptr); - } - PollResult::Done => (), - } - } - } - - unsafe fn poll_inner(ptr: NonNull
) -> PollResult { - // Safety: caller has to ensure `ptr` is valid - let this = unsafe { Self::from_raw(ptr) }; - - match this.state().transition_to_running() { - TransitionToRunning::Success => { - // Separated to reduce LLVM codegen - fn transition_result_to_poll_result(result: TransitionToIdle) -> PollResult { - match result { - TransitionToIdle::Ok => PollResult::Done, - TransitionToIdle::OkNotified => PollResult::Notified, - TransitionToIdle::OkDealloc => PollResult::Dealloc, - TransitionToIdle::Cancelled => PollResult::Complete, - } - } - let header_ptr = this.header_ptr(); - let waker_ref = waker_ref::(&header_ptr); - let cx = Context::from_waker(&waker_ref); - // Safety: `transition_to_running` returns `Success` only when we have exclusive - // access - let res = unsafe { poll_future(this.core(), cx) }; - - if res == Poll::Ready(()) { - // The future completed. Move on to complete the task. - return PollResult::Complete; - } - - let transition_res = this.state().transition_to_idle(); - if let TransitionToIdle::Cancelled = transition_res { - // The transition to idle failed because the task was - // cancelled during the poll. - // Safety: `transition_to_running` returns `Success` only when we have exclusive - // access - unsafe { cancel_task(this.core()) }; - } - transition_result_to_poll_result(transition_res) - } - TransitionToRunning::Cancelled => { - // Safety: `transition_to_running` returns `Cancelled` only when we have exclusive - // access - unsafe { cancel_task(this.core()) }; - PollResult::Complete - } - TransitionToRunning::Failed => PollResult::Done, - TransitionToRunning::Dealloc => PollResult::Dealloc, - } - } - - unsafe fn schedule(ptr: NonNull
) { - // Safety: this method gets called through the vtable ensuring that the pointer is valid - // for this `RawTaskRef`'s `F` and `S` generics. - unsafe { - let this = Self::from_raw(ptr); - this.core().scheduler.schedule(this.get_new_task()); - } - } - - unsafe fn dealloc(ptr: NonNull
) { - // Safety: The caller of this method just transitioned our ref-count to - // zero, so it is our responsibility to release the allocation. - // - // We don't hold any references into the allocation at this point, but - // it is possible for another thread to still hold a `&State` into the - // allocation if that other thread has decremented its last ref-count, - // but has not yet returned from the relevant method on `State`. - // - // However, the `State` type consists of just an `AtomicUsize`, and an - // `AtomicUsize` wraps the entirety of its contents in an `UnsafeCell`. - // As explained in the documentation for `UnsafeCell`, such references - // are allowed to be dangling after their last use, even if the - // reference has not yet gone out of scope. - // - // Additionally, this method gets called through the vtable ensuring that - // the pointer is valid for this `RawTaskRef`'s `F` and `S` generics. - unsafe { - log::trace!( - "about to dealloc task ptr {:?} with layout {:?}", - ptr.as_ptr(), - Layout::new::>() - ); - drop(Box::from_raw(ptr.cast::>().as_ptr())); - log::trace!("deallocated task"); - } - } - - unsafe fn try_read_output(ptr: NonNull
, dst: *mut (), waker: &Waker) { - // Safety: this method gets called through the vtable ensuring that the pointer is valid - // for this `RawTaskRef`'s `F` and `S` generics. The caller has to ensure the `dst` pointer - // is valid. - unsafe { - let this = Self::from_raw(ptr); - let dst = dst.cast::>>(); - if can_read_output(this.header(), this.trailer(), waker) { - *dst = Poll::Ready(this.core().take_output()); - } - } - } - - unsafe fn drop_join_handle_slow(ptr: NonNull
) { - // Safety: this method gets called through the vtable ensuring that the pointer is valid - // for this `RawTaskRef`'s `F` and `S` generics - unsafe { - let this = Self::from_raw(ptr); - - // Try to unset `JOIN_INTEREST` and `JOIN_WAKER`. This must be done as a first step in - // case the task concurrently completed. - let transition = this.state().transition_to_join_handle_dropped(); - - if transition.drop_output { - // It is our responsibility to drop the output. This is critical as - // the task output may not be `Send` and as such must remain with - // the scheduler or `JoinHandle`. i.e. if the output remains in the - // task structure until the task is deallocated, it may be dropped - // by a Waker on any arbitrary thread. - // - // Panics are delivered to the user via the `JoinHandle`. Given that - // they are dropping the `JoinHandle`, we assume they are not - // interested in the panic and swallow it. - let _ = panic::catch_unwind(AssertUnwindSafe(|| { - this.core().drop_future_or_output(); - })); - } - - if transition.drop_waker { - // If the JOIN_WAKER flag is unset at this point, the task is either - // already terminal or not complete so the `JoinHandle` is responsible - // for dropping the waker. - // Safety: - // If the JOIN_WAKER bit is not set the join handle has exclusive - // access to the waker as per rule 2 in task/mod.rs. - // This can only be the case at this point in two scenarios: - // 1. The task completed and the runtime unset `JOIN_WAKER` flag - // after accessing the waker during task completion. So the - // `JoinHandle` is the only one to access the join waker here. - // 2. The task is not completed so the `JoinHandle` was able to unset - // `JOIN_WAKER` bit itself to get mutable access to the waker. - // The runtime will not access the waker when this flag is unset. - this.trailer().set_waker(None); - } - - // Drop the `JoinHandle` reference, possibly deallocating the task - this.drop_reference(); - } - } - - unsafe fn shutdown(ptr: NonNull
) { - // Safety: this method gets called through the vtable ensuring that the pointer is valid - // for this `RawTaskRef`'s `F` and `S` generics - unsafe { - let this = Self::from_raw(ptr); - - if !this.state().transition_to_shutdown() { - // The task is concurrently running. No further work needed. - this.drop_reference(); - return; - } - - // By transitioning the lifecycle to `Running`, we have permission to - // drop the future. - cancel_task(this.core()); - this.complete(); - } - } - - /// Construct a typed task reference from an untyped pointer to a task. - /// - /// # Safety - /// - /// The caller has to ensure `ptr` is a valid task AND that the tasks output and scheduler - /// match this types generic arguments `F` and `S`. Getting this wrong e.g. calling - /// `RawTaskRef::<(), S>::from_raw` on a task that has the output type `i32` will likely lead - /// to stack corruption. - unsafe fn from_raw(ptr: NonNull
) -> Self { - Self { ptr: ptr.cast() } - } - - fn header_ptr(&self) -> NonNull
{ - self.ptr.cast() - } - - fn header(&self) -> &Header { - // Safety: constructor ensures the pointer is always valid - unsafe { &*self.header_ptr().as_ptr() } - } - - fn state(&self) -> &State { - &self.header().state - } - - fn core(&self) -> &Core { - // Safety: constructor ensures the pointer is always valid - unsafe { &self.ptr.as_ref().core } - } - - fn trailer(&self) -> &Trailer { - // Safety: constructor ensures the pointer is always valid - unsafe { &self.ptr.as_ref().trailer } - } - - fn drop_reference(self) { - if self.state().ref_dec() { - // Safety: `ref_dec` returns true if no other references exist, so deallocation is safe - unsafe { - Self::dealloc(self.ptr.cast()); - } - } - } - - fn complete(&self) { - // The future has completed and its output has been written to the task - // stage. We transition from running to complete. - let snapshot = self.state().transition_to_complete(); - - // We catch panics here in case dropping the future or waking the - // JoinHandle panics. - let _ = panic::catch_unwind(AssertUnwindSafe(|| { - if !snapshot.is_join_interested() { - // The `JoinHandle` is not interested in the output of - // this task. It is our responsibility to drop the - // output. The join waker was already dropped by the - // `JoinHandle` before. - // Safety: the COMPLETE bit has been set above and the JOIN_INTEREST bit is unset - // so according to rule 7 we have mutable exclusive access - unsafe { - self.core().drop_future_or_output(); - } - } else if snapshot.is_join_waker_set() { - // Notify the waker. - // Safety: Reading the waker field is safe per rule 4 - // in task/mod.rs, since the JOIN_WAKER bit is set and the call - // to transition_to_complete() above set the COMPLETE bit. - unsafe { - self.trailer().wake_join(); - } - - // Inform the `JoinHandle` that we are done waking the waker by - // unsetting the `JOIN_WAKER` bit. If the `JoinHandle` has - // already been dropped and `JOIN_INTEREST` is unset, then we must - // drop the waker ourselves. - if !self - .state() - .unset_waker_after_complete() - .is_join_interested() - { - // SAFETY: We have COMPLETE=1 and JOIN_INTEREST=0, so - // we have exclusive access to the waker. - unsafe { self.trailer().set_waker(None) }; - } - } - })); - } - - /// Releases the task from the scheduler. Returns the number of ref-counts - /// that should be decremented. - fn release(&self) -> usize { - // We don't actually increment the ref-count here, but the new task is - // never destroyed, so that's ok. - let me = ManuallyDrop::new(self.get_new_task()); - - if let Some(task) = self.core().scheduler.release(&me) { - mem::forget(task); - 2 - } else { - 1 - } - } - - fn get_new_task(&self) -> TaskRef { - // safety: The header is at the beginning of the cell, so this cast is - // safe. - unsafe { TaskRef::from_raw(self.ptr.cast()) } - } + pub(super) trailer_offset: usize, } -struct Stub; -impl Future for Stub { - type Output = (); - - fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { - unreachable!("poll called on a stub future") - } -} - -impl Schedule for Stub { - fn schedule(&self, _task: TaskRef) { - unreachable!("schedule called on a stub scheduler") - } - fn release(&self, _task: &TaskRef) -> Option { - unreachable!("release called on a stub scheduler") - } - fn yield_now(&self, _task: TaskRef) { - unreachable!("yield_now called on a stub scheduler") - } -} - -impl RawTaskRef { - const STUB_VTABLE: Vtable = Vtable { - poll: Self::poll_stub, - schedule: Self::schedule_stub, - dealloc: Self::dealloc, - try_read_output: Self::try_read_output_stub, - drop_join_handle_slow: Self::drop_join_handle_slow_stub, - shutdown: Self::shutdown_stub, - id_offset: get_id_offset::(), - trailer_offset: get_trailer_offset::(), - }; - - pub fn new_stub() -> Self { - let ptr = Box::into_raw(Box::new(Task { - header: Header { - state: State::new(), - vtable: &Self::STUB_VTABLE, - }, - core: Core { - scheduler: Stub, - stage: UnsafeCell::new(Stage::Running(Stub)), - task_id: Id::stub(), - }, - trailer: Trailer { - waker: UnsafeCell::new(None), - run_queue_links: mpsc_queue::Links::default(), - owned_tasks_links: linked_list::Links::default(), - }, - })); - log::trace!("allocated stub ptr {ptr:?}"); - - Self { - // Safety: we just allocated the pointer, it is always valid - ptr: unsafe { NonNull::new_unchecked(ptr) }, - } - } - - unsafe fn poll_stub(_ptr: NonNull
) { - // Safety: this method should never be called - unsafe { - debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); - unreachable!("poll_stub called on a stub task"); - } - } - - unsafe fn schedule_stub(_ptr: NonNull
) { - // Safety: this method should never be called - unsafe { - debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); - unreachable!("schedule_stub called on a stub task"); - } - } - - unsafe fn try_read_output_stub(_ptr: NonNull
, _dst: *mut (), _waker: &Waker) { - // Safety: this method should never be called - unsafe { - debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); - unreachable!("try_read_output_stub called on a stub task"); - } - } - - unsafe fn drop_join_handle_slow_stub(_ptr: NonNull
) { - // Safety: this method should never be called - unsafe { - debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); - unreachable!("drop_join_handle_slow_stub called on a stub task"); - } - } - - /// # Safety - /// - /// The caller must ensure the pointer is valid - unsafe fn shutdown_stub(_ptr: NonNull
) { - // Safety: this method should never be called - unsafe { - debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); - unreachable!("shutdown_stub called on a stub task"); - } - } +/// Either the future or the output. +#[repr(C)] // https://github.com/rust-lang/miri/issues/3780 +pub(super) enum Stage { + Running(T), + Finished(super::Result), + Consumed, } impl Header { @@ -902,18 +279,17 @@ unsafe impl linked_list::Linked for Header { type Handle = TaskRef; fn into_ptr(task: Self::Handle) -> NonNull { - let ptr = task.0; + let ptr = task.header_ptr(); // converting a `TaskRef` into a pointer to enqueue it assigns ownership // of the ref count to the list, so we don't want to run its `Drop` // impl. mem::forget(task); ptr } - unsafe fn from_ptr(ptr: NonNull) -> Self::Handle { - TaskRef(ptr) + // Safety: ensured by the caller + unsafe { TaskRef::from_raw(ptr) } } - unsafe fn links(ptr: NonNull) -> NonNull> { // Safety: `TaskRef` is just a newtype wrapper around `NonNull
` unsafe { @@ -934,18 +310,17 @@ unsafe impl mpsc_queue::Linked for Header { type Handle = TaskRef; fn into_ptr(task: Self::Handle) -> NonNull { - let ptr = task.0; + let ptr = task.header_ptr(); // converting a `TaskRef` into a pointer to enqueue it assigns ownership // of the ref count to the queue, so we don't want to run its `Drop` // impl. mem::forget(task); ptr } - unsafe fn from_ptr(ptr: NonNull) -> Self::Handle { - TaskRef(ptr) + // Safety: ensured by the caller + unsafe { TaskRef::from_raw(ptr) } } - unsafe fn links(ptr: NonNull) -> NonNull> where Self: Sized, @@ -1081,167 +456,12 @@ impl Trailer { } } -#[expect(tail_expr_drop_order, reason = "TODO")] -fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { - // Load a snapshot of the current task state - let snapshot = header.state.load(); - - debug_assert!(snapshot.is_join_interested()); - - if !snapshot.is_complete() { - // If the task is not complete, try storing the provided waker in the - // task's waker field. - - let res = if snapshot.is_join_waker_set() { - // If JOIN_WAKER is set, then JoinHandle has previously stored a - // waker in the waker field per step (iii) of rule 5 in task/mod.rs. - - // Optimization: if the stored waker and the provided waker wake the - // same task, then return without touching the waker field. - // Safety: Reading the waker field below is safe per rule 3 in task/mod.rs. - if unsafe { trailer.will_wake(waker) } { - return false; - } - - // Otherwise swap the stored waker with the provided waker by - // following the rule 5 in task/mod.rs. - header - .state - .unset_waker() - .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot)) - } else { - // If JOIN_WAKER is unset, then JoinHandle has mutable access to the - // waker field per rule 2 in task/mod.rs; therefore, skip step (i) - // of rule 5 and try to store the provided waker in the waker field. - // Safety: absence of JOIN_WAKER means we have exclusive access - set_join_waker(header, trailer, waker.clone(), snapshot) - }; - - match res { - Ok(_) => return false, - Err(snapshot) => { - assert!(snapshot.is_complete()); - } - } - } - true -} - -fn set_join_waker( - header: &Header, - trailer: &Trailer, - waker: Waker, - snapshot: Snapshot, -) -> Result { - assert!(snapshot.is_join_interested()); - assert!(!snapshot.is_join_waker_set()); - - // Safety: Only the `JoinHandle` may set the `waker` field. When - // `JOIN_INTEREST` is **not** set, nothing else will touch the field. - unsafe { - trailer.set_waker(Some(waker)); - - // Update the `JoinWaker` state accordingly - let res = header.state.set_join_waker(); - - // If the state could not be updated, then clear the join waker - if res.is_err() { - trailer.set_waker(None); - } - - res - } -} - -/// Cancels the task and store the appropriate error in the stage field. -/// -/// # Safety -/// -/// The caller has to ensure this hart has exclusive mutable access to the tasks `stage` field (ie the -/// future or output). -unsafe fn cancel_task(core: &Core) { - // Safety: caller has to ensure mutual exclusion - unsafe { - // Drop the future from a panic guard. - let res = panic::catch_unwind(AssertUnwindSafe(|| { - core.drop_future_or_output(); - })); - - core.store_output(Err(panic_result_to_join_error(core.task_id, res))); - } -} - -/// Polls the future. If the future completes, the output is written to the -/// stage field. -/// -/// # Safety -/// -/// The caller has to ensure this hart has exclusive mutable access to the tasks `stage` field (ie the -/// future or output). -unsafe fn poll_future(core: &Core, cx: Context<'_>) -> Poll<()> { - // Poll the future. - let output = panic::catch_unwind(AssertUnwindSafe(|| { - struct Guard<'a, T: Future, S: Schedule> { - core: &'a Core, - } - impl Drop for Guard<'_, T, S> { - fn drop(&mut self) { - // If the future panics on poll, we drop it inside the panic - // guard. - // Safety: caller has to ensure mutual exclusion - unsafe { - self.core.drop_future_or_output(); - } - } - } - let guard = Guard { core }; - // Safety: caller has to ensure mutual exclusion - let res = unsafe { guard.core.poll(cx) }; - mem::forget(guard); - res - })); - - // Prepare output for being placed in the core stage. - let output = match output { - Ok(Poll::Pending) => return Poll::Pending, - Ok(Poll::Ready(output)) => Ok(output), - Err(panic) => Err(panic_to_error(core.task_id, panic)), - }; - - // Catch and ignore panics if the future panics on drop. - // Safety: caller has to ensure mutual exclusion - let res = panic::catch_unwind(AssertUnwindSafe(|| unsafe { - core.store_output(output); - })); - - assert!(res.is_ok(), "unhandled panic {res:?}"); - - Poll::Ready(()) -} - -fn panic_result_to_join_error( - task_id: Id, - res: Result<(), Box>, -) -> JoinError { - match res { - Ok(()) => JoinError::cancelled(task_id), - Err(panic) => JoinError::panic(task_id, panic), - } -} - -#[cold] -fn panic_to_error(task_id: Id, panic: Box) -> JoinError { - log::error!("unhandled panic"); - // scheduler().unhandled_panic(); - JoinError::panic(task_id, panic) -} - /// Compute the offset of the `Core` field in `Task` using the /// `#[repr(C)]` algorithm. /// /// Pseudo-code for the `#[repr(C)]` algorithm can be found here: /// -const fn get_core_offset() -> usize { +pub const fn get_core_offset() -> usize { let mut offset = size_of::
(); let core_misalign = offset % align_of::>(); @@ -1257,7 +477,7 @@ const fn get_core_offset() -> usize { /// /// Pseudo-code for the `#[repr(C)]` algorithm can be found here: /// -const fn get_id_offset() -> usize { +pub const fn get_id_offset() -> usize { let mut offset = get_core_offset::(); offset += size_of::(); @@ -1274,7 +494,7 @@ const fn get_id_offset() -> usize { /// /// Pseudo-code for the `#[repr(C)]` algorithm can be found here: /// -const fn get_trailer_offset() -> usize { +pub const fn get_trailer_offset() -> usize { let mut offset = size_of::
(); let core_misalign = offset % align_of::>(); diff --git a/kernel/src/executor/task/references.rs b/kernel/src/executor/task/references.rs new file mode 100644 index 00000000..ae0c85d9 --- /dev/null +++ b/kernel/src/executor/task/references.rs @@ -0,0 +1,836 @@ +// Copyright 2025 Jonas Kruckenberg +// +// Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use crate::executor::task::error::JoinError; +use crate::executor::task::id::Id; +use crate::executor::task::raw::{ + get_id_offset, get_trailer_offset, Core, Header, Stage, Task, Trailer, Vtable, +}; +use crate::executor::task::state::{ + Snapshot, State, TransitionToIdle, TransitionToNotifiedByRef, TransitionToNotifiedByVal, + TransitionToRunning, +}; +use crate::executor::task::waker::waker_ref; +use crate::executor::task::Schedule; +use crate::panic; +use alloc::boxed::Box; +use core::alloc::Layout; +use core::any::Any; +use core::cell::UnsafeCell; +use core::future::Future; +use core::mem; +use core::panic::AssertUnwindSafe; +use core::pin::Pin; +use core::ptr::NonNull; +use core::task::{Context, Poll, Waker}; + +/// A type-erased, reference-counted pointer to a spawned [`Task`]. +/// +/// `TaskRef`s are reference-counted, and the task will be deallocated when the +/// last `TaskRef` pointing to it is dropped. +#[derive(Eq, PartialEq)] +pub struct TaskRef(NonNull
); + +/// A typed pointer to a spawned [`Task`]. It's roughly a lower-level version of [`TaskRef`] +/// that is not reference counted and tied to a specific tasks future type and scheduler. +struct RawTaskRef { + ptr: NonNull>, +} + +impl TaskRef { + pub(crate) fn new_stub() -> Self { + Self(RawTaskRef::new_stub().ptr.cast()) + } + + #[expect(tail_expr_drop_order, reason = "")] + pub(crate) fn new(future: F, scheduler: S, task_id: Id) -> (Self, Self, Self) + where + F: Future, + S: Schedule + 'static, + { + let ptr = RawTaskRef::new(future, scheduler, task_id).ptr.cast(); + (Self(ptr), Self(ptr), Self(ptr)) + } + + pub(crate) fn clone_from_raw(ptr: NonNull
) -> Self { + let this = Self(ptr); + this.state().ref_inc(); + this + } + + pub(super) unsafe fn from_raw(ptr: NonNull
) -> Self { + Self(ptr) + } + + pub(in crate::executor) fn header_ptr(&self) -> NonNull
{ + self.0 + } + pub(super) fn header(&self) -> &Header { + // Safety: constructor ensures the pointer is always valid + unsafe { self.0.as_ref() } + } + /// Returns a reference to the task's state. + pub(super) fn state(&self) -> &State { + &self.header().state + } + + pub(in crate::executor) fn run(self) { + self.poll(); + mem::forget(self); + } + + pub(in crate::executor) fn poll(&self) { + let vtable = self.header().vtable; + // Safety: constructor ensures the pointer is always valid + unsafe { + (vtable.poll)(self.0); + } + } + pub(super) fn schedule(&self) { + let vtable = self.header().vtable; + // Safety: constructor ensures the pointer is always valid + unsafe { + (vtable.schedule)(self.0); + } + } + pub(super) fn dealloc(&self) { + let vtable = self.header().vtable; + // Safety: constructor ensures the pointer is always valid + unsafe { + (vtable.dealloc)(self.0); + } + } + pub(super) unsafe fn try_read_output(&self, dst: *mut (), waker: &Waker) { + let vtable = self.header().vtable; + // Safety: constructor ensures the pointer is always valid + unsafe { + (vtable.try_read_output)(self.0, dst, waker); + } + } + pub(super) fn drop_join_handle_slow(&self) { + let vtable = self.header().vtable; + // Safety: constructor ensures the pointer is always valid + unsafe { (vtable.drop_join_handle_slow)(self.0) } + } + pub(super) fn shutdown(&self) { + let vtable = self.header().vtable; + // Safety: constructor ensures the pointer is always valid + unsafe { (vtable.shutdown)(self.0) } + } + pub(super) fn drop_reference(&self) { + if self.state().ref_dec() { + self.dealloc(); + } + } + /// This call consumes a ref-count and notifies the task. This will create a + /// new Notified and submit it if necessary. + /// + /// The caller does not need to hold a ref-count besides the one that was + /// passed to this call. + pub(super) fn wake_by_val(&self) { + match self.state().transition_to_notified_by_val() { + TransitionToNotifiedByVal::Submit => { + todo!() + + // // The caller has given us a ref-count, and the transition has + // // created a new ref-count, so we now hold two. We turn the new + // // ref-count Notified and pass it to the call to `schedule`. + // // + // // The old ref-count is retained for now to ensure that the task + // // is not dropped during the call to `schedule` if the call + // // drops the task it was given. + // self.schedule(); + // + // // Now that we have completed the call to schedule, we can + // // release our ref-count. + // self.drop_reference(); + } + TransitionToNotifiedByVal::Dealloc => { + self.dealloc(); + } + TransitionToNotifiedByVal::DoNothing => {} + } + } + + /// This call notifies the task. It will not consume any ref-counts, but the + /// caller should hold a ref-count. This will create a new Notified and + /// submit it if necessary. + pub(super) fn wake_by_ref(&self) { + match self.state().transition_to_notified_by_ref() { + TransitionToNotifiedByRef::Submit => { + // The transition above incremented the ref-count for a new task + // and the caller also holds a ref-count. The caller's ref-count + // ensures that the task is not destroyed even if the new task + // is dropped before `schedule` returns. + self.schedule(); + } + TransitionToNotifiedByRef::DoNothing => {} + } + } + + /// Remotely aborts the task. + /// + /// The caller should hold a ref-count, but we do not consume it. + /// + /// This is similar to `shutdown` except that it asks the runtime to perform + /// the shutdown. This is necessary to avoid the shutdown happening in the + /// wrong thread for non-Send tasks. + pub(super) fn remote_abort(&self) { + if self.state().transition_to_notified_and_cancel() { + // The transition has created a new ref-count, which we turn into + // a Notified and pass to the task. + // + // Since the caller holds a ref-count, the task cannot be destroyed + // before the call to `schedule` returns even if the call drops the + // `Notified` internally. + self.schedule(); + } + } +} + +impl Clone for TaskRef { + #[inline] + #[track_caller] + fn clone(&self) -> Self { + log::trace!("TaskRef::clone {:?}", self.0); + self.state().ref_inc(); + Self(self.0) + } +} + +impl Drop for TaskRef { + #[inline] + #[track_caller] + fn drop(&mut self) { + // log::trace!("TaskRef::drop {:?}", self.0); + if self.state().ref_dec() { + self.dealloc(); + } + } +} + +// Safety: task refs are "just" atomically reference counted pointers and the state lifecycle system ensures mutual +// exclusion for mutating methods, thus this type is always Send +unsafe impl Send for TaskRef {} +// Safety: task refs are "just" atomically reference counted pointers and the state lifecycle system ensures mutual +// exclusion for mutating methods, thus this type is always Sync +unsafe impl Sync for TaskRef {} + +impl RawTaskRef +where + F: Future, + S: Schedule + 'static, +{ + const TASK_VTABLE: Vtable = Vtable { + poll: Self::poll, + schedule: Self::schedule, + dealloc: Self::dealloc, + try_read_output: Self::try_read_output, + drop_join_handle_slow: Self::drop_join_handle_slow, + shutdown: Self::shutdown, + id_offset: get_id_offset::(), + trailer_offset: get_trailer_offset::(), + }; + + pub fn new(future: F, scheduler: S, task_id: Id) -> Self { + let ptr = Box::into_raw(Box::new(Task { + header: Header { + state: State::new(), + vtable: &Self::TASK_VTABLE, + }, + core: Core { + scheduler, + stage: UnsafeCell::new(Stage::Running(future)), + task_id, + }, + trailer: Trailer { + waker: UnsafeCell::new(None), + run_queue_links: mpsc_queue::Links::default(), + owned_tasks_links: linked_list::Links::default(), + }, + })); + + log::trace!( + "allocated task ptr {ptr:?} with layout {:?}", + Layout::new::>() + ); + Self { + // Safety: we just allocated the pointer, it is always valid + ptr: unsafe { NonNull::new_unchecked(ptr) }, + } + } + + unsafe fn poll(ptr: NonNull
) { + // Safety: this method gets called through the vtable ensuring that the pointer is valid + // for this `RawTaskRef`'s `F` and `S` generics. + unsafe { + let this = Self::from_raw(ptr); + + // We pass our ref-count to `poll_inner`. + match Self::poll_inner(ptr) { + PollResult::Notified => { + debug_assert!(this.state().load().ref_count() >= 2); + // The `poll_inner` call has given us two ref-counts back. + // We give one of them to a new task and call `yield_now`. + this.core().scheduler.yield_now(this.get_new_task()); + + // The remaining ref-count is now dropped. We kept the extra + // ref-count until now to ensure that even if the `yield_now` + // call drops the provided task, the task isn't deallocated + // before after `yield_now` returns. + this.drop_reference(); + } + PollResult::Complete => { + this.complete(); + } + PollResult::Dealloc => { + Self::dealloc(ptr); + } + PollResult::Done => (), + } + } + } + + unsafe fn poll_inner(ptr: NonNull
) -> PollResult { + // Safety: caller has to ensure `ptr` is valid + let this = unsafe { Self::from_raw(ptr) }; + + match this.state().transition_to_running() { + TransitionToRunning::Success => { + // Separated to reduce LLVM codegen + fn transition_result_to_poll_result(result: TransitionToIdle) -> PollResult { + match result { + TransitionToIdle::Ok => PollResult::Done, + TransitionToIdle::OkNotified => PollResult::Notified, + TransitionToIdle::OkDealloc => PollResult::Dealloc, + TransitionToIdle::Cancelled => PollResult::Complete, + } + } + let header_ptr = this.header_ptr(); + let waker_ref = waker_ref::(&header_ptr); + let cx = Context::from_waker(&waker_ref); + // Safety: `transition_to_running` returns `Success` only when we have exclusive + // access + let res = unsafe { poll_future(this.core(), cx) }; + + if res == Poll::Ready(()) { + // The future completed. Move on to complete the task. + return PollResult::Complete; + } + + let transition_res = this.state().transition_to_idle(); + if let TransitionToIdle::Cancelled = transition_res { + // The transition to idle failed because the task was + // cancelled during the poll. + // Safety: `transition_to_running` returns `Success` only when we have exclusive + // access + unsafe { cancel_task(this.core()) }; + } + transition_result_to_poll_result(transition_res) + } + TransitionToRunning::Cancelled => { + // Safety: `transition_to_running` returns `Cancelled` only when we have exclusive + // access + unsafe { cancel_task(this.core()) }; + PollResult::Complete + } + TransitionToRunning::Failed => PollResult::Done, + TransitionToRunning::Dealloc => PollResult::Dealloc, + } + } + + unsafe fn schedule(ptr: NonNull
) { + // Safety: this method gets called through the vtable ensuring that the pointer is valid + // for this `RawTaskRef`'s `F` and `S` generics. + unsafe { + let this = Self::from_raw(ptr); + this.core().scheduler.schedule(this.get_new_task()); + } + } + + unsafe fn dealloc(ptr: NonNull
) { + // Safety: The caller of this method just transitioned our ref-count to + // zero, so it is our responsibility to release the allocation. + // + // We don't hold any references into the allocation at this point, but + // it is possible for another thread to still hold a `&State` into the + // allocation if that other thread has decremented its last ref-count, + // but has not yet returned from the relevant method on `State`. + // + // However, the `State` type consists of just an `AtomicUsize`, and an + // `AtomicUsize` wraps the entirety of its contents in an `UnsafeCell`. + // As explained in the documentation for `UnsafeCell`, such references + // are allowed to be dangling after their last use, even if the + // reference has not yet gone out of scope. + // + // Additionally, this method gets called through the vtable ensuring that + // the pointer is valid for this `RawTaskRef`'s `F` and `S` generics. + unsafe { + debug_assert_eq!(ptr.as_ref().state.load().ref_count(), 0); + log::trace!( + "about to dealloc task ptr {:?} with layout {:?}", + ptr.as_ptr(), + Layout::new::>() + ); + drop(Box::from_raw(ptr.cast::>().as_ptr())); + log::trace!("deallocated task"); + } + } + + unsafe fn try_read_output(ptr: NonNull
, dst: *mut (), waker: &Waker) { + // Safety: this method gets called through the vtable ensuring that the pointer is valid + // for this `RawTaskRef`'s `F` and `S` generics. The caller has to ensure the `dst` pointer + // is valid. + unsafe { + let this = Self::from_raw(ptr); + let dst = dst.cast::>>(); + if can_read_output(this.header(), this.trailer(), waker) { + *dst = Poll::Ready(this.core().take_output()); + } + } + } + + unsafe fn drop_join_handle_slow(ptr: NonNull
) { + // Safety: this method gets called through the vtable ensuring that the pointer is valid + // for this `RawTaskRef`'s `F` and `S` generics + unsafe { + let this = Self::from_raw(ptr); + + // Try to unset `JOIN_INTEREST` and `JOIN_WAKER`. This must be done as a first step in + // case the task concurrently completed. + let transition = this.state().transition_to_join_handle_dropped(); + + if transition.drop_output { + // It is our responsibility to drop the output. This is critical as + // the task output may not be `Send` and as such must remain with + // the scheduler or `JoinHandle`. i.e. if the output remains in the + // task structure until the task is deallocated, it may be dropped + // by a Waker on any arbitrary thread. + // + // Panics are delivered to the user via the `JoinHandle`. Given that + // they are dropping the `JoinHandle`, we assume they are not + // interested in the panic and swallow it. + let _ = panic::catch_unwind(AssertUnwindSafe(|| { + this.core().drop_future_or_output(); + })); + } + + if transition.drop_waker { + // If the JOIN_WAKER flag is unset at this point, the task is either + // already terminal or not complete so the `JoinHandle` is responsible + // for dropping the waker. + // Safety: + // If the JOIN_WAKER bit is not set the join handle has exclusive + // access to the waker as per rule 2 in task/mod.rs. + // This can only be the case at this point in two scenarios: + // 1. The task completed and the runtime unset `JOIN_WAKER` flag + // after accessing the waker during task completion. So the + // `JoinHandle` is the only one to access the join waker here. + // 2. The task is not completed so the `JoinHandle` was able to unset + // `JOIN_WAKER` bit itself to get mutable access to the waker. + // The runtime will not access the waker when this flag is unset. + this.trailer().set_waker(None); + } + + // Drop the `JoinHandle` reference, possibly deallocating the task + this.drop_reference(); + } + } + + unsafe fn shutdown(ptr: NonNull
) { + // Safety: this method gets called through the vtable ensuring that the pointer is valid + // for this `RawTaskRef`'s `F` and `S` generics + unsafe { + let this = Self::from_raw(ptr); + + if !this.state().transition_to_shutdown() { + // The task is concurrently running. No further work needed. + this.drop_reference(); + return; + } + + // By transitioning the lifecycle to `Running`, we have permission to + // drop the future. + cancel_task(this.core()); + this.complete(); + } + } + + /// Construct a typed task reference from an untyped pointer to a task. + /// + /// # Safety + /// + /// The caller has to ensure `ptr` is a valid task AND that the tasks output and scheduler + /// match this types generic arguments `F` and `S`. Getting this wrong e.g. calling + /// `RawTaskRef::<(), S>::from_raw` on a task that has the output type `i32` will likely lead + /// to stack corruption. + unsafe fn from_raw(ptr: NonNull
) -> Self { + Self { ptr: ptr.cast() } + } + + fn header_ptr(&self) -> NonNull
{ + self.ptr.cast() + } + + fn header(&self) -> &Header { + // Safety: constructor ensures the pointer is always valid + unsafe { &*self.header_ptr().as_ptr() } + } + + fn state(&self) -> &State { + &self.header().state + } + + fn core(&self) -> &Core { + // Safety: constructor ensures the pointer is always valid + unsafe { &self.ptr.as_ref().core } + } + + fn trailer(&self) -> &Trailer { + // Safety: constructor ensures the pointer is always valid + unsafe { &self.ptr.as_ref().trailer } + } + + fn complete(&self) { + // The future has completed and its output has been written to the task + // stage. We transition from running to complete. + let snapshot = self.state().transition_to_complete(); + + // We catch panics here in case dropping the future or waking the + // JoinHandle panics. + let _ = panic::catch_unwind(AssertUnwindSafe(|| { + if !snapshot.is_join_interested() { + // The `JoinHandle` is not interested in the output of + // this task. It is our responsibility to drop the + // output. The join waker was already dropped by the + // `JoinHandle` before. + // Safety: the COMPLETE bit has been set above and the JOIN_INTEREST bit is unset + // so according to rule 7 we have mutable exclusive access + unsafe { + self.core().drop_future_or_output(); + } + } else if snapshot.is_join_waker_set() { + // Notify the waker. + // Safety: Reading the waker field is safe per rule 4 + // in task/mod.rs, since the JOIN_WAKER bit is set and the call + // to transition_to_complete() above set the COMPLETE bit. + unsafe { + self.trailer().wake_join(); + } + + // Inform the `JoinHandle` that we are done waking the waker by + // unsetting the `JOIN_WAKER` bit. If the `JoinHandle` has + // already been dropped and `JOIN_INTEREST` is unset, then we must + // drop the waker ourselves. + if !self + .state() + .unset_waker_after_complete() + .is_join_interested() + { + // SAFETY: We have COMPLETE=1 and JOIN_INTEREST=0, so + // we have exclusive access to the waker. + unsafe { self.trailer().set_waker(None) }; + } + } + })); + } + + fn drop_reference(self) { + if self.state().ref_dec() { + // Safety: `ref_dec` returns true if no other references exist, so deallocation is safe + unsafe { + Self::dealloc(self.ptr.cast()); + } + } + } + + fn get_new_task(&self) -> TaskRef { + // safety: The header is at the beginning of the cell, so this cast is + // safe. + unsafe { TaskRef::from_raw(self.ptr.cast()) } + } + + // /// Releases the task from the scheduler. Returns the number of ref-counts + // /// that should be decremented. + // fn release(&self) -> usize { + // // We don't actually increment the ref-count here, but the new task is + // // never destroyed, so that's ok. + // let me = ManuallyDrop::new(self.get_new_task()); + // + // if let Some(task) = self.core().scheduler.release(&me) { + // mem::forget(task); + // 2 + // } else { + // 1 + // } + // } + // +} + +struct Stub; +impl Future for Stub { + type Output = (); + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + unreachable!("poll called on a stub future") + } +} + +impl Schedule for Stub { + fn schedule(&self, _task: TaskRef) { + unreachable!("schedule called on a stub scheduler") + } + fn release(&self, _task: &TaskRef) -> Option { + unreachable!("release called on a stub scheduler") + } + fn yield_now(&self, _task: TaskRef) { + unreachable!("yield_now called on a stub scheduler") + } +} + +impl RawTaskRef { + const STUB_VTABLE: Vtable = Vtable { + poll: Self::poll_stub, + schedule: Self::schedule_stub, + dealloc: Self::dealloc, + try_read_output: Self::try_read_output_stub, + drop_join_handle_slow: Self::drop_join_handle_slow_stub, + shutdown: Self::shutdown_stub, + id_offset: get_id_offset::(), + trailer_offset: get_trailer_offset::(), + }; + + pub fn new_stub() -> Self { + let ptr = Box::into_raw(Box::new(Task { + header: Header { + state: State::new(), + vtable: &Self::STUB_VTABLE, + }, + core: Core { + scheduler: Stub, + stage: UnsafeCell::new(Stage::Running(Stub)), + task_id: Id::stub(), + }, + trailer: Trailer { + waker: UnsafeCell::new(None), + run_queue_links: mpsc_queue::Links::default(), + owned_tasks_links: linked_list::Links::default(), + }, + })); + log::trace!("allocated stub ptr {ptr:?}"); + + Self { + // Safety: we just allocated the pointer, it is always valid + ptr: unsafe { NonNull::new_unchecked(ptr) }, + } + } + + unsafe fn poll_stub(_ptr: NonNull
) { + // Safety: this method should never be called + unsafe { + debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); + unreachable!("poll_stub called on a stub task"); + } + } + + unsafe fn schedule_stub(_ptr: NonNull
) { + // Safety: this method should never be called + unsafe { + debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); + unreachable!("schedule_stub called on a stub task"); + } + } + + unsafe fn try_read_output_stub(_ptr: NonNull
, _dst: *mut (), _waker: &Waker) { + // Safety: this method should never be called + unsafe { + debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); + unreachable!("try_read_output_stub called on a stub task"); + } + } + + unsafe fn drop_join_handle_slow_stub(_ptr: NonNull
) { + // Safety: this method should never be called + unsafe { + debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); + unreachable!("drop_join_handle_slow_stub called on a stub task"); + } + } + + /// # Safety + /// + /// The caller must ensure the pointer is valid + unsafe fn shutdown_stub(_ptr: NonNull
) { + // Safety: this method should never be called + unsafe { + debug_assert!(Header::get_id_ptr(_ptr).as_ref().is_stub()); + unreachable!("shutdown_stub called on a stub task"); + } + } +} + +#[expect(tail_expr_drop_order, reason = "TODO")] +fn can_read_output(header: &Header, trailer: &Trailer, waker: &Waker) -> bool { + // Load a snapshot of the current task state + let snapshot = header.state.load(); + + debug_assert!(snapshot.is_join_interested()); + + if !snapshot.is_complete() { + // If the task is not complete, try storing the provided waker in the + // task's waker field. + + let res = if snapshot.is_join_waker_set() { + // If JOIN_WAKER is set, then JoinHandle has previously stored a + // waker in the waker field per step (iii) of rule 5 in task/mod.rs. + + // Optimization: if the stored waker and the provided waker wake the + // same task, then return without touching the waker field. + // Safety: Reading the waker field below is safe per rule 3 in task/mod.rs. + if unsafe { trailer.will_wake(waker) } { + return false; + } + + // Otherwise swap the stored waker with the provided waker by + // following the rule 5 in task/mod.rs. + header + .state + .unset_waker() + .and_then(|snapshot| set_join_waker(header, trailer, waker.clone(), snapshot)) + } else { + // If JOIN_WAKER is unset, then JoinHandle has mutable access to the + // waker field per rule 2 in task/mod.rs; therefore, skip step (i) + // of rule 5 and try to store the provided waker in the waker field. + // Safety: absence of JOIN_WAKER means we have exclusive access + set_join_waker(header, trailer, waker.clone(), snapshot) + }; + + match res { + Ok(_) => return false, + Err(snapshot) => { + assert!(snapshot.is_complete()); + } + } + } + true +} + +fn set_join_waker( + header: &Header, + trailer: &Trailer, + waker: Waker, + snapshot: Snapshot, +) -> Result { + assert!(snapshot.is_join_interested()); + assert!(!snapshot.is_join_waker_set()); + + // Safety: Only the `JoinHandle` may set the `waker` field. When + // `JOIN_INTEREST` is **not** set, nothing else will touch the field. + unsafe { + trailer.set_waker(Some(waker)); + + // Update the `JoinWaker` state accordingly + let res = header.state.set_join_waker(); + + // If the state could not be updated, then clear the join waker + if res.is_err() { + trailer.set_waker(None); + } + + res + } +} + +pub enum PollResult { + Complete, + Notified, + Done, + Dealloc, +} + +/// Cancels the task and store the appropriate error in the stage field. +/// +/// # Safety +/// +/// The caller has to ensure this hart has exclusive mutable access to the tasks `stage` field (ie the +/// future or output). +unsafe fn cancel_task(core: &Core) { + // Safety: caller has to ensure mutual exclusion + unsafe { + // Drop the future from a panic guard. + let res = panic::catch_unwind(AssertUnwindSafe(|| { + core.drop_future_or_output(); + })); + + core.store_output(Err(panic_result_to_join_error(core.task_id, res))); + } +} + +/// Polls the future. If the future completes, the output is written to the +/// stage field. +/// +/// # Safety +/// +/// The caller has to ensure this hart has exclusive mutable access to the tasks `stage` field (ie the +/// future or output). +unsafe fn poll_future(core: &Core, cx: Context<'_>) -> Poll<()> { + // Poll the future. + let output = panic::catch_unwind(AssertUnwindSafe(|| { + struct Guard<'a, T: Future, S: Schedule> { + core: &'a Core, + } + impl Drop for Guard<'_, T, S> { + fn drop(&mut self) { + // If the future panics on poll, we drop it inside the panic + // guard. + // Safety: caller has to ensure mutual exclusion + unsafe { + self.core.drop_future_or_output(); + } + } + } + let guard = Guard { core }; + // Safety: caller has to ensure mutual exclusion + let res = unsafe { guard.core.poll(cx) }; + mem::forget(guard); + res + })); + + // Prepare output for being placed in the core stage. + let output = match output { + Ok(Poll::Pending) => return Poll::Pending, + Ok(Poll::Ready(output)) => Ok(output), + Err(panic) => Err(panic_to_error(core.task_id, panic)), + }; + + // Catch and ignore panics if the future panics on drop. + // Safety: caller has to ensure mutual exclusion + let res = panic::catch_unwind(AssertUnwindSafe(|| unsafe { + core.store_output(output); + })); + + assert!(res.is_ok(), "unhandled panic {res:?}"); + + Poll::Ready(()) +} + +fn panic_result_to_join_error( + task_id: Id, + res: Result<(), Box>, +) -> JoinError { + match res { + Ok(()) => JoinError::cancelled(task_id), + Err(panic) => JoinError::panic(task_id, panic), + } +} + +#[cold] +fn panic_to_error(task_id: Id, panic: Box) -> JoinError { + log::error!("unhandled panic"); + // scheduler().unhandled_panic(); + JoinError::panic(task_id, panic) +} diff --git a/kernel/src/executor/task/state.rs b/kernel/src/executor/task/state.rs index 2fe9e75e..1c2f37f9 100644 --- a/kernel/src/executor/task/state.rs +++ b/kernel/src/executor/task/state.rs @@ -9,16 +9,6 @@ use crate::arch; use core::fmt; use core::sync::atomic::{AtomicUsize, Ordering}; -pub(super) struct State { - val: AtomicUsize, -} - -/// Current state value. -#[derive(Copy, Clone)] -pub(super) struct Snapshot(usize); - -type UpdateResult = Result; - /// The task is currently being run. const RUNNING: usize = 0b0001; @@ -66,6 +56,35 @@ const REF_ONE: usize = 1 << REF_COUNT_SHIFT; /// As the task starts with a `Notified`, `NOTIFIED` is set. const INITIAL_STATE: usize = (REF_ONE * 3) | JOIN_INTEREST | NOTIFIED; +/// Task state. The task stores its state in an atomic `usize` with various bitfields for the +/// necessary information. The state has the following layout: +/// +/// ```text +/// | 63 6 | 5 5 | 4 4 | 3 3 | 2 2 | 1 0 | +/// | refcount | cancelled | join waker | join interest | notified | lifecycle | +/// ``` +/// +/// - `lifecycle` (bit 0n and 1) +/// - `RUNNING` (bit 0) - Tracks whether the task is currently being polled or cancelled. +/// - `COMPLETE` (bit 1) - Is one once the future has fully completed and has been dropped. +/// Never unset once set. Never set together with `RUNNING`. +/// - `NOTIFIED` (bit 2) - Tracks whether a Notified object currently exists. +/// - `JOIN_INTEREST` - Is set to one if there exists a `JoinHandle`. +/// - `JOIN_WAKER` - Acts as an access control bit for the join handle waker. The +/// protocol for its usage is described below. +/// - `CANCELLED` (bit 3) - Is set to one for tasks that should be cancelled as soon as possible. +/// +/// The rest of the bits are used for the ref-count. +pub(super) struct State { + val: AtomicUsize, +} + +/// Current state value. +#[derive(Copy, Clone)] +pub(super) struct Snapshot(usize); + +type UpdateResult = Result; + #[must_use] pub(super) enum TransitionToRunning { /// We successfully transitioned the task to the RUNNING state @@ -106,8 +125,6 @@ pub(super) struct TransitionToJoinHandleDrop { pub(super) drop_output: bool, } -/// All transitions are performed via RMW operations. This establishes an -/// unambiguous modification order. impl State { /// Returns a task's initial state. pub(super) fn new() -> State { @@ -123,8 +140,6 @@ impl State { Snapshot(self.val.load(Ordering::Acquire)) } - /// Attempts to transition the lifecycle to `Running`. This sets the - /// notified bit to false so notifications during the poll can be detected. pub(super) fn transition_to_running(&self) -> TransitionToRunning { self.fetch_update_action(|mut next| { let action; @@ -203,11 +218,6 @@ impl State { } /// Transitions the state to `NOTIFIED`. - /// - /// If no task needs to be submitted, a ref-count is consumed. - /// - /// If a task needs to be submitted, the ref-count is incremented for the - /// new Notified. pub(super) fn transition_to_notified_by_val(&self) -> TransitionToNotifiedByVal { self.fetch_update_action(|mut snapshot| { let action; @@ -246,6 +256,11 @@ impl State { } /// Transitions the state to `NOTIFIED`. + /// + /// If no task needs to be submitted, a ref-count is consumed. + /// + /// If a task needs to be submitted, the ref-count is incremented for the + /// new Notified. pub(super) fn transition_to_notified_by_ref(&self) -> TransitionToNotifiedByRef { self.fetch_update_action(|mut snapshot| { if snapshot.is_complete() || snapshot.is_notified() { @@ -267,30 +282,6 @@ impl State { }) } - // /// Transitions the state to `NOTIFIED`, unconditionally increasing the ref - // /// count. - // /// - // /// Returns `true` if the notified bit was transitioned from `0` to `1`; - // /// otherwise `false.` - // #[cfg(all( - // tokio_unstable, - // tokio_taskdump, - // feature = "rt", - // target_os = "linux", - // any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64") - // ))] - // pub(super) fn transition_to_notified_for_tracing(&self) -> bool { - // self.fetch_update_action(|mut snapshot| { - // if snapshot.is_notified() { - // (false, None) - // } else { - // snapshot.set_notified(); - // snapshot.ref_inc(); - // (true, Some(snapshot)) - // } - // }) - // } - /// Sets the cancelled bit and transitions the state to `NOTIFIED` if idle. /// /// Returns `true` if the task needs to be submitted to the pool for @@ -466,6 +457,7 @@ impl State { Snapshot(prev.0 & !JOIN_WAKER) } + /// Increases the reference count by one. pub(super) fn ref_inc(&self) { // Using a relaxed ordering is alright here, as knowledge of the // original reference prevents other threads from erroneously deleting @@ -486,20 +478,13 @@ impl State { } } - /// Returns `true` if the task should be released. + /// Decreases the reference count by one, returning `true` if the task should be released. pub(super) fn ref_dec(&self) -> bool { let prev = Snapshot(self.val.fetch_sub(REF_ONE, Ordering::AcqRel)); assert!(prev.ref_count() >= 1); prev.ref_count() == 1 } - // /// Returns `true` if the task should be released. - // pub(super) fn ref_dec_twice(&self) -> bool { - // let prev = Snapshot(self.val.fetch_sub(2 * REF_ONE, Ordering::AcqRel)); - // assert!(prev.ref_count() >= 2); - // prev.ref_count() == 2 - // } - fn fetch_update_action(&self, mut f: F) -> T where F: FnMut(Snapshot) -> (T, Option), @@ -546,8 +531,6 @@ impl State { } } -// ===== impl Snapshot ===== - impl Snapshot { /// Returns `true` if the task is in an idle state. pub(super) fn is_idle(self) -> bool { @@ -612,7 +595,7 @@ impl Snapshot { self.0 &= !JOIN_WAKER; } - pub fn ref_count(self) -> usize { + pub(super) fn ref_count(self) -> usize { (self.0 & REF_COUNT_MASK) >> REF_COUNT_SHIFT } @@ -621,7 +604,7 @@ impl Snapshot { self.0 += REF_ONE; } - pub(super) fn ref_dec(&mut self) { + fn ref_dec(&mut self) { assert!(self.ref_count() > 0); self.0 -= REF_ONE; } diff --git a/kernel/src/executor/task/waker.rs b/kernel/src/executor/task/waker.rs index b8deb8ec..890b8bca 100644 --- a/kernel/src/executor/task/waker.rs +++ b/kernel/src/executor/task/waker.rs @@ -6,7 +6,7 @@ // copied, modified, or distributed except according to those terms. use super::raw::Header; -use super::raw::TaskRef; +use crate::executor::task::TaskRef; use core::marker::PhantomData; use core::mem::ManuallyDrop; use core::ptr::NonNull; diff --git a/kernel/src/executor/wake_list.rs b/kernel/src/executor/wake_list.rs new file mode 100644 index 00000000..acbade7b --- /dev/null +++ b/kernel/src/executor/wake_list.rs @@ -0,0 +1,90 @@ +// Copyright 2025 Jonas Kruckenberg +// +// Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use core::mem::MaybeUninit; +use core::ptr; +use core::task::Waker; + +const NUM_WAKERS: usize = 32; + +/// A list of wakers to be woken. +/// +/// # Invariants +/// +/// The first `curr` elements of `inner` are initialized. +pub(crate) struct WakeList { + inner: [MaybeUninit; NUM_WAKERS], + curr: usize, +} + +impl WakeList { + pub(crate) fn new() -> Self { + const UNINIT_WAKER: MaybeUninit = MaybeUninit::uninit(); + + Self { + inner: [UNINIT_WAKER; NUM_WAKERS], + curr: 0, + } + } + + #[inline] + pub(crate) fn can_push(&self) -> bool { + self.curr < NUM_WAKERS + } + + pub(crate) fn push(&mut self, val: Waker) { + debug_assert!(self.can_push()); + + self.inner[self.curr] = MaybeUninit::new(val); + self.curr += 1; + } + + pub(crate) fn wake_all(&mut self) { + struct DropGuard { + start: *mut Waker, + end: *mut Waker, + } + + impl Drop for DropGuard { + fn drop(&mut self) { + // SAFETY: Both pointers are part of the same object, with `start <= end`. + let len = usize::try_from(unsafe { self.end.offset_from(self.start) }).unwrap(); + let slice = ptr::slice_from_raw_parts_mut(self.start, len); + // SAFETY: All elements in `start..len` are initialized, so we can drop them. + unsafe { ptr::drop_in_place(slice) }; + } + } + + debug_assert!(self.curr <= NUM_WAKERS); + + let mut guard = { + let start = self.inner.as_mut_ptr().cast::(); + // SAFETY: The resulting pointer is in bounds or one after the length of the same object. + let end = unsafe { start.add(self.curr) }; + // Transfer ownership of the wakers in `inner` to `DropGuard`. + self.curr = 0; + DropGuard { start, end } + }; + while !ptr::eq(guard.start, guard.end) { + // SAFETY: `start` is always initialized if `start != end`. + let waker = unsafe { ptr::read(guard.start) }; + // SAFETY: The resulting pointer is in bounds or one after the length of the same object. + guard.start = unsafe { guard.start.add(1) }; + // If this panics, then `guard` will clean up the remaining wakers. + waker.wake(); + } + } +} + +impl Drop for WakeList { + fn drop(&mut self) { + let slice = + ptr::slice_from_raw_parts_mut(self.inner.as_mut_ptr().cast::(), self.curr); + // SAFETY: The first `curr` elements are initialized, so we can drop them. + unsafe { ptr::drop_in_place(slice) }; + } +} diff --git a/kernel/src/panic.rs b/kernel/src/panic.rs index 08531248..04a70c06 100644 --- a/kernel/src/panic.rs +++ b/kernel/src/panic.rs @@ -75,6 +75,10 @@ where }) } +pub(crate) fn resume_unwind(payload: Box) { + rust_panic(payload) +} + /// Entry point for panics from the `core` crate. #[panic_handler] fn begin_panic_handler(info: &core::panic::PanicInfo<'_>) -> ! {