Skip to content

Commit 453dbd2

Browse files
committed
rt: add support for non-send closures for thread (un)parking
Add support for non `Send`+`Sync` closures for thread parking and unparking callbacks when using a `LocalRuntime`. Since a `LocalRuntime` will always run its tasks on the same thread, its safe to accept a non `Send`+`Sync` closure. Signed-off-by: Sanskar Jaiswal <[email protected]>
1 parent 925c614 commit 453dbd2

File tree

3 files changed

+194
-14
lines changed

3 files changed

+194
-14
lines changed

tokio/src/runtime/builder.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ impl Builder {
927927
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
928928
pub fn build_local(&mut self, options: LocalOptions) -> io::Result<LocalRuntime> {
929929
match &self.kind {
930-
Kind::CurrentThread => self.build_current_thread_local_runtime(),
930+
Kind::CurrentThread => self.build_current_thread_local_runtime(options),
931931
#[cfg(feature = "rt-multi-thread")]
932932
Kind::MultiThread => panic!("multi_thread is not supported for LocalRuntime"),
933933
}
@@ -1439,11 +1439,16 @@ impl Builder {
14391439
}
14401440

14411441
#[cfg(tokio_unstable)]
1442-
fn build_current_thread_local_runtime(&mut self) -> io::Result<LocalRuntime> {
1442+
fn build_current_thread_local_runtime(
1443+
&mut self,
1444+
opts: LocalOptions,
1445+
) -> io::Result<LocalRuntime> {
14431446
use crate::runtime::local_runtime::LocalRuntimeScheduler;
14441447

14451448
let tid = std::thread::current().id();
14461449

1450+
self.before_park = opts.before_park;
1451+
self.after_unpark = opts.after_unpark;
14471452
let (scheduler, handle, blocking_pool) =
14481453
self.build_current_thread_runtime_components(Some(tid))?;
14491454

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,159 @@
11
use std::marker::PhantomData;
22

3+
use crate::runtime::Callback;
4+
35
/// [`LocalRuntime`]-only config options
46
///
5-
/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may
6-
/// be added.
7-
///
87
/// Use `LocalOptions::default()` to create the default set of options. This type is used with
98
/// [`Builder::build_local`].
109
///
10+
/// When using [`Builder::build_local`], this overrides any pre-configured options set on the
11+
/// [`Builder`].
12+
///
1113
/// [`Builder::build_local`]: crate::runtime::Builder::build_local
1214
/// [`LocalRuntime`]: crate::runtime::LocalRuntime
13-
#[derive(Default, Debug)]
15+
/// [`Builder`]: crate::runtime::Builder
16+
#[derive(Default)]
1417
#[non_exhaustive]
18+
#[allow(missing_debug_implementations)]
1519
pub struct LocalOptions {
1620
/// Marker used to make this !Send and !Sync.
1721
_phantom: PhantomData<*mut u8>,
22+
23+
/// To run before the local runtime is parked.
24+
pub(crate) before_park: Option<Callback>,
25+
26+
/// To run before the local runtime is spawned.
27+
pub(crate) after_unpark: Option<Callback>,
28+
}
29+
30+
impl std::fmt::Debug for LocalOptions {
31+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32+
f.debug_struct("LocalOptions")
33+
.field("before_park", &self.before_park.as_ref().map(|_| "..."))
34+
.field("after_unpark", &self.after_unpark.as_ref().map(|_| "..."))
35+
.finish()
36+
}
37+
}
38+
39+
impl LocalOptions {
40+
/// Executes function `f` just before the local runtime is parked (goes idle).
41+
/// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn)
42+
/// can be called, and may result in this thread being unparked immediately.
43+
///
44+
/// This can be used to start work only when the executor is idle, or for bookkeeping
45+
/// and monitoring purposes.
46+
///
47+
/// This differs from the [`Builder::on_thread_park`] method in that it accepts a non Send + Sync
48+
/// closure.
49+
///
50+
/// Note: There can only be one park callback for a runtime; calling this function
51+
/// more than once replaces the last callback defined, rather than adding to it.
52+
///
53+
/// # Examples
54+
///
55+
/// ```
56+
/// # use tokio::runtime::{Builder, LocalOptions};
57+
/// # pub fn main() {
58+
/// let (tx, rx) = std::sync::mpsc::channel();
59+
/// let mut opts = LocalOptions::default();
60+
/// opts.on_thread_park(move || match rx.recv() {
61+
/// Ok(x) => println!("Received from channel: {}", x),
62+
/// Err(e) => println!("Error receiving from channel: {}", e),
63+
/// });
64+
///
65+
/// let runtime = Builder::new_current_thread()
66+
/// .enable_time()
67+
/// .build_local(opts)
68+
/// .unwrap();
69+
///
70+
/// runtime.block_on(async {
71+
/// tokio::task::spawn_local(async move {
72+
/// tx.send(42).unwrap();
73+
/// });
74+
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
75+
/// })
76+
/// # }
77+
/// ```
78+
///
79+
/// [`Builder`]: crate::runtime::Builder
80+
/// [`Builder::on_thread_park`]: crate::runtime::Builder::on_thread_park
81+
pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self
82+
where
83+
F: Fn() + 'static,
84+
{
85+
self.before_park = Some(std::sync::Arc::new(to_send_sync(f)));
86+
self
87+
}
88+
89+
/// Executes function `f` just after the local runtime unparks (starts executing tasks).
90+
///
91+
/// This is intended for bookkeeping and monitoring use cases; note that work
92+
/// in this callback will increase latencies when the application has allowed one or
93+
/// more runtime threads to go idle.
94+
///
95+
/// This differs from the [`Builder::on_thread_unpark`] method in that it accepts a non Send + Sync
96+
/// closure.
97+
///
98+
/// Note: There can only be one unpark callback for a runtime; calling this function
99+
/// more than once replaces the last callback defined, rather than adding to it.
100+
///
101+
/// # Examples
102+
///
103+
/// ```
104+
/// # use tokio::runtime::{Builder, LocalOptions};
105+
/// # pub fn main() {
106+
/// let (tx, rx) = std::sync::mpsc::channel();
107+
/// let mut opts = LocalOptions::default();
108+
/// opts.on_thread_unpark(move || match rx.recv() {
109+
/// Ok(x) => println!("Received from channel: {}", x),
110+
/// Err(e) => println!("Error receiving from channel: {}", e),
111+
/// });
112+
///
113+
/// let runtime = Builder::new_current_thread()
114+
/// .enable_time()
115+
/// .build_local(opts)
116+
/// .unwrap();
117+
///
118+
/// runtime.block_on(async {
119+
/// tokio::task::spawn_local(async move {
120+
/// tx.send(42).unwrap();
121+
/// });
122+
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
123+
/// })
124+
/// # }
125+
/// ```
126+
///
127+
/// [`Builder`]: crate::runtime::Builder
128+
/// [`Builder::on_thread_unpark`]: crate::runtime::Builder::on_thread_unpark
129+
pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self
130+
where
131+
F: Fn() + 'static,
132+
{
133+
self.after_unpark = Some(std::sync::Arc::new(to_send_sync(f)));
134+
self
135+
}
136+
}
137+
138+
// A wrapper type to allow non-Send + Sync closures to be used in a Send + Sync context.
139+
// This is specifically used for executing callbacks when using a `LocalRuntime`.
140+
struct UnsafeSendSync<T>(T);
141+
142+
// SAFETY: This type is only used in a context where it is guaranteed that the closure will not be
143+
// sent across threads.
144+
unsafe impl<T> Send for UnsafeSendSync<T> {}
145+
unsafe impl<T> Sync for UnsafeSendSync<T> {}
146+
147+
impl<T: Fn()> UnsafeSendSync<T> {
148+
fn call(&self) {
149+
(self.0)()
150+
}
151+
}
152+
153+
fn to_send_sync<F>(f: F) -> impl Fn() + Send + Sync
154+
where
155+
F: Fn(),
156+
{
157+
let f = UnsafeSendSync(f);
158+
move || f.call()
18159
}

tokio/tests/rt_local.rs

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use tokio::task::spawn_local;
66

77
#[test]
88
fn test_spawn_local_in_runtime() {
9-
let rt = rt();
9+
let rt = rt(LocalOptions::default());
1010

1111
let res = rt.block_on(async move {
1212
let (tx, rx) = tokio::sync::oneshot::channel();
@@ -22,9 +22,43 @@ fn test_spawn_local_in_runtime() {
2222
assert_eq!(res, 5);
2323
}
2424

25+
#[test]
26+
fn test_on_thread_park_unpark_in_runtime() {
27+
let mut opts = LocalOptions::default();
28+
29+
// the refcell makes the below callbacks `!Send + !Sync`
30+
let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false));
31+
let on_park_cc = on_park_called.clone();
32+
opts.on_thread_park(move || {
33+
*on_park_cc.borrow_mut() = true;
34+
});
35+
36+
let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false));
37+
let on_unpark_cc = on_unpark_called.clone();
38+
opts.on_thread_unpark(move || {
39+
*on_unpark_cc.borrow_mut() = true;
40+
});
41+
let rt = rt(opts);
42+
43+
rt.block_on(async move {
44+
let (tx, rx) = tokio::sync::oneshot::channel();
45+
46+
spawn_local(async {
47+
tokio::task::yield_now().await;
48+
tx.send(5).unwrap();
49+
});
50+
51+
// this ensures on_thread_park is called
52+
rx.await.unwrap()
53+
});
54+
55+
assert!(*on_park_called.borrow());
56+
assert!(*on_unpark_called.borrow());
57+
}
58+
2559
#[test]
2660
fn test_spawn_from_handle() {
27-
let rt = rt();
61+
let rt = rt(LocalOptions::default());
2862

2963
let (tx, rx) = tokio::sync::oneshot::channel();
3064

@@ -40,7 +74,7 @@ fn test_spawn_from_handle() {
4074

4175
#[test]
4276
fn test_spawn_local_on_runtime_object() {
43-
let rt = rt();
77+
let rt = rt(LocalOptions::default());
4478

4579
let (tx, rx) = tokio::sync::oneshot::channel();
4680

@@ -56,7 +90,7 @@ fn test_spawn_local_on_runtime_object() {
5690

5791
#[test]
5892
fn test_spawn_local_from_guard() {
59-
let rt = rt();
93+
let rt = rt(LocalOptions::default());
6094

6195
let (tx, rx) = tokio::sync::oneshot::channel();
6296

@@ -78,7 +112,7 @@ fn test_spawn_from_guard_other_thread() {
78112
let (tx, rx) = std::sync::mpsc::channel();
79113

80114
std::thread::spawn(move || {
81-
let rt = rt();
115+
let rt = rt(LocalOptions::default());
82116
let handle = rt.handle().clone();
83117

84118
tx.send(handle).unwrap();
@@ -98,7 +132,7 @@ fn test_spawn_local_from_guard_other_thread() {
98132
let (tx, rx) = std::sync::mpsc::channel();
99133

100134
std::thread::spawn(move || {
101-
let rt = rt();
135+
let rt = rt(LocalOptions::default());
102136
let handle = rt.handle().clone();
103137

104138
tx.send(handle).unwrap();
@@ -111,9 +145,9 @@ fn test_spawn_local_from_guard_other_thread() {
111145
spawn_local(async {});
112146
}
113147

114-
fn rt() -> tokio::runtime::LocalRuntime {
148+
fn rt(opts: LocalOptions) -> tokio::runtime::LocalRuntime {
115149
tokio::runtime::Builder::new_current_thread()
116150
.enable_all()
117-
.build_local(LocalOptions::default())
151+
.build_local(opts)
118152
.unwrap()
119153
}

0 commit comments

Comments
 (0)