Skip to content

Commit

Permalink
feat: add PacketWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
huster-zhangpeng committed Nov 4, 2024
1 parent a9633a6 commit b5344d3
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 170 deletions.
225 changes: 129 additions & 96 deletions qbase/src/packet.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use bytes::BytesMut;
use bytes::{buf::UninitSlice, BufMut, BytesMut};
use deref_derive::{Deref, DerefMut};
use encrypt::{encode_long_first_byte, encode_short_first_byte, encrypt_packet, protect_header};
use enum_dispatch::enum_dispatch;
use header::io::WriteHeader;

use crate::cid::ConnectionId;
use crate::{
cid::ConnectionId,
varint::{EncodeBytes, VarInt, WriteVarInt},
};

/// QUIC packet parse error definitions.
pub mod error;
Expand All @@ -28,6 +33,11 @@ pub use header::{
LongHeaderBuilder, OneRttHeader, RetryHeader, VersionNegotiationHeader, ZeroRttHeader,
};

/// The io module provides the functions to parse the QUIC packet.
///
/// The writing of the QUIC packet is not provided here, they are written in place.
pub mod io;

/// Encoding and decoding of packet number
pub mod number;
#[doc(hidden)]
Expand Down Expand Up @@ -129,103 +139,126 @@ impl Iterator for PacketReader {
}
}

/// The io module provides the functions to parse the QUIC packet.
///
/// The writing of the QUIC packet is not provided here, they are written in place.
pub mod io {
use bytes::BytesMut;
use nom::multi::length_data;

use super::{
error::Error,
header::io::be_header,
r#type::{io::be_packet_type, Type},
*,
};
use crate::varint::be_varint;

/// Parse the payload of a packet.
///
/// - For long packets, the payload is a [`nom::multi::length_data`].
/// - For 1-RTT packet, the payload is the remaining content of the datagram.
fn be_payload(
pkty: Type,
datagram: &mut BytesMut,
remain_len: usize,
) -> Result<(BytesMut, usize), Error> {
let offset = datagram.len() - remain_len;
let input = &datagram[offset..];
let (remain, payload) = length_data(be_varint)(input).map_err(|e| match e {
ne @ nom::Err::Incomplete(_) => Error::IncompleteHeader(pkty, ne.to_string()),
_ => unreachable!("parsing packet header never generates error or failure"),
})?;
let payload_len = payload.len();
if payload_len < 20 {
// The payload needs at least 20 bytes to have enough samples to remove the packet header protection.
return Err(Error::UnderSampling(payload.len()));
pub struct PacketWriter<'b> {
buffer: &'b mut [u8],
hdr_len: usize,
len_encoding: usize,
pn: (u64, PacketNumber),
cursor: usize,
end: usize,
tag_len: usize,
}

impl<'b> PacketWriter<'b> {
pub fn new<H>(
header: &H,
buffer: &'b mut [u8],
pn: (u64, PacketNumber),
tag_len: usize,
) -> Option<Self>
where
H: EncodeHeader,
for<'a> &'a mut [u8]: WriteHeader<H>,
{
let hdr_len = header.size();
let len_encoding = header.length_encoding();
if buffer.len() < hdr_len + len_encoding + 20 {
return None;
}
let packet_length = datagram.len() - remain.len();
let bytes = datagram.split_to(packet_length);
Ok((bytes, packet_length - payload_len))

let (mut hdr_buf, mut payload_buf) = buffer.split_at_mut(hdr_len + len_encoding);
let encoded_pn = pn.1;
hdr_buf.put_header(header);
payload_buf.put_packet_number(encoded_pn);

let end = buffer.len() - tag_len;
Some(Self {
buffer,
hdr_len,
len_encoding,
pn,
cursor: hdr_len + len_encoding + encoded_pn.size(),
end,
tag_len,
})
}

/// Parse the QUIC packet from the datagram, given the length of the DCID.
/// Returns the parsed packet or an error, and the datagram removed the packet's content.
pub fn be_packet(datagram: &mut BytesMut, dcid_len: usize) -> Result<Packet, Error> {
let input = datagram.as_ref();
let (remain, pkty) = be_packet_type(input).map_err(|e| match e {
ne @ nom::Err::Incomplete(_) => Error::IncompleteType(ne.to_string()),
nom::Err::Error(e) => e,
_ => unreachable!("parsing packet type never generates failure"),
})?;
let (remain, header) = be_header(pkty, dcid_len, remain).map_err(|e| match e {
ne @ nom::Err::Incomplete(_) => Error::IncompleteHeader(pkty, ne.to_string()),
_ => unreachable!("parsing packet header never generates error or failure"),
})?;
match header {
Header::VN(header) => Ok(Packet::VN(header)),
Header::Retry(header) => Ok(Packet::Retry(header)),
Header::Initial(header) => {
let (bytes, offset) = be_payload(pkty, datagram, remain.len())?;
Ok(Packet::Data(DataPacket {
header: DataHeader::Long(long::DataHeader::Initial(header)),
bytes,
offset,
}))
}
Header::ZeroRtt(header) => {
let (bytes, offset) = be_payload(pkty, datagram, remain.len())?;
Ok(Packet::Data(DataPacket {
header: DataHeader::Long(long::DataHeader::ZeroRtt(header)),
bytes,
offset,
}))
}
Header::Handshake(header) => {
let (bytes, offset) = be_payload(pkty, datagram, remain.len())?;
Ok(Packet::Data(DataPacket {
header: DataHeader::Long(long::DataHeader::Handshake(header)),
bytes,
offset,
}))
}
Header::OneRtt(header) => {
if remain.len() < 20 {
// The payload needs at least 20 bytes to have enough samples to remove the packet header protection.
return Err(Error::UnderSampling(remain.len()));
}
let bytes = datagram.clone();
let offset = bytes.len() - remain.len();
datagram.clear();
Ok(Packet::Data(DataPacket {
header: DataHeader::Short(header),
bytes,
offset,
}))
}
}
pub fn pad(&mut self, cnt: usize) {
self.put_bytes(0, cnt);
}

pub fn encrypt_long_packet(
&mut self,
hpk: &dyn rustls::quic::HeaderProtectionKey,
pk: &dyn rustls::quic::PacketKey,
) -> usize {
let (actual_pn, encoded_pn) = self.pn;
encode_long_first_byte(&mut self.buffer[0], encoded_pn.size());

let pkt_size = self.cursor + self.tag_len;
let payload_len = pkt_size - self.hdr_len - self.len_encoding;
let mut pn_buf = &mut self.buffer[self.hdr_len..self.hdr_len + self.len_encoding];
pn_buf.encode_varint(&VarInt::try_from(payload_len).unwrap(), EncodeBytes::Two);

encrypt_packet(
pk,
actual_pn,
&mut self.buffer[..pkt_size],
self.hdr_len + self.len_encoding + encoded_pn.size(),
);
protect_header(
hpk,
&mut self.buffer[..pkt_size],
self.hdr_len,
encoded_pn.size(),
);
pkt_size
}

pub fn encrypt_short_packet(
&mut self,
key_phase: KeyPhaseBit,
hpk: &dyn rustls::quic::HeaderProtectionKey,
pk: &dyn rustls::quic::PacketKey,
) -> usize {
let (actual_pn, encoded_pn) = self.pn;
encode_short_first_byte(&mut self.buffer[0], encoded_pn.size(), key_phase);

let pkt_size = self.cursor + self.tag_len;
encrypt_packet(
pk,
actual_pn,
&mut self.buffer[..pkt_size],
self.hdr_len + self.len_encoding + encoded_pn.size(),
);
protect_header(
hpk,
&mut self.buffer[..pkt_size],
self.hdr_len,
encoded_pn.size(),
);
pkt_size
}
}

#[cfg(test)]
mod tests {}
unsafe impl BufMut for PacketWriter<'_> {
fn remaining_mut(&self) -> usize {
self.end - self.cursor
}

unsafe fn advance_mut(&mut self, cnt: usize) {
if self.remaining_mut() < cnt {
panic!(
"advance out of bounds: the len is {} but advancing by {}",
cnt,
self.remaining_mut()
);
}

self.cursor += cnt;
}

fn chunk_mut(&mut self) -> &mut UninitSlice {
UninitSlice::new(&mut self.buffer[self.cursor..self.end])
}
}
46 changes: 23 additions & 23 deletions qbase/src/packet/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ pub mod short;

#[doc(hidden)]
pub use long::{
io::{LongHeaderBuilder, WriteLongHeader, WriteSpecific},
io::{LongHeaderBuilder, WriteSpecific},
DataHeader, HandshakeHeader, InitialHeader, LongHeader, RetryHeader, VersionNegotiationHeader,
ZeroRttHeader,
};
#[doc(hidden)]
pub use short::{io::WriteShortHeader, OneRttHeader};
pub use short::OneRttHeader;

use super::r#type::{
long::{v1, Type as LongType, Version},
Expand All @@ -40,6 +40,10 @@ pub trait EncodeHeader {
fn size(&self) -> usize {
0
}

fn length_encoding(&self) -> usize {
0
}
}

/// Get the Destination Connection ID (DCID) of the packet, each packet has a DCID.
Expand Down Expand Up @@ -73,12 +77,8 @@ pub enum Header {
/// how to write the header into a UDP packet.
pub mod io {
use super::{
long::{
io::{LongHeaderBuilder, WriteLongHeader},
Handshake, Initial, Retry, VersionNegotiation, ZeroRtt,
},
short::io::WriteShortHeader,
Header,
long::{io::LongHeaderBuilder, Handshake, Initial, Retry, VersionNegotiation, ZeroRtt},
Header, LongHeader, OneRttHeader,
};
use crate::{
cid::be_connection_id,
Expand Down Expand Up @@ -115,29 +115,29 @@ pub mod io {
/// When sending packets, it is necessary to organize the data and write
/// various types of QUIC packets into an UDP datagram. This trait will
/// be used to write the packet header.
pub trait WriteHeader: bytes::BufMut {
pub trait WriteHeader<H>: bytes::BufMut {
/// Write a packet header to the buffer.
fn put_header(&mut self, header: &Header);
fn put_header(&mut self, header: &H);
}

impl<T> WriteHeader for T
impl<T> WriteHeader<Header> for T
where
T: bytes::BufMut
+ WriteLongHeader<VersionNegotiation>
+ WriteLongHeader<Retry>
+ WriteLongHeader<Initial>
+ WriteLongHeader<ZeroRtt>
+ WriteLongHeader<Handshake>
+ WriteShortHeader,
+ WriteHeader<LongHeader<VersionNegotiation>>
+ WriteHeader<LongHeader<Retry>>
+ WriteHeader<LongHeader<Initial>>
+ WriteHeader<LongHeader<ZeroRtt>>
+ WriteHeader<LongHeader<Handshake>>
+ WriteHeader<OneRttHeader>,
{
fn put_header(&mut self, header: &Header) {
match header {
Header::VN(vn) => self.put_long_header(vn),
Header::Retry(retry) => self.put_long_header(retry),
Header::Initial(initial) => self.put_long_header(initial),
Header::ZeroRtt(zero_rtt) => self.put_long_header(zero_rtt),
Header::Handshake(handshake) => self.put_long_header(handshake),
Header::OneRtt(one_rtt) => self.put_short_header(one_rtt),
Header::VN(vn) => self.put_header(vn),
Header::Retry(retry) => self.put_header(retry),
Header::Initial(initial) => self.put_header(initial),
Header::ZeroRtt(zero_rtt) => self.put_header(zero_rtt),
Header::Handshake(handshake) => self.put_header(handshake),
Header::OneRtt(one_rtt) => self.put_header(one_rtt),
}
}
}
Expand Down
39 changes: 16 additions & 23 deletions qbase/src/packet/header/long.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ impl<S: EncodeHeader> EncodeHeader for LongHeader<S> {
+ 1 + self.scid.len() // scid一样
+ self.specific.size()
}

fn length_encoding(&self) -> usize {
2 // 长包头都带有length字段,统一2字节,能表达1~16KB的长度的包
}
}

macro_rules! bind_type {
Expand Down Expand Up @@ -200,9 +204,12 @@ pub mod io {
use super::*;
use crate::{
cid::WriteConnectionId,
packet::r#type::{
io::WritePacketType,
long::{v1::Type as LongV1Type, Type as LongType},
packet::{
header::io::WriteHeader,
r#type::{
io::WritePacketType,
long::{v1::Type as LongV1Type, Type as LongType},
},
},
varint::{be_varint, WriteVarInt},
};
Expand Down Expand Up @@ -371,31 +378,17 @@ pub mod io {
/// Handshake headers are empty, so there is nothing to write.
impl<T: BufMut> WriteSpecific<Handshake> for T {}

/// A [`bytes::BufMut`] extension trait, makes buffer more friendly to write long headers.
///
/// Write the long header content, including the packet type, destination connection ID,
/// source connection ID, and specific header content.
///
/// ## Note
///
/// It does not write the payload Length of the packet, and leaves it to be filled in when
/// collecting data to send.
pub trait WriteLongHeader<S>: BufMut {
/// Write the long header.
fn put_long_header(&mut self, wrapper: &LongHeader<S>);
}

impl<T, S> WriteLongHeader<S> for T
impl<T, S> WriteHeader<LongHeader<S>> for T
where
T: BufMut + WriteSpecific<S>,
LongHeader<S>: GetType,
{
fn put_long_header(&mut self, long_header: &LongHeader<S>) {
let ty = long_header.get_type();
fn put_header(&mut self, header: &LongHeader<S>) {
let ty = header.get_type();
self.put_packet_type(&ty);
self.put_connection_id(&long_header.dcid);
self.put_connection_id(&long_header.scid);
self.put_specific(&long_header.specific);
self.put_connection_id(&header.dcid);
self.put_connection_id(&header.scid);
self.put_specific(&header.specific);
}
}
}
Expand Down
Loading

0 comments on commit b5344d3

Please sign in to comment.