Skip to content

Commit

Permalink
feat: apply transaction to assemble packets in different spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
huster-zhangpeng committed Nov 18, 2024
1 parent d88848d commit 993164f
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 97 deletions.
66 changes: 52 additions & 14 deletions qbase/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use bytes::{buf::UninitSlice, BufMut, BytesMut};
use deref_derive::{Deref, DerefMut};
use encrypt::{encode_long_first_byte, encode_short_first_byte, encrypt_packet, protect_header};
use enum_dispatch::enum_dispatch;
use getset::Getters;
use header::io::WriteHeader;

use crate::{
Expand Down Expand Up @@ -229,19 +230,28 @@ impl<'b> PacketWriter<'b> {
self.in_flight
}

pub fn is_empty(&self) -> bool {
self.cursor == self.hdr_len + self.len_encoding
}

pub fn encrypt_long_packet(
&mut self,
mut self,
hpk: &dyn rustls::quic::HeaderProtectionKey,
pk: &dyn rustls::quic::PacketKey,
) -> usize {
let (actual_pn, encoded_pn) = self.pn;
encode_long_first_byte(&mut self.buffer[0], encoded_pn.size());
) -> AssembledPacket<'b> {
let mut payload_len = self.cursor - self.hdr_len - self.len_encoding;
debug_assert!(payload_len > 0);
if payload_len + self.tag_len < 20 {
let padding_len = 20 - payload_len - self.tag_len;
self.pad(padding_len);
payload_len += padding_len;
}

let mut len_buf = &mut self.buffer[self.hdr_len..self.hdr_len + self.len_encoding];
let (actual_pn, encoded_pn) = self.pn;
let pkt_size = self.cursor + self.tag_len;
let payload_len = pkt_size - self.hdr_len - self.len_encoding;
let mut pn_buf = &mut self.buffer[self.hdr_len..self.hdr_len + self.len_encoding];
pn_buf.encode_varint(&VarInt::try_from(payload_len).unwrap(), EncodeBytes::Two);

len_buf.encode_varint(&VarInt::try_from(payload_len).unwrap(), EncodeBytes::Two);
encode_long_first_byte(&mut self.buffer[0], encoded_pn.size());
encrypt_packet(
pk,
actual_pn,
Expand All @@ -254,19 +264,30 @@ impl<'b> PacketWriter<'b> {
self.hdr_len,
encoded_pn.size(),
);
pkt_size
AssembledPacket {
buffer: self.buffer,
size: pkt_size,
is_ack_eliciting: self.ack_eliciting,
in_flight: self.in_flight,
}
}

pub fn encrypt_short_packet(
&mut self,
mut self,
key_phase: KeyPhaseBit,
hpk: &dyn rustls::quic::HeaderProtectionKey,
pk: &dyn rustls::quic::PacketKey,
) -> usize {
let (actual_pn, encoded_pn) = self.pn;
encode_short_first_byte(&mut self.buffer[0], encoded_pn.size(), key_phase);
) -> AssembledPacket<'b> {
let payload_len = self.cursor - self.hdr_len - self.len_encoding;
debug_assert!(payload_len > 0);
if payload_len + self.tag_len < 20 {
let padding_len = 20 - payload_len - self.tag_len;
self.pad(padding_len);
}

let pkt_size = self.cursor + self.tag_len;
let (actual_pn, encoded_pn) = self.pn;
encode_short_first_byte(&mut self.buffer[0], encoded_pn.size(), key_phase);
encrypt_packet(
pk,
actual_pn,
Expand All @@ -279,10 +300,27 @@ impl<'b> PacketWriter<'b> {
self.hdr_len,
encoded_pn.size(),
);
pkt_size
AssembledPacket {
buffer: self.buffer,
size: pkt_size,
is_ack_eliciting: self.ack_eliciting,
in_flight: self.in_flight,
}
}
}

#[derive(Debug, Getters)]
pub struct AssembledPacket<'b> {
#[getset(get = "pub")]
buffer: &'b mut [u8],
#[getset(get = "pub")]
size: usize,
#[getset(get = "pub")]
is_ack_eliciting: bool,
#[getset(get = "pub")]
in_flight: bool,
}

impl<F> MarshalFrame<F> for PacketWriter<'_>
where
F: BeFrame,
Expand Down
49 changes: 21 additions & 28 deletions qbase/src/packet/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,22 @@ pub mod io {

#[cfg(test)]
mod tests {
use super::{
io::be_header,
long::{Handshake, Initial, Retry, VersionNegotiation, ZeroRtt},
Header, LongHeaderBuilder,
};
use crate::{
cid::ConnectionId,
packet::{
header::io::WriteHeader,
r#type::{long, long::Ver1, short::OneRtt, Type},
OneRttHeader, SpinBit,
},
};

#[test]
fn test_read_header() {
use super::{io::be_header, Header};
use crate::{
cid::ConnectionId,
packet::{
r#type::{long, long::Ver1, short::OneRtt, Type},
SpinBit,
},
};

// VersionNegotiation Header
let buf = vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02];
let (remain, vn_long_header) =
Expand Down Expand Up @@ -252,24 +256,17 @@ mod tests {
assert_eq!(remain.len(), 0);
match one_rtt_header {
Header::OneRtt(one_rtt) => {
assert_eq!(one_rtt.dcid, ConnectionId::default());
assert_eq!(one_rtt.spin, SpinBit::One);
assert_eq!(
one_rtt,
OneRttHeader::new(SpinBit::One, ConnectionId::default())
);
}
_ => panic!("unexpected header type"),
}
}

#[test]
fn test_write_header() {
use super::{
long::{Handshake, Initial, Retry, VersionNegotiation, ZeroRtt},
LongHeaderBuilder,
};
use crate::{
cid::ConnectionId,
packet::{header::io::WriteHeader, Header, OneRttHeader, SpinBit},
};

// VersionNegotiation Header
let mut buf = vec![];
let vn_long_header = Header::VN(
Expand Down Expand Up @@ -345,19 +342,15 @@ mod tests {

// OneRtt Header with SpinBit::On
let mut buf = vec![];
let one_rtt_header = Header::OneRtt(OneRttHeader {
spin: SpinBit::One,
dcid: ConnectionId::default(),
});
let one_rtt_header =
Header::OneRtt(OneRttHeader::new(SpinBit::One, ConnectionId::default()));
buf.put_header(&one_rtt_header);
assert_eq!(buf, [0x60]);

// OneRtt Header with SpinBit::Off
let mut buf = vec![];
let one_rtt_header = Header::OneRtt(OneRttHeader {
spin: SpinBit::Zero,
dcid: ConnectionId::default(),
});
let one_rtt_header =
Header::OneRtt(OneRttHeader::new(SpinBit::Zero, ConnectionId::default()));
buf.put_header(&one_rtt_header);
assert_eq!(buf, [0x40]);
}
Expand Down
13 changes: 10 additions & 3 deletions qbase/src/packet/header/short.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,18 @@ use crate::{cid::ConnectionId, packet::SpinBit};
///
/// See [1-RTT Packet](https://www.rfc-editor.org/rfc/rfc9000.html#name-1-rtt-packet)
/// in [RFC9000](https://www.rfc-editor.org/rfc/rfc9000.html) for more details.
#[derive(Debug, Default, Clone)]
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct OneRttHeader {
// For simplicity, the spin bit is also part of the 1RTT header.
pub spin: SpinBit,
pub dcid: ConnectionId,
spin: SpinBit,
dcid: ConnectionId,
}

impl OneRttHeader {
/// Create a new 1RTT header.
pub fn new(spin: SpinBit, dcid: ConnectionId) -> Self {
Self { spin, dcid }
}
}

impl EncodeHeader for OneRttHeader {
Expand Down
6 changes: 2 additions & 4 deletions qbase/src/packet/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ impl PacketNumber {

#[cfg(test)]
mod tests {
use super::{PacketNumber, WritePacketNumber};

#[test]
fn test_read_packet_number() {
Expand Down Expand Up @@ -177,9 +178,6 @@ mod tests {

#[test]
fn test_write_packet_number() {
use super::WritePacketNumber;
use crate::packet::PacketNumber;

let mut buf = vec![];
buf.put_packet_number(PacketNumber::encode(0, 0));
assert_eq!(buf, [0x00]);
Expand Down Expand Up @@ -209,6 +207,6 @@ mod tests {
#[test]
#[should_panic]
fn test_encode_packet_number_overflow() {
super::PacketNumber::encode(1 << 31, 0);
PacketNumber::encode(1 << 31, 0);
}
}
91 changes: 88 additions & 3 deletions qconnection/src/conn/space/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@ use qbase::{
decrypt_packet, remove_protection_of_long_packet, remove_protection_of_short_packet,
},
encrypt::{encode_short_first_byte, encrypt_packet, protect_header},
header::{io::WriteHeader, EncodeHeader, GetType, OneRttHeader},
header::{
io::WriteHeader, long::io::LongHeaderBuilder, EncodeHeader, GetType, OneRttHeader,
},
keys::{ArcKeys, ArcOneRttKeys, ArcOneRttPacketKeys, HeaderProtectionKeys},
number::WritePacketNumber,
r#type::Type,
DataPacket, PacketNumber,
signal::SpinBit,
AssembledPacket, DataPacket, PacketNumber, PacketWriter,
},
token::ArcTokenRegistry,
Epoch,
Expand All @@ -42,6 +45,7 @@ use crate::{
path::{ArcPaths, Path, SendBuffer},
pipe,
router::Router,
tx::{PacketMemory, Transaction},
};

#[derive(Clone)]
Expand Down Expand Up @@ -377,6 +381,87 @@ impl DataSpace {
});
}

pub fn try_assemble_0rtt<'b>(
&self,
tx: &mut Transaction<'_>,
buf: &'b mut [u8],
) -> Option<(AssembledPacket<'b>, Option<u64>)> {
if self.one_rtt_keys.get_local_keys().is_some() {
return None;
}

let keys = self.zero_rtt_keys.get_local_keys()?;
let sent_journal = self.journal.sent();
let mut packet = PacketMemory::new(
LongHeaderBuilder::with_cid(tx.dcid(), ConnectionId::default()).zero_rtt(),
buf,
keys.local.packet.tag_len(),
&sent_journal,
)?;

let mut ack = None;
if let Some((largest, rcvd_time)) = tx.need_ack(Epoch::Handshake) {
let rcvd_pkt_records = self.journal.rcvd();
if let Some(ack_frame) =
rcvd_pkt_records.gen_ack_frame_util(largest, rcvd_time, packet.remaining_mut())
{
packet.dump_ack_frame(ack_frame);
ack = Some(largest);
}
}

// TODO: 可以封装在CryptoStream中,当成一个函数
// crypto_stream.try_load_data_into(&mut packet);
let crypto_stream_outgoing = self.crypto_stream.outgoing();
crypto_stream_outgoing.try_load_data_into(&mut packet);

let packet: PacketWriter<'b> = packet.try_into().ok()?;
Some((
packet.encrypt_long_packet(keys.local.header.as_ref(), keys.local.packet.as_ref()),
ack,
))
}

pub fn try_assemble_1rtt<'b>(
&self,
tx: &mut Transaction<'_>,
spin: SpinBit,
buf: &'b mut [u8],
) -> Option<(AssembledPacket<'b>, Option<u64>)> {
let (hpk, pk) = self.one_rtt_keys.get_local_keys()?;
let sent_journal = self.journal.sent();
let mut packet = PacketMemory::new(
OneRttHeader::new(spin, tx.dcid()),
buf,
pk.tag_len(),
&sent_journal,
)?;

let mut ack = None;
if let Some((largest, rcvd_time)) = tx.need_ack(Epoch::Handshake) {
let rcvd_pkt_records = self.journal.rcvd();
if let Some(ack_frame) =
rcvd_pkt_records.gen_ack_frame_util(largest, rcvd_time, packet.remaining_mut())
{
packet.dump_ack_frame(ack_frame);
ack = Some(largest);
}
}

// TODO: 可以封装在CryptoStream中,当成一个函数
// crypto_stream.try_load_data_into(&mut packet);
let crypto_stream_outgoing = self.crypto_stream.outgoing();
crypto_stream_outgoing.try_load_data_into(&mut packet);

let packet: PacketWriter<'b> = packet.try_into().ok()?;
let pk_guard = pk.lock_guard();
let (key_phase, pk) = pk_guard.get_local();
Some((
packet.encrypt_short_packet(key_phase, hpk.as_ref(), pk.as_ref()),
ack,
))
}

pub fn reader(
&self,
challenge_sndbuf: SendBuffer<PathChallengeFrame>,
Expand Down Expand Up @@ -458,7 +543,7 @@ impl ClosingOneRttScope {
let hpk = &hpk.local;

let spin = Default::default();
let hdr = OneRttHeader { spin, dcid };
let hdr = OneRttHeader::new(spin, dcid);
let (mut hdr_buf, payload_tag) = buf.split_at_mut(hdr.size());
let payload_tag_len = payload_tag.len();
let tag_len = pk.tag_len();
Expand Down
Loading

0 comments on commit 993164f

Please sign in to comment.