Skip to content

Commit

Permalink
fix(qcongestion): ensure ack is sent for out-of-order packets
Browse files Browse the repository at this point in the history
  • Loading branch information
metah3m authored and huster-zhangpeng committed Jan 26, 2025
1 parent c6311a2 commit c12af73
Showing 1 changed file with 177 additions and 53 deletions.
230 changes: 177 additions & 53 deletions qcongestion/src/congestion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
cmp::Ordering,
collections::VecDeque,
collections::{HashSet, VecDeque},
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
time::{Duration, Instant},
Expand Down Expand Up @@ -572,66 +572,67 @@ impl super::CongestionControl for ArcCC {
/// It also retires packets that have been acknowledged by an ACK frame that has already sent and which has been confirmed by the peer.
struct RcvdRecords {
epoch: Epoch,
need_ack: bool,
last_ack_sent: Option<(u64, u64)>,
largest_recv_time: Option<(u64, Instant)>,
rcvd_queue: VecDeque<u64>,
ack_immedietly: bool,
last_ack_sent: Option<(u64, HashSet<u64>)>,
rcvd_queue: VecDeque<(u64, Instant)>,
}

impl RcvdRecords {
fn new(epoch: Epoch) -> Self {
Self {
epoch,
need_ack: false,
ack_immedietly: false,
last_ack_sent: None,
largest_recv_time: None,
rcvd_queue: VecDeque::new(),
}
}

fn on_pkt_rcvd(&mut self, pn: u64) {
// An endpoint MUST acknowledge all ack-eliciting Initial and Handshake packets immediately
if self.epoch == Epoch::Initial || self.epoch == Epoch::Handshake {
self.need_ack = true;
self.ack_immedietly = true;
}
// See [Section 13.2.1](https://www.rfc-editor.org/rfc/rfc9000.html#name-sending-ack-frames)
// An endpoint SHOULD generate and send an ACK frame without delay when it receives an ack-eliciting packet either:
// 1. When the received packet has a packet number less than another ack-eliciting packet that has been received
// 2. when the packet has a packet number larger than the highest-numbered ack-eliciting packet that has been
// received and there are missing packets between that packet and this packet.
if let Some(&largest) = self.rcvd_queue.back() {
if pn < largest || pn - largest > 1 {
self.need_ack = true;
}
if pn >= largest {
self.largest_recv_time = Some((pn, Instant::now()));
if pn > largest {
self.rcvd_queue.push_back(pn);
return;
if let Some(&(largest_pn, _)) = self.rcvd_queue.back() {
self.ack_immedietly = pn < largest_pn || pn.saturating_sub(largest_pn) > 1;

let idx = self.rcvd_queue.partition_point(|&(x, _)| x < pn);
match self.rcvd_queue.get(idx) {
Some(&(n, _)) if n != pn => self.rcvd_queue.insert(idx, (pn, Instant::now())),
None => {
self.rcvd_queue.push_back((pn, Instant::now()));
}
_ => (),
}
} else {
self.largest_recv_time = Some((pn, Instant::now()));
};

let index = self.rcvd_queue.partition_point(|&x| x < pn);

if self.rcvd_queue.is_empty() || self.rcvd_queue[index] != pn {
self.rcvd_queue.insert(index, pn);
self.rcvd_queue.push_back((pn, Instant::now()));
}
}

/// Checks whether an ACK frame needs to be sent.
/// Returns [`Some`] if it's time to send an ACK based on the maximum delay.
fn requires_ack(&self, max_delay: Duration) -> Option<(u64, Instant)> {
if self.need_ack {
return self.largest_recv_time;
let largest_pn = self.rcvd_queue.back().map(|&(pn, time)| (pn, time));
if self.ack_immedietly {
return largest_pn;
}

// All ack-eliciting 0-RTT and 1-RTT packets MUST acknowledge within its advertised max_ack_delay
if let Some((largest, recv_time)) = self.largest_recv_time {
let now = Instant::now();
if now - recv_time >= max_delay {
return Some((largest, recv_time));
let empty_set = HashSet::new();
let pending_ack = self
.last_ack_sent
.as_ref()
.map(|(_, set)| set)
.unwrap_or(&empty_set);

let now = Instant::now();
for (pn, rec_time) in self.rcvd_queue.iter() {
if now - *rec_time >= max_delay && !pending_ack.contains(pn) {
return largest_pn;
}
}
None
Expand All @@ -640,26 +641,26 @@ impl RcvdRecords {
/// Called when an ACK is sent.
/// Updates the last ACK sent information and resets the `need_ack` flag.
fn on_ack_sent(&mut self, pn: u64, largest_acked: u64) {
self.last_ack_sent = Some((pn, largest_acked));
self.largest_recv_time = None;
self.need_ack = false;
let pending_retire = self
.rcvd_queue
.iter()
.filter(|&(pn, _)| *pn <= largest_acked)
.map(|&(pn, _)| pn);
self.last_ack_sent = Some((pn, pending_retire.collect()));
self.ack_immedietly = false;
}

/// Processes an acknowledged (ACK) packet.
/// If the ACKed packet number matches the last sent ACK number, retires all acknowledged packets.
fn ack(&mut self, ack: u64, trackers: &[Arc<dyn TrackPackets>; 3]) {
if let Some((pn, largest_acked)) = self.last_ack_sent {
if ack == pn {
trackers[self.epoch].retire(
&mut self
.rcvd_queue
.iter()
.filter(|&&pn| pn <= largest_acked)
.copied(),
);
self.rcvd_queue.retain(|&pn| pn > largest_acked);
}
}
let pending_retire = match self.last_ack_sent {
Some((pn, ref list)) if ack == pn => list.clone(),
_ => return,
};

trackers[self.epoch].retire(&mut pending_retire.iter().cloned());
self.rcvd_queue
.retain(|&(pn, _)| !pending_retire.contains(&pn));
}
}

Expand Down Expand Up @@ -938,41 +939,164 @@ mod tests {
assert_eq!(ack_reocrd.rcvd_queue.len(), 1);

ack_reocrd.on_ack_sent(1, 1);
assert_eq!(ack_reocrd.last_ack_sent, Some((1, 1)));

assert!(ack_reocrd.requires_ack(max_ack_delay).is_none());

ack_reocrd.on_pkt_rcvd(3);
assert_eq!(ack_reocrd.rcvd_queue, vec![1, 3]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![1, 3]
);

ack_reocrd.on_pkt_rcvd(0);
assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![0, 1, 3]
);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 3);

ack_reocrd.on_pkt_rcvd(5);
ack_reocrd.on_pkt_rcvd(7);
assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![0, 1, 3, 5, 7]
);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 7);

// pn 2 ack 0,1,3,5,7
ack_reocrd.on_ack_sent(2, 7);
ack_reocrd.on_pkt_rcvd(9);
assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7, 9]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![0, 1, 3, 5, 7, 9]
);

assert_eq!(
ack_reocrd.last_ack_sent,
Some((
2,
HashSet::from_iter(vec![0, 1, 3, 5, 7].into_iter().map(|x| x as u64))
))
);

// pn 3 ack 0,1,3,5,7,9
ack_reocrd.on_ack_sent(3, 9);
assert_eq!(
ack_reocrd.last_ack_sent,
Some((
3,
HashSet::from_iter(vec![0, 1, 3, 5, 7, 9].into_iter().map(|x| x as u64))
))
);

// recv pn 2 ack, ingore
ack_reocrd.ack(2, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]);
assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7, 9]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![0, 1, 3, 5, 7, 9]
);

ack_reocrd.on_pkt_rcvd(11);
assert_eq!(ack_reocrd.rcvd_queue, vec![0, 1, 3, 5, 7, 9, 11]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![0, 1, 3, 5, 7, 9, 11]
);
// recv pn 3 ack, ret

ack_reocrd.ack(3, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]);
assert_eq!(ack_reocrd.rcvd_queue, vec![11]);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![11]
);
}

#[test]
fn test_ack_record_reversed() {
let max_ack_delay = Duration::from_millis(100);
let mut ack_reocrd = RcvdRecords::new(Epoch::Initial);

ack_reocrd.on_pkt_rcvd(10);
ack_reocrd.on_pkt_rcvd(9);
ack_reocrd.on_pkt_rcvd(8);

assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 10);
ack_reocrd.on_ack_sent(1, 10);

ack_reocrd.on_pkt_rcvd(7);
ack_reocrd.on_pkt_rcvd(6);
ack_reocrd.on_pkt_rcvd(5);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 10);
ack_reocrd.on_ack_sent(2, 10);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None);

// ingnore ack 1
ack_reocrd.ack(1, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![5, 6, 7, 8, 9, 10]
);

ack_reocrd.on_pkt_rcvd(4);

// ack 2
ack_reocrd.ack(2, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay).unwrap().0, 4);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![4]
);
ack_reocrd.on_ack_sent(3, 4);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None);

// ack 3
ack_reocrd.ack(3, &[Arc::new(Mock), Arc::new(Mock), Arc::new(Mock)]);
assert_eq!(ack_reocrd.requires_ack(max_ack_delay), None);
assert_eq!(
ack_reocrd
.rcvd_queue
.iter()
.map(|&(pn, _)| pn)
.collect::<Vec<_>>(),
vec![]
);
}
struct Mock;
impl TrackPackets for Mock {
fn may_loss(&self, _: &mut dyn Iterator<Item = u64>) {}
Expand Down

0 comments on commit c12af73

Please sign in to comment.