Skip to content

Commit

Permalink
fix: handle REMOVE event
Browse files Browse the repository at this point in the history
  • Loading branch information
Berrysoft committed Jun 11, 2024
1 parent 16b6fa6 commit e2d17ed
Showing 1 changed file with 115 additions and 16 deletions.
131 changes: 115 additions & 16 deletions src/iocp/psn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod wait;

use std::collections::HashMap;
use std::io;
use std::mem::MaybeUninit;
use std::os::windows::io::{
AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, FromRawHandle, OwnedHandle,
RawHandle, RawSocket,
Expand All @@ -30,15 +31,15 @@ use wait::WaitCompletionPacket;
use windows_sys::Win32::Foundation::{ERROR_SUCCESS, INVALID_HANDLE_VALUE, WAIT_TIMEOUT};
use windows_sys::Win32::Networking::WinSock::{
ProcessSocketNotifications, SOCK_NOTIFY_EVENT_ERR, SOCK_NOTIFY_EVENT_HANGUP,
SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_OP_DISABLE, SOCK_NOTIFY_OP_ENABLE,
SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, SOCK_NOTIFY_REGISTER_EVENT_IN,
SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, SOCK_NOTIFY_REGISTRATION,
SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, SOCK_NOTIFY_TRIGGER_ONESHOT,
SOCK_NOTIFY_TRIGGER_PERSISTENT,
SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_EVENT_REMOVE, SOCK_NOTIFY_OP_DISABLE,
SOCK_NOTIFY_OP_ENABLE, SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP,
SOCK_NOTIFY_REGISTER_EVENT_IN, SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT,
SOCK_NOTIFY_REGISTRATION, SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL,
SOCK_NOTIFY_TRIGGER_ONESHOT, SOCK_NOTIFY_TRIGGER_PERSISTENT,
};
use windows_sys::Win32::System::Threading::INFINITE;
use windows_sys::Win32::System::IO::{
CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED_ENTRY,
CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED, OVERLAPPED_ENTRY,
};

use super::dur2timeout;
Expand Down Expand Up @@ -143,12 +144,18 @@ impl Poller {

let socket = socket.as_raw_socket();

lock!(self.sources.read())
let sources = lock!(self.sources.read());
let oldkey = sources
.get(&socket)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;

if oldkey != &interest.key {
// To change the key, remove the old registration and wait for REMOVE event.
let info = create_registration(socket, Event::none(*oldkey), PollMode::Oneshot, false);
self.update_and_wait_for_remove(info, *oldkey)?;
}
let info = create_registration(socket, interest, mode, true);
unsafe { self.update_source(info) }
self.update_source(info)
}

/// Deletes a socket.
Expand All @@ -166,7 +173,7 @@ impl Poller {
.remove(&socket)
.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?;
let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false);
unsafe { self.update_source(info) }
self.update_and_wait_for_remove(info, key)
}

/// Add a new waitable to the poller.
Expand Down Expand Up @@ -260,7 +267,7 @@ impl Poller {
}

/// Add or modify the registration.
unsafe fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> {
fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> {
let res = unsafe {
ProcessSocketNotifications(
self.port.as_raw_handle() as _,
Expand All @@ -283,6 +290,94 @@ impl Poller {
}
}

/// Attempt to remove a registration, and wait for the `SOCK_NOTIFY_EVENT_REMOVE` event.
fn update_and_wait_for_remove(
&self,
mut reg: SOCK_NOTIFY_REGISTRATION,
key: usize,
) -> io::Result<()> {
debug_assert_eq!(reg.operation, SOCK_NOTIFY_OP_REMOVE as _);
let mut received = 0;
let mut entry: MaybeUninit<OVERLAPPED_ENTRY> = MaybeUninit::uninit();

let repost = |entry: OVERLAPPED_ENTRY| {
self.post_raw(
entry.dwNumberOfBytesTransferred,
entry.lpCompletionKey,
entry.lpOverlapped,
)
};

// Update the registration and wait for the event in the same time.
// However, the returned completion entry may not be the wanted REMOVE event.
let res = unsafe {
ProcessSocketNotifications(
self.port.as_raw_handle() as _,
1,
&mut reg,
0,
1,
entry.as_mut_ptr().cast(),
&mut received,
)
};
match res {
ERROR_SUCCESS | WAIT_TIMEOUT => {
if reg.registrationResult != ERROR_SUCCESS {
// If the registration is not successful, the received entry should be reposted.
if received == 1 {
repost(unsafe { entry.assume_init() })?;
}
return Err(io::Error::from_raw_os_error(reg.registrationResult as _));
}
}
_ => return Err(io::Error::from_raw_os_error(res as _)),
}
if received == 1 {
// The registration is successful, and check the received entry.
let entry = unsafe { entry.assume_init() };
if entry.lpCompletionKey == key {
// If the entry is current key but not the remove event, just ignore it.
if (entry.dwNumberOfBytesTransferred & SOCK_NOTIFY_EVENT_REMOVE) != 0 {
return Ok(());
}
} else {
repost(entry)?;
}
}

// No wanted event, start a loop to wait for it.
// TODO: any better solutions?
loop {
let res = unsafe {
ProcessSocketNotifications(
self.port.as_raw_handle() as _,
0,
null_mut(),
0,
1,
entry.as_mut_ptr().cast(),
&mut received,
)
};
match res {
ERROR_SUCCESS => {
debug_assert_eq!(received, 1);
let entry = unsafe { entry.assume_init() };
if entry.lpCompletionKey == key {
if (entry.dwNumberOfBytesTransferred & SOCK_NOTIFY_EVENT_REMOVE) != 0 {
return Ok(());
}
} else {
repost(entry)?;
}
}
WAIT_TIMEOUT => {}
_ => return Err(io::Error::from_raw_os_error(res as _)),
}
}
}

/// Waits for I/O events with an optional timeout.
pub fn wait(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
let span = tracing::trace_span!(
Expand Down Expand Up @@ -337,13 +432,17 @@ impl Poller {
let _enter = span.enter();

let event = packet.event();
self.post_raw(interest_to_events(event), event.key, null_mut())
}

fn post_raw(
&self,
transferred: u32,
key: usize,
overlapped: *mut OVERLAPPED,
) -> io::Result<()> {
let res = unsafe {
PostQueuedCompletionStatus(
self.port.as_raw_handle() as _,
interest_to_events(event),
event.key,
null_mut(),
)
PostQueuedCompletionStatus(self.port.as_raw_handle() as _, transferred, key, overlapped)
};
if res == 0 {
Err(io::Error::last_os_error())
Expand Down

0 comments on commit e2d17ed

Please sign in to comment.