From d6f9c5e6d3ab8f6c601056229b47ea2df814f134 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 12 Jun 2024 09:21:18 +0000 Subject: [PATCH] mDNS: update domain lib; fix various issues --- examples/onoff_light/src/main.rs | 2 +- rs-matter/Cargo.toml | 6 +- rs-matter/src/core.rs | 6 +- rs-matter/src/data_model/core.rs | 4 +- rs-matter/src/data_model/objects/handler.rs | 6 +- rs-matter/src/mdns/builtin.rs | 70 +- rs-matter/src/mdns/proto.rs | 990 +++++++++++++++----- rs-matter/src/transport/core.rs | 5 +- 8 files changed, 845 insertions(+), 244 deletions(-) diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 192d62f4..f3321921 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -292,7 +292,7 @@ async fn run_mdns(matter: &Matter<'_>) -> Result<(), Error> { .run_builtin_mdns( &socket, &socket, - Host { + &Host { id: 0, hostname: "rs-matter-demo", ip: ipv4_addr.octets(), diff --git a/rs-matter/Cargo.toml b/rs-matter/Cargo.toml index d6689157..900dd6f2 100644 --- a/rs-matter/Cargo.toml +++ b/rs-matter/Cargo.toml @@ -9,7 +9,7 @@ readme = "README.md" keywords = ["matter", "smart", "smart-home", "IoT", "ESP32"] categories = ["embedded", "network-programming"] license = "Apache-2.0" -rust-version = "1.77" +rust-version = "1.78" [features] default = ["os", "mbedtls"] @@ -27,7 +27,6 @@ rs-matter-macros = { version = "0.1", path = "../rs-matter-macros" } bitflags = { version = "2.5", default-features = false } # TODO: Update to 2.x byteorder = { version = "1.5", default-features = false } heapless = "0.8" -heapless07 = { package = "heapless", version = "0.7" } # Necessary for domain 0.9 num = { version = "0.4", default-features = false } num-derive = "0.4" num-traits = { version = "0.2", default-features = false } @@ -42,8 +41,7 @@ embassy-futures = "0.1" embassy-time = "0.3" embassy-sync = "0.5" critical-section = "1.1" -domain = { version = "0.9", default-features = false, features = ["heapless"] } -octseq = { version = "0.3", default-features = false } +domain = { version = "0.10", default-features = false, features = ["heapless"] } portable-atomic = "1" qrcodegen-no-heap = "1.8" scopeguard = "1" diff --git a/rs-matter/src/core.rs b/rs-matter/src/core.rs index 8c667887..0898731e 100644 --- a/rs-matter/src/core.rs +++ b/rs-matter/src/core.rs @@ -83,8 +83,8 @@ impl<'a> Matter<'a> { /// /// # Parameters /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device - /// requires a set of device attestation certificates and keys. It is the responsibility of - /// this object to return the device attestation details when queried upon. + /// requires a set of device attestation certificates and keys. It is the responsibility of + /// this object to return the device attestation details when queried upon. #[inline(always)] pub const fn new( dev_det: &'a BasicInfoConfig<'a>, @@ -203,7 +203,7 @@ impl<'a> Matter<'a> { &self, send: S, recv: R, - host: crate::mdns::Host<'_>, + host: &crate::mdns::Host<'_>, interface: Option, ) -> Result<(), Error> where diff --git a/rs-matter/src/data_model/core.rs b/rs-matter/src/data_model/core.rs index 6b2d1638..49c96129 100644 --- a/rs-matter/src/data_model/core.rs +++ b/rs-matter/src/data_model/core.rs @@ -82,9 +82,9 @@ where /// The parameters are as follows: /// * `buffers` - a reference to an implementation of `BufferAccess` which is used for allocating RX and TX buffers on the fly, when necessary /// * `subscriptions` - a reference to a `Subscriptions` struct which is used for managing subscriptions. `N` designates the maximum - /// number of subscriptions that can be managed by this handler. + /// number of subscriptions that can be managed by this handler. /// * `handler` - an instance of type `T` which implements the `DataModelHandler` trait. This instance is used for interacting with the underlying - /// clusters of the data model. + /// clusters of the data model. #[inline(always)] pub const fn new(buffers: &'a B, subscriptions: &'a Subscriptions, handler: T) -> Self { Self { diff --git a/rs-matter/src/data_model/objects/handler.rs b/rs-matter/src/data_model/objects/handler.rs index b76b0ec7..53a1d47b 100644 --- a/rs-matter/src/data_model/objects/handler.rs +++ b/rs-matter/src/data_model/objects/handler.rs @@ -143,7 +143,7 @@ impl EmptyHandler { /// /// The returned chained handler works as follows: /// - It will call the provided `handler` instance if the endpoint and cluster - /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. + /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. /// - Otherwise, the empty handler would be invoked, causing the operation to error out. pub const fn chain( self, @@ -185,7 +185,7 @@ pub struct ChainedHandler { impl ChainedHandler { /// Construct a chained handler that works as follows: /// - It will call the provided `handler` instance if the endpoint and cluster - /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. + /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. /// - Otherwise, it will call the `next` handler pub const fn new(handler_endpoint: u16, handler_cluster: u32, handler: H, next: T) -> Self { Self { @@ -200,7 +200,7 @@ impl ChainedHandler { /// /// The returned chained handler works as follows: /// - It will call the provided `handler` instance if the endpoint and cluster - /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. + /// of the incoming request do match the `handler_endpoint` and `handler_cluster` provided here. /// - Otherwise, it will call the `self` handler pub const fn chain

( self, diff --git a/rs-matter/src/mdns/builtin.rs b/rs-matter/src/mdns/builtin.rs index 2dae97db..235f1bc3 100644 --- a/rs-matter/src/mdns/builtin.rs +++ b/rs-matter/src/mdns/builtin.rs @@ -1,4 +1,6 @@ -use core::{cell::RefCell, pin::pin}; +use core::cell::RefCell; +use core::net::IpAddr; +use core::pin::pin; use embassy_futures::select::select; use embassy_sync::blocking_mutex::raw::{NoopRawMutex, RawMutex}; @@ -12,6 +14,7 @@ use crate::transport::network::{ Address, Ipv4Addr, Ipv6Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4, SocketAddrV6, }; +use crate::utils::rand::Rand; use crate::utils::{buf::BufferAccess, notification::Notification, select::Coalesce}; use super::{Service, ServiceMode}; @@ -91,14 +94,16 @@ impl<'a> MdnsImpl<'a> { Ok(()) } + #[allow(clippy::too_many_arguments)] pub async fn run( &self, send: S, recv: R, tx_buf: SB, rx_buf: RB, - host: Host<'_>, + host: &Host<'_>, interface: Option, + rand: Rand, ) -> Result<(), Error> where S: NetworkSend, @@ -108,8 +113,8 @@ impl<'a> MdnsImpl<'a> { { let send = Mutex::::new(send); - let mut broadcast = pin!(self.broadcast(&send, &tx_buf, &host, interface)); - let mut respond = pin!(self.respond(&send, recv, &tx_buf, &rx_buf, &host, interface)); + let mut broadcast = pin!(self.broadcast(&send, &tx_buf, host, interface)); + let mut respond = pin!(self.respond(&send, recv, &tx_buf, &rx_buf, host, interface, rand)); select(&mut broadcast, &mut respond).coalesce().await } @@ -160,6 +165,7 @@ impl<'a> MdnsImpl<'a> { } } + #[allow(clippy::too_many_arguments)] async fn respond( &self, send: &Mutex, @@ -167,7 +173,8 @@ impl<'a> MdnsImpl<'a> { tx_buf: SB, rx_buf: RB, host: &Host<'_>, - _interface: Option, + interface: Option, + rand: Rand, ) -> Result<(), Error> where S: NetworkSend, @@ -185,8 +192,8 @@ impl<'a> MdnsImpl<'a> { let mut tx = tx_buf.get().await.ok_or(ErrorCode::NoSpace)?; let mut send = send.lock().await; - let len = match host.respond(self, &rx[..len], &mut tx, 60) { - Ok(len) => len, + let (len, delay) = match host.respond(self, &rx[..len], &mut tx, 60) { + Ok((len, delay)) => (len, delay), Err(err) => { warn!("mDNS protocol error {err} while replying to {addr}"); continue; @@ -194,19 +201,44 @@ impl<'a> MdnsImpl<'a> { }; if len > 0 { - info!("Replying to mDNS query from {addr}"); - - match send.send_to(&tx[..len], addr).await { - Ok(_) => (), - Err(err) => { - // Turns out we might receive queries from Ipv6 addresses which are actually unreachable by us - // Still to be investigated why, but it does seem that we are receiving packets which contain - // non-link-local Ipv6 addresses, to which we cannot respond - // - // A possible reason for this might be that we are receiving these packets via the broadcast group - // - yet - it is still unclear how these arrive given that we are only listening on the link-local address - warn!("IO error {err:?} while replying to {addr}"); + let ipv4 = addr + .udp() + .map(|addr| matches!(addr.ip(), IpAddr::V4(_))) + .unwrap_or(true); + + let reply_addr = if ipv4 { + Some(SocketAddr::V4(SocketAddrV4::new( + MDNS_IPV4_BROADCAST_ADDR, + MDNS_PORT, + ))) + } else { + interface.map(|interface| { + SocketAddr::V6(SocketAddrV6::new( + MDNS_IPV6_BROADCAST_ADDR, + MDNS_PORT, + 0, + interface, + )) + }) + }; + + if let Some(reply_addr) = reply_addr { + if delay { + let mut b = [0]; + rand(&mut b); + + // Generate a delay between 20 and 120 ms, as per spec + let delay_ms = 20 + (b[0] as u32 * 100 / 256); + + info!("Replying to mDNS query from {addr} on {reply_addr}, delay {delay_ms}ms"); + Timer::after(Duration::from_millis(delay_ms as _)).await; + } else { + info!("Replying to mDNS query from {addr} on {reply_addr}"); } + + send.send_to(&tx[..len], Address::Udp(reply_addr)).await?; + } else { + info!("Cannot reply to mDNS query from {addr}: no suitable broadcast address found"); } } } diff --git a/rs-matter/src/mdns/proto.rs b/rs-matter/src/mdns/proto.rs index af90d163..dd56c50a 100644 --- a/rs-matter/src/mdns/proto.rs +++ b/rs-matter/src/mdns/proto.rs @@ -1,20 +1,19 @@ use core::fmt::Write; -use domain::{ - base::{ - header::Flags, - iana::Class, - message::ShortMessage, - message_builder::{AnswerBuilder, PushError}, - name::FromStrError, - wire::{Composer, ParseError}, - Dname, Message, MessageBuilder, Rtype, ToDname, - }, - dep::octseq::{OctetsBuilder, ShortBuf}, - rdata::{Aaaa, Ptr, Srv, Txt, A}, -}; +use bitflags::bitflags; + +use domain::base::header::Flags; +use domain::base::iana::{Class, Opcode, Rcode}; +use domain::base::message::ShortMessage; +use domain::base::message_builder::{AdditionalBuilder, AnswerBuilder, PushError}; +use domain::base::name::FromStrError; +use domain::base::wire::{Composer, ParseError}; +use domain::base::{Message, MessageBuilder, Name, RecordSectionBuilder, Rtype, ToName}; +use domain::dep::octseq::Truncate; +use domain::dep::octseq::{OctetsBuilder, ShortBuf}; +use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; + use log::trace; -use octseq::Truncate; use crate::error::{Error, ErrorCode}; @@ -80,6 +79,17 @@ where } } +// What additional data to be set in the mDNS reply +bitflags! { + #[repr(transparent)] + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct AdditionalData: u8 { + const IPS = 0x01; + const SRV = 0x02; + const TXT = 0x04; + } +} + pub struct Host<'a> { pub id: u16, pub hostname: &'a str, @@ -88,6 +98,10 @@ pub struct Host<'a> { } impl<'a> Host<'a> { + /// Broadcast an mDNS packet with the host and its services + /// + /// Should be done pro-actively every time there is a change in the host + /// data itself, or in the data of one of its services pub fn broadcast( &self, services: T, @@ -100,196 +114,295 @@ impl<'a> Host<'a> { let mut answer = message.answer(); - self.set_broadcast(services, &mut answer, ttl_sec)?; + self.set_answer_header(&mut answer); + + self.add_ipv4(&mut answer, ttl_sec)?; + self.add_ipv6(&mut answer, ttl_sec)?; + + services.for_each(|service| { + service.add_service(&mut answer, self.hostname, ttl_sec)?; + service.add_service_type(&mut answer, ttl_sec)?; + service.add_service_subtypes(&mut answer, ttl_sec)?; + service.add_dns_sd_service_type(&mut answer, ttl_sec)?; + service.add_dns_sd_service_subtypes(&mut answer, ttl_sec)?; + service.add_txt(&mut answer, ttl_sec)?; + + Ok(()) + })?; let buf = answer.finish(); Ok(buf.1) } + /// Respond to an mDNS packet as long as it contains at least one question + /// which is applicable to the hoswt and its services + /// + /// Returns the number of bytes written to the buffer and a boolean indicating + /// whether the response should be delayed by a random interval of 20 - 120ms, + /// as per the mDNS spec pub fn respond( &self, services: T, data: &[u8], buf: &mut [u8], ttl_sec: u32, - ) -> Result { + ) -> Result<(usize, bool), Error> { let buf = Buf(buf, 0); let message = MessageBuilder::from_target(buf)?; let mut answer = message.answer(); + let mut ad = AdditionalData::empty(); + let mut delay = false; + + if self.answer(data, &services, &mut answer, &mut ad, &mut delay, ttl_sec)? { + let mut additional = answer.additional(); - if self.set_response(data, services, &mut answer, ttl_sec)? { - let buf = answer.finish(); + self.additional(ad, &services, &mut additional, ttl_sec)?; - Ok(buf.1) + let buf = additional.finish(); + + Ok((buf.1, delay)) } else { - Ok(0) + Ok((0, false)) } } - fn set_broadcast( + /// Generate answers for queries in the message which are applicable to the host and + /// the services registered in it + /// + /// Returns true if any answers were generated + /// + /// Updates the `AdditionalData` parameter with indications of what additional data + /// to be set in the "additional data" DNS record + /// + /// Updates the `delay` parameter to indicate if the reply should be delayed to avoid + /// collissions with other mDNS responders + #[allow(clippy::too_many_arguments)] + fn answer( &self, + data: &[u8], services: F, answer: &mut AnswerBuilder, + ad: &mut AdditionalData, + delay: &mut bool, ttl_sec: u32, - ) -> Result<(), Error> + ) -> Result where T: Composer, F: Services, { - self.set_header(answer); + self.set_answer_header(answer); - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; + let message = Message::from_octets(data)?; - services.for_each(|service| { - service.add_service(answer, self.hostname, ttl_sec)?; - service.add_service_type(answer, ttl_sec)?; - service.add_service_subtypes(answer, ttl_sec)?; - service.add_txt(answer, ttl_sec)?; + let mut replied = false; - Ok(()) - })?; + for question in message.question() { + trace!("Handling question {:?}", question); - Ok(()) + let question = question?; + + replied |= self.answer_one( + question.qname(), + question.qtype(), + &services, + answer, + ad, + delay, + ttl_sec, + )?; + } + + Ok(replied) } - fn set_response( + /// Generate additional data records as indicated in the `AdditionalData` parameter + /// + /// Note that we are not 100% compliant with the spec, because for efficiency purposes + /// (to avoid extra allocations) we are putting more data in the additional data section + /// than strictly needed (i.e. we answer with _all_ SRV and _all_ TXT records for _all_ + /// registered services, even when we get a query for a specific service). + /// + /// Given that the additional data section is optional and provisional, this is not expected + /// to be an issue. + fn additional( &self, - data: &[u8], + ad: AdditionalData, services: F, - answer: &mut AnswerBuilder, + additional: &mut AdditionalBuilder, ttl_sec: u32, ) -> Result where T: Composer, F: Services, { - self.set_header(answer); + let mut replied = false; - let message = Message::from_octets(data)?; + if ad.contains(AdditionalData::IPS) { + self.add_ipv4(additional, ttl_sec)?; + self.add_ipv6(additional, ttl_sec)?; + replied = true; + } - let mut replied = false; + if ad.contains(AdditionalData::SRV) { + services.for_each(|service| { + service.add_service(additional, self.hostname, ttl_sec)?; + replied = true; - for question in message.question() { - trace!("Handling question {:?}", question); + Ok(()) + })?; + } - let question = question?; + if ad.contains(AdditionalData::TXT) { + services.for_each(|service| { + service.add_txt(additional, ttl_sec)?; + replied = true; - match question.qtype() { - Rtype::A - if question - .qname() - .name_eq(&Host::host_fqdn(self.hostname, true)?) => - { - self.add_ipv4(answer, ttl_sec)?; - replied = true; - } - Rtype::Aaaa - if question - .qname() - .name_eq(&Host::host_fqdn(self.hostname, true)?) => - { - self.add_ipv6(answer, ttl_sec)?; - replied = true; - } - Rtype::Srv => { - services.for_each(|service| { - if question.qname().name_eq(&service.service_fqdn(true)?) { - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - replied = true; - } + Ok(()) + })?; + } - Ok(()) - })?; - } - Rtype::Ptr => { - services.for_each(|service| { - if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { - service.add_service_type(answer, ttl_sec)?; - replied = true; - } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { - // TODO - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - service.add_service_type(answer, ttl_sec)?; - service.add_service_subtypes(answer, ttl_sec)?; - service.add_txt(answer, ttl_sec)?; - replied = true; - } + Ok(replied) + } - Ok(()) - })?; - } - Rtype::Txt => { - services.for_each(|service| { - if question.qname().name_eq(&service.service_fqdn(true)?) { - service.add_txt(answer, ttl_sec)?; - replied = true; - } + /// Append the answer to a specific question in the message as long as the host can + /// answer that question + /// + /// Updates the `AdditionalData` parameter with information what additional data to be + /// set in the "additional data" DNS record + /// + /// Updates the `delay` parameter to indicate if the reply should be delayed + /// (i.e. when answering DNS-SD queries with the DNS-SD FQDN + /// to avoid collissions with other mDNS responders) + #[allow(clippy::too_many_arguments)] + fn answer_one( + &self, + name: N, + rtype: Rtype, + services: F, + answer: &mut R, + ad: &mut AdditionalData, + delay: &mut bool, + ttl_sec: u32, + ) -> Result + where + N: ToName, + F: Services, + R: RecordSectionBuilder, + T: Composer, + { + if matches!(rtype, Rtype::ANY) { + let mut replied = false; + + replied |= + self.answer_simple(&name, Rtype::A, &services, answer, ad, delay, ttl_sec)?; + replied |= + self.answer_simple(&name, Rtype::AAAA, &services, answer, ad, delay, ttl_sec)?; + replied |= + self.answer_simple(&name, Rtype::PTR, &services, answer, ad, delay, ttl_sec)?; + replied |= + self.answer_simple(&name, Rtype::SRV, &services, answer, ad, delay, ttl_sec)?; + replied |= + self.answer_simple(&name, Rtype::TXT, services, answer, ad, delay, ttl_sec)?; + + Ok(replied) + } else { + self.answer_simple(name, rtype, services, answer, ad, delay, ttl_sec) + } + } - Ok(()) - })?; - } - Rtype::Any => { - // A / AAAA - if question - .qname() - .name_eq(&Host::host_fqdn(self.hostname, true)?) - { - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; + /// Same as `answer_question` but does not answer questions of type "Any" + #[allow(clippy::too_many_arguments)] + fn answer_simple( + &self, + name: N, + rtype: Rtype, + services: F, + answer: &mut R, + ad: &mut AdditionalData, + delay: &mut bool, + ttl_sec: u32, + ) -> Result + where + N: ToName, + F: Services, + R: RecordSectionBuilder, + T: Composer, + { + let mut replied = false; + + match rtype { + Rtype::A if name.name_eq(&Host::host_fqdn(self.hostname, true)?) => { + self.add_ipv4(answer, ttl_sec)?; + replied = true; + } + Rtype::AAAA if name.name_eq(&Host::host_fqdn(self.hostname, true)?) => { + self.add_ipv6(answer, ttl_sec)?; + replied = true; + } + Rtype::SRV => { + services.for_each(|service| { + if name.name_eq(&service.service_fqdn(true)?) { + service.add_service(answer, self.hostname, ttl_sec)?; + *ad |= AdditionalData::IPS; replied = true; } - // PTR - services.for_each(|service| { - if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { - service.add_service_type(answer, ttl_sec)?; - replied = true; - } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { - // TODO - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - service.add_service_type(answer, ttl_sec)?; - service.add_service_subtypes(answer, ttl_sec)?; - service.add_txt(answer, ttl_sec)?; - replied = true; + Ok(()) + })?; + } + Rtype::PTR => { + services.for_each(|service| { + if name.name_eq(&Service::dns_sd_fqdn(true)?) { + service.add_dns_sd_service_type(answer, ttl_sec)?; + service.add_dns_sd_service_subtypes(answer, ttl_sec)?; + *ad |= AdditionalData::SRV; + *ad |= AdditionalData::TXT; + *delay = true; // As we reply to a shared resource question, hence we need to avoid collissions + replied = true; + } else if name.name_eq(&service.service_type_fqdn(true)?) { + service.add_service(answer, self.hostname, ttl_sec)?; + *ad |= AdditionalData::SRV; + *ad |= AdditionalData::TXT; + replied = true; + } else { + for subtype in service.service_subtypes { + if name.name_eq(&service.service_subtype_fqdn(subtype, true)?) { + service.add_service_subtype(answer, subtype, ttl_sec)?; + replied = true; + *ad |= AdditionalData::SRV; + *ad |= AdditionalData::TXT; + break; + } } + } - Ok(()) - })?; - - // SRV - services.for_each(|service| { - if question.qname().name_eq(&service.service_fqdn(true)?) { - self.add_ipv4(answer, ttl_sec)?; - self.add_ipv6(answer, ttl_sec)?; - service.add_service(answer, self.hostname, ttl_sec)?; - replied = true; - } + Ok(()) + })?; + } + Rtype::TXT => { + services.for_each(|service| { + if name.name_eq(&service.service_fqdn(true)?) { + service.add_txt(answer, ttl_sec)?; + replied = true; + } - Ok(()) - })?; - } - _ => (), + Ok(()) + })?; } + _ => (), } Ok(replied) } - fn set_header(&self, answer: &mut AnswerBuilder) { + fn set_answer_header(&self, answer: &mut AnswerBuilder) { let header = answer.header_mut(); header.set_id(self.id); - header.set_opcode(domain::base::iana::Opcode::Query); - header.set_rcode(domain::base::iana::Rcode::NoError); + header.set_opcode(Opcode::QUERY); + header.set_rcode(Rcode::NOERROR); let mut flags = Flags::new(); flags.qr = true; @@ -297,28 +410,28 @@ impl<'a> Host<'a> { header.set_flags(flags); } - fn add_ipv4( - &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { + fn add_ipv4(&self, answer: &mut R, ttl_sec: u32) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { answer.push(( Self::host_fqdn(self.hostname, false).unwrap(), - Class::In, + Class::IN, ttl_sec, A::from_octets(self.ip[0], self.ip[1], self.ip[2], self.ip[3]), )) } - fn add_ipv6( - &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { + fn add_ipv6(&self, answer: &mut R, ttl_sec: u32) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { if let Some(ip) = &self.ipv6 { answer.push(( Self::host_fqdn(self.hostname, false).unwrap(), - Class::In, + Class::IN, ttl_sec, Aaaa::new((*ip).into()), )) @@ -327,111 +440,156 @@ impl<'a> Host<'a> { } } - fn host_fqdn(hostname: &str, suffix: bool) -> Result { + fn host_fqdn(hostname: &str, suffix: bool) -> Result { let suffix = if suffix { "." } else { "" }; - let mut host_fqdn = heapless07::String::<60>::new(); + let mut host_fqdn = heapless::String::<60>::new(); write!(host_fqdn, "{}.local{}", hostname, suffix,).unwrap(); - Dname::>::from_chars(host_fqdn.chars()) + Name::>::from_chars(host_fqdn.chars()) } } impl<'a> Service<'a> { - fn add_service( + fn add_service( &self, - answer: &mut AnswerBuilder, + answer: &mut R, hostname: &str, ttl_sec: u32, - ) -> Result<(), PushError> { + ) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { answer.push(( self.service_fqdn(false).unwrap(), - Class::In, + Class::IN, ttl_sec, Srv::new(0, 0, self.port, Host::host_fqdn(hostname, false).unwrap()), )) } - fn add_service_type( - &self, - answer: &mut AnswerBuilder, - ttl_sec: u32, - ) -> Result<(), PushError> { + fn add_service_type(&self, answer: &mut R, ttl_sec: u32) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { answer.push(( - Self::dns_sd_fqdn(false).unwrap(), - Class::In, + self.service_type_fqdn(false).unwrap(), + Class::IN, ttl_sec, - Ptr::new(self.service_type_fqdn(false).unwrap()), - ))?; + Ptr::new(self.service_fqdn(false).unwrap()), + )) + } + fn add_dns_sd_service_type(&self, answer: &mut R, ttl_sec: u32) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { answer.push(( - self.service_type_fqdn(false).unwrap(), - Class::In, + Self::dns_sd_fqdn(false).unwrap(), + Class::IN, ttl_sec, - Ptr::new(self.service_fqdn(false).unwrap()), + Ptr::new(self.service_type_fqdn(false).unwrap()), )) } - fn add_service_subtypes( + fn add_service_subtypes(&self, answer: &mut R, ttl_sec: u32) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { + for service_subtype in self.service_subtypes { + self.add_service_subtype(answer, service_subtype, ttl_sec)?; + } + + Ok(()) + } + + fn add_dns_sd_service_subtypes( &self, - answer: &mut AnswerBuilder, + answer: &mut R, ttl_sec: u32, - ) -> Result<(), PushError> { + ) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { for service_subtype in self.service_subtypes { - self.add_service_subtype(answer, service_subtype, ttl_sec)?; + self.add_dns_sd_service_subtype(answer, service_subtype, ttl_sec)?; } Ok(()) } - fn add_service_subtype( + fn add_service_subtype( &self, - answer: &mut AnswerBuilder, + answer: &mut R, service_subtype: &str, ttl_sec: u32, - ) -> Result<(), PushError> { - answer.push(( - Self::dns_sd_fqdn(false).unwrap(), - Class::In, - ttl_sec, - Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()), - ))?; - + ) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { answer.push(( self.service_subtype_fqdn(service_subtype, false).unwrap(), - Class::In, + Class::IN, ttl_sec, Ptr::new(self.service_fqdn(false).unwrap()), )) } - fn add_txt( + fn add_dns_sd_service_subtype( &self, - answer: &mut AnswerBuilder, + answer: &mut R, + service_subtype: &str, ttl_sec: u32, - ) -> Result<(), PushError> { - // only way I found to create multiple parts in a Txt - // each slice is the length and then the data - let mut octets = heapless07::Vec::<_, 256>::new(); - //octets.append_slice(&[1u8, b'X'])?; - //octets.append_slice(&[2u8, b'A', b'B'])?; - //octets.append_slice(&[0u8])?; - for (k, v) in self.txt_kvs { - octets.append_slice(&[(k.len() + v.len() + 1) as u8])?; - octets.append_slice(k.as_bytes())?; - octets.append_slice(&[b'='])?; - octets.append_slice(v.as_bytes())?; - } + ) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { + answer.push(( + Self::dns_sd_fqdn(false).unwrap(), + Class::IN, + ttl_sec, + Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()), + )) + } + + fn add_txt(&self, answer: &mut R, ttl_sec: u32) -> Result<(), PushError> + where + R: RecordSectionBuilder, + T: Composer, + { + if self.txt_kvs.is_empty() { + let txt = Txt::from_octets(&[0]).unwrap(); - let txt = Txt::from_octets(&mut octets).unwrap(); + answer.push((self.service_fqdn(false).unwrap(), Class::IN, ttl_sec, txt)) + } else { + let mut octets = heapless::Vec::<_, 256>::new(); + + // only way I found to create multiple parts in a Txt + // each slice is the length and then the data + for (k, v) in self.txt_kvs { + octets.append_slice(&[(k.len() + v.len() + 1) as u8])?; + octets.append_slice(k.as_bytes())?; + octets.append_slice(&[b'='])?; + octets.append_slice(v.as_bytes())?; + } - answer.push((self.service_fqdn(false).unwrap(), Class::In, ttl_sec, txt)) + let txt = Txt::from_octets(&octets).unwrap(); + + answer.push((self.service_fqdn(false).unwrap(), Class::IN, ttl_sec, txt)) + } } - fn service_fqdn(&self, suffix: bool) -> Result { + fn service_fqdn(&self, suffix: bool) -> Result { let suffix = if suffix { "." } else { "" }; - let mut service_fqdn = heapless07::String::<60>::new(); + let mut service_fqdn = heapless::String::<60>::new(); write!( service_fqdn, "{}.{}.{}.local{}", @@ -439,13 +597,13 @@ impl<'a> Service<'a> { ) .unwrap(); - Dname::>::from_chars(service_fqdn.chars()) + Name::>::from_chars(service_fqdn.chars()) } - fn service_type_fqdn(&self, suffix: bool) -> Result { + fn service_type_fqdn(&self, suffix: bool) -> Result { let suffix = if suffix { "." } else { "" }; - let mut service_type_fqdn = heapless07::String::<60>::new(); + let mut service_type_fqdn = heapless::String::<60>::new(); write!( service_type_fqdn, "{}.{}.local{}", @@ -453,17 +611,17 @@ impl<'a> Service<'a> { ) .unwrap(); - Dname::>::from_chars(service_type_fqdn.chars()) + Name::>::from_chars(service_type_fqdn.chars()) } fn service_subtype_fqdn( &self, service_subtype: &str, suffix: bool, - ) -> Result { + ) -> Result { let suffix = if suffix { "." } else { "" }; - let mut service_subtype_fqdn = heapless07::String::<40>::new(); + let mut service_subtype_fqdn = heapless::String::<40>::new(); write!( service_subtype_fqdn, "{}._sub.{}.{}.local{}", @@ -471,11 +629,11 @@ impl<'a> Service<'a> { ) .unwrap(); - Dname::>::from_chars(service_subtype_fqdn.chars()) + Name::>::from_chars(service_subtype_fqdn.chars()) } - fn dns_sd_fqdn(suffix: bool) -> Result { - Dname::>::from_chars( + fn dns_sd_fqdn(suffix: bool) -> Result { + Name::>::from_chars( if suffix { "_services._dns-sd._udp.local." } else { @@ -523,3 +681,413 @@ impl<'a> AsRef<[u8]> for Buf<'a> { &self.0[..self.1] } } + +#[cfg(test)] +mod tests { + use core::net::{Ipv4Addr, Ipv6Addr}; + + use domain::base::header::Flags; + use domain::base::iana::{Class, Opcode, Rcode}; + use domain::base::{Message, MessageBuilder, Name, RecordSection, Rtype, ToName}; + use domain::rdata::AllRecordData; + + use crate::error::Error; + use crate::mdns::Service; + + use super::{Buf, Host, Services}; + + static TEST_HOST_ONLY: TestRun = TestRun { + host: Host { + id: 0, + hostname: "foo", + ip: [192, 168, 0, 1], + ipv6: None, + }, + services: &[], + + tests: &[ + // No questions - no answers + (&[], &[], &[]), + // Other domain - no answers + ( + &[Question { + name: "foo1.local", + qtype: Rtype::A, + }], + &[], + &[], + ), + // Our domain + ( + &[Question { + name: "foo.local", + qtype: Rtype::A, + }], + &[Answer { + owner: "foo.local", + details: AnswerDetails::A(Ipv4Addr::new(192, 168, 0, 1)), + }], + &[], + ), + // ipv6 - no answer (TODO: We should return negative answer here in future) + ( + &[Question { + name: "foo.local", + qtype: Rtype::AAAA, + }], + &[], + &[], + ), + ], + }; + + static TEST_SERVICES: TestRun = TestRun { + host: Host { + id: 1, + hostname: "foo", + ip: [192, 168, 0, 1], + ipv6: Some(Ipv6Addr::new(0xfb, 0, 0, 0, 0, 0, 0, 1).octets()), + }, + services: &[ + Service { + name: "bar", + service: "_matterc", + protocol: "_udp", + port: 1234, + service_subtypes: &["L", "R"], + txt_kvs: &[("a", "b"), ("c", "d")], + }, + Service { + name: "ddd", + service: "_matter", + protocol: "_tcp", + port: 1235, + service_subtypes: &[], + txt_kvs: &[], + }, + ], + + tests: &[ + // No questions - no answers + (&[], &[], &[]), + // Other domain - no answers + ( + &[Question { + name: "foo1.local", + qtype: Rtype::A, + }], + &[], + &[], + ), + // SRV - no answer + ( + &[Question { + name: "foo.bar.local", + qtype: Rtype::SRV, + }], + &[], + &[], + ), + // SRV - Answer + ( + &[Question { + name: "bar._matterc._udp.local", + qtype: Rtype::SRV, + }], + &[Answer { + owner: "bar._matterc._udp.local", + details: AnswerDetails::Srv { + port: 1234, + target: "foo.local", + }, + }], + &[ + Answer { + owner: "foo.local", + details: AnswerDetails::A(Ipv4Addr::new(192, 168, 0, 1)), + }, + Answer { + owner: "foo.local", + details: AnswerDetails::Aaaa(Ipv6Addr::new(0xfb, 0, 0, 0, 0, 0, 0, 1)), + }, + ], + ), + // PTR + ( + &[Question { + name: "_services._dns-sd._udp.local", + qtype: Rtype::PTR, + }], + &[ + Answer { + owner: "_services._dns-sd._udp.local", + details: AnswerDetails::Ptr("_matterc._udp.local"), + }, + Answer { + owner: "_services._dns-sd._udp.local", + details: AnswerDetails::Ptr("L._sub._matterc._udp.local"), + }, + Answer { + owner: "_services._dns-sd._udp.local", + details: AnswerDetails::Ptr("R._sub._matterc._udp.local"), + }, + Answer { + owner: "_services._dns-sd._udp.local", + details: AnswerDetails::Ptr("_matter._tcp.local"), + }, + ], + &[ + Answer { + owner: "bar._matterc._udp.local", + details: AnswerDetails::Srv { + port: 1234, + target: "foo.local", + }, + }, + Answer { + owner: "ddd._matter._tcp.local", + details: AnswerDetails::Srv { + port: 1235, + target: "foo.local", + }, + }, + Answer { + owner: "bar._matterc._udp.local", + details: AnswerDetails::Txt(&[("a", "b"), ("c", "d")]), + }, + Answer { + owner: "ddd._matter._tcp.local", + details: AnswerDetails::Txt(&[]), + }, + ], + ), + ], + }; + + #[test] + fn test_host_only() { + TEST_HOST_ONLY.run(); + } + + #[test] + fn test_services() { + TEST_SERVICES.run(); + } + + struct TestRun<'a> { + host: Host<'a>, + services: &'a [Service<'a>], + + tests: &'a [(&'a [Question<'a>], &'a [Answer<'a>], &'a [Answer<'a>])], + } + + impl<'a> TestRun<'a> { + fn run(&self) { + let mut buf1 = [0; 1500]; + let mut buf2 = [0; 1500]; + + for (questions, expected_answers, expected_additional) in self.tests { + let data = Question::prep(&mut buf1, self.host.id, questions); + + let (len, _) = self + .host + .respond(self.services, data, &mut buf2, 0) + .unwrap(); + + if len > 0 { + Answer::validate( + &buf2[..len], + self.host.id, + expected_answers, + expected_additional, + ); + } else { + assert!(expected_answers.is_empty()); + assert!(expected_additional.is_empty()); + } + } + } + } + + #[derive(Debug)] + struct Question<'a> { + name: &'a str, + qtype: Rtype, + } + + impl<'a> Question<'a> { + fn prep<'b>(buf: &'b mut [u8], id: u16, questions: &[Question]) -> &'b [u8] { + let message = MessageBuilder::from_target(Buf(buf, 0)).unwrap(); + + let mut qb = message.question(); + + let header = qb.header_mut(); + header.set_id(id); + header.set_opcode(Opcode::QUERY); + header.set_rcode(Rcode::NOERROR); + + let mut flags = Flags::new(); + flags.qr = false; + flags.aa = true; + header.set_flags(flags); + + for question in questions { + let dname = + Name::>::from_chars(question.name.chars()).unwrap(); + + qb.push((dname, question.qtype, Class::IN)).unwrap(); + } + + let len = qb.finish().as_ref().len(); + + &buf[..len] + } + } + + #[derive(Debug)] + enum AnswerDetails<'a> { + A(Ipv4Addr), + Aaaa(Ipv6Addr), + Srv { port: u16, target: &'a str }, + Ptr(&'a str), + Txt(&'a [(&'a str, &'a str)]), + } + + #[derive(Debug)] + struct Answer<'a> { + owner: &'a str, + details: AnswerDetails<'a>, + } + + impl<'a> Answer<'a> { + fn validate( + data: &[u8], + expected_id: u16, + expected_answers: &[Answer], + expected_additional: &[Answer], + ) { + let message = Message::from_octets(data).unwrap(); + + let header = message.header(); + assert_eq!(header.id(), expected_id); + assert_eq!(header.opcode(), Opcode::QUERY); + assert_eq!(header.rcode(), Rcode::NOERROR); + + Answer::validate_section(&message.answer().unwrap(), expected_answers); + Answer::validate_section(&message.additional().unwrap(), expected_additional); + } + + fn validate_section(section: &RecordSection<&[u8]>, expected_answers: &[Answer]) { + let mut answers = section.peekable(); + let mut expectations = expected_answers.iter().peekable(); + + while answers.peek().is_some() && expectations.peek().is_some() { + let answer = answers + .next() + .unwrap() + .unwrap() + .to_any_record::>() + .unwrap(); + + let expected = expectations.next().unwrap(); + + assert!( + answer.owner().name_eq( + &Name::>::from_chars(expected.owner.chars()).unwrap() + ), + "OWNER {} (answer) != {} (expected)", + answer.owner(), + expected.owner + ); + + match (answer.data(), &expected.details) { + (AllRecordData::A(a), AnswerDetails::A(ip)) => { + assert_eq!(Ipv4Addr::from(a.addr().octets()), *ip); + } + (AllRecordData::Aaaa(a), AnswerDetails::Aaaa(ip)) => { + assert_eq!(Ipv6Addr::from(a.addr().octets()), *ip); + } + (AllRecordData::Srv(s), AnswerDetails::Srv { port, target }) => { + assert_eq!(s.port(), *port); + assert!( + s.target().name_eq( + &Name::>::from_chars(target.chars()).unwrap() + ), + "SRV {} (answer) != {} (expected)", + s.target(), + target + ); + } + (AllRecordData::Ptr(p), AnswerDetails::Ptr(name)) => { + assert!( + p.ptrdname().name_eq( + &Name::>::from_chars(name.chars()).unwrap() + ), + "PTR {} (answer) != {} (expected)", + p.ptrdname(), + name, + ); + } + (AllRecordData::Txt(txt), AnswerDetails::Txt(kvs)) => { + use core::fmt::Write; + + let mut txt = txt.iter().peekable(); + let mut kvs = kvs.iter().peekable(); + + while txt.peek().is_some() && kvs.peek().is_some() { + let t = txt.next().unwrap(); + + if t.is_empty() || t.len() == 1 && t[0] == 0 { + continue; + } + + let (k, v) = kvs.next().unwrap(); + + let mut str = heapless::Vec::::new(); + write!(&mut str, "{k}={v}").unwrap(); + + assert_eq!(t, str); + } + + while let Some(t) = txt.next() { + if !t.is_empty() { + panic!("Unexpected TXT string {:?} for {}", t, expected.owner); + } + } + + if let Some((k, v)) = kvs.next() { + panic!("Missing TXT string {k}={v} for {}", expected.owner); + } + } + other => panic!("Unexpected record type: {:?}", other), + } + } + + if let Some(answer) = answers.next() { + let answer = answer + .unwrap() + .to_any_record::>() + .unwrap(); + + panic!("Unexpected answer {:?}", answer); + } + + if let Some(expected) = expectations.next() { + panic!("Missing answer {:?}", expected); + } + } + } + + impl<'a> Services for &'a [Service<'a>] { + fn for_each(&self, mut callback: F) -> Result<(), Error> + where + F: FnMut(&Service) -> Result<(), Error>, + { + for service in self.iter() { + callback(service)?; + } + + Ok(()) + } + } +} diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index 2acfc1a4..a2afde78 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -80,6 +80,7 @@ pub struct TransportMgr<'m> { pub(crate) session_removed: Notification, pub session_mgr: RefCell, // For testing pub(crate) mdns: MdnsImpl<'m>, + rand: Rand, } impl<'m> TransportMgr<'m> { @@ -92,6 +93,7 @@ impl<'m> TransportMgr<'m> { session_removed: Notification::new(), session_mgr: RefCell::new(SessionMgr::new(epoch, rand)), mdns, + rand, } } @@ -258,7 +260,7 @@ impl<'m> TransportMgr<'m> { &self, send: S, recv: R, - host: crate::mdns::Host<'_>, + host: &crate::mdns::Host<'_>, interface: Option, ) -> Result<(), Error> where @@ -275,6 +277,7 @@ impl<'m> TransportMgr<'m> { PacketBufferExternalAccess(&self.rx), host, interface, + self.rand, ) .await } else {