diff --git a/qcongestion/src/congestion.rs b/qcongestion/src/congestion.rs index 2ec3a1c0..7f322487 100644 --- a/qcongestion/src/congestion.rs +++ b/qcongestion/src/congestion.rs @@ -1,6 +1,6 @@ use std::{ cmp::Ordering, - collections::VecDeque, + collections::{HashSet, VecDeque}, sync::{Arc, Mutex}, task::{Context, Poll, Waker}, time::{Duration, Instant}, @@ -572,19 +572,17 @@ impl super::CongestionControl for ArcCC { /// It also retires packets that have been acknowledged by an ACK frame that has already sent and which has been confirmed by the peer. struct RcvdRecords { epoch: Epoch, - need_ack: bool, - last_ack_sent: Option<(u64, u64)>, - largest_recv_time: Option<(u64, Instant)>, - rcvd_queue: VecDeque, + ack_immedietly: bool, + last_ack_sent: Option<(u64, HashSet)>, + rcvd_queue: VecDeque<(u64, Instant)>, } impl RcvdRecords { fn new(epoch: Epoch) -> Self { Self { epoch, - need_ack: false, + ack_immedietly: false, last_ack_sent: None, - largest_recv_time: None, rcvd_queue: VecDeque::new(), } } @@ -592,46 +590,49 @@ impl RcvdRecords { fn on_pkt_rcvd(&mut self, pn: u64) { // An endpoint MUST acknowledge all ack-eliciting Initial and Handshake packets immediately if self.epoch == Epoch::Initial || self.epoch == Epoch::Handshake { - self.need_ack = true; + self.ack_immedietly = true; } // See [Section 13.2.1](https://www.rfc-editor.org/rfc/rfc9000.html#name-sending-ack-frames) // An endpoint SHOULD generate and send an ACK frame without delay when it receives an ack-eliciting packet either: // 1. When the received packet has a packet number less than another ack-eliciting packet that has been received // 2. when the packet has a packet number larger than the highest-numbered ack-eliciting packet that has been // received and there are missing packets between that packet and this packet. - if let Some(&largest) = self.rcvd_queue.back() { - if pn < largest || pn - largest > 1 { - self.need_ack = true; - } - if pn >= largest { - self.largest_recv_time = Some((pn, Instant::now())); - if pn > largest { - self.rcvd_queue.push_back(pn); - return; + if let Some(&(largest_pn, _)) = self.rcvd_queue.back() { + self.ack_immedietly = pn < largest_pn || pn.saturating_sub(largest_pn) > 1; + + let idx = self.rcvd_queue.partition_point(|&(x, _)| x < pn); + match self.rcvd_queue.get(idx) { + Some(&(n, _)) if n != pn => self.rcvd_queue.insert(idx, (pn, Instant::now())), + None => { + self.rcvd_queue.push_back((pn, Instant::now())); } + _ => (), } } else { - self.largest_recv_time = Some((pn, Instant::now())); - }; - - let index = self.rcvd_queue.partition_point(|&x| x < pn); - - if self.rcvd_queue.is_empty() || self.rcvd_queue[index] != pn { - self.rcvd_queue.insert(index, pn); + self.rcvd_queue.push_back((pn, Instant::now())); } } /// Checks whether an ACK frame needs to be sent. /// Returns [`Some`] if it's time to send an ACK based on the maximum delay. fn requires_ack(&self, max_delay: Duration) -> Option<(u64, Instant)> { - if self.need_ack { - return self.largest_recv_time; + let largest_pn = self.rcvd_queue.back().map(|&(pn, time)| (pn, time)); + if self.ack_immedietly { + return largest_pn; } + // All ack-eliciting 0-RTT and 1-RTT packets MUST acknowledge within its advertised max_ack_delay - if let Some((largest, recv_time)) = self.largest_recv_time { - let now = Instant::now(); - if now - recv_time >= max_delay { - return Some((largest, recv_time)); + let empty_set = HashSet::new(); + let pending_ack = self + .last_ack_sent + .as_ref() + .map(|(_, set)| set) + .unwrap_or(&empty_set); + + let now = Instant::now(); + for (pn, rec_time) in self.rcvd_queue.iter() { + if now - *rec_time >= max_delay && !pending_ack.contains(pn) { + return largest_pn; } } None @@ -640,26 +641,26 @@ impl RcvdRecords { /// Called when an ACK is sent. /// Updates the last ACK sent information and resets the `need_ack` flag. fn on_ack_sent(&mut self, pn: u64, largest_acked: u64) { - self.last_ack_sent = Some((pn, largest_acked)); - self.largest_recv_time = None; - self.need_ack = false; + let pending_retire = self + .rcvd_queue + .iter() + .filter(|&(pn, _)| *pn <= largest_acked) + .map(|&(pn, _)| pn); + self.last_ack_sent = Some((pn, pending_retire.collect())); + self.ack_immedietly = false; } /// Processes an acknowledged (ACK) packet. /// If the ACKed packet number matches the last sent ACK number, retires all acknowledged packets. fn ack(&mut self, ack: u64, trackers: &[Arc; 3]) { - if let Some((pn, largest_acked)) = self.last_ack_sent { - if ack == pn { - trackers[self.epoch].retire( - &mut self - .rcvd_queue - .iter() - .filter(|&&pn| pn <= largest_acked) - .copied(), - ); - self.rcvd_queue.retain(|&pn| pn > largest_acked); - } - } + let pending_retire = match self.last_ack_sent { + Some((pn, ref list)) if ack == pn => list.clone(), + _ => return, + }; + + trackers[self.epoch].retire(&mut pending_retire.iter().cloned()); + self.rcvd_queue + .retain(|&(pn, _)| !pending_retire.contains(&pn)); } } @@ -938,41 +939,164 @@ mod tests { assert_eq!(ack_reocrd.rcvd_queue.len(), 1); ack_reocrd.on_ack_sent(1, 1); - assert_eq!(ack_reocrd.last_ack_sent, Some((1, 1))); + assert!(ack_reocrd.requires_ack(max_ack_delay).is_none()); ack_reocrd.on_pkt_rcvd(3); - assert_eq!(ack_reocrd.rcvd_queue, vec![1, 3]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![1, 3] + ); ack_reocrd.on_pkt_rcvd(0); - assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![0, 1, 3] + ); assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 3); ack_reocrd.on_pkt_rcvd(5); ack_reocrd.on_pkt_rcvd(7); - assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![0, 1, 3, 5, 7] + ); assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 7); // pn 2 ack 0,1,3,5,7 ack_reocrd.on_ack_sent(2, 7); ack_reocrd.on_pkt_rcvd(9); - assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7, 9]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![0, 1, 3, 5, 7, 9] + ); + + assert_eq!( + ack_reocrd.last_ack_sent, + Some(( + 2, + HashSet::from_iter(vec![0, 1, 3, 5, 7].into_iter().map(|x| x as u64)) + )) + ); // pn 3 ack 0,1,3,5,7,9 ack_reocrd.on_ack_sent(3, 9); + assert_eq!( + ack_reocrd.last_ack_sent, + Some(( + 3, + HashSet::from_iter(vec![0, 1, 3, 5, 7, 9].into_iter().map(|x| x as u64)) + )) + ); // recv pn 2 ack, ingore ack_reocrd.ack(2, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]); - assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7, 9]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![0, 1, 3, 5, 7, 9] + ); ack_reocrd.on_pkt_rcvd(11); - assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7, 9, 11]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![0, 1, 3, 5, 7, 9, 11] + ); // recv pn 3 ack, ret ack_reocrd.ack(3, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]); - assert_eq!(ack_reocrd.rcvd_queue, vec![11]); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![11] + ); } + #[test] + fn test_ack_record_reversed() { + let max_ack_delay = Duration::from_millis(100); + let mut ack_reocrd = RcvdRecords::new(Epoch::Initial); + + ack_reocrd.on_pkt_rcvd(10); + ack_reocrd.on_pkt_rcvd(9); + ack_reocrd.on_pkt_rcvd(8); + + assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 10); + ack_reocrd.on_ack_sent(1, 10); + + ack_reocrd.on_pkt_rcvd(7); + ack_reocrd.on_pkt_rcvd(6); + ack_reocrd.on_pkt_rcvd(5); + assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 10); + ack_reocrd.on_ack_sent(2, 10); + assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None); + + // ingnore ack 1 + ack_reocrd.ack(1, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]); + assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![5, 6, 7, 8, 9, 10] + ); + + ack_reocrd.on_pkt_rcvd(4); + + // ack 2 + ack_reocrd.ack(2, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]); + assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 4); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![4] + ); + ack_reocrd.on_ack_sent(3, 4); + assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None); + + // ack 3 + ack_reocrd.ack(3, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]); + assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None); + assert_eq!( + ack_reocrd + .rcvd_queue + .iter() + .map(|&(pn, _)| pn) + .collect::>(), + vec![] + ); + } struct Mock; impl TrackPackets for Mock { fn may_loss(&self, _: &mut dyn Iterator) {}