From b4f5f39245a7ed35aa0bfb2cb3ee6d88250811d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20CORTIER?= Date: Thu, 7 Dec 2023 22:25:28 -0500 Subject: [PATCH] fix(connector): handle requested_channel_id != joined_channel_id Some RDP servers may join channels used a different channel ID than requested by the client. This patch adds proper handling for this. --- .../src/channel_connection.rs | 45 ++++++++++++++++--- crates/ironrdp-connector/src/connection.rs | 14 ++++-- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/crates/ironrdp-connector/src/channel_connection.rs b/crates/ironrdp-connector/src/channel_connection.rs index c6fcd7d1b..d1ac68a9d 100644 --- a/crates/ironrdp-connector/src/channel_connection.rs +++ b/crates/ironrdp-connector/src/channel_connection.rs @@ -2,6 +2,7 @@ use std::mem; use ironrdp_pdu::write_buf::WriteBuf; use ironrdp_pdu::{mcs, PduHint}; +use ironrdp_svc::StaticChannelSet; use crate::{ConnectorError, ConnectorErrorExt as _, ConnectorResult, Sequence, State, Written}; @@ -54,16 +55,18 @@ impl State for ChannelConnectionState { #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct ChannelConnectionSequence { pub state: ChannelConnectionState, + pub static_channels: StaticChannelSet, pub channel_ids: Vec, } impl ChannelConnectionSequence { - pub fn new(io_channel_id: u16, mut channel_ids: Vec) -> Self { + pub fn new(static_channels: StaticChannelSet, io_channel_id: u16, mut channel_ids: Vec) -> Self { // I/O channel ID must be joined as well channel_ids.push(io_channel_id); Self { state: ChannelConnectionState::SendErectDomainRequest, + static_channels, channel_ids, } } @@ -173,11 +176,41 @@ impl Sequence for ChannelConnectionSequence { debug!(message = ?channel_join_confirm, "Received"); - if channel_join_confirm.initiator_id != user_channel_id - || channel_join_confirm.channel_id != channel_join_confirm.requested_channel_id - || channel_join_confirm.channel_id != channel_id - { - return Err(general_err!("received bad MCS Channel Join Confirm")); + if channel_join_confirm.initiator_id != user_channel_id { + warn!( + channel_join_confirm.initiator_id, + user_channel_id, "Inconsistent initiator ID for MCS Channel Join Confirm", + ) + } + + if channel_id != channel_join_confirm.requested_channel_id { + return Err(reason_err!( + "ChannelJoinConfirm", + "unexpected requested_channel_id in MCS Channel Join Confirm: received {}, got {}", + channel_id, + channel_join_confirm.requested_channel_id, + )); + } + + if channel_id != channel_join_confirm.channel_id { + warn!( + channel_join_confirm.channel_id, + channel_id, "Server changed the ID of the joined channel" + ); + + // Update the channel ID of the static channel. + let channel = self + .static_channels + .get_type_id_by_channel_id(channel_id) + .ok_or_else(|| { + reason_err!( + "ChannelJoinConfirm", + "failed to retrieve the channel type for channel ID {channel_id} (this is a bug)" + ) + })?; + + self.static_channels + .attach_channel_id(channel, channel_join_confirm.channel_id); } let next_index = index.checked_add(1).unwrap(); diff --git a/crates/ironrdp-connector/src/connection.rs b/crates/ironrdp-connector/src/connection.rs index 364522415..6c2303798 100644 --- a/crates/ironrdp-connector/src/connection.rs +++ b/crates/ironrdp-connector/src/connection.rs @@ -348,13 +348,13 @@ impl Sequence for ClientConnector { debug!(?static_channel_ids, io_channel_id); - let joined: Vec<_> = self + let zipped: Vec<_> = self .static_channels .type_ids() .zip(static_channel_ids.iter().copied()) .collect(); - joined.into_iter().for_each(|(channel, channel_id)| { + zipped.into_iter().for_each(|(channel, channel_id)| { self.static_channels.attach_channel_id(channel, channel_id); }); @@ -363,7 +363,12 @@ impl Sequence for ClientConnector { ClientConnectorState::ChannelConnection { selected_protocol, io_channel_id, - channel_connection: ChannelConnectionSequence::new(io_channel_id, static_channel_ids), + channel_connection: ChannelConnectionSequence::new( + // The channel connection sequence will update the channel ID of each joined channel as necessary. + std::mem::take(&mut self.static_channels), + io_channel_id, + static_channel_ids, + ), }, ) } @@ -382,6 +387,9 @@ impl Sequence for ClientConnector { { debug_assert!(channel_connection.state.is_terminal()); + // Take back the static channel set. + self.static_channels = channel_connection.static_channels; + ClientConnectorState::RdpSecurityCommencement { selected_protocol, io_channel_id,