diff --git a/src/iocp/psn/mod.rs b/src/iocp/psn/mod.rs index f0adbf6..0deb382 100644 --- a/src/iocp/psn/mod.rs +++ b/src/iocp/psn/mod.rs @@ -18,6 +18,7 @@ mod wait; use std::collections::HashMap; use std::io; +use std::mem::MaybeUninit; use std::os::windows::io::{ AsHandle, AsRawHandle, AsRawSocket, BorrowedHandle, BorrowedSocket, FromRawHandle, OwnedHandle, RawHandle, RawSocket, @@ -30,15 +31,15 @@ use wait::WaitCompletionPacket; use windows_sys::Win32::Foundation::{ERROR_SUCCESS, INVALID_HANDLE_VALUE, WAIT_TIMEOUT}; use windows_sys::Win32::Networking::WinSock::{ ProcessSocketNotifications, SOCK_NOTIFY_EVENT_ERR, SOCK_NOTIFY_EVENT_HANGUP, - SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_OP_DISABLE, SOCK_NOTIFY_OP_ENABLE, - SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, SOCK_NOTIFY_REGISTER_EVENT_IN, - SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, SOCK_NOTIFY_REGISTRATION, - SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, SOCK_NOTIFY_TRIGGER_ONESHOT, - SOCK_NOTIFY_TRIGGER_PERSISTENT, + SOCK_NOTIFY_EVENT_IN, SOCK_NOTIFY_EVENT_OUT, SOCK_NOTIFY_EVENT_REMOVE, SOCK_NOTIFY_OP_DISABLE, + SOCK_NOTIFY_OP_ENABLE, SOCK_NOTIFY_OP_REMOVE, SOCK_NOTIFY_REGISTER_EVENT_HANGUP, + SOCK_NOTIFY_REGISTER_EVENT_IN, SOCK_NOTIFY_REGISTER_EVENT_NONE, SOCK_NOTIFY_REGISTER_EVENT_OUT, + SOCK_NOTIFY_REGISTRATION, SOCK_NOTIFY_TRIGGER_EDGE, SOCK_NOTIFY_TRIGGER_LEVEL, + SOCK_NOTIFY_TRIGGER_ONESHOT, SOCK_NOTIFY_TRIGGER_PERSISTENT, }; use windows_sys::Win32::System::Threading::INFINITE; use windows_sys::Win32::System::IO::{ - CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED_ENTRY, + CreateIoCompletionPort, PostQueuedCompletionStatus, OVERLAPPED, OVERLAPPED_ENTRY, }; use super::dur2timeout; @@ -143,12 +144,18 @@ impl Poller { let socket = socket.as_raw_socket(); - lock!(self.sources.read()) + let sources = lock!(self.sources.read()); + let oldkey = sources .get(&socket) .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; + if oldkey != &interest.key { + // To change the key, remove the old registration and wait for REMOVE event. + let info = create_registration(socket, Event::none(*oldkey), PollMode::Oneshot, false); + self.update_and_wait_for_remove(info, *oldkey)?; + } let info = create_registration(socket, interest, mode, true); - unsafe { self.update_source(info) } + self.update_source(info) } /// Deletes a socket. @@ -166,7 +173,7 @@ impl Poller { .remove(&socket) .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound))?; let info = create_registration(socket, Event::none(key), PollMode::Oneshot, false); - unsafe { self.update_source(info) } + self.update_and_wait_for_remove(info, key) } /// Add a new waitable to the poller. @@ -260,7 +267,7 @@ impl Poller { } /// Add or modify the registration. - unsafe fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> { + fn update_source(&self, mut reg: SOCK_NOTIFY_REGISTRATION) -> io::Result<()> { let res = unsafe { ProcessSocketNotifications( self.port.as_raw_handle() as _, @@ -283,6 +290,94 @@ impl Poller { } } + /// Attempt to remove a registration, and wait for the `SOCK_NOTIFY_EVENT_REMOVE` event. + fn update_and_wait_for_remove( + &self, + mut reg: SOCK_NOTIFY_REGISTRATION, + key: usize, + ) -> io::Result<()> { + debug_assert_eq!(reg.operation, SOCK_NOTIFY_OP_REMOVE as _); + let mut received = 0; + let mut entry: MaybeUninit = MaybeUninit::uninit(); + + let repost = |entry: OVERLAPPED_ENTRY| { + self.post_raw( + entry.dwNumberOfBytesTransferred, + entry.lpCompletionKey, + entry.lpOverlapped, + ) + }; + + // Update the registration and wait for the event in the same time. + // However, the returned completion entry may not be the wanted REMOVE event. + let res = unsafe { + ProcessSocketNotifications( + self.port.as_raw_handle() as _, + 1, + &mut reg, + 0, + 1, + entry.as_mut_ptr().cast(), + &mut received, + ) + }; + match res { + ERROR_SUCCESS | WAIT_TIMEOUT => { + if reg.registrationResult != ERROR_SUCCESS { + // If the registration is not successful, the received entry should be reposted. + if received == 1 { + repost(unsafe { entry.assume_init() })?; + } + return Err(io::Error::from_raw_os_error(reg.registrationResult as _)); + } + } + _ => return Err(io::Error::from_raw_os_error(res as _)), + } + if received == 1 { + // The registration is successful, and check the received entry. + let entry = unsafe { entry.assume_init() }; + if entry.lpCompletionKey == key { + // If the entry is current key but not the remove event, just ignore it. + if (entry.dwNumberOfBytesTransferred & SOCK_NOTIFY_EVENT_REMOVE) != 0 { + return Ok(()); + } + } else { + repost(entry)?; + } + } + + // No wanted event, start a loop to wait for it. + // TODO: any better solutions? + loop { + let res = unsafe { + ProcessSocketNotifications( + self.port.as_raw_handle() as _, + 0, + null_mut(), + 0, + 1, + entry.as_mut_ptr().cast(), + &mut received, + ) + }; + match res { + ERROR_SUCCESS => { + debug_assert_eq!(received, 1); + let entry = unsafe { entry.assume_init() }; + if entry.lpCompletionKey == key { + if (entry.dwNumberOfBytesTransferred & SOCK_NOTIFY_EVENT_REMOVE) != 0 { + return Ok(()); + } + } else { + repost(entry)?; + } + } + WAIT_TIMEOUT => {} + _ => return Err(io::Error::from_raw_os_error(res as _)), + } + } + } + /// Waits for I/O events with an optional timeout. pub fn wait(&self, events: &mut Events, timeout: Option) -> io::Result<()> { let span = tracing::trace_span!( @@ -337,13 +432,17 @@ impl Poller { let _enter = span.enter(); let event = packet.event(); + self.post_raw(interest_to_events(event), event.key, null_mut()) + } + + fn post_raw( + &self, + transferred: u32, + key: usize, + overlapped: *mut OVERLAPPED, + ) -> io::Result<()> { let res = unsafe { - PostQueuedCompletionStatus( - self.port.as_raw_handle() as _, - interest_to_events(event), - event.key, - null_mut(), - ) + PostQueuedCompletionStatus(self.port.as_raw_handle() as _, transferred, key, overlapped) }; if res == 0 { Err(io::Error::last_os_error())