Skip to content

Commit ee6ae5a

Browse files
committed
io: add tokio_util::io::simplex
Signed-off-by: ADD-SP <[email protected]>
1 parent 925c614 commit ee6ae5a

File tree

4 files changed

+391
-1
lines changed

4 files changed

+391
-1
lines changed

tokio-util/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ net = ["tokio/net"]
2727
compat = ["futures-io"]
2828
codec = []
2929
time = ["tokio/time", "slab"]
30-
io = []
30+
io = ["tokio/rt"]
3131
io-util = ["io", "tokio/rt", "tokio/io-util"]
3232
rt = ["tokio/rt", "tokio/sync", "futures-util"]
3333
join-map = ["rt", "hashbrown"]

tokio-util/src/io/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ mod copy_to_bytes;
1414
mod inspect;
1515
mod read_buf;
1616
mod reader_stream;
17+
pub mod simplex;
1718
mod sink_writer;
1819
mod stream_reader;
1920

tokio-util/src/io/simplex.rs

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
//! Unidirectional byte-oriented channel.
2+
3+
use bytes::Buf;
4+
use bytes::BytesMut;
5+
use futures_core::ready;
6+
use std::io::Error as IoError;
7+
use std::io::ErrorKind as IoErrorKind;
8+
use std::pin::Pin;
9+
use std::sync::{Arc, Mutex};
10+
use std::task::{Context, Poll, Waker};
11+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12+
use tokio::task::coop::poll_proceed;
13+
14+
type IoResult<T> = Result<T, IoError>;
15+
16+
#[derive(Debug)]
17+
struct Inner {
18+
/// `poll_*` will return [`Poll::Pending`] if the backpressure boundary is reached
19+
backpressure_boundary: usize,
20+
21+
/// either [`Sender`] or [`Receiver`] is closed
22+
is_closed: bool,
23+
24+
/// Waker used to wake the [`Receiver`]
25+
receiver_waker: Option<Waker>,
26+
27+
/// Waker used to wake the [`Sender`]
28+
sender_waker: Option<Waker>,
29+
30+
/// Buffer used to read and write data
31+
buf: BytesMut,
32+
}
33+
34+
impl Inner {
35+
fn with_capacity(backpressure_boundary: usize) -> Self {
36+
Self {
37+
backpressure_boundary,
38+
is_closed: false,
39+
receiver_waker: None,
40+
sender_waker: None,
41+
buf: BytesMut::new(),
42+
}
43+
}
44+
45+
fn register_receiver_waker(&mut self, waker: &Waker) {
46+
match self.receiver_waker.as_mut() {
47+
Some(old) if old.will_wake(waker) => {}
48+
Some(old) => old.clone_from(waker),
49+
None => self.receiver_waker = Some(waker.clone()),
50+
}
51+
}
52+
53+
fn register_sender_waker(&mut self, waker: &Waker) {
54+
match self.sender_waker.as_mut() {
55+
Some(old) if old.will_wake(waker) => {}
56+
Some(old) => old.clone_from(waker),
57+
None => self.sender_waker = Some(waker.clone()),
58+
}
59+
}
60+
61+
fn wake_receiver(&mut self) {
62+
if let Some(waker) = self.receiver_waker.take() {
63+
waker.wake();
64+
}
65+
}
66+
67+
fn wake_sender(&mut self) {
68+
if let Some(waker) = self.sender_waker.take() {
69+
waker.wake();
70+
}
71+
}
72+
73+
fn is_closed(&self) -> bool {
74+
self.is_closed
75+
}
76+
77+
fn close_receiver(&mut self) {
78+
self.is_closed = true;
79+
self.wake_sender();
80+
}
81+
82+
fn close_sender(&mut self) {
83+
self.is_closed = true;
84+
self.wake_receiver();
85+
}
86+
}
87+
88+
/// Receiver of the simplex channel.
89+
///
90+
/// You can still read the remaining data from the buffer
91+
/// even if the write half has been dropped.
92+
/// See [`Sender::poll_shutdown`] and [`Sender::drop`] for more details.
93+
#[derive(Debug)]
94+
pub struct Receiver {
95+
inner: Arc<Mutex<Inner>>,
96+
}
97+
98+
impl Drop for Receiver {
99+
/// This also wakes up the [`Sender`].
100+
fn drop(&mut self) {
101+
self.inner.lock().unwrap().close_receiver();
102+
}
103+
}
104+
105+
impl AsyncRead for Receiver {
106+
fn poll_read(
107+
self: Pin<&mut Self>,
108+
cx: &mut Context<'_>,
109+
buf: &mut ReadBuf<'_>,
110+
) -> Poll<IoResult<()>> {
111+
let mut inner = self.inner.lock().unwrap();
112+
113+
let to_read = buf.remaining().min(inner.buf.remaining());
114+
if to_read == 0 {
115+
return if inner.is_closed() {
116+
Poll::Ready(Ok(()))
117+
} else {
118+
inner.register_receiver_waker(cx.waker());
119+
inner.wake_sender();
120+
Poll::Pending
121+
};
122+
}
123+
124+
ready!(poll_proceed(cx)).made_progress();
125+
126+
buf.put_slice(&inner.buf[..to_read]);
127+
inner.buf.advance(to_read);
128+
inner.wake_sender();
129+
Poll::Ready(Ok(()))
130+
}
131+
}
132+
133+
/// Sender of the simplex channel.
134+
///
135+
/// ## Shutdown
136+
///
137+
/// See [`Sender::poll_shutdown`].
138+
#[derive(Debug)]
139+
pub struct Sender {
140+
inner: Arc<Mutex<Inner>>,
141+
}
142+
143+
impl Drop for Sender {
144+
/// This also wakes up the [`Receiver`].
145+
fn drop(&mut self) {
146+
self.inner.lock().unwrap().close_sender();
147+
}
148+
}
149+
150+
impl AsyncWrite for Sender {
151+
/// # Error
152+
///
153+
/// This method will return [`IoErrorKind::BrokenPipe`]
154+
/// if the channel has been closed.
155+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
156+
let mut inner = self.inner.lock().unwrap();
157+
158+
if inner.is_closed() {
159+
return Poll::Ready(Err(IoError::new(
160+
IoErrorKind::BrokenPipe,
161+
"simplex has been closed",
162+
)));
163+
}
164+
165+
let free = inner
166+
.backpressure_boundary
167+
.checked_sub(inner.buf.len())
168+
.expect("backpressure boundary overflow");
169+
let to_write = buf.len().min(free);
170+
if to_write == 0 {
171+
inner.register_sender_waker(cx.waker());
172+
inner.wake_receiver();
173+
return Poll::Pending;
174+
}
175+
176+
// this is to avoid starving other tasks
177+
ready!(poll_proceed(cx)).made_progress();
178+
179+
inner.buf.extend_from_slice(&buf[..to_write]);
180+
inner.wake_receiver();
181+
Poll::Ready(Ok(to_write))
182+
}
183+
184+
/// # Error
185+
///
186+
/// This method will return [`IoErrorKind::BrokenPipe`]
187+
/// if the channel has been closed.
188+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
189+
let inner = self.inner.lock().unwrap();
190+
if inner.is_closed() {
191+
Poll::Ready(Err(IoError::new(
192+
IoErrorKind::BrokenPipe,
193+
"simplex has been shut down",
194+
)))
195+
} else {
196+
Poll::Ready(Ok(()))
197+
}
198+
}
199+
200+
/// After returns [`Poll::Ready`], all the following call to
201+
/// [`Sender::poll_write`] and [`Sender::poll_flush`]
202+
/// will return error.
203+
///
204+
/// The [`Receiver`] can still be used to read remaining data
205+
/// until all bytes have been consumed.
206+
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<IoResult<()>> {
207+
let mut inner = self.inner.lock().unwrap();
208+
209+
if inner.is_closed() {
210+
Poll::Ready(Err(IoError::new(
211+
IoErrorKind::BrokenPipe,
212+
"simplex has already been shut down, cannot be shut down again",
213+
)))
214+
} else {
215+
inner.close_sender();
216+
Poll::Ready(Ok(()))
217+
}
218+
}
219+
}
220+
221+
/// Create a simplex channel.
222+
///
223+
/// The `capacity` parameter specifies the maximum number of bytes that can be
224+
/// stored in the channel without making the [`Sender::poll_write`]
225+
/// return [`Poll::Pending`].
226+
pub fn new(capacity: usize) -> (Sender, Receiver) {
227+
let inner = Arc::new(Mutex::new(Inner::with_capacity(capacity)));
228+
let tx = Sender {
229+
inner: Arc::clone(&inner),
230+
};
231+
let rx = Receiver { inner };
232+
(tx, rx)
233+
}

0 commit comments

Comments
 (0)