Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for using CompletionPacket for overlapped I/O #137

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ features = [
[dev-dependencies]
easy-parallel = "3.1.0"
fastrand = "2.0.0"
tracing-subscriber = "0.3"

[target.'cfg(any(unix, target_os = "fuchsia", target_os = "vxworks"))'.dev_dependencies]
libc = "0.2"

[target.'cfg(windows)'.dev_dependencies]
tempfile = "3.7"
4 changes: 2 additions & 2 deletions src/iocp/afd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use windows_sys::Win32::System::WindowsProgramming::{IO_STATUS_BLOCK, OBJECT_ATT

#[derive(Default)]
#[repr(C)]
pub(super) struct AfdPollInfo {
pub(crate) struct AfdPollInfo {
/// The timeout for this poll.
timeout: i64,

Expand Down Expand Up @@ -561,7 +561,7 @@ impl<T> OnceCell<T> {
pin_project_lite::pin_project! {
/// An I/O status block paired with some auxillary data.
#[repr(C)]
pub(super) struct IoStatusBlock<T> {
pub(crate) struct IoStatusBlock<T> {
// The I/O status block.
iosb: UnsafeCell<IO_STATUS_BLOCK>,

Expand Down
46 changes: 19 additions & 27 deletions src/iocp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
mod afd;
mod port;

use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo, IoStatusBlock};
use afd::{base_socket, Afd, AfdPollInfo, AfdPollMask, HasAfdInfo};
use port::{IoCompletionPort, OverlappedEntry};

pub(crate) use afd::IoStatusBlock;
pub(crate) use port::{Completion, CompletionHandle};

use windows_sys::Win32::Foundation::{
BOOLEAN, ERROR_INVALID_HANDLE, ERROR_IO_PENDING, STATUS_CANCELLED,
};
Expand Down Expand Up @@ -497,7 +500,7 @@ impl Poller {
}

/// Push an IOCP packet into the queue.
pub(super) fn post(&self, packet: CompletionPacket) -> io::Result<()> {
pub(super) fn post(&self, packet: crate::os::iocp::CompletionPacket) -> io::Result<()> {
self.port.post(0, 0, packet.0)
}

Expand Down Expand Up @@ -682,38 +685,17 @@ impl EventExtra {
}
}

/// A packet used to wake up the poller with an event.
#[derive(Debug, Clone)]
pub struct CompletionPacket(Packet);

impl CompletionPacket {
/// Create a new completion packet with a custom event.
pub fn new(event: Event) -> Self {
Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event })))
}

/// Get the event associated with this packet.
pub fn event(&self) -> &Event {
let data = self.0.as_ref().data().project_ref();

match data {
PacketInnerProj::Custom { event } => event,
_ => unreachable!(),
}
}
}

/// The type of our completion packet.
///
/// It needs to be pinned, since it contains data that is expected by IOCP not to be moved.
type Packet = Pin<Arc<PacketUnwrapped>>;
pub(crate) type Packet = Pin<Arc<PacketUnwrapped>>;
type PacketUnwrapped = IoStatusBlock<PacketInner>;

pin_project! {
/// The inner type of the packet.
#[project_ref = PacketInnerProj]
#[project = PacketInnerProjMut]
enum PacketInner {
pub(crate) enum PacketInner {
// A packet for a socket.
Socket {
// The AFD packet state.
Expand Down Expand Up @@ -769,6 +751,16 @@ impl HasAfdInfo for PacketInner {
}

impl PacketUnwrapped {
/// If this is an event packet, get the event.
pub(crate) fn event(self: Pin<&Self>) -> &Event {
let data = self.data().project_ref();

match data {
PacketInnerProj::Custom { event } => event,
_ => unreachable!(),
}
}

/// Set the new events that this socket is waiting on.
///
/// Returns `true` if we need to be updated.
Expand Down Expand Up @@ -1085,7 +1077,7 @@ impl PacketUnwrapped {

/// Per-socket state.
#[derive(Debug)]
struct SocketState {
pub(crate) struct SocketState {
/// The raw socket handle.
socket: RawSocket,

Expand Down Expand Up @@ -1130,7 +1122,7 @@ enum SocketStatus {

/// Per-waitable handle state.
#[derive(Debug)]
struct WaitableState {
pub(crate) struct WaitableState {
/// The handle that this state is for.
handle: RawHandle,

Expand Down
4 changes: 2 additions & 2 deletions src/iocp/port.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use windows_sys::Win32::System::IO::{
/// # Safety
///
/// This must be a valid completion block.
pub(super) unsafe trait Completion {
pub(crate) unsafe trait Completion {
/// Signal to the completion block that we are about to start an operation.
fn try_lock(self: Pin<&Self>) -> bool;

Expand All @@ -40,7 +40,7 @@ pub(super) unsafe trait Completion {
/// # Safety
///
/// This must be a valid completion block.
pub(super) unsafe trait CompletionHandle: Deref + Sized {
pub(crate) unsafe trait CompletionHandle: Deref + Sized {
/// Type of the completion block.
type Completion: Completion;

Expand Down
41 changes: 40 additions & 1 deletion src/os/iocp.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,52 @@
//! Functionality that is only availale for IOCP-based platforms.

pub use crate::sys::CompletionPacket;
use crate::sys::{Completion, CompletionHandle, IoStatusBlock, Packet, PacketInner};

use super::__private::PollerSealed;
use crate::{Event, PollMode, Poller};

use std::io;
use std::os::windows::io::{AsRawHandle, RawHandle};
use std::os::windows::prelude::{AsHandle, BorrowedHandle};
use std::sync::Arc;

/// A packet used to wake up the poller with an event.
#[derive(Debug, Clone)]
pub struct CompletionPacket(pub(crate) Packet);

impl CompletionPacket {
/// Create a new completion packet with a custom event.
pub fn new(event: Event) -> Self {
Self(Arc::pin(IoStatusBlock::from(PacketInner::Custom { event })))
}

/// Get the event associated with this packet.
pub fn event(&self) -> &Event {
self.0.as_ref().event()
}

/// Get a pointer to the underlying I/O status block.
///
/// This pointer can be used as an `OVERLAPPED` block in Windows APIs. Calling this function
/// marks the block as "in use". Trying to call this function again before the operation is
/// indicated as complete by the poller will result in a panic.
pub fn as_ptr(&self) -> *mut () {
if !self.0.as_ref().get().try_lock() {
panic!("completion packet is already in use");
}

self.0.as_ref().get_ref() as *const _ as *const () as *mut ()
}

/// Cancel the in flight operation.
///
/// # Safety
///
/// The packet must be in flight and the operation must be cancelled already.
pub unsafe fn cancel(&mut self) {
self.0.as_ref().get().unlock();
}
}

/// Extension trait for the [`Poller`] type that provides functionality specific to IOCP-based
/// platforms.
Expand Down
179 changes: 179 additions & 0 deletions tests/windows_overlapped.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
//! Take advantage of overlapped I/O on Windows using CompletionPacket.

#![cfg(windows)]

use polling::os::iocp::CompletionPacket;
use polling::{Event, Events, Poller};

use std::io;
use std::os::windows::ffi::OsStrExt;
use std::os::windows::io::{AsRawHandle, FromRawHandle, OwnedHandle};

use windows_sys::Win32::{Foundation as wf, Storage::FileSystem as wfs, System::IO as wio};

#[test]
fn win32_file_io() {
// Create two completion packets: one for reading, one for writing.
let read_packet = CompletionPacket::new(Event::readable(1));
let write_packet = CompletionPacket::new(Event::writable(2));

// Create a poller.
let poller = Poller::new().unwrap();
let mut events = Events::new();

// Open a file for writing.
let dir = tempfile::tempdir().unwrap();
let file_path = dir.path().join("test.txt");
let fname = file_path
.as_os_str()
.encode_wide()
.chain(Some(0))
.collect::<Vec<_>>();
let file_handle = unsafe {
let raw_handle = wfs::CreateFileW(
fname.as_ptr(),
wf::GENERIC_WRITE | wf::GENERIC_READ,
0,
std::ptr::null_mut(),
wfs::CREATE_ALWAYS,
wfs::FILE_FLAG_OVERLAPPED,
0,
);

if raw_handle == wf::INVALID_HANDLE_VALUE {
panic!("CreateFileW failed: {}", io::Error::last_os_error());
}

OwnedHandle::from_raw_handle(raw_handle as _)
};

// Associate this file with the poller.
unsafe {
let poller_handle = poller.as_raw_handle();
if wio::CreateIoCompletionPort(file_handle.as_raw_handle() as _, poller_handle as _, 1, 0)
== 0
{
panic!(
"CreateIoCompletionPort failed: {}",
io::Error::last_os_error()
);
}
}

// Repeatedly write to the pipe.
let input_text = "Now is the time for all good men to come to the aid of their party";
let mut len = input_text.len();
while len > 0 {
// Begin the write.
let ptr = write_packet.as_ptr() as *mut _;
unsafe {
if wfs::WriteFile(
file_handle.as_raw_handle() as _,
input_text.as_ptr() as _,
len as _,
std::ptr::null_mut(),
ptr,
) == 0
&& wf::GetLastError() != wf::ERROR_IO_PENDING
{
panic!("WriteFile failed: {}", io::Error::last_os_error());
}
}

// Wait for the overlapped operation to complete.
'waiter: loop {
events.clear();
println!("Starting wait...");
poller.wait(&mut events, None).unwrap();
println!("Got events");

for event in events.iter() {
if event.writable && event.key == 2 {
break 'waiter;
}
}
}

// Decrement the length by the number of bytes written.
let bytes_written = input_text.len();
len -= bytes_written;
}

// Close the file and re-open it for reading.
drop(file_handle);
let file_handle = unsafe {
let raw_handle = wfs::CreateFileW(
fname.as_ptr(),
wf::GENERIC_READ | wf::GENERIC_WRITE,
0,
std::ptr::null_mut(),
wfs::OPEN_EXISTING,
wfs::FILE_FLAG_OVERLAPPED,
0,
);

if raw_handle == wf::INVALID_HANDLE_VALUE {
panic!("CreateFileW failed: {}", io::Error::last_os_error());
}

OwnedHandle::from_raw_handle(raw_handle as _)
};

// Associate this file with the poller.
unsafe {
let poller_handle = poller.as_raw_handle();
if wio::CreateIoCompletionPort(file_handle.as_raw_handle() as _, poller_handle as _, 2, 0)
== 0
{
panic!(
"CreateIoCompletionPort failed: {}",
io::Error::last_os_error()
);
}
}

// Repeatedly read from the pipe.
let mut buffer = vec![0u8; 1024];
let mut buffer_cursor = &mut *buffer;
let mut len = 1024;
let mut bytes_received = 0;

while bytes_received < input_text.len() {
// Begin the read.
let ptr = read_packet.as_ptr().cast();
unsafe {
if wfs::ReadFile(
file_handle.as_raw_handle() as _,
buffer_cursor.as_mut_ptr() as _,
len as _,
std::ptr::null_mut(),
ptr,
) == 0
&& wf::GetLastError() != wf::ERROR_IO_PENDING
{
panic!("ReadFile failed: {}", io::Error::last_os_error());
}
}

// Wait for the overlapped operation to complete.
'waiter: loop {
events.clear();
poller.wait(&mut events, None).unwrap();

for event in events.iter() {
if event.readable && event.key == 1 {
break 'waiter;
}
}
}

// Increment the cursor and decrement the length by the number of bytes read.
let bytes_read = input_text.len();
buffer_cursor = &mut buffer_cursor[bytes_read..];
len -= bytes_read;
bytes_received += bytes_read;
}

assert_eq!(bytes_received, input_text.len());
assert_eq!(&buffer[..bytes_received], input_text.as_bytes());
}