From 1c2ee9c7653a9961ac51a50d464c93e658e3e137 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 11 Jun 2024 08:52:40 +0300 Subject: [PATCH] CASE re-arm for a FailSafe armed over a PASE session initially (#170) * PASE session upgrade on AddNOC; make fab_idx NonZeroU8 * Changes after code review feedback --- rs-matter/src/acl.rs | 197 ++++++++++++------ rs-matter/src/data_model/core.rs | 17 +- rs-matter/src/data_model/sdm/failsafe.rs | 46 ++-- .../data_model/sdm/general_commissioning.rs | 6 +- rs-matter/src/data_model/sdm/noc.rs | 73 ++++--- rs-matter/src/data_model/subscriptions.rs | 19 +- .../data_model/system_model/access_control.rs | 51 +++-- rs-matter/src/fabric.rs | 42 ++-- rs-matter/src/secure_channel/case.rs | 23 +- rs-matter/src/secure_channel/pake.rs | 2 +- rs-matter/src/tlv/traits.rs | 26 +++ rs-matter/src/transport/session.rs | 84 ++++---- rs-matter/tests/common/im_engine.rs | 11 +- rs-matter/tests/data_model/acl_and_dataver.rs | 31 +-- 14 files changed, 402 insertions(+), 226 deletions(-) diff --git a/rs-matter/src/acl.rs b/rs-matter/src/acl.rs index 5896a563..21a7c432 100644 --- a/rs-matter/src/acl.rs +++ b/rs-matter/src/acl.rs @@ -15,7 +15,7 @@ * limitations under the License. */ -use core::{cell::RefCell, fmt::Display}; +use core::{cell::RefCell, fmt::Display, num::NonZeroU8}; use crate::{ data_model::objects::{Access, ClusterId, EndptId, Privilege}, @@ -170,18 +170,20 @@ pub struct Accessor<'a> { impl<'a> Accessor<'a> { pub fn for_session(session: &Session, acl_mgr: &'a RefCell) -> Self { match session.get_session_mode() { - SessionMode::Case(c) => { + SessionMode::Case { + fab_idx, cat_ids, .. + } => { let mut subject = AccessorSubjects::new(session.get_peer_node_id().unwrap_or_default()); - for i in c.cat_ids { + for i in *cat_ids { if i != 0 { let _ = subject.add_catid(i); } } - Accessor::new(c.fab_idx, subject, AuthMode::Case, acl_mgr) + Accessor::new(fab_idx.get(), subject, AuthMode::Case, acl_mgr) } - SessionMode::Pase => { - Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, acl_mgr) + SessionMode::Pase { fab_idx } => { + Accessor::new(*fab_idx, AccessorSubjects::new(1), AuthMode::Pase, acl_mgr) } SessionMode::PlainText => { @@ -300,14 +302,14 @@ pub struct AclEntry { targets: Targets, // TODO: Instead of the direct value, we should consider GlobalElements::FabricIndex #[tagval(0xFE)] - pub fab_idx: Option, + pub fab_idx: NonZeroU8, } impl AclEntry { - pub fn new(fab_idx: u8, privilege: Privilege, auth_mode: AuthMode) -> Self { + pub fn new(fab_idx: NonZeroU8, privilege: Privilege, auth_mode: AuthMode) -> Self { const INIT_SUBJECTS: Option = None; Self { - fab_idx: Some(fab_idx), + fab_idx, privilege, auth_mode, subjects: [INIT_SUBJECTS; SUBJECTS_PER_ENTRY], @@ -366,7 +368,7 @@ impl AclEntry { } // true if both are true - allow && self.fab_idx == Some(accessor.fab_idx) + allow && self.fab_idx.get() == accessor.fab_idx } fn match_access_desc(&self, object: &AccessDesc) -> bool { @@ -438,37 +440,46 @@ impl AclMgr { Ok(()) } - pub fn add(&mut self, entry: AclEntry) -> Result<(), Error> { - let cnt = self - .entries - .iter() - .flatten() - .filter(|a| a.fab_idx == entry.fab_idx) - .count(); - if cnt >= ENTRIES_PER_FABRIC { + pub fn add(&mut self, entry: AclEntry) -> Result { + if entry.auth_mode == AuthMode::Pase { + // Reserved for future use + // TODO: Should be something that results in IMStatusCode::ConstraintError + Err(ErrorCode::Invalid)?; + } + + let cnt = self.get_index_in_fabric(MAX_ACL_ENTRIES, entry.fab_idx); + if cnt >= ENTRIES_PER_FABRIC as u8 { Err(ErrorCode::NoSpace)?; } let slot = self.entries.iter().position(|a| a.is_none()); if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES { - if let Some(index) = slot { - self.entries[index] = Some(entry); + let fab_idx = entry.fab_idx; + + let slot = if let Some(slot) = slot { + self.entries[slot] = Some(entry); + + slot } else { self.entries .push(Some(entry)) .map_err(|_| ErrorCode::NoSpace) .unwrap(); - } + + self.entries.len() - 1 + }; self.changed = true; - } - Ok(()) + Ok(self.get_index_in_fabric(slot, fab_idx)) + } else { + Err(ErrorCode::NoSpace.into()) + } } // Since the entries are fabric-scoped, the index is only for entries with the matching fabric index - pub fn edit(&mut self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> { + pub fn edit(&mut self, index: u8, fab_idx: NonZeroU8, new: AclEntry) -> Result<(), Error> { let old = self.for_index_in_fabric(index, fab_idx)?; *old = Some(new); @@ -477,7 +488,7 @@ impl AclMgr { Ok(()) } - pub fn delete(&mut self, index: u8, fab_idx: u8) -> Result<(), Error> { + pub fn delete(&mut self, index: u8, fab_idx: NonZeroU8) -> Result<(), Error> { let old = self.for_index_in_fabric(index, fab_idx)?; *old = None; @@ -486,11 +497,11 @@ impl AclMgr { Ok(()) } - pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> { + pub fn delete_for_fabric(&mut self, fab_idx: NonZeroU8) -> Result<(), Error> { for entry in &mut self.entries { if entry .as_ref() - .map(|e| e.fab_idx == Some(fab_idx)) + .map(|e| e.fab_idx == fab_idx) .unwrap_or(false) { *entry = None; @@ -513,10 +524,34 @@ impl AclMgr { } pub fn allow(&self, req: &AccessReq) -> bool { - // PASE Sessions have implicit access grant - if req.accessor.auth_mode == AuthMode::Pase { + // PASE Sessions with no fabric index have implicit access grant, + // but only as long as the ACL list is empty + // + // As per the spec: + // The Access Control List is able to have an initial entry added because the Access Control Privilege + // Granting algorithm behaves as if, over a PASE commissioning channel during the commissioning + // phase, the following implicit Access Control Entry were present on the Commissionee (but not on + // the Commissioner): + // Access Control Cluster: { + // ACL: [ + // 0: { + // // implicit entry only; does not explicitly exist! + // FabricIndex: 0, // not fabric-specific + // Privilege: Administer, + // AuthMode: PASE, + // Subjects: [], + // Targets: [] // entire node + // } + // ], + // Extension: [] + // } + if req.accessor.auth_mode == AuthMode::Pase + && req.accessor.fab_idx == 0 + && self.entries.iter().all(Option::is_none) + { return true; } + for e in self.entries.iter().flatten() { if e.allow(req) { return true; @@ -568,17 +603,13 @@ impl AclMgr { fn for_index_in_fabric( &mut self, index: u8, - fab_idx: u8, + fab_idx: NonZeroU8, ) -> Result<&mut Option, Error> { // Can't use flatten as we need to borrow the Option<> not the 'AclEntry' for (curr_index, entry) in self .entries .iter_mut() - .filter(|e| { - e.as_ref() - .filter(|e1| e1.fab_idx == Some(fab_idx)) - .is_some() - }) + .filter(|e| e.as_ref().filter(|e1| e1.fab_idx == fab_idx).is_some()) .enumerate() { if curr_index == index as usize { @@ -587,6 +618,19 @@ impl AclMgr { } Err(ErrorCode::NotFound.into()) } + + /// Traverse fabric specific entries to find the index of an entry relative to its fabric. + /// + /// If the ACL Mgr has 3 entries with fabric indexes, 1, 2, 1, then the actual + /// index 2 in the ACL Mgr will be the list index 1 for Fabric 1 + fn get_index_in_fabric(&self, till_slot_index: usize, fab_idx: NonZeroU8) -> u8 { + self.entries + .iter() + .take(till_slot_index) + .flatten() + .filter(|e| e.fab_idx == fab_idx) + .count() as u8 + } } impl core::fmt::Display for AclMgr { @@ -601,8 +645,8 @@ impl core::fmt::Display for AclMgr { #[cfg(test)] #[allow(clippy::bool_assert_comparison)] -mod tests { - use core::cell::RefCell; +pub(crate) mod tests { + use core::{cell::RefCell, num::NonZeroU8}; use crate::{ acl::{gen_noc_cat, AccessorSubjects}, @@ -612,31 +656,52 @@ mod tests { use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; + pub(crate) const FAB_1: NonZeroU8 = match NonZeroU8::new(1) { + Some(f) => f, + None => unreachable!(), + }; + pub(crate) const FAB_2: NonZeroU8 = match NonZeroU8::new(2) { + Some(f) => f, + None => unreachable!(), + }; + pub(crate) const FAB_3: NonZeroU8 = match NonZeroU8::new(3) { + Some(f) => f, + None => unreachable!(), + }; + #[test] fn test_basic_empty_subject_target() { let am = RefCell::new(AclMgr::new()); am.borrow_mut().erase_all().unwrap(); + + let accessor = Accessor::new(0, AccessorSubjects::new(112233), AuthMode::Pase, &am); + let path = GenericPath::new(Some(1), Some(1234), None); + let mut req = AccessReq::new(&accessor, path, Access::READ); + req.set_target_perms(Access::RWVA); + + // Default allow for PASE if no entries yet + assert!(req.allow()); + let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am); let path = GenericPath::new(Some(1), Some(1234), None); let mut req = AccessReq::new(&accessor, path, Access::READ); req.set_target_perms(Access::RWVA); - // Default deny + // Default deny for CASE assert_eq!(req.allow(), false); - // Deny for session mode mismatch - let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase); - am.borrow_mut().add(new).unwrap(); - assert_eq!(req.allow(), false); + // Deny adding invalid auth mode (PASE is reserved for future) + let new = AclEntry::new(FAB_1, Privilege::VIEW, AuthMode::Pase); + assert!(am.borrow_mut().add(new).is_err()); // Deny for fab idx mismatch - let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); - am.borrow_mut().add(new).unwrap(); + let new = AclEntry::new(FAB_1, Privilege::VIEW, AuthMode::Case); + assert_eq!(am.borrow_mut().add(new).unwrap(), 0); assert_eq!(req.allow(), false); // Allow - let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); - am.borrow_mut().add(new).unwrap(); + let new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); + assert_eq!(am.borrow_mut().add(new).unwrap(), 0); assert_eq!(req.allow(), true); } @@ -650,15 +715,15 @@ mod tests { req.set_target_perms(Access::RWVA); // Deny for subject mismatch - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject(112232).unwrap(); - am.borrow_mut().add(new).unwrap(); + assert_eq!(am.borrow_mut().add(new).unwrap(), 0); assert_eq!(req.allow(), false); // Allow for subject match - target is wildcard - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.borrow_mut().add(new).unwrap(); + assert_eq!(am.borrow_mut().add(new).unwrap(), 1); assert_eq!(req.allow(), true); } @@ -681,20 +746,20 @@ mod tests { req.set_target_perms(Access::RWVA); // Deny for CAT id mismatch - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) .unwrap(); am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Deny of CAT version mismatch - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap(); am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for CAT match - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); @@ -719,14 +784,14 @@ mod tests { req.set_target_perms(Access::RWVA); // Deny for CAT id mismatch - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) .unwrap(); am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), false); // Allow for CAT match and version more than ACL version - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); am.borrow_mut().add(new).unwrap(); assert_eq!(req.allow(), true); @@ -742,7 +807,7 @@ mod tests { req.set_target_perms(Access::RWVA); // Deny for target mismatch - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_target(Target { cluster: Some(2), endpoint: Some(4567), @@ -753,7 +818,7 @@ mod tests { assert_eq!(req.allow(), false); // Allow for cluster match - subject wildcard - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_target(Target { cluster: Some(1234), endpoint: None, @@ -767,7 +832,7 @@ mod tests { am.borrow_mut().erase_all().unwrap(); // Allow for endpoint match - subject wildcard - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_target(Target { cluster: None, endpoint: Some(1), @@ -781,7 +846,7 @@ mod tests { am.borrow_mut().erase_all().unwrap(); // Allow for exact match - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_target(Target { cluster: Some(1234), endpoint: Some(1), @@ -801,7 +866,7 @@ mod tests { let path = GenericPath::new(Some(1), Some(1234), None); // Create an Exact Match ACL with View privilege - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_target(Target { cluster: Some(1234), endpoint: Some(1), @@ -817,7 +882,7 @@ mod tests { assert_eq!(req.allow(), false); // Create an Exact Match ACL with Admin privilege - let mut new = AclEntry::new(2, Privilege::ADMIN, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::ADMIN, AuthMode::Case); new.add_target(Target { cluster: Some(1234), endpoint: Some(1), @@ -846,19 +911,19 @@ mod tests { req3.set_target_perms(Access::RWVA); // Allow for subject match - target is wildcard - Fabric idx 2 - let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.borrow_mut().add(new).unwrap(); + assert_eq!(am.borrow_mut().add(new).unwrap(), 0); // Allow for subject match - target is wildcard - Fabric idx 3 - let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case); + let mut new = AclEntry::new(FAB_3, Privilege::VIEW, AuthMode::Case); new.add_subject(112233).unwrap(); - am.borrow_mut().add(new).unwrap(); + assert_eq!(am.borrow_mut().add(new).unwrap(), 0); // Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed assert_eq!(req2.allow(), true); assert_eq!(req3.allow(), true); - am.borrow_mut().delete_for_fabric(2).unwrap(); + am.borrow_mut().delete_for_fabric(FAB_2).unwrap(); assert_eq!(req2.allow(), false); assert_eq!(req3.allow(), true); } diff --git a/rs-matter/src/data_model/core.rs b/rs-matter/src/data_model/core.rs index 162ba3f8..6b2d1638 100644 --- a/rs-matter/src/data_model/core.rs +++ b/rs-matter/src/data_model/core.rs @@ -17,6 +17,7 @@ use core::cell::{Cell, RefCell}; use core::iter::Peekable; +use core::num::NonZeroU8; use core::pin::pin; use core::time::Duration; @@ -53,7 +54,7 @@ const MAX_WRITE_ATTRS_IN_ONE_TRANS: usize = 7; pub type IMBuffer = heapless::Vec; struct SubscriptionBuffer { - fabric_idx: u8, + fabric_idx: NonZeroU8, peer_node_id: u64, subscription_id: u32, buffer: B, @@ -369,7 +370,8 @@ where debug!("IM: Subscribe request: {:?}", req); let (fabric_idx, peer_node_id) = exchange.with_session(|sess| { - let fabric_idx = sess.get_local_fabric_idx().ok_or(ErrorCode::Invalid)?; + let fabric_idx = + NonZeroU8::new(sess.get_local_fabric_idx()).ok_or(ErrorCode::Invalid)?; let peer_node_id = sess.get_peer_node_id().ok_or(ErrorCode::Invalid)?; Ok((fabric_idx, peer_node_id)) @@ -407,7 +409,7 @@ where }); let primed = self - .report_data(id, fabric_idx, peer_node_id, &rx, &mut tx, exchange) + .report_data(id, fabric_idx.get(), peer_node_id, &rx, &mut tx, exchange) .await?; if primed { @@ -523,7 +525,14 @@ where if let Some(mut tx) = self.buffers.get().await { let primed = self - .report_data(id, fabric_idx, peer_node_id, &rx, &mut tx, &mut exchange) + .report_data( + id, + fabric_idx.get(), + peer_node_id, + &rx, + &mut tx, + &mut exchange, + ) .await?; exchange.acknowledge().await?; diff --git a/rs-matter/src/data_model/sdm/failsafe.rs b/rs-matter/src/data_model/sdm/failsafe.rs index 19bfe0cd..449428c0 100644 --- a/rs-matter/src/data_model/sdm/failsafe.rs +++ b/rs-matter/src/data_model/sdm/failsafe.rs @@ -15,6 +15,8 @@ * limitations under the License. */ +use core::num::NonZeroU8; + use crate::{ error::{Error, ErrorCode}, transport::session::SessionMode, @@ -27,13 +29,12 @@ use log::error; enum NocState { NocNotRecvd, // This is the local fabric index - AddNocRecvd(u8), - UpdateNocRecvd(u8), + AddNocRecvd(NonZeroU8), + UpdateNocRecvd(NonZeroU8), } #[derive(PartialEq)] pub struct ArmedCtx { - session_mode: SessionMode, timeout: u16, noc_state: NocState, } @@ -58,16 +59,26 @@ impl FailSafe { match &mut self.state { State::Idle => { self.state = State::Armed(ArmedCtx { - session_mode, timeout, noc_state: NocState::NocNotRecvd, }) } State::Armed(c) => { - if c.session_mode != session_mode { - error!("Received Fail-Safe Arm with different session modes; current {:?}, incoming {:?}", c.session_mode, session_mode); - Err(ErrorCode::Invalid)?; + match c.noc_state { + NocState::NocNotRecvd => (), + NocState::AddNocRecvd(fab_idx) | NocState::UpdateNocRecvd(fab_idx) => { + if let Some(sess_fab_idx) = NonZeroU8::new(session_mode.fab_idx()) { + if sess_fab_idx != fab_idx { + error!("Received Fail-Safe Re-arm with a different fabric index from a previous Add/Update NOC"); + Err(ErrorCode::Invalid)?; + } + } else { + error!("Received Fail-Safe Re-arm from a session that does not have a fabric index"); + Err(ErrorCode::Invalid)?; + } + } } + // re-arm c.timeout = timeout; } @@ -83,17 +94,20 @@ impl FailSafe { } State::Armed(c) => { match c.noc_state { - NocState::NocNotRecvd => Err(ErrorCode::Invalid)?, - NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => { - if let SessionMode::Case(c) = session_mode { - if c.fab_idx != idx { - error!( - "Received disarm in separate session from previous Add/Update NOC" - ); + NocState::NocNotRecvd => { + error!("Received Fail-Safe Disarm, yet the failsafe has not received Add/Update NOC first"); + Err(ErrorCode::Invalid)?; + } + NocState::AddNocRecvd(fab_idx) | NocState::UpdateNocRecvd(fab_idx) => { + if let Some(sess_fab_idx) = NonZeroU8::new(session_mode.fab_idx()) { + if sess_fab_idx != fab_idx { + error!("Received disarm with different fabric index from a previous Add/Update NOC"); Err(ErrorCode::Invalid)?; } } else { - error!("Received disarm in a non-CASE session"); + error!( + "Received disarm from a session that does not have a fabric index" + ); Err(ErrorCode::Invalid)?; } } @@ -108,7 +122,7 @@ impl FailSafe { self.state != State::Idle } - pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> { + pub fn record_add_noc(&mut self, fabric_index: NonZeroU8) -> Result<(), Error> { match &mut self.state { State::Idle => Err(ErrorCode::Invalid.into()), State::Armed(c) => { diff --git a/rs-matter/src/data_model/sdm/general_commissioning.rs b/rs-matter/src/data_model/sdm/general_commissioning.rs index 93ce2c18..5d185b35 100644 --- a/rs-matter/src/data_model/sdm/general_commissioning.rs +++ b/rs-matter/src/data_model/sdm/general_commissioning.rs @@ -21,6 +21,7 @@ use crate::data_model::objects::*; use crate::data_model::sdm::failsafe::FailSafe; use crate::tlv::{FromTLV, TLVElement, ToTLV, UtfStr}; use crate::transport::exchange::Exchange; +use crate::transport::session::SessionMode; use crate::utils::rand::Rand; use crate::{attribute_enum, cmd_enter}; use crate::{command_enum, error::*}; @@ -267,9 +268,8 @@ impl<'a> GenCommCluster<'a> { let mut status: u8 = CommissioningErrorEnum::OK as u8; // Has to be a Case Session - if exchange - .with_session(|sess| Ok(sess.get_local_fabric_idx()))? - .is_none() + if !exchange + .with_session(|sess| Ok(matches!(sess.get_session_mode(), SessionMode::Case { .. })))? { status = CommissioningErrorEnum::InvalidAuthentication as u8; } diff --git a/rs-matter/src/data_model/sdm/noc.rs b/rs-matter/src/data_model/sdm/noc.rs index eee2cc5f..a339477c 100644 --- a/rs-matter/src/data_model/sdm/noc.rs +++ b/rs-matter/src/data_model/sdm/noc.rs @@ -16,6 +16,7 @@ */ use core::cell::RefCell; +use core::num::NonZeroU8; use crate::acl::{AclEntry, AclMgr, AuthMode}; use crate::cert::{Cert, MAX_CERT_TLV_LEN}; @@ -211,7 +212,7 @@ struct CertChainReq { #[derive(FromTLV)] struct RemoveFabricReq { - fab_idx: u8, + fab_idx: NonZeroU8, } #[derive(Clone)] @@ -265,7 +266,7 @@ impl<'a> NocCluster<'a> { Attributes::Fabrics(_) => { writer.start_array(AttrDataWriter::TAG)?; self.fabric_mgr.borrow().for_each(|entry, fab_idx| { - if !attr.fab_filter || attr.fab_idx == fab_idx { + if !attr.fab_filter || attr.fab_idx == fab_idx.get() { let root_ca_cert = entry.get_root_ca()?; entry @@ -321,17 +322,11 @@ impl<'a> NocCluster<'a> { Ok(()) } - fn add_acl(&self, fab_idx: u8, admin_subject: u64) -> Result<(), Error> { - let mut acl = AclEntry::new(fab_idx, Privilege::ADMIN, AuthMode::Case); - acl.add_subject(admin_subject)?; - self.acl_mgr.borrow_mut().add(acl) - } - fn _handle_command_addnoc( &self, exchange: &Exchange, data: &TLVElement, - ) -> Result { + ) -> Result { let noc_data = exchange .with_session(|sess| Ok(sess.take_noc_data()))? .ok_or(NocStatus::MissingCsr)?; @@ -346,16 +341,6 @@ impl<'a> NocCluster<'a> { Err(NocStatus::InsufficientPrivlege)?; } - // TODO - // // This command's processing may take longer, send a stand alone ACK to the peer to avoid any retranmissions - // let ack_send = secure_channel::common::send_mrp_standalone_ack( - // trans.exch, - // trans.session, - // ); - // if ack_send.is_err() { - // error!("Error sending Standalone ACK, falling back to piggybacked ACK"); - // } - let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; @@ -394,10 +379,37 @@ impl<'a> NocCluster<'a> { .add(fabric, self.mdns) .map_err(|_| NocStatus::TableFull)?; - self.add_acl(fab_idx, r.case_admin_subject)?; + let _fab_guard = scopeguard::guard(fab_idx, |fab_idx| { + // Remove the fabric if we fail further down this function + self.fabric_mgr + .borrow_mut() + .remove(fab_idx, self.mdns) + .unwrap(); + }); + + let mut acl = AclEntry::new(fab_idx, Privilege::ADMIN, AuthMode::Case); + acl.add_subject(r.case_admin_subject)?; + let acl_entry_index = self.acl_mgr.borrow_mut().add(acl)?; + + let _acl_guard = scopeguard::guard(fab_idx, |fab_idx| { + // Remove the ACL entry if we fail further down this function + self.acl_mgr + .borrow_mut() + .delete(acl_entry_index, fab_idx) + .unwrap(); + }); self.failsafe.borrow_mut().record_add_noc(fab_idx)?; + // Finally, upgrade our session with the new fabric index + exchange.with_session(|sess| { + if matches!(sess.get_session_mode(), SessionMode::Pase { .. }) { + sess.upgrade_fabric_idx(fab_idx)?; + } + + Ok(()) + })?; + Ok(fab_idx) } @@ -426,21 +438,21 @@ impl<'a> NocCluster<'a> { ) -> Result<(), Error> { cmd_enter!("Update Fabric Label"); let req = UpdateFabricLabelReq::from_tlv(data).map_err(Error::map_invalid_data_type)?; - let (result, fab_idx) = if let SessionMode::Case(c) = + let (result, fab_idx) = if let SessionMode::Case { fab_idx, .. } = exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? { if self .fabric_mgr .borrow_mut() .set_label( - c.fab_idx, + fab_idx, req.label.as_str().map_err(Error::map_invalid_data_type)?, ) .is_err() { - (NocStatus::LabelConflict, c.fab_idx) + (NocStatus::LabelConflict, fab_idx.get()) } else { - (NocStatus::Ok, c.fab_idx) + (NocStatus::Ok, fab_idx.get()) } } else { // Update Fabric Label not allowed @@ -478,7 +490,12 @@ impl<'a> NocCluster<'a> { Ok(()) } else { - Self::create_nocresponse(encoder, NocStatus::InvalidFabricIndex, req.fab_idx, "") + Self::create_nocresponse( + encoder, + NocStatus::InvalidFabricIndex, + req.fab_idx.get(), + "", + ) } } @@ -491,7 +508,7 @@ impl<'a> NocCluster<'a> { cmd_enter!("AddNOC"); let (status, fab_idx) = match self._handle_command_addnoc(exchange, data) { - Ok(fab_idx) => (NocStatus::Ok, fab_idx), + Ok(fab_idx) => (NocStatus::Ok, fab_idx.get()), Err(NocError::Status(status)) => (status, 0), Err(NocError::Error(error)) => Err(error)?, }; @@ -644,11 +661,11 @@ impl<'a> NocCluster<'a> { // This may happen on CASE or PASE. For PASE, the existence of NOC Data is necessary match exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))? { - SessionMode::Case(_) => { + SessionMode::Case { .. } => { // TODO - Updating the Trusted RCA of an existing Fabric Self::add_rca_to_session_noc_data(exchange, data)?; } - SessionMode::Pase => { + SessionMode::Pase { .. } => { Self::add_rca_to_session_noc_data(exchange, data)?; } _ => (), diff --git a/rs-matter/src/data_model/subscriptions.rs b/rs-matter/src/data_model/subscriptions.rs index ecb14fa0..e207c511 100644 --- a/rs-matter/src/data_model/subscriptions.rs +++ b/rs-matter/src/data_model/subscriptions.rs @@ -16,6 +16,7 @@ */ use core::cell::RefCell; +use core::num::NonZeroU8; use embassy_sync::blocking_mutex::raw::NoopRawMutex; use embassy_time::Instant; @@ -25,7 +26,7 @@ use portable_atomic::{AtomicU32, Ordering}; use crate::utils::notification::Notification; struct Subscription { - fabric_idx: u8, + fabric_idx: NonZeroU8, peer_node_id: u64, session_id: Option, id: u32, @@ -96,7 +97,7 @@ impl Subscriptions { pub(crate) fn add( &self, - fabric_idx: u8, + fabric_idx: NonZeroU8, peer_node_id: u64, session_id: u32, min_int_secs: u16, @@ -139,7 +140,7 @@ impl Subscriptions { pub(crate) fn remove( &self, - fabric_idx: Option, + fabric_idx: Option, peer_node_id: Option, id: Option, ) { @@ -153,7 +154,10 @@ impl Subscriptions { } } - pub(crate) fn find_removed_session(&self, session_removed: F) -> Option<(u8, u64, u32, u32)> + pub(crate) fn find_removed_session( + &self, + session_removed: F, + ) -> Option<(NonZeroU8, u64, u32, u32)> where F: Fn(u32) -> bool, { @@ -170,7 +174,7 @@ impl Subscriptions { }) } - pub(crate) fn find_expired(&self, now: Instant) -> Option<(u8, u64, Option, u32)> { + pub(crate) fn find_expired(&self, now: Instant) -> Option<(NonZeroU8, u64, Option, u32)> { self.subscriptions.borrow().iter().find_map(|sub| { sub.is_expired(now).then_some(( sub.fabric_idx, @@ -183,7 +187,10 @@ impl Subscriptions { /// Note that this method has a side effect: /// it updates the `reported_at` field of the subscription that is returned. - pub(crate) fn find_report_due(&self, now: Instant) -> Option<(u8, u64, Option, u32)> { + pub(crate) fn find_report_due( + &self, + now: Instant, + ) -> Option<(NonZeroU8, u64, Option, u32)> { self.subscriptions .borrow_mut() .iter_mut() diff --git a/rs-matter/src/data_model/system_model/access_control.rs b/rs-matter/src/data_model/system_model/access_control.rs index b16f5918..aedf3a80 100644 --- a/rs-matter/src/data_model/system_model/access_control.rs +++ b/rs-matter/src/data_model/system_model/access_control.rs @@ -16,6 +16,7 @@ */ use core::cell::RefCell; +use core::num::NonZeroU8; use strum::{EnumDiscriminants, FromRepr}; @@ -99,7 +100,7 @@ impl<'a> AccessControlCluster<'a> { Attributes::Acl(_) => { writer.start_array(AttrDataWriter::TAG)?; self.acl_mgr.borrow().for_each_acl(|entry| { - if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx { + if !attr.fab_filter || attr.fab_idx == entry.fab_idx.get() { entry.to_tlv(&mut writer, TagType::Anonymous)?; } @@ -136,7 +137,11 @@ impl<'a> AccessControlCluster<'a> { match attr.attr_id.try_into()? { Attributes::Acl(_) => { attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| { - self.write_acl_attr(&op, data, attr.fab_idx) + self.write_acl_attr( + &op, + data, + NonZeroU8::new(attr.fab_idx).ok_or(ErrorCode::Invalid)?, + ) }) } _ => { @@ -154,7 +159,7 @@ impl<'a> AccessControlCluster<'a> { &self, op: &ListOperation, data: &TLVElement, - fab_idx: u8, + fab_idx: NonZeroU8, ) -> Result<(), Error> { info!("Performing ACL operation {:?}", op); match op { @@ -162,15 +167,17 @@ impl<'a> AccessControlCluster<'a> { let mut acl_entry = AclEntry::from_tlv(data)?; info!("ACL {:?}", acl_entry); // Overwrite the fabric index with our accessing fabric index - acl_entry.fab_idx = Some(fab_idx); + acl_entry.fab_idx = fab_idx; if let ListOperation::EditItem(index) = op { self.acl_mgr .borrow_mut() - .edit(*index as u8, fab_idx, acl_entry) + .edit(*index as u8, fab_idx, acl_entry)?; } else { - self.acl_mgr.borrow_mut().add(acl_entry) + self.acl_mgr.borrow_mut().add(acl_entry)?; } + + Ok(()) } ListOperation::DeleteItem(index) => { self.acl_mgr.borrow_mut().delete(*index as u8, fab_idx) @@ -212,6 +219,8 @@ mod tests { use super::AccessControlCluster; + use crate::acl::tests::{FAB_1, FAB_2}; + #[test] /// Add an ACL entry fn acl_cluster_add() { @@ -222,16 +231,16 @@ mod tests { let acl_mgr = RefCell::new(AclMgr::new()); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); - let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, ACL has fabric index 2, but the accessing fabric is 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1 - let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1); + let result = acl.write_acl_attr(&ListOperation::AddItem, &data, FAB_1); assert!(result.is_ok()); - let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); + let verifier = AclEntry::new(FAB_1, Privilege::VIEW, AuthMode::Case); acl_mgr .borrow() .for_each_acl(|a| { @@ -251,21 +260,21 @@ mod tests { // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order let acl_mgr = RefCell::new(AclMgr::new()); let mut verifier = [ - AclEntry::new(2, Privilege::VIEW, AuthMode::Case), - AclEntry::new(1, Privilege::VIEW, AuthMode::Case), - AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), + AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case), + AclEntry::new(FAB_1, Privilege::VIEW, AuthMode::Case), + AclEntry::new(FAB_2, Privilege::ADMIN, AuthMode::Case), ]; for i in &verifier { acl_mgr.borrow_mut().add(i.clone()).unwrap(); } let acl = AccessControlCluster::new(&acl_mgr, dummy_rand); - let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); + let new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow - let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); + let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, FAB_2); // Fabric 2's index 1, is actually our index 2, update the verifier verifier[2] = new; assert!(result.is_ok()); @@ -288,9 +297,9 @@ mod tests { // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order let acl_mgr = RefCell::new(AclMgr::new()); let input = [ - AclEntry::new(2, Privilege::VIEW, AuthMode::Case), - AclEntry::new(1, Privilege::VIEW, AuthMode::Case), - AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), + AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case), + AclEntry::new(FAB_1, Privilege::VIEW, AuthMode::Case), + AclEntry::new(FAB_2, Privilege::ADMIN, AuthMode::Case), ]; for i in &input { acl_mgr.borrow_mut().add(i.clone()).unwrap(); @@ -300,7 +309,7 @@ mod tests { let data = TLVElement::new(TagType::Anonymous, ElementType::True); // Test , Delete Fabric 1's index 0 - let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1); + let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, FAB_1); assert!(result.is_ok()); let verifier = [input[0].clone(), input[2].clone()]; @@ -325,9 +334,9 @@ mod tests { // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order let acl_mgr = RefCell::new(AclMgr::new()); let input = [ - AclEntry::new(2, Privilege::VIEW, AuthMode::Case), - AclEntry::new(1, Privilege::VIEW, AuthMode::Case), - AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), + AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case), + AclEntry::new(FAB_1, Privilege::VIEW, AuthMode::Case), + AclEntry::new(FAB_2, Privilege::ADMIN, AuthMode::Case), ]; for i in input { acl_mgr.borrow_mut().add(i).unwrap(); diff --git a/rs-matter/src/fabric.rs b/rs-matter/src/fabric.rs index 6db99607..cf71a63c 100644 --- a/rs-matter/src/fabric.rs +++ b/rs-matter/src/fabric.rs @@ -16,6 +16,7 @@ */ use core::fmt::Write; +use core::num::NonZeroU8; use byteorder::{BigEndian, ByteOrder, LittleEndian}; use heapless::{String, Vec}; @@ -43,7 +44,7 @@ pub struct FabricDescriptor<'a> { label: UtfStr<'a>, // TODO: Instead of the direct value, we should consider GlobalElements::FabricIndex #[tagval(0xFE)] - pub fab_idx: Option, + pub fab_idx: NonZeroU8, } #[derive(Debug, ToTLV, FromTLV)] @@ -165,7 +166,7 @@ impl Fabric { pub fn get_fabric_desc<'a>( &'a self, - fab_idx: u8, + fab_idx: NonZeroU8, root_ca_cert: &'a Cert, ) -> Result, Error> { let desc = FabricDescriptor { @@ -174,7 +175,7 @@ impl Fabric { fabric_id: self.fabric_id, node_id: self.node_id, label: UtfStr(self.label.as_bytes()), - fab_idx: Some(fab_idx), + fab_idx, }; Ok(desc) @@ -246,7 +247,7 @@ impl FabricMgr { self.changed } - pub fn add(&mut self, f: Fabric, mdns: &dyn Mdns) -> Result { + pub fn add(&mut self, f: Fabric, mdns: &dyn Mdns) -> Result { // Do not re-use slots (if possible) because currently we use the // position of the fabric in the array as a `fabric_index` as per the Matter Core spec // TODO: In future introduce a new field in Fabric to store the fabric index, as @@ -262,23 +263,25 @@ impl FabricMgr { if let Some(index) = slot { self.fabrics[index] = Some(f); - Ok((index + 1) as u8) + // Unwrapping is safe because we explicitly add + 1 here + Ok(NonZeroU8::new(index as u8 + 1).unwrap()) } else { self.fabrics .push(Some(f)) .map_err(|_| ErrorCode::NoSpace) .unwrap(); - Ok(self.fabrics.len() as u8) + // Unwrapping is safe because we just added the entry + Ok(NonZeroU8::new(self.fabrics.len() as u8).unwrap()) } } else { Err(ErrorCode::NoSpace.into()) } } - pub fn remove(&mut self, fab_idx: u8, mdns: &dyn Mdns) -> Result<(), Error> { - if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { - if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { + pub fn remove(&mut self, fab_idx: NonZeroU8, mdns: &dyn Mdns) -> Result<(), Error> { + if fab_idx.get() as usize <= self.fabrics.len() { + if let Some(f) = self.fabrics[(fab_idx.get() - 1) as usize].take() { mdns.remove(&f.mdns_service_name)?; self.changed = true; Ok(()) @@ -290,23 +293,20 @@ impl FabricMgr { } } - pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result { + pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result { for (index, fabric) in self.fabrics.iter().enumerate() { if let Some(fabric) = fabric { if fabric.match_dest_id(random, target).is_ok() { - return Ok(index + 1); + // Unwrapping is safe because we explicitly add + 1 here + return Ok(NonZeroU8::new(index as u8 + 1).unwrap()); } } } Err(ErrorCode::NotFound.into()) } - pub fn get_fabric(&self, idx: usize) -> Result, Error> { - if idx == 0 { - Ok(None) - } else { - Ok(self.fabrics[idx - 1].as_ref()) - } + pub fn get_fabric(&self, idx: NonZeroU8) -> Option<&Fabric> { + self.fabrics[idx.get() as usize - 1].as_ref() } pub fn is_empty(&self) -> bool { @@ -320,17 +320,17 @@ impl FabricMgr { // Parameters to T are the Fabric and its Fabric Index pub fn for_each(&self, mut f: T) -> Result<(), Error> where - T: FnMut(&Fabric, u8) -> Result<(), Error>, + T: FnMut(&Fabric, NonZeroU8) -> Result<(), Error>, { for (index, fabric) in self.fabrics.iter().enumerate() { if let Some(fabric) = fabric { - f(fabric, (index + 1) as u8)?; + f(fabric, NonZeroU8::new(index as u8 + 1).unwrap())?; } } Ok(()) } - pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { + pub fn set_label(&mut self, index: NonZeroU8, label: &str) -> Result<(), Error> { if !label.is_empty() && self .fabrics @@ -341,7 +341,7 @@ impl FabricMgr { return Err(ErrorCode::Invalid.into()); } - let index = (index - 1) as usize; + let index = (index.get() - 1) as usize; if let Some(fabric) = &mut self.fabrics[index] { fabric.label = label.try_into().unwrap(); self.changed = true; diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index 77384ea8..3052cadb 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -15,6 +15,8 @@ * limitations under the License. */ +use core::num::NonZeroU8; + use log::{error, trace}; use crate::{ @@ -27,7 +29,7 @@ use crate::{ tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, transport::{ exchange::Exchange, - session::{CaseDetails, NocCatIds, ReservedSession, SessionMode}, + session::{NocCatIds, ReservedSession, SessionMode}, }, utils::{rand::Rand, writebuf::WriteBuf}, }; @@ -40,7 +42,7 @@ pub struct CaseSession { shared_secret: [u8; crypto::ECDH_SHARED_SECRET_LEN_BYTES], our_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], - local_fabric_idx: usize, + local_fabric_idx: u8, } impl Default for CaseSession { @@ -103,7 +105,8 @@ impl Case { let status = { let fabric_mgr = exchange.matter().fabric_mgr.borrow(); - let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + let fabric = NonZeroU8::new(case_session.local_fabric_idx) + .and_then(|fabric_idx| fabric_mgr.get_fabric(fabric_idx)); if let Some(fabric) = fabric { let root = get_root_node_struct(exchange.rx()?.payload())?; let encrypted = root.find_tag(1)?.slice()?; @@ -173,10 +176,11 @@ impl Case { case_session.peer_sessid, case_session.local_sessid, peer_addr, - SessionMode::Case(CaseDetails::new( - case_session.local_fabric_idx as u8, - &peer_catids, - )), + SessionMode::Case { + // Unwrapping is safe, because if the fabric index was 0, we would not be in here + fab_idx: NonZeroU8::new(case_session.local_fabric_idx).unwrap(), + cat_ids: peer_catids, + }, Some(&session_keys[0..16]), Some(&session_keys[16..32]), Some(&session_keys[32..48]), @@ -237,7 +241,7 @@ impl Case { .as_mut() .unwrap() .update(exchange.rx()?.payload())?; - case_session.local_fabric_idx = local_fabric_idx?; + case_session.local_fabric_idx = local_fabric_idx?.get(); if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { error!("Invalid public key length"); Err(ErrorCode::Invalid)?; @@ -272,7 +276,8 @@ impl Case { let encrypted_len = { let fabric_mgr = exchange.matter().fabric_mgr.borrow(); - let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + let fabric = NonZeroU8::new(case_session.local_fabric_idx) + .and_then(|fabric_idx| fabric_mgr.get_fabric(fabric_idx)); if let Some(fabric) = fabric { #[cfg(feature = "alloc")] let signature_mut = &mut *signature; diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index 8a53239b..d00434c9 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -204,7 +204,7 @@ impl Pake { peer_sessid, local_sessid, peer_addr, - SessionMode::Pase, + SessionMode::Pase { fab_idx: 0 }, Some(&session_keys[0..16]), Some(&session_keys[16..32]), Some(&session_keys[32..48]), diff --git a/rs-matter/src/tlv/traits.rs b/rs-matter/src/tlv/traits.rs index 168fb503..acb47c01 100644 --- a/rs-matter/src/tlv/traits.rs +++ b/rs-matter/src/tlv/traits.rs @@ -91,7 +91,20 @@ macro_rules! fromtlv_for { }; } +macro_rules! fromtlv_for_nonzero { + ($($t:ident:$n:ty)*) => { + $( + impl<'a> FromTLV<'a> for $n { + fn from_tlv(t: &TLVElement) -> Result { + <$n>::new(t.$t()?).ok_or_else(|| ErrorCode::Invalid.into()) + } + } + )* + }; +} + fromtlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); +fromtlv_for_nonzero!(i8:core::num::NonZeroI8 u8:core::num::NonZeroU8 i16:core::num::NonZeroI16 u16:core::num::NonZeroU16 i32:core::num::NonZeroI32 u32:core::num::NonZeroU32 i64:core::num::NonZeroI64 u64:core::num::NonZeroU64); pub trait ToTLV { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; @@ -118,6 +131,18 @@ macro_rules! totlv_for { }; } +macro_rules! totlv_for_nonzero { + ($($t:ident:$n:ty)*) => { + $( + impl ToTLV for $n { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.$t(tag, self.get()) + } + } + )* + }; +} + impl ToTLV for [T; N] { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { tw.start_array(tag)?; @@ -140,6 +165,7 @@ impl<'a, T: ToTLV> ToTLV for &'a [T] { // Generate ToTLV for standard data types totlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); +totlv_for_nonzero!(i8:core::num::NonZeroI8 u8:core::num::NonZeroU8 i16:core::num::NonZeroI16 u16:core::num::NonZeroU16 i32:core::num::NonZeroI32 u32:core::num::NonZeroU32 i64:core::num::NonZeroI64 u64:core::num::NonZeroU64); // We define a few common data types that will be required here // diff --git a/rs-matter/src/transport/session.rs b/rs-matter/src/transport/session.rs index 731eee03..70110059 100644 --- a/rs-matter/src/transport/session.rs +++ b/rs-matter/src/transport/session.rs @@ -17,6 +17,7 @@ use core::cell::RefCell; use core::fmt; +use core::num::NonZeroU8; use core::time::Duration; use log::{error, info, trace, warn}; @@ -44,30 +45,34 @@ pub type NocCatIds = [u32; MAX_CAT_IDS_PER_NOC]; const MATTER_AES128_KEY_SIZE: usize = 16; -#[derive(Debug, Default, Clone, PartialEq)] -pub struct CaseDetails { - pub fab_idx: u8, - pub cat_ids: NocCatIds, -} - -impl CaseDetails { - pub fn new(fab_idx: u8, cat_ids: &NocCatIds) -> Self { - Self { - fab_idx, - cat_ids: *cat_ids, - } - } -} - -#[derive(Debug, PartialEq, Clone, Default)] +#[derive(Debug, PartialEq, Eq, Clone, Default)] pub enum SessionMode { // The Case session will capture the local fabric index - Case(CaseDetails), - Pase, + // and the local fabric index + Case { + fab_idx: NonZeroU8, + cat_ids: NocCatIds, + }, + // The Pase session always starts with a fabric index of 0 + // (i.e. no fabric) but will be upgraded to the actual fabric index + // once AddNOC or UpdateNOC is received + Pase { + fab_idx: u8, + }, #[default] PlainText, } +impl SessionMode { + pub fn fab_idx(&self) -> u8 { + match self { + SessionMode::Case { fab_idx, .. } => fab_idx.get(), + SessionMode::Pase { fab_idx, .. } => *fab_idx, + SessionMode::PlainText => 0, + } + } +} + pub struct Session { // Internal ID which is guaranteeed to be unique accross all sessions and not change when sessions are added/removed pub(crate) id: u32, @@ -154,7 +159,7 @@ impl Session { pub fn is_encrypted(&self) -> bool { match self.mode { - SessionMode::Case(_) | SessionMode::Pase => true, + SessionMode::Case { .. } | SessionMode::Pase { .. } => true, SessionMode::PlainText => false, } } @@ -163,18 +168,8 @@ impl Session { self.peer_nodeid } - pub fn get_peer_cat_ids(&self) -> Option<&NocCatIds> { - match &self.mode { - SessionMode::Case(a) => Some(&a.cat_ids), - _ => None, - } - } - - pub fn get_local_fabric_idx(&self) -> Option { - match &self.mode { - SessionMode::Case(a) => Some(a.fab_idx), - _ => None, - } + pub fn get_local_fabric_idx(&self) -> u8 { + self.mode.fab_idx() } pub fn get_session_mode(&self) -> &SessionMode { @@ -189,14 +184,14 @@ impl Session { pub fn get_dec_key(&self) -> Option<&[u8]> { match self.mode { - SessionMode::Case(_) | SessionMode::Pase => Some(&self.dec_key), + SessionMode::Case { .. } | SessionMode::Pase { .. } => Some(&self.dec_key), SessionMode::PlainText => None, } } pub fn get_enc_key(&self) -> Option<&[u8]> { match self.mode { - SessionMode::Case(_) | SessionMode::Pase => Some(&self.enc_key), + SessionMode::Case { .. } | SessionMode::Pase { .. } => Some(&self.enc_key), SessionMode::PlainText => None, } } @@ -206,7 +201,7 @@ impl Session { } pub(crate) fn is_for_node(&self, fabric_idx: u8, peer_node_id: u64, secure: bool) -> bool { - self.get_local_fabric_idx() == Some(fabric_idx) + self.get_local_fabric_idx() == fabric_idx && self.peer_nodeid == Some(peer_node_id) && self.is_encrypted() == secure && !self.reserved @@ -224,6 +219,23 @@ impl Session { && !self.reserved } + pub fn upgrade_fabric_idx(&mut self, fabric_idx: NonZeroU8) -> Result<(), Error> { + if let SessionMode::Pase { fab_idx } = &mut self.mode { + if *fab_idx == 0 { + *fab_idx = fabric_idx.get(); + } else { + // Upgrading a PASE session can happen only once + Err(ErrorCode::Invalid)?; + } + } else { + // CASE sessions are not upgradeable, as per spec + // And for plain text sessions - we shoudn't even get here in the first place + Err(ErrorCode::Invalid)?; + } + + Ok(()) + } + /// Update the session state with the data in the received packet headers. /// /// Return `true` if a new exchange was created, and `false` otherwise. @@ -653,12 +665,12 @@ impl SessionMgr { /// This assumes that the higher layer has taken care of doing anything required /// as per the spec before the sessions are removed - pub fn remove_for_fabric(&mut self, fabric_idx: u8) { + pub fn remove_for_fabric(&mut self, fabric_idx: NonZeroU8) { loop { let Some(index) = self .sessions .iter() - .position(|sess| sess.get_local_fabric_idx() == Some(fabric_idx)) + .position(|sess| sess.get_local_fabric_idx() == fabric_idx.get()) else { break; }; diff --git a/rs-matter/tests/common/im_engine.rs b/rs-matter/tests/common/im_engine.rs index 2886d184..bee9e19f 100644 --- a/rs-matter/tests/common/im_engine.rs +++ b/rs-matter/tests/common/im_engine.rs @@ -17,6 +17,7 @@ use crate::common::echo_cluster; use core::borrow::Borrow; +use core::num::NonZeroU8; use embassy_futures::{block_on, join::join, select::select3}; @@ -62,7 +63,7 @@ use rs_matter::{ Address, Ipv4Addr, NetworkReceive, NetworkSend, SocketAddr, SocketAddrV4, MAX_RX_PACKET_SIZE, MAX_TX_PACKET_SIZE, }, - session::{CaseDetails, NocCatIds, ReservedSession, SessionMode}, + session::{NocCatIds, ReservedSession, SessionMode}, }, utils::{buf::PooledBuffers, select::Coalesce}, Matter, MATTER_PORT, @@ -222,7 +223,8 @@ impl<'a> ImEngine<'a> { pub fn add_default_acl(&self) { // Only allow the standard peer node id of the IM Engine - let mut default_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut default_acl = + AclEntry::new(NonZeroU8::new(1).unwrap(), Privilege::ADMIN, AuthMode::Case); default_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); } @@ -270,7 +272,10 @@ impl<'a> ImEngine<'a> { 1, 1, ADDR, - SessionMode::Case(CaseDetails::new(1, cat_ids)), + SessionMode::Case { + fab_idx: NonZeroU8::new(1).unwrap(), + cat_ids: cat_ids.clone(), + }, None, None, None, diff --git a/rs-matter/tests/data_model/acl_and_dataver.rs b/rs-matter/tests/data_model/acl_and_dataver.rs index 18acf303..7cedf18e 100644 --- a/rs-matter/tests/data_model/acl_and_dataver.rs +++ b/rs-matter/tests/data_model/acl_and_dataver.rs @@ -15,6 +15,8 @@ * limitations under the License. */ +use core::num::NonZeroU8; + use rs_matter::{ acl::{gen_noc_cat, AclEntry, AuthMode, Target}, data_model::{ @@ -39,6 +41,11 @@ use crate::{ }, }; +const FAB_1: NonZeroU8 = match NonZeroU8::new(1) { + Some(f) => f, + None => unreachable!(), +}; + #[test] /// Ensure that wildcard read attributes don't include error response /// and silently drop the data when access is not granted @@ -70,7 +77,7 @@ fn wc_read_attribute() { im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to only access endpoint 0 - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -81,7 +88,7 @@ fn wc_read_attribute() { im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to also access endpoint 1 - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -121,7 +128,7 @@ fn exact_read_attribute() { im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to access any endpoint - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -184,7 +191,7 @@ fn wc_write_attribute() { ); // Add ACL to allow our peer to access one endpoint - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -203,7 +210,7 @@ fn wc_write_attribute() { ); // Add ACL to allow our peer to access another endpoint - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(1), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -262,7 +269,7 @@ fn exact_write_attribute() { ); // Add ACL to allow our peer to access any endpoint - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -317,7 +324,7 @@ fn exact_write_attribute_noc_cat() { ); // Add ACL to allow our peer to access any endpoint - let mut acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); acl.add_subject_catid(cat_in_acl).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -350,7 +357,7 @@ fn insufficient_perms_write() { let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission - let mut acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); + let mut acl = AclEntry::new(FAB_1, Privilege::OPERATE, AuthMode::Case); acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); acl.add_target(Target::new(Some(0), None, None)).unwrap(); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); @@ -412,7 +419,7 @@ fn write_with_runtime_acl_add() { ); // Create ACL to allow our peer ADMIN on everything - let mut allow_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut allow_acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); allow_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); let acl_att = GenericPath::new( @@ -427,7 +434,7 @@ fn write_with_runtime_acl_add() { ); // Create ACL that only allows write to the ACL Cluster - let mut basic_acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let mut basic_acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); basic_acl.add_subject(IM_ENGINE_PEER_ID).unwrap(); basic_acl .add_target(Target::new(Some(0), Some(access_control::ID), None)) @@ -462,7 +469,7 @@ fn test_read_data_ver() { let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission - let acl = AclEntry::new(1, Privilege::OPERATE, AuthMode::Case); + let acl = AclEntry::new(FAB_1, Privilege::OPERATE, AuthMode::Case); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); let wc_ep_att1 = GenericPath::new( @@ -566,7 +573,7 @@ fn test_write_data_ver() { let handler = im.handler(); // Add ACL to allow our peer with only OPERATE permission - let acl = AclEntry::new(1, Privilege::ADMIN, AuthMode::Case); + let acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); let wc_ep_attwrite = GenericPath::new(