diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index cc67a468..3ba694b2 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -78,7 +78,7 @@ fn main() -> Result<(), Error> { // e.g., an opt-level of "0" will require a several times' larger stack. // // Optimizing/lowering `rs-matter` memory consumption is an ongoing topic. - .stack_size(65 * 1024) + .stack_size(54 * 1024) .spawn(run) .unwrap(); diff --git a/examples/onoff_light_bt/src/comm.rs b/examples/onoff_light_bt/src/comm.rs index 3c1f2135..88d2d9d8 100644 --- a/examples/onoff_light_bt/src/comm.rs +++ b/examples/onoff_light_bt/src/comm.rs @@ -30,7 +30,7 @@ use rs_matter::data_model::sdm::nw_commissioning::{ use rs_matter::error::{Error, ErrorCode}; use rs_matter::interaction_model::core::IMStatusCode; use rs_matter::interaction_model::messages::ib::Status; -use rs_matter::tlv::{FromTLV, OctetStr, TLVElement}; +use rs_matter::tlv::{FromTLV, Octets, TLVElement, TLVWrite}; use rs_matter::transport::exchange::Exchange; use rs_matter::utils::sync::Notification; @@ -70,7 +70,7 @@ impl<'a> WifiNwCommCluster<'a> { match attr.attr_id.try_into()? { Attributes::MaxNetworks => AttrType::::new().encode(writer, 1_u8), Attributes::Networks => { - writer.start_array(AttrDataWriter::TAG)?; + writer.start_array(&AttrDataWriter::TAG)?; writer.end_container()?; writer.complete() @@ -79,9 +79,7 @@ impl<'a> WifiNwCommCluster<'a> { Attributes::ConnectMaxTimeSecs => AttrType::new().encode(writer, 60_u8), Attributes::InterfaceEnabled => AttrType::new().encode(writer, true), Attributes::LastNetworkingStatus => AttrType::new().encode(writer, 0_u8), - Attributes::LastNetworkID => { - AttrType::new().encode(writer, OctetStr("ssid".as_bytes())) - } + Attributes::LastNetworkID => AttrType::new().encode(writer, Octets("ssid".as_bytes())), Attributes::LastConnectErrorValue => AttrType::new().encode(writer, 0), } } diff --git a/rs-matter-macros-impl/src/tlv.rs b/rs-matter-macros-impl/src/tlv.rs index 101a316c..650bfcc0 100644 --- a/rs-matter-macros-impl/src/tlv.rs +++ b/rs-matter-macros-impl/src/tlv.rs @@ -4,7 +4,8 @@ use proc_macro2::{Ident, Literal, Span, TokenStream}; use quote::{format_ident, quote}; use syn::meta::ParseNestedMeta; use syn::parse::ParseStream; -use syn::{DeriveInput, Lifetime, LitInt, LitStr, Type}; +use syn::token::{Gt, Lt}; +use syn::{DeriveInput, Lifetime, LifetimeParam, LitInt, LitStr, Type}; #[derive(PartialEq, Debug)] struct TlvArgs { @@ -13,6 +14,7 @@ struct TlvArgs { datatype: String, unordered: bool, lifetime: syn::Lifetime, + lifetime_explicit: bool, } impl Default for TlvArgs { @@ -23,6 +25,7 @@ impl Default for TlvArgs { datatype: "struct".to_string(), unordered: false, lifetime: Lifetime::new("'_", Span::call_site()), + lifetime_explicit: false, } } } @@ -37,6 +40,7 @@ impl TlvArgs { } else if meta.path.is_ident("lifetime") { self.lifetime = Lifetime::new(&meta.value()?.parse::()?.value(), Span::call_site()); + self.lifetime_explicit = true; } else if meta.path.is_ident("datatype") { self.datatype = meta.value()?.parse::()?.value(); } else if meta.path.is_ident("unordered") { @@ -91,7 +95,7 @@ fn parse_tag_val(attrs: &[syn::Attribute]) -> Option { /// Given a data type and existing tags, convert them into /// a function to call for read/write (like u8/u16) and a list -/// of numeric liternals of tags (which may be u8 or u16) +/// of numeric literals of tags (which may be u8 or u16) /// /// Ideally we would also be able to figure out the writing type using "repr" data /// however for now we require a "datatype" to be valid @@ -126,6 +130,51 @@ fn get_unit_enum_func_and_tags( /// Generate a ToTlv implementation for a structure fn gen_totlv_for_struct( + data_struct: &syn::DataStruct, + struct_name: &proc_macro2::Ident, + tlvargs: &TlvArgs, + generics: &syn::Generics, +) -> TokenStream { + match &data_struct.fields { + syn::Fields::Named(fields) => { + gen_totlv_for_struct_named(fields, struct_name, tlvargs, generics) + } + syn::Fields::Unnamed(fields) => { + gen_totlv_for_struct_unnamed(fields, struct_name, tlvargs, generics) + } + _ => panic!("Union structs are not supported"), + } +} + +/// Generate a ToTlv implementation for a structure with a single unnamed field +/// The structure is behaving as a Newtype over the unnamed field +fn gen_totlv_for_struct_unnamed( + fields: &syn::FieldsUnnamed, + struct_name: &proc_macro2::Ident, + tlvargs: &TlvArgs, + generics: &syn::Generics, +) -> TokenStream { + if fields.unnamed.len() != 1 { + panic!("Only a single unnamed field supported for unnamed structures"); + } + + let krate = Ident::new(&tlvargs.rs_matter_crate, Span::call_site()); + + quote! { + impl #generics #krate::tlv::ToTLV for #struct_name #generics { + fn to_tlv(&self, tag: &#krate::tlv::TLVTag, mut tw: W) -> Result<(), #krate::error::Error> { + #krate::tlv::ToTLV::to_tlv(&self.0, tag, &mut tw) + } + + fn tlv_iter(&self, tag: #krate::tlv::TLVTag) -> impl Iterator> { + #krate::tlv::ToTLV::tlv_iter(&self.0, tag) + } + } + } +} + +/// Generate a ToTlv implementation for a structure with named fields +fn gen_totlv_for_struct_named( fields: &syn::FieldsNamed, struct_name: &proc_macro2::Ident, tlvargs: &TlvArgs, @@ -151,13 +200,13 @@ fn gen_totlv_for_struct( quote! { impl #generics #krate::tlv::ToTLV for #struct_name #generics { - fn to_tlv(&self, tw: &mut #krate::tlv::TLVWriter, tag_type: #krate::tlv::TagType) -> Result<(), #krate::error::Error> { + fn to_tlv(&self, tag: &#krate::tlv::TLVTag, mut tw: W) -> Result<(), #krate::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { - tw. #datatype (tag_type)?; + tw.#datatype(tag)?; #( - self.#idents.to_tlv(tw, #krate::tlv::TagType::Context(#tags))?; + #krate::tlv::ToTLV::to_tlv(&self.#idents, &#krate::tlv::TLVTag::Context(#tags), &mut tw)?; )* tw.end_container() })() { @@ -167,6 +216,14 @@ fn gen_totlv_for_struct( Ok(()) } } + + fn tlv_iter(&self, tag: #krate::tlv::TLVTag) -> impl Iterator> { + let iter = #krate::tlv::TLV::structure(tag).into_tlv_iter(); + + #(let iter = Iterator::chain(iter, #krate::tlv::ToTLV::tlv_iter(&self.#idents, #krate::tlv::TLVTag::Context(#tags)));)* + + Iterator::chain(iter, #krate::tlv::TLV::end_container().into_tlv_iter()) + } } } } @@ -230,12 +287,12 @@ fn gen_totlv_for_enum( quote! { impl #generics #krate::tlv::ToTLV for #enum_name #generics { - fn to_tlv(&self, tw: &mut #krate::tlv::TLVWriter, tag_type: #krate::tlv::TagType) -> Result<(), #krate::error::Error> { + fn to_tlv(&self, tag: &#krate::tlv::TLVTag, mut tw: W) -> Result<(), #krate::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { match self { - #( Self::#variant_names => tw.#write_func(tag_type, #tags), )* + #( Self::#variant_names => tw.#write_func(tag, #tags), )* } })() { tw.rewind_to(anchor); @@ -244,40 +301,112 @@ fn gen_totlv_for_enum( Ok(()) } } + + fn tlv_iter(&self, tag: #krate::tlv::TLVTag) -> impl Iterator> { + match self { + #( Self::#variant_names => #krate::tlv::TLV::#write_func(tag, #tags).into_tlv_iter(), )* + } + } } } } else { // tags MUST be context-tags (up to u8 range) - if tags.iter().any(|v| *v > 0xFF) { + if tags.iter().any(|v| *v > u8::MAX as _) { panic!( "Enum discriminator value larger that 0xFF for {:?}", enum_name ) } + if tags.len() > 6 { + panic!("More than 6 enum variants for {:?}", enum_name) + } + + let either_ident = if tags.len() != 2 { + format_ident!("Either{}Iter", tags.len()) + } else { + format_ident!("EitherIter") + }; + + let either_variants = (0..tags.len()) + .map(|t| match t { + 0 => "First", + 1 => "Second", + 2 => "Third", + 3 => "Fourth", + 4 => "Fifth", + 5 => "Sixth", + _ => unreachable!(), + }) + .map(|t| format_ident!("{}", t)) + .collect::>(); + let tags = tags .into_iter() .map(|v| Literal::u8_suffixed(v as u8)) .collect::>(); - quote! { - impl #generics #krate::tlv::ToTLV for #enum_name #generics { - fn to_tlv(&self, tw: &mut #krate::tlv::TLVWriter, tag_type: #krate::tlv::TagType) -> Result<(), #krate::error::Error> { - let anchor = tw.get_tail(); + if tlvargs.datatype == "naked" { + quote! { + impl #generics #krate::tlv::ToTLV for #enum_name #generics { + fn to_tlv(&self, tag: &#krate::tlv::TLVTag, mut tw: W) -> Result<(), #krate::error::Error> { + let anchor = tw.get_tail(); - if let Err(err) = (|| { - tw.start_struct(tag_type)?; + if let Err(err) = (|| { + match self { + #( + Self::#variant_names(c) => { #krate::tlv::ToTLV::to_tlv(c, &#krate::tlv::TLVTag::Context(#tags), &mut tw) } + )* + } + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + + fn tlv_iter(&self, tag: #krate::tlv::TLVTag) -> impl Iterator> { match self { #( - Self::#variant_names(c) => { c.to_tlv(tw, #krate::tlv::TagType::Context(#tags))?; }, + Self::#variant_names(c) => #krate::tlv::#either_ident::#either_variants(#krate::tlv::ToTLV::tlv_iter(c, #krate::tlv::TLVTag::Context(#tags))), )* } - tw.end_container() - })() { - tw.rewind_to(anchor); - Err(err) - } else { - Ok(()) + } + } + } + } else { + quote! { + impl #generics #krate::tlv::ToTLV for #enum_name #generics { + fn to_tlv(&self, tag: &#krate::tlv::TLVTag, mut tw: W) -> Result<(), #krate::error::Error> { + let anchor = tw.get_tail(); + + if let Err(err) = (|| { + tw.start_struct(tag)?; + match self { + #( + Self::#variant_names(c) => #krate::tlv::ToTLV::to_tlv(c, &#krate::tlv::TLVTag::Context(#tags), &mut tw), + )* + }?; + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + + fn tlv_iter(&self, tag: #krate::tlv::TLVTag) -> impl Iterator> { + let iter = #krate::tlv::TLV::structure(tag).into_tlv_iter(); + + let iter = Iterator::chain(iter, match self { + #( + Self::#variant_names(c) => #krate::tlv::#either_ident::#either_variants(#krate::tlv::ToTLV::tlv_iter(c, #krate::tlv::TLVTag::Context(#tags))), + )* + }); + + Iterator::chain(iter, #krate::tlv::TLV::end_container().into_tlv_iter()) } } } @@ -316,39 +445,108 @@ pub fn derive_totlv(ast: DeriveInput, rs_matter_crate: String) -> TokenStream { let tlvargs = parse_tlvargs(&ast, rs_matter_crate); let generics = ast.generics; - if let syn::Data::Struct(syn::DataStruct { - fields: syn::Fields::Named(ref fields), - .. - }) = ast.data - { - gen_totlv_for_struct(fields, name, &tlvargs, &generics) - } else if let syn::Data::Enum(data_enum) = ast.data { - gen_totlv_for_enum(&data_enum, name, &tlvargs, &generics) - } else { - panic!( - "Derive ToTLV - Only supported struct and enum for now {:?}", - ast.data - ); + match &ast.data { + syn::Data::Struct(data_struct) => { + gen_totlv_for_struct(data_struct, name, &tlvargs, &generics) + } + syn::Data::Enum(data_enum) => gen_totlv_for_enum(data_enum, name, &tlvargs, &generics), + _ => panic!("Derive ToTLV - Only supported struct and enum for now"), } } /// Generate a FromTlv implementation for a structure fn gen_fromtlv_for_struct( + data_struct: &syn::DataStruct, + struct_name: &proc_macro2::Ident, + tlvargs: TlvArgs, + generics: &syn::Generics, +) -> TokenStream { + match &data_struct.fields { + syn::Fields::Named(fields) => { + gen_fromtlv_for_struct_named(fields, struct_name, tlvargs, generics) + } + syn::Fields::Unnamed(fields) => { + gen_fromtlv_for_struct_unnamed(fields, struct_name, tlvargs, generics) + } + _ => panic!("Union structs are not supported"), + } +} + +/// Generate a FromTlv implementation for a structure with a single unnamed field +/// The structure is behaving as a Newtype over the unnamed field +fn gen_fromtlv_for_struct_unnamed( + fields: &syn::FieldsUnnamed, + struct_name: &proc_macro2::Ident, + tlvargs: TlvArgs, + generics: &syn::Generics, +) -> TokenStream { + if fields.unnamed.len() != 1 { + panic!("Only a single unnamed field supported for unnamed structures"); + } + + let krate = Ident::new(&tlvargs.rs_matter_crate, Span::call_site()); + let lifetime = tlvargs.lifetime; + let ty = normalize_fromtlv_type(&fields.unnamed[0].ty); + + quote! { + impl #generics #krate::tlv::FromTLV<#lifetime> for #struct_name #generics { + fn from_tlv(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + Ok(Self(#ty::from_tlv(element)?)) + } + } + + impl #generics TryFrom<&#krate::tlv::TLVElement<#lifetime>> for #struct_name #generics { + type Error = #krate::error::Error; + + fn try_from(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + use #krate::tlv::FromTLV; + + Self::from_tlv(element) + } + } + } +} + +/// Generate a ToTlv implementation for a structure with named fields +fn gen_fromtlv_for_struct_named( fields: &syn::FieldsNamed, struct_name: &proc_macro2::Ident, tlvargs: TlvArgs, generics: &syn::Generics, ) -> TokenStream { let mut tag_start = tlvargs.start; - let lifetime = tlvargs.lifetime; - let datatype = format_ident!("confirm_{}", tlvargs.datatype); + + let (lifetime, impl_generics) = if tlvargs.lifetime_explicit { + (tlvargs.lifetime, generics.clone()) + } else { + // The `'_` default lifetime from tlvargs won't do. + // We need a named lifetime that has to be part of the `impl<>` block. + + let lifetime = Lifetime::new("'__from_tlv", Span::call_site()); + + let mut impl_generics = generics.clone(); + + if impl_generics.gt_token.is_none() { + impl_generics.gt_token = Some(Gt::default()); + impl_generics.lt_token = Some(Lt::default()); + } + + impl_generics + .params + .push(syn::GenericParam::Lifetime(LifetimeParam::new( + lifetime.clone(), + ))); + + (lifetime, impl_generics) + }; + + let datatype = format_ident!("r#{}", tlvargs.datatype); let mut idents = Vec::new(); let mut types = Vec::new(); let mut tags = Vec::new(); for field in fields.named.iter() { - let type_name = &field.ty; if let Some(a) = parse_tag_val(&field.attrs) { // TODO: The current limitation with this is that a hard-coded integer // value has to be mentioned in the tagval attribute. This is because @@ -361,67 +559,47 @@ fn gen_fromtlv_for_struct( } idents.push(&field.ident); - if let Type::Path(path) = type_name { - // When paths are like `matter_rs::tlv::Nullable` - // this ignores the arguments and just does: - // `matter_rs::tlv::Nullable` - let idents = path - .path - .segments - .iter() - .map(|s| s.ident.clone()) - .collect::>(); - types.push(quote!(#(#idents)::*)); - } else { - panic!("Don't know what to do {:?}", type_name); - } + types.push(normalize_fromtlv_type(&field.ty)); } let krate = Ident::new(&tlvargs.rs_matter_crate, Span::call_site()); + let seq_method = format_ident!("{}_ctx", if tlvargs.unordered { "find" } else { "scan" }); - // Currently we don't use find_tag() because the tags come in sequential - // order. If ever the tags start coming out of order, we can use find_tag() - // instead - if !tlvargs.unordered { - quote! { - impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics { - fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { - let mut t_iter = t.#datatype ()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; - let mut item = t_iter.next(); - #( - let #idents = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(#tags)) { - let backup = item; - item = t_iter.next(); - #types::from_tlv(&backup.unwrap()) - } else { - #types::tlv_not_found() - }?; - )* - Ok(Self { - #(#idents, - )* - }) - } - } + quote! { + impl #impl_generics #krate::tlv::FromTLV<#lifetime> for #struct_name #generics { + fn from_tlv(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + #[allow(unused_mut)] + let mut seq = element.#datatype()?; + + Ok(Self { + #(#idents: #types::from_tlv(&seq.#seq_method(#tags)?)?, + )* + }) + } + + fn init_from_tlv(element: #krate::tlv::TLVElement<#lifetime>) -> impl #krate::utils::init::Init { + #krate::utils::init::into_init(move || { + #[allow(unused_mut)] + let mut seq = element.#datatype()?; + + let init = #krate::utils::init::try_init!(Self { + #(#idents <- #types::init_from_tlv(seq.#seq_method(#tags)?), + )* + }? #krate::error::Error); + + Ok(init) + }) + } } - } else { - quote! { - impl #generics #krate::tlv::FromTLV <#lifetime> for #struct_name #generics { - fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { - #( - let #idents = if let Ok(s) = t.find_tag(#tags as u32) { - #types::from_tlv(&s) - } else { - #types::tlv_not_found() - }?; - )* - - Ok(Self { - #(#idents, - )* - }) - } - } + + impl #impl_generics TryFrom<&#krate::tlv::TLVElement<#lifetime>> for #struct_name #generics { + type Error = #krate::error::Error; + + fn try_from(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + use #krate::tlv::FromTLV; + + Self::from_tlv(element) + } } } } @@ -464,7 +642,7 @@ fn gen_fromtlv_for_enum( } if variant_types.contains(&FieldTypes::Unnamed) && variant_types.contains(&FieldTypes::Unit) { - // You should have enum Foo {A,B,C} OR Foo{A(X), B(Y), ...} + // You should have enum Foo { A, B, C } OR Foo { A(X), B(Y), .. } // Combining them does not work panic!("Enum contains both unit and unnamed fields. This is not supported."); } @@ -481,18 +659,29 @@ fn gen_fromtlv_for_enum( let krate = Ident::new(&tlvargs.rs_matter_crate, Span::call_site()); if variant_types.contains(&FieldTypes::Unit) { - let (read_func, tags) = + let (elem_read_method, tags) = get_unit_enum_func_and_tags(enum_name, tlvargs.datatype.as_str(), tags); quote! { - impl #generics #krate::tlv::FromTLV <#lifetime> for #enum_name #generics { - fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { - Ok(match t.#read_func()? { - #( #tags => Self::#variant_names, )* - _ => return Err(#krate::error::Error::new(#krate::error::ErrorCode::Invalid)), - }) - } - } + impl #generics #krate::tlv::FromTLV<#lifetime> for #enum_name #generics { + fn from_tlv(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + Ok(match element.#elem_read_method()? { + #(#tags => Self::#variant_names, + )* + _ => Err(#krate::error::ErrorCode::Invalid)?, + }) + } + } + + impl #generics TryFrom<&#krate::tlv::TLVElement<#lifetime>> for #enum_name #generics { + type Error = #krate::error::Error; + + fn try_from(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + use #krate::tlv::FromTLV; + + Self::from_tlv(element) + } + } } } else { // tags MUST be context-tags (up to u8 range) @@ -521,27 +710,66 @@ fn gen_fromtlv_for_enum( } } + let enter = (tlvargs.datatype != "naked") + .then(|| { + quote! { + let element = element + .r#struct()? + .iter() + .next() + .ok_or(#krate::error::ErrorCode::TLVTypeMismatch)??; + } + }) + .unwrap_or(TokenStream::new()); + quote! { - impl #generics #krate::tlv::FromTLV <#lifetime> for #enum_name #generics { - fn from_tlv(t: &#krate::tlv::TLVElement<#lifetime>) -> Result { - let mut t_iter = t.confirm_struct()?.enter().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; - let mut item = t_iter.next().ok_or_else(|| #krate::error::Error::new(#krate::error::ErrorCode::Invalid))?; - if let TagType::Context(tag) = item.get_tag() { - match tag { - #( - #tags => Ok(Self::#variant_names(#types::from_tlv(&item)?)), - )* - _ => Err(#krate::error::Error::new(#krate::error::ErrorCode::Invalid)), - } - } else { - Err(#krate::error::Error::new(#krate::error::ErrorCode::TLVTypeMismatch)) - } - } - } + impl #generics #krate::tlv::FromTLV<#lifetime> for #enum_name #generics { + fn from_tlv(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + #enter + + let tag = element + .try_ctx()? + .ok_or(#krate::error::ErrorCode::TLVTypeMismatch)?; + + Ok(match tag { + #(#tags => Self::#variant_names(#types::from_tlv(&element)?), + )* + _ => Err(#krate::error::ErrorCode::Invalid)?, + }) + } + } + + impl #generics TryFrom<&#krate::tlv::TLVElement<#lifetime>> for #enum_name #generics { + type Error = #krate::error::Error; + + fn try_from(element: &#krate::tlv::TLVElement<#lifetime>) -> Result { + use #krate::tlv::FromTLV; + + Self::from_tlv(element) + } + } } } } +fn normalize_fromtlv_type(ty: &syn::Type) -> TokenStream { + let Type::Path(type_path) = ty else { + panic!("Don't know what to do {:?}", ty); + }; + + // When paths are like `matter_rs::tlv::Nullable` + // this ignores the arguments and just does: + // `matter_rs::tlv::Nullable` + let type_idents = type_path + .path + .segments + .iter() + .map(|s| s.ident.clone()) + .collect::>(); + + quote!(#(#type_idents)::*) +} + /// Derive FromTLV Macro /// /// This macro works for structures. It will create an implementation @@ -576,19 +804,12 @@ pub fn derive_fromtlv(ast: DeriveInput, rs_matter_crate: String) -> TokenStream let generics = ast.generics; - if let syn::Data::Struct(syn::DataStruct { - fields: syn::Fields::Named(ref fields), - .. - }) = ast.data - { - gen_fromtlv_for_struct(fields, name, tlvargs, &generics) - } else if let syn::Data::Enum(data_enum) = ast.data { - gen_fromtlv_for_enum(&data_enum, name, tlvargs, &generics) - } else { - panic!( - "Derive FromTLV - Only supported Struct for now {:?}", - ast.data - ) + match &ast.data { + syn::Data::Struct(data_struct) => { + gen_fromtlv_for_struct(data_struct, name, tlvargs, &generics) + } + syn::Data::Enum(data_enum) => gen_fromtlv_for_enum(data_enum, name, tlvargs, &generics), + _ => panic!("Derive FromTLV - Only supported struct and enum for now"), } } @@ -664,27 +885,44 @@ mod tests { &derive_totlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( impl rs_matter_maybe_renamed::tlv::ToTLV for TestS { - fn to_tlv( - &self, - tw: &mut rs_matter_maybe_renamed::tlv::TLVWriter, - tag_type: rs_matter_maybe_renamed::tlv::TagType + fn to_tlv( + &self, + tag: &rs_matter_maybe_renamed::tlv::TLVTag, + mut tw: W, ) -> Result<(), rs_matter_maybe_renamed::error::Error> { - let anchor = tw.get_tail(); - if let Err(err) = (|| { - tw.start_struct(tag_type)?; - self.field1 - .to_tlv(tw, rs_matter_maybe_renamed::tlv::TagType::Context(0u8))?; - self.field2 - .to_tlv(tw, rs_matter_maybe_renamed::tlv::TagType::Context(1u8))?; - tw.end_container() - })() { - tw.rewind_to(anchor); - Err(err) - } else { - Ok(()) - } - } - } + let anchor = tw.get_tail(); + if let Err(err) = (|| { + tw.start_struct(tag)?; + rs_matter_maybe_renamed::tlv::ToTLV::to_tlv(&self.field1, &rs_matter_maybe_renamed::tlv::TLVTag::Context(0u8), &mut tw)?; + rs_matter_maybe_renamed::tlv::ToTLV::to_tlv(&self.field2, &rs_matter_maybe_renamed::tlv::TLVTag::Context(1u8), &mut tw)?; + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + + fn tlv_iter( + &self, + tag: rs_matter_maybe_renamed::tlv::TLVTag, + ) -> impl Iterator> { + let iter = rs_matter_maybe_renamed::tlv::TLV::structure(tag).into_tlv_iter(); + + let iter = Iterator::chain( + iter, + rs_matter_maybe_renamed::tlv::ToTLV::tlv_iter(&self.field1,rs_matter_maybe_renamed::tlv::TLVTag::Context(0u8)), + ); + + let iter = Iterator::chain( + iter, + rs_matter_maybe_renamed::tlv::ToTLV::tlv_iter(&self.field2, rs_matter_maybe_renamed::tlv::TLVTag::Context(1u8)), + ); + + Iterator::chain(iter, rs_matter_maybe_renamed::tlv::TLV::end_container().into_tlv_iter()) + } + } ) ); } @@ -704,52 +942,55 @@ mod tests { assert_tokenstreams_eq!( &derive_fromtlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( - impl rs_matter_maybe_renamed::tlv::FromTLV<'_> for TestS { - fn from_tlv( - t: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, - ) -> Result { - let mut t_iter = t.confirm_struct()?.enter().ok_or_else(|| { - rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::Invalid, - ) - })?; - let mut item = t_iter.next(); - let field1 = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(0u8)) { - let backup = item; - item = t_iter.next(); - u8::from_tlv(&backup.unwrap()) - } else { - u8::tlv_not_found() - }?; - let field2 = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(1u8)) { - let backup = item; - item = t_iter.next(); - u32::from_tlv(&backup.unwrap()) - } else { - u32::tlv_not_found() - }?; - let field_opt = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(2u8)) { - let backup = item; - item = t_iter.next(); - Option::from_tlv(&backup.unwrap()) - } else { - Option::tlv_not_found() - }?; - let field_null = if Some(true) == item.as_ref().map(|x| x.check_ctx_tag(3u8)) { - let backup = item; - item = t_iter.next(); - rs_matter_maybe_renamed::tlv::Nullable::from_tlv(&backup.unwrap()) - } else { - rs_matter_maybe_renamed::tlv::Nullable::tlv_not_found() - }?; - Ok(Self { - field1, - field2, - field_opt, - field_null, - }) - } - } + impl<'__from_tlv> rs_matter_maybe_renamed::tlv::FromTLV<'__from_tlv> for TestS { + fn from_tlv( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'__from_tlv>, + ) -> Result { + #[allow(unused_mut)] + let mut seq = element.r#struct()?; + + Ok(Self { + field1: u8::from_tlv(&seq.scan_ctx(0u8)?)?, + field2: u32::from_tlv(&seq.scan_ctx(1u8)?)?, + field_opt: Option::from_tlv(&seq.scan_ctx(2u8)?)?, + field_null: rs_matter_maybe_renamed::tlv::Nullable::from_tlv( + &seq.scan_ctx(3u8)?, + )?, + }) + } + + fn init_from_tlv( + element: rs_matter_maybe_renamed::tlv::TLVElement<'__from_tlv>, + ) -> impl rs_matter_maybe_renamed::utils::init::Init< + Self, + rs_matter_maybe_renamed::error::Error, + > { + rs_matter_maybe_renamed::utils::init::into_init(move || { + #[allow(unused_mut)] + let mut seq = element.r#struct()?; + + let init = rs_matter_maybe_renamed::utils::init::try_init!(Self { + field1 <- u8::init_from_tlv(seq.scan_ctx(0u8)?), + field2 <- u32::init_from_tlv(seq.scan_ctx(1u8)?), + field_opt <- Option::init_from_tlv(seq.scan_ctx(2u8)?), + field_null <- rs_matter_maybe_renamed::tlv::Nullable::init_from_tlv(seq.scan_ctx(3u8)?), + }? rs_matter_maybe_renamed::error::Error); + + Ok(init) + }) + } + } + + impl<'__from_tlv> TryFrom<&rs_matter_maybe_renamed::tlv::TLVElement<'__from_tlv>> for TestS { + type Error = rs_matter_maybe_renamed::error::Error; + + fn try_from( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'__from_tlv>, + ) -> Result { + use rs_matter_maybe_renamed::tlv::FromTLV; + Self::from_tlv(element) + } + } ) ); } @@ -768,29 +1009,46 @@ mod tests { &derive_totlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( impl rs_matter_maybe_renamed::tlv::ToTLV for TestEnum { - fn to_tlv( + fn to_tlv( &self, - tw: &mut rs_matter_maybe_renamed::tlv::TLVWriter, - tag_type: rs_matter_maybe_renamed::tlv::TagType, + tag: &rs_matter_maybe_renamed::tlv::TLVTag, + mut tw: W, ) -> Result<(), rs_matter_maybe_renamed::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { - tw.start_struct(tag_type)?; + tw.start_struct(tag)?; + match self { + Self::ValueA(c) => rs_matter_maybe_renamed::tlv::ToTLV::to_tlv(c, &rs_matter_maybe_renamed::tlv::TLVTag::Context(0u8), &mut tw), + Self::ValueB(c) => rs_matter_maybe_renamed::tlv::ToTLV::to_tlv(c, &rs_matter_maybe_renamed::tlv::TLVTag::Context(1u8), &mut tw), + }?; + tw.end_container() + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + + fn tlv_iter( + &self, + tag: rs_matter_maybe_renamed::tlv::TLVTag, + ) -> impl Iterator> { + let iter = rs_matter_maybe_renamed::tlv::TLV::structure(tag).into_tlv_iter(); + + let iter = Iterator::chain( + iter, match self { - Self::ValueA(c) => { - c.to_tlv(tw, rs_matter_maybe_renamed::tlv::TagType::Context(0u8))?; - } - Self::ValueB(c) => { - c.to_tlv(tw, rs_matter_maybe_renamed::tlv::TagType::Context(1u8))?; - } - } - tw.end_container() - })() { - tw.rewind_to(anchor); - Err(err) - } else { - Ok(()) - } + Self::ValueA(c) => rs_matter_maybe_renamed::tlv::EitherIter::First( + rs_matter_maybe_renamed::tlv::ToTLV::tlv_iter(c, rs_matter_maybe_renamed::tlv::TLVTag::Context(0u8)), + ), + Self::ValueB(c) => rs_matter_maybe_renamed::tlv::EitherIter::Second( + rs_matter_maybe_renamed::tlv::ToTLV::tlv_iter(c, rs_matter_maybe_renamed::tlv::TLVTag::Context(1u8)), + ), + }, + ); + + Iterator::chain(iter, rs_matter_maybe_renamed::tlv::TLV::end_container().into_tlv_iter()) } } ) @@ -811,23 +1069,33 @@ mod tests { &derive_totlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( impl rs_matter_maybe_renamed::tlv::ToTLV for TestEnum { - fn to_tlv( + fn to_tlv( &self, - tw: &mut rs_matter_maybe_renamed::tlv::TLVWriter, - tag_type: rs_matter_maybe_renamed::tlv::TagType, + tag: &rs_matter_maybe_renamed::tlv::TLVTag, + mut tw: W, ) -> Result<(), rs_matter_maybe_renamed::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { match self { - Self::ValueA => tw.u8(tag_type, 0u8), - Self::ValueB => tw.u8(tag_type, 1u8), + Self::ValueA => tw.u8(tag, 0u8), + Self::ValueB => tw.u8(tag, 1u8), } - })() { - tw.rewind_to(anchor); - Err(err) - } else { - Ok(()) - } + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + + fn tlv_iter( + &self, + tag: rs_matter_maybe_renamed::tlv::TLVTag, + ) -> impl Iterator> { + match self { + Self::ValueA => rs_matter_maybe_renamed::tlv::TLV::u8(tag, 0u8).into_tlv_iter(), + Self::ValueB => rs_matter_maybe_renamed::tlv::TLV::u8(tag, 1u8).into_tlv_iter(), + } } } ) @@ -853,25 +1121,37 @@ mod tests { &derive_totlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( impl rs_matter_maybe_renamed::tlv::ToTLV for TestEnum { - fn to_tlv( + fn to_tlv( &self, - tw: &mut rs_matter_maybe_renamed::tlv::TLVWriter, - tag_type: rs_matter_maybe_renamed::tlv::TagType, + tag: &rs_matter_maybe_renamed::tlv::TLVTag, + mut tw: W, ) -> Result<(), rs_matter_maybe_renamed::error::Error> { let anchor = tw.get_tail(); if let Err(err) = (|| { match self { - Self::ValueA => tw.u16(tag_type, 0u16), - Self::ValueB => tw.u16(tag_type, 1u16), - Self::ValueC => tw.u16(tag_type, 100u16), - Self::ValueD => tw.u16(tag_type, 4660u16), + Self::ValueA => tw.u16(tag, 0u16), + Self::ValueB => tw.u16(tag, 1u16), + Self::ValueC => tw.u16(tag, 100u16), + Self::ValueD => tw.u16(tag, 4660u16), } - })() { - tw.rewind_to(anchor); - Err(err) - } else { - Ok(()) - } + })() { + tw.rewind_to(anchor); + Err(err) + } else { + Ok(()) + } + } + + fn tlv_iter( + &self, + tag: rs_matter_maybe_renamed::tlv::TLVTag, + ) -> impl Iterator> { + match self { + Self::ValueA => rs_matter_maybe_renamed::tlv::TLV::u16(tag, 0u16).into_tlv_iter(), + Self::ValueB => rs_matter_maybe_renamed::tlv::TLV::u16(tag, 1u16).into_tlv_iter(), + Self::ValueC => rs_matter_maybe_renamed::tlv::TLV::u16(tag, 100u16).into_tlv_iter(), + Self::ValueD => rs_matter_maybe_renamed::tlv::TLV::u16(tag, 4660u16).into_tlv_iter(), + } } } ) @@ -892,34 +1172,37 @@ mod tests { &derive_fromtlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( impl rs_matter_maybe_renamed::tlv::FromTLV<'_> for TestEnum { - fn from_tlv( - t: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, - ) -> Result { - let mut t_iter = t.confirm_struct()?.enter().ok_or_else(|| { - rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::Invalid, - ) - })?; - let mut item = t_iter.next().ok_or_else(|| { - rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::Invalid, - ) - })?; - if let TagType::Context(tag) = item.get_tag() { - match tag { - 0u8 => Ok(Self::ValueA(u32::from_tlv(&item)?)), - 1u8 => Ok(Self::ValueB(u32::from_tlv(&item)?)), - _ => Err(rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::Invalid, - )), - } - } else { - Err(rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::TLVTypeMismatch, - )) - } - } - } + fn from_tlv( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, + ) -> Result { + let element = element + .r#struct()? + .iter() + .next() + .ok_or(rs_matter_maybe_renamed::error::ErrorCode::TLVTypeMismatch)??; + + let tag = element + .try_ctx()? + .ok_or(rs_matter_maybe_renamed::error::ErrorCode::TLVTypeMismatch)?; + + Ok(match tag { + 0u8 => Self::ValueA(u32::from_tlv(&element)?), + 1u8 => Self::ValueB(u32::from_tlv(&element)?), + _ => Err(rs_matter_maybe_renamed::error::ErrorCode::Invalid)?, + }) + } + } + + impl TryFrom<&rs_matter_maybe_renamed::tlv::TLVElement<'_>> for TestEnum { + type Error = rs_matter_maybe_renamed::error::Error; + + fn try_from( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, + ) -> Result { + use rs_matter_maybe_renamed::tlv::FromTLV; + Self::from_tlv(element) + } + } ) ); } @@ -939,16 +1222,26 @@ mod tests { "e!( impl rs_matter_maybe_renamed::tlv::FromTLV<'_> for TestEnum { fn from_tlv( - t: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, - ) -> Result { - Ok(match t.u8()? { - 0u8 => Self::ValueA, - 1u8 => Self::ValueB, - _ => return Err(rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::Invalid)), - }) - } - } + element: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, + ) -> Result { + Ok(match element.u8()? { + 0u8 => Self::ValueA, + 1u8 => Self::ValueB, + _ => Err(rs_matter_maybe_renamed::error::ErrorCode::Invalid)?, + }) + } + } + + impl TryFrom<&rs_matter_maybe_renamed::tlv::TLVElement<'_>> for TestEnum { + type Error = rs_matter_maybe_renamed::error::Error; + + fn try_from( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, + ) -> Result { + use rs_matter_maybe_renamed::tlv::FromTLV; + Self::from_tlv(element) + } + } ) ); } @@ -972,19 +1265,29 @@ mod tests { &derive_fromtlv(ast, "rs_matter_maybe_renamed".to_string()), "e!( impl rs_matter_maybe_renamed::tlv::FromTLV<'_> for TestEnum { - fn from_tlv( - t: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, - ) -> Result { - Ok(match t.u16()? { - 0u16 => Self::A, - 1u16 => Self::B, - 100u16 => Self::C, - 4660u16 => Self::D, - _ => return Err(rs_matter_maybe_renamed::error::Error::new( - rs_matter_maybe_renamed::error::ErrorCode::Invalid)), + fn from_tlv( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, + ) -> Result { + Ok(match element.u16()? { + 0u16 => Self::A, + 1u16 => Self::B, + 100u16 => Self::C, + 4660u16 => Self::D, + _ => Err(rs_matter_maybe_renamed::error::ErrorCode::Invalid)?, }) - } - } + } + } + + impl TryFrom<&rs_matter_maybe_renamed::tlv::TLVElement<'_>> for TestEnum { + type Error = rs_matter_maybe_renamed::error::Error; + + fn try_from( + element: &rs_matter_maybe_renamed::tlv::TLVElement<'_>, + ) -> Result { + use rs_matter_maybe_renamed::tlv::FromTLV; + Self::from_tlv(element) + } + } ) ); } diff --git a/rs-matter/Cargo.toml b/rs-matter/Cargo.toml index 61548f9d..98ee8976 100644 --- a/rs-matter/Cargo.toml +++ b/rs-matter/Cargo.toml @@ -12,7 +12,8 @@ license = "Apache-2.0" rust-version = "1.78" [features] -default = ["os", "mbedtls"] +default = ["os", "rustcrypto"] +#default = ["os", "mbedtls"] os = ["std", "backtrace", "critical-section/std", "embassy-sync/std", "embassy-time/std", "embassy-time/generic-queue"] std = ["alloc", "rand"] backtrace = [] @@ -85,6 +86,7 @@ nix = { version = "0.27", features = ["net"] } futures-lite = "2" async-channel = "2" static_cell = "2" +similar = "2.6" [[example]] name = "onoff_light" diff --git a/rs-matter/src/acl.rs b/rs-matter/src/acl.rs index 5289b157..ad274102 100644 --- a/rs-matter/src/acl.rs +++ b/rs-matter/src/acl.rs @@ -25,7 +25,9 @@ use crate::data_model::objects::{Access, ClusterId, EndptId, Privilege}; use crate::error::{Error, ErrorCode}; use crate::fabric; use crate::interaction_model::messages::GenericPath; -use crate::tlv::{self, FromTLV, Nullable, TLVElement, TLVList, TLVWriter, TagType, ToTLV}; +use crate::tlv::{ + EitherIter, FromTLV, Nullable, TLVElement, TLVTag, TLVWrite, TLVWriter, ToTLV, TLV, +}; use crate::transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}; use crate::utils::cell::RefCell; use crate::utils::init::{init, Init}; @@ -57,16 +59,19 @@ impl FromTLV<'_> for AuthMode { } impl ToTLV for AuthMode { - fn to_tlv( - &self, - tw: &mut crate::tlv::TLVWriter, - tag: crate::tlv::TagType, - ) -> Result<(), Error> { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { match self { AuthMode::Invalid => Ok(()), _ => tw.u8(tag, *self as u8), } } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + match self { + AuthMode::Invalid => EitherIter::First(core::iter::empty()), + _ => EitherIter::Second(TLV::u8(tag, *self as u8).into_tlv_iter()), + } + } } /// An accessor can have as many identities: one node id and Upto MAX_CAT_IDS_PER_NOC @@ -291,7 +296,7 @@ type Targets = Nullable<[Option; TARGETS_PER_ENTRY]>; impl Targets { fn init_notnull() -> Self { const INIT_TARGETS: Option = None; - Nullable::NotNull([INIT_TARGETS; TARGETS_PER_ENTRY]) + Nullable::some([INIT_TARGETS; TARGETS_PER_ENTRY]) } } @@ -336,19 +341,18 @@ impl AclEntry { } pub fn add_target(&mut self, target: Target) -> Result<(), Error> { - if self.targets.is_null() { + if self.targets.is_none() { self.targets = Targets::init_notnull(); } let index = self .targets .as_ref() - .notnull() .unwrap() .iter() .position(|s| s.is_none()) .ok_or(ErrorCode::NoSpace)?; - self.targets.as_mut().notnull().unwrap()[index] = Some(target); + self.targets.as_mut().unwrap()[index] = Some(target); Ok(()) } @@ -382,7 +386,7 @@ impl AclEntry { fn match_access_desc(&self, object: &AccessDesc) -> bool { let mut allow = false; let mut entries_exist = false; - match self.targets.as_ref().notnull() { + match self.targets.as_ref() { None => allow = true, // Allow if targets are NULL Some(targets) => { for t in targets.iter().flatten() { @@ -583,9 +587,18 @@ impl AclMgr { } pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { - let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; + let entries = TLVElement::new(data).array()?.iter(); + + self.entries.clear(); + + for entry in entries { + let entry = entry?; + + self.entries + .push(Option::::from_tlv(&entry)?) + .map_err(|_| ErrorCode::NoSpace)?; + } - tlv::vec_from_tlv(&mut self.entries, &root)?; self.changed = false; Ok(()) @@ -597,7 +610,7 @@ impl AclMgr { let mut tw = TLVWriter::new(&mut wb); self.entries .as_slice() - .to_tlv(&mut tw, TagType::Anonymous)?; + .to_tlv(&TLVTag::Anonymous, &mut tw)?; self.changed = false; diff --git a/rs-matter/src/cert/mod.rs b/rs-matter/src/cert/mod.rs index fb23f1a2..78cee4a3 100644 --- a/rs-matter/src/cert/mod.rs +++ b/rs-matter/src/cert/mod.rs @@ -17,18 +17,24 @@ use core::fmt::{self, Write}; -use crate::{ - crypto::KeyPair, - error::{Error, ErrorCode}, - tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, - utils::{epoch::MATTER_CERT_DOESNT_EXPIRE, storage::WriteBuf}, -}; use log::error; + +use num::FromPrimitive; use num_derive::FromPrimitive; -pub use self::asn1_writer::ASN1Writer; +use crate::crypto::KeyPair; +use crate::error::{Error, ErrorCode}; +use crate::tlv::{FromTLV, Octets, TLVArray, TLVElement, TLVList, ToTLV}; +use crate::utils::epoch::MATTER_CERT_DOESNT_EXPIRE; +use crate::utils::iter::TryFindIterator; + use self::printer::CertPrinter; +pub use self::asn1_writer::ASN1Writer; + +mod asn1_writer; +mod printer; + // As per section 6.1.3 "Certificate Sizes" of the Matter 1.1 spec pub const MAX_CERT_TLV_LEN: usize = 400; @@ -38,8 +44,10 @@ const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01] const OID_EC_TYPE_PRIME256V1: [u8; 8] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x03, 0x01, 0x07]; const OID_ECDSA_WITH_SHA256: [u8; 8] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x04, 0x03, 0x02]; +const MAX_DEPTH: usize = 10; + #[derive(FromPrimitive)] -pub enum CertTags { +pub enum CertTag { SerialNum = 1, SignAlgo = 2, Issuer = 3, @@ -80,103 +88,7 @@ pub fn get_sign_algo(algo: u8) -> Option { num::FromPrimitive::from_u8(algo) } -const KEY_USAGE_DIGITAL_SIGN: u16 = 0x0001; -const KEY_USAGE_NON_REPUDIATION: u16 = 0x0002; -const KEY_USAGE_KEY_ENCIPHERMENT: u16 = 0x0004; -const KEY_USAGE_DATA_ENCIPHERMENT: u16 = 0x0008; -const KEY_USAGE_KEY_AGREEMENT: u16 = 0x0010; -const KEY_USAGE_KEY_CERT_SIGN: u16 = 0x0020; -const KEY_USAGE_CRL_SIGN: u16 = 0x0040; -const KEY_USAGE_ENCIPHER_ONLY: u16 = 0x0080; -const KEY_USAGE_DECIPHER_ONLY: u16 = 0x0100; - -fn reverse_byte(byte: u8) -> u8 { - const LOOKUP: [u8; 16] = [ - 0x00, 0x08, 0x04, 0x0c, 0x02, 0x0a, 0x06, 0x0e, 0x01, 0x09, 0x05, 0x0d, 0x03, 0x0b, 0x07, - 0x0f, - ]; - (LOOKUP[(byte & 0x0f) as usize] << 4) | LOOKUP[(byte >> 4) as usize] -} - -fn int_to_bitstring(mut a: u16, buf: &mut [u8]) { - if buf.len() >= 2 { - buf[0] = reverse_byte((a & 0xff) as u8); - a >>= 8; - buf[1] = reverse_byte((a & 0xff) as u8); - } -} - -macro_rules! add_if { - ($key:ident, $bit:ident,$str:literal) => { - if ($key & $bit) != 0 { - $str - } else { - "" - } - }; -} - -fn get_print_str(key_usage: u16) -> heapless::String<256> { - let mut string = heapless::String::new(); - write!( - &mut string, - "{}{}{}{}{}{}{}{}{}", - add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), - add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), - add_if!(key_usage, KEY_USAGE_KEY_ENCIPHERMENT, "keyEncipherment "), - add_if!(key_usage, KEY_USAGE_DATA_ENCIPHERMENT, "dataEncipherment "), - add_if!(key_usage, KEY_USAGE_KEY_AGREEMENT, "keyAgreement "), - add_if!(key_usage, KEY_USAGE_KEY_CERT_SIGN, "keyCertSign "), - add_if!(key_usage, KEY_USAGE_CRL_SIGN, "CRLSign "), - add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), - add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), - ) - .unwrap(); - - string -} - -#[allow(unused_assignments)] -fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Error> { - let mut key_usage_str = [0u8; 2]; - int_to_bitstring(key_usage, &mut key_usage_str); - w.bitstr(&get_print_str(key_usage), true, &key_usage_str)?; - Ok(()) -} - -fn encode_extended_key_usage( - list: impl Iterator, - w: &mut dyn CertConsumer, -) -> Result<(), Error> { - const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; - const OID_CLIENT_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02]; - const OID_CODE_SIGN: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x03]; - const OID_EMAIL_PROT: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x04]; - const OID_TIMESTAMP: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x08]; - const OID_OCSP_SIGN: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x09]; - let encoding = [ - ("", &[0; 8]), - ("ServerAuth", &OID_SERVER_AUTH), - ("ClientAuth", &OID_CLIENT_AUTH), - ("CodeSign", &OID_CODE_SIGN), - ("EmailProtection", &OID_EMAIL_PROT), - ("Timestamp", &OID_TIMESTAMP), - ("OCSPSign", &OID_OCSP_SIGN), - ]; - - w.start_seq("")?; - for t in list { - let t = t as usize; - if t > 0 && t <= encoding.len() { - w.oid(encoding[t].0, encoding[t].1)?; - } else { - error!("Skipping encoding key usage out of bounds"); - } - } - w.end_seq() -} - -#[derive(FromTLV, ToTLV, Default, Debug, PartialEq)] +#[derive(Default, Debug, Clone, FromTLV, ToTLV, PartialEq, Eq, Hash)] #[tlvargs(start = 1)] struct BasicConstraints { is_ca: bool, @@ -197,39 +109,38 @@ impl BasicConstraints { } } -fn encode_extension_start( - tag: &str, - critical: bool, - oid: &[u8], - w: &mut dyn CertConsumer, -) -> Result<(), Error> { - w.start_seq(tag)?; - w.oid("", oid)?; - if critical { - w.bool("critical:", true)?; - } - w.start_compound_ostr("value:") -} - -fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> { - w.end_compound_ostr()?; - w.end_seq() +// #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +// #[repr(u8)] +// enum ExtTag { +// BasicConstraints = 1, +// KeyUsage = 2, +// ExtKeyUsage = 3, +// SubjectKeyId = 4, +// AuthorityKeyId = 5, +// FutureExtensions = 6, +// } + +#[derive(Debug, Clone, FromTLV, ToTLV, PartialEq, Eq, Hash)] +#[tlvargs(start = 1, lifetime = "'a", datatype = "naked", unordered)] +enum Extension<'a> { + BasicConstraints(BasicConstraints), + KeyUsage(u16), + ExtKeyUsage(TLVArray<'a, u8>), + SubjectKeyId(Octets<'a>), + AuthorityKeyId(Octets<'a>), + FutureExtensions(Octets<'a>), } -const MAX_EXTENSION_ENTRIES: usize = 6; - -// The order in which the extensions arrive is important, as the signing -// requires that the ASN1 notation retain the same order -#[derive(Default, Debug, PartialEq)] -struct Extensions<'a>(heapless::Vec, MAX_EXTENSION_ENTRIES>); - -impl<'a> Extensions<'a> { - fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { +impl<'a> Extension<'a> { + fn encode_all( + iter: impl Iterator> + 'a, + w: &mut dyn CertConsumer, + ) -> Result<(), Error> { w.start_ctx("X509v3 extensions:", 3)?; w.start_seq("")?; - for extension in &self.0 { - extension.encode(w)?; + for extension in iter { + extension?.encode(w)?; } w.end_seq()?; @@ -237,71 +148,7 @@ impl<'a> Extensions<'a> { Ok(()) } -} - -impl<'a> FromTLV<'a> for Extensions<'a> { - fn from_tlv(t: &TLVElement<'a>) -> Result { - let tlv_iter = t - .confirm_list()? - .enter() - .ok_or_else(|| Error::new(ErrorCode::Invalid))?; - - let mut extensions = heapless::Vec::new(); - - for item in tlv_iter { - let TagType::Context(tag) = item.get_tag() else { - return Err(ErrorCode::Invalid.into()); - }; - let extension = match tag { - 1 => Extension::BasicConstraints(BasicConstraints::from_tlv(&item)?), - 2 => Extension::KeyUsage(item.u16()?), - 3 => Extension::ExtKeyUsage(TLVArray::from_tlv(&item)?), - 4 => Extension::SubjectKeyId(OctetStr::from_tlv(&item)?), - 5 => Extension::AuthorityKeyId(OctetStr::from_tlv(&item)?), - 6 => Extension::FutureExtensions(OctetStr::from_tlv(&item)?), - _ => Err(ErrorCode::Invalid)?, - }; - - extensions - .push(extension) - .map_err(|_| Error::new(ErrorCode::NoSpace))?; - } - - Ok(Self(extensions)) - } -} - -impl<'a> ToTLV for Extensions<'a> { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.start_list(tag)?; - - for extension in &self.0 { - match extension { - Extension::BasicConstraints(t) => t.to_tlv(tw, TagType::Context(1))?, - Extension::KeyUsage(t) => tw.u16(TagType::Context(2), *t)?, - Extension::ExtKeyUsage(t) => t.to_tlv(tw, TagType::Context(3))?, - Extension::SubjectKeyId(t) => t.to_tlv(tw, TagType::Context(4))?, - Extension::AuthorityKeyId(t) => t.to_tlv(tw, TagType::Context(5))?, - Extension::FutureExtensions(t) => t.to_tlv(tw, TagType::Context(6))?, - } - } - - tw.end_container() - } -} - -#[derive(Debug, PartialEq)] -enum Extension<'a> { - BasicConstraints(BasicConstraints), - KeyUsage(u16), - ExtKeyUsage(TLVArray<'a, u8>), - SubjectKeyId(OctetStr<'a>), - AuthorityKeyId(OctetStr<'a>), - FutureExtensions(OctetStr<'a>), -} - -impl<'a> Extension<'a> { fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13]; const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F]; @@ -311,36 +158,41 @@ impl<'a> Extension<'a> { match self { Extension::BasicConstraints(t) => { - encode_extension_start( + Self::encode_extension_start( "X509v3 Basic Constraints", true, &OID_BASIC_CONSTRAINTS, w, )?; t.encode(w)?; - encode_extension_end(w)?; + Self::encode_extension_end(w)?; } Extension::KeyUsage(t) => { - encode_extension_start("X509v3 Key Usage", true, &OID_KEY_USAGE, w)?; - encode_key_usage(*t, w)?; - encode_extension_end(w)?; + Self::encode_extension_start("X509v3 Key Usage", true, &OID_KEY_USAGE, w)?; + Self::encode_key_usage(*t, w)?; + Self::encode_extension_end(w)?; } Extension::ExtKeyUsage(t) => { - encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?; - encode_extended_key_usage(t.iter(), w)?; - encode_extension_end(w)?; + Self::encode_extension_start( + "X509v3 Extended Key Usage", + true, + &OID_EXT_KEY_USAGE, + w, + )?; + Self::encode_extended_key_usage(t.iter(), w)?; + Self::encode_extension_end(w)?; } Extension::SubjectKeyId(t) => { - encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; + Self::encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; w.ostr("", t.0)?; - encode_extension_end(w)?; + Self::encode_extension_end(w)?; } Extension::AuthorityKeyId(t) => { - encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; + Self::encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; w.start_seq("")?; w.ctx("", 0, t.0)?; w.end_seq()?; - encode_extension_end(w)?; + Self::encode_extension_end(w)?; } Extension::FutureExtensions(t) => { error!("Future Extensions Not Yet Supported: {:x?}", t.0) @@ -349,10 +201,126 @@ impl<'a> Extension<'a> { Ok(()) } + + fn encode_extension_start( + tag: &str, + critical: bool, + oid: &[u8], + w: &mut dyn CertConsumer, + ) -> Result<(), Error> { + w.start_seq(tag)?; + w.oid("", oid)?; + if critical { + w.bool("critical:", true)?; + } + w.start_compound_ostr("value:") + } + + fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> { + w.end_compound_ostr()?; + w.end_seq() + } + + #[allow(unused_assignments)] + fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Error> { + let mut key_usage_str = [0u8; 2]; + Self::int_to_bitstring(key_usage, &mut key_usage_str); + w.bitstr(&Self::get_print_str(key_usage), true, &key_usage_str)?; + Ok(()) + } + + fn encode_extended_key_usage( + list: impl Iterator>, + w: &mut dyn CertConsumer, + ) -> Result<(), Error> { + const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; + const OID_CLIENT_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02]; + const OID_CODE_SIGN: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x03]; + const OID_EMAIL_PROT: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x04]; + const OID_TIMESTAMP: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x08]; + const OID_OCSP_SIGN: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x09]; + let encoding = [ + ("", &[0; 8]), + ("ServerAuth", &OID_SERVER_AUTH), + ("ClientAuth", &OID_CLIENT_AUTH), + ("CodeSign", &OID_CODE_SIGN), + ("EmailProtection", &OID_EMAIL_PROT), + ("Timestamp", &OID_TIMESTAMP), + ("OCSPSign", &OID_OCSP_SIGN), + ]; + + w.start_seq("")?; + for t in list { + let t = t? as usize; + if t > 0 && t <= encoding.len() { + w.oid(encoding[t].0, encoding[t].1)?; + } else { + error!("Skipping encoding key usage out of bounds"); + } + } + w.end_seq() + } + + fn get_print_str(key_usage: u16) -> heapless::String<256> { + const KEY_USAGE_DIGITAL_SIGN: u16 = 0x0001; + const KEY_USAGE_NON_REPUDIATION: u16 = 0x0002; + const KEY_USAGE_KEY_ENCIPHERMENT: u16 = 0x0004; + const KEY_USAGE_DATA_ENCIPHERMENT: u16 = 0x0008; + const KEY_USAGE_KEY_AGREEMENT: u16 = 0x0010; + const KEY_USAGE_KEY_CERT_SIGN: u16 = 0x0020; + const KEY_USAGE_CRL_SIGN: u16 = 0x0040; + const KEY_USAGE_ENCIPHER_ONLY: u16 = 0x0080; + const KEY_USAGE_DECIPHER_ONLY: u16 = 0x0100; + + macro_rules! add_if { + ($key:ident, $bit:ident,$str:literal) => { + if ($key & $bit) != 0 { + $str + } else { + "" + } + }; + } + + let mut string = heapless::String::new(); + write!( + &mut string, + "{}{}{}{}{}{}{}{}{}", + add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), + add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), + add_if!(key_usage, KEY_USAGE_KEY_ENCIPHERMENT, "keyEncipherment "), + add_if!(key_usage, KEY_USAGE_DATA_ENCIPHERMENT, "dataEncipherment "), + add_if!(key_usage, KEY_USAGE_KEY_AGREEMENT, "keyAgreement "), + add_if!(key_usage, KEY_USAGE_KEY_CERT_SIGN, "keyCertSign "), + add_if!(key_usage, KEY_USAGE_CRL_SIGN, "CRLSign "), + add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), + add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), + ) + .unwrap(); + + string + } + + fn int_to_bitstring(mut a: u16, buf: &mut [u8]) { + if buf.len() >= 2 { + buf[0] = Self::reverse_byte((a & 0xff) as u8); + a >>= 8; + buf[1] = Self::reverse_byte((a & 0xff) as u8); + } + } + + fn reverse_byte(byte: u8) -> u8 { + const LOOKUP: [u8; 16] = [ + 0x00, 0x08, 0x04, 0x0c, 0x02, 0x0a, 0x06, 0x0e, 0x01, 0x09, 0x05, 0x0d, 0x03, 0x0b, + 0x07, 0x0f, + ]; + (LOOKUP[(byte & 0x0f) as usize] << 4) | LOOKUP[(byte >> 4) as usize] + } } -#[derive(FromPrimitive, Copy, Clone)] -enum DnTags { +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, FromPrimitive)] +#[repr(u8)] +enum DNTag { CommonName = 1, Surname = 2, SerialNum = 3, @@ -377,99 +345,51 @@ enum DnTags { NocCat = 22, } -#[derive(Debug, PartialEq)] -enum DistNameValue<'a> { +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum DNValue<'a> { Uint(u64), - Utf8Str(&'a [u8]), - PrintableStr(&'a [u8]), + Utf8(&'a str), + PrintableStr(&'a str), } -const MAX_DN_ENTRIES: usize = 5; +#[derive(FromTLV, ToTLV, Debug, Clone, PartialEq, Eq, Hash)] +#[tlvargs(lifetime = "'a")] +struct DN<'a>(TLVElement<'a>); -#[derive(Default, Debug, PartialEq)] -struct DistNames<'a> { - // The order in which the DNs arrive is important, as the signing - // requires that the ASN1 notation retains the same order - dn: heapless::Vec<(u8, DistNameValue<'a>), MAX_DN_ENTRIES>, -} +impl<'a> DN<'a> { + pub fn tag(&self) -> Result { + let ctx = self.0.try_ctx()?.ok_or(ErrorCode::Invalid)? & 0x7f; -impl<'a> DistNames<'a> { - fn u64(&self, match_id: DnTags) -> Option { - self.dn - .iter() - .find(|(id, _)| *id == match_id as u8) - .and_then(|(_, value)| { - if let DistNameValue::Uint(u) = *value { - Some(u) - } else { - None - } - }) + Ok(DNTag::from_u8(ctx).ok_or(ErrorCode::Invalid)?) } - fn u32_arr(&self, match_id: DnTags, output: &mut [u32]) { - let mut out_index = 0; - for (_, val) in self.dn.iter().filter(|(id, _)| *id == match_id as u8) { - if let DistNameValue::Uint(a) = val { - if out_index < output.len() { - // CatIds are actually just 32-bit - output[out_index] = *a as u32; - out_index += 1; - } - } - } - } -} + pub fn is_printable(&self) -> Result { + let ctx = self.0.try_ctx()?.ok_or(ErrorCode::Invalid)?; -const PRINTABLE_STR_THRESHOLD: u8 = 0x80; + Ok(ctx >= 0x80) + } -impl<'a> FromTLV<'a> for DistNames<'a> { - fn from_tlv(t: &TLVElement<'a>) -> Result { - let mut d = Self { - dn: heapless::Vec::new(), - }; - let iter = t.confirm_list()?.enter().ok_or(ErrorCode::Invalid)?; - for t in iter { - if let TagType::Context(tag) = t.get_tag() { - if let Ok(value) = t.u64() { - d.dn.push((tag, DistNameValue::Uint(value))) - .map_err(|_| ErrorCode::BufferTooSmall)?; - } else if let Ok(value) = t.slice() { - if tag > PRINTABLE_STR_THRESHOLD { - d.dn.push(( - tag - PRINTABLE_STR_THRESHOLD, - DistNameValue::PrintableStr(value), - )) - .map_err(|_| ErrorCode::BufferTooSmall)?; - } else { - d.dn.push((tag, DistNameValue::Utf8Str(value))) - .map_err(|_| ErrorCode::BufferTooSmall)?; - } - } - } - } - Ok(d) + fn uint(&self) -> Result { + self.0.u64() } -} -impl<'a> ToTLV for DistNames<'a> { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.start_list(tag)?; - for (name, value) in &self.dn { - match value { - DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?, - DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v)?, - DistNameValue::PrintableStr(v) => { - tw.utf8(TagType::Context(*name + PRINTABLE_STR_THRESHOLD), v)? - } + fn value(&self) -> Result, Error> { + if let Ok(value) = self.0.utf8() { + if self.is_printable()? { + Ok(DNValue::PrintableStr(value)) + } else { + Ok(DNValue::Utf8(value)) } + } else { + self.0.u64().map(DNValue::Uint) } - tw.end_container() } -} -impl<'a> DistNames<'a> { - fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> { + fn encode_all( + values: impl Iterator> + 'a, + tag: &str, + w: &mut dyn CertConsumer, + ) -> Result<(), Error> { const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03]; const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04]; const OID_SERIAL_NUMBER: [u8; 3] = [0x55_u8, 0x04, 0x05]; @@ -557,25 +477,66 @@ impl<'a> DistNames<'a> { ]; w.start_seq(tag)?; - for (id, value) in &self.dn { - let tag: Option = num::FromPrimitive::from_u8(*id); - if tag.is_some() { - let index = (id - 1) as usize; + for dn in values { + let dn = dn?; + let tag = dn.tag(); + + if let Ok(tag) = &tag { + let index = *tag as usize - 1; if index <= DN_ENCODING.len() { let this = &DN_ENCODING[index]; - encode_dn_value(value, this.0, this.1, w, this.2)?; + dn.encode(this.0, this.1, w, this.2)?; } else { // Non Matter DNs are encoded as - error!("Invalid DN, too high {}", id); + error!("Invalid DN, too high {:?}", tag); } } else { // Non Matter DNs are encoded as - error!("Non Matter DNs are not yet supported {}", id); + error!("Non Matter DNs are not yet supported {:?}", tag); } } w.end_seq()?; Ok(()) } + + fn encode( + &self, + name: &str, + oid: &[u8], + w: &mut dyn CertConsumer, + // Only applicable for integer values + expected_len: Option, + ) -> Result<(), Error> { + w.start_set("")?; + w.start_seq("")?; + w.oid(name, oid)?; + match self.value()? { + DNValue::Uint(v) => match expected_len { + Some(IntToStringLen::Len16) => { + let mut string = heapless::String::<32>::new(); + write!(&mut string, "{:016X}", v).unwrap(); + w.utf8str("", &string)? + } + Some(IntToStringLen::Len8) => { + let mut string = heapless::String::<32>::new(); + write!(&mut string, "{:08X}", v).unwrap(); + w.utf8str("", &string)? + } + _ => { + error!("Invalid encoding"); + Err(ErrorCode::Invalid)? + } + }, + DNValue::Utf8(v) => { + w.utf8str("", v)?; + } + DNValue::PrintableStr(v) => { + w.printstr("", v)?; + } + } + w.end_seq()?; + w.end_set() + } } #[derive(Copy, Clone)] @@ -585,131 +546,124 @@ enum IntToStringLen { Len8, } -fn encode_dn_value( - value: &DistNameValue, - name: &str, - oid: &[u8], - w: &mut dyn CertConsumer, - // Only applicable for integer values - expected_len: Option, -) -> Result<(), Error> { - w.start_set("")?; - w.start_seq("")?; - w.oid(name, oid)?; - match value { - DistNameValue::Uint(v) => match expected_len { - Some(IntToStringLen::Len16) => { - let mut string = heapless::String::<32>::new(); - write!(&mut string, "{:016X}", v).unwrap(); - w.utf8str("", &string)? - } - Some(IntToStringLen::Len8) => { - let mut string = heapless::String::<32>::new(); - write!(&mut string, "{:08X}", v).unwrap(); - w.utf8str("", &string)? - } - _ => { - error!("Invalid encoding"); - Err(ErrorCode::Invalid)? - } - }, - DistNameValue::Utf8Str(v) => { - w.utf8str("", core::str::from_utf8(v)?)?; - } - DistNameValue::PrintableStr(v) => { - w.printstr("", core::str::from_utf8(v)?)?; - } +#[derive(FromTLV, ToTLV, Debug, Clone, PartialEq, Eq, Hash)] +#[tlvargs(lifetime = "'a")] +pub struct CertRef<'a>(TLVElement<'a>); + +impl<'a> CertRef<'a> { + pub const fn new(tlv: TLVElement<'a>) -> Self { + Self(tlv) } - w.end_seq()?; - w.end_set() -} -#[derive(FromTLV, ToTLV, Default, Debug, PartialEq)] -#[tlvargs(lifetime = "'a", start = 1)] -pub struct Cert<'a> { - serial_no: OctetStr<'a>, - sign_algo: u8, - issuer: DistNames<'a>, - not_before: u32, - not_after: u32, - subject: DistNames<'a>, - pubkey_algo: u8, - ec_curve_id: u8, - pubkey: OctetStr<'a>, - extensions: Extensions<'a>, - signature: OctetStr<'a>, -} + fn serial_no(&self) -> Result<&[u8], Error> { + self.0.structure()?.find_ctx(1)?.str() + } -// TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding -// rules in terms of sequence may get complicated. Need to look into this -impl<'a> Cert<'a> { - pub fn new(cert_bin: &'a [u8]) -> Result { - let root = tlv::get_root_node(cert_bin)?; - Cert::from_tlv(&root) + fn sign_algo(&self) -> Result { + self.0.structure()?.find_ctx(2)?.u8() } - pub fn get_node_id(&self) -> Result { - self.subject - .u64(DnTags::NodeId) - .ok_or_else(|| Error::from(ErrorCode::NoNodeId)) + fn issuer(&self) -> Result>, Error> { + TLVList::new(self.0.structure()?.find_ctx(3)?) } - pub fn get_cat_ids(&self, output: &mut [u32]) { - self.subject.u32_arr(DnTags::NocCat, output) + fn not_before(&self) -> Result { + self.0.structure()?.find_ctx(4)?.u32() } - pub fn get_fabric_id(&self) -> Result { - self.subject - .u64(DnTags::FabricId) - .ok_or_else(|| Error::from(ErrorCode::NoFabricId)) + fn not_after(&self) -> Result { + self.0.structure()?.find_ctx(5)?.u32() } - pub fn get_pubkey(&self) -> &[u8] { - self.pubkey.0 + fn subject(&self) -> Result>, Error> { + TLVList::new(self.0.structure()?.find_ctx(6)?) } - pub fn get_subject_key_id(&self) -> Result<&[u8], Error> { - self.extensions - .0 - .iter() - .find_map(|extension| { - if let Extension::SubjectKeyId(id) = extension { - Some(id.0) - } else { - None - } - }) - .ok_or_else(|| Error::from(ErrorCode::Invalid)) + fn pubkey_algo(&self) -> Result { + self.0.structure()?.find_ctx(7)?.u8() } - pub fn is_authority(&self, their: &Cert) -> Result { - let their_subject = their.get_subject_key_id()?; + fn ec_curve_id(&self) -> Result { + self.0.structure()?.find_ctx(8)?.u8() + } + + pub fn pubkey(&self) -> Result<&[u8], Error> { + self.0.structure()?.find_ctx(9)?.str() + } + + fn extensions(&self) -> Result>, Error> { + TLVList::new(self.0.structure()?.find_ctx(10)?) + } - let authority = self - .extensions - .0 + fn signature(&self) -> Result<&[u8], Error> { + self.0.structure()?.find_ctx(11)?.str() + } + + pub fn get_node_id(&self) -> Result { + let dn = self + .subject()? .iter() - .find_map(|extension| { - if let Extension::AuthorityKeyId(id) = extension { - Some(id.0 == their_subject) - } else { - None + .do_try_find(|dn| Ok(dn.tag()? == DNTag::NodeId))? + .ok_or(ErrorCode::NoNodeId)?; + + dn.uint() + } + + pub fn get_cat_ids(&self, output: &mut [u32]) -> Result<(), Error> { + let mut offset = 0; + + self.subject()?.iter().try_for_each(|dn| { + let dn = dn?; + + if dn.tag()? == DNTag::NocCat { + if offset == output.len() { + Err(ErrorCode::NoSpace)?; } - }) - .unwrap_or(false); - Ok(authority) + output[offset] = dn.uint()? as u32; + offset += 1; + } + + Ok(()) + }) } - pub fn get_signature(&self) -> &[u8] { - self.signature.0 + pub fn get_fabric_id(&self) -> Result { + let dn = self + .subject()? + .iter() + .do_try_find(|dn| Ok(dn.tag()? == DNTag::FabricId))? + .ok_or(ErrorCode::NoFabricId)?; + + dn.uint() } - pub fn as_tlv(&self, buf: &mut [u8]) -> Result { - let mut wb = WriteBuf::new(buf); - let mut tw = TLVWriter::new(&mut wb); - self.to_tlv(&mut tw, TagType::Anonymous)?; - Ok(wb.as_slice().len()) + fn get_subject_key_id(&self) -> Result<&[u8], Error> { + let extension = self + .extensions()? + .iter() + .do_try_find(|extension| Ok(matches!(extension, Extension::SubjectKeyId(_))))? + .ok_or(Error::new(ErrorCode::Invalid))?; + + let Extension::SubjectKeyId(id) = extension else { + unreachable!(); + }; + + Ok(id.0) + } + + fn is_authority(&self, their: &CertRef) -> Result { + let their_subject = their.get_subject_key_id()?; + + let authority = self.extensions()?.iter().do_try_find(|extension| { + Ok(if let Extension::AuthorityKeyId(id) = extension { + id.0 == their_subject + } else { + false + }) + })?; + + Ok(authority.is_some()) } pub fn as_asn1(&self, buf: &mut [u8]) -> Result { @@ -729,46 +683,47 @@ impl<'a> Cert<'a> { w.integer("", &[2])?; w.end_ctx()?; - w.integer("Serial Num:", self.serial_no.0)?; + w.integer("Serial Num:", self.serial_no()?)?; w.start_seq("Signature Algorithm:")?; - let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(ErrorCode::Invalid)? { + let (str, oid) = match get_sign_algo(self.sign_algo()?).ok_or(ErrorCode::Invalid)? { SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256), }; w.oid(str, &oid)?; w.end_seq()?; - self.issuer.encode("Issuer:", w)?; + DN::encode_all(self.issuer()?.iter(), "Issuer:", w)?; w.start_seq("Validity:")?; - w.utctime("Not Before:", self.not_before.into())?; - if self.not_after == 0 { + w.utctime("Not Before:", self.not_before()?.into())?; + if self.not_after()? == 0 { // As per the spec a Not-After value of 0, indicates no well-defined // expiration date and should return in GeneralizedTime of 99991231235959Z w.utctime("Not After:", MATTER_CERT_DOESNT_EXPIRE)?; } else { - w.utctime("Not After:", self.not_after.into())?; + w.utctime("Not After:", self.not_after()?.into())?; } w.end_seq()?; - self.subject.encode("Subject:", w)?; + DN::encode_all(self.subject()?.iter(), "Subject:", w)?; w.start_seq("")?; w.start_seq("Public Key Algorithm")?; - let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(ErrorCode::Invalid)? { + let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo()?).ok_or(ErrorCode::Invalid)? { PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY), }; w.oid(str, &pub_key)?; - let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(ErrorCode::Invalid)? { - EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1), - }; + let (str, curve_id) = + match get_ec_curve_id(self.ec_curve_id()?).ok_or(ErrorCode::Invalid)? { + EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1), + }; w.oid(str, &curve_id)?; w.end_seq()?; - w.bitstr("Public-Key:", false, self.pubkey.0)?; + w.bitstr("Public-Key:", false, self.pubkey()?)?; w.end_seq()?; - self.extensions.encode(w)?; + Extension::encode_all(self.extensions()?.iter(), w)?; // We do not encode the Signature in the DER certificate @@ -776,36 +731,39 @@ impl<'a> Cert<'a> { } } -impl<'a> fmt::Display for Cert<'a> { +impl<'a> fmt::Display for CertRef<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut printer = CertPrinter::new(f); - let _ = self - .encode(&mut printer) - .map_err(|e| error!("Error decoding certificate: {}", e)); + + self.encode(&mut printer).map_err(|_| fmt::Error)?; + // Signature is not encoded by the Cert Decoder - writeln!(f, "Signature: {:x?}", self.get_signature()) + writeln!( + f, + "Signature: {:x?}", + self.signature().map_err(|_| fmt::Error)? + ) } } pub struct CertVerifier<'a> { - cert: &'a Cert<'a>, + cert: &'a CertRef<'a>, } impl<'a> CertVerifier<'a> { - pub fn new(cert: &'a Cert) -> Self { + pub fn new(cert: &'a CertRef<'a>) -> Self { Self { cert } } - pub fn add_cert(self, parent: &'a Cert) -> Result, Error> { + pub fn add_cert(self, parent: &'a CertRef<'a>, buf: &mut [u8]) -> Result { if !self.cert.is_authority(parent)? { Err(ErrorCode::InvalidAuthKey)?; } - let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; - let len = self.cert.as_asn1(&mut asn1)?; - let asn1 = &asn1[..len]; + let len = self.cert.as_asn1(buf)?; + let asn1 = &buf[..len]; - let k = KeyPair::new_from_public(parent.get_pubkey())?; - k.verify_msg(asn1, self.cert.get_signature()) + let k = KeyPair::new_from_public(parent.pubkey()?)?; + k.verify_msg(asn1, self.cert.signature()?) .inspect_err(|e| { error!( "Error {e} in signature verification of certificate: {:x?} by {:x?}", @@ -818,9 +776,10 @@ impl<'a> CertVerifier<'a> { Ok(CertVerifier::new(parent)) } - pub fn finalise(self) -> Result<(), Error> { + pub fn finalise(self, buf: &mut [u8]) -> Result<(), Error> { let cert = self.cert; - self.add_cert(cert)?; + self.add_cert(cert, buf)?; + Ok(()) } } @@ -845,39 +804,34 @@ pub trait CertConsumer { fn utctime(&mut self, tag: &str, epoch: u64) -> Result<(), Error>; } -const MAX_DEPTH: usize = 10; -const MAX_ASN1_CERT_SIZE: usize = 1000; - -mod asn1_writer; -mod printer; - #[cfg(test)] mod tests { use log::info; - use crate::cert::Cert; - use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; + use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; use crate::utils::storage::WriteBuf; + use super::CertRef; + #[test] fn test_asn1_encode_success() { { let mut asn1_buf = [0u8; 1000]; - let c = Cert::new(&test_vectors::CHIP_CERT_INPUT1).unwrap(); + let c = CertRef::new(TLVElement::new(&test_vectors::CHIP_CERT_INPUT1)); let len = c.as_asn1(&mut asn1_buf).unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT1, &asn1_buf[..len]); } { let mut asn1_buf = [0u8; 1000]; - let c = Cert::new(&test_vectors::CHIP_CERT_INPUT2).unwrap(); + let c = CertRef::new(TLVElement::new(&test_vectors::CHIP_CERT_INPUT2)); let len = c.as_asn1(&mut asn1_buf).unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT2, &asn1_buf[..len]); } { let mut asn1_buf = [0u8; 1000]; - let c = Cert::new(&test_vectors::CHIP_CERT_TXT_IN_DN).unwrap(); + let c = CertRef::new(TLVElement::new(&test_vectors::CHIP_CERT_TXT_IN_DN)); let len = c.as_asn1(&mut asn1_buf).unwrap(); assert_eq!(&test_vectors::ASN1_OUTPUT_TXT_IN_DN, &asn1_buf[..len]); } @@ -885,15 +839,16 @@ mod tests { #[test] fn test_verify_chain_success() { - let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); - let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); - let rca = Cert::new(&test_vectors::RCA1_SUCCESS).unwrap(); + let mut buf = [0; 1000]; + let noc = CertRef::new(TLVElement::new(&test_vectors::NOC1_SUCCESS)); + let icac = CertRef::new(TLVElement::new(&test_vectors::ICAC1_SUCCESS)); + let rca = CertRef::new(TLVElement::new(&test_vectors::RCA1_SUCCESS)); let a = noc.verify_chain_start(); - a.add_cert(&icac) + a.add_cert(&icac, &mut buf) .unwrap() - .add_cert(&rca) + .add_cert(&rca, &mut buf) .unwrap() - .finalise() + .finalise(&mut buf) .unwrap(); } @@ -902,12 +857,16 @@ mod tests { // The chain doesn't lead up to a self-signed certificate use crate::error::ErrorCode; - let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); - let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); + let mut buf = [0; 1000]; + let noc = CertRef::new(TLVElement::new(&test_vectors::NOC1_SUCCESS)); + let icac = CertRef::new(TLVElement::new(&test_vectors::ICAC1_SUCCESS)); let a = noc.verify_chain_start(); assert_eq!( Err(ErrorCode::InvalidAuthKey), - a.add_cert(&icac).unwrap().finalise().map_err(|e| e.code()) + a.add_cert(&icac, &mut buf) + .unwrap() + .finalise(&mut buf) + .map_err(|e| e.code()) ); } @@ -915,35 +874,42 @@ mod tests { fn test_auth_key_chain_incorrect() { use crate::error::ErrorCode; - let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); - let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); + let mut buf = [0; 1000]; + let noc = CertRef::new(TLVElement::new(&test_vectors::NOC1_AUTH_KEY_FAIL)); + let icac = CertRef::new(TLVElement::new(&test_vectors::ICAC1_SUCCESS)); let a = noc.verify_chain_start(); assert_eq!( Err(ErrorCode::InvalidAuthKey), - a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) + a.add_cert(&icac, &mut buf) + .map(|_| ()) + .map_err(|e| e.code()) ); } #[test] fn test_zero_value_of_not_after_field() { - let noc = Cert::new(&test_vectors::NOC_NOT_AFTER_ZERO).unwrap(); - let rca = Cert::new(&test_vectors::RCA_FOR_NOC_NOT_AFTER_ZERO).unwrap(); + let mut buf = [0; 1000]; + let noc = CertRef::new(TLVElement::new(&test_vectors::NOC_NOT_AFTER_ZERO)); + let rca = CertRef::new(TLVElement::new(&test_vectors::RCA_FOR_NOC_NOT_AFTER_ZERO)); let v = noc.verify_chain_start(); - let v = v.add_cert(&rca).unwrap(); - v.finalise().unwrap(); + let v = v.add_cert(&rca, &mut buf).unwrap(); + v.finalise(&mut buf).unwrap(); } #[test] fn test_cert_corrupted() { use crate::error::ErrorCode; - let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); - let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); + let mut buf = [0; 1000]; + let noc = CertRef::new(TLVElement::new(&test_vectors::NOC1_CORRUPT_CERT)); + let icac = CertRef::new(TLVElement::new(&test_vectors::ICAC1_SUCCESS)); let a = noc.verify_chain_start(); assert_eq!( Err(ErrorCode::InvalidSignature), - a.add_cert(&icac).map(|_| ()).map_err(|e| e.code()) + a.add_cert(&icac, &mut buf) + .map(|_| ()) + .map_err(|e| e.code()) ); } @@ -958,15 +924,15 @@ mod tests { for input in test_input.iter() { info!("Testing next input..."); - let root = tlv::get_root_node(input).unwrap(); - let cert = Cert::from_tlv(&root).unwrap(); + let root = TLVElement::new(input); + let cert = CertRef::from_tlv(&root).unwrap(); let mut buf = [0u8; 1024]; let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); - cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); + cert.to_tlv(&TagType::Anonymous, &mut tw).unwrap(); - let root2 = tlv::get_root_node(wb.as_slice()).unwrap(); - let cert2 = Cert::from_tlv(&root2).unwrap(); + let root2 = TLVElement::new(wb.as_slice()); + let cert2 = CertRef::from_tlv(&root2).unwrap(); assert_eq!(cert, cert2); } } @@ -975,16 +941,7 @@ mod tests { fn test_unordered_extensions() { let mut buf = [0; 1000]; - let cert = Cert::new(test_vectors::UNORDERED_EXTENSIONS_CHIP).unwrap(); - - let mut writer = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writer); - - cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!( - tw.get_buf().as_slice(), - test_vectors::UNORDERED_EXTENSIONS_CHIP - ); + let cert = CertRef::new(TLVElement::new(test_vectors::UNORDERED_EXTENSIONS_CHIP)); let asn1_len = cert.as_asn1(&mut buf).unwrap(); assert_eq!(&buf[..asn1_len], test_vectors::UNORDERED_EXTENSIONS_DER); diff --git a/rs-matter/src/crypto/crypto_rustcrypto.rs b/rs-matter/src/crypto/crypto_rustcrypto.rs index 19c288ea..5d29b9cf 100644 --- a/rs-matter/src/crypto/crypto_rustcrypto.rs +++ b/rs-matter/src/crypto/crypto_rustcrypto.rs @@ -16,6 +16,7 @@ */ use core::convert::{TryFrom, TryInto}; +use core::mem::MaybeUninit; use aes::Aes128; use alloc::vec; @@ -43,7 +44,7 @@ use x509_cert::{ use crate::{ error::{Error, ErrorCode}, secure_channel::crypto_rustcrypto::RandRngCore, - utils::rand::Rand, + utils::{init::InitMaybeUninit, rand::Rand}, }; type HmacSha256I = hmac::Hmac; @@ -182,8 +183,9 @@ impl KeyPair { .try_into() .unwrap(), )]); - let mut pubkey = [0; 65]; - self.get_public_key(&mut pubkey).unwrap(); + let mut pubkey = MaybeUninit::<[u8; 65]>::uninit(); // TODO MEDIUM BUFFER + let pubkey = pubkey.init_zeroed(); + self.get_public_key(pubkey).unwrap(); let info = x509_cert::request::CertReqInfo { version: x509_cert::request::Version::V1, subject, @@ -200,7 +202,7 @@ impl KeyPair { .unwrap(), ), }, - subject_public_key: BitString::from_bytes(&pubkey).unwrap(), + subject_public_key: BitString::from_bytes(&*pubkey).unwrap(), }, attributes: Default::default(), }; diff --git a/rs-matter/src/crypto/mod.rs b/rs-matter/src/crypto/mod.rs index 04584c8c..7ace2bbc 100644 --- a/rs-matter/src/crypto/mod.rs +++ b/rs-matter/src/crypto/mod.rs @@ -14,9 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +use core::mem::MaybeUninit; + use crate::{ error::{Error, ErrorCode}, - tlv::{FromTLV, TLVWriter, TagType, ToTLV}, + tlv::{FromTLV, TLVTag, TLVWrite, ToTLV, TLV}, + utils::init::InitMaybeUninit, }; pub const SYMM_KEY_LEN_BITS: usize = 128; @@ -67,33 +71,41 @@ impl<'a> FromTLV<'a> for KeyPair { where Self: Sized, { - t.confirm_array()?.enter(); + let array = t.array()?; + let mut iter = array.iter(); - if let Some(mut array) = t.enter() { - let pub_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?; - let priv_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?; + let pub_key = iter.next().ok_or(ErrorCode::Invalid)??.str()?; + let priv_key = iter.next().ok_or(ErrorCode::Invalid)??.str()?; - KeyPair::new_from_components(pub_key, priv_key) - } else { - Err(ErrorCode::Invalid.into()) - } + KeyPair::new_from_components(pub_key, priv_key) } } impl ToTLV for KeyPair { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - let mut buf = [0; 1024]; // TODO + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + let mut pubkey_buf = MaybeUninit::<[u8; EC_POINT_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let pubkey_buf = pubkey_buf.init_zeroed(); - tw.start_array(tag)?; + let mut privkey_buf = MaybeUninit::<[u8; BIGNUM_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let privkey_buf = privkey_buf.init_zeroed(); + + let pubkey_len = self.get_public_key(pubkey_buf)?; + let privkey_len = self.get_private_key(privkey_buf)?; - let size = self.get_public_key(&mut buf)?; - tw.str16(TagType::Anonymous, &buf[..size])?; + tw.start_array(tag)?; - let size = self.get_private_key(&mut buf)?; - tw.str16(TagType::Anonymous, &buf[..size])?; + tw.str(&TLVTag::Anonymous, &pubkey_buf[..pubkey_len])?; + tw.str(&TLVTag::Anonymous, &privkey_buf[..privkey_len])?; tw.end_container() } + + fn tlv_iter(&self, _tag: TLVTag) -> impl Iterator> { + unimplemented!("Not implemented for `KeyPair`"); + + #[allow(unreachable_code)] + core::iter::empty() + } } #[cfg(test)] diff --git a/rs-matter/src/data_model/core.rs b/rs-matter/src/data_model/core.rs index e17115f3..64a6f60b 100644 --- a/rs-matter/src/data_model/core.rs +++ b/rs-matter/src/data_model/core.rs @@ -33,11 +33,11 @@ use crate::interaction_model::core::{ IMStatusCode, OpCode, ReportDataReq, PROTO_ID_INTERACTION_MODEL, }; use crate::interaction_model::messages::msg::{ - InvReq, InvRespTag, ReadReq, ReportDataTag, StatusResp, SubscribeReq, SubscribeResp, TimedReq, - WriteReq, WriteRespTag, + InvReqRef, InvRespTag, ReadReqRef, ReportDataTag, StatusResp, SubscribeReqRef, SubscribeResp, + TimedReq, WriteReqRef, WriteRespTag, }; use crate::respond::ExchangeHandler; -use crate::tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType}; +use crate::tlv::{get_root_node_struct, FromTLV, TLVElement, TLVTag, TLVWrite, TLVWriter}; use crate::transport::exchange::{Exchange, MAX_EXCHANGE_RX_BUF_SIZE, MAX_EXCHANGE_TX_BUF_SIZE}; use crate::utils::storage::WriteBuf; @@ -147,7 +147,7 @@ where let metadata = self.handler.lock().await; - let req = ReadReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + let req = ReadReqRef::new(TLVElement::new(exchange.rx()?.payload())); debug!("IM: Read request: {:?}", req); let req = ReportDataReq::Read(&req); @@ -155,17 +155,27 @@ where let accessor = exchange.accessor()?; // Will the clusters that are to be invoked await? - let awaits = metadata.node().read(&req, None, &accessor).any(|item| { - item.map(|attr| self.handler.read_awaits(exchange, &attr)) + let mut awaits = false; + + for item in metadata + .node() + .read(&req, None, &exchange.accessor()?, true)? + { + if item? + .map(|attr| self.handler.read_awaits(exchange, &attr)) .unwrap_or(false) - }); + { + awaits = true; + break; + } + } if !awaits { // No, they won't. Answer the request by directly using the RX packet // of the transport layer, as the operation won't await. let node = metadata.node(); - let mut attrs = node.read(&req, None, &accessor).peekable(); + let mut attrs = node.read(&req, None, &accessor, true)?.peekable(); if !req .respond(&self.handler, exchange, None, &mut attrs, &mut wb, true) @@ -201,11 +211,11 @@ where return Ok(()); }; - let req = ReadReq::from_tlv(&get_root_node_struct(&rx)?)?; + let req = ReadReqRef::new(TLVElement::new(&rx)); let req = ReportDataReq::Read(&req); let node = metadata.node(); - let mut attrs = node.read(&req, None, &accessor).peekable(); + let mut attrs = node.read(&req, None, &accessor, true)?.peekable(); loop { let more_chunks = req @@ -231,10 +241,10 @@ where exchange: &mut Exchange<'_>, timeout_instant: Option, ) -> Result { - let req = WriteReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + let req = WriteReqRef::new(TLVElement::new(exchange.rx()?.payload())); debug!("IM: Write request: {:?}", req); - let timed = req.timed_request.unwrap_or(false); + let timed = req.timed_request()?; if self.timed_out(exchange, timeout_instant, timed).await? { return Ok(false); @@ -248,16 +258,20 @@ where let metadata = self.handler.lock().await; - let req = WriteReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + let req = WriteReqRef::new(TLVElement::new(exchange.rx()?.payload())); // Will the clusters that are to be invoked await? - let awaits = metadata - .node() - .write(&req, &exchange.accessor()?) - .any(|item| { - item.map(|(attr, _)| self.handler.write_awaits(exchange, &attr)) - .unwrap_or(false) - }); + let mut awaits = false; + + for item in metadata.node().write(&req, &exchange.accessor()?)? { + if item? + .map(|(attr, _)| self.handler.write_awaits(exchange, &attr)) + .unwrap_or(false) + { + awaits = true; + break; + } + } let more_chunks = if awaits { // Yes, they will @@ -269,7 +283,7 @@ where return Ok(false); }; - let req = WriteReq::from_tlv(&get_root_node_struct(&rx)?)?; + let req = WriteReqRef::new(TLVElement::new(&rx)); req.respond(&self.handler, exchange, &metadata.node(), &mut wb) .await? @@ -291,10 +305,10 @@ where exchange: &mut Exchange<'_>, timeout_instant: Option, ) -> Result<(), Error> { - let req = InvReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + let req = InvReqRef::new(TLVElement::new(exchange.rx()?.payload())); debug!("IM: Invoke request: {:?}", req); - let timed = req.timed_request.unwrap_or(false); + let timed = req.timed_request()?; if self.timed_out(exchange, timeout_instant, timed).await? { return Ok(()); @@ -308,16 +322,20 @@ where let metadata = self.handler.lock().await; - let req = InvReq::from_tlv(&get_root_node_struct(exchange.rx()?.payload())?)?; + let req = InvReqRef::new(TLVElement::new(exchange.rx()?.payload())); // Will the clusters that are to be invoked await? - let awaits = metadata - .node() - .invoke(&req, &exchange.accessor()?) - .any(|item| { - item.map(|(cmd, _)| self.handler.invoke_awaits(exchange, &cmd)) - .unwrap_or(false) - }); + let mut awaits = false; + + for item in metadata.node().invoke(&req, &exchange.accessor()?)? { + if item? + .map(|(cmd, _)| self.handler.invoke_awaits(exchange, &cmd)) + .unwrap_or(false) + { + awaits = true; + break; + } + } if awaits { // Yes, they will @@ -329,15 +347,15 @@ where return Ok(()); }; - let req = InvReq::from_tlv(&get_root_node_struct(&rx)?)?; + let req = InvReqRef::new(TLVElement::new(&rx)); - req.respond(&self.handler, exchange, &metadata.node(), &mut wb) + req.respond(&self.handler, exchange, &metadata.node(), &mut wb, false) .await?; } else { // No, they won't. Answer the request by directly using the RX packet // of the transport layer, as the operation won't await. - req.respond(&self.handler, exchange, &metadata.node(), &mut wb) + req.respond(&self.handler, exchange, &metadata.node(), &mut wb, false) .await?; } @@ -355,7 +373,7 @@ where return Ok(()); }; - let req = SubscribeReq::from_tlv(&get_root_node_struct(&rx)?)?; + let req = SubscribeReqRef::new(TLVElement::new(&rx)); debug!("IM: Subscribe request: {:?}", req); let (fabric_idx, peer_node_id) = exchange.with_session(|sess| { @@ -366,7 +384,7 @@ where Ok((fabric_idx, peer_node_id)) })?; - if !req.keep_subs { + if !req.keep_subs()? { self.subscriptions .remove(Some(fabric_idx), Some(peer_node_id), None); self.subscriptions_buffers @@ -376,8 +394,8 @@ where info!("All subscriptions for [F:{fabric_idx:x},P:{peer_node_id:x}] removed"); } - let max_int_secs = core::cmp::max(req.max_int_ceil, 40); // Say we need at least 4 secs for potential latencies - let min_int_secs = req.min_int_floor; + let max_int_secs = core::cmp::max(req.max_int_ceil()?, 40); // Say we need at least 4 secs for potential latencies + let min_int_secs = req.min_int_floor()?; let Some(id) = self.subscriptions.add( fabric_idx, @@ -398,7 +416,15 @@ where }); let primed = self - .report_data(id, fabric_idx.get(), peer_node_id, &rx, &mut tx, exchange) + .report_data( + id, + fabric_idx.get(), + peer_node_id, + &rx, + &mut tx, + exchange, + true, + ) .await?; if primed { @@ -498,11 +524,6 @@ where .unwrap(); let rx = self.subscriptions_buffers.borrow_mut().remove(index).buffer; - let mut req = SubscribeReq::from_tlv(&get_root_node_struct(&rx)?)?; - - // Only used when priming the subscription - req.dataver_filters = None; - let mut exchange = if let Some(session_id) = session_id { Exchange::initiate_for_session(matter, session_id)? } else { @@ -521,6 +542,7 @@ where &rx, &mut tx, &mut exchange, + false, ) .await?; @@ -585,6 +607,7 @@ where } } + #[allow(clippy::too_many_arguments)] async fn report_data( &self, id: u32, @@ -593,13 +616,14 @@ where rx: &[u8], tx: &mut [u8], exchange: &mut Exchange<'_>, + with_dataver: bool, ) -> Result where T: DataModelHandler, { let mut wb = WriteBuf::new(tx); - let req = SubscribeReq::from_tlv(&get_root_node_struct(rx)?)?; + let req = SubscribeReqRef::new(TLVElement::new(rx)); let req = ReportDataReq::Subscribe(&req); let metadata = self.handler.lock().await; @@ -608,7 +632,7 @@ where { let node = metadata.node(); - let mut attrs = node.read(&req, None, &accessor).peekable(); + let mut attrs = node.read(&req, None, &accessor, with_dataver)?.peekable(); loop { let more_chunks = req @@ -746,41 +770,48 @@ impl<'a> ReportDataReq<'a> { ) -> Result where T: DataModelHandler, - I: Iterator, AttrStatus>>, + I: Iterator, AttrStatus>, Error>>, { wb.reset(); wb.shrink(Self::LONG_READS_TLV_RESERVE_SIZE)?; let mut tw = TLVWriter::new(wb); - tw.start_struct(TagType::Anonymous)?; + tw.start_struct(&TLVTag::Anonymous)?; if let Some(subscription_id) = subscription_id { assert!(matches!(self, ReportDataReq::Subscribe(_))); tw.u32( - TagType::Context(ReportDataTag::SubscriptionId as u8), + &TLVTag::Context(ReportDataTag::SubscriptionId as u8), subscription_id, )?; } else { assert!(matches!(self, ReportDataReq::Read(_))); } - let has_requests = self.attr_requests().is_some(); + let has_requests = self.attr_requests()?.is_some(); if has_requests { - tw.start_array(TagType::Context(ReportDataTag::AttributeReports as u8))?; + tw.start_array(&TLVTag::Context(ReportDataTag::AttributeReports as u8))?; } while let Some(item) = attrs.peek() { - if AttrDataEncoder::handle_read(exchange, item, &handler, &mut tw).await? { - attrs.next(); - } else { - break; + match item { + Ok(item) => { + if AttrDataEncoder::handle_read(exchange, item, &handler, &mut tw).await? { + attrs.next(); + } else { + break; + } + } + Err(_) => { + attrs.next().transpose()?; + } } } wb.expand(Self::LONG_READS_TLV_RESERVE_SIZE)?; - let mut tw = TLVWriter::new(wb); + let tw = wb; if has_requests { tw.end_container()?; @@ -789,11 +820,11 @@ impl<'a> ReportDataReq<'a> { let more_chunks = attrs.peek().is_some(); if more_chunks { - tw.bool(TagType::Context(ReportDataTag::MoreChunkedMsgs as u8), true)?; + tw.bool(&TLVTag::Context(ReportDataTag::MoreChunkedMsgs as u8), true)?; } if !more_chunks && suppress_resp { - tw.bool(TagType::Context(ReportDataTag::SupressResponse as u8), true)?; + tw.bool(&TLVTag::Context(ReportDataTag::SupressResponse as u8), true)?; } tw.end_container()?; @@ -802,7 +833,7 @@ impl<'a> ReportDataReq<'a> { } } -impl<'a> WriteReq<'a> { +impl<'a> WriteReqRef<'a> { async fn respond( &self, handler: T, @@ -819,8 +850,8 @@ impl<'a> WriteReq<'a> { let mut tw = TLVWriter::new(wb); - tw.start_struct(TagType::Anonymous)?; - tw.start_array(TagType::Context(WriteRespTag::WriteResponses as u8))?; + tw.start_struct(&TLVTag::Anonymous)?; + tw.start_array(&TLVTag::Context(WriteRespTag::WriteResponses as u8))?; // The spec expects that a single write request like DeleteList + AddItem // should cause all ACLs of that fabric to be deleted and the new one to be added (Case 1). @@ -833,26 +864,27 @@ impl<'a> WriteReq<'a> { // Thus we support the Case1 by doing this. It does come at the cost of maintaining an // additional list of expanded write requests as we start processing those. let write_attrs: heapless::Vec<_, MAX_WRITE_ATTRS_IN_ONE_TRANS> = - node.write(self, &accessor).collect(); + node.write(self, &accessor)?.collect(); for item in write_attrs { - AttrDataEncoder::handle_write(exchange, &item, &handler, &mut tw).await?; + AttrDataEncoder::handle_write(exchange, &item?, &handler, &mut tw).await?; } tw.end_container()?; tw.end_container()?; - Ok(self.more_chunked.unwrap_or(false)) + self.more_chunked() } } -impl<'a> InvReq<'a> { +impl<'a> InvReqRef<'a> { async fn respond( &self, handler: T, exchange: &Exchange<'_>, node: &Node<'_>, wb: &mut WriteBuf<'_>, + suppress_resp: bool, ) -> Result<(), Error> where T: DataModelHandler, @@ -861,21 +893,24 @@ impl<'a> InvReq<'a> { let mut tw = TLVWriter::new(wb); - tw.start_struct(TagType::Anonymous)?; + tw.start_struct(&TLVTag::Anonymous)?; // Suppress Response -> TODO: Need to revisit this for cases where we send a command back - tw.bool(TagType::Context(InvRespTag::SupressResponse as u8), false)?; + tw.bool( + &TLVTag::Context(InvRespTag::SupressResponse as u8), + suppress_resp, + )?; - let has_requests = self.inv_requests.is_some(); + let has_requests = self.inv_requests()?.is_some(); if has_requests { - tw.start_array(TagType::Context(InvRespTag::InvokeResponses as u8))?; + tw.start_array(&TLVTag::Context(InvRespTag::InvokeResponses as u8))?; } let accessor = exchange.accessor()?; - for item in node.invoke(self, &accessor) { - CmdDataEncoder::handle(&item, &handler, &mut tw, exchange).await?; + for item in node.invoke(self, &accessor)? { + CmdDataEncoder::handle(&item?, &handler, &mut tw, exchange).await?; } if has_requests { diff --git a/rs-matter/src/data_model/objects/cluster.rs b/rs-matter/src/data_model/objects/cluster.rs index 01da6bf9..e5575aca 100644 --- a/rs-matter/src/data_model/objects/cluster.rs +++ b/rs-matter/src/data_model/objects/cluster.rs @@ -31,7 +31,7 @@ use crate::{ }, }, // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer - tlv::{Nullable, TLVWriter, TagType}, + tlv::{Nullable, TLVTag, TLVWrite}, }; use core::fmt::{self, Debug}; @@ -311,7 +311,7 @@ impl<'a> Cluster<'a> { pub fn read(&self, attr: AttrId, mut writer: AttrDataWriter) -> Result<(), Error> { match attr.try_into()? { GlobalElements::AttributeList => { - self.encode_attribute_ids(AttrDataWriter::TAG, &mut writer)?; + self.encode_attribute_ids(&AttrDataWriter::TAG, &mut *writer)?; writer.complete() } GlobalElements::FeatureMap => writer.set(self.feature_map), @@ -322,10 +322,10 @@ impl<'a> Cluster<'a> { } } - fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) -> Result<(), Error> { + fn encode_attribute_ids(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { tw.start_array(tag)?; for a in self.attributes { - tw.u16(TagType::Anonymous, a.id)?; + tw.u16(&TLVTag::Anonymous, a.id)?; } tw.end_container() diff --git a/rs-matter/src/data_model/objects/encoder.rs b/rs-matter/src/data_model/objects/encoder.rs index ff5aaf86..50935032 100644 --- a/rs-matter/src/data_model/objects/encoder.rs +++ b/rs-matter/src/data_model/objects/encoder.rs @@ -15,108 +15,25 @@ * limitations under the License. */ -use core::fmt::{Debug, Formatter}; +use core::fmt::Debug; use core::marker::PhantomData; use core::ops::{Deref, DerefMut}; use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::messages::ib::{ - AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag, + AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdResp, CmdRespTag, CmdStatus, }; -use crate::tlv::UtfStr; +use crate::tlv::TLVTag; use crate::transport::exchange::Exchange; use crate::{ error::{Error, ErrorCode}, interaction_model::messages::ib::{AttrDataTag, AttrRespTag}, - tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{FromTLV, TLVElement, TLVWrite, TLVWriter, TagType, ToTLV}, }; use log::error; use super::{AttrDetails, CmdDetails, DataModelHandler}; -// TODO: Should this return an IMStatusCode Error? But if yes, the higher layer -// may have already started encoding the 'success' headers, we might not want to manage -// the tw.rewind() in that case, if we add this support -pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter); - -#[derive(Clone)] -/// A structure for encoding various types of values -pub enum EncodeValue<'a> { - /// This indicates a value that is dynamically generated. This variant - /// is typically used in the transmit/to-tlv path where we want to encode a value at - /// run time - Closure(EncodeValueGen<'a>), - /// This indicates a value that is in the TLVElement form. this variant is - /// typically used in the receive/from-tlv path where we don't want to decode the - /// full value but it can be done at the time of its usage - Tlv(TLVElement<'a>), - /// This indicates a static value. This variant is typically used in the transmit/ - /// to-tlv path - Value(&'a dyn ToTLV), -} - -impl<'a> EncodeValue<'a> { - pub fn unwrap_tlv(self) -> Option> { - match self { - EncodeValue::Tlv(t) => Some(t), - _ => None, - } - } -} - -impl<'a> PartialEq for EncodeValue<'a> { - fn eq(&self, other: &Self) -> bool { - match self { - EncodeValue::Closure(_) => { - error!("PartialEq not yet supported"); - false - } - EncodeValue::Tlv(a) => { - if let EncodeValue::Tlv(b) = other { - a == b - } else { - false - } - } - // Just claim false for now - EncodeValue::Value(_) => { - error!("PartialEq not yet supported"); - false - } - } - } -} - -impl<'a> Debug for EncodeValue<'a> { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> { - match self { - EncodeValue::Closure(_) => write!(f, "Contains closure"), - EncodeValue::Tlv(t) => write!(f, "{:?}", t), - EncodeValue::Value(_) => write!(f, "Contains EncodeValue"), - }?; - Ok(()) - } -} - -impl<'a> ToTLV for EncodeValue<'a> { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - match self { - EncodeValue::Closure(f) => { - (f)(tag_type, tw); - Ok(()) - } - EncodeValue::Tlv(_) => panic!("This looks invalid"), - EncodeValue::Value(v) => v.to_tlv(tw, tag_type), - } - } -} - -impl<'a> FromTLV<'a> for EncodeValue<'a> { - fn from_tlv(data: &TLVElement<'a>) -> Result { - Ok(EncodeValue::Tlv(data.clone())) - } -} - pub struct AttrDataEncoder<'a, 'b, 'c> { dataver_filter: Option, path: AttrPath, @@ -150,7 +67,7 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { }; if let Some(status) = status { - AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + AttrResp::Status(status).to_tlv(&TagType::Anonymous, tw)?; } Ok(true) @@ -176,7 +93,7 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { }; if let Some(status) = status { - status.to_tlv(tw, TagType::Anonymous)?; + status.to_tlv(&TagType::Anonymous, tw)?; } Ok(()) @@ -198,11 +115,11 @@ impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> { { let mut writer = AttrDataWriter::new(self.tw); - writer.start_struct(TagType::Anonymous)?; - writer.start_struct(TagType::Context(AttrRespTag::Data as _))?; - writer.u32(TagType::Context(AttrDataTag::DataVer as _), dataver)?; + writer.start_struct(&TLVTag::Anonymous)?; + writer.start_struct(&TLVTag::Context(AttrRespTag::Data as _))?; + writer.u32(&TLVTag::Context(AttrDataTag::DataVer as _), dataver)?; self.path - .to_tlv(&mut writer, TagType::Context(AttrDataTag::Path as _))?; + .to_tlv(&TagType::Context(AttrDataTag::Path as _), &mut *writer)?; Ok(Some(writer)) } else { @@ -218,7 +135,7 @@ pub struct AttrDataWriter<'a, 'b, 'c> { } impl<'a, 'b, 'c> AttrDataWriter<'a, 'b, 'c> { - pub const TAG: TagType = TagType::Context(AttrDataTag::Data as _); + pub const TAG: TLVTag = TLVTag::Context(AttrDataTag::Data as _); fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self { let anchor = tw.get_tail(); @@ -230,8 +147,8 @@ impl<'a, 'b, 'c> AttrDataWriter<'a, 'b, 'c> { } } - pub fn set(self, value: T) -> Result<(), Error> { - value.to_tlv(self.tw, Self::TAG)?; + pub fn set(mut self, value: T) -> Result<(), Error> { + value.to_tlv(&Self::TAG, &mut self.tw)?; self.complete() } @@ -345,7 +262,7 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { }; if let Some(status) = status { - InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?; + CmdResp::Status(status).to_tlv(&TagType::Anonymous, tw)?; } Ok(()) @@ -366,12 +283,12 @@ impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> { pub fn with_command(mut self, cmd: u16) -> Result, Error> { let mut writer = CmdDataWriter::new(self.tracker, self.tw); - writer.start_struct(TagType::Anonymous)?; - writer.start_struct(TagType::Context(InvRespTag::Cmd as _))?; + writer.start_struct(&TLVTag::Anonymous)?; + writer.start_struct(&TLVTag::Context(CmdRespTag::Cmd as _))?; self.path.path.leaf = Some(cmd as _); self.path - .to_tlv(&mut writer, TagType::Context(CmdDataTag::Path as _))?; + .to_tlv(&TagType::Context(CmdDataTag::Path as _), &mut *writer)?; Ok(writer) } @@ -398,8 +315,8 @@ impl<'a, 'b, 'c> CmdDataWriter<'a, 'b, 'c> { } } - pub fn set(self, value: T) -> Result<(), Error> { - value.to_tlv(self.tw, Self::TAG)?; + pub fn set(mut self, value: T) -> Result<(), Error> { + value.to_tlv(&Self::TAG, &mut self.tw)?; self.complete() } @@ -478,11 +395,11 @@ impl AttrUtfType { } pub fn encode(&self, writer: AttrDataWriter, value: &str) -> Result<(), Error> { - writer.set(UtfStr::new(value.as_bytes())) + writer.set(value) } - pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<&'a str, IMStatusCode> { - data.str().map_err(|_| IMStatusCode::InvalidDataType) + pub fn decode<'a>(&self, data: &TLVElement<'a>) -> Result<&'a str, IMStatusCode> { + data.utf8().map_err(|_| IMStatusCode::InvalidDataType) } } diff --git a/rs-matter/src/data_model/objects/node.rs b/rs-matter/src/data_model/objects/node.rs index 4054bb0a..4b27fa09 100644 --- a/rs-matter/src/data_model/objects/node.rs +++ b/rs-matter/src/data_model/objects/node.rs @@ -17,17 +17,16 @@ use crate::{ acl::Accessor, - alloc, data_model::objects::Endpoint, + error::Error, interaction_model::{ core::{IMStatusCode, ReportDataReq}, messages::{ - ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter}, - msg::{InvReq, WriteReq}, + ib::{AttrStatus, CmdStatus, DataVersionFilter}, + msg::{InvReqRef, WriteReqRef}, GenericPath, }, }, - // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer tlv::{TLVArray, TLVElement}, }; use core::{ @@ -65,206 +64,205 @@ pub struct Node<'a> { } impl<'a> Node<'a> { - pub fn read<'s, 'm>( - &'s self, + pub fn read<'m>( + &'m self, req: &'m ReportDataReq, from: Option, accessor: &'m Accessor<'m>, - ) -> impl Iterator> + 'm - where - 's: 'm, + with_dataver_filters: bool, + ) -> Result, Error>> + 'm, Error> { - self.read_attr_requests( - req.attr_requests() - .iter() - .flat_map(|attr_requests| attr_requests.iter()), - req.dataver_filters(), - req.fabric_filtered(), - accessor, - from, - ) - } + let dataver_filters = req.dataver_filters()?; + let fabric_filtered = req.fabric_filtered()?; + + let iter = req + .attr_requests()? + .into_iter() + .flat_map(|reqs| reqs.into_iter()) + .flat_map(move |path| { + let path = match path { + Ok(path) => path, + Err(e) => return WildcardIter::Single(once(Err(e))), + }; - fn read_attr_requests<'s, 'm, P>( - &'s self, - attr_requests: P, - dataver_filters: Option<&'m TLVArray>, - fabric_filtered: bool, - accessor: &'m Accessor<'m>, - from: Option, - ) -> impl Iterator> + 'm - where - 's: 'm, - P: Iterator + 'm, - { - alloc!(attr_requests.flat_map(move |path| { - if path.to_gp().is_wildcard() { - let from = from.clone(); - - let iter = self - .match_attributes(path.endpoint, path.cluster, path.attr) - .skip_while(move |(ep, cl, attr)| { - !Self::matches(from.as_ref(), ep.id, cl.id, attr.id as _) - }) - .filter(move |(ep, cl, attr)| { - Cluster::check_attr_access( - accessor, - GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), - false, - attr.access, - ) - .is_ok() - }) - .map(move |(ep, cl, attr)| { - let dataver = if let Some(dataver_filters) = dataver_filters { - dataver_filters.iter().find_map(|filter| { - (filter.path.endpoint == ep.id && filter.path.cluster == cl.id) - .then_some(filter.data_ver) - }) - } else { - None - }; - - Ok(AttrDetails { - node: self, - endpoint_id: ep.id, - cluster_id: cl.id, - attr_id: attr.id, - list_index: path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: fabric_filtered, - dataver, - wildcard: true, + if path.to_gp().is_wildcard() { + let from = from.clone(); + let dataver_filters = dataver_filters.clone(); + + let iter = self + .match_attributes(path.endpoint, path.cluster, path.attr) + .skip_while(move |(ep, cl, attr)| { + !Self::matches(from.as_ref(), ep.id, cl.id, attr.id as _) }) - }); - - WildcardIter::Wildcard(iter) - } else { - let ep = path.endpoint.unwrap(); - let cl = path.cluster.unwrap(); - let attr = path.attr.unwrap(); - - let result = match self.check_attribute(accessor, ep, cl, attr, false) { - Ok(()) => { - let dataver = if let Some(dataver_filters) = dataver_filters { - dataver_filters.iter().find_map(|filter| { - (filter.path.endpoint == ep && filter.path.cluster == cl) - .then_some(filter.data_ver) - }) - } else { - None - }; - - Ok(AttrDetails { - node: self, - endpoint_id: ep, - cluster_id: cl, - attr_id: attr, - list_index: path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: fabric_filtered, - dataver, - wildcard: false, + .filter(|(ep, cl, attr)| { + Cluster::check_attr_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), + false, + attr.access, + ) + .is_ok() }) - } - Err(err) => Err(AttrStatus::new(&path.to_gp(), err, 0)), - }; + .map(move |(ep, cl, attr)| { + let dataver = with_dataver_filters + .then(|| Self::dataver(dataver_filters.as_ref(), ep.id, cl.id)) + .transpose()? + .flatten(); - WildcardIter::Single(once(result)) - } - })) + Ok(Ok(AttrDetails { + node: self, + endpoint_id: ep.id, + cluster_id: cl.id, + attr_id: attr.id, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: fabric_filtered, + dataver, + wildcard: true, + })) + }); + + WildcardIter::Wildcard(iter) + } else { + let ep = path.endpoint.unwrap(); + let cl = path.cluster.unwrap(); + let attr = path.attr.unwrap(); + + let result = match self.check_attribute(accessor, ep, cl, attr, false) { + Ok(()) => Self::dataver(dataver_filters.as_ref(), ep, cl).map(|dataver| { + Ok(AttrDetails { + node: self, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, + list_index: path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: fabric_filtered, + dataver, + wildcard: false, + }) + }), + Err(err) => Ok(Err(AttrStatus::new(&path.to_gp(), err, 0))), + }; + + WildcardIter::Single(once(result)) + } + }); + + Ok(iter) } pub fn write<'m>( &'m self, - req: &'m WriteReq, + req: &'m WriteReqRef, accessor: &'m Accessor<'m>, - ) -> impl Iterator), AttrStatus>> + 'm { - alloc!(req.write_requests.iter().flat_map(move |attr_data| { - if attr_data.path.cluster.is_none() { - WildcardIter::Single(once(Err(AttrStatus::new( - &attr_data.path.to_gp(), - IMStatusCode::UnsupportedCluster, - 0, - )))) - } else if attr_data.path.attr.is_none() { - WildcardIter::Single(once(Err(AttrStatus::new( - &attr_data.path.to_gp(), - IMStatusCode::UnsupportedAttribute, - 0, - )))) - } else if attr_data.path.to_gp().is_wildcard() { - let iter = self - .match_attributes( - attr_data.path.endpoint, - attr_data.path.cluster, - attr_data.path.attr, - ) - .filter(move |(ep, cl, attr)| { - Cluster::check_attr_access( - accessor, - GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), - true, - attr.access, + ) -> Result< + impl Iterator), AttrStatus>, Error>> + 'm, + Error, + > { + let iter = req + .write_requests()? + .into_iter() + .flat_map(move |attr_data| { + let attr_data = match attr_data { + Ok(attr_data) => attr_data, + Err(e) => return WildcardIter::Single(once(Err(e))), + }; + + if attr_data.path.cluster.is_none() { + WildcardIter::Single(once(Ok(Err(AttrStatus::new( + &attr_data.path.to_gp(), + IMStatusCode::UnsupportedCluster, + 0, + ))))) + } else if attr_data.path.attr.is_none() { + WildcardIter::Single(once(Ok(Err(AttrStatus::new( + &attr_data.path.to_gp(), + IMStatusCode::UnsupportedAttribute, + 0, + ))))) + } else if attr_data.path.to_gp().is_wildcard() { + let iter = self + .match_attributes( + attr_data.path.endpoint, + attr_data.path.cluster, + attr_data.path.attr, ) - .is_ok() - }) - .map(move |(ep, cl, attr)| { - Ok(( + .filter(move |(ep, cl, attr)| { + Cluster::check_attr_access( + accessor, + GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)), + true, + attr.access, + ) + .is_ok() + }) + .map(move |(ep, cl, attr)| { + Ok(Ok(( + AttrDetails { + node: self, + endpoint_id: ep.id, + cluster_id: cl.id, + attr_id: attr.id, + list_index: attr_data.path.list_index, + fab_idx: accessor.fab_idx, + fab_filter: false, + dataver: attr_data.data_ver, + wildcard: true, + }, + attr_data.data.clone(), + ))) + }); + + WildcardIter::Wildcard(iter) + } else { + let ep = attr_data.path.endpoint.unwrap(); + let cl = attr_data.path.cluster.unwrap(); + let attr = attr_data.path.attr.unwrap(); + + let result = match self.check_attribute(accessor, ep, cl, attr, true) { + Ok(()) => Ok(Ok(( AttrDetails { node: self, - endpoint_id: ep.id, - cluster_id: cl.id, - attr_id: attr.id, + endpoint_id: ep, + cluster_id: cl, + attr_id: attr, list_index: attr_data.path.list_index, fab_idx: accessor.fab_idx, fab_filter: false, dataver: attr_data.data_ver, - wildcard: true, + wildcard: false, }, - attr_data.data.clone().unwrap_tlv().unwrap(), - )) - }); - - WildcardIter::Wildcard(iter) - } else { - let ep = attr_data.path.endpoint.unwrap(); - let cl = attr_data.path.cluster.unwrap(); - let attr = attr_data.path.attr.unwrap(); - - let result = match self.check_attribute(accessor, ep, cl, attr, true) { - Ok(()) => Ok(( - AttrDetails { - node: self, - endpoint_id: ep, - cluster_id: cl, - attr_id: attr, - list_index: attr_data.path.list_index, - fab_idx: accessor.fab_idx, - fab_filter: false, - dataver: attr_data.data_ver, - wildcard: false, - }, - attr_data.data.unwrap_tlv().unwrap(), - )), - Err(err) => Err(AttrStatus::new(&attr_data.path.to_gp(), err, 0)), - }; + attr_data.data, + ))), + Err(err) => Ok(Err(AttrStatus::new(&attr_data.path.to_gp(), err, 0))), + }; - WildcardIter::Single(once(result)) - } - })) + WildcardIter::Single(once(result)) + } + }); + + Ok(iter) } pub fn invoke<'m>( &'m self, - req: &'m InvReq, + req: &'m InvReqRef, accessor: &'m Accessor<'m>, - ) -> impl Iterator), CmdStatus>> + 'm { - alloc!(req - .inv_requests - .iter() - .flat_map(|inv_requests| inv_requests.iter()) + ) -> Result< + impl Iterator), CmdStatus>, Error>> + 'm, + Error, + > { + let iter = req + .inv_requests()? + .into_iter() + .flat_map(|reqs| reqs.into_iter()) .flat_map(move |cmd_data| { + let cmd_data = match cmd_data { + Ok(cmd_data) => cmd_data, + Err(e) => return WildcardIter::Single(once(Err(e))), + }; + if cmd_data.path.path.is_wildcard() { let iter = self .match_commands( @@ -280,7 +278,7 @@ impl<'a> Node<'a> { .is_ok() }) .map(move |(ep, cl, cmd)| { - Ok(( + Ok(Ok(( CmdDetails { node: self, endpoint_id: ep.id, @@ -288,8 +286,8 @@ impl<'a> Node<'a> { cmd_id: cmd, wildcard: true, }, - cmd_data.data.clone().unwrap_tlv().unwrap(), - )) + cmd_data.data.clone(), + ))) }); WildcardIter::Wildcard(iter) @@ -299,7 +297,7 @@ impl<'a> Node<'a> { let cmd = cmd_data.path.path.leaf.unwrap(); let result = match self.check_command(accessor, ep, cl, cmd) { - Ok(()) => Ok(( + Ok(()) => Ok(Ok(( CmdDetails { node: self, endpoint_id: cmd_data.path.path.endpoint.unwrap(), @@ -307,14 +305,16 @@ impl<'a> Node<'a> { cmd_id: cmd_data.path.path.leaf.unwrap(), wildcard: false, }, - cmd_data.data.unwrap_tlv().unwrap(), - )), - Err(err) => Err(CmdStatus::new(cmd_data.path, err, 0)), + cmd_data.data, + ))), + Err(err) => Ok(Err(CmdStatus::new(cmd_data.path, err, 0))), }; WildcardIter::Single(once(result)) } - })) + }); + + Ok(iter) } fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool { @@ -388,6 +388,24 @@ impl<'a> Node<'a> { .find(|endpoint| endpoint.id == ep) .ok_or(IMStatusCode::UnsupportedEndpoint) } + + fn dataver( + dataver_filters: Option<&TLVArray>, + ep: EndptId, + cl: ClusterId, + ) -> Result, Error> { + if let Some(dataver_filters) = dataver_filters { + for filter in dataver_filters { + let filter = filter?; + + if filter.path.endpoint == ep && filter.path.cluster == cl { + return Ok(Some(filter.data_ver)); + } + } + } + + Ok(None) + } } impl<'a> core::fmt::Display for Node<'a> { diff --git a/rs-matter/src/data_model/objects/privilege.rs b/rs-matter/src/data_model/objects/privilege.rs index 466a9e39..7cfa27d5 100644 --- a/rs-matter/src/data_model/objects/privilege.rs +++ b/rs-matter/src/data_model/objects/privilege.rs @@ -17,7 +17,7 @@ use crate::{ error::{Error, ErrorCode}, - tlv::{FromTLV, TLVElement, ToTLV}, + tlv::{FromTLV, TLVElement, TLVTag, TLVWrite, ToTLV, TLV}, }; use log::error; @@ -39,6 +39,22 @@ bitflags! { } } +impl Privilege { + pub fn raw_value(&self) -> u8 { + if self.contains(Privilege::ADMIN) { + 5 + } else if self.contains(Privilege::OPERATE) { + 4 + } else if self.contains(Privilege::MANAGE) { + 3 + } else if self.contains(Privilege::VIEW) { + 1 + } else { + 0 + } + } +} + impl FromTLV<'_> for Privilege { fn from_tlv(t: &TLVElement) -> Result where @@ -59,23 +75,11 @@ impl FromTLV<'_> for Privilege { } impl ToTLV for Privilege { - #[allow(clippy::bool_to_int_with_if)] - fn to_tlv( - &self, - tw: &mut crate::tlv::TLVWriter, - tag: crate::tlv::TagType, - ) -> Result<(), Error> { - let val = if self.contains(Privilege::ADMIN) { - 5 - } else if self.contains(Privilege::OPERATE) { - 4 - } else if self.contains(Privilege::MANAGE) { - 3 - } else if self.contains(Privilege::VIEW) { - 1 - } else { - 0 - }; - tw.u8(tag, val) + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.u8(tag, self.raw_value()) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + TLV::u8(tag, self.raw_value()).into_tlv_iter() } } diff --git a/rs-matter/src/data_model/sdm/admin_commissioning.rs b/rs-matter/src/data_model/sdm/admin_commissioning.rs index d050fe22..a5cc825a 100644 --- a/rs-matter/src/data_model/sdm/admin_commissioning.rs +++ b/rs-matter/src/data_model/sdm/admin_commissioning.rs @@ -118,10 +118,8 @@ impl AdminCommCluster { } else { match attr.attr_id.try_into()? { Attributes::WindowStatus(codec) => codec.encode(writer, 1), - Attributes::AdminVendorId(codec) => codec.encode(writer, Nullable::NotNull(1)), - Attributes::AdminFabricIndex(codec) => { - codec.encode(writer, Nullable::NotNull(1)) - } + Attributes::AdminVendorId(codec) => codec.encode(writer, Nullable::some(1)), + Attributes::AdminFabricIndex(codec) => codec.encode(writer, Nullable::some(1)), } } } else { diff --git a/rs-matter/src/data_model/sdm/dev_att.rs b/rs-matter/src/data_model/sdm/dev_att.rs index f3dbc2ac..87dca3cc 100644 --- a/rs-matter/src/data_model/sdm/dev_att.rs +++ b/rs-matter/src/data_model/sdm/dev_att.rs @@ -36,10 +36,15 @@ pub enum DataType { /// Objects that implement this trait allow the Matter subsystem to query the object /// for the Device Attestation data that is programmed in the Matter device. pub trait DevAttDataFetcher { - /// Get Device Attestation Data - /// - /// This API is expected to return the particular Device Attestation data as is - /// requested by the Matter subsystem. - /// The type of data that can be queried is defined in the [DataType] enum. - fn get_devatt_data(&self, data_type: DataType, data: &mut [u8]) -> Result; + /// Get the data in the provided buffer + fn get_devatt_data(&self, data_type: DataType, buf: &mut [u8]) -> Result; +} + +impl DevAttDataFetcher for &T +where + T: DevAttDataFetcher, +{ + fn get_devatt_data(&self, data_type: DataType, buf: &mut [u8]) -> Result { + (*self).get_devatt_data(data_type, buf) + } } diff --git a/rs-matter/src/data_model/sdm/general_commissioning.rs b/rs-matter/src/data_model/sdm/general_commissioning.rs index 8340dd2d..30befea1 100644 --- a/rs-matter/src/data_model/sdm/general_commissioning.rs +++ b/rs-matter/src/data_model/sdm/general_commissioning.rs @@ -22,7 +22,7 @@ use rs_matter_macros::idl_import; use strum::{EnumDiscriminants, FromRepr}; use crate::data_model::objects::*; -use crate::tlv::{FromTLV, TLVElement, ToTLV, UtfStr}; +use crate::tlv::{FromTLV, TLVElement, ToTLV, Utf8Str}; use crate::transport::exchange::Exchange; use crate::transport::session::SessionMode; use crate::{attribute_enum, cmd_enter}; @@ -60,7 +60,7 @@ pub enum RespCommands { #[tlvargs(lifetime = "'a")] struct CommonResponse<'a> { error_code: u8, - debug_txt: UtfStr<'a>, + debug_txt: Utf8Str<'a>, } pub const CLUSTER: Cluster<'static> = Cluster { @@ -172,7 +172,7 @@ impl GenCommCluster { } Attributes::BasicCommissioningInfo(_) => { self.basic_comm_info - .to_tlv(&mut writer, AttrDataWriter::TAG)?; + .to_tlv(&AttrDataWriter::TAG, &mut *writer)?; writer.complete() } Attributes::SupportsConcurrentConnection(codec) => { @@ -234,7 +234,7 @@ impl GenCommCluster { let cmd_data = CommonResponse { error_code: status, - debug_txt: UtfStr::new(b""), + debug_txt: "", }; encoder @@ -252,15 +252,17 @@ impl GenCommCluster { ) -> Result<(), Error> { cmd_enter!("Set Regulatory Config"); let country_code = data - .find_tag(1) + .r#struct() .map_err(|_| ErrorCode::InvalidCommand)? - .slice() + .find_ctx(1) + .map_err(|_| ErrorCode::InvalidCommand)? + .utf8() .map_err(|_| ErrorCode::InvalidCommand)?; - info!("Received country code: {:?}", country_code); + info!("Received country code: {}", country_code); let cmd_data = CommonResponse { error_code: 0, - debug_txt: UtfStr::new(b""), + debug_txt: "", }; encoder @@ -299,7 +301,7 @@ impl GenCommCluster { let cmd_data = CommonResponse { error_code: status, - debug_txt: UtfStr::new(b""), + debug_txt: "", }; encoder diff --git a/rs-matter/src/data_model/sdm/noc.rs b/rs-matter/src/data_model/sdm/noc.rs index 52793766..f3a13730 100644 --- a/rs-matter/src/data_model/sdm/noc.rs +++ b/rs-matter/src/data_model/sdm/noc.rs @@ -16,6 +16,7 @@ */ use core::cell::Cell; +use core::mem::MaybeUninit; use core::num::NonZeroU8; use log::{error, info, warn}; @@ -23,15 +24,16 @@ use log::{error, info, warn}; use strum::{EnumDiscriminants, FromRepr}; use crate::acl::{AclEntry, AuthMode}; -use crate::cert::{Cert, MAX_CERT_TLV_LEN}; +use crate::cert::{CertRef, MAX_CERT_TLV_LEN}; use crate::crypto::{self, KeyPair}; use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, MAX_SUPPORTED_FABRICS}; -use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVTag, TLVWrite, TLVWriter, ToTLV, UtfStr}; use crate::transport::exchange::Exchange; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; +use crate::utils::init::InitMaybeUninit; use crate::utils::storage::WriteBuf; use crate::{attribute_enum, cmd_enter, command_enum, error::*}; @@ -73,13 +75,6 @@ impl From for NocError { } } -// Some placeholder value for now -const MAX_CERT_DECLARATION_LEN: usize = 600; -// Some placeholder value for now -const MAX_CSR_LEN: usize = 300; -// As defined in the Matter Spec -const RESP_MAX: usize = 900; - pub const ID: u32 = 0x003E; #[derive(FromRepr)] @@ -169,11 +164,6 @@ impl NocData { } } -#[derive(ToTLV)] -struct CertChainResp<'a> { - cert: OctetStr<'a>, -} - #[derive(ToTLV)] struct NocResp<'a> { status_code: u8, @@ -239,7 +229,7 @@ impl NocCluster { } Attributes::CurrentFabricIndex(codec) => codec.encode(writer, attr.fab_idx), Attributes::Fabrics(_) => { - writer.start_array(AttrDataWriter::TAG)?; + writer.start_array(&AttrDataWriter::TAG)?; exchange .matter() .fabric_mgr @@ -250,7 +240,7 @@ impl NocCluster { entry .get_fabric_desc(fab_idx, &root_ca_cert)? - .to_tlv(&mut writer, TagType::Anonymous)?; + .to_tlv(&TLVTag::Anonymous, &mut *writer)?; } Ok(()) @@ -324,16 +314,20 @@ impl NocCluster { let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; - let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received NOC as: {}", noc_cert); + info!( + "Received NOC as: {}", + CertRef::new(TLVElement::new(r.noc_value.0)) + ); let noc = crate::utils::storage::Vec::from_slice(r.noc_value.0) .map_err(|_| NocStatus::InvalidNOC)?; let icac = if let Some(icac_value) = r.icac_value { if !icac_value.0.is_empty() { - let icac_cert = Cert::new(icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; - info!("Received ICAC as: {}", icac_cert); + info!( + "Received ICAC as: {}", + CertRef::new(TLVElement::new(icac_value.0)) + ); let icac = crate::utils::storage::Vec::from_slice(icac_value.0) .map_err(|_| NocStatus::InvalidNOC)?; @@ -431,7 +425,7 @@ impl NocCluster { let cmd_data = NocResp { status_code: status_code as u8, fab_idx, - debug_txt: UtfStr::new(debug_txt.as_bytes()), + debug_txt, }; encoder @@ -454,10 +448,7 @@ impl NocCluster { .matter() .fabric_mgr .borrow_mut() - .set_label( - fab_idx, - req.label.as_str().map_err(Error::map_invalid_data_type)?, - ) + .set_label(fab_idx, req.label) .is_err() { (NocStatus::LabelConflict, fab_idx.get()) @@ -554,19 +545,11 @@ impl NocCluster { let mut writer = encoder.with_command(RespCommands::AttReqResp as _)?; - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let mut attest_element = WriteBuf::new(&mut buf); - writer.start_struct(CmdDataWriter::TAG)?; + writer.start_struct(&CmdDataWriter::TAG)?; add_attestation_element( exchange.matter().epoch(), exchange.matter().dev_att(), req.str.0, - &mut attest_element, - &mut writer, - )?; - add_attestation_signature( - exchange.matter().dev_att(), - &mut attest_element, &attest_challenge, &mut writer, )?; @@ -588,22 +571,15 @@ impl NocCluster { info!("Received data: {}", data); let cert_type = get_certchainrequest_params(data).map_err(Error::map_invalid_command)?; - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let len = exchange - .matter() - .dev_att() - .get_devatt_data(cert_type, &mut buf)?; - let buf = &buf[0..len]; - - let cmd_data = CertChainResp { - cert: OctetStr::new(buf), - }; + let mut writer = encoder.with_command(RespCommands::CertChainResp as _)?; - encoder - .with_command(RespCommands::CertChainResp as _)? - .set(cmd_data)?; + writer.start_struct(&CmdDataWriter::TAG)?; + writer.str_cb(&TLVTag::Context(0), |buf| { + exchange.matter().dev_att().get_devatt_data(cert_type, buf) + })?; + writer.end_container()?; - Ok(()) + writer.complete() } fn handle_command_csrrequest( @@ -630,13 +606,11 @@ impl NocCluster { let mut writer = encoder.with_command(RespCommands::CSRResp as _)?; - let mut buf: [u8; RESP_MAX] = [0; RESP_MAX]; - let mut nocsr_element = WriteBuf::new(&mut buf); - writer.start_struct(CmdDataWriter::TAG)?; - add_nocsrelement(&noc_keypair, req.str.0, &mut nocsr_element, &mut writer)?; - add_attestation_signature( + writer.start_struct(&CmdDataWriter::TAG)?; + add_nocsrelement( exchange.matter().dev_att(), - &mut nocsr_element, + &noc_keypair, + req.str.0, &attest_challenge, &mut writer, )?; @@ -728,58 +702,85 @@ fn add_attestation_element( epoch: Epoch, dev_att: &dyn DevAttDataFetcher, att_nonce: &[u8], - write_buf: &mut WriteBuf, + attest_challenge: &[u8], t: &mut TLVWriter, ) -> Result<(), Error> { - let mut cert_dec: [u8; MAX_CERT_DECLARATION_LEN] = [0; MAX_CERT_DECLARATION_LEN]; - let len = dev_att.get_devatt_data(dev_att::DataType::CertDeclaration, &mut cert_dec)?; - let cert_dec = &cert_dec[0..len]; - let epoch = epoch().as_secs() as u32; - let mut writer = TLVWriter::new(write_buf); - writer.start_struct(TagType::Anonymous)?; - writer.str16(TagType::Context(1), cert_dec)?; - writer.str8(TagType::Context(2), att_nonce)?; - writer.u32(TagType::Context(3), epoch)?; - writer.end_container()?; - - t.str16(TagType::Context(0), write_buf.as_slice()) + + let mut signature_buf = MaybeUninit::<[u8; crypto::EC_SIGNATURE_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let signature_buf = signature_buf.init_zeroed(); + let mut signature_len = 0; + + t.str_cb(&TLVTag::Context(0), |buf| { + let mut wb = WriteBuf::new(buf); + wb.start_struct(&TLVTag::Anonymous)?; + wb.str_cb(&TLVTag::Context(1), |buf| { + dev_att.get_devatt_data(dev_att::DataType::CertDeclaration, buf) + })?; + wb.str(&TLVTag::Context(2), att_nonce)?; + wb.u32(&TLVTag::Context(3), epoch)?; + wb.end_container()?; + + let len = wb.get_tail(); + + signature_len = + compute_attestation_signature(dev_att, &mut wb, attest_challenge, signature_buf)?.len(); + + Ok(len) + })?; + t.str(&TLVTag::Context(1), &signature_buf[..signature_len]) } -fn add_attestation_signature( +fn add_nocsrelement( dev_att: &dyn DevAttDataFetcher, - attest_element: &mut WriteBuf, + noc_keypair: &KeyPair, + csr_nonce: &[u8], attest_challenge: &[u8], - resp: &mut TLVWriter, + t: &mut TLVWriter, ) -> Result<(), Error> { + let mut signature_buf = MaybeUninit::<[u8; crypto::EC_SIGNATURE_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let signature_buf = signature_buf.init_zeroed(); + let mut signature_len = 0; + + t.str_cb(&TLVTag::Context(0), |buf| { + let mut wb = WriteBuf::new(buf); + + wb.start_struct(&TLVTag::Anonymous)?; + wb.str_cb(&TLVTag::Context(1), |buf| { + Ok(noc_keypair.get_csr(buf)?.len()) + })?; + wb.str(&TLVTag::Context(2), csr_nonce)?; + wb.end_container()?; + + let len = wb.get_tail(); + + signature_len = + compute_attestation_signature(dev_att, &mut wb, attest_challenge, signature_buf)?.len(); + + Ok(len) + })?; + t.str(&TLVTag::Context(1), &signature_buf[..signature_len]) +} + +fn compute_attestation_signature<'a>( + dev_att: &dyn DevAttDataFetcher, + attest_element: &mut WriteBuf, + attest_challenge: &[u8], + signature_buf: &'a mut [u8], +) -> Result<&'a [u8], Error> { let dac_key = { - let mut pubkey = [0_u8; crypto::EC_POINT_LEN_BYTES]; - let mut privkey = [0_u8; crypto::BIGNUM_LEN_BYTES]; - dev_att.get_devatt_data(dev_att::DataType::DACPubKey, &mut pubkey)?; - dev_att.get_devatt_data(dev_att::DataType::DACPrivKey, &mut privkey)?; - KeyPair::new_from_components(&pubkey, &privkey) + let mut pubkey_buf = MaybeUninit::<[u8; crypto::EC_POINT_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let pubkey_buf = pubkey_buf.init_zeroed(); + let mut privkey_buf = MaybeUninit::<[u8; crypto::BIGNUM_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let privkey_buf = privkey_buf.init_zeroed(); + let pubkey_len = dev_att.get_devatt_data(dev_att::DataType::DACPubKey, pubkey_buf)?; + let privkey_len = dev_att.get_devatt_data(dev_att::DataType::DACPrivKey, privkey_buf)?; + KeyPair::new_from_components(&pubkey_buf[..pubkey_len], &privkey_buf[..privkey_len]) }?; attest_element.copy_from_slice(attest_challenge)?; - let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; - dac_key.sign_msg(attest_element.as_slice(), &mut signature)?; - resp.str8(TagType::Context(1), &signature) -} + let len = dac_key.sign_msg(attest_element.as_slice(), signature_buf)?; -fn add_nocsrelement( - noc_keypair: &KeyPair, - csr_nonce: &[u8], - write_buf: &mut WriteBuf, - resp: &mut TLVWriter, -) -> Result<(), Error> { - let mut csr: [u8; MAX_CSR_LEN] = [0; MAX_CSR_LEN]; - let csr = noc_keypair.get_csr(&mut csr)?; - let mut writer = TLVWriter::new(write_buf); - writer.start_struct(TagType::Anonymous)?; - writer.str8(TagType::Context(1), csr)?; - writer.str8(TagType::Context(2), csr_nonce)?; - writer.end_container()?; - - resp.str8(TagType::Context(0), write_buf.as_slice()) + Ok(&signature_buf[..len]) } fn get_certchainrequest_params(data: &TLVElement) -> Result { diff --git a/rs-matter/src/data_model/sdm/nw_commissioning.rs b/rs-matter/src/data_model/sdm/nw_commissioning.rs index 50501ef2..2d15913e 100644 --- a/rs-matter/src/data_model/sdm/nw_commissioning.rs +++ b/rs-matter/src/data_model/sdm/nw_commissioning.rs @@ -24,7 +24,7 @@ use crate::data_model::objects::{ Cluster, Dataver, Handler, NonBlockingHandler, Quality, ATTRIBUTE_LIST, FEATURE_MAP, }; use crate::error::{Error, ErrorCode}; -use crate::tlv::{OctetStr, TLVArray, TagType, ToTLV}; +use crate::tlv::{OctetStr, TLVArray, TLVTag, TLVWrite, ToTLV}; use crate::transport::exchange::Exchange; use crate::{attribute_enum, command_enum}; @@ -293,8 +293,8 @@ impl EthNwCommCluster { match attr.attr_id.try_into()? { Attributes::MaxNetworks => AttrType::::new().encode(writer, 1), Attributes::Networks => { - writer.start_array(AttrDataWriter::TAG)?; - info.nw_info.to_tlv(&mut writer, TagType::Anonymous)?; + writer.start_array(&AttrDataWriter::TAG)?; + info.nw_info.to_tlv(&TLVTag::Anonymous, &mut *writer)?; writer.end_container()?; writer.complete() } @@ -310,11 +310,11 @@ impl EthNwCommCluster { Attributes::LastNetworkID => { info.nw_info .network_id - .to_tlv(&mut writer, AttrDataWriter::TAG)?; + .to_tlv(&AttrDataWriter::TAG, &mut *writer)?; writer.complete() } Attributes::LastConnectErrorValue => { - writer.null(AttrDataWriter::TAG)?; + writer.null(&AttrDataWriter::TAG)?; writer.complete() } _ => Err(ErrorCode::AttributeNotFound.into()), diff --git a/rs-matter/src/data_model/sdm/wifi_nw_diagnostics.rs b/rs-matter/src/data_model/sdm/wifi_nw_diagnostics.rs index 868009b3..352255fe 100644 --- a/rs-matter/src/data_model/sdm/wifi_nw_diagnostics.rs +++ b/rs-matter/src/data_model/sdm/wifi_nw_diagnostics.rs @@ -25,7 +25,7 @@ use strum::{EnumDiscriminants, FromRepr}; use crate::data_model::objects::*; use crate::error::{Error, ErrorCode}; -use crate::tlv::{TLVElement, TagType}; +use crate::tlv::{TLVElement, TLVTag, TLVWrite}; use crate::transport::exchange::Exchange; use crate::{attribute_enum, command_enum}; @@ -182,7 +182,7 @@ impl WifiNwDiagCluster { let data = self.data.borrow(); match attr.attr_id.try_into()? { - Attributes::Bssid => writer.str8(TagType::Anonymous, &data.bssid), + Attributes::Bssid => writer.str(&TLVTag::Anonymous, &data.bssid), Attributes::SecurityType(codec) => codec.encode(writer, data.security_type), Attributes::WifiVersion(codec) => codec.encode(writer, data.wifi_version), Attributes::ChannelNumber(codec) => codec.encode(writer, data.channel_number), diff --git a/rs-matter/src/data_model/system_model/access_control.rs b/rs-matter/src/data_model/system_model/access_control.rs index 7ee58133..30f86439 100644 --- a/rs-matter/src/data_model/system_model/access_control.rs +++ b/rs-matter/src/data_model/system_model/access_control.rs @@ -24,7 +24,7 @@ use log::{error, info}; use crate::acl::{self, AclEntry, AclMgr}; use crate::data_model::objects::*; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; -use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; +use crate::tlv::{FromTLV, TLVElement, TLVTag, TLVWrite, ToTLV}; use crate::transport::exchange::Exchange; use crate::{attribute_enum, error::*}; @@ -132,7 +132,7 @@ impl AccessControlCluster { } else { match attr.attr_id.try_into()? { Attributes::Acl(_) => { - writer.start_array(AttrDataWriter::TAG)?; + writer.start_array(&AttrDataWriter::TAG)?; acl_mgr.for_each_acl(|entry| { if !attr.fab_filter || entry @@ -140,7 +140,7 @@ impl AccessControlCluster { .map(|fi| fi.get() == attr.fab_idx) .unwrap_or(false) { - entry.to_tlv(&mut writer, TagType::Anonymous)?; + entry.to_tlv(&TLVTag::Anonymous, &mut *writer)?; } Ok(()) @@ -151,7 +151,7 @@ impl AccessControlCluster { } Attributes::Extension(_) => { // Empty for now - writer.start_array(AttrDataWriter::TAG)?; + writer.start_array(&AttrDataWriter::TAG)?; writer.end_container()?; writer.complete() @@ -234,7 +234,10 @@ mod tests { use crate::data_model::objects::{AttrDataEncoder, AttrDetails, Node, Privilege}; use crate::data_model::system_model::access_control::Dataver; use crate::interaction_model::messages::ib::ListOperation; - use crate::tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}; + use crate::tlv::{ + get_root_node_struct, TLVControl, TLVElement, TLVTag, TLVTagType, TLVValueType, TLVWriter, + ToTLV, + }; use crate::utils::storage::WriteBuf; use super::AccessControlCluster; @@ -252,7 +255,7 @@ mod tests { let acl = AccessControlCluster::new(Dataver::new(0)); let new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); - new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); + new.to_tlv(&TLVTag::Anonymous, &mut tw).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, ACL has fabric index 2, but the accessing fabric is 1 @@ -289,7 +292,7 @@ mod tests { let acl = AccessControlCluster::new(Dataver::new(0)); let new = AclEntry::new(FAB_2, Privilege::VIEW, AuthMode::Case); - new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); + new.to_tlv(&TLVTag::Anonymous, &mut tw).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap(); // Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow @@ -324,7 +327,8 @@ mod tests { } let acl = AccessControlCluster::new(Dataver::new(0)); // data is don't-care actually - let data = TLVElement::new(TagType::Anonymous, ElementType::True); + let data = &[TLVControl::new(TLVTagType::Anonymous, TLVValueType::Null).as_raw()]; + let data = TLVElement::new(data.as_slice()); // Test , Delete Fabric 1's index 0 let result = acl.write_acl_attr(&mut acl_mgr, &ListOperation::DeleteItem(0), &data, FAB_1); diff --git a/rs-matter/src/data_model/system_model/descriptor.rs b/rs-matter/src/data_model/system_model/descriptor.rs index 21d29572..a5501a8a 100644 --- a/rs-matter/src/data_model/system_model/descriptor.rs +++ b/rs-matter/src/data_model/system_model/descriptor.rs @@ -22,7 +22,8 @@ use strum::FromRepr; use crate::attribute_enum; use crate::data_model::objects::*; use crate::error::Error; -use crate::tlv::{TLVWriter, TagType, ToTLV}; +use crate::tlv::TLVTag; +use crate::tlv::{TLVWrite, TLVWriter, TagType, ToTLV}; use crate::transport::exchange::Exchange; pub const ID: u32 = 0x001D; @@ -132,7 +133,7 @@ impl<'a> DescriptorCluster<'a> { self.encode_devtype_list( attr.node, attr.endpoint_id, - AttrDataWriter::TAG, + &AttrDataWriter::TAG, &mut writer, )?; writer.complete() @@ -141,7 +142,7 @@ impl<'a> DescriptorCluster<'a> { self.encode_server_list( attr.node, attr.endpoint_id, - AttrDataWriter::TAG, + &AttrDataWriter::TAG, &mut writer, )?; writer.complete() @@ -150,7 +151,7 @@ impl<'a> DescriptorCluster<'a> { self.encode_parts_list( attr.node, attr.endpoint_id, - AttrDataWriter::TAG, + &AttrDataWriter::TAG, &mut writer, )?; writer.complete() @@ -159,7 +160,7 @@ impl<'a> DescriptorCluster<'a> { self.encode_client_list( attr.node, attr.endpoint_id, - AttrDataWriter::TAG, + &AttrDataWriter::TAG, &mut writer, )?; writer.complete() @@ -175,14 +176,14 @@ impl<'a> DescriptorCluster<'a> { &self, node: &Node, endpoint_id: u16, - tag: TagType, + tag: &TLVTag, tw: &mut TLVWriter, ) -> Result<(), Error> { tw.start_array(tag)?; for endpoint in node.endpoints { if endpoint.id == endpoint_id { let dev_type = endpoint.device_type; - dev_type.to_tlv(tw, TagType::Anonymous)?; + dev_type.to_tlv(&TagType::Anonymous, &mut *tw)?; } } @@ -193,14 +194,14 @@ impl<'a> DescriptorCluster<'a> { &self, node: &Node, endpoint_id: u16, - tag: TagType, + tag: &TLVTag, tw: &mut TLVWriter, ) -> Result<(), Error> { tw.start_array(tag)?; for endpoint in node.endpoints { if endpoint.id == endpoint_id { for cluster in endpoint.clusters { - tw.u32(TagType::Anonymous, cluster.id as _)?; + tw.u32(&TLVTag::Anonymous, cluster.id as _)?; } } } @@ -212,14 +213,14 @@ impl<'a> DescriptorCluster<'a> { &self, node: &Node, endpoint_id: u16, - tag: TagType, + tag: &TLVTag, tw: &mut TLVWriter, ) -> Result<(), Error> { tw.start_array(tag)?; for endpoint in node.endpoints { if self.matcher.describe(endpoint_id, endpoint.id) { - tw.u16(TagType::Anonymous, endpoint.id)?; + tw.u16(&TLVTag::Anonymous, endpoint.id)?; } } @@ -230,7 +231,7 @@ impl<'a> DescriptorCluster<'a> { &self, _node: &Node, _endpoint_id: u16, - tag: TagType, + tag: &TLVTag, tw: &mut TLVWriter, ) -> Result<(), Error> { // No Clients supported diff --git a/rs-matter/src/fabric.rs b/rs-matter/src/fabric.rs index f6d8ba19..3b2eb9d3 100644 --- a/rs-matter/src/fabric.rs +++ b/rs-matter/src/fabric.rs @@ -16,21 +16,22 @@ */ use core::fmt::Write; +use core::mem::MaybeUninit; use core::num::NonZeroU8; -use byteorder::{BigEndian, ByteOrder, LittleEndian}; +use byteorder::{BigEndian, ByteOrder}; use heapless::String; use log::info; -use crate::cert::{Cert, MAX_CERT_TLV_LEN}; +use crate::cert::{CertRef, MAX_CERT_TLV_LEN}; use crate::crypto::{self, hkdf_sha256, HmacSha256, KeyPair}; use crate::error::{Error, ErrorCode}; use crate::group_keys::KeySet; use crate::mdns::{Mdns, ServiceMode}; -use crate::tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}; -use crate::utils::init::{init, Init}; +use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; +use crate::utils::init::{init, Init, InitMaybeUninit}; use crate::utils::storage::{Vec, WriteBuf}; const COMPRESSED_FABRIC_ID_LEN: usize = 8; @@ -73,15 +74,15 @@ impl Fabric { label: &str, ) -> Result { let (node_id, fabric_id) = { - let noc_p = Cert::new(&noc)?; + let noc_p = CertRef::new(TLVElement::new(&noc)); (noc_p.get_node_id()?, noc_p.get_fabric_id()?) }; let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN]; let ipk = { - let root_ca_p = Cert::new(&root_ca)?; - Fabric::get_compressed_id(root_ca_p.get_pubkey(), fabric_id, &mut compressed_id)?; + let root_ca_p = CertRef::new(TLVElement::new(&root_ca)); + Fabric::get_compressed_id(root_ca_p.pubkey()?, fabric_id, &mut compressed_id)?; KeySet::new(ipk, &compressed_id)? }; @@ -131,17 +132,14 @@ impl Fabric { let mut mac = HmacSha256::new(self.ipk.op_key())?; mac.update(random)?; - mac.update(self.get_root_ca()?.get_pubkey())?; + mac.update(self.get_root_ca()?.pubkey()?)?; - let mut buf: [u8; 8] = [0; 8]; - LittleEndian::write_u64(&mut buf, self.fabric_id); - mac.update(&buf)?; + mac.update(&self.fabric_id.to_le_bytes())?; + mac.update(&self.node_id.to_le_bytes())?; - LittleEndian::write_u64(&mut buf, self.node_id); - mac.update(&buf)?; - - let mut id = [0_u8; crypto::SHA256_HASH_LEN_BYTES]; - mac.finish(&mut id)?; + let mut id = MaybeUninit::<[u8; crypto::SHA256_HASH_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let id = id.init_zeroed(); + mac.finish(id)?; if id.as_slice() == target { Ok(()) } else { @@ -161,21 +159,21 @@ impl Fabric { self.fabric_id } - pub fn get_root_ca(&self) -> Result, Error> { - Cert::new(&self.root_ca) + pub fn get_root_ca(&self) -> Result, Error> { + Ok(CertRef::new(TLVElement::new(&self.root_ca))) } pub fn get_fabric_desc<'a>( &'a self, fab_idx: NonZeroU8, - root_ca_cert: &'a Cert, + root_ca_cert: &'a CertRef<'a>, ) -> Result, Error> { let desc = FabricDescriptor { - root_public_key: OctetStr::new(root_ca_cert.get_pubkey()), + root_public_key: OctetStr::new(root_ca_cert.pubkey()?), vendor_id: self.vendor_id, fabric_id: self.fabric_id, node_id: self.node_id, - label: UtfStr(self.label.as_bytes()), + label: self.label.as_str(), fab_idx, }; @@ -215,13 +213,19 @@ impl FabricMgr { } pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> { + let entries = TLVElement::new(data).array()?.iter(); + for fabric in self.fabrics.iter().flatten() { mdns.remove(&fabric.mdns_service_name)?; } - let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; + for entry in entries { + let entry = entry?; - tlv::vec_from_tlv(&mut self.fabrics, &root)?; + self.fabrics + .push(Option::::from_tlv(&entry)?) + .map_err(|_| ErrorCode::NoSpace)?; + } for fabric in self.fabrics.iter().flatten() { mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?; @@ -239,11 +243,11 @@ impl FabricMgr { self.fabrics .as_slice() - .to_tlv(&mut tw, TagType::Anonymous)?; + .to_tlv(&TagType::Anonymous, &mut tw)?; self.changed = false; - let len = tw.get_tail(); + let len = wb.get_tail(); Ok(Some(&buf[..len])) } else { diff --git a/rs-matter/src/group_keys.rs b/rs-matter/src/group_keys.rs index 6487c671..fbe73dc2 100644 --- a/rs-matter/src/group_keys.rs +++ b/rs-matter/src/group_keys.rs @@ -19,6 +19,7 @@ use crate::{ crypto::{self, SYMM_KEY_LEN_BYTES}, error::{Error, ErrorCode}, tlv::{FromTLV, ToTLV}, + utils::init::{init, zeroed, Init}, }; type KeySetKey = [u8; SYMM_KEY_LEN_BYTES]; @@ -30,6 +31,20 @@ pub struct KeySet { } impl KeySet { + pub const fn new0() -> Self { + Self { + epoch_key: [0; SYMM_KEY_LEN_BYTES], + op_key: [0; SYMM_KEY_LEN_BYTES], + } + } + + pub fn init() -> impl Init { + init!(Self { + epoch_key <- zeroed(), + op_key <- zeroed(), + }) + } + pub fn new(epoch_key: &[u8], compressed_id: &[u8]) -> Result { let mut ks = KeySet::default(); KeySet::op_key_from_ipk(epoch_key, compressed_id, &mut ks.op_key)?; diff --git a/rs-matter/src/interaction_model/core.rs b/rs-matter/src/interaction_model/core.rs index 7266e0a3..897509a1 100644 --- a/rs-matter/src/interaction_model/core.rs +++ b/rs-matter/src/interaction_model/core.rs @@ -19,7 +19,7 @@ use core::time::Duration; use crate::{ error::*, - tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{FromTLV, TLVArray, TLVElement, TLVTag, TLVWrite, TagType, ToTLV, TLV}, transport::exchange::MessageMeta, utils::{epoch::Epoch, storage::WriteBuf}, }; @@ -27,7 +27,7 @@ use num::FromPrimitive; use num_derive::FromPrimitive; use super::messages::ib::{AttrPath, DataVersionFilter}; -use super::messages::msg::{ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq}; +use super::messages::msg::{ReadReqRef, StatusResp, SubscribeReqRef, SubscribeResp, TimedReq}; #[macro_export] macro_rules! cmd_enter { @@ -37,7 +37,7 @@ macro_rules! cmd_enter { }}; } -#[derive(FromPrimitive, Debug, Clone, Copy, PartialEq)] +#[derive(FromPrimitive, Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum IMStatusCode { Success = 0, Failure = 1, @@ -98,8 +98,12 @@ impl FromTLV<'_> for IMStatusCode { } impl ToTLV for IMStatusCode { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.u16(tag_type, *self as u16) + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.u16(tag, *self as _) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + TLV::u16(tag, *self as _).into_tlv_iter() } } @@ -143,40 +147,39 @@ pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01; /// A wrapper enum for `ReadReq` and `SubscribeReq` that allows downstream code to /// treat the two in a unified manner with regards to `OpCode::ReportDataResp` type responses. +#[derive(Debug, Clone)] pub enum ReportDataReq<'a> { - Read(&'a ReadReq<'a>), - Subscribe(&'a SubscribeReq<'a>), + Read(&'a ReadReqRef<'a>), + Subscribe(&'a SubscribeReqRef<'a>), } impl<'a> ReportDataReq<'a> { - pub fn attr_requests(&self) -> &Option> { + pub fn attr_requests(&self) -> Result>, Error> { match self { - ReportDataReq::Read(req) => &req.attr_requests, - ReportDataReq::Subscribe(req) => &req.attr_requests, + ReportDataReq::Read(req) => req.attr_requests(), + ReportDataReq::Subscribe(req) => req.attr_requests(), } } - pub fn dataver_filters(&self) -> Option<&TLVArray<'_, DataVersionFilter>> { + pub fn dataver_filters(&self) -> Result>, Error> { match self { - ReportDataReq::Read(req) => req.dataver_filters.as_ref(), - ReportDataReq::Subscribe(req) => req.dataver_filters.as_ref(), + ReportDataReq::Read(req) => req.dataver_filters(), + ReportDataReq::Subscribe(req) => req.dataver_filters(), } } - pub fn fabric_filtered(&self) -> bool { + pub fn fabric_filtered(&self) -> Result { match self { - ReportDataReq::Read(req) => req.fabric_filtered, - ReportDataReq::Subscribe(req) => req.fabric_filtered, + ReportDataReq::Read(req) => req.fabric_filtered(), + ReportDataReq::Subscribe(req) => req.fabric_filtered(), } } } impl StatusResp { pub fn write(wb: &mut WriteBuf, status: IMStatusCode) -> Result<(), Error> { - let mut tw = TLVWriter::new(wb); - let status = Self { status }; - status.to_tlv(&mut tw, TagType::Anonymous) + status.to_tlv(&TagType::Anonymous, wb) } } @@ -194,10 +197,8 @@ impl SubscribeResp { subscription_id: u32, max_int: u16, ) -> Result<&'a [u8], Error> { - let mut tw = TLVWriter::new(wb); - let resp = Self::new(subscription_id, max_int); - resp.to_tlv(&mut tw, TagType::Anonymous)?; + resp.to_tlv(&TagType::Anonymous, &mut *wb)?; Ok(wb.as_slice()) } diff --git a/rs-matter/src/interaction_model/messages.rs b/rs-matter/src/interaction_model/messages.rs index a1372965..1ab36089 100644 --- a/rs-matter/src/interaction_model/messages.rs +++ b/rs-matter/src/interaction_model/messages.rs @@ -32,7 +32,11 @@ pub struct GenericPath { } impl GenericPath { - pub fn new(endpoint: Option, cluster: Option, leaf: Option) -> Self { + pub const fn new( + endpoint: Option, + cluster: Option, + leaf: Option, + ) -> Self { Self { endpoint, cluster, @@ -66,9 +70,12 @@ impl GenericPath { pub mod msg { + use core::fmt; + use crate::{ + error::Error, interaction_model::core::IMStatusCode, - tlv::{FromTLV, TLVArray, ToTLV}, + tlv::{FromTLV, TLVArray, TLVElement, ToTLV}, }; use super::ib::{ @@ -76,42 +83,83 @@ pub mod msg { EventPath, }; - #[derive(Debug, Default, FromTLV, ToTLV)] + #[derive(Debug, Default, Clone, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct SubscribeReq<'a> { pub keep_subs: bool, pub min_int_floor: u16, pub max_int_ceil: u16, pub attr_requests: Option>, - event_requests: Option>, - event_filters: Option>, + pub event_requests: Option>, + pub event_filters: Option>, // The Context Tags are discontiguous for some reason - _dummy: Option, + pub _dummy: Option, pub fabric_filtered: bool, pub dataver_filters: Option>, } - impl<'a> SubscribeReq<'a> { - pub fn new(fabric_filtered: bool, min_int_floor: u16, max_int_ceil: u16) -> Self { - Self { - fabric_filtered, - min_int_floor, - max_int_ceil, - ..Default::default() - } + #[derive(FromTLV, ToTLV, Clone, PartialEq, Eq, Hash)] + #[tlvargs(lifetime = "'a")] + pub struct SubscribeReqRef<'a>(TLVElement<'a>); + + impl<'a> SubscribeReqRef<'a> { + pub const fn new(element: TLVElement<'a>) -> Self { + Self(element) + } + + pub fn keep_subs(&self) -> Result { + self.0.r#struct()?.find_ctx(0)?.bool() + } + + pub fn min_int_floor(&self) -> Result { + self.0.r#struct()?.find_ctx(1)?.u16() + } + + pub fn max_int_ceil(&self) -> Result { + self.0.r#struct()?.find_ctx(2)?.u16() } - pub fn set_attr_requests(mut self, requests: &'a [AttrPath]) -> Self { - self.attr_requests = Some(TLVArray::new(requests)); - self + pub fn attr_requests(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(3)?) + } + + pub fn event_requests(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(4)?) + } + + pub fn event_filters(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(5)?) + } + + pub fn fabric_filtered(&self) -> Result { + self.0.r#struct()?.find_ctx(7)?.bool() + } + + pub fn dataver_filters(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(8)?) + } + } + + impl fmt::Debug for SubscribeReqRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SubscribeReqRef") + .field("keep_subs", &self.keep_subs()) + .field("min_int_floor", &self.min_int_floor()) + .field("max_int_ceil", &self.max_int_ceil()) + .field("attr_requests", &self.attr_requests()) + .field("event_requests", &self.event_requests()) + .field("event_filters", &self.event_filters()) + .field("fabric_filtered", &self.fabric_filtered()) + .field("dataver_filters", &self.dataver_filters()) + .finish() } } - #[derive(Debug, FromTLV, ToTLV)] + #[derive(Debug, Default, Clone, FromTLV, ToTLV)] pub struct SubscribeResp { pub subs_id: u32, // The Context Tags are discontiguous for some reason - _dummy: Option, + pub _dummy: Option, pub max_int: u16, } @@ -125,12 +173,12 @@ pub mod msg { } } - #[derive(FromTLV, ToTLV, Debug)] + #[derive(Debug, Clone, FromTLV, ToTLV)] pub struct TimedReq { pub timeout: u16, } - #[derive(FromTLV, ToTLV)] + #[derive(Debug, Clone, FromTLV, ToTLV)] pub struct StatusResp { pub status: IMStatusCode, } @@ -141,7 +189,7 @@ pub mod msg { InvokeRequests = 2, } - #[derive(FromTLV, ToTLV, Debug)] + #[derive(Debug, Default, Clone, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct InvReq<'a> { pub suppress_response: Option, @@ -149,45 +197,129 @@ pub mod msg { pub inv_requests: Option>>, } + #[derive(FromTLV, ToTLV, Clone, PartialEq, Eq, Hash)] + #[tlvargs(lifetime = "'a")] + pub struct InvReqRef<'a>(TLVElement<'a>); + + impl<'a> InvReqRef<'a> { + pub const fn new(element: TLVElement<'a>) -> Self { + Self(element) + } + + pub fn suppress_response(&self) -> Result { + self.0 + .r#struct()? + .find_ctx(0)? + .non_empty() + .map(|t| t.bool()) + .unwrap_or(Ok(false)) + } + + pub fn timed_request(&self) -> Result { + self.0 + .r#struct()? + .find_ctx(1)? + .non_empty() + .map(|t| t.bool()) + .unwrap_or(Ok(false)) + } + + pub fn inv_requests(&self) -> Result>>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(2)?) + } + } + + impl fmt::Debug for InvReqRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InvReqRef") + .field("suppress_response", &self.suppress_response()) + .field("timed_request", &self.timed_request()) + .field("inv_requests", &self.inv_requests()) + .finish() + } + } + #[derive(FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] pub struct InvResp<'a> { pub suppress_response: Option, - pub inv_responses: Option>>, + pub inv_responses: Option>>, } // This enum is helpful when we are constructing the response // step by step in incremental manner + #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] + #[repr(u8)] pub enum InvRespTag { SupressResponse = 0, InvokeResponses = 1, } - #[derive(Default, ToTLV, FromTLV, Debug)] + #[derive(Debug, Default, Clone, ToTLV, FromTLV)] #[tlvargs(lifetime = "'a")] pub struct ReadReq<'a> { pub attr_requests: Option>, - event_requests: Option>, - event_filters: Option>, + pub event_requests: Option>, + pub event_filters: Option>, pub fabric_filtered: bool, pub dataver_filters: Option>, } - impl<'a> ReadReq<'a> { - pub fn new(fabric_filtered: bool) -> Self { - Self { - fabric_filtered, - ..Default::default() - } + #[derive(FromTLV, ToTLV, Clone, PartialEq, Eq, Hash)] + #[tlvargs(lifetime = "'a")] + pub struct ReadReqRef<'a>(TLVElement<'a>); + + impl<'a> ReadReqRef<'a> { + pub const fn new(element: TLVElement<'a>) -> Self { + Self(element) + } + + pub fn attr_requests(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(0)?) } - pub fn set_attr_requests(mut self, requests: &'a [AttrPath]) -> Self { - self.attr_requests = Some(TLVArray::new(requests)); - self + pub fn event_requests(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(1)?) + } + + pub fn event_filters(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(2)?) + } + + pub fn fabric_filtered(&self) -> Result { + self.0.r#struct()?.find_ctx(3)?.bool() + } + + pub fn dataver_filters(&self) -> Result>, Error> { + Option::from_tlv(&self.0.r#struct()?.find_ctx(4)?) } } - #[derive(FromTLV, ToTLV, Debug)] + impl fmt::Debug for ReadReqRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ReadReqRef") + .field("attr_requests", &self.attr_requests()) + .field("event_requests", &self.event_requests()) + .field("event_filters", &self.event_filters()) + .field("fabric_filtered", &self.fabric_filtered()) + .field("dataver_filters", &self.dataver_filters()) + .finish() + } + } + + // This enum is helpful when we are constructing the request + // step by step in incremental manner + #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] + #[repr(u8)] + pub enum ReadReqTag { + AttrRequests = 0, + EventRequests = 1, + EventFilters = 2, + FabricFiltered = 3, + DataVersionFilters = 4, + } + + #[derive(Debug, Clone, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct WriteReq<'a> { pub supress_response: Option, @@ -196,18 +328,66 @@ pub mod msg { pub more_chunked: Option, } - impl<'a> WriteReq<'a> { - pub fn new(supress_response: bool, write_requests: &'a [AttrData<'a>]) -> Self { - let mut w = Self { - supress_response: None, - write_requests: TLVArray::new(write_requests), - timed_request: None, - more_chunked: None, - }; - if supress_response { - w.supress_response = Some(true); - } - w + // This enum is helpful when we are constructing the request + // step by step in incremental manner + #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] + #[repr(u8)] + pub enum WriteReqTag { + SuppressResponse = 0, + TimedRequest = 1, + WriteRequests = 2, + MoreChunked = 3, + } + + #[derive(FromTLV, ToTLV, Clone, PartialEq, Eq, Hash)] + #[tlvargs(lifetime = "'a")] + pub struct WriteReqRef<'a>(TLVElement<'a>); + + impl<'a> WriteReqRef<'a> { + pub const fn new(element: TLVElement<'a>) -> Self { + Self(element) + } + + pub fn supress_response(&self) -> Result { + self.0 + .r#struct()? + .find_ctx(0)? + .non_empty() + .map(|t| t.bool()) + .unwrap_or(Ok(false)) + } + + pub fn timed_request(&self) -> Result { + self.0 + .r#struct()? + .find_ctx(1)? + .non_empty() + .map(|t| t.bool()) + .unwrap_or(Ok(false)) + } + + pub fn write_requests(&self) -> Result, Error> { + TLVArray::new(self.0.r#struct()?.find_ctx(2)?) + } + + pub fn more_chunked(&self) -> Result { + self.0 + .r#struct()? + .find_ctx(3)? + .non_empty() + .map(|t| t.bool()) + .unwrap_or(Ok(false)) + } + } + + impl fmt::Debug for WriteReqRef<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WriteReqRef") + .field("supress_response", &self.supress_response()) + .field("timed_request", &self.timed_request()) + .field("write_requests", &self.write_requests()) + .field("more_chunked", &self.more_chunked()) + .finish() } } @@ -247,10 +427,10 @@ pub mod ib { use core::fmt::Debug; use crate::{ - data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId}, + data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EndptId}, error::{Error, ErrorCode}, interaction_model::core::IMStatusCode, - tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, + tlv::{FromTLV, Nullable, TLVElement, TLVTag, TLVWrite, ToTLV, TLV}, }; use log::error; @@ -259,12 +439,12 @@ pub mod ib { // Command Response #[derive(Clone, FromTLV, ToTLV, Debug)] #[tlvargs(lifetime = "'a")] - pub enum InvResp<'a> { + pub enum CmdResp<'a> { Cmd(CmdData<'a>), Status(CmdStatus), } - impl<'a> InvResp<'a> { + impl<'a> CmdResp<'a> { pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self { Self::Status(CmdStatus { path: cmd_path, @@ -273,18 +453,18 @@ pub mod ib { } } - impl<'a> From> for InvResp<'a> { + impl<'a> From> for CmdResp<'a> { fn from(value: CmdData<'a>) -> Self { Self::Cmd(value) } } - pub enum InvRespTag { + pub enum CmdRespTag { Cmd = 0, Status = 1, } - impl<'a> From for InvResp<'a> { + impl<'a> From for CmdResp<'a> { fn from(value: CmdStatus) -> Self { Self::Status(value) } @@ -312,11 +492,11 @@ pub mod ib { #[tlvargs(lifetime = "'a")] pub struct CmdData<'a> { pub path: CmdPath, - pub data: EncodeValue<'a>, + pub data: TLVElement<'a>, } impl<'a> CmdData<'a> { - pub fn new(path: CmdPath, data: EncodeValue<'a>) -> Self { + pub const fn new(path: CmdPath, data: TLVElement<'a>) -> Self { Self { path, data } } } @@ -327,7 +507,7 @@ pub mod ib { } // Status - #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)] + #[derive(Debug, Clone, PartialEq, Eq, Hash, FromTLV, ToTLV)] pub struct Status { pub status: IMStatusCode, pub cluster_status: u16, @@ -379,16 +559,16 @@ pub mod ib { } // Attribute Data - #[derive(Clone, PartialEq, FromTLV, ToTLV, Debug)] + #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct AttrData<'a> { pub data_ver: Option, pub path: AttrPath, - pub data: EncodeValue<'a>, + pub data: TLVElement<'a>, } impl<'a> AttrData<'a> { - pub fn new(data_ver: Option, path: AttrPath, data: EncodeValue<'a>) -> Self { + pub fn new(data_ver: Option, path: AttrPath, data: TLVElement<'a>) -> Self { Self { data_ver, path, @@ -421,7 +601,7 @@ pub mod ib { where F: FnMut(ListOperation, &TLVElement) -> Result<(), Error>, { - if let Some(Nullable::NotNull(index)) = attr.list_index { + if let Some(Some(index)) = attr.list_index.map(Into::into) { // If list index is valid, // - this is a modify item or delete item operation if data.null().is_ok() { @@ -430,15 +610,14 @@ pub mod ib { } else { f(ListOperation::EditItem(index), data) } - } else if data.confirm_array().is_ok() { + } else if let Ok(array) = data.array() { // If data is list, this is either Delete List or OverWrite List operation // in either case, we have to first delete the whole list f(ListOperation::DeleteList, data)?; // Now the data must be a list, that should be added item by item - let container = data.enter().ok_or(ErrorCode::Invalid)?; - for d in container { - f(ListOperation::AddItem, &d)?; + for d in array.iter() { + f(ListOperation::AddItem, &d?)?; } Ok(()) } else { @@ -447,7 +626,7 @@ pub mod ib { } } - #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)] + #[derive(Debug, Clone, PartialEq, Eq, Hash, FromTLV, ToTLV)] pub struct AttrStatus { path: AttrPath, status: Status, @@ -463,7 +642,7 @@ pub mod ib { } // Attribute Path - #[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)] + #[derive(Default, Clone, Debug, PartialEq, Eq, Hash, FromTLV, ToTLV)] #[tlvargs(datatype = "list")] pub struct AttrPath { pub tag_compression: Option, @@ -475,12 +654,18 @@ pub mod ib { } impl AttrPath { - pub fn new(path: &GenericPath) -> Self { + pub const fn new(path: &GenericPath) -> Self { Self { endpoint: path.endpoint, cluster: path.cluster, - attr: path.leaf.map(|x| x as u16), - ..Default::default() + attr: if let Some(leaf) = path.leaf { + Some(leaf as u16) + } else { + None + }, + tag_compression: None, + node: None, + list_index: None, } } @@ -541,8 +726,12 @@ pub mod ib { } impl ToTLV for CmdPath { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - self.path.to_tlv(tw, tag_type) + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + self.path.to_tlv(tag, tw) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + self.path.tlv_iter(tag) } } diff --git a/rs-matter/src/lib.rs b/rs-matter/src/lib.rs index aa12f3a6..114e4db3 100644 --- a/rs-matter/src/lib.rs +++ b/rs-matter/src/lib.rs @@ -71,6 +71,7 @@ //! Start off exploring by going to the [Matter] object. #![cfg_attr(not(feature = "std"), no_std)] #![allow(async_fn_in_trait)] +#![recursion_limit = "256"] pub mod acl; pub mod cert; diff --git a/rs-matter/src/pairing/mod.rs b/rs-matter/src/pairing/mod.rs index b10081a1..c26e886d 100644 --- a/rs-matter/src/pairing/mod.rs +++ b/rs-matter/src/pairing/mod.rs @@ -17,22 +17,22 @@ //! This module contains the logic for generating the pairing code and the QR code for easy pairing. -pub mod code; -pub mod qr; -pub mod vendor_identifiers; - use log::info; + +use qr::no_optional_data; use verhoeff::Verhoeff; -use crate::{ - codec::base38, data_model::cluster_basic_information::BasicInfoConfig, error::Error, - secure_channel::spake2p::VerifierOption, CommissioningData, -}; +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::Error; +use crate::secure_channel::spake2p::VerifierOption; +use crate::CommissioningData; -use self::{ - code::{compute_pairing_code, pretty_print_pairing_code}, - qr::{compute_qr_code_text, print_qr_code}, -}; +use self::code::{compute_pairing_code, pretty_print_pairing_code}; +use self::qr::{compute_qr_code_text, print_qr_code}; + +pub mod code; +pub mod qr; +pub mod vendor_identifiers; // TODO: Rework as a `bitflags!` enum #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -90,10 +90,17 @@ pub fn print_pairing_code_and_qr( buf: &mut [u8], ) -> Result<(), Error> { let pairing_code = compute_pairing_code(comm_data); + pretty_print_pairing_code(&pairing_code); - let (qr_code, remaining_buf) = - compute_qr_code_text(dev_det, comm_data, discovery_capabilities, &[], buf)?; + let (qr_code, remaining_buf) = compute_qr_code_text( + dev_det, + comm_data, + discovery_capabilities, + no_optional_data, + buf, + )?; + print_qr_code(qr_code, remaining_buf)?; Ok(()) diff --git a/rs-matter/src/pairing/qr.rs b/rs-matter/src/pairing/qr.rs index 44c14af1..55d65c3b 100644 --- a/rs-matter/src/pairing/qr.rs +++ b/rs-matter/src/pairing/qr.rs @@ -15,18 +15,17 @@ * limitations under the License. */ +use core::iter::Empty; + use qrcodegen_no_heap::{QrCode, QrCodeEcc, Version}; -use crate::{ - error::ErrorCode, - tlv::{ElementType, TLVElement, TLVWriter, TagType, ToTLV}, - utils::storage::WriteBuf, -}; +use crate::codec::base38; +use crate::error::ErrorCode; +use crate::tlv::{EitherIter, TLVTag, TLV}; +use crate::utils::storage::WriteBuf; -use super::{ - vendor_identifiers::{is_vendor_id_valid_operationally, VendorId}, - *, -}; +use super::vendor_identifiers::{is_vendor_id_valid_operationally, VendorId}; +use super::*; // See section 5.1.2. QR Code in the Matter specification const LONG_BITS: usize = 12; @@ -56,23 +55,27 @@ pub const BPKFSALT_TAG: u8 = 0x02; pub const NUMBER_OFDEVICES_TAG: u8 = 0x03; pub const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04; -pub struct QrSetupPayload<'data> { +pub struct QrSetupPayload<'data, T> { version: u8, flow_type: CommissionningFlowType, discovery_capabilities: DiscoveryCapabilities, dev_det: &'data BasicInfoConfig<'data>, comm_data: &'data CommissioningData, - // The slice must be ordered by the tag of each `TLVElement` in ascending order. - optional_data: &'data [TLVElement<'data>], + // The data written by the optional data provider must be ordered by the tag of each TLV element in ascending order. + optional_data: T, } -impl<'data> QrSetupPayload<'data> { +impl<'data, T, I> QrSetupPayload<'data, T> +where + T: Fn() -> I, + I: Iterator> + 'data, +{ /// `optional_data` should be ordered by tag number in ascending order. pub fn new( dev_det: &'data BasicInfoConfig, comm_data: &'data CommissioningData, discovery_capabilities: DiscoveryCapabilities, - optional_data: &'data [TLVElement<'data>], + optional_data: T, ) -> Self { const DEFAULT_VERSION: u8 = 0; @@ -161,104 +164,27 @@ impl<'data> QrSetupPayload<'data> { } pub fn try_as_str<'a>(&self, buf: &'a mut [u8]) -> Result<(&'a str, &'a mut [u8]), Error> { - let str_len = self.try_iter(buf)?.count(); + let str_len = self.emit_chars().count(); let (str_buf, remaining_buf) = buf.split_at_mut(str_len); let mut wb = WriteBuf::new(str_buf); - for ch in self.try_iter(remaining_buf)? { - wb.le_u8(ch as u8)?; + for ch in self.emit_chars() { + wb.le_u8(ch? as u8)?; } let str = unsafe { core::str::from_utf8_unchecked(str_buf) }; Ok((str, remaining_buf)) } - pub fn try_iter<'a>( - &'a self, - tlv_buf: &'a mut [u8], - ) -> Result + 'a, Error> { - let iter = self.emit_chars(self.optional_data_to_tlv(tlv_buf)?.iter().copied()); - - Ok(iter) - } - - pub fn estimate_optional_data_tlv(&self) -> Result { - let mut estimate = 0; - - let data_item_size_estimate = |info: &TLVElement| { - // Each data item needs a control byte and a context tag. - let mut size: usize = 2; - - if let &ElementType::Utf8l(data) = info.get_element_type() { - // We'll need to encode the string length and then the string data. - // Length is at most 8 bytes. - size += 8; - size += data.len() - } else { - // Integer. Assume it might need up to 8 bytes, for simplicity. - size += 8; - } - - size - }; - - for data in self.optional_data { - estimate += data_item_size_estimate(data); - } - - // Estimate 4 bytes of overhead per field. This can happen for a large - // octet string field: 1 byte control, 1 byte context tag, 2 bytes - // length. - // - // The struct itself has a control byte and an end-of-struct marker. - estimate += 4 + 2; - - if estimate > u32::MAX as usize { - Err(ErrorCode::NoMemory)?; - } - - Ok(estimate) - } - - pub fn optional_data_to_tlv<'a>(&self, buf: &'a mut [u8]) -> Result<&'a [u8], Error> { - if self.optional_data.is_empty() && self.dev_det.serial_no.is_empty() { - Ok(&[]) - } else { - let mut wb = WriteBuf::new(buf); - let mut tw = TLVWriter::new(&mut wb); - - tw.start_struct(TagType::Anonymous)?; - - if !self.dev_det.serial_no.is_empty() { - tw.utf8( - TagType::Context(SERIAL_NUMBER_TAG), - self.dev_det.serial_no.as_bytes(), - )?; - } - - for elem in self.optional_data { - elem.to_tlv(&mut tw, TagType::Anonymous)?; - } - - tw.end_container()?; - - let end = wb.get_tail(); - Ok(&buf[..end]) - } - } - - fn emit_chars<'a, T>(&'a self, tlv_data: T) -> impl Iterator + 'a - where - T: Iterator + 'a, - { + pub fn emit_chars(&self) -> impl Iterator> + '_ { struct PackedBitsIterator(I); impl Iterator for PackedBitsIterator where - I: Iterator, + I: Iterator>, { - type Item = (u32, u8); + type Item = Result<(u32, u8), Error>; fn next(&mut self) -> Option { let mut chunk = 0; @@ -267,6 +193,11 @@ impl<'data> QrSetupPayload<'data> { for index in 0..24 { // Up to 24 bits as we are enclding with Base38, which means up to 3 bytes at once if let Some(bit) = self.0.next() { + let bit = match bit { + Ok(bit) => bit, + Err(err) => return Some(Err(err)), + }; + chunk |= (bit as u32) << index; packed_bits += 1; } else { @@ -277,23 +208,27 @@ impl<'data> QrSetupPayload<'data> { if packed_bits > 0 { assert!(packed_bits % 8 == 0); - Some((chunk, packed_bits)) + Some(Ok((chunk, packed_bits))) } else { None } } } - "MT:".chars().chain( - PackedBitsIterator(self.emit_all_bits(tlv_data)) - .flat_map(|(bits, bits_count)| base38::encode_bits(bits, bits_count)), - ) + "MT:" + .chars() + .map(Result::Ok) + .chain( + PackedBitsIterator(self.emit_all_bits()).flat_map(|bits| match bits { + Ok((bits, bits_count)) => { + EitherIter::First(base38::encode_bits(bits, bits_count).map(Result::Ok)) + } + Err(err) => EitherIter::Second(core::iter::once(Err(err))), + }), + ) } - fn emit_all_bits<'a, I>(&'a self, tlv_data: I) -> impl Iterator + 'a - where - I: Iterator + 'a, - { + fn emit_all_bits(&self) -> impl Iterator> + '_ { let passwd = passwd_from_comm_data(self.comm_data); Self::emit_bits(self.version as _, VERSION_FIELD_LENGTH_IN_BITS) @@ -322,11 +257,51 @@ impl<'data> QrSetupPayload<'data> { SETUP_PINCODE_FIELD_LENGTH_IN_BITS, )) .chain(Self::emit_bits(0, PADDING_FIELD_LENGTH_IN_BITS)) - .chain(tlv_data.flat_map(|b| Self::emit_bits(b as _, 8))) + .chain( + self.emit_optional_tlv_data() + .flat_map(|bits| Self::emit_maybe_bits(bits.map(|bits| (bits as _, 8)))), + ) + } + + fn emit_bits(input: u32, len: usize) -> impl Iterator> { + (0..len).map(move |i| Ok((input >> i) & 1 == 1)) + } + + fn emit_maybe_bits( + bits: Result<(u32, usize), Error>, + ) -> impl Iterator> { + match bits { + Ok((input, len)) => EitherIter::First(Self::emit_bits(input, len)), + Err(err) => EitherIter::Second(core::iter::once(Err(err))), + } } - fn emit_bits(input: u32, len: usize) -> impl Iterator { - (0..len).map(move |i| (input >> i) & 1 == 1) + fn emit_optional_tlv_data(&self) -> impl Iterator> + '_ { + if self.dev_det.serial_no.is_empty() && (self.optional_data)().next().is_none() { + return EitherIter::First(core::iter::empty()); + } + + let serial_no = if self.dev_det.serial_no.is_empty() { + EitherIter::First(core::iter::empty()) + } else { + EitherIter::Second( + TLV::utf8(TLVTag::Context(SERIAL_NUMBER_TAG), self.dev_det.serial_no) + .into_tlv_iter(), + ) + }; + + EitherIter::Second( + TLV::structure(TLVTag::Anonymous) + .into_tlv_iter() + .chain(serial_no) + .flat_map(TLV::result_into_bytes_iter) + .chain((self.optional_data)()) + .chain( + TLV::end_container() + .into_tlv_iter() + .flat_map(TLV::result_into_bytes_iter), + ), + ) } } @@ -544,23 +519,33 @@ pub fn compute_qr_code_version(qr_code_text: &str) -> u8 { } } -pub fn compute_qr_code_text<'a>( +pub fn compute_qr_code_text<'a, T, I>( dev_det: &BasicInfoConfig, comm_data: &CommissioningData, discovery_capabilities: DiscoveryCapabilities, - optional_data: &[TLVElement], + optional_data: T, buf: &'a mut [u8], -) -> Result<(&'a str, &'a mut [u8]), Error> { +) -> Result<(&'a str, &'a mut [u8]), Error> +where + T: Fn() -> I, + I: Iterator>, +{ let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities, optional_data); qr_code_data.try_as_str(buf) } +pub type NoOptionalData = fn() -> Empty>; + +pub fn no_optional_data() -> Empty> { + core::iter::empty() +} + #[cfg(test)] mod tests { use super::*; - use crate::{secure_channel::spake2p::VerifierData, tlv::ElementType, utils::rand::dummy_rand}; + use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand}; #[test] fn can_base38_encode() { @@ -577,7 +562,8 @@ mod tests { }; let disc_cap = DiscoveryCapabilities::new(false, true, false); - let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap, &[]); + let qr_code_data = + QrSetupPayload::::new(&dev_det, &comm_data, disc_cap, no_optional_data); let mut buf = [0; 1024]; let data_str = qr_code_data .try_as_str(&mut buf) @@ -602,7 +588,8 @@ mod tests { }; let disc_cap = DiscoveryCapabilities::new(true, false, false); - let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap, &[]); + let qr_code_data = + QrSetupPayload::::new(&dev_det, &comm_data, disc_cap, no_optional_data); let mut buf = [0; 1024]; let data_str = qr_code_data .try_as_str(&mut buf) @@ -633,18 +620,23 @@ mod tests { }; let disc_cap = DiscoveryCapabilities::new(true, false, false); - let optional_data = [ - TLVElement::new( - TagType::Context(OPTIONAL_DEFAULT_STRING_TAG), - ElementType::Utf8l(OPTIONAL_DEFAULT_STRING_VALUE.as_bytes()), - ), - // todo: check why unsigned ints are not accepted by 'chip-tool payload parse-setup-payload' - TLVElement::new( - TagType::Context(OPTIONAL_DEFAULT_INT_TAG), - ElementType::S32(OPTIONAL_DEFAULT_INT_VALUE), - ), - ]; - let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap, &optional_data); + let optional_data = || { + TLV::utf8( + TLVTag::Context(OPTIONAL_DEFAULT_STRING_TAG), + OPTIONAL_DEFAULT_STRING_VALUE, + ) + .into_tlv_iter() + .chain( + TLV::i32( + TLVTag::Context(OPTIONAL_DEFAULT_INT_TAG), + OPTIONAL_DEFAULT_INT_VALUE, + ) + .into_tlv_iter(), + ) + .flat_map(TLV::result_into_bytes_iter) + }; + + let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap, optional_data); let mut buf = [0; 1024]; let data_str = qr_code_data diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index e1c4fe7c..01630be0 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -15,23 +15,23 @@ * limitations under the License. */ -use core::num::NonZeroU8; +use core::{mem::MaybeUninit, num::NonZeroU8}; use log::{error, trace}; use crate::{ alloc, - cert::Cert, + cert::CertRef, crypto::{self, KeyPair, Sha256}, error::{Error, ErrorCode}, fabric::Fabric, - secure_channel::common::{complete_with_status, OpCode, SCStatusCodes}, - tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType}, + secure_channel::common::{complete_with_status, sc_write, OpCode, SCStatusCodes}, + tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVTag, TLVWrite}, transport::{ exchange::Exchange, session::{NocCatIds, ReservedSession, SessionMode}, }, - utils::{rand::Rand, storage::WriteBuf}, + utils::{init::InitMaybeUninit, rand::Rand, storage::WriteBuf}, }; #[derive(Debug, Clone)] @@ -109,9 +109,9 @@ impl Case { .and_then(|fabric_idx| fabric_mgr.get_fabric(fabric_idx)); if let Some(fabric) = fabric { let root = get_root_node_struct(exchange.rx()?.payload())?; - let encrypted = root.find_tag(1)?.slice()?; + let encrypted = root.structure()?.ctx(1)?.str()?; - let mut decrypted = alloc!([0; 800]); + let mut decrypted = alloc!([0; 800]); // TODO LARGE BUFFER if encrypted.len() > decrypted.len() { error!("Data too large"); Err(ErrorCode::NoSpace)?; @@ -126,19 +126,19 @@ impl Case { let root = get_root_node_struct(decrypted)?; let d = Sigma3Decrypt::from_tlv(&root)?; - let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?); - let mut initiator_icac = None; - if let Some(icac) = d.initiator_icac { - initiator_icac = Some(alloc!(Cert::new(icac.0)?)); - } - - #[cfg(feature = "alloc")] - let initiator_icac_mut = initiator_icac.as_deref(); - - #[cfg(not(feature = "alloc"))] - let initiator_icac_mut = initiator_icac.as_ref(); + let initiator_noc = CertRef::new(TLVElement::new(d.initiator_noc.0)); + let initiator_icac = d + .initiator_icac + .map(|icac| CertRef::new(TLVElement::new(icac.0))); - if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { + let mut validate_certs_buf = alloc!([0; 800]); // TODO LARGE BUFFER + let validate_certs_buf = &mut validate_certs_buf[..]; + if let Err(e) = Case::validate_certs( + fabric, + &initiator_noc, + initiator_icac.as_ref(), + validate_certs_buf, + ) { error!("Certificate Chain doesn't match: {}", e); SCStatusCodes::InvalidParameter } else if let Err(e) = Case::validate_sigma3_sign( @@ -153,19 +153,21 @@ impl Case { } else { // Only now do we add this message to the TT Hash let mut peer_catids: NocCatIds = Default::default(); - initiator_noc.get_cat_ids(&mut peer_catids); + initiator_noc.get_cat_ids(&mut peer_catids)?; case_session .tt_hash .as_mut() .unwrap() .update(exchange.rx()?.payload())?; - let mut session_keys = [0_u8; 3 * crypto::SYMM_KEY_LEN_BYTES]; + let mut session_keys = + MaybeUninit::<[u8; 3 * crypto::SYMM_KEY_LEN_BYTES]>::uninit(); // TODO MEDIM BUFFER + let session_keys = session_keys.init_zeroed(); Case::get_session_keys( fabric.ipk.op_key(), case_session.tt_hash.as_ref().unwrap(), &case_session.shared_secret, - &mut session_keys, + session_keys, )?; let peer_addr = exchange.with_session(|sess| Ok(sess.get_peer_addr()))?; @@ -264,113 +266,100 @@ impl Case { } // println!("Derived secret: {:x?} len: {}", secret, len); - let mut our_random: [u8; 32] = [0; 32]; - (exchange.matter().rand())(&mut our_random); + let mut our_random = MaybeUninit::<[u8; 32]>::uninit(); // TODO MEDIUM BUFFER + let our_random = our_random.init_zeroed(); + (exchange.matter().rand())(our_random); - // Derive the Encrypted Part - const MAX_ENCRYPTED_SIZE: usize = 800; + let mut hash_updated = false; + exchange + .send_with(|exchange, tw| { + let fabric_mgr = exchange.matter().fabric_mgr.borrow(); - let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]); - let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]); + let fabric = NonZeroU8::new(case_session.local_fabric_idx) + .and_then(|fabric_idx| fabric_mgr.get_fabric(fabric_idx)); - let encrypted_len = { - let fabric_mgr = exchange.matter().fabric_mgr.borrow(); + let Some(fabric) = fabric else { + return sc_write(tw, SCStatusCodes::NoSharedTrustRoots, &[]); + }; - let fabric = NonZeroU8::new(case_session.local_fabric_idx) - .and_then(|fabric_idx| fabric_mgr.get_fabric(fabric_idx)); - if let Some(fabric) = fabric { - #[cfg(feature = "alloc")] - let signature_mut = &mut *signature; + tw.start_struct(&TLVTag::Anonymous)?; + tw.str(&TLVTag::Context(1), &*our_random)?; + tw.u16(&TLVTag::Context(2), local_sessid)?; + tw.str(&TLVTag::Context(3), &case_session.our_pub_key)?; + + // Use the remainder of the TX buffer as scratch space for performing signature + let sign_buf = tw.empty_as_mut_slice(); - #[cfg(not(feature = "alloc"))] - let signature_mut = &mut signature; + let mut signature = MaybeUninit::<[u8; crypto::EC_SIGNATURE_LEN_BYTES]>::uninit(); // TODO MEDIUM BUFFER + let signature = signature.init_zeroed(); let sign_len = Case::get_sigma2_sign( fabric, &case_session.our_pub_key, &case_session.peer_pub_key, - signature_mut, - )?; - let signature = &signature[..sign_len]; - - #[cfg(feature = "alloc")] - let encrypted_mut = &mut *encrypted; - - #[cfg(not(feature = "alloc"))] - let encrypted_mut = &mut encrypted; - - let encrypted_len = Case::get_sigma2_encryption( - fabric, - exchange.matter().rand(), - &our_random, - case_session, + sign_buf, signature, - encrypted_mut, )?; - Some(encrypted_len) - } else { - None - } - }; + let signature = &signature[..sign_len]; - if let Some(encrypted_len) = encrypted_len { - let mut hash_updated = false; - let encrypted = &encrypted[0..encrypted_len]; - - exchange - .send_with(|_, wb| { - let mut tw = TLVWriter::new(wb); - tw.start_struct(TagType::Anonymous)?; - tw.str8(TagType::Context(1), &our_random)?; - tw.u16(TagType::Context(2), local_sessid)?; - tw.str8(TagType::Context(3), &case_session.our_pub_key)?; - tw.str16(TagType::Context(4), encrypted)?; - tw.end_container()?; - - if !hash_updated { - case_session - .tt_hash - .as_mut() - .unwrap() - .update(wb.as_mut_slice())?; - hash_updated = true; - } - - Ok(Some(OpCode::CASESigma2.into())) - }) - .await - } else { - complete_with_status(exchange, SCStatusCodes::NoSharedTrustRoots, &[]).await - } + tw.str_cb(&TLVTag::Context(4), |buf| { + Case::get_sigma2_encryption( + fabric, + exchange.matter().rand(), + &*our_random, + case_session, + signature, + buf, + ) + })?; + tw.end_container()?; + + if !hash_updated { + case_session + .tt_hash + .as_mut() + .unwrap() + .update(tw.as_slice())?; + hash_updated = true; + } + + Ok(Some(OpCode::CASESigma2.into())) + }) + .await } fn validate_sigma3_sign( initiator_noc: &[u8], initiator_icac: Option<&[u8]>, - initiator_noc_cert: &Cert, + initiator_noc_cert: &CertRef, sign: &[u8], case_session: &CaseSession, ) -> Result<(), Error> { const MAX_TBS_SIZE: usize = 800; let mut buf = [0; MAX_TBS_SIZE]; let mut write_buf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut write_buf); - tw.start_struct(TagType::Anonymous)?; - tw.str16(TagType::Context(1), initiator_noc)?; + let tw = &mut write_buf; + tw.start_struct(&TLVTag::Anonymous)?; + tw.str(&TLVTag::Context(1), initiator_noc)?; if let Some(icac) = initiator_icac { - tw.str16(TagType::Context(2), icac)?; + tw.str(&TLVTag::Context(2), icac)?; } - tw.str8(TagType::Context(3), &case_session.peer_pub_key)?; - tw.str8(TagType::Context(4), &case_session.our_pub_key)?; + tw.str(&TLVTag::Context(3), &case_session.peer_pub_key)?; + tw.str(&TLVTag::Context(4), &case_session.our_pub_key)?; tw.end_container()?; - let key = KeyPair::new_from_public(initiator_noc_cert.get_pubkey())?; + let key = KeyPair::new_from_public(initiator_noc_cert.pubkey()?)?; key.verify_msg(write_buf.as_slice(), sign)?; Ok(()) } - fn validate_certs(fabric: &Fabric, noc: &Cert, icac: Option<&Cert>) -> Result<(), Error> { + fn validate_certs( + fabric: &Fabric, + noc: &CertRef, + icac: Option<&CertRef>, + buf: &mut [u8], + ) -> Result<(), Error> { let mut verifier = noc.verify_chain_start(); if fabric.get_fabric_id() != noc.get_fabric_id()? { @@ -384,12 +373,12 @@ impl Case { Err(ErrorCode::Invalid)?; } } - verifier = verifier.add_cert(icac)?; + verifier = verifier.add_cert(icac, buf)?; } verifier - .add_cert(&Cert::new(&fabric.root_ca)?)? - .finalise()?; + .add_cert(&CertRef::new(TLVElement::new(&fabric.root_ca)), buf)? + .finalise(buf)?; Ok(()) } @@ -519,15 +508,15 @@ impl Case { )?; let mut write_buf = WriteBuf::new(out); - let mut tw = TLVWriter::new(&mut write_buf); - tw.start_struct(TagType::Anonymous)?; - tw.str16(TagType::Context(1), &fabric.noc)?; + let tw = &mut write_buf; + tw.start_struct(&TLVTag::Anonymous)?; + tw.str(&TLVTag::Context(1), &fabric.noc)?; if let Some(icac_cert) = fabric.icac.as_ref() { - tw.str16(TagType::Context(2), icac_cert)? + tw.str(&TLVTag::Context(2), icac_cert)? }; - tw.str8(TagType::Context(3), signature)?; - tw.str8(TagType::Context(4), &resumption_id)?; + tw.str(&TLVTag::Context(3), signature)?; + tw.str(&TLVTag::Context(4), &resumption_id)?; tw.end_container()?; //println!("TBE is {:x?}", write_buf.as_borrow_slice()); let nonce: [u8; crypto::AEAD_NONCE_LEN_BYTES] = [ @@ -555,20 +544,18 @@ impl Case { fabric: &Fabric, our_pub_key: &[u8], peer_pub_key: &[u8], + buf: &mut [u8], signature: &mut [u8], ) -> Result { - // We are guaranteed this unwrap will work - const MAX_TBS_SIZE: usize = 800; - let mut buf = [0; MAX_TBS_SIZE]; - let mut write_buf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut write_buf); - tw.start_struct(TagType::Anonymous)?; - tw.str16(TagType::Context(1), &fabric.noc)?; + let mut write_buf = WriteBuf::new(buf); + let tw = &mut write_buf; + tw.start_struct(&TLVTag::Anonymous)?; + tw.str(&TLVTag::Context(1), &fabric.noc)?; if let Some(icac_cert) = fabric.icac.as_deref() { - tw.str16(TagType::Context(2), icac_cert)?; + tw.str(&TLVTag::Context(2), icac_cert)?; } - tw.str8(TagType::Context(3), our_pub_key)?; - tw.str8(TagType::Context(4), peer_pub_key)?; + tw.str(&TLVTag::Context(3), our_pub_key)?; + tw.str(&TLVTag::Context(4), peer_pub_key)?; tw.end_container()?; //println!("TBS is {:x?}", write_buf.as_borrow_slice()); fabric.sign_msg(write_buf.as_slice(), signature) diff --git a/rs-matter/src/secure_channel/core.rs b/rs-matter/src/secure_channel/core.rs index b5d77259..6dc3069f 100644 --- a/rs-matter/src/secure_channel/core.rs +++ b/rs-matter/src/secure_channel/core.rs @@ -53,11 +53,11 @@ impl SecureChannel { match meta.opcode()? { OpCode::PBKDFParamRequest => { - let mut spake2p = alloc!(Spake2P::new()); + let mut spake2p = alloc!(Spake2P::new()); // TODO LARGE BUFFER Pake::new().handle(exchange, &mut spake2p).await } OpCode::CASESigma1 => { - let mut case_session = alloc!(CaseSession::new()); + let mut case_session = alloc!(CaseSession::new()); // TODO LARGE BUFFER Case::new().handle(exchange, &mut case_session).await } opcode => { diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index 9e7297fa..d999914e 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -23,7 +23,7 @@ use crate::crypto; use crate::error::{Error, ErrorCode}; use crate::mdns::{Mdns, ServiceMode}; use crate::secure_channel::common::{complete_with_status, OpCode}; -use crate::tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; +use crate::tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TagType, ToTLV}; use crate::transport::{ exchange::{Exchange, ExchangeId}, session::{ReservedSession, SessionMode}, @@ -278,10 +278,10 @@ impl Pake { exchange .send_with(|_, wb| { let resp = Pake1Resp { - pb: OctetStr(&pB), - cb: OctetStr(&cB), + pb: OctetStr::new(&pB), + cb: OctetStr::new(&cB), }; - resp.to_tlv(&mut TLVWriter::new(wb), TagType::Anonymous)?; + resp.to_tlv(&TagType::Anonymous, wb)?; Ok(Some(OpCode::PASEPake2.into())) }) @@ -304,8 +304,7 @@ impl Pake { let pase = exchange.matter().pase_mgr.borrow(); let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; - let root = tlv::get_root_node(rx.payload())?; - let a = PBKDFParamReq::from_tlv(&root)?; + let a = PBKDFParamReq::from_tlv(&TLVElement::new(rx.payload()))?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); Err(ErrorCode::Invalid)?; @@ -330,14 +329,14 @@ impl Pake { // Generate response let mut resp = PBKDFParamResp { init_random: OctetStr::new(initiator_random), - our_random: OctetStr(&our_random), + our_random: OctetStr::new(&our_random), local_sessid, params: None, }; if !a.has_params { let params_resp = PBKDFParamRespParams { count: session.verifier.count, - salt: OctetStr(&salt), + salt: OctetStr::new(&salt), }; resp.params = Some(params_resp); } @@ -351,7 +350,7 @@ impl Pake { let mut context_set = false; exchange .send_with(|_, wb| { - resp.to_tlv(&mut TLVWriter::new(wb), TagType::Anonymous)?; + resp.to_tlv(&TagType::Anonymous, &mut *wb)?; if !context_set { spake2p.update_context(wb.as_slice())?; @@ -455,7 +454,7 @@ struct PBKDFParamResp<'a> { #[allow(non_snake_case)] fn extract_pasepake_1_or_3_params(buf: &[u8]) -> Result<&[u8], Error> { let root = get_root_node_struct(buf)?; - let pA = root.find_tag(1)?.slice()?; + let pA = root.structure()?.ctx(1)?.str()?; Ok(pA) } diff --git a/rs-matter/src/tlv.rs b/rs-matter/src/tlv.rs new file mode 100644 index 00000000..6bf66a17 --- /dev/null +++ b/rs-matter/src/tlv.rs @@ -0,0 +1,1135 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::borrow::Borrow; +use core::fmt; +use core::iter::Once; +use core::marker::PhantomData; + +use num::FromPrimitive; +use num_traits::ToBytes; + +use crate::error::{Error, ErrorCode}; + +pub use rs_matter_macros::{FromTLV, ToTLV}; + +pub use read::*; +pub use toiter::*; +pub use traits::*; +pub use write::*; + +mod read; +mod toiter; +mod traits; +mod write; + +/// Represents the TLV tag type encoded in the control byte of each TLV element. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, num_derive::FromPrimitive)] +#[repr(u8)] +pub enum TLVTagType { + Anonymous = 0, + Context = 1, + CommonPrf16 = 2, + CommonPrf32 = 3, + ImplPrf16 = 4, + ImplPrf32 = 5, + FullQual48 = 6, + FullQual64 = 7, +} + +impl TLVTagType { + /// Return the size of the tag data following the control byte + /// in the TLV element representation. + pub const fn size(&self) -> usize { + match self { + Self::Anonymous => 0, + Self::Context => 1, + Self::CommonPrf16 => 2, + Self::CommonPrf32 => 4, + Self::ImplPrf16 => 2, + Self::ImplPrf32 => 4, + Self::FullQual48 => 6, + Self::FullQual64 => 8, + } + } +} + +impl fmt::Display for TLVTagType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Anonymous => write!(f, "Anonymous"), + Self::Context => write!(f, "Context"), + Self::CommonPrf16 => write!(f, "CommonPrf16"), + Self::CommonPrf32 => write!(f, "CommonPrf32"), + Self::ImplPrf16 => write!(f, "ImplPrf16"), + Self::ImplPrf32 => write!(f, "ImplPrf32"), + Self::FullQual48 => write!(f, "FullQual48"), + Self::FullQual64 => write!(f, "FullQual64"), + } + } +} + +/// Represents the TLV value type encoded in the control byte of each TLV element. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, num_derive::FromPrimitive)] +#[repr(u8)] +pub enum TLVValueType { + S8 = 0, + S16 = 1, + S32 = 2, + S64 = 3, + U8 = 4, + U16 = 5, + U32 = 6, + U64 = 7, + False = 8, + True = 9, + F32 = 10, + F64 = 11, + Utf8l = 12, + Utf16l = 13, + Utf32l = 14, + Utf64l = 15, + Str8l = 16, + Str16l = 17, + Str32l = 18, + Str64l = 19, + Null = 20, + Struct = 21, + Array = 22, + List = 23, + EndCnt = 24, +} + +impl TLVValueType { + /// Return the size of the value corresponding to this value type. + /// + /// If the value type has a variable size (i.e. octet and Utf8 strings), this function returns `None`. + pub const fn fixed_size(&self) -> Option { + match self { + Self::S8 => Some(1), + Self::S16 => Some(2), + Self::S32 => Some(4), + Self::S64 => Some(8), + Self::U8 => Some(1), + Self::U16 => Some(2), + Self::U32 => Some(4), + Self::U64 => Some(8), + Self::F32 => Some(4), + Self::F64 => Some(8), + Self::Utf8l + | Self::Utf16l + | Self::Utf32l + | Self::Utf64l + | Self::Str8l + | Self::Str16l + | Self::Str32l + | Self::Str64l => None, + _ => Some(0), + } + } + + /// Return the size of the length field for variable size value types. + /// + /// if the value type has a fixed size, this function returns 0. + /// Variable size types are only octet strings and utf8 strings. + pub const fn variable_size_len(&self) -> usize { + match self { + Self::Utf8l | Self::Str8l => 1, + Self::Utf16l | Self::Str16l => 2, + Self::Utf32l | Self::Str32l => 4, + Self::Utf64l | Self::Str64l => 8, + _ => 0, + } + } + + /// Convenience method to check if the value type is a container type + /// (container start or end). + pub const fn is_container(&self) -> bool { + self.is_container_start() || self.is_container_end() + } + + /// Convenience method to check if the value type is a container start type. + pub const fn is_container_start(&self) -> bool { + matches!(self, Self::Struct | Self::Array | Self::List) + } + + /// Convenience method to check if the value type is a container end type. + pub const fn is_container_end(&self) -> bool { + matches!(self, Self::EndCnt) + } + + /// Convenience method to check if the value type is an Octet String type. + pub const fn is_str(&self) -> bool { + matches!( + self, + Self::Str8l | Self::Str16l | Self::Str32l | Self::Str64l + ) + } + + /// Convenience method to check if the value type is a UTF-8 String type. + pub const fn is_utf8(&self) -> bool { + matches!( + self, + Self::Utf8l | Self::Utf16l | Self::Utf32l | Self::Utf64l + ) + } + + pub fn container_value<'a>(&self) -> Result, Error> { + Ok(match self { + Self::Struct => TLVValue::Struct, + Self::Array => TLVValue::Array, + Self::List => TLVValue::List, + Self::EndCnt => TLVValue::EndCnt, + _ => Err(ErrorCode::TLVTypeMismatch)?, + }) + } +} + +impl fmt::Display for TLVValueType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::S8 => write!(f, "S8"), + Self::S16 => write!(f, "S16"), + Self::S32 => write!(f, "S32"), + Self::S64 => write!(f, "S64"), + Self::U8 => write!(f, "U8"), + Self::U16 => write!(f, "U16"), + Self::U32 => write!(f, "U32"), + Self::U64 => write!(f, "U64"), + Self::False => write!(f, "False"), + Self::True => write!(f, "True"), + Self::F32 => write!(f, "F32"), + Self::F64 => write!(f, "F64"), + Self::Utf8l => write!(f, "Utf8l"), + Self::Utf16l => write!(f, "Utf16l"), + Self::Utf32l => write!(f, "Utf32l"), + Self::Utf64l => write!(f, "Utf64l"), + Self::Str8l => write!(f, "Str8l"), + Self::Str16l => write!(f, "Str16l"), + Self::Str32l => write!(f, "Str32l"), + Self::Str64l => write!(f, "Str64l"), + Self::Null => write!(f, "Null"), + Self::Struct => write!(f, "Struct"), + Self::Array => write!(f, "Array"), + Self::List => write!(f, "List"), + Self::EndCnt => write!(f, "EndCnt"), + } + } +} + +/// Represents the control byte of a TLV element (i.e. the tag type and the value type). +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] +pub struct TLVControl { + pub tag_type: TLVTagType, + pub value_type: TLVValueType, +} + +impl TLVControl { + const TAG_SHIFT_BITS: u8 = 5; + const TAG_MASK: u8 = 0xe0; + const TYPE_MASK: u8 = 0x1f; + + /// Create a new TLV control byte by parsing the provided tag type and value type. + #[inline(always)] + pub const fn new(tag_type: TLVTagType, value_type: TLVValueType) -> Self { + Self { + tag_type, + value_type, + } + } + + /// Create a new TLV control byte by parsing the provided control byte + /// into a tag type and a value type. + /// + /// The function will return an error if the provided control byte is invalid. + #[inline(always)] + pub fn parse(control: u8) -> Result { + let tag_type = FromPrimitive::from_u8((control & Self::TAG_MASK) >> Self::TAG_SHIFT_BITS) + .ok_or(ErrorCode::TLVTypeMismatch)?; + let value_type = + FromPrimitive::from_u8(control & Self::TYPE_MASK).ok_or(ErrorCode::TLVTypeMismatch)?; + + Ok(Self::new(tag_type, value_type)) + } + + /// Return the raw control byte. + #[inline(always)] + pub const fn as_raw(&self) -> u8 { + ((self.tag_type as u8) << Self::TAG_SHIFT_BITS) | (self.value_type as u8) + } + + /// Return `true` if the control byte represents a container start (struct, array or list). + #[inline(always)] + pub fn is_container_start(&self) -> bool { + self.value_type.is_container_start() + } + + /// Return `true` if the control byte represents a container end. + #[inline(always)] + pub fn is_container_end(&self) -> bool { + matches!(self.tag_type, TLVTagType::Anonymous) && self.value_type.is_container_end() + } + + /// Return an error if the control byte does not represent a container start. + #[inline(always)] + pub fn confirm_container_end(&self) -> Result<(), Error> { + if !self.is_container_end() { + return Err(ErrorCode::InvalidData.into()); + } + + Ok(()) + } +} + +impl fmt::Display for TLVControl { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Control({} {})", self.tag_type, self.value_type) + } +} + +/// A high-level representation of a TLV tag + value. +/// +/// Amongsat other things, it is a convenient way to emit TLV byte sequences. +/// +/// A `TLV` can be constructed programmatically, or returned from a `TLVElement`. +/// +/// Unlike a `TLVElement` however, a `TLV` does not represent a complete container, +/// but rather, its beginning or end. +/// +/// I.e. +/// ```rust +/// use rs_matter::tlv::{TLV, TLVTag, TLVValue}; +/// +/// let tlvs = &[ +/// TLV::new(TLVTag::Anonymous, TLVValue::Struct), +/// TLV::new(TLVTag::Context(0), TLVValue::Utf8l("Hello, World!")), +/// TLV::new(TLVTag::Anonymous, TLVValue::EndCnt), +/// ]; +/// +/// let bytes_iter = tlvs.iter().flat_map(|tlv| tlv.bytes_iter()); +/// for byte in bytes_iter { +/// println!("{:02X}", byte); +/// } +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct TLV<'a> { + pub tag: TLVTag, + pub value: TLVValue<'a>, +} + +impl<'a> TLV<'a> { + /// Create a new TLV instance with the provided tag and value. + pub const fn new(tag: TLVTag, value: TLVValue<'a>) -> Self { + Self { tag, value } + } + + /// Create a TLV with the given tag and the provided value as an S8 TLV. + pub const fn i8(tag: TLVTag, value: i8) -> Self { + Self::new(tag, TLVValue::i8(value)) + } + + /// Create a TLV with the given tag and the provided value as an S8 or S16 TLV, + /// depending on whether the value is small enough to fit in an S8 TLV. + pub const fn i16(tag: TLVTag, value: i16) -> Self { + Self::new(tag, TLVValue::i16(value)) + } + + /// Create a TLV with the given tag and the provided value as an S8, S16, or S32 TLV, + /// depending on whether the value is small enough to fit in an S8 or S16 TLV. + pub const fn i32(tag: TLVTag, value: i32) -> Self { + Self::new(tag, TLVValue::i32(value)) + } + + /// Create a TLV with the given tag and the provided value as an S8, S16, S32, or S64 TLV, + /// depending on whether the value is small enough to fit in an S8, S16, or S32 TLV. + pub const fn i64(tag: TLVTag, value: i64) -> Self { + Self::new(tag, TLVValue::i64(value)) + } + + /// Create a TLV with the given tag and the provided value as a U8 TLV. + pub const fn u8(tag: TLVTag, value: u8) -> Self { + Self::new(tag, TLVValue::u8(value)) + } + + /// Create a TLV with the given tag and the provided value as a U8 or U16 TLV, + /// depending on whether the value is small enough to fit in a U8 TLV. + pub const fn u16(tag: TLVTag, value: u16) -> Self { + Self::new(tag, TLVValue::u16(value)) + } + + /// Create a TLV with the given tag and the provided value as a U8, U16, or U32 TLV, + /// depending on whether the value is small enough to fit in a U8 or U16 TLV. + pub const fn u32(tag: TLVTag, value: u32) -> Self { + Self::new(tag, TLVValue::u32(value)) + } + + /// Create a TLV with the given tag and the provided value as a U8, U16, U32, or U64 TLV, + /// depending on whether the value is small enough to fit in a U8, U16, or U32 TLV. + pub const fn u64(tag: TLVTag, value: u64) -> Self { + Self::new(tag, TLVValue::u64(value)) + } + + /// Create a TLV with the given tag and the provided value as a F32 TLV. + pub const fn f32(tag: TLVTag, value: f32) -> Self { + Self::new(tag, TLVValue::f32(value)) + } + + /// Create a TLV with the given tag and the provided value as a F64 TLV. + pub const fn f64(tag: TLVTag, value: f64) -> Self { + Self::new(tag, TLVValue::f64(value)) + } + + /// Create a TLV with the given tag and the provided value as a UTF-8 TLV. + /// The length of the string is encoded as 1, 2, 4 or 8 octets, + /// depending on the length of the string. + pub const fn utf8(tag: TLVTag, value: &'a str) -> Self { + Self::new(tag, TLVValue::utf8(value)) + } + + /// Create a TLV with the given tag and the provided value as an octet string TLV. + /// The length of the string is encoded as 1, 2, 4 or 8 octets, + /// depending on the length of the string. + pub const fn str(tag: TLVTag, value: &'a [u8]) -> Self { + Self::new(tag, TLVValue::str(value)) + } + + /// Create a TLV with the given tag which will have a value of type Struct (start). + pub const fn r#struct(tag: TLVTag) -> Self { + Self::new(tag, TLVValue::r#struct()) + } + + /// Create a TLV with the given tag which will have a value of type Struct (start). + pub const fn structure(tag: TLVTag) -> Self { + Self::new(tag, TLVValue::structure()) + } + + /// Create a TLV with the given tag which will have a value of type Array (start). + pub const fn array(tag: TLVTag) -> Self { + Self::new(tag, TLVValue::array()) + } + + /// Create a TLV with the given tag which will have a value of type List (start). + pub const fn list(tag: TLVTag) -> Self { + Self::new(tag, TLVValue::list()) + } + + /// Create a TLV with the given tag which will have a value of type EndCnt (container end). + pub const fn end_container() -> Self { + Self::new(TLVTag::Anonymous, TLVValue::end_container()) + } + + /// Create a TLV with the given tag which will have a value of type Null. + pub const fn null(tag: TLVTag) -> Self { + Self::new(tag, TLVValue::null()) + } + + /// Create a TLV with the given tag which will have a value of type True. + pub const fn bool(tag: TLVTag, value: bool) -> Self { + Self::new(tag, TLVValue::bool(value)) + } + + /// Converts the TLV into an iterator with a single item - the TLV. + pub fn into_tlv_iter(self) -> OnceTLVIter<'a> { + core::iter::once(Ok(self)) + } + + /// Returns an iterator over the bytes of the TLV. + pub fn bytes_iter(&self) -> TLVBytesIter<'a, &TLVTag, &TLVValue<'a>> { + TLVBytesIter { + control: core::iter::once( + TLVControl::new(self.tag.tag_type(), self.value.value_type()).as_raw(), + ), + tag: self.tag.iter(), + value: self.value.iter(), + } + } + + /// Converts the TLV into an iterator over its bytes. + pub fn into_bytes_iter(self) -> TLVBytesIter<'a, TLVTag, TLVValue<'a>> { + TLVBytesIter { + control: core::iter::once( + TLVControl::new(self.tag.tag_type(), self.value.value_type()).as_raw(), + ), + tag: self.tag.into_iterator(), + value: self.value.into_iterator(), + } + } + + /// Converts the provided result into an iterator over the bytes of the TLV. + pub fn result_into_bytes_iter( + result: Result, + ) -> TLVResultBytesIter<'a, TLVTag, TLVValue<'a>> { + TLVResultBytesIter::new(result) + } +} + +impl<'a> IntoIterator for TLV<'a> { + type Item = u8; + type IntoIter = TLVBytesIter<'a, TLVTag, TLVValue<'a>>; + + fn into_iter(self) -> Self::IntoIter { + TLV::into_bytes_iter(self) + } +} + +impl<'s, 'a> IntoIterator for &'s TLV<'a> { + type Item = u8; + type IntoIter = TLVBytesIter<'a, &'s TLVTag, &'s TLVValue<'a>>; + + fn into_iter(self) -> Self::IntoIter { + TLV::bytes_iter(self) + } +} + +/// An iterator over the bytes of a TLV that might return an error. +pub enum TLVResultBytesIter<'a, T, V> +where + T: Borrow, + V: Borrow>, +{ + Ok(TLVBytesIter<'a, T, V>), + Err(core::iter::Once>), +} + +impl<'a> TLVResultBytesIter<'a, TLVTag, TLVValue<'a>> { + pub fn new(result: Result, Error>) -> Self { + match result { + Ok(tlv) => Self::Ok(tlv.into_bytes_iter()), + Err(err) => Self::Err(core::iter::once(Err(err))), + } + } +} + +impl<'a, T, V> Iterator for TLVResultBytesIter<'a, T, V> +where + T: Borrow, + V: Borrow>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match self { + Self::Ok(iter) => iter.next().map(Ok), + Self::Err(iter) => iter.next(), + } + } +} + +/// An iterator over the bytes of a TLV. +pub struct TLVBytesIter<'a, T, V> +where + T: Borrow, + V: Borrow>, +{ + control: Once, + tag: TLVTagIter, + value: TLVValueIter<'a, V>, +} + +impl<'a, T, V> Iterator for TLVBytesIter<'a, T, V> +where + T: Borrow, + V: Borrow>, +{ + type Item = u8; + + fn next(&mut self) -> Option { + self.control + .next() + .or_else(|| self.tag.next()) + .or_else(|| self.value.next()) + } +} + +/// The iterator type for a TLV that returns the TLV itself. +pub type OnceTLVIter<'s> = core::iter::Once, Error>>; + +/// For backwards compatibility +pub type TagType = TLVTag; + +/// A high-level representation of a TLV tag (tag type and tag value). +/// +/// A `TLVTag` can be constructed programmatically, or returned from a `TLVElement`. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum TLVTag { + Anonymous, + Context(u8), + CommonPrf16(u16), + CommonPrf32(u32), + ImplPrf16(u16), + ImplPrf32(u32), + FullQual48 { + vendor_id: u16, + profile: u16, + tag: u16, + }, + FullQual64 { + vendor_id: u16, + profile: u16, + tag: u32, + }, +} + +impl TLVTag { + /// Return the tag type of the TLV tag. + pub const fn tag_type(&self) -> TLVTagType { + match self { + Self::Anonymous => TLVTagType::Anonymous, + Self::Context(_) => TLVTagType::Context, + Self::CommonPrf16(_) => TLVTagType::CommonPrf16, + Self::CommonPrf32(_) => TLVTagType::CommonPrf32, + Self::ImplPrf16(_) => TLVTagType::ImplPrf16, + Self::ImplPrf32(_) => TLVTagType::ImplPrf32, + Self::FullQual48 { .. } => TLVTagType::FullQual48, + Self::FullQual64 { .. } => TLVTagType::FullQual64, + } + } + + /// Return an iterator over the bytes of the TLV tag. + pub fn iter(&self) -> TLVTagIter<&Self> { + TLVTagIter { + value: self, + index: 0, + } + } + + /// Converts itself into an iterator over the bytes of the TLV tag. + pub fn into_iterator(self) -> TLVTagIter { + TLVTagIter { + value: self, + index: 0, + } + } + + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TLVTag::Anonymous => Ok(()), + TLVTag::Context(tag) => write!(f, "{tag}"), + TLVTag::CommonPrf16(tag) => write!(f, "CommonPrf16({tag})"), + TLVTag::CommonPrf32(tag) => write!(f, "CommonPrf32({tag})"), + TLVTag::ImplPrf16(tag) => write!(f, "ImplPrf16({tag})"), + TLVTag::ImplPrf32(tag) => write!(f, "ImplPrf32({tag})"), + TLVTag::FullQual48 { + vendor_id, + profile, + tag, + } => write!(f, "FullQual48(VID:{vendor_id} PRF:{profile} {tag})"), + TLVTag::FullQual64 { + vendor_id, + profile, + tag, + } => write!(f, "FullQual64(VID:{vendor_id} PRF:{profile} {tag})"), + } + } +} + +impl IntoIterator for TLVTag { + type Item = u8; + type IntoIter = TLVTagIter; + + fn into_iter(self) -> Self::IntoIter { + TLVTag::into_iterator(self) + } +} + +impl<'a> IntoIterator for &'a TLVTag { + type Item = u8; + type IntoIter = TLVTagIter; + + fn into_iter(self) -> Self::IntoIter { + TLVTag::iter(self) + } +} + +/// An iterator over the bytes of a TLV tag. +pub struct TLVTagIter +where + T: Borrow, +{ + value: T, + index: usize, +} + +impl TLVTagIter +where + T: Borrow, +{ + fn next_byte(&mut self, bytes: &[u8]) -> Option { + self.next_byte_offset(0, bytes) + } + + fn next_byte_offset(&mut self, offset: usize, bytes: &[u8]) -> Option { + if self.index - offset < bytes.len() { + let byte = bytes[self.index - offset]; + + self.index += 1; + + Some(byte) + } else { + None + } + } +} + +impl Iterator for TLVTagIter +where + T: Borrow, +{ + type Item = u8; + + fn next(&mut self) -> Option { + match self.value.borrow() { + TLVTag::Anonymous => None, + TLVTag::Context(tag) => self.next_byte(&tag.to_le_bytes()), + TLVTag::CommonPrf16(tag) => self.next_byte(&tag.to_le_bytes()), + TLVTag::CommonPrf32(tag) => self.next_byte(&tag.to_le_bytes()), + TLVTag::ImplPrf16(tag) => self.next_byte(&tag.to_le_bytes()), + TLVTag::ImplPrf32(tag) => self.next_byte(&tag.to_le_bytes()), + TLVTag::FullQual48 { + vendor_id, + profile, + tag, + } => { + if self.index < 2 { + self.next_byte(&vendor_id.to_le_bytes()) + } else if self.index < 4 { + self.next_byte_offset(2, &profile.to_le_bytes()) + } else { + self.next_byte_offset(4, &tag.to_le_bytes()) + } + } + TLVTag::FullQual64 { + vendor_id, + profile, + tag, + } => { + if self.index < 2 { + self.next_byte(&vendor_id.to_le_bytes()) + } else if self.index < 4 { + self.next_byte_offset(2, &profile.to_le_bytes()) + } else { + self.next_byte_offset(4, &tag.to_le_bytes()) + } + } + } + } +} + +impl fmt::Display for TLVTag { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TLVTag::Anonymous => write!(f, "Anonymous"), + TLVTag::Context(tag) => write!(f, "Context({})", tag), + _ => self.fmt(f), + } + } +} + +/// For backwards compatibility +pub type ElementType<'a> = TLVValue<'a>; + +/// A high-level representation of a TLV value. +/// +/// Combined with `TLVTag` into a `TLV` struct it is a convenient way +/// to emit TLV byte sequences. +/// +/// A `TLVValue` can be constructed programmatically, or returned from a `TLVElement`. +/// +/// Unlike a `TLVElement` however, a `TLVValue` does not represent a complete container, +/// but rather, its beginning or end. +#[derive(Debug, Clone, PartialEq)] +pub enum TLVValue<'a> { + S8(i8), + S16(i16), + S32(i32), + S64(i64), + U8(u8), + U16(u16), + U32(u32), + U64(u64), + False, + True, + F32(f32), + F64(f64), + Utf8l(&'a str), + Utf16l(&'a str), + Utf32l(&'a str), + Utf64l(&'a str), + Str8l(&'a [u8]), + Str16l(&'a [u8]), + Str32l(&'a [u8]), + Str64l(&'a [u8]), + Null, + Struct, + Array, + List, + EndCnt, +} + +impl<'a> TLVValue<'a> { + /// Return the value type of the TLV value. + pub const fn value_type(&self) -> TLVValueType { + match self { + Self::S8(_) => TLVValueType::S8, + Self::S16(_) => TLVValueType::S16, + Self::S32(_) => TLVValueType::S32, + Self::S64(_) => TLVValueType::S64, + Self::U8(_) => TLVValueType::U8, + Self::U16(_) => TLVValueType::U16, + Self::U32(_) => TLVValueType::U32, + Self::U64(_) => TLVValueType::U64, + Self::False => TLVValueType::False, + Self::True => TLVValueType::True, + Self::F32(_) => TLVValueType::F32, + Self::F64(_) => TLVValueType::F64, + Self::Utf8l(_) => TLVValueType::Utf8l, + Self::Utf16l(_) => TLVValueType::Utf16l, + Self::Utf32l(_) => TLVValueType::Utf32l, + Self::Utf64l(_) => TLVValueType::Utf64l, + Self::Str8l(_) => TLVValueType::Str8l, + Self::Str16l(_) => TLVValueType::Str16l, + Self::Str32l(_) => TLVValueType::Str32l, + Self::Str64l(_) => TLVValueType::Str64l, + Self::Null => TLVValueType::Null, + Self::Struct => TLVValueType::Struct, + Self::Array => TLVValueType::Array, + Self::List => TLVValueType::List, + Self::EndCnt => TLVValueType::EndCnt, + } + } + + /// Create a TLV value as an S8 TLV value. + pub const fn i8(value: i8) -> Self { + Self::S8(value) + } + + /// Create a TLV value as an S8 or S16 TLV value, + /// depending on whether the value is small enough to fit in an S8 TLV value. + pub const fn i16(value: i16) -> Self { + if value >= i8::MIN as _ && value <= i8::MAX as _ { + Self::i8(value as i8) + } else { + Self::S16(value) + } + } + + /// Create a TLV value as an S8, S16, or S32 TLV value, + /// depending on whether the value is small enough to fit in an S8 or S16 TLV value. + pub const fn i32(value: i32) -> Self { + if value >= i16::MIN as _ && value <= i16::MAX as _ { + Self::i16(value as i16) + } else { + Self::S32(value) + } + } + + /// Create a TLV value as an S8, S16, S32, or S64 TLV value, + /// depending on whether the value is small enough to fit in an S8, S16, or S32 TLV value. + pub const fn i64(value: i64) -> Self { + if value >= i32::MIN as _ && value <= i32::MAX as _ { + Self::i32(value as i32) + } else { + Self::S64(value) + } + } + + /// Create a TLV value as a U8 TLV value. + pub const fn u8(value: u8) -> Self { + Self::U8(value) + } + + /// Create a TLV value as a U8 or U16 TLV value, + /// depending on whether the value is small enough to fit in a U8 TLV value. + pub const fn u16(value: u16) -> Self { + if value <= u8::MAX as _ { + Self::u8(value as u8) + } else { + Self::U16(value) + } + } + + /// Create a TLV value as a U8, U16, or U32 TLV value, + /// depending on whether the value is small enough to fit in a U8 or U16 TLV value. + pub const fn u32(value: u32) -> Self { + if value <= u16::MAX as _ { + Self::u16(value as u16) + } else { + Self::U32(value) + } + } + + /// Create a TLV value as a U8, U16, U32, or U64 TLV value, + /// depending on whether the value is small enough to fit in a U8, U16, or U32 TLV value. + pub const fn u64(value: u64) -> Self { + if value <= u32::MAX as _ { + Self::u32(value as u32) + } else { + Self::U64(value) + } + } + + /// Create a TLV value as an F32 TLV value. + pub const fn f32(value: f32) -> Self { + Self::F32(value) + } + + /// Create a TLV value as an F64 TLV value. + pub const fn f64(value: f64) -> Self { + Self::F64(value) + } + + /// Create a TLV value as a UTF-8 TLV value. + /// The length of the string is encoded as 1, 2, 4 or 8 octets, + /// depending on the length of the string. + pub const fn utf8(value: &'a str) -> Self { + let len = value.len(); + + if len <= u8::MAX as _ { + Self::Utf8l(value) + } else if len <= u16::MAX as _ { + Self::Utf16l(value) + } else if len <= u32::MAX as _ { + Self::Utf32l(value) + } else { + Self::Utf64l(value) + } + } + + /// Create a TLV value as an octet string TLV value. + /// The length of the string is encoded as 1, 2, 4 or 8 octets, + /// depending on the length of the string. + pub const fn str(value: &'a [u8]) -> Self { + let len = value.len(); + + if len <= u8::MAX as _ { + Self::Str8l(value) + } else if len <= u16::MAX as _ { + Self::Str16l(value) + } else if len <= u32::MAX as _ { + Self::Str32l(value) + } else { + Self::Str64l(value) + } + } + + /// Create a TLV value of type Struct (start). + pub const fn r#struct() -> Self { + Self::Struct + } + + /// Create a TLV value of type Struct (start). + pub const fn structure() -> Self { + Self::Struct + } + + /// Create a TLV value of type Array (start). + pub const fn array() -> Self { + Self::Array + } + + /// Create a TLV value of type List (start). + pub const fn list() -> Self { + Self::List + } + + /// Create a TLV value of type EndCnt (container end). + pub const fn end_container() -> Self { + Self::EndCnt + } + + /// Create a TLV value of type Null. + pub const fn null() -> Self { + Self::Null + } + + /// Create a TLV value of type boolean (True or False). + pub const fn bool(value: bool) -> Self { + if value { + Self::True + } else { + Self::False + } + } + + /// Return an iterator over the bytes of the TLV value. + pub fn iter(&self) -> TLVValueIter<'a, &Self> { + TLVValueIter { + value: self, + _p: PhantomData, + index: 0, + } + } + + /// Converts itself into an iterator over the bytes of the TLV value. + pub fn into_iterator(self) -> TLVValueIter<'a, Self> { + TLVValueIter { + value: self, + _p: PhantomData, + index: 0, + } + } + + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::S8(a) => write!(f, "S8({a})"), + Self::S16(a) => write!(f, "S16({a})"), + Self::S32(a) => write!(f, "S32({a})"), + Self::S64(a) => write!(f, "S64({a})"), + Self::U8(a) => write!(f, "U8({a})"), + Self::U16(a) => write!(f, "U16({a})"), + Self::U32(a) => write!(f, "U32({a})"), + Self::U64(a) => write!(f, "U64({a})"), + Self::F32(a) => write!(f, "F32({a})"), + Self::F64(a) => write!(f, "F64({a})"), + Self::Null => write!(f, "Null"), + Self::Struct => write!(f, "{{"), + Self::Array => write!(f, "["), + Self::List => write!(f, "("), + Self::True => write!(f, "True"), + Self::False => write!(f, "False"), + Self::Utf8l(a) | Self::Utf16l(a) | Self::Utf32l(a) | Self::Utf64l(a) => { + write!(f, "\"{a}\"") + } + Self::Str8l(a) | Self::Str16l(a) | Self::Str32l(a) | Self::Str64l(a) => { + write!(f, "({}){a:02X?}", a.len()) + } + Self::EndCnt => write!(f, ">"), + } + } +} + +impl<'a> IntoIterator for TLVValue<'a> { + type Item = u8; + type IntoIter = TLVValueIter<'a, Self>; + + fn into_iter(self) -> Self::IntoIter { + TLVValue::into_iterator(self) + } +} + +impl<'s, 'a> IntoIterator for &'s TLVValue<'a> { + type Item = u8; + type IntoIter = TLVValueIter<'a, Self>; + + fn into_iter(self) -> Self::IntoIter { + TLVValue::iter(self) + } +} + +/// An iterator over the bytes of a TLV value. +pub struct TLVValueIter<'a, T> +where + T: Borrow>, +{ + value: T, + _p: PhantomData<&'a ()>, + index: usize, +} + +impl<'a, T> TLVValueIter<'a, T> +where + T: Borrow>, +{ + fn variable_len_len(&self) -> usize { + match self.value.borrow() { + TLVValue::Utf8l(_) | TLVValue::Str8l(_) => 1, + TLVValue::Utf16l(_) | TLVValue::Str16l(_) => 2, + TLVValue::Utf32l(_) | TLVValue::Str32l(_) => 4, + TLVValue::Utf64l(_) | TLVValue::Str64l(_) => 8, + _ => 0, + } + } + + fn next_byte(&mut self, bytes: &[u8]) -> Option { + self.next_byte_offset(0, bytes) + } + + fn next_byte_offset(&mut self, offset: usize, bytes: &[u8]) -> Option { + if self.index - offset < bytes.len() { + let byte = bytes[self.index - offset]; + + self.index += 1; + + Some(byte) + } else { + None + } + } +} + +impl<'a, T> Iterator for TLVValueIter<'a, T> +where + T: Borrow>, +{ + type Item = u8; + + fn next(&mut self) -> Option { + match self.value.borrow() { + TLVValue::S8(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::S16(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::S32(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::S64(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::U8(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::U16(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::U32(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::U64(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::F32(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::F64(a) => self.next_byte(&a.to_le_bytes()), + TLVValue::Utf8l(a) | TLVValue::Utf16l(a) | TLVValue::Utf32l(a) => { + let len_len = self.variable_len_len(); + if self.index < len_len { + self.next_byte(&a.len().to_le_bytes()) + } else { + self.next_byte_offset(len_len, a.as_bytes()) + } + } + TLVValue::Str8l(a) | TLVValue::Str16l(a) | TLVValue::Str32l(a) => { + let len_len = self.variable_len_len(); + if self.index < len_len { + self.next_byte(&a.len().to_le_bytes()) + } else { + self.next_byte_offset(len_len, a) + } + } + _ => None, + } + } +} + +impl<'a> fmt::Display for TLVValue<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.fmt(f) + } +} + +pub(crate) fn pad(ident: usize, f: &mut fmt::Formatter<'_>) -> fmt::Result { + for _ in 0..ident { + write!(f, " ")?; + } + + Ok(()) +} + +/// For backwards compatibility +pub fn get_root_node_struct(data: &[u8]) -> Result, Error> { + // TODO: Check for trailing data + let element = TLVElement::new(data); + + element.structure()?; + + Ok(element) +} diff --git a/rs-matter/src/tlv/mod.rs b/rs-matter/src/tlv/mod.rs deleted file mode 100644 index 09b32ba8..00000000 --- a/rs-matter/src/tlv/mod.rs +++ /dev/null @@ -1,53 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Tag Types */ -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum TagType { - Anonymous, - Context(u8), - CommonPrf16(u16), - CommonPrf32(u32), - ImplPrf16(u16), - ImplPrf32(u32), - FullQual48(u64), - FullQual64(u64), -} -pub const TAG_SHIFT_BITS: u8 = 5; -pub const TAG_MASK: u8 = 0xe0; -pub const TYPE_MASK: u8 = 0x1f; -pub const MAX_TAG_INDEX: usize = 8; - -pub static TAG_SIZE_MAP: [usize; MAX_TAG_INDEX] = [ - 0, // Anonymous - 1, // Context - 2, // CommonPrf16 - 4, // CommonPrf32 - 2, // ImplPrf16 - 4, // ImplPrf32 - 6, // FullQual48 - 8, // FullQual64 -]; - -mod parser; -mod traits; -mod writer; - -pub use parser::*; -pub use rs_matter_macros::{FromTLV, ToTLV}; -pub use traits::*; -pub use writer::*; diff --git a/rs-matter/src/tlv/parser.rs b/rs-matter/src/tlv/parser.rs deleted file mode 100644 index 826653d5..00000000 --- a/rs-matter/src/tlv/parser.rs +++ /dev/null @@ -1,1215 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use crate::error::{Error, ErrorCode}; - -use byteorder::{ByteOrder, LittleEndian}; -use core::fmt::{self, Display}; -use log::{error, info}; - -use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK}; - -pub struct TLVList<'a> { - buf: &'a [u8], -} - -impl<'a> TLVList<'a> { - pub fn new(buf: &'a [u8]) -> TLVList<'a> { - TLVList { buf } - } -} - -impl<'a> Display for TLVList<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let tlvlist = self; - - const MAX_DEPTH: usize = 9; - const SPACE_BUF: &str = " "; - - let space: [&str; MAX_DEPTH] = [ - &SPACE_BUF[0..0], - &SPACE_BUF[0..4], - &SPACE_BUF[0..8], - &SPACE_BUF[0..12], - &SPACE_BUF[0..16], - &SPACE_BUF[0..20], - &SPACE_BUF[0..24], - &SPACE_BUF[0..28], - &SPACE_BUF[0..32], - ]; - - let mut stack: [char; MAX_DEPTH] = [' '; MAX_DEPTH]; - let mut index = 0_usize; - let iter = tlvlist.iter(); - for a in iter { - match a.element_type { - ElementType::Struct(_) => { - if index < MAX_DEPTH { - writeln!(f, "{}{}", space[index], a)?; - stack[index] = '}'; - index += 1; - } else { - writeln!(f, "<>")?; - } - } - ElementType::Array(_) | ElementType::List(_) => { - if index < MAX_DEPTH { - writeln!(f, "{}{}", space[index], a)?; - stack[index] = ']'; - index += 1; - } else { - writeln!(f, "<>")?; - } - } - ElementType::EndCnt => { - if index > 0 { - index -= 1; - writeln!(f, "{}{}", space[index], stack[index])?; - } else { - writeln!(f, "<>")?; - } - } - _ => writeln!(f, "{}{}", space[index], a)?, - } - } - - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub enum ElementType<'a> { - S8(i8), - S16(i16), - S32(i32), - S64(i64), - U8(u8), - U16(u16), - U32(u32), - U64(u64), - False, - True, - F32(f32), - F64(f64), - Utf8l(&'a [u8]), - Utf16l(&'a [u8]), - Utf32l, - Utf64l, - Str8l(&'a [u8]), - Str16l(&'a [u8]), - Str32l, - Str64l, - Null, - Struct(&'a [u8]), - Array(&'a [u8]), - List(&'a [u8]), - EndCnt, - Last, -} - -const MAX_VALUE_INDEX: usize = 25; - -// This is a function that takes a TLVListIterator and returns the tag type -type ExtractTag = for<'a> fn(&TLVListIterator<'a>) -> TagType; -static TAG_EXTRACTOR: [ExtractTag; 8] = [ - // Anonymous 0 - |_t| TagType::Anonymous, - // Context 1 - |t| TagType::Context(t.buf[t.current]), - // CommonPrf16 2 - |t| TagType::CommonPrf16(LittleEndian::read_u16(&t.buf[t.current..])), - // CommonPrf32 3 - |t| TagType::CommonPrf32(LittleEndian::read_u32(&t.buf[t.current..])), - // ImplPrf16 4 - |t| TagType::ImplPrf16(LittleEndian::read_u16(&t.buf[t.current..])), - // ImplPrf32 5 - |t| TagType::ImplPrf32(LittleEndian::read_u32(&t.buf[t.current..])), - // FullQual48 6 - |t| TagType::FullQual48(LittleEndian::read_u48(&t.buf[t.current..])), - // FullQual64 7 - |t| TagType::FullQual64(LittleEndian::read_u64(&t.buf[t.current..])), -]; - -// This is a function that takes a TLVListIterator and returns the element type -// Some elements (like strings), also consume additional size, than that mentioned -// if this is the case, the additional size is returned -type ExtractValue = for<'a> fn(&TLVListIterator<'a>) -> (usize, ElementType<'a>); - -static VALUE_EXTRACTOR: [ExtractValue; MAX_VALUE_INDEX] = [ - // S8 0 - { |t| (0, ElementType::S8(t.buf[t.current] as i8)) }, - // S16 1 - { - |t| { - ( - 0, - ElementType::S16(LittleEndian::read_i16(&t.buf[t.current..])), - ) - } - }, - // S32 2 - { - |t| { - ( - 0, - ElementType::S32(LittleEndian::read_i32(&t.buf[t.current..])), - ) - } - }, - // S64 3 - { - |t| { - ( - 0, - ElementType::S64(LittleEndian::read_i64(&t.buf[t.current..])), - ) - } - }, - // U8 4 - { |t| (0, ElementType::U8(t.buf[t.current])) }, - // U16 5 - { - |t| { - ( - 0, - ElementType::U16(LittleEndian::read_u16(&t.buf[t.current..])), - ) - } - }, - // U32 6 - { - |t| { - ( - 0, - ElementType::U32(LittleEndian::read_u32(&t.buf[t.current..])), - ) - } - }, - // U64 7 - { - |t| { - ( - 0, - ElementType::U64(LittleEndian::read_u64(&t.buf[t.current..])), - ) - } - }, - // False 8 - { |_t| (0, ElementType::False) }, - // True 9 - { |_t| (0, ElementType::True) }, - // F32 10 - { |_t| (0, ElementType::Last) }, - // F64 11 - { |_t| (0, ElementType::Last) }, - // Utf8l 12 - { - |t| match read_length_value(1, t) { - Err(_) => (0, ElementType::Last), - Ok((size, string)) => (size, ElementType::Utf8l(string)), - } - }, - // Utf16l 13 - { - |t| match read_length_value(2, t) { - Err(_) => (0, ElementType::Last), - Ok((size, string)) => (size, ElementType::Utf16l(string)), - } - }, - // Utf32l 14 - { |_t| (0, ElementType::Last) }, - // Utf64l 15 - { |_t| (0, ElementType::Last) }, - // Str8l 16 - { - |t| match read_length_value(1, t) { - Err(_) => (0, ElementType::Last), - Ok((size, string)) => (size, ElementType::Str8l(string)), - } - }, - // Str16l 17 - { - |t| match read_length_value(2, t) { - Err(_) => (0, ElementType::Last), - Ok((size, string)) => (size, ElementType::Str16l(string)), - } - }, - // Str32l 18 - { |_t| (0, ElementType::Last) }, - // Str64l 19 - { |_t| (0, ElementType::Last) }, - // Null 20 - { |_t| (0, ElementType::Null) }, - // Struct 21 - { |t| (0, ElementType::Struct(&t.buf[t.current..])) }, - // Array 22 - { |t| (0, ElementType::Array(&t.buf[t.current..])) }, - // List 23 - { |t| (0, ElementType::List(&t.buf[t.current..])) }, - // EndCnt 24 - { |_t| (0, ElementType::EndCnt) }, -]; - -// The array indices here correspond to the numeric value of the Element Type as defined in the Matter Spec -static VALUE_SIZE_MAP: [usize; MAX_VALUE_INDEX] = [ - 1, // S8 0 - 2, // S16 1 - 4, // S32 2 - 8, // S64 3 - 1, // U8 4 - 2, // U16 5 - 4, // U32 6 - 8, // U64 7 - 0, // False 8 - 0, // True 9 - 4, // F32 10 - 8, // F64 11 - 1, // Utf8l 12 - 2, // Utf16l 13 - 4, // Utf32l 14 - 8, // Utf64l 15 - 1, // Str8l 16 - 2, // Str16l 17 - 4, // Str32l 18 - 8, // Str64l 19 - 0, // Null 20 - 0, // Struct 21 - 0, // Array 22 - 0, // List 23 - 0, // EndCnt 24 -]; - -fn read_length_value<'a>( - size_of_length_field: usize, - t: &TLVListIterator<'a>, -) -> Result<(usize, &'a [u8]), Error> { - // The current offset is the string size - let length: usize = LittleEndian::read_uint(&t.buf[t.current..], size_of_length_field) as usize; - // We'll consume the current offset (len) + the entire string - if length + size_of_length_field > t.buf.len() - t.current { - // Return Error - Err(ErrorCode::NoSpace.into()) - } else { - Ok(( - // return the additional size only - length, - &t.buf[(t.current + size_of_length_field)..(t.current + size_of_length_field + length)], - )) - } -} - -#[derive(Debug, Clone)] -pub struct TLVElement<'a> { - tag_type: TagType, - element_type: ElementType<'a>, -} - -impl<'a> PartialEq for TLVElement<'a> { - fn eq(&self, other: &Self) -> bool { - match self.element_type { - ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => { - let mut our_iter = TLVListIterator::from_buf(buf); - let mut their = match other.element_type { - ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => { - TLVListIterator::from_buf(buf) - } - _ => { - // If we are a container, the other must be a container, else this is a mismatch - return false; - } - }; - let mut nest_level = 0_u8; - loop { - let ours = our_iter.next(); - let theirs = their.next(); - if core::mem::discriminant(&ours) != core::mem::discriminant(&theirs) { - // One of us reached end of list, but the other didn't, that's a mismatch - return false; - } - if ours.is_none() { - // End of list - break; - } - // guaranteed to work - let ours = ours.unwrap(); - let theirs = theirs.unwrap(); - - if let ElementType::EndCnt = ours.element_type { - if nest_level == 0 { - break; - } - nest_level -= 1; - } else { - if is_container(&ours.element_type) { - nest_level += 1; - // Only compare the discriminants in case of array/list/structures, - // instead of actual element values. Those will be subsets within this same - // list that will get validated anyway - if core::mem::discriminant(&ours.element_type) - != core::mem::discriminant(&theirs.element_type) - { - return false; - } - } else if ours.element_type != theirs.element_type { - return false; - } - - if ours.tag_type != theirs.tag_type { - return false; - } - } - } - true - } - _ => self.tag_type == other.tag_type && self.element_type == other.element_type, - } - } -} - -impl<'a> TLVElement<'a> { - pub fn enter(&self) -> Option> { - let buf = match self.element_type { - ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => buf, - _ => return None, - }; - let list_iter = TLVListIterator { buf, current: 0 }; - Some(TLVContainerIterator { - list_iter, - prev_container: false, - iterator_consumed: false, - }) - } - - pub fn new(tag: TagType, value: ElementType<'a>) -> Self { - Self { - tag_type: tag, - element_type: value, - } - } - - pub fn i8(&self) -> Result { - match self.element_type { - ElementType::S8(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn u8(&self) -> Result { - match self.element_type { - ElementType::U8(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn i16(&self) -> Result { - match self.element_type { - ElementType::S8(a) => Ok(a.into()), - ElementType::S16(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn u16(&self) -> Result { - match self.element_type { - ElementType::U8(a) => Ok(a.into()), - ElementType::U16(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn i32(&self) -> Result { - match self.element_type { - ElementType::S8(a) => Ok(a.into()), - ElementType::S16(a) => Ok(a.into()), - ElementType::S32(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn u32(&self) -> Result { - match self.element_type { - ElementType::U8(a) => Ok(a.into()), - ElementType::U16(a) => Ok(a.into()), - ElementType::U32(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn i64(&self) -> Result { - match self.element_type { - ElementType::S8(a) => Ok(a.into()), - ElementType::S16(a) => Ok(a.into()), - ElementType::S32(a) => Ok(a.into()), - ElementType::S64(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn u64(&self) -> Result { - match self.element_type { - ElementType::U8(a) => Ok(a.into()), - ElementType::U16(a) => Ok(a.into()), - ElementType::U32(a) => Ok(a.into()), - ElementType::U64(a) => Ok(a), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn slice(&self) -> Result<&'a [u8], Error> { - match self.element_type { - ElementType::Str8l(s) - | ElementType::Utf8l(s) - | ElementType::Str16l(s) - | ElementType::Utf16l(s) => Ok(s), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn str(&self) -> Result<&'a str, Error> { - match self.element_type { - ElementType::Str8l(s) - | ElementType::Utf8l(s) - | ElementType::Str16l(s) - | ElementType::Utf16l(s) => { - Ok(core::str::from_utf8(s).map_err(|_| Error::from(ErrorCode::InvalidData))?) - } - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn bool(&self) -> Result { - match self.element_type { - ElementType::False => Ok(false), - ElementType::True => Ok(true), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn null(&self) -> Result<(), Error> { - match self.element_type { - ElementType::Null => Ok(()), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn confirm_struct(&self) -> Result<&TLVElement<'a>, Error> { - match self.element_type { - ElementType::Struct(_) => Ok(self), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn confirm_array(&self) -> Result<&TLVElement<'a>, Error> { - match self.element_type { - ElementType::Array(_) => Ok(self), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn confirm_list(&self) -> Result<&TLVElement<'a>, Error> { - match self.element_type { - ElementType::List(_) => Ok(self), - _ => Err(ErrorCode::TLVTypeMismatch.into()), - } - } - - pub fn find_tag(&self, tag: u32) -> Result, Error> { - let match_tag: TagType = TagType::Context(tag as u8); - - let iter = self.enter().ok_or(ErrorCode::TLVTypeMismatch)?; - for a in iter { - if match_tag == a.tag_type { - return Ok(a); - } - } - Err(ErrorCode::NoTagFound.into()) - } - - pub fn get_tag(&self) -> TagType { - self.tag_type - } - - pub fn check_ctx_tag(&self, tag: u8) -> bool { - if let TagType::Context(our_tag) = self.tag_type { - if our_tag == tag { - return true; - } - } - false - } - - pub fn get_element_type(&self) -> &ElementType { - &self.element_type - } -} - -impl<'a> fmt::Display for TLVElement<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.tag_type { - TagType::Anonymous => (), - TagType::Context(tag) => write!(f, "{}: ", tag)?, - _ => write!(f, "Other Context Tag")?, - } - match self.element_type { - ElementType::Struct(_) => write!(f, "{{"), - ElementType::Array(_) => write!(f, "["), - ElementType::List(_) => write!(f, "["), - ElementType::EndCnt => write!(f, ">"), - ElementType::True => write!(f, "True"), - ElementType::False => write!(f, "False"), - ElementType::Str8l(a) - | ElementType::Utf8l(a) - | ElementType::Str16l(a) - | ElementType::Utf16l(a) => { - if let Ok(s) = core::str::from_utf8(a) { - write!(f, "len[{}]\"{}\"", s.len(), s) - } else { - write!(f, "len[{}]{:x?}", a.len(), a) - } - } - _ => write!(f, "{:?}", self.element_type), - } - } -} - -// This is a TLV List iterator, it only iterates over the individual TLVs in a TLV list -#[derive(Clone, Debug, PartialEq)] -pub struct TLVListIterator<'a> { - buf: &'a [u8], - current: usize, -} - -impl<'a> TLVListIterator<'a> { - fn from_buf(buf: &'a [u8]) -> Self { - Self { buf, current: 0 } - } - - fn advance(&mut self, len: usize) { - self.current += len; - } - - // Caller should ensure they are reading the _right_ tag at the _right_ place - fn read_this_tag(&mut self, tag_type: u8) -> Option { - if tag_type as usize >= MAX_TAG_INDEX { - return None; - } - let tag_size = TAG_SIZE_MAP[tag_type as usize]; - if tag_size > self.buf.len() - self.current { - return None; - } - let tag = (TAG_EXTRACTOR[tag_type as usize])(self); - self.advance(tag_size); - Some(tag) - } - - fn read_this_value(&mut self, element_type: u8) -> Option> { - if element_type as usize >= MAX_VALUE_INDEX { - return None; - } - let mut size = VALUE_SIZE_MAP[element_type as usize]; - if size > self.buf.len() - self.current { - error!( - "Invalid value found: {} self {:?} size {}", - element_type, self, size - ); - return None; - } - - let (extra_size, element) = (VALUE_EXTRACTOR[element_type as usize])(self); - if element != ElementType::Last { - size += extra_size; - self.advance(size); - Some(element) - } else { - None - } - } -} - -impl<'a> Iterator for TLVListIterator<'a> { - type Item = TLVElement<'a>; - /* Code for going to the next Element */ - fn next(&mut self) -> Option> { - if self.buf.len() - self.current < 1 { - return None; - } - /* Read Control */ - let control = self.buf[self.current]; - let tag_type = (control & TAG_MASK) >> TAG_SHIFT_BITS; - let element_type = control & TYPE_MASK; - self.advance(1); - - /* Consume Tag */ - let tag_type = self.read_this_tag(tag_type)?; - - /* Consume Value */ - let element_type = self.read_this_value(element_type)?; - - Some(TLVElement { - tag_type, - element_type, - }) - } -} - -impl<'a> TLVList<'a> { - pub fn iter(&self) -> TLVListIterator<'a> { - TLVListIterator { - current: 0, - buf: self.buf, - } - } -} - -fn is_container(element_type: &ElementType) -> bool { - matches!( - element_type, - ElementType::Struct(_) | ElementType::Array(_) | ElementType::List(_) - ) -} - -// This is a Container iterator, it iterates over containers in a TLV list -#[derive(Debug, PartialEq)] -pub struct TLVContainerIterator<'a> { - list_iter: TLVListIterator<'a>, - prev_container: bool, - iterator_consumed: bool, -} - -impl<'a> TLVContainerIterator<'a> { - fn skip_to_end_of_container(&mut self) -> Option> { - let mut nest_level = 0; - while let Some(element) = self.list_iter.next() { - // We know we are already in a container, we have to keep looking for end-of-container - // println!("Skip: element: {:x?} nest_level: {}", element, nest_level); - match element.element_type { - ElementType::EndCnt => { - if nest_level == 0 { - // Return the element following this element - // println!("Returning"); - // The final next() may be the end of the top-level container itself, if so, we must return None - let last_elem = self.list_iter.next()?; - match last_elem.element_type { - ElementType::EndCnt => { - self.iterator_consumed = true; - return None; - } - _ => return Some(last_elem), - } - } - nest_level -= 1; - } - _ => { - if is_container(&element.element_type) { - nest_level += 1; - } - } - } - } - None - } -} - -impl<'a> Iterator for TLVContainerIterator<'a> { - type Item = TLVElement<'a>; - /* Code for going to the next Element */ - fn next(&mut self) -> Option> { - // This iterator may be consumed, but the underlying might not. This protects it from such occurrences - if self.iterator_consumed { - return None; - } - let element: TLVElement = if self.prev_container { - // println!("Calling skip to end of container"); - self.skip_to_end_of_container()? - } else { - self.list_iter.next()? - }; - // println!("Found element: {:x?}", element); - /* If we found end of container, that means our own container is over */ - if element.element_type == ElementType::EndCnt { - self.iterator_consumed = true; - return None; - } - - self.prev_container = is_container(&element.element_type); - Some(element) - } -} - -pub fn get_root_node(b: &[u8]) -> Result { - Ok(TLVList::new(b) - .iter() - .next() - .ok_or(ErrorCode::InvalidData)?) -} - -pub fn get_root_node_struct(b: &[u8]) -> Result { - let root = TLVList::new(b) - .iter() - .next() - .ok_or(ErrorCode::InvalidData)?; - - root.confirm_struct()?; - - Ok(root) -} - -pub fn get_root_node_list(b: &[u8]) -> Result { - let root = TLVList::new(b) - .iter() - .next() - .ok_or(ErrorCode::InvalidData)?; - - root.confirm_list()?; - - Ok(root) -} - -pub fn print_tlv_list(b: &[u8]) { - info!("TLV list:\n{}\n---------", TLVList::new(b)); -} - -#[cfg(test)] -mod tests { - use log::info; - - use super::{ - get_root_node_list, get_root_node_struct, ElementType, TLVElement, TLVList, TagType, - }; - use crate::error::ErrorCode; - - #[test] - fn test_short_length_tag() { - // The 0x36 is an array with a tag, but we leave out the tag field - let b = [0x15, 0x36]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!(tlv_iter.next(), None); - } - - #[test] - fn test_invalid_value_type() { - // The 0x24 is a a tagged integer, here we leave out the integer value - let b = [0x15, 0x1f, 0x0]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!(tlv_iter.next(), None); - } - - #[test] - fn test_short_length_value_immediate() { - // The 0x24 is a a tagged integer, here we leave out the integer value - let b = [0x15, 0x24, 0x0]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!(tlv_iter.next(), None); - } - - #[test] - fn test_short_length_value_string() { - // This is a tagged string, with tag 0 and length 0xb, but we only have 4 bytes in the string - let b = [0x15, 0x30, 0x00, 0x0b, 0x73, 0x6d, 0x61, 0x72]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!(tlv_iter.next(), None); - } - - #[test] - fn test_valid_tag() { - // The 0x36 is an array with a tag, here tag is 0 - let b = [0x15, 0x36, 0x0]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!( - tlv_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(0), - element_type: ElementType::Array(&[]), - }) - ); - } - - #[test] - fn test_valid_value_immediate() { - // The 0x24 is a a tagged integer, here the integer is 2 - let b = [0x15, 0x24, 0x1, 0x2]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!( - tlv_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(1), - element_type: ElementType::U8(2), - }) - ); - } - - #[test] - fn test_valid_value_string() { - // This is a tagged string, with tag 0 and length 4, and we have 4 bytes in the string - let b = [0x15, 0x30, 0x5, 0x04, 0x73, 0x6d, 0x61, 0x72]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!( - tlv_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(5), - element_type: ElementType::Str8l(&[0x73, 0x6d, 0x61, 0x72]), - }) - ); - } - - #[test] - fn test_valid_value_string16() { - // This is a tagged string, with tag 0 and length 4, and we have 4 bytes in the string - let b = [ - 0x15, 0x31, 0x1, 0xd8, 0x1, 0x30, 0x82, 0x1, 0xd4, 0x30, 0x82, 0x1, 0x7a, 0xa0, 0x3, - 0x2, 0x1, 0x2, 0x2, 0x8, 0x3e, 0x6c, 0xe6, 0x50, 0x9a, 0xd8, 0x40, 0xcd, 0x30, 0xa, - 0x6, 0x8, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x4, 0x3, 0x2, 0x30, 0x30, 0x31, 0x18, 0x30, - 0x16, 0x6, 0x3, 0x55, 0x4, 0x3, 0xc, 0xf, 0x4d, 0x61, 0x74, 0x74, 0x65, 0x72, 0x20, - 0x54, 0x65, 0x73, 0x74, 0x20, 0x50, 0x41, 0x41, 0x31, 0x14, 0x30, 0x12, 0x6, 0xa, 0x2b, - 0x6, 0x1, 0x4, 0x1, 0x82, 0xa2, 0x7c, 0x2, 0x1, 0xc, 0x4, 0x46, 0x46, 0x46, 0x31, 0x30, - 0x20, 0x17, 0xd, 0x32, 0x31, 0x30, 0x36, 0x32, 0x38, 0x31, 0x34, 0x32, 0x33, 0x34, - 0x33, 0x5a, 0x18, 0xf, 0x39, 0x39, 0x39, 0x39, 0x31, 0x32, 0x33, 0x31, 0x32, 0x33, - 0x35, 0x39, 0x35, 0x39, 0x5a, 0x30, 0x46, 0x31, 0x18, 0x30, 0x16, 0x6, 0x3, 0x55, 0x4, - 0x3, 0xc, 0xf, 0x4d, 0x61, 0x74, 0x74, 0x65, 0x72, 0x20, 0x54, 0x65, 0x73, 0x74, 0x20, - 0x50, 0x41, 0x49, 0x31, 0x14, 0x30, 0x12, 0x6, 0xa, 0x2b, 0x6, 0x1, 0x4, 0x1, 0x82, - 0xa2, 0x7c, 0x2, 0x1, 0xc, 0x4, 0x46, 0x46, 0x46, 0x31, 0x31, 0x14, 0x30, 0x12, 0x6, - 0xa, 0x2b, 0x6, 0x1, 0x4, 0x1, 0x82, 0xa2, 0x7c, 0x2, 0x2, 0xc, 0x4, 0x38, 0x30, 0x30, - 0x30, 0x30, 0x59, 0x30, 0x13, 0x6, 0x7, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x2, 0x1, 0x6, - 0x8, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x3, 0x1, 0x7, 0x3, 0x42, 0x0, 0x4, 0x80, 0xdd, - 0xf1, 0x1b, 0x22, 0x8f, 0x3e, 0x31, 0xf6, 0x3b, 0xcf, 0x57, 0x98, 0xda, 0x14, 0x62, - 0x3a, 0xeb, 0xbd, 0xe8, 0x2e, 0xf3, 0x78, 0xee, 0xad, 0xbf, 0xb1, 0x8f, 0xe1, 0xab, - 0xce, 0x31, 0xd0, 0x8e, 0xd4, 0xb2, 0x6, 0x4, 0xb6, 0xcc, 0xc6, 0xd9, 0xb5, 0xfa, 0xb6, - 0x4e, 0x7d, 0xe1, 0xc, 0xb7, 0x4b, 0xe0, 0x17, 0xc9, 0xec, 0x15, 0x16, 0x5, 0x6d, 0x70, - 0xf2, 0xcd, 0xb, 0x22, 0xa3, 0x66, 0x30, 0x64, 0x30, 0x12, 0x6, 0x3, 0x55, 0x1d, 0x13, - 0x1, 0x1, 0xff, 0x4, 0x8, 0x30, 0x6, 0x1, 0x1, 0xff, 0x2, 0x1, 0x0, 0x30, 0xe, 0x6, - 0x3, 0x55, 0x1d, 0xf, 0x1, 0x1, 0xff, 0x4, 0x4, 0x3, 0x2, 0x1, 0x6, 0x30, 0x1d, 0x6, - 0x3, 0x55, 0x1d, 0xe, 0x4, 0x16, 0x4, 0x14, 0xaf, 0x42, 0xb7, 0x9, 0x4d, 0xeb, 0xd5, - 0x15, 0xec, 0x6e, 0xcf, 0x33, 0xb8, 0x11, 0x15, 0x22, 0x5f, 0x32, 0x52, 0x88, 0x30, - 0x1f, 0x6, 0x3, 0x55, 0x1d, 0x23, 0x4, 0x18, 0x30, 0x16, 0x80, 0x14, 0x6a, 0xfd, 0x22, - 0x77, 0x1f, 0x51, 0x1f, 0xec, 0xbf, 0x16, 0x41, 0x97, 0x67, 0x10, 0xdc, 0xdc, 0x31, - 0xa1, 0x71, 0x7e, 0x30, 0xa, 0x6, 0x8, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x4, 0x3, 0x2, - 0x3, 0x48, 0x0, 0x30, 0x45, 0x2, 0x21, 0x0, 0x96, 0xc9, 0xc8, 0xcf, 0x2e, 0x1, 0x88, - 0x60, 0x5, 0xd8, 0xf5, 0xbc, 0x72, 0xc0, 0x7b, 0x75, 0xfd, 0x9a, 0x57, 0x69, 0x5a, - 0xc4, 0x91, 0x11, 0x31, 0x13, 0x8b, 0xea, 0x3, 0x3c, 0xe5, 0x3, 0x2, 0x20, 0x25, 0x54, - 0x94, 0x3b, 0xe5, 0x7d, 0x53, 0xd6, 0xc4, 0x75, 0xf7, 0xd2, 0x3e, 0xbf, 0xcf, 0xc2, - 0x3, 0x6c, 0xd2, 0x9b, 0xa6, 0x39, 0x3e, 0xc7, 0xef, 0xad, 0x87, 0x14, 0xab, 0x71, - 0x82, 0x19, 0x26, 0x2, 0x3e, 0x0, 0x0, 0x0, - ]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!( - tlv_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(1), - element_type: ElementType::Str16l(&[ - 0x30, 0x82, 0x1, 0xd4, 0x30, 0x82, 0x1, 0x7a, 0xa0, 0x3, 0x2, 0x1, 0x2, 0x2, - 0x8, 0x3e, 0x6c, 0xe6, 0x50, 0x9a, 0xd8, 0x40, 0xcd, 0x30, 0xa, 0x6, 0x8, 0x2a, - 0x86, 0x48, 0xce, 0x3d, 0x4, 0x3, 0x2, 0x30, 0x30, 0x31, 0x18, 0x30, 0x16, 0x6, - 0x3, 0x55, 0x4, 0x3, 0xc, 0xf, 0x4d, 0x61, 0x74, 0x74, 0x65, 0x72, 0x20, 0x54, - 0x65, 0x73, 0x74, 0x20, 0x50, 0x41, 0x41, 0x31, 0x14, 0x30, 0x12, 0x6, 0xa, - 0x2b, 0x6, 0x1, 0x4, 0x1, 0x82, 0xa2, 0x7c, 0x2, 0x1, 0xc, 0x4, 0x46, 0x46, - 0x46, 0x31, 0x30, 0x20, 0x17, 0xd, 0x32, 0x31, 0x30, 0x36, 0x32, 0x38, 0x31, - 0x34, 0x32, 0x33, 0x34, 0x33, 0x5a, 0x18, 0xf, 0x39, 0x39, 0x39, 0x39, 0x31, - 0x32, 0x33, 0x31, 0x32, 0x33, 0x35, 0x39, 0x35, 0x39, 0x5a, 0x30, 0x46, 0x31, - 0x18, 0x30, 0x16, 0x6, 0x3, 0x55, 0x4, 0x3, 0xc, 0xf, 0x4d, 0x61, 0x74, 0x74, - 0x65, 0x72, 0x20, 0x54, 0x65, 0x73, 0x74, 0x20, 0x50, 0x41, 0x49, 0x31, 0x14, - 0x30, 0x12, 0x6, 0xa, 0x2b, 0x6, 0x1, 0x4, 0x1, 0x82, 0xa2, 0x7c, 0x2, 0x1, - 0xc, 0x4, 0x46, 0x46, 0x46, 0x31, 0x31, 0x14, 0x30, 0x12, 0x6, 0xa, 0x2b, 0x6, - 0x1, 0x4, 0x1, 0x82, 0xa2, 0x7c, 0x2, 0x2, 0xc, 0x4, 0x38, 0x30, 0x30, 0x30, - 0x30, 0x59, 0x30, 0x13, 0x6, 0x7, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x2, 0x1, 0x6, - 0x8, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x3, 0x1, 0x7, 0x3, 0x42, 0x0, 0x4, 0x80, - 0xdd, 0xf1, 0x1b, 0x22, 0x8f, 0x3e, 0x31, 0xf6, 0x3b, 0xcf, 0x57, 0x98, 0xda, - 0x14, 0x62, 0x3a, 0xeb, 0xbd, 0xe8, 0x2e, 0xf3, 0x78, 0xee, 0xad, 0xbf, 0xb1, - 0x8f, 0xe1, 0xab, 0xce, 0x31, 0xd0, 0x8e, 0xd4, 0xb2, 0x6, 0x4, 0xb6, 0xcc, - 0xc6, 0xd9, 0xb5, 0xfa, 0xb6, 0x4e, 0x7d, 0xe1, 0xc, 0xb7, 0x4b, 0xe0, 0x17, - 0xc9, 0xec, 0x15, 0x16, 0x5, 0x6d, 0x70, 0xf2, 0xcd, 0xb, 0x22, 0xa3, 0x66, - 0x30, 0x64, 0x30, 0x12, 0x6, 0x3, 0x55, 0x1d, 0x13, 0x1, 0x1, 0xff, 0x4, 0x8, - 0x30, 0x6, 0x1, 0x1, 0xff, 0x2, 0x1, 0x0, 0x30, 0xe, 0x6, 0x3, 0x55, 0x1d, 0xf, - 0x1, 0x1, 0xff, 0x4, 0x4, 0x3, 0x2, 0x1, 0x6, 0x30, 0x1d, 0x6, 0x3, 0x55, 0x1d, - 0xe, 0x4, 0x16, 0x4, 0x14, 0xaf, 0x42, 0xb7, 0x9, 0x4d, 0xeb, 0xd5, 0x15, 0xec, - 0x6e, 0xcf, 0x33, 0xb8, 0x11, 0x15, 0x22, 0x5f, 0x32, 0x52, 0x88, 0x30, 0x1f, - 0x6, 0x3, 0x55, 0x1d, 0x23, 0x4, 0x18, 0x30, 0x16, 0x80, 0x14, 0x6a, 0xfd, - 0x22, 0x77, 0x1f, 0x51, 0x1f, 0xec, 0xbf, 0x16, 0x41, 0x97, 0x67, 0x10, 0xdc, - 0xdc, 0x31, 0xa1, 0x71, 0x7e, 0x30, 0xa, 0x6, 0x8, 0x2a, 0x86, 0x48, 0xce, - 0x3d, 0x4, 0x3, 0x2, 0x3, 0x48, 0x0, 0x30, 0x45, 0x2, 0x21, 0x0, 0x96, 0xc9, - 0xc8, 0xcf, 0x2e, 0x1, 0x88, 0x60, 0x5, 0xd8, 0xf5, 0xbc, 0x72, 0xc0, 0x7b, - 0x75, 0xfd, 0x9a, 0x57, 0x69, 0x5a, 0xc4, 0x91, 0x11, 0x31, 0x13, 0x8b, 0xea, - 0x3, 0x3c, 0xe5, 0x3, 0x2, 0x20, 0x25, 0x54, 0x94, 0x3b, 0xe5, 0x7d, 0x53, - 0xd6, 0xc4, 0x75, 0xf7, 0xd2, 0x3e, 0xbf, 0xcf, 0xc2, 0x3, 0x6c, 0xd2, 0x9b, - 0xa6, 0x39, 0x3e, 0xc7, 0xef, 0xad, 0x87, 0x14, 0xab, 0x71, 0x82, 0x19 - ]), - }) - ); - assert_eq!( - tlv_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(2), - element_type: ElementType::U32(62), - }) - ); - } - - #[test] - fn test_no_iterator_for_int() { - // The 0x24 is a a tagged integer, here the integer is 2 - let b = [0x15, 0x24, 0x1, 0x2]; - let tlvlist = TLVList::new(&b); - let mut tlv_iter = tlvlist.iter(); - // Skip the 0x15 - tlv_iter.next(); - assert_eq!(tlv_iter.next().unwrap().enter(), None); - } - - #[test] - fn test_struct_iteration_with_mix_values() { - // This is a struct with 3 valid values - let b = [ - 0x15, 0x24, 0x0, 0x2, 0x26, 0x2, 0x4e, 0x10, 0x02, 0x00, 0x30, 0x3, 0x04, 0x73, 0x6d, - 0x61, 0x72, - ]; - let mut root_iter = get_root_node_struct(&b).unwrap().enter().unwrap(); - assert_eq!( - root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(0), - element_type: ElementType::U8(2), - }) - ); - assert_eq!( - root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(2), - element_type: ElementType::U32(135246), - }) - ); - assert_eq!( - root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(3), - element_type: ElementType::Str8l(&[0x73, 0x6d, 0x61, 0x72]), - }) - ); - } - - #[test] - fn test_struct_find_element_mix_values() { - // This is a struct with 3 valid values - let b = [ - 0x15, 0x30, 0x3, 0x04, 0x73, 0x6d, 0x61, 0x72, 0x24, 0x0, 0x2, 0x26, 0x2, 0x4e, 0x10, - 0x02, 0x00, - ]; - let root = get_root_node_struct(&b).unwrap(); - - assert_eq!( - root.find_tag(0).unwrap(), - TLVElement { - tag_type: TagType::Context(0), - element_type: ElementType::U8(2), - } - ); - assert_eq!( - root.find_tag(2).unwrap(), - TLVElement { - tag_type: TagType::Context(2), - element_type: ElementType::U32(135246), - } - ); - assert_eq!( - root.find_tag(3).unwrap(), - TLVElement { - tag_type: TagType::Context(3), - element_type: ElementType::Str8l(&[0x73, 0x6d, 0x61, 0x72]), - } - ); - } - - #[test] - fn test_list_iteration_with_mix_values() { - // This is a list with 3 valid values - let b = [ - 0x17, 0x24, 0x0, 0x2, 0x26, 0x2, 0x4e, 0x10, 0x02, 0x00, 0x30, 0x3, 0x04, 0x73, 0x6d, - 0x61, 0x72, - ]; - let mut root_iter = get_root_node_list(&b).unwrap().enter().unwrap(); - assert_eq!( - root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(0), - element_type: ElementType::U8(2), - }) - ); - assert_eq!( - root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(2), - element_type: ElementType::U32(135246), - }) - ); - assert_eq!( - root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(3), - element_type: ElementType::Str8l(&[0x73, 0x6d, 0x61, 0x72]), - }) - ); - } - - #[test] - fn test_complex_structure_invoke_cmd() { - // This is what we typically get in an invoke command - let b = [ - 0x15, 0x36, 0x0, 0x15, 0x37, 0x0, 0x25, 0x0, 0x2, 0x0, 0x26, 0x1, 0x6, 0x0, 0x0, 0x0, - 0x26, 0x2, 0x1, 0x0, 0x0, 0x0, 0x18, 0x35, 0x1, 0x18, 0x18, 0x18, 0x18, - ]; - - let root = get_root_node_struct(&b).unwrap(); - - let mut cmd_list_iter = root - .find_tag(0) - .unwrap() - .confirm_array() - .unwrap() - .enter() - .unwrap(); - info!("Command list iterator: {:?}", cmd_list_iter); - - // This is an array of CommandDataIB, but we'll only use the first element - let cmd_data_ib = cmd_list_iter.next().unwrap(); - - let cmd_path = cmd_data_ib.find_tag(0).unwrap(); - let cmd_path = cmd_path.confirm_list().unwrap(); - assert_eq!( - cmd_path.find_tag(0).unwrap(), - TLVElement { - tag_type: TagType::Context(0), - element_type: ElementType::U16(2), - } - ); - assert_eq!( - cmd_path.find_tag(1).unwrap(), - TLVElement { - tag_type: TagType::Context(1), - element_type: ElementType::U32(6), - } - ); - assert_eq!( - cmd_path.find_tag(2).unwrap(), - TLVElement { - tag_type: TagType::Context(2), - element_type: ElementType::U32(1), - } - ); - assert_eq!( - cmd_path.find_tag(3).map_err(|e| e.code()), - Err(ErrorCode::NoTagFound) - ); - - // This is the variable of the invoke command - assert_eq!( - cmd_data_ib.find_tag(1).unwrap().enter().unwrap().next(), - None - ); - } - - #[test] - fn test_read_past_end_of_container() { - let b = [0x15, 0x35, 0x0, 0x24, 0x1, 0x2, 0x18, 0x24, 0x0, 0x2, 0x18]; - - let mut sub_root_iter = get_root_node_struct(&b) - .unwrap() - .find_tag(0) - .unwrap() - .enter() - .unwrap(); - assert_eq!( - sub_root_iter.next(), - Some(TLVElement { - tag_type: TagType::Context(1), - element_type: ElementType::U8(2), - }) - ); - assert_eq!(sub_root_iter.next(), None); - // Call next, even after the first next returns None - assert_eq!(sub_root_iter.next(), None); - assert_eq!(sub_root_iter.next(), None); - } - - #[test] - fn test_basic_list_iterator() { - // This is the input we have - let b = [ - 0x15, 0x36, 0x0, 0x15, 0x37, 0x0, 0x24, 0x0, 0x2, 0x24, 0x2, 0x6, 0x24, 0x3, 0x1, 0x18, - 0x35, 0x1, 0x18, 0x18, 0x18, 0x18, - ]; - - let dummy_pointer = &b[1..]; - // These are the decoded elements that we expect from this input - let verify_matrix: [(TagType, ElementType); 13] = [ - (TagType::Anonymous, ElementType::Struct(dummy_pointer)), - (TagType::Context(0), ElementType::Array(dummy_pointer)), - (TagType::Anonymous, ElementType::Struct(dummy_pointer)), - (TagType::Context(0), ElementType::List(dummy_pointer)), - (TagType::Context(0), ElementType::U8(2)), - (TagType::Context(2), ElementType::U8(6)), - (TagType::Context(3), ElementType::U8(1)), - (TagType::Anonymous, ElementType::EndCnt), - (TagType::Context(1), ElementType::Struct(dummy_pointer)), - (TagType::Anonymous, ElementType::EndCnt), - (TagType::Anonymous, ElementType::EndCnt), - (TagType::Anonymous, ElementType::EndCnt), - (TagType::Anonymous, ElementType::EndCnt), - ]; - - let mut list_iter = TLVList::new(&b).iter(); - let mut index = 0; - loop { - let element = list_iter.next(); - match element { - None => break, - Some(a) => { - assert_eq!(a.tag_type, verify_matrix[index].0); - assert_eq!( - core::mem::discriminant(&a.element_type), - core::mem::discriminant(&verify_matrix[index].1) - ); - } - } - index += 1; - } - // After the end, purposefully try a few more next - assert_eq!(list_iter.next(), None); - assert_eq!(list_iter.next(), None); - } -} diff --git a/rs-matter/src/tlv/read.rs b/rs-matter/src/tlv/read.rs new file mode 100644 index 00000000..995c2258 --- /dev/null +++ b/rs-matter/src/tlv/read.rs @@ -0,0 +1,2003 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::{cmp::Ordering, fmt}; + +use crate::error::{Error, ErrorCode}; + +use super::{pad, TLVControl, TLVTag, TLVTagType, TLVValue, TLVValueType, TLV}; + +/// A newtype for reading TLV-encoded data from Rust `&[u8]` slices. +/// +/// Semantically, a `TLVElement` is just a byte slice of TLV-encoded data/stream, and the methods provided by this therefore +/// allow to parse - on the fly - the byte slice as TLV. +/// +/// Note also, that - as per the Matter Core Spec: +/// - A valid TLV stream always represents a SINGLE TLV element (hence why this type is named `TLVElement` and why we claim +/// that it represents also a whole TLV stream) +/// - If there is a need to encode more than one TLV element, they should be encoded in a TLV container (array, list or struct), +/// hence we end up again with a single TLV element, which represents the whole container. +/// +/// Parsing/reading/validating the TLV of the slice represented by a `TLVElement` is done on-demand. What this means is that: +/// - `TLVElement::new(slice)` always succeeds, even when the passed slice contains invalid TLV data +/// - As the various methods of `TLVElement` type are called, the data in the slice is parsed and validated on the fly. Hence why all methods +/// on `TLVElement` except `is_empty` are fallible. +/// +/// A TLV element can currently be constructed from an empty `&[]` slice, but the empty slice does not actually represent a TLV element, +/// so all methods except `TLVElement::is_empty` would fail on a `TLVElement` constructed from an empty slice. The only reason why empty slices +/// are currently allowed is to simplify the `FromTLV` trait a bit by representing data which was not found (i.e. optional data in TLV structures) +/// as a TLVElement with an empty slice. +/// +/// The design approach from above (on-demand parsing/validation) trades memory efficiency for extra computations, in that by simply decorating +/// a Rust `&[u8]` slice anbd post-poning everything else post-construction it ensures the size of a `TLVElement` is equal to the size of the wrapped +/// `&[u8]` slice - i.e., a regular Rust fat pointer (8 bytes on 32 bit archs and 16 bytes on 64 bit archs). +/// +/// Furthermore, all accompanying types of `TLVElement`, like `TLVSequence`, `TLVContainerIter` and `TLVArray` are also just newtypes over byte slices +/// and therefore just as small. +/// +/// (Keeping interim data is still optionally possible, by using the `TLV::tag` and `TLV::value` +/// methods to read the tag and value of a TLV as enums.) +/// +/// As for representing the encoded TLV stream itself as a raw `&[u8]` slice - this trivializes the traversal of the stream +/// as the stream traversal is represented as returning sub-slices of the original slice. It also allows `FromTLV` implementations where +/// the data is borrowed directly from the `&[u8]` slice representing the encoded TLV stream without any data moves. Types that implement +/// such borrowing are e.g.: +/// - `&str` (used to represent borrowed TLV UTF-8 strings) +/// - `Bytes<'a>` (a newtype over `&'a [u8]` - used to represent TLV octet strings) +/// - `TLVArray` +/// - `TLVSequence` - discussed below +/// +/// Also, this representation naturally allows random-access to the TLV stream, which is necessary for a number of reasons: +/// - Deserialization of TLV structs into Rust structs (with the `FromTLV` derive macro) where the order of the TLV elements +/// of the struct is not known in advance +/// - Delayed in-place initialization of large Rust types with `FromTLV::init_from_tlv` which requires random access for reasons +/// beyond the possible unordering of the TLV struct elements. +/// +/// In practice, random access - and in general - representation of the TLV stream as a `&[u8]` slice should be natural and +/// convenient, as the TLV stream usually comes from the network UDP/TCP memory buffers of the Matter transport protocol, and +/// these can and are borrowed as `&[u8]` slices in the upper-layer code for direct reads. +#[derive(Clone, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct TLVElement<'a>(TLVSequence<'a>); + +impl<'a> TLVElement<'a> { + /// Create a new `TLVElement` from a byte slice, where the byte slice contains an encoded TLV stream (a TLV element). + #[inline(always)] + pub fn new(data: &'a [u8]) -> Self { + Self(TLVSequence(data)) + } + + /// Return `true` if the wrapped byte slice is the empty `&[]` slice. + /// Empty byte slices do not represent valid TLV data, as the TLV data should be a valid TLV element, + /// yet they are useful when implementing the `FromTLV` trait. + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.0 .0.is_empty() + } + + /// Return `Some(self)` if the wrapped byte slice is not empty, `None` otherwise. + pub fn non_empty(&self) -> Option<&TLVElement<'a>> { + if self.is_empty() { + None + } else { + Some(self) + } + } + + /// Return a copy of the wrapped TLV byte slice. + #[inline(always)] + pub const fn raw_data(&self) -> &'a [u8] { + self.0 .0 + } + + /// Return the TLV control byte of the first TLV in the slice. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the first byte of the slice does + /// not represent a valid TLV control byte or if the wrapped byte slice is empty. + #[inline(always)] + pub fn control(&self) -> Result { + self.0.control() + } + + /// Return a sub-slice of the wrapped byte slice that designates the encoded value + /// of this `TLVElement` (i.e. the raw "value" aspect of the Tag-Length-Value encoding) + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// For getting a parsed value, use `value` or any of the other helper methods that + /// retrieve a value of a certain type. + #[inline(always)] + pub fn raw_value(&self) -> Result<&'a [u8], Error> { + self.0.raw_value() + } + + /// Return a `TLV` struct representing the tag and value of this `TLVElement`. + /// This method is a convenience method that combines the `tag` and `value` methods. + pub fn tlv(&self) -> Result, Error> { + Ok(TLV { + tag: self.tag()?, + value: self.value()?, + }) + } + + /// Return a `TLVTag` enum representing the tag of this `TLVElement`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV + /// byte slice contains malformed TLV data. + #[inline(always)] + pub fn tag(&self) -> Result { + let tag_type = self.control()?.tag_type; + + let slice = self + .0 + .tag_start()? + .get(..tag_type.size()) + .ok_or(ErrorCode::TLVTypeMismatch)?; + + let tag = match tag_type { + TLVTagType::Anonymous => TLVTag::Anonymous, + TLVTagType::Context => TLVTag::Context(slice[0]), + TLVTagType::CommonPrf16 => { + TLVTag::CommonPrf16(u16::from_le_bytes(slice.try_into().unwrap())) + } + TLVTagType::CommonPrf32 => { + TLVTag::CommonPrf32(u32::from_le_bytes(slice.try_into().unwrap())) + } + TLVTagType::ImplPrf16 => { + TLVTag::ImplPrf16(u16::from_le_bytes(slice.try_into().unwrap())) + } + TLVTagType::ImplPrf32 => { + TLVTag::ImplPrf32(u32::from_le_bytes(slice.try_into().unwrap())) + } + TLVTagType::FullQual48 => TLVTag::FullQual48 { + vendor_id: u16::from_le_bytes([slice[0], slice[1]]), + profile: u16::from_le_bytes([slice[2], slice[3]]), + tag: u16::from_le_bytes([slice[4], slice[5]]), + }, + TLVTagType::FullQual64 => TLVTag::FullQual64 { + vendor_id: u16::from_le_bytes([slice[0], slice[1]]), + profile: u16::from_le_bytes([slice[2], slice[3]]), + tag: u32::from_le_bytes([slice[4], slice[5], slice[6], slice[7]]), + }, + }; + + Ok(tag) + } + + /// Return a `TLVValue` enum representing the value of this `TLVElement`. + /// + /// Note that if the TLV element is a container, the return `TLV` value would only deisgnate + /// the container type (struct, array or list) and not the actual content of the container. + pub fn value(&self) -> Result, Error> { + let control = self.control()?; + + let slice = self.0.container_value(control)?; + + let value = match control.value_type { + TLVValueType::S8 => TLVValue::S8(i8::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::S16 => TLVValue::S16(i16::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::S32 => TLVValue::S32(i32::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::S64 => TLVValue::S64(i64::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::U8 => TLVValue::U8(u8::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::U16 => TLVValue::U16(u16::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::U32 => TLVValue::U32(u32::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::U64 => TLVValue::U64(u64::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::False => TLVValue::False, + TLVValueType::True => TLVValue::True, + TLVValueType::F32 => TLVValue::F32(f32::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::F64 => TLVValue::F64(f64::from_le_bytes(slice.try_into().unwrap())), + TLVValueType::Utf8l => TLVValue::Utf8l( + core::str::from_utf8(slice).map_err(|_| ErrorCode::TLVTypeMismatch)?, + ), + TLVValueType::Utf16l => TLVValue::Utf16l( + core::str::from_utf8(slice).map_err(|_| ErrorCode::TLVTypeMismatch)?, + ), + TLVValueType::Utf32l => TLVValue::Utf32l( + core::str::from_utf8(slice).map_err(|_| ErrorCode::TLVTypeMismatch)?, + ), + TLVValueType::Utf64l => TLVValue::Utf64l( + core::str::from_utf8(slice).map_err(|_| ErrorCode::TLVTypeMismatch)?, + ), + TLVValueType::Str8l => TLVValue::Str8l(slice), + TLVValueType::Str16l => TLVValue::Str16l(slice), + TLVValueType::Str32l => TLVValue::Str32l(slice), + TLVValueType::Str64l => TLVValue::Str64l(slice), + TLVValueType::Null => TLVValue::Null, + TLVValueType::Struct => TLVValue::Struct, + TLVValueType::Array => TLVValue::Array, + TLVValueType::List => TLVValue::List, + TLVValueType::EndCnt => TLVValue::EndCnt, + }; + + Ok(value) + } + + pub fn len(&self) -> Result { + todo!() + } + + /// Return the value of this TLV element as an `i8`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV S8 value. + pub fn i8(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::S8) { + Ok(i8::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Return the value of this TLV element as a `u8`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV U8 value. + pub fn u8(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::U8) { + Ok(u8::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Return the value of this TLV element as an `i16`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV S8 or S16 value. + pub fn i16(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::S16) { + Ok(i16::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + self.i8().map(|a| a.into()) + } + } + + /// Return the value of this TLV element as a `u16`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV U8 or U16 value. + pub fn u16(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::U16) { + Ok(u16::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + self.u8().map(|a| a.into()) + } + } + + /// Return the value of this TLV element as an `i32`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV S8, S16 or S32 value. + pub fn i32(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::S32) { + Ok(i32::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + self.i16().map(|a| a.into()) + } + } + + /// Return the value of this TLV element as a `u32`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV U8, U16 or U32 value. + pub fn u32(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::U32) { + Ok(u32::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + self.u16().map(|a| a.into()) + } + } + + /// Return the value of this TLV element as an `i64`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV S8, S16, S32 or S64 value. + pub fn i64(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::S64) { + Ok(i64::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + self.i32().map(|a| a.into()) + } + } + + /// Return the value of this TLV element as a `u64`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV U8, U16, U32 or U64 value. + pub fn u64(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::U64) { + Ok(u64::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + self.u32().map(|a| a.into()) + } + } + + /// Return the value of this TLV element as an `f32`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV F32 value. + pub fn f32(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::F32) { + Ok(f32::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Return the value of this TLV element as an `f64`. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV F64 value. + pub fn f64(&self) -> Result { + let control = self.control()?; + + if matches!(control.value_type, TLVValueType::F64) { + Ok(f64::from_le_bytes( + self.0 + .value(control)? + .try_into() + .map_err(|_| ErrorCode::InvalidData)?, + )) + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Return the value of this TLV element as a byte slice. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV Octet String. + pub fn str(&self) -> Result<&'a [u8], Error> { + let control = self.control()?; + + if !control.value_type.is_str() { + Err(ErrorCode::Invalid)?; + } + + self.0.value(control) + } + + /// Return the value of this TLV element as a UTF-8 string. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV UTF-8 String. + pub fn utf8(&self) -> Result<&'a str, Error> { + let control = self.control()?; + + if !control.value_type.is_utf8() { + Err(ErrorCode::Invalid)?; + } + + core::str::from_utf8(self.0.value(control)?).map_err(|_| ErrorCode::InvalidData.into()) + } + + /// Return the value of this TLV element as a UTF-16 string. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV UTF-8 String or a TLV octet string. + pub fn octets(&self) -> Result<&'a [u8], Error> { + let control = self.control()?; + + if control.value_type.variable_size_len() == 0 { + Err(ErrorCode::Invalid)?; + } + + self.0.value(control) + } + + /// Return the value of this TLV element as a UTF-16 string. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV boolean. + pub fn bool(&self) -> Result { + let control = self.control()?; + + match control.value_type { + TLVValueType::False => Ok(false), + TLVValueType::True => Ok(true), + _ => Err(ErrorCode::TLVTypeMismatch.into()), + } + } + + /// Return `true` if this TLV element is as a container (i.e., a struct, array or list). + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + pub fn is_container(&self) -> Result { + Ok(self.control()?.value_type.is_container()) + } + + /// Confirm that this TLV element contains a TLV null value. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV null value. + pub fn null(&self) -> Result<(), Error> { + if matches!(self.control()?.value_type, TLVValueType::Null) { + Ok(()) + } else { + Err(ErrorCode::InvalidData.into()) + } + } + + /// Return the content of the struct container represented by this TLV element. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV struct. + pub fn structure(&self) -> Result, Error> { + self.r#struct() + } + + /// Return the content of the struct container represented by this TLV element. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV struct. + /// + /// (Same as method `structure` but with a special name to ease the `FromTLV` trait derivation for + /// user types.) + pub fn r#struct(&self) -> Result, Error> { + if matches!(self.control()?.value_type, TLVValueType::Struct) { + self.0.next_enter() + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Return the content of the array container represented by this TLV element. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV array. + pub fn array(&self) -> Result, Error> { + if matches!(self.control()?.value_type, TLVValueType::Array) { + self.0.next_enter() + } else { + Err(ErrorCode::InvalidData.into()) + } + } + + /// Return the content of the list container represented by this TLV element. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV list. + pub fn list(&self) -> Result, Error> { + if matches!(self.control()?.value_type, TLVValueType::List) { + self.0.next_enter() + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Return the content of the container (array, struct or list) represented by this TLV element. + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the value of the TLV element is not + /// a TLV container. + pub fn container(&self) -> Result, Error> { + if matches!( + self.control()?.value_type, + TLVValueType::List | TLVValueType::Array | TLVValueType::Struct + ) { + self.0.next_enter() + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Confirm that this TLV element is tagged with the anonymous tag (`TLVTag::Anonymous`). + /// + /// Returns an error with code `ErrorCode::TLVTypeMismatch` if the wrapped TLV byte slice + /// contains malformed TLV data. + /// + /// Returns an error with code `ErrorCode::InvalidData` if the tag of the TLV element is not + /// the anonymous tag. + pub fn confirm_anon(&self) -> Result<(), Error> { + if matches!(self.control()?.tag_type, TLVTagType::Anonymous) { + Ok(()) + } else { + Err(ErrorCode::TLVTypeMismatch.into()) + } + } + + /// Retrieve the context ID of the element. + /// If element is not tagged with a context tag, the method will return an error. + pub fn ctx(&self) -> Result { + Ok(self.try_ctx()?.ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Retrieve the context ID of the element. + /// If element is not tagged with a context tag, the method will return `None`. + pub fn try_ctx(&self) -> Result, Error> { + let control = self.control()?; + + if matches!(control.tag_type, TLVTagType::Context) { + Ok(Some( + *self + .0 + .tag(control.tag_type)? + .first() + .ok_or(ErrorCode::TLVTypeMismatch)?, + )) + } else { + Ok(None) + } + } + + fn fmt(&self, indent: usize, f: &mut fmt::Formatter) -> fmt::Result { + pad(indent, f)?; + + let tag = self.tag().map_err(|_| fmt::Error)?; + + tag.fmt(f)?; + + if !matches!(tag.tag_type(), TLVTagType::Anonymous) { + write!(f, ": ")?; + } + + let value = self.value().map_err(|_| fmt::Error)?; + + value.fmt(f)?; + + if value.value_type().is_container() { + let mut empty = true; + + for (index, elem) in self.container().map_err(|_| fmt::Error)?.iter().enumerate() { + if index > 0 { + writeln!(f, ",")?; + } else { + writeln!(f)?; + } + + elem.map_err(|_| fmt::Error)?.fmt(indent + 2, f)?; + + empty = false; + } + + if !empty { + writeln!(f)?; + pad(indent, f)?; + } + + match value.value_type() { + TLVValueType::Struct => write!(f, "}}"), + TLVValueType::Array => write!(f, "]"), + TLVValueType::List => write!(f, ")"), + _ => unreachable!(), + }?; + } + + Ok(()) + } +} + +impl<'a> fmt::Debug for TLVElement<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.fmt(0, f) + } +} + +impl<'a> fmt::Display for TLVElement<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.fmt(0, f) + } +} + +/// A newtype for iterating over the `TLVElement` "child" instances contained in `TLVElement` which is a TLV container +/// (array, struct or list). +/// (Internally, `TLVSequence` might be used for other purposes, but the external contract is only the one from above.) +/// +/// Just like `TLVElement`, `TLVSequence` is a newtype over a byte slice - the byte sub-slice of the parent `TLVElement` +/// container where its value starts. +/// +/// Unlike `TLVElement`, `TLVSequence` - as the name suggests - represents a sequence of 0, 1 or more `TLVElements`. +/// The only public API of `TLVSequence` however is the `iter` method which returns a `TLVContainerIter` iterator over +/// the `TLVElement` instances in the sequence. +#[derive(Clone, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct TLVSequence<'a>(pub(crate) &'a [u8]); + +impl<'a> TLVSequence<'a> { + const EMPTY: Self = Self(&[]); + + /// Return an iterator over the `TLVElement` instances in this `TLVSequence`. + #[inline(always)] + pub fn iter(&self) -> TLVSequenceIter<'a> { + TLVSequenceIter::new(self.clone()) + } + + /// Return an iterator over the `TLV` instances in this `TLVSequence`. + /// + /// The difference with `iter` is that for container elements, `tlv_iter` + /// will return separate `TLV` instances for the container start, the container + /// elements and the container end, where if an element in the container is + /// itself a container, the algorithm will be applied recursively to the inner container. + pub fn tlv_iter(&self) -> TLVSequenceTLVIter<'a> { + TLVSequenceTLVIter::new(self.clone()) + } + + /// A convenience utility that returns the first `TLVElement` in the sequence + /// which is tagged with a context tag (`TLVTag::Context`) where the context ID + /// is matching the ID passed in the `ctx` parameter. + /// + /// If there is no TLV element tagged with a context tag with the matching ID, the method + /// will return an error. + pub fn ctx(&self, ctx: u8) -> Result, Error> { + let element = self.find_ctx(ctx)?; + + if element.is_empty() { + Err(ErrorCode::NotFound.into()) + } else { + Ok(element) + } + } + + /// A convenience utility that returns the first `TLVElement` in the sequence + /// which is tagged with a context tag (`TLVTag::Context`) where the context ID + /// is matching the ID passed in the `ctx` parameter. + /// + /// If there is no TLV element tagged with a context tag with the matching ID, the method + /// will return an empty `TLVElement`. + pub fn find_ctx(&self, ctx: u8) -> Result, Error> { + for elem in self.iter() { + let elem = elem?; + + if let Some(elem_ctx) = elem.try_ctx()? { + if elem_ctx == ctx { + return Ok(elem); + } + } + } + + Ok(TLVElement(Self::EMPTY)) + } + + /// A convenience utility that returns the first `TLVElement` in the sequence + /// which is tagged with a context tag (`TLVTag::Context`) where the context ID + /// is equal to the ID passed in the `ctx` parameter. + /// + /// If there is no TLV element tagged with a context tag with the matching ID, the method + /// will return an empty TLV element. + /// + /// As a side effect of calling this method, the `TLVSequence` instance will be updated + /// to point to the next element after the found element, or if an element with the + /// provided context ID does not exist, to the first element with a bigger context ID than + /// the one we are looking for. + pub fn scan_ctx(&mut self, ctx: u8) -> Result, Error> { + self.scan_map(move |elem| { + if elem.is_empty() { + return Ok(Some(elem)); + } + + if let Some(elem_ctx) = elem.try_ctx()? { + match elem_ctx.cmp(&ctx) { + Ordering::Equal => return Ok(Some(elem)), + Ordering::Greater => return Ok(Some(TLVElement(Self::EMPTY))), + _ => (), + } + } + + Ok(None) + }) + } + + /// A convenience utility that returns scans the elements in the sequence, + /// in-order and stops scanning once the provided mapping closure `f` + /// returns a non-empty result. + /// + /// As a side effect of calling this method, the `TLVSequence` instance will be updated + /// to point to the next element after the one on which the provided closure + /// returned a non-empty result. + /// + /// Note that the closure _must_ ultimately return a non-empty result - if for nothing else + /// then for the empty element that is passed to it when the sequence is exhausted, + /// or else the method would loop forever. + pub fn scan_map(&mut self, mut f: F) -> Result + where + F: FnMut(TLVElement<'a>) -> Result, Error>, + { + loop { + if let Some(elem) = f(self.current()?)? { + return Ok(elem); + } + + *self = self.container_next()?; + } + } + + /// Return a raw byte sub-slice representing the TLV-encoded elements and only those + /// elements that belong to the TLV container whose elements are represented by this `TLVSequence` instance. + /// + /// This method is necessary, because both `TLVElement` instances, as well as `TLVSequence` instances - for optimization purposes - + /// might be constructed during iteration on slices which are technically longer than the actual TLV-encoded data + /// they represent. + /// + /// So in case the user is need of the actual, exact raw representation of a TLV container **value**, this method is provided. + #[inline(always)] + pub fn raw_value(&self) -> Result<&'a [u8], Error> { + let control = self.control()?; + + self.container_value(control) + } + + /// Return a sub-sequence representing the TLV-encoded elements after the first one on the sequence. + /// + /// As the name suggests, if the first TLV element in the sequence is a container, this method will return a sub-sequence + /// which corresponds to the first element INSIDE the container. + /// + /// If the sequence is empty, or the sequence contains just one element, the method will return an empty `TLVSequence`. + /// + /// Note also that this method will also return sub-sequences where the first element might be a TLV `TLVValueType::EndCnt` marker, + /// which - formally speaking - is not a TLVElement, but a TLV control byte that marks the end of a container. + fn next_enter(&self) -> Result { + if self.0.is_empty() { + return Ok(Self::EMPTY); + } + + let control = self.control()?; + + Ok(Self(self.next_start(control)?)) + } + + /// Return a sub-sequence representing the TLV-encoded elements after the first one on the sequence. + /// + /// As the name suggests, if the first TLV element in the sequence is a container, this method will return a sub-sequence + /// which corresponds to the elements AFTER the container element (i.e., the method "skips over" the elements of the container element). + /// + /// If the sequence is empty or the sequence starts with a container-end control byte, the method will return the current sequence. + fn container_next(&self) -> Result { + if self.0.is_empty() { + return Ok(Self::EMPTY); + } + + let control = self.control()?; + + if control.value_type.is_container_end() { + control.confirm_container_end()?; + + return Ok(self.clone()); + } + + let mut next = self.next_enter()?; + + if control.value_type.is_container() { + let mut level = 1; + + while level > 0 { + let control = next.control()?; + + if control.value_type.is_container_end() { + control.confirm_container_end()?; + level -= 1; + } else if control.value_type.is_container() { + level += 1; + } + + next = next.next_enter()?; + } + } + + Ok(next) + } + + /// Return the first TLV element in the sequence. + /// If the sequence is empty, or if the sequence starts with a container-end TLV, + /// an empty element will be returned. + fn current(&self) -> Result, Error> { + if self.0.is_empty() { + return Ok(TLVElement(Self::EMPTY)); + } + + let control = self.control()?; + + if control.value_type.is_container_end() { + control.confirm_container_end()?; + + return Ok(TLVElement(Self::EMPTY)); + } + + return Ok(TLVElement::new(self.0)); + } + + /// Return the TLV control byte of the first TLV in the sequence. + /// If the sequence is empty, an error will be returned. + #[inline(always)] + fn control(&self) -> Result { + TLVControl::parse(*self.0.first().ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return a sub-slice of the wrapped byte slice that designates the START of the tag payload + /// of the first TLV in the sequence. + /// + /// If there is no tag payload (i.e., the tag is of type `TLVTagType::Anonymous`), the returned sub-slice + /// will designate the start of the TLV element value or value length. + #[inline(always)] + fn tag_start(&self) -> Result<&'a [u8], Error> { + Ok(self.0.get(1..).ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return a sub-slice of the wrapped byte slice that designates the exact raw slice representing the tag payload + /// of the first TLV in the sequence. + /// + /// If there is no tag payload (i.e., the tag is of type `TLVTagType::Anonymous`), the returned sub-slice + /// will be the empty slice. + #[inline(always)] + fn tag(&self, tag_type: TLVTagType) -> Result<&'a [u8], Error> { + Ok(self + .tag_start()? + .get(..tag_type.size()) + .ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return a sub-slice of the wrapped byte slice that designates the START of the value length field + /// of the first TLV in the sequence. + /// + /// The value length field is the field that designates the length of the value of the TLV element. + /// If the TLV element control byte designates an element with a fixed size or a container element, + /// the returned sub-slice will designate the start of the value field. + #[inline(always)] + fn value_len_start(&self, tag_type: TLVTagType) -> Result<&'a [u8], Error> { + Ok(self + .tag_start() + .unwrap() + .get(tag_type.size()..) + .ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return a sub-slice of the wrapped byte slice that designates the START of the value field of + /// the first TLV in the sequence. + /// + /// The value field is the field that designates the actual value of the TLV element. + #[inline(always)] + fn value_start(&self, control: TLVControl) -> Result<&'a [u8], Error> { + Ok(self + .value_len_start(control.tag_type)? + .get(control.value_type.variable_size_len()..) + .ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return a sub-slice of the wrapped byte slice that designates the exact raw slice representing the value payload + /// of the first TLV element in the sequence. + /// + /// For container elements, this method will return the empty slice. Use `container_value` (a more computationally expensive method) + /// to get the exact taw slice of the first TLV element value that also works for containers. + #[inline(always)] + fn value(&self, control: TLVControl) -> Result<&'a [u8], Error> { + let value_len = self.value_len(control)?; + + Ok(self + .value_start(control)? + .get(..value_len) + .ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return a sub-slice of the wrapped byte slice that designates the exact raw slice representing the value payload + /// of the first TLV element in the sequence. + #[inline(always)] + fn container_value(&self, control: TLVControl) -> Result<&'a [u8], Error> { + let value_len = self.container_value_len(control)?; + + Ok(self + .value_start(control)? + .get(..value_len) + .ok_or(ErrorCode::TLVTypeMismatch)?) + } + + /// Return the length of the value field of the first TLV element in the sequence. + /// + /// - For elements that do have a fixed size, the fixed size will be returned. + /// - For UTF-8 and octet strings, the actual string length will be returned. + /// - For containers, a length of 0 will be returned. Use `container_value_len` + /// (much more computationally expensive method) to get the exact length of the container. + #[inline(always)] + fn value_len(&self, control: TLVControl) -> Result { + if let Some(fixed_size) = control.value_type.fixed_size() { + return Ok(fixed_size); + } + + let size_len = control.value_type.variable_size_len(); + + let value_len_slice = self + .value_len_start(control.tag_type)? + .get(..size_len) + .ok_or(ErrorCode::TLVTypeMismatch)?; + + let len = match size_len { + 1 => u8::from_be_bytes(value_len_slice.try_into().unwrap()) as usize, + 2 => u16::from_le_bytes(value_len_slice.try_into().unwrap()) as usize, + 4 => u32::from_le_bytes(value_len_slice.try_into().unwrap()) as usize, + 8 => u64::from_le_bytes(value_len_slice.try_into().unwrap()) as usize, + _ => unreachable!(), + }; + + Ok(len) + } + + /// Return the length of the value field of the first TLV element in the sequence, regardless of the + /// element type (fixed size, variable size, or container). + #[inline(always)] + fn container_value_len(&self, control: TLVControl) -> Result { + if control.value_type.is_container() { + let mut next = self.clone(); + let mut len = 0; + let mut level = 1; + + while level > 0 { + next = next.next_enter()?; + len += next.len()?; + + let control = next.control()?; + + if control.value_type.is_container_end() { + control.confirm_container_end()?; + level -= 1; + } else if control.value_type.is_container() { + level += 1; + } + } + + Ok(len) + } else { + self.value_len(control) + } + } + + /// Return the length of the first TLV element in the sequence. + /// + /// For containers, the return length will NOT include the elements contained inside + /// the container, nor the one-byte `EndCnt` marker. + #[inline(always)] + fn len(&self) -> Result { + let control = self.control()?; + + self.value_len(control).map(|value_len| { + 1 + control.tag_type.size() + control.value_type.variable_size_len() + value_len + }) + } + + /// Returns a sub-slice representing the start of the next TLV element in the sequence. + /// If the sequence contains just one element, the method will return an empty slice. + /// If the sequence contains no elements, the method will return an error with code `ErrorCode::TLVTypeMismatch`. + /// + /// Just like `next_enter` (wich is based on `next_start`) this method does "enter" container elements, + /// and might return a sub-slice where the first element is the special `EndCnt` marker. + #[inline(always)] + fn next_start(&self, control: TLVControl) -> Result<&'a [u8], Error> { + let value_len = self.value_len(control)?; + + Ok(self + .value_start(control)? + .get(value_len..) + .ok_or(ErrorCode::TLVTypeMismatch)?) + } + + pub(crate) fn fmt(&self, indent: usize, f: &mut fmt::Formatter) -> fmt::Result { + let mut first = true; + + for elem in self.iter() { + if first { + first = false; + } else { + writeln!(f, ",")?; + } + + let elem = elem.map_err(|_| fmt::Error)?; + + elem.fmt(indent, f)?; + } + + if !first { + writeln!(f)?; + } + + Ok(()) + } +} + +impl<'a> fmt::Debug for TLVSequence<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.fmt(0, f) + } +} + +impl<'a> fmt::Display for TLVSequence<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.fmt(0, f) + } +} + +/// A type representing an iterator over the elements of a `TLVSequence` returning `TLV` instances. +#[derive(Clone)] +pub struct TLVSequenceTLVIter<'a> { + seq: TLVSequence<'a>, + nesting: usize, +} + +impl<'a> TLVSequenceTLVIter<'a> { + /// Create a new `TLVContainerIter` instance. + const fn new(seq: TLVSequence<'a>) -> Self { + Self { seq, nesting: 0 } + } + + fn try_next(&mut self) -> Result>, Error> { + let current = self.seq.current()?; + if current.is_empty() { + return Ok(None); + } + + self.advance()?; + + Ok(Some(TLV::new(current.tag()?, current.value()?))) + } + + fn advance(&mut self) -> Result<(), Error> { + if self.nesting > 0 || !self.seq.0.is_empty() && !self.seq.control()?.is_container_end() { + self.seq = self.seq.next_enter()?; + + let control = self.seq.control()?; + + if control.is_container_start() { + self.nesting += 1; + } else if control.is_container_end() { + self.nesting -= 1; + } + } + + Ok(()) + } +} + +impl<'a> Iterator for TLVSequenceTLVIter<'a> { + type Item = Result, Error>; + + fn next(&mut self) -> Option { + self.try_next().transpose() + } +} + +/// A type representing an iterator over the elements of a `TLVSequence`. +#[derive(Clone)] +#[repr(transparent)] +pub struct TLVSequenceIter<'a>(TLVSequence<'a>); + +impl<'a> TLVSequenceIter<'a> { + /// Create a new `TLVContainerIter` instance. + const fn new(seq: TLVSequence<'a>) -> Self { + Self(seq) + } + + fn advance(&mut self) -> Result<(), Error> { + self.0 = self.0.container_next()?; + + Ok(()) + } +} + +impl<'a> Iterator for TLVSequenceIter<'a> { + type Item = Result, Error>; + + fn next(&mut self) -> Option { + self.0 + .current() + .and_then(|current| self.advance().map(|_| current)) + .map(|elem| (!elem.is_empty()).then_some(elem)) + .transpose() + } +} + +impl<'a> fmt::Debug for TLVSequenceIter<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(0, f) + } +} + +impl<'a> fmt::Display for TLVSequenceIter<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(0, f) + } +} + +#[cfg(test)] +mod tests { + use core::{f32, f64}; + + use super::TLVElement; + use crate::tlv::{TLVArray, TLVList, TLVSequence, TLVStruct, TLVTag, TLVValue, TLV}; + + #[test] + fn test_no_container_for_int() { + // The 0x24 is a a tagged integer, here the integer is 2 + let data = &[0x15, 0x24, 0x1, 0x2]; + let seq = TLVSequence(data); + // Skip the 0x15 + let seq = seq.next_enter().unwrap(); + + let elem = TLVElement(seq); + assert!(elem.container().is_err()); + } + + #[test] + fn test_struct_iteration_with_mix_values() { + // This is a struct with 3 valid values + let data = &[ + 0x15, 0x24, 0x0, 0x2, 0x26, 0x2, 0x4e, 0x10, 0x02, 0x00, 0x30, 0x3, 0x04, 0x73, 0x6d, + 0x61, 0x72, + ]; + + let mut root_iter = TLVElement::new(data).structure().unwrap().iter(); + assert_eq!( + root_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(0), + value: TLVValue::U8(2), + } + ); + assert_eq!( + root_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(2), + value: TLVValue::U32(135246), + } + ); + assert_eq!( + root_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(3), + value: TLVValue::Str8l(&[0x73, 0x6d, 0x61, 0x72]), + } + ); + } + + #[test] + fn test_struct_find_element_mix_values() { + // This is a struct with 3 valid values + let data = &[ + 0x15, 0x30, 0x3, 0x04, 0x73, 0x6d, 0x61, 0x72, 0x24, 0x0, 0x2, 0x26, 0x2, 0x4e, 0x10, + 0x02, 0x00, + ]; + let root = TLVElement::new(data).structure().unwrap(); + + assert_eq!( + root.find_ctx(0).unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(0), + value: TLVValue::U8(2), + } + ); + assert_eq!(root.find_ctx(2).unwrap().tag().unwrap(), TLVTag::Context(2)); + assert_eq!(root.find_ctx(2).unwrap().u64().unwrap(), 135246); + + assert_eq!(root.find_ctx(3).unwrap().tag().unwrap(), TLVTag::Context(3)); + assert_eq!( + root.find_ctx(3).unwrap().str().unwrap(), + &[0x73, 0x6d, 0x61, 0x72] + ); + } + + #[test] + fn test_list_iteration_with_mix_values() { + // This is a list with 3 valid values + let data = &[ + 0x17, 0x24, 0x0, 0x2, 0x26, 0x2, 0x4e, 0x10, 0x02, 0x00, 0x30, 0x3, 0x04, 0x73, 0x6d, + 0x61, 0x72, + ]; + let mut root_iter = TLVElement::new(data).list().unwrap().iter(); + assert_eq!( + root_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(0), + value: TLVValue::U8(2), + } + ); + assert_eq!( + root_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(2), + value: TLVValue::U32(135246), + } + ); + assert_eq!( + root_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(3), + value: TLVValue::Str8l(&[0x73, 0x6d, 0x61, 0x72]), + } + ); + } + + #[test] + fn test_read_past_end_of_container() { + let data = &[0x15, 0x35, 0x0, 0x24, 0x1, 0x2, 0x18, 0x24, 0x0, 0x2, 0x18]; + + let mut struct2_iter = TLVElement::new(data) + .structure() + .unwrap() + .find_ctx(0) + .unwrap() + .structure() + .unwrap() + .iter(); + + assert_eq!( + struct2_iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Context(1), + value: TLVValue::U8(2), + } + ); + assert!(struct2_iter.next().is_none()); + // Call next, even after the first next returns None + assert!(struct2_iter.next().is_none()); + assert!(struct2_iter.next().is_none()); + } + + #[test] + fn test_iteration() { + // This is the input we have + // { + // 0: [ + // { + // 0: L[ 0: 2, 2: 6, 3: 1], + // 1: {}, + // }, + // ], + // } + + let data = &[ + 0x15, 0x36, 0x0, 0x15, 0x37, 0x0, 0x24, 0x0, 0x2, 0x24, 0x2, 0x6, 0x24, 0x3, 0x1, 0x18, + 0x35, 0x1, 0x18, 0x18, 0x18, 0x18, + ]; + + let struct0 = TLVStruct::::new(TLVElement::new(data)).unwrap(); + + assert_eq!( + struct0.element().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Struct, + } + ); + assert_eq!(struct0.iter().count(), 1); + + let array = TLVArray::::new(struct0.iter().next().unwrap().unwrap()).unwrap(); + + assert_eq!( + array.element().tlv().unwrap(), + TLV { + tag: TLVTag::Context(0), + value: TLVValue::Array, + } + ); + assert_eq!(array.iter().count(), 1); + + let struct1 = TLVStruct::::new(array.iter().next().unwrap().unwrap()).unwrap(); + assert_eq!( + struct1.element().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Struct, + } + ); + assert_eq!(struct1.iter().count(), 2); + + let mut struct1_iter = struct1.iter(); + + let list = TLVList::::new(struct1_iter.next().unwrap().unwrap()).unwrap(); + assert_eq!( + list.element().tlv().unwrap(), + TLV { + tag: TLVTag::Context(0), + value: TLVValue::List, + } + ); + assert_eq!(list.iter().count(), 3); + + let mut list_iter = list.iter(); + + let le1 = list_iter.next().unwrap().unwrap(); + assert_eq!( + le1.tlv().unwrap(), + TLV { + tag: TLVTag::Context(0), + value: TLVValue::U8(2) + } + ); + + let le2 = list_iter.next().unwrap().unwrap(); + assert_eq!( + le2.tlv().unwrap(), + TLV { + tag: TLVTag::Context(2), + value: TLVValue::U8(6) + } + ); + + let le3 = list_iter.next().unwrap().unwrap(); + assert_eq!( + le3.tlv().unwrap(), + TLV { + tag: TLVTag::Context(3), + value: TLVValue::U8(1) + } + ); + + assert!(list_iter.next().is_none()); + + let struct2 = TLVStruct::::new(struct1_iter.next().unwrap().unwrap()).unwrap(); + assert_eq!( + struct2.element().tlv().unwrap(), + TLV { + tag: TLVTag::Context(1), + value: TLVValue::Struct, + } + ); + assert_eq!(struct2.iter().count(), 0); + } + + #[test] + fn test_matter_spec_examples() { + let tlv = |slice| TLVElement::new(slice).tlv().unwrap(); + + // Boolean false + + assert_eq!( + tlv(&[0x08]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::False, + } + ); + + // Boolean true + + assert_eq!( + tlv(&[0x09]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::True, + } + ); + + // Signed Integer, 1-octet, value 42 + + assert_eq!( + tlv(&[0x00, 0x2a]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S8(42), + } + ); + + // Signed Integer, 1-octet, value -17 + + assert_eq!( + tlv(&[0x00, 0xef]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S8(-17), + } + ); + + // Unsigned Integer, 1-octet, value 42U + + assert_eq!( + tlv(&[0x04, 0x2a]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::U8(42), + } + ); + + // Signed Integer, 2-octet, value 42 + + assert_eq!( + tlv(&[0x01, 0x2a, 0x00]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S16(42), + } + ); + + // Signed Integer, 4-octet, value -170000 + + assert_eq!( + tlv(&[0x02, 0xf0, 0x67, 0xfd, 0xff]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S32(-170000), + } + ); + + // Signed Integer, 8-octet, value 40000000000 + + assert_eq!( + tlv(&[0x03, 0x00, 0x90, 0x2f, 0x50, 0x09, 0x00, 0x00, 0x00]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S64(40000000000), + } + ); + + // UTF-8 String, 1-octet length, "Hello!" + + assert_eq!( + tlv(&[0x0c, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Utf8l("Hello!"), + } + ); + + // UTF-8 String, 1-octet length, "Tschüs" + + assert_eq!( + tlv(&[0x0c, 0x07, 0x54, 0x73, 0x63, 0x68, 0xc3, 0xbc, 0x73]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Utf8l("Tschüs"), + } + ); + + // Octet String, 1-octet length, octets 00 01 02 03 04 + + assert_eq!( + tlv(&[0x10, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Str8l(&[0x00, 0x01, 0x02, 0x03, 0x04]), + } + ); + + // Null + + assert_eq!( + tlv(&[0x14]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Null, + } + ); + + // Single precision floating point 0.0 + + assert_eq!( + tlv(&[0x0a, 0x00, 0x00, 0x00, 0x00]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F32(0.0), + } + ); + + // Single precision floating point (1.0 / 3.0) + + assert_eq!( + tlv(&[0x0a, 0xab, 0xaa, 0xaa, 0x3e]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F32(1.0 / 3.0), + } + ); + + // Single precision floating point 17.9 + + assert_eq!( + tlv(&[0x0a, 0x33, 0x33, 0x8f, 0x41]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F32(17.9), + } + ); + + // Single precision floating point infinity + + assert_eq!( + tlv(&[0x0a, 0x00, 0x00, 0x80, 0x7f]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F32(f32::INFINITY), + } + ); + + // Single precision floating point negative infinity + + assert_eq!( + tlv(&[0x0a, 0x00, 0x00, 0x80, 0xff]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F32(f32::NEG_INFINITY), + } + ); + + // Double precision floating point 0.0 + + assert_eq!( + tlv(&[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F64(0.0), + } + ); + + // Double precision floating point (1.0 / 3.0) + + assert_eq!( + tlv(&[0x0b, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xd5, 0x3f]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F64(1.0 / 3.0), + } + ); + + // Double precision floating point 17.9 + + assert_eq!( + tlv(&[0x0b, 0x66, 0x66, 0x66, 0x66, 0x66, 0xe6, 0x31, 0x40]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F64(17.9), + } + ); + + // Double precision floating point infinity (∞) + + assert_eq!( + tlv(&[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x7f]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F64(f64::INFINITY), + } + ); + + // Double precision floating point negative infinity + + assert_eq!( + tlv(&[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xff]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F64(f64::NEG_INFINITY), + } + ); + + // Empty Structure, {} + + assert_eq!( + tlv(&[0x15, 0x18]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Struct, + } + ); + + assert!(TLVElement::new(&[0x15, 0x18]) + .structure() + .unwrap() + .iter() + .next() + .is_none()); + + // Empty Array, [] + + assert_eq!( + tlv(&[0x16, 0x18]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Array, + } + ); + + assert!(TLVElement::new(&[0x16, 0x18]) + .array() + .unwrap() + .iter() + .next() + .is_none()); + + // Empty List, [] + + assert_eq!( + tlv(&[0x17, 0x18]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::List, + } + ); + + assert!(TLVElement::new(&[0x17, 0x18]) + .list() + .unwrap() + .iter() + .next() + .is_none()); + + // Structure, two context specific tags, Signed Intger, 1 octet values, {0 = 42, 1 = -17} + + let data = &[0x15, 0x20, 0x00, 0x2a, 0x20, 0x01, 0xef, 0x18]; + + assert_eq!( + tlv(data), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Struct, + } + ); + + let mut iter = TLVElement::new(data).structure().unwrap().iter(); + + let s1 = iter.next().unwrap().unwrap(); + assert_eq!(s1.tag().unwrap(), TLVTag::Context(0)); + assert_eq!(s1.i32().unwrap(), 42); + + let s2 = iter.next().unwrap().unwrap(); + assert_eq!(s2.tag().unwrap(), TLVTag::Context(1)); + assert_eq!(s2.i16().unwrap(), -17); + + assert!(iter.next().is_none()); + + // Array, Signed Integer, 1-octet values, [0, 1, 2, 3, 4] + + let data = &[ + 0x16, 0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04, 0x18, + ]; + + assert_eq!( + tlv(data), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Array, + } + ); + + let iter = TLVElement::new(data).array().unwrap().iter().enumerate(); + + for (index, elem) in iter { + let elem = elem.unwrap(); + + assert_eq!(elem.tag().unwrap(), TLVTag::Anonymous); + assert_eq!(elem.i8().unwrap(), index as i8); + } + + // List, mix of anonymous and context tags, Signed Integer, 1 octet values, [[1, 0 = 42, 2, 3, 0 = -17]] + + let data = &[ + 0x17, 0x00, 0x01, 0x20, 0x00, 0x2a, 0x00, 0x02, 0x00, 0x03, 0x20, 0x00, 0xef, 0x18, + ]; + + assert_eq!( + tlv(data), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::List, + } + ); + + let expected = &[ + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S8(1), + }, + TLV { + tag: TLVTag::Context(0), + value: TLVValue::S8(42), + }, + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S8(2), + }, + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S8(3), + }, + TLV { + tag: TLVTag::Context(0), + value: TLVValue::S8(-17), + }, + ]; + + let mut iter = TLVElement::new(data).list().unwrap().iter(); + + for elem in expected { + assert_eq!(iter.next().unwrap().unwrap().tlv().unwrap(), *elem); + } + + assert!(iter.next().is_none()); + + // Array, mix of element types, [42, -170000, {}, 17.9, "Hello!"] + + let data = &[ + 0x16, 0x00, 0x2a, 0x02, 0xf0, 0x67, 0xfd, 0xff, 0x15, 0x18, 0x0a, 0x33, 0x33, 0x8f, + 0x41, 0x0c, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x18, + ]; + + assert_eq!( + tlv(data), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Array, + } + ); + + let mut iter = TLVElement::new(data).array().unwrap().iter(); + + assert_eq!( + iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S8(42), + } + ); + + assert_eq!( + iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::S32(-170000), + } + ); + + assert_eq!( + iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Struct, + } + ); + + assert_eq!( + iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::F32(17.9), + } + ); + + assert_eq!( + iter.next().unwrap().unwrap().tlv().unwrap(), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::Utf8l("Hello!"), + } + ); + + // Anonymous tag, Unsigned Integer, 1-octet value, 42U + + assert_eq!( + tlv(&[0x04, 0x2a]), + TLV { + tag: TLVTag::Anonymous, + value: TLVValue::U8(42), + } + ); + + // Context tag 1, Unsigned Integer, 1-octet value, 1 = 42U + + assert_eq!( + tlv(&[0x24, 0x01, 0x2a]), + TLV { + tag: TLVTag::Context(1), + value: TLVValue::U8(42), + } + ); + + // Common profile tag 1, Unsigned Integer, 1-octet value, Matter::1 = 42U + + assert_eq!( + tlv(&[0x44, 0x01, 0x00, 0x2a]), + TLV { + tag: TLVTag::CommonPrf16(1), + value: TLVValue::U8(42), + } + ); + + // Common profile tag 100000, Unsigned Integer, 1-octet value, Matter::100000 = 42U + + assert_eq!( + tlv(&[0x64, 0xa0, 0x86, 0x01, 0x00, 0x2a]), + TLV { + tag: TLVTag::CommonPrf32(100000), + value: TLVValue::U8(42), + } + ); + + // Fully qualified tag, Vendor ID 0xFFF1/65521, pro­file number 0xDEED/57069, + // 2-octet tag 1, Unsigned Integer, 1-octet value 42, 65521::57069:1 = 42U + + assert_eq!( + tlv(&[0xc4, 0xf1, 0xff, 0xed, 0xde, 0x01, 0x00, 0x2a]), + TLV { + tag: TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 1, + }, + value: TLVValue::U8(42), + } + ); + + // Fully qualified tag, Vendor ID 0xFFF1/65521, pro­file number 0xDEED/57069, + // 4-octet tag 0xAA55FEED/2857762541, Unsigned Integer, 1-octet value 42, 65521::57069:2857762541 = 42U + + assert_eq!( + tlv(&[0xe4, 0xf1, 0xff, 0xed, 0xde, 0xed, 0xfe, 0x55, 0xaa, 0x2a]), + TLV { + tag: TLVTag::FullQual64 { + vendor_id: 65521, + profile: 57069, + tag: 2857762541, + }, + value: TLVValue::U8(42), + } + ); + + // Structure with the fully qualified tag, Vendor ID 0xFFF1/65521, profile number 0xDEED/57069, + // 2-octet tag 1. The structure contains a single ele­ment labeled using a fully qualified tag under + // the same profile, with 2-octet tag 0xAA55/43605. 65521::57069:1 = {65521::57069:43605 = 42U} + + let data = &[ + 0xd5, 0xf1, 0xff, 0xed, 0xde, 0x01, 0x00, 0xc4, 0xf1, 0xff, 0xed, 0xde, 0x55, 0xaa, + 0x2a, 0x18, + ]; + + assert_eq!( + tlv(data), + TLV { + tag: TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 1, + }, + value: TLVValue::Struct, + } + ); + + let mut iter = TLVElement::new(data).structure().unwrap().iter(); + + let u1 = iter.next().unwrap().unwrap(); + + assert_eq!( + u1.tag().unwrap(), + TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 43605, + } + ); + + assert_eq!(u1.u8().unwrap(), 42); + + assert!(iter.next().is_none()); + } +} diff --git a/rs-matter/src/tlv/toiter.rs b/rs-matter/src/tlv/toiter.rs new file mode 100644 index 00000000..3134a3c4 --- /dev/null +++ b/rs-matter/src/tlv/toiter.rs @@ -0,0 +1,750 @@ +use core::iter::{Chain, Once}; + +use crate::error::{Error, ErrorCode}; + +use super::{OnceTLVIter, TLVTag, TLVValue, TLVValueType, TLV}; + +type TLVResult<'a> = Result, Error>; +type ChainedTLVIter<'a, C> = Chain>; + +/// A decorator trait for serializing data as TLV in the form of an +/// `Iterator` of `Result, Error>` bytes. +/// +/// The trait provides additional combinators on top of the standard `Iterator` +/// trait combinators (e.g. `map`, `filter`, `flat_map`, etc.) that allow for serializing TLV elements. +/// +/// The trait is already implemented for any `Iterator` that yields items of type `Result, Error>`, +/// so users are not expected to provide implementations of it. +/// +/// Using an Iterator approach to TLV serialization is useful when the data is not serialized to its +/// final location (be it in the storage or in an outgoing network packet) - but rather - is serialized +/// so that it is afterwards consumed as a stream of bytes by another component - say - a hash signature +/// algorithm that operates on the TLV representation of the data. +/// +/// This way, the need for an interim buffer for the serialized TLV data might be avoided. +/// +/// NOTE: +/// Keep in mind that the resulting iterator might quickly become rather large if the serialized +/// TLV data contains many small TLV elements, as each TLV element is represented as multiple compositions +/// of the Rust `Iterator` combinators (e.g. `chain`, `map`, `flat_map`, etc.), and - moreover - +/// the size of each `TLV` itself is rather large (~ 32 bytes on 32bit archs). +/// +/// Therefore, the iterator TLV serialization is only useful when the serialized TLV data contains few but +/// large non-container TLV elements, like octet strings or utf8 strings +/// (which is typical for e.g. TLV-encoded certificates). +/// +/// For other cases, allocating a temporary memory buffer and serializing into it with `TLVWrite` might result +/// in less memory overhead (and better performance when reading the raw serialized TLV data) by the code that +/// operates on it. +pub trait TLVIter<'a>: Iterator> + Sized { + fn flatten(value: Result) -> EitherIter>> { + match value { + Ok(value) => EitherIter::First(value), + Err(err) => EitherIter::Second(core::iter::once(Err(err))), + } + } + + /// Serialize a TLV tag and value. + fn tlv(self, tag: TLVTag, value: TLVValue<'a>) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::new(tag, value).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as an S8 TLV value. + fn i8(self, tag: TLVTag, data: i8) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::i8(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as a U8 TLV value. + fn u8(self, tag: TLVTag, data: u8) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::u8(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as an S16 TLV value, + /// or as an S8 TLV value if the provided data can fit in the S8 domain range. + fn i16(self, tag: TLVTag, data: i16) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::i16(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as a U16 TLV value, + /// or as a U8 TLV value if the provided data can fit in the U8 domain range. + fn u16(self, tag: TLVTag, data: u16) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::u16(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as an S32 TLV value, + /// or as an S16 / S8 TLV value if the provided data can fit in a smaller domain range. + fn i32(self, tag: TLVTag, data: i32) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::i32(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as a U32 TLV value, + /// or as a U16 / U8 TLV value if the provided data can fit in a smaller domain range. + fn u32(self, tag: TLVTag, data: u32) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::u32(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as an S64 TLV value, + /// or as an S32 / S16 / S8 TLV value if the provided data can fit in a smaller domain range. + fn i64(self, tag: TLVTag, data: i64) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::i64(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as a U64 TLV value, + /// or as a U32 / U16 / U8 TLV value if the provided data can fit in a smaller domain range. + fn u64(self, tag: TLVTag, data: u64) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::u64(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as an F32 TLV value. + fn f32(self, tag: TLVTag, data: f32) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::f32(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as an F64 TLV value. + fn f64(self, tag: TLVTag, data: f64) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::f64(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as a TLV Octet String. + /// + /// The exact octet string type (Str8l, Str16l, Str32l, or Str64l) is chosen based on the length of the data, + /// whereas the smallest type filling the provided data length is chosen. + fn str(self, tag: TLVTag, data: &'a [u8]) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::str(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and the provided value as a TLV UTF-8 String. + /// + /// The exact UTF-8 string type (Utf8l, Utf16l, Utf32l, or Utf64l) is chosen based on the length of the data, + /// whereas the smallest type filling the provided data length is chosen. + fn utf8(self, tag: TLVTag, data: &'a str) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::utf8(tag, data).into_tlv_iter()) + } + + /// Serialize the given tag and a value indicating the start of a Struct TLV container. + /// + /// NOTE: The user must call `end_container` after serializing all the Struct fields + /// to close the Struct container or else the generated TLV stream will be invalid. + fn start_struct(self, tag: TLVTag) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::structure(tag).into_tlv_iter()) + } + + /// Serialize the given tag and a value indicating the start of an Array TLV container. + /// + /// NOTE: The user must call `end_container` after serializing all the Array elements + /// to close the Array container or else the generated TLV stream will be invalid. + fn start_array(self, tag: TLVTag) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::array(tag).into_tlv_iter()) + } + + /// Serialize the given tag and a value indicating the start of a List TLV container. + /// + /// NOTE: The user must call `end_container` after serializing all the List elements + /// to close the List container or else the generated TLV stream will be invalid. + fn start_list(self, tag: TLVTag) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::list(tag).into_tlv_iter()) + } + + /// Serialize the given tag and a value indicating the start of a TLV container. + /// + /// NOTE: The user must call `end_container` after serializing all the container fields + /// to close the Struct container or else the generated TLV stream will be invalid. + fn start_container(self, tag: TLVTag, container_type: TLVValueType) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + match container_type { + TLVValueType::Struct => self.start_struct(tag), + TLVValueType::Array => self.start_array(tag), + TLVValueType::List => self.start_list(tag), + _ => self.chain(core::iter::once(Err(ErrorCode::TLVTypeMismatch.into()))), + } + } + + /// Serialize a value indicating the end of a Struct, Array, or List TLV container. + /// + /// NOTE: This method must be called only when the corresponding container has been opened + /// using `start_struct`, `start_array`, or `start_list`, or else the generated TLV stream will be invalid. + fn end_container(self) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::end_container().into_tlv_iter()) + } + + /// Serialize the given tag and a value indicating a Null TLV value. + fn null(self, tag: TLVTag) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::null(tag).into_tlv_iter()) + } + + /// Serialize the given tag and a value indicating a True or False TLV value. + fn bool(self, tag: TLVTag, data: bool) -> ChainedTLVIter<'a, Self> + where + Self: 'a, + { + self.chain(TLV::bool(tag, data).into_tlv_iter()) + } +} + +impl<'a, T> TLVIter<'a> for T where T: Iterator> {} + +/// A decorator enum type wrapping two iterators and implementing +/// the `Iterator` trait. +/// +/// Useful when the "to-tlv-iter" implementation needs to return +/// one of two iterators based on some condition. +pub enum EitherIter { + First(F), + Second(S), +} + +impl Iterator for EitherIter +where + F: Iterator, + S: Iterator, +{ + type Item = ::Item; + + fn next(&mut self) -> Option { + match self { + Self::First(i) => i.next(), + Self::Second(i) => i.next(), + } + } +} + +/// A decorator enum type wrapping three iterators and implementing +/// the `Iterator` trait. +/// +/// Useful when the "to-tlv-iter" implementation needs to return +/// one of three iterators based on some condition. +pub enum Either3Iter { + First(F), + Second(S), + Third(T), +} + +impl Iterator for Either3Iter +where + F: Iterator, + S: Iterator, + T: Iterator, +{ + type Item = ::Item; + + fn next(&mut self) -> Option { + match self { + Self::First(i) => i.next(), + Self::Second(i) => i.next(), + Self::Third(i) => i.next(), + } + } +} + +/// A decorator enum type wrapping four iterators and implementing +/// the `Iterator` trait. +/// +/// Useful when the "to-tlv-iter" implementation needs to return +/// one of four iterators based on some condition. +pub enum Either4Iter { + First(F), + Second(S), + Third(T), + Fourth(U), +} + +impl Iterator for Either4Iter +where + F: Iterator, + S: Iterator, + T: Iterator, + U: Iterator, +{ + type Item = ::Item; + + fn next(&mut self) -> Option { + match self { + Self::First(i) => i.next(), + Self::Second(i) => i.next(), + Self::Third(i) => i.next(), + Self::Fourth(i) => i.next(), + } + } +} + +/// A decorator enum type wrapping five iterators and implementing +/// the `Iterator` trait. +/// +/// Useful when the "to-tlv-iter" implementation needs to return +/// one of five iterators based on some condition. +pub enum Either5Iter { + First(F), + Second(S), + Third(T), + Fourth(U), + Fifth(I), +} + +impl Iterator for Either5Iter +where + F: Iterator, + S: Iterator, + T: Iterator, + U: Iterator, + I: Iterator, +{ + type Item = ::Item; + + fn next(&mut self) -> Option { + match self { + Self::First(i) => i.next(), + Self::Second(i) => i.next(), + Self::Third(i) => i.next(), + Self::Fourth(i) => i.next(), + Self::Fifth(i) => i.next(), + } + } +} + +/// A decorator enum type wrapping six iterators and implementing +/// the `Iterator` trait. +/// +/// Useful when the "to-tlv-iter" implementation needs to return +/// one of six iterators based on some condition. +pub enum Either6Iter { + First(F), + Second(S), + Third(T), + Fourth(U), + Fifth(I), + Sixth(X), +} + +impl Iterator for Either6Iter +where + F: Iterator, + S: Iterator, + T: Iterator, + U: Iterator, + I: Iterator, + X: Iterator, +{ + type Item = ::Item; + + fn next(&mut self) -> Option { + match self { + Self::First(i) => i.next(), + Self::Second(i) => i.next(), + Self::Third(i) => i.next(), + Self::Fourth(i) => i.next(), + Self::Fifth(i) => i.next(), + Self::Sixth(i) => i.next(), + } + } +} + +#[cfg(test)] +mod tests { + use core::{f32, iter::empty}; + + use crate::tlv::TLV; + + use super::{TLVIter, TLVResult, TLVTag}; + + fn expect<'a, I>(iter: I, expected: &[u8]) + where + I: Iterator>, + { + let mut iter = iter.map(|r| r.unwrap()).flat_map(TLV::into_bytes_iter); + let mut expected = expected.iter().copied(); + + loop { + match (iter.next(), expected.next()) { + (Some(a), Some(b)) => assert_eq!(a, b), + (None, None) => break, + (Some(_), None) => panic!("Iterator has more bytes than expected"), + (None, Some(_)) => panic!("Iterator has fewer bytes than expected"), + } + } + } + + #[test] + fn test_write_success() { + expect( + empty() + .start_struct(TLVTag::Anonymous) + .u8(TLVTag::Anonymous, 12) + .u8(TLVTag::Context(1), 13) + .u16(TLVTag::Anonymous, 0x1212) + .u16(TLVTag::Context(2), 0x1313) + .start_array(TLVTag::Context(3)) + .bool(TLVTag::Anonymous, true) + .end_container() + .end_container(), + &[ + 21, 4, 12, 36, 1, 13, 5, 0x12, 0x012, 37, 2, 0x13, 0x13, 54, 3, 9, 24, 24, + ], + ); + } + + #[test] + fn test_put_str8() { + expect( + empty() + .u8(TLVTag::Context(1), 13) + .str(TLVTag::Anonymous, &[10, 11, 12, 13, 14]) + .u16(TLVTag::Context(2), 0x1313) + .str(TLVTag::Context(3), &[20, 21, 22]), + &[ + 36, 1, 13, 16, 5, 10, 11, 12, 13, 14, 37, 2, 0x13, 0x13, 48, 3, 3, 20, 21, 22, + ], + ); + } + + #[test] + fn test_matter_spec_examples() { + // Boolean false + + expect(empty().bool(TLVTag::Anonymous, false), &[0x08]); + + // Boolean true + + expect(empty().bool(TLVTag::Anonymous, true), &[0x09]); + + // Signed Integer, 1-octet, value 42 + + expect(empty().i8(TLVTag::Anonymous, 42), &[0x00, 0x2a]); + + // Signed Integer, 1-octet, value -17 + + expect(empty().i8(TLVTag::Anonymous, -17), &[0x00, 0xef]); + + // Unsigned Integer, 1-octet, value 42U + + expect(empty().u8(TLVTag::Anonymous, 42), &[0x04, 0x2a]); + + // Signed Integer, 2-octet, value 422 + + expect(empty().i16(TLVTag::Anonymous, 422), &[0x01, 0xa6, 0x01]); + + // Signed Integer, 4-octet, value -170000 + + expect( + empty().i64(TLVTag::Anonymous, -170000), + &[0x02, 0xf0, 0x67, 0xfd, 0xff], + ); + + // Signed Integer, 8-octet, value 40000000000 + + expect( + empty().i64(TLVTag::Anonymous, 40000000000), + &[0x03, 0x00, 0x90, 0x2f, 0x50, 0x09, 0x00, 0x00, 0x00], + ); + + // UTF-8 String, 1-octet length, "Hello!" + + expect( + empty().utf8(TLVTag::Anonymous, "Hello!"), + &[0x0c, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21], + ); + + // UTF-8 String, 1-octet length, "Tschüs" + + expect( + empty().utf8(TLVTag::Anonymous, "Tschüs"), + &[0x0c, 0x07, 0x54, 0x73, 0x63, 0x68, 0xc3, 0xbc, 0x73], + ); + + // Octet String, 1-octet length, octets 00 01 02 03 04 + + expect( + empty().str(TLVTag::Anonymous, &[0x00, 0x01, 0x02, 0x03, 0x04]), + &[0x10, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04], + ); + + // Null + + expect(empty().null(TLVTag::Anonymous), &[0x14]); + + // Single precision floating point 0.0 + + expect( + empty().f32(TLVTag::Anonymous, 0.0), + &[0x0a, 0x00, 0x00, 0x00, 0x00], + ); + + // Single precision floating point (1.0 / 3.0) + + expect( + empty().f32(TLVTag::Anonymous, 1.0 / 3.0), + &[0x0a, 0xab, 0xaa, 0xaa, 0x3e], + ); + + // Single precision floating point 17.9 + + expect( + empty().f32(TLVTag::Anonymous, 17.9), + &[0x0a, 0x33, 0x33, 0x8f, 0x41], + ); + + // Single precision floating point infinity + + expect( + empty().f32(TLVTag::Anonymous, f32::INFINITY), + &[0x0a, 0x00, 0x00, 0x80, 0x7f], + ); + + // Single precision floating point negative infinity + + expect( + empty().f32(TLVTag::Anonymous, f32::NEG_INFINITY), + &[0x0a, 0x00, 0x00, 0x80, 0xff], + ); + + // Double precision floating point 0.0 + + expect( + empty().f64(TLVTag::Anonymous, 0.0), + &[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00], + ); + + // Double precision floating point (1.0 / 3.0) + + expect( + empty().f64(TLVTag::Anonymous, 1.0 / 3.0), + &[0x0b, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xd5, 0x3f], + ); + + // Double precision floating point 17.9 + + expect( + empty().f64(TLVTag::Anonymous, 17.9), + &[0x0b, 0x66, 0x66, 0x66, 0x66, 0x66, 0xe6, 0x31, 0x40], + ); + + // Double precision floating point infinity (∞) + + expect( + empty().f64(TLVTag::Anonymous, f64::INFINITY), + &[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x7f], + ); + + // Double precision floating point negative infinity + + expect( + empty().f64(TLVTag::Anonymous, f64::NEG_INFINITY), + &[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xff], + ); + + // Empty Structure, {} + + expect( + empty().start_struct(TLVTag::Anonymous).end_container(), + &[0x15, 0x18], + ); + + // Empty Array, [] + + expect( + empty().start_array(TLVTag::Anonymous).end_container(), + &[0x16, 0x18], + ); + + // Empty List, [] + + expect( + empty().start_list(TLVTag::Anonymous).end_container(), + &[0x17, 0x18], + ); + + // Structure, two context specific tags, Signed Integer, 1 octet values, {0 = 42, 1 = -17} + + expect( + empty() + .start_struct(TLVTag::Anonymous) + .i8(TLVTag::Context(0), 42) + .i32(TLVTag::Context(1), -17) + .end_container(), + &[0x15, 0x20, 0x00, 0x2a, 0x20, 0x01, 0xef, 0x18], + ); + + // Array, Signed Integer, 1-octet values, [0, 1, 2, 3, 4] + + expect( + empty() + .start_array(TLVTag::Anonymous) + .i8(TLVTag::Anonymous, 0) + .i8(TLVTag::Anonymous, 1) + .i8(TLVTag::Anonymous, 2) + .i8(TLVTag::Anonymous, 3) + .i8(TLVTag::Anonymous, 4) + .end_container(), + &[ + 0x16, 0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04, 0x18, + ], + ); + + // List, mix of anonymous and context tags, Signed Integer, 1 octet values, [[1, 0 = 42, 2, 3, 0 = -17]] + + expect( + empty() + .start_list(TLVTag::Anonymous) + .i64(TLVTag::Anonymous, 1) + .i16(TLVTag::Context(0), 42) + .i8(TLVTag::Anonymous, 2) + .i8(TLVTag::Anonymous, 3) + .i32(TLVTag::Context(0), -17) + .end_container(), + &[ + 0x17, 0x00, 0x01, 0x20, 0x00, 0x2a, 0x00, 0x02, 0x00, 0x03, 0x20, 0x00, 0xef, 0x18, + ], + ); + + // Array, mix of element types, [42, -170000, {}, 17.9, "Hello!"] + + expect( + empty() + .start_array(TLVTag::Anonymous) + .i64(TLVTag::Anonymous, 42) + .i64(TLVTag::Anonymous, -170000) + .start_struct(TLVTag::Anonymous) + .end_container() + .f32(TLVTag::Anonymous, 17.9) + .utf8(TLVTag::Anonymous, "Hello!") + .end_container(), + &[ + 0x16, 0x00, 0x2a, 0x02, 0xf0, 0x67, 0xfd, 0xff, 0x15, 0x18, 0x0a, 0x33, 0x33, 0x8f, + 0x41, 0x0c, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x18, + ], + ); + + // Anonymous tag, Unsigned Integer, 1-octet value, 42U + + expect(empty().u64(TLVTag::Anonymous, 42), &[0x04, 0x2a]); + + // Context tag 1, Unsigned Integer, 1-octet value, 1 = 42U + + expect(empty().u16(TLVTag::Context(1), 42), &[0x24, 0x01, 0x2a]); + + // Common profile tag 1, Unsigned Integer, 1-octet value, Matter::1 = 42U + + expect( + empty().u16(TLVTag::CommonPrf16(1), 42), + &[0x44, 0x01, 0x00, 0x2a], + ); + + // Common profile tag 100000, Unsigned Integer, 1-octet value, Matter::100000 = 42U + + expect( + empty().u16(TLVTag::CommonPrf32(100000), 42), + &[0x64, 0xa0, 0x86, 0x01, 0x00, 0x2a], + ); + + // Fully qualified tag, Vendor ID 0xFFF1/65521, pro­file number 0xDEED/57069, + // 2-octet tag 1, Unsigned Integer, 1-octet value 42, 65521::57069:1 = 42U + + expect( + empty().u16( + TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 1, + }, + 42, + ), + &[0xc4, 0xf1, 0xff, 0xed, 0xde, 0x01, 0x00, 0x2a], + ); + + // Fully qualified tag, Vendor ID 0xFFF1/65521, pro­file number 0xDEED/57069, + // 4-octet tag 0xAA55FEED/2857762541, Unsigned Integer, 1-octet value 42, 65521::57069:2857762541 = 42U + + expect( + empty().u16( + TLVTag::FullQual64 { + vendor_id: 65521, + profile: 57069, + tag: 2857762541, + }, + 42, + ), + &[0xe4, 0xf1, 0xff, 0xed, 0xde, 0xed, 0xfe, 0x55, 0xaa, 0x2a], + ); + + // Structure with the fully qualified tag, Vendor ID 0xFFF1/65521, profile number 0xDEED/57069, + // 2-octet tag 1. The structure contains a single ele­ment labeled using a fully qualified tag under + // the same profile, with 2-octet tag 0xAA55/43605. 65521::57069:1 = {65521::57069:43605 = 42U} + + expect( + empty() + .start_struct(TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 1, + }) + .u64( + TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 43605, + }, + 42, + ) + .end_container(), + &[ + 0xd5, 0xf1, 0xff, 0xed, 0xde, 0x01, 0x00, 0xc4, 0xf1, 0xff, 0xed, 0xde, 0x55, 0xaa, + 0x2a, 0x18, + ], + ); + } +} diff --git a/rs-matter/src/tlv/traits.rs b/rs-matter/src/tlv/traits.rs index 2839fb06..36bdb4d6 100644 --- a/rs-matter/src/tlv/traits.rs +++ b/rs-matter/src/tlv/traits.rs @@ -15,587 +15,222 @@ * limitations under the License. */ -use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; -use crate::error::{Error, ErrorCode}; -use core::fmt::Debug; -use core::slice::Iter; -use log::error; - -pub trait FromTLV<'a> { - fn from_tlv(t: &TLVElement<'a>) -> Result - where - Self: Sized; - - // I don't think anybody except Option will define this - fn tlv_not_found() -> Result - where - Self: Sized, - { - Err(ErrorCode::TLVNotFound.into()) - } -} - -impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { - fn from_tlv(t: &TLVElement<'a>) -> Result - where - Self: Sized, - { - t.confirm_array()?; - - let mut a = heapless::Vec::::new(); - if let Some(tlv_iter) = t.enter() { - for element in tlv_iter { - a.push(T::from_tlv(&element)?) - .map_err(|_| ErrorCode::NoSpace)?; - } - } - - // TODO: This was the old behavior before rebasing the - // implementation on top of heapless::Vec (to avoid requiring Copy) - // Not sure why we actually need that yet, but without it unit tests fail - while a.len() < N { - a.push(Default::default()).map_err(|_| ErrorCode::NoSpace)?; - } - - a.into_array().map_err(|_| ErrorCode::Invalid.into()) - } -} - -pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>( - vec: &mut heapless::Vec, - t: &TLVElement<'a>, -) -> Result<(), Error> { - vec.clear(); - - t.confirm_array()?; - - if let Some(tlv_iter) = t.enter() { - for element in tlv_iter { - vec.push(T::from_tlv(&element)?) - .map_err(|_| ErrorCode::NoSpace)?; - } - } - - Ok(()) -} - -pub fn vec_from_tlv<'a, T: FromTLV<'a>, const N: usize>( - vec: &mut crate::utils::storage::Vec, - t: &TLVElement<'a>, -) -> Result<(), Error> { - vec.clear(); - - t.confirm_array()?; - - if let Some(tlv_iter) = t.enter() { - for element in tlv_iter { - vec.push(T::from_tlv(&element)?) - .map_err(|_| ErrorCode::NoSpace)?; +use crate::error::Error; +use crate::utils::init; + +use super::{ + EitherIter, TLVElement, TLVSequenceIter, TLVSequenceTLVIter, TLVTag, TLVValue, TLVValueType, + TLVWrite, TLV, +}; + +pub use container::*; +pub use maybe::*; +pub use octets::*; +pub use slice::*; +pub use str::*; + +mod array; +mod bitflags; +mod container; +mod maybe; +mod octets; +mod primitive; +mod slice; +mod str; +mod vec; + +/// A trait representing Rust types that can deserialize themselves from +/// a TLV-encoded byte slice. +pub trait FromTLV<'a>: Sized + 'a { + /// Deserialize the type from a TLV-encoded element. + fn from_tlv(element: &TLVElement<'a>) -> Result; + + /// Generate an in-place initializer for the type that initializes + /// the type from a TLV-encoded element. + fn init_from_tlv(element: TLVElement<'a>) -> impl init::Init { + unsafe { + init::init_from_closure(move |slot| { + core::ptr::write(slot, Self::from_tlv(&element)?); + + Ok(()) + }) } } - - Ok(()) -} - -macro_rules! fromtlv_for { - ($($t:ident)*) => { - $( - impl<'a> FromTLV<'a> for $t { - fn from_tlv(t: &TLVElement) -> Result { - t.$t() - } - } - )* - }; -} - -macro_rules! fromtlv_for_nonzero { - ($($t:ident:$n:ty)*) => { - $( - impl<'a> FromTLV<'a> for $n { - fn from_tlv(t: &TLVElement) -> Result { - <$n>::new(t.$t()?).ok_or_else(|| ErrorCode::Invalid.into()) - } - } - )* - }; } -fromtlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); -fromtlv_for_nonzero!(i8:core::num::NonZeroI8 u8:core::num::NonZeroU8 i16:core::num::NonZeroI16 u16:core::num::NonZeroU16 i32:core::num::NonZeroI32 u32:core::num::NonZeroU32 i64:core::num::NonZeroI64 u64:core::num::NonZeroU64); - +/// A trait representing Rust types that can serialize themselves to +/// a TLV-encoded stream. pub trait ToTLV { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; + /// Serialize the type to a TLV-encoded stream. + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error>; + + /// Serialize the type as an iterator of `TLV` instances by potentially borrowing + /// data from the type. + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator>; } impl ToTLV for &T where T: ToTLV, { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - (**self).to_tlv(tw, tag) - } -} - -macro_rules! totlv_for { - ($($t:ident)*) => { - $( - impl ToTLV for $t { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.$t(tag, *self) - } - } - )* - }; -} - -macro_rules! totlv_for_nonzero { - ($($t:ident:$n:ty)*) => { - $( - impl ToTLV for $n { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.$t(tag, self.get()) - } - } - )* - }; -} - -impl ToTLV for [T; N] { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.start_array(tag)?; - for i in self { - i.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - } -} - -impl<'a, T: ToTLV> ToTLV for &'a [T] { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.start_array(tag)?; - for i in *self { - i.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - } -} - -// Generate ToTLV for standard data types -totlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); -totlv_for_nonzero!(i8:core::num::NonZeroI8 u8:core::num::NonZeroU8 i16:core::num::NonZeroI16 u16:core::num::NonZeroU16 i32:core::num::NonZeroI32 u32:core::num::NonZeroU32 i64:core::num::NonZeroI64 u64:core::num::NonZeroU64); - -// We define a few common data types that will be required here -// -// - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec -// - These only have references into the original list -// - heapless::String, Vheapless::ec: Is the owned version of utfstr and ostr, data is cloned into this -// - heapless::String is only partially implemented -// -// - TLVArray: Is an array of entries, with reference within the original list - -/// Implements UTFString from the spec -#[derive(Debug, Copy, Clone, PartialEq, Default, Hash, Eq)] -pub struct UtfStr<'a>(pub &'a [u8]); - -impl<'a> UtfStr<'a> { - pub const fn new(str: &'a [u8]) -> Self { - Self(str) - } - - pub fn as_str(&self) -> Result<&str, Error> { - core::str::from_utf8(self.0).map_err(|_| ErrorCode::Invalid.into()) - } -} - -impl<'a> ToTLV for UtfStr<'a> { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.utf16(tag, self.0) - } -} - -impl<'a> FromTLV<'a> for UtfStr<'a> { - fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { - t.slice().map(UtfStr) - } -} - -/// Implements OctetString from the spec -#[derive(Debug, Copy, Clone, PartialEq, Default, Hash, Eq)] -pub struct OctetStr<'a>(pub &'a [u8]); - -impl<'a> OctetStr<'a> { - pub fn new(str: &'a [u8]) -> Self { - Self(str) + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + (*self).to_tlv(tag, tw) } -} -impl<'a> FromTLV<'a> for OctetStr<'a> { - fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { - t.slice().map(OctetStr) + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + (*self).tlv_iter(tag) } } -impl<'a> ToTLV for OctetStr<'a> { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.str16(tag, self.0) +impl<'a> FromTLV<'a> for TLVElement<'a> { + fn from_tlv(element: &TLVElement<'a>) -> Result { + Ok(element.clone()) } } -/// Implements the Owned version of Octet String -impl FromTLV<'_> for heapless::Vec { - fn from_tlv(t: &TLVElement) -> Result, Error> { - heapless::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into()) +impl<'a> ToTLV for TLVElement<'a> { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + if self.is_empty() { + // Special-case serializing empty TLV elements to nothing + // Useful in tests + Ok(()) + } else { + tw.raw_value(tag, self.control()?.value_type, self.raw_value()?) + } } -} -impl ToTLV for heapless::Vec { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.str16(tag, self.as_slice()) + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + TLVElementTLVIter::Start(tag, self.clone()) } } -/// Implements the Owned version of Octet String -impl FromTLV<'_> for crate::utils::storage::Vec { - fn from_tlv(t: &TLVElement) -> Result, Error> { - crate::utils::storage::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into()) - } +enum TLVElementTLVIter<'a> { + Start(TLVTag, TLVElement<'a>), + Seq(TLVSequenceTLVIter<'a>), + Finished, } -impl ToTLV for crate::utils::storage::Vec { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.str16(tag, self.as_slice()) - } -} +impl<'a> Iterator for TLVElementTLVIter<'a> { + type Item = Result, Error>; -/// Implements the Owned version of UTF String -impl FromTLV<'_> for heapless::String { - fn from_tlv(t: &TLVElement) -> Result, Error> { - let mut string = heapless::String::new(); + fn next(&mut self) -> Option { + match core::mem::replace(self, Self::Finished) { + TLVElementTLVIter::Start(tag, elem) => { + if elem.is_empty() { + // Special-case serializing empty TLV elements to nothing + // Useful in tests + None + } else { + let value = elem.value().map(|value| TLV::new(tag, value)); - string - .push_str(core::str::from_utf8(t.slice()?)?) - .map_err(|_| ErrorCode::NoSpace)?; + if let Ok(seq) = elem.container() { + *self = Self::Seq(seq.tlv_iter()); + } else { + *self = TLVElementTLVIter::Finished; + } - Ok(string) + Some(value) + } + } + TLVElementTLVIter::Seq(mut iter) => { + if let Some(value) = iter.next() { + *self = TLVElementTLVIter::Seq(iter); + Some(value) + } else { + Some(Ok(TLV::end_container())) + } + } + TLVElementTLVIter::Finished => None, + } } } -impl ToTLV for heapless::String { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.utf16(tag, self.as_bytes()) +impl<'a> FromTLV<'a> for TLVValue<'a> { + fn from_tlv(element: &TLVElement<'a>) -> Result { + element.value() } } -/// Applies to all the Option<> Processing -impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option { - fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { - Ok(Some(T::from_tlv(t)?)) - } - - fn tlv_not_found() -> Result - where - Self: Sized, - { - Ok(None) +impl<'a> ToTLV for TLVValue<'a> { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.tlv(tag, self) } -} -impl ToTLV for Option { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - match self { - Some(s) => s.to_tlv(tw, tag), - None => Ok(()), - } + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + TLV::new(tag, self.clone()).into_tlv_iter() } } -/// Represent a nullable value -/// -/// The value may be null or a valid value -/// Note: Null is different from Option. If the value is optional, include Option<> too. For -/// example, Option> -#[derive(Copy, Clone, PartialEq, Debug, Hash, Eq)] -pub enum Nullable { - Null, - NotNull(T), -} - -impl Nullable { - pub fn as_mut(&mut self) -> Nullable<&mut T> { - match self { - Nullable::Null => Nullable::Null, - Nullable::NotNull(t) => Nullable::NotNull(t), - } - } +#[cfg(test)] +mod tests { + use core::fmt::Debug; + use core::mem::MaybeUninit; - pub fn as_ref(&self) -> Nullable<&T> { - match self { - Nullable::Null => Nullable::Null, - Nullable::NotNull(t) => Nullable::NotNull(t), - } - } + use rs_matter_macros::{FromTLV, ToTLV}; - pub fn is_null(&self) -> bool { - match self { - Nullable::Null => true, - Nullable::NotNull(_) => false, - } - } + use crate::tlv::{Octets, TLVElement, TLVWriter, TLV}; + use crate::utils::init::InitMaybeUninit; + use crate::utils::storage::WriteBuf; - pub fn notnull(self) -> Option { - match self { - Nullable::Null => None, - Nullable::NotNull(t) => Some(t), - } - } -} + use super::{FromTLV, OctetStr, TLVTag, ToTLV}; -impl<'a, T: FromTLV<'a>> FromTLV<'a> for Nullable { - fn from_tlv(t: &TLVElement<'a>) -> Result, Error> { - match t.get_element_type() { - ElementType::Null => Ok(Nullable::Null), - _ => Ok(Nullable::NotNull(T::from_tlv(t)?)), - } - } -} + fn test_from_tlv<'a, T: FromTLV<'a> + PartialEq + Debug>(data: &'a [u8], expected: T) { + let root = TLVElement::new(data); + let test = T::from_tlv(&root).unwrap(); + assert_eq!(test, expected); -impl ToTLV for Nullable { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - match self { - Nullable::Null => tw.null(tag), - Nullable::NotNull(s) => s.to_tlv(tw, tag), - } - } -} + let test_init = T::init_from_tlv(root); -#[derive(Clone)] -pub enum TLVArray<'a, T> { - // This is used for the to-tlv path - Slice(&'a [T]), - // This is used for the from-tlv path - Ptr(TLVElement<'a>), -} + let mut test = MaybeUninit::::uninit(); -pub enum TLVArrayIter<'a, T> { - Slice(Iter<'a, T>), - Ptr(Option>), -} + let test = test.try_init_with(test_init).unwrap(); -impl<'a, T: ToTLV> TLVArray<'a, T> { - pub fn new(slice: &'a [T]) -> Self { - Self::Slice(slice) + assert_eq!(*test, expected); } - pub fn iter(&self) -> TLVArrayIter<'a, T> { - match self { - Self::Slice(s) => TLVArrayIter::Slice(s.iter()), - Self::Ptr(p) => TLVArrayIter::Ptr(p.enter()), - } - } -} + fn test_to_tlv(t: T, expected: &[u8]) { + let mut buf = [0; 20]; + let mut writebuf = WriteBuf::new(&mut buf); + let mut tw = TLVWriter::new(&mut writebuf); -impl<'a, T: ToTLV + FromTLV<'a> + Clone> TLVArray<'a, T> { - pub fn get_index(&self, index: usize) -> T { - for (curr, element) in self.iter().enumerate() { - if curr == index { - return element; - } - } - panic!("Out of bounds"); - } -} + t.to_tlv(&TLVTag::Anonymous, &mut tw).unwrap(); -impl<'a, T: FromTLV<'a> + Clone> Iterator for TLVArrayIter<'a, T> { - type Item = T; - /* Code for going to the next Element */ - fn next(&mut self) -> Option { - match self { - Self::Slice(s_iter) => s_iter.next().cloned(), - Self::Ptr(p_iter) => { - if let Some(tlv_iter) = p_iter.as_mut() { - let e = tlv_iter.next(); - if let Some(element) = e { - T::from_tlv(&element).ok() - } else { - None - } - } else { - None - } - } - } - } -} + assert_eq!(writebuf.as_slice(), expected); -impl<'a, 'b, T> PartialEq> for TLVArray<'a, T> -where - T: ToTLV + FromTLV<'a> + Clone + PartialEq, - 'b: 'a, -{ - fn eq(&self, other: &TLVArray<'b, T>) -> bool { - let mut iter1 = self.iter(); - let mut iter2 = other.iter(); - loop { - match (iter1.next(), iter2.next()) { - (None, None) => return true, - (Some(x), Some(y)) => { - if x != y { - return false; - } - } - _ => return false, - } - } - } -} + writebuf.reset(); -impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> -where - T: ToTLV + FromTLV<'a> + Clone + PartialEq, -{ - fn eq(&self, other: &&[T]) -> bool { - let mut iter1 = self.iter(); - let mut iter2 = other.iter(); + let mut iter = t + .tlv_iter(TLVTag::Anonymous) + .flat_map(TLV::result_into_bytes_iter); loop { - match (iter1.next(), iter2.next()) { - (None, None) => return true, - (Some(x), Some(y)) => { - if x != *y { - return false; - } - } - _ => return false, - } - } - } -} - -impl<'a, T: FromTLV<'a> + Clone + ToTLV> ToTLV for TLVArray<'a, T> { - fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { - tw.start_array(tag_type)?; - for a in self.iter() { - a.to_tlv(tw, TagType::Anonymous)?; - } - tw.end_container() - // match *self { - // Self::Slice(s) => { - // tw.start_array(tag_type)?; - // for a in s { - // a.to_tlv(tw, TagType::Anonymous)?; - // } - // tw.end_container() - // } - // Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV - // } - } -} - -impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { - fn from_tlv(t: &TLVElement<'a>) -> Result { - t.confirm_array()?; - Ok(Self::Ptr(t.clone())) - } -} - -impl<'a, T: Debug + ToTLV + FromTLV<'a> + Clone> Debug for TLVArray<'a, T> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "TLVArray [")?; - let mut first = true; - for i in self.iter() { - if !first { - write!(f, ", ")?; + match iter.next() { + Some(Ok(byte)) => writebuf.append(&[byte]).unwrap(), + None => break, + _ => panic!("Error in iterator"), } - - write!(f, "{:?}", i)?; - first = false; } - write!(f, "]") - } -} -impl<'a> ToTLV for TLVElement<'a> { - fn to_tlv(&self, tw: &mut TLVWriter, _tag_type: TagType) -> Result<(), Error> { - match self.get_element_type() { - ElementType::S8(v) => v.to_tlv(tw, self.get_tag()), - ElementType::U8(v) => v.to_tlv(tw, self.get_tag()), - ElementType::U16(v) => v.to_tlv(tw, self.get_tag()), - ElementType::S16(v) => v.to_tlv(tw, self.get_tag()), - ElementType::U32(v) => v.to_tlv(tw, self.get_tag()), - ElementType::S32(v) => v.to_tlv(tw, self.get_tag()), - ElementType::U64(v) => v.to_tlv(tw, self.get_tag()), - ElementType::S64(v) => v.to_tlv(tw, self.get_tag()), - ElementType::False => tw.bool(self.get_tag(), false), - ElementType::True => tw.bool(self.get_tag(), true), - ElementType::Utf8l(v) | ElementType::Utf16l(v) => tw.utf16(self.get_tag(), v), - ElementType::Str8l(v) | ElementType::Str16l(v) => tw.str16(self.get_tag(), v), - ElementType::Null => tw.null(self.get_tag()), - ElementType::Struct(_) => tw.start_struct(self.get_tag()), - ElementType::Array(_) => tw.start_array(self.get_tag()), - ElementType::List(_) => tw.start_list(self.get_tag()), - ElementType::EndCnt => tw.end_container(), - _ => { - error!("ToTLV Not supported"); - Err(ErrorCode::Invalid.into()) - } - } + assert_eq!(writebuf.as_slice(), expected); } -} - -/// Implements to/from TLV for the given enumeration that was -/// created using `bitflags!` -/// -/// NOTE: -/// - bitflgs are generally unrestricted. The provided implementations -/// do NOT attempt to validate flags for validity and the entire -/// range of flags will be marshalled (including unknown flags) -#[macro_export] -macro_rules! bitflags_tlv { - ($enum_name:ident, $type:ident) => { - impl FromTLV<'_> for $enum_name { - fn from_tlv(t: &TLVElement) -> Result { - Ok(Self::from_bits_retain(t.$type()?)) - } - } - - impl ToTLV for $enum_name { - fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { - tw.$type(tag, self.bits()) - } - } - }; -} - -#[cfg(test)] -mod tests { - use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV}; - use crate::tlv::TLVList; - use crate::utils::storage::WriteBuf; - use rs_matter_macros::{FromTLV, ToTLV}; #[derive(ToTLV)] struct TestDerive { a: u16, b: u32, } + #[test] fn test_derive_totlv() { - let mut buf = [0; 20]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - let abc = TestDerive { - a: 0x1010, - b: 0x20202020, - }; - abc.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!( - buf, - [21, 37, 0, 0x10, 0x10, 38, 1, 0x20, 0x20, 0x20, 0x20, 24, 0, 0, 0, 0, 0, 0, 0, 0] + test_to_tlv( + TestDerive { + a: 0x1010, + b: 0x20202020, + }, + &[21, 37, 0, 0x10, 0x10, 38, 1, 0x20, 0x20, 0x20, 0x20, 24], ); } - #[derive(FromTLV)] + #[derive(FromTLV, Debug, PartialEq)] struct TestDeriveSimple { a: u16, b: u32, @@ -603,16 +238,13 @@ mod tests { #[test] fn test_derive_fromtlv() { - let b = [ - 21, 37, 0, 10, 0, 38, 1, 20, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, - ]; - let root = TLVList::new(&b).iter().next().unwrap(); - let test = TestDeriveSimple::from_tlv(&root).unwrap(); - assert_eq!(test.a, 10); - assert_eq!(test.b, 20); + test_from_tlv( + &[21, 37, 0, 10, 0, 38, 1, 20, 0, 0, 0, 24], + TestDeriveSimple { a: 10, b: 20 }, + ); } - #[derive(FromTLV)] + #[derive(FromTLV, Debug, PartialEq)] #[tlvargs(lifetime = "'a")] struct TestDeriveStr<'a> { a: u16, @@ -621,14 +253,16 @@ mod tests { #[test] fn test_derive_fromtlv_str() { - let b = [21, 37, 0, 10, 0, 0x30, 0x01, 0x03, 10, 11, 12, 0]; - let root = TLVList::new(&b).iter().next().unwrap(); - let test = TestDeriveStr::from_tlv(&root).unwrap(); - assert_eq!(test.a, 10); - assert_eq!(test.b, OctetStr(&[10, 11, 12])); + test_from_tlv( + &[21, 37, 0, 10, 0, 0x30, 0x01, 0x03, 10, 11, 12, 0], + TestDeriveStr { + a: 10, + b: Octets(&[10, 11, 12]), + }, + ); } - #[derive(FromTLV, Debug)] + #[derive(FromTLV, Debug, PartialEq)] struct TestDeriveOption { a: u16, b: Option, @@ -637,41 +271,36 @@ mod tests { #[test] fn test_derive_fromtlv_option() { - let b = [21, 37, 0, 10, 0, 37, 2, 11, 0]; - let root = TLVList::new(&b).iter().next().unwrap(); - let test = TestDeriveOption::from_tlv(&root).unwrap(); - assert_eq!(test.a, 10); - assert_eq!(test.b, None); - assert_eq!(test.c, Some(11)); + test_from_tlv( + &[21, 37, 0, 10, 0, 37, 2, 11, 0], + TestDeriveOption { + a: 10, + b: None, + c: Some(11), + }, + ); } - #[derive(FromTLV, ToTLV, Debug)] + #[derive(FromTLV, ToTLV, Debug, PartialEq)] struct TestDeriveFabScoped { a: u16, #[tagval(0xFE)] fab_idx: u16, } + #[test] fn test_derive_fromtlv_fab_scoped() { - let b = [21, 37, 0, 10, 0, 37, 0xFE, 11, 0]; - let root = TLVList::new(&b).iter().next().unwrap(); - let test = TestDeriveFabScoped::from_tlv(&root).unwrap(); - assert_eq!(test.a, 10); - assert_eq!(test.fab_idx, 11); + test_from_tlv( + &[21, 37, 0, 10, 0, 37, 0xFE, 11, 0], + TestDeriveFabScoped { a: 10, fab_idx: 11 }, + ); } #[test] fn test_derive_totlv_fab_scoped() { - let mut buf = [0; 20]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 }; - - abc.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!( - buf, - [21, 36, 0, 20, 36, 0xFE, 3, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + test_to_tlv( + TestDeriveFabScoped { a: 20, fab_idx: 3 }, + &[21, 36, 0, 20, 36, 0xFE, 3, 24], ); } @@ -684,23 +313,9 @@ mod tests { #[test] fn test_derive_from_to_tlv_enum() { // Test FromTLV - let b = [21, 36, 0, 100, 24, 0]; - let root = TLVList::new(&b).iter().next().unwrap(); - let mut enum_val = TestDeriveEnum::from_tlv(&root).unwrap(); - assert_eq!(enum_val, TestDeriveEnum::ValueA(100)); - - // Modify the value and test ToTLV - enum_val = TestDeriveEnum::ValueB(10); + test_from_tlv(&[21, 36, 0, 100, 24, 0], TestDeriveEnum::ValueA(100)); // Test ToTLV - let mut buf = [0; 20]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!( - buf, - [21, 36, 1, 10, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] - ); + test_to_tlv(TestDeriveEnum::ValueB(10), &[21, 36, 1, 10, 24]); } } diff --git a/rs-matter/src/tlv/traits/array.rs b/rs-matter/src/tlv/traits/array.rs new file mode 100644 index 00000000..1570b03d --- /dev/null +++ b/rs-matter/src/tlv/traits/array.rs @@ -0,0 +1,70 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for Rust built-in arrays. +//! Rust bilt-in arrays are serialized and deserialized as TLV arrays. +//! +//! The deserialization support requires `T` to implement `Default`, or else +//! the deserialization will not work for the cases where the deserialized TLV array +//! turns out to be shorter than the Rust array into which we deserialize. +//! +//! Note that the implementation below CANNOT efficiently in-place initialize the arrays, +//! as that would imply that the array elements should implement the unsafe `Zeroed` trait +//! instead of `Default`. +//! Since that would restrict the use-cases where built-in arrays can be utilized, +//! the implementation below requires `Default` instead for the array elements. +//! +//! Therefore, use `Vec` instead of built-in arrays if you need to efficiently in-place initialize +//! (potentially large) arrays. + +use crate::error::{Error, ErrorCode}; +use crate::utils::storage::Vec; + +use super::{tlv_array_iter, FromTLV, TLVArray, TLVElement, TLVTag, TLVWrite, ToTLV, TLV}; + +impl<'a, T, const N: usize> FromTLV<'a> for [T; N] +where + T: FromTLV<'a> + Default, +{ + fn from_tlv(element: &TLVElement<'a>) -> Result { + let mut vec = Vec::::new(); + + for item in TLVArray::new(element.clone())? { + vec.push(item?).map_err(|_| ErrorCode::NoSpace)?; + } + + while !vec.is_full() { + vec.push(Default::default()) + .map_err(|_| ErrorCode::NoSpace)?; + } + + Ok(vec.into_array().map_err(|_| ErrorCode::NoSpace).unwrap()) + } +} + +impl ToTLV for [T; N] +where + T: ToTLV, +{ + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + self.as_slice().to_tlv(tag, tw) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + tlv_array_iter(tag, self.iter()) + } +} diff --git a/rs-matter/src/tlv/traits/bitflags.rs b/rs-matter/src/tlv/traits/bitflags.rs new file mode 100644 index 00000000..eb9cb861 --- /dev/null +++ b/rs-matter/src/tlv/traits/bitflags.rs @@ -0,0 +1,56 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for `bitflags!`. +//! Bitflags are serialized and deserialized as TLV enumerations. + +/// Implements to/from TLV for the given enumeration that was +/// created using `bitflags!` +/// +/// NOTE: +/// - bitflgs are generally unrestricted. The provided implementations +/// do NOT attempt to validate flags for validity and the entire +/// range of flags will be marshalled (including unknown flags) +#[macro_export] +macro_rules! bitflags_tlv { + ($enum_name:ident, $type:ident) => { + impl<'a> $crate::tlv::FromTLV<'a> for $enum_name { + fn from_tlv(element: &$crate::tlv::TLVElement<'a>) -> Result { + Ok(Self::from_bits_retain($crate::tlv::TLVElement::$type( + element, + )?)) + } + } + + impl $crate::tlv::ToTLV for $enum_name { + fn to_tlv( + &self, + tag: &$crate::tlv::TLVTag, + mut tw: W, + ) -> Result<(), Error> { + tw.$type(tag, self.bits()) + } + + fn tlv_iter( + &self, + tag: $crate::tlv::TLVTag, + ) -> impl Iterator> { + $crate::tlv::TLV::$type(tag, self.bits()).into_tlv_iter() + } + } + }; +} diff --git a/rs-matter/src/tlv/traits/container.rs b/rs-matter/src/tlv/traits/container.rs new file mode 100644 index 00000000..c6c2bc08 --- /dev/null +++ b/rs-matter/src/tlv/traits/container.rs @@ -0,0 +1,454 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! A container type (`TLVContainer`) and an iterator type (`TLVContainerIter`) that represent and iterate directly over serialized TLV containers. +//! As such, the memory prepresentation of `TLVContainer` and `TLVContainerIter` is just a byte slice (`&[u8]`), +//! and the container elements are materialized (with `FromTLV`) only when the container is iterated over. +//! +//! The difference between `TLVContainer` and `TLVContainerIter` on one side, and `TLVElement`, `TLVSequence` and `TLVSequenceIter` on the other +//! is that the former are generified by type `T: FromTLV<'_>` and can directly yield values of type `T` when iterated over, +//! while iterating over a `TLVSequence` with a `TLVSequenceIter` always yields elements of type `TLVElement`. +//! +//! Thus, a `TLVContainer, ()`> is equivalent to a `TLVElement` which represents a container and +//! `TLVContainerIter>` is equivalent to a `TLVSequenceIter<'_>` that is obtained by `element.container()?.iter()`. + +use core::fmt; +use core::marker::PhantomData; + +use crate::error::Error; +use crate::utils::init; + +use super::{EitherIter, FromTLV, TLVElement, TLVSequenceIter, TLVTag, TLVWrite, ToTLV, TLV}; + +/// A type-state that indicates that the container can be any type of container (array, list or struct). +pub type AnyContainer = (); + +/// A type-state that indicates that the container should be an array. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ArrayContainer; + +/// A type-state that indicates that the container should be a list. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ListContainer; + +/// A type-state that indicates that the container should be a struct. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct StructContainer; + +/// A type alias for an array TLV container. +pub type TLVArray<'a, T> = TLVContainer<'a, T, ArrayContainer>; +/// A type alias for a list TLV container. +pub type TLVList<'a, T> = TLVContainer<'a, T, ListContainer>; +/// A type alias for a struct TLV container. +pub type TLVStruct<'a, T> = TLVContainer<'a, T, StructContainer>; + +/// `TLVContainer` is an efficient (memory-wise) way to represent a serialized TLV container, in that +/// it does not materialize the container elements until the container is iterated over. +/// +/// Therefore, `TLVContainer` is just a wrapper (newtype) of the serialized TLV container `&[u8]` slice. +#[derive(Clone, PartialEq, Eq, Hash)] +#[repr(transparent)] +pub struct TLVContainer<'a, T, C = AnyContainer> { + element: TLVElement<'a>, + _type: PhantomData T>, + _container_type: PhantomData, +} + +impl<'a, T, C> TLVContainer<'a, T, C> +where + T: FromTLV<'a>, +{ + /// Creates a new `TLVContainer` from a TLV element. + /// The constructor does not check whether the passed slice is a valid TLV container. + pub const fn new_unchecked(element: TLVElement<'a>) -> Self { + Self { + element, + _type: PhantomData, + _container_type: PhantomData, + } + } + + pub fn element(&self) -> &TLVElement<'a> { + &self.element + } + + /// Returns an iterator over the elements of the container. + pub fn iter(&self) -> TLVContainerIter<'a, T> { + TLVContainerIter::new(self.element.container().unwrap().iter()) + } +} + +impl<'a, T> TLVContainer<'a, T, AnyContainer> +where + T: FromTLV<'a>, +{ + /// Creates a new `TLVContainer` from a TLV element that can be any container. + pub fn new(element: TLVElement<'a>) -> Result { + if !element.is_empty() { + element.container()?; + } + + Ok(Self::new_unchecked(element)) + } +} + +impl<'a, T> TLVContainer<'a, T, ArrayContainer> +where + T: FromTLV<'a>, +{ + /// Creates a new `TLVContainer` from a TLV element that is expected to be of type array. + pub fn new(element: TLVElement<'a>) -> Result { + if !element.is_empty() { + element.array()?; + } + + Ok(Self::new_unchecked(element)) + } +} + +impl<'a, T> TLVContainer<'a, T, ListContainer> +where + T: FromTLV<'a>, +{ + /// Creates a new `TLVContainer` from a TLV element that is expected to be of type list. + pub fn new(element: TLVElement<'a>) -> Result { + if !element.is_empty() { + element.list()?; + } + + Ok(Self::new_unchecked(element)) + } +} + +impl<'a, T> TLVContainer<'a, T, StructContainer> +where + T: FromTLV<'a>, +{ + /// Creates a new `TLVContainer` from a TLV element that is expected to be of type struct. + pub fn new(element: TLVElement<'a>) -> Result { + if !element.is_empty() { + element.structure()?; + } + + Ok(Self::new_unchecked(element)) + } +} + +impl<'a, T, C> IntoIterator for TLVContainer<'a, T, C> +where + T: FromTLV<'a>, +{ + type Item = Result; + type IntoIter = TLVContainerIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T, C> IntoIterator for &TLVContainer<'a, T, C> +where + T: FromTLV<'a>, +{ + type Item = Result; + type IntoIter = TLVContainerIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl<'a, T, C> fmt::Debug for TLVContainer<'a, T, C> +where + T: FromTLV<'a> + fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + + let mut first = true; + + for elem in self.iter() { + if first { + first = false; + } else { + write!(f, ", ")?; + } + + write!(f, "{elem:?}")?; + } + + write!(f, "]") + } +} + +impl<'a, T, C> FromTLV<'a> for TLVContainer<'a, T, C> +where + T: FromTLV<'a>, + C: 'a, +{ + fn from_tlv(element: &TLVElement<'a>) -> Result { + Ok(Self::new_unchecked(element.clone())) + } +} + +impl<'a, T, C> ToTLV for TLVContainer<'a, T, C> { + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + self.element.to_tlv(tag, tw) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + self.element.tlv_iter(tag) + } +} + +/// An iterator over a serialized TLV container. +#[repr(transparent)] +pub struct TLVContainerIter<'a, T> { + iter: TLVSequenceIter<'a>, + _type: PhantomData T>, +} + +impl<'a, T> TLVContainerIter<'a, T> +where + T: FromTLV<'a>, +{ + /// Create a new `TLVContainerIter` from a TLV sequence iterator. + pub const fn new(iter: TLVSequenceIter<'a>) -> Self { + Self { + iter, + _type: PhantomData, + } + } + + pub fn try_next(&mut self) -> Option> { + let tlv = self.iter.next()?; + + Some(tlv.and_then(|tlv| T::from_tlv(&tlv))) + } + + pub fn try_next_init(&mut self) -> Option + 'a, Error>> { + let tlv = self.iter.next()?; + + Some(tlv.map(|tlv| T::init_from_tlv(tlv))) + } +} + +impl<'a, T> Iterator for TLVContainerIter<'a, T> +where + T: FromTLV<'a>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + self.try_next() + } +} + +/// A container type that can represent either a serialized TLV array or a slice of elements. +/// +/// Necessary for the few cases in the code where deserialized TLV structures are mutated - +/// post deserialization - with custom array data. +#[derive(Debug, Clone)] +pub enum TLVArrayOrSlice<'a, T> +where + T: FromTLV<'a>, +{ + Array(TLVArray<'a, T>), + Slice(&'a [T]), +} + +impl<'a, T> TLVArrayOrSlice<'a, T> +where + T: FromTLV<'a>, +{ + /// Creates a new `TLVArrayOrSlice` from a TLV slice. + pub const fn new_array(array: TLVArray<'a, T>) -> Self { + Self::Array(array) + } + + /// Creates a new `TLVArrayOrSlice` from a slice. + pub const fn new_slice(slice: &'a [T]) -> Self { + Self::Slice(slice) + } + + /// Returns an iterator over the elements of the array. + pub fn iter(&self) -> Result, Error> { + match self { + Self::Array(array) => Ok(TLVArrayOrSliceIter::Array(array.iter())), + Self::Slice(slice) => Ok(TLVArrayOrSliceIter::Slice(slice.iter())), + } + } +} + +impl<'a, T> FromTLV<'a> for TLVArrayOrSlice<'a, T> +where + T: FromTLV<'a>, +{ + fn from_tlv(element: &TLVElement<'a>) -> Result { + Ok(Self::new_array(TLVArray::new(element.clone())?)) + } +} + +impl<'a, T> ToTLV for TLVArrayOrSlice<'a, T> +where + T: FromTLV<'a>, + T: ToTLV, +{ + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + match self { + Self::Array(array) => array.to_tlv(tag, tw), + Self::Slice(slice) => slice.to_tlv(tag, tw), + } + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + match self { + Self::Array(array) => EitherIter::First(array.tlv_iter(tag)), + Self::Slice(slice) => EitherIter::Second(slice.tlv_iter(tag)), + } + } +} + +/// An iterator over the `TLVArrayOrSlice` elements. +pub enum TLVArrayOrSliceIter<'a, T> { + Array(TLVContainerIter<'a, T>), + Slice(core::slice::Iter<'a, T>), +} + +impl<'a, T> Iterator for TLVArrayOrSliceIter<'a, T> +where + T: FromTLV<'a> + Clone, +{ + type Item = Result; + + fn next(&mut self) -> Option { + match self { + Self::Array(array) => array.next(), + Self::Slice(slice) => slice.next().cloned().map(|t| Ok(t)), + } + } +} + +// impl<'a, T: ToTLV + FromTLV<'a> + Clone> TLVArray<'a, T> { +// pub fn get_index(&self, index: usize) -> T { +// for (curr, element) in self.iter().enumerate() { +// if curr == index { +// return element; +// } +// } +// panic!("Out of bounds"); +// } +// } + +// // impl<'a, 'b, T> PartialEq> for TLVArray<'a, T> +// // where +// // T: ToTLV + FromTLV<'a> + Clone + PartialEq, +// // 'b: 'a, +// // { +// // fn eq(&self, other: &TLVArray<'b, T>) -> bool { +// // let mut iter1 = self.iter(); +// // let mut iter2 = other.iter(); +// // loop { +// // match (iter1.next(), iter2.next()) { +// // (None, None) => return true, +// // (Some(x), Some(y)) => { +// // if x != y { +// // return false; +// // } +// // } +// // _ => return false, +// // } +// // } +// // } +// // } + +// // impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> +// // where +// // T: ToTLV + FromTLV<'a> + Clone + PartialEq, +// // { +// // fn eq(&self, other: &&[T]) -> bool { +// // let mut iter1 = self.iter(); +// // let mut iter2 = other.iter(); +// // loop { +// // match (iter1.next(), iter2.next()) { +// // (None, None) => return true, +// // (Some(x), Some(y)) => { +// // if x != *y { +// // return false; +// // } +// // } +// // _ => return false, +// // } +// // } +// // } +// // } + +// impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { +// fn from_tlv(t: TLVElement<'a>) -> Result { +// TLVArray::new(t) +// } +// } + +// impl<'a, T> ToTLV for TLVArray<'a, T> { +// fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { +// tw.start_array(tag_type)?; +// for a in self.iter() { +// a.to_tlv(tw, TagType::Anonymous)?; +// } +// tw.end_container() +// // match *self { +// // Self::Slice(s) => { +// // tw.start_array(tag_type)?; +// // for a in s { +// // a.to_tlv(tw, TagType::Anonymous)?; +// // } +// // tw.end_container() +// // } +// // Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV +// // } +// } + +// fn tlv_iter(&self, tag: TagType) -> impl Iterator + '_ { +// empty() +// .start_array(tag) +// .chain(self.iter().flat_map(move |i| i.into_tlv_iter(TagType::Anonymous))) +// .end_container() +// } + +// fn into_tlv_iter(self, tag: TagType) -> impl Iterator where Self: Sized { +// empty() +// .start_array(tag) +// .chain(self.into_iter().flat_map(move |i| i.into_tlv_iter(TagType::Anonymous))) +// .end_container() +// } +// } + +// impl<'a, T: Debug + ToTLV + FromTLV<'a> + Clone> Debug for TLVArray<'a, T> { +// fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { +// write!(f, "TLVArray [")?; +// let mut first = true; +// for i in self.iter() { +// if !first { +// write!(f, ", ")?; +// } + +// write!(f, "{:?}", i)?; +// first = false; +// } +// write!(f, "]") +// } +// } diff --git a/rs-matter/src/tlv/traits/maybe.rs b/rs-matter/src/tlv/traits/maybe.rs new file mode 100644 index 00000000..18cfdd7c --- /dev/null +++ b/rs-matter/src/tlv/traits/maybe.rs @@ -0,0 +1,168 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for TLV optional values and TLV nullable types via `Maybe` and `Option`. +//! - `Option` and `Optional` both represent an optional value in a TLV struct +//! - `Nullable` represents a nullable TLV type, where `T` is the non-nullable subdomain of the type. +//! i.e. `Nullable` represents the nullable variation of the TLV `U8` type. +//! +//! To elaborate, `null` and optional are two different notions in the TLV spec: +//! - Optional values apply only to TLV structs, and have the semantics +//! that the value might not be provided in the TLV stream for that struct +//! - `null` is a property of the type _domain_ and therefore applies to all TLV types, +//! and has the semantics that the value is provided, but is null +//! +//! Therefore, e.g. `Optional>` is completely valid (in the context of a struct member) +//! and means that this struct member is optional, but additionally - when provided - can be null. +//! +//! In terms of memory optimizations: +//! - Use `Option` only when the optional T value is small, as `Option` cannot be in-place initialized; +//! otherwise, use `Optional` (which is equivalent to `Maybe` and `Maybe`). +//! - Use `Nullable` (which is equivalent to `Maybe`) to represent +//! the nullable variations of the TLV types. This type can always be initialized in-place. +//! +//! Using `Optional` (or `Option`) **outside** of struct members has no TLV meaning but won't fail either: +//! - During deserialization, a stream containing a value of type `T` would be deserialized as `Some(T)` if the user has +//! provided an `Option` or an `Optional` type declaration instead of just `T` +//! - During serialization, a value of `Some(T)` would be serialized as `T`, while a value `None` would simply not be serialized + +use core::fmt::Debug; +use core::iter::empty; + +use crate::error::Error; +use crate::utils::init; +use crate::utils::maybe::Maybe; + +use super::{EitherIter, FromTLV, TLVElement, TLVTag, TLVValueType, TLVWrite, ToTLV, TLV}; + +/// A tag for `Maybe` that makes it behave as an optional struct value per the TLV spec. +pub type AsOptional = (); + +/// A tag for `Maybe` that makes it behave as a nullable type per the TLV spec. +#[derive(Debug)] +pub struct AsNullable; + +/// Represents optional values as per the TLV spec. +/// +/// Note that `Option` also represents optional values, but `Option` +/// cannot be created in-place, which is necessary when large values are involved. +/// +/// Therefore, using `Optional` is recommended over `Option` when the optional value is large. +pub type Optional = Maybe; + +/// Represents nullable values as per the TLV spec. +pub type Nullable = Maybe; + +impl<'a, T: FromTLV<'a>> FromTLV<'a> for Maybe { + fn from_tlv(element: &TLVElement<'a>) -> Result { + match element.control()?.value_type { + TLVValueType::Null => Ok(Maybe::none()), + _ => Ok(Maybe::some(T::from_tlv(element)?)), + } + } + + fn init_from_tlv(element: TLVElement<'a>) -> impl init::Init { + unsafe { + init::init_from_closure(move |slot| { + let init = match element.control()?.value_type { + TLVValueType::Null => None, + _ => Some(T::init_from_tlv(element)), + }; + + init::Init::__init(Maybe::init(init), slot) + }) + } + } +} + +impl ToTLV for Maybe { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + match self.as_ref() { + None => tw.null(tag), + Some(s) => s.to_tlv(tag, tw), + } + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + match self.as_ref() { + None => EitherIter::First(TLV::null(tag).into_tlv_iter()), + Some(s) => EitherIter::Second(s.tlv_iter(tag)), + } + } +} + +impl<'a, T: FromTLV<'a> + 'a> FromTLV<'a> for Maybe { + fn from_tlv(element: &TLVElement<'a>) -> Result { + if element.is_empty() { + Ok(Self::none()) + } else { + Ok(Self::some(T::from_tlv(element)?)) + } + } + + fn init_from_tlv(element: TLVElement<'a>) -> impl init::Init { + if element.is_empty() { + Self::init(None) + } else { + Self::init(Some(T::init_from_tlv(element))) + } + } +} + +impl ToTLV for Maybe { + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + match self.as_ref() { + None => Ok(()), + Some(s) => s.to_tlv(tag, tw), + } + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + use crate::tlv::EitherIter; + + match self.as_ref() { + None => EitherIter::First(empty()), + Some(s) => EitherIter::Second(s.tlv_iter(tag)), + } + } +} + +impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option { + fn from_tlv(element: &TLVElement<'a>) -> Result { + if element.is_empty() { + return Ok(None); + } + + Ok(Some(T::from_tlv(element)?)) + } +} + +impl ToTLV for Option { + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + match self.as_ref() { + None => Ok(()), + Some(s) => s.to_tlv(tag, tw), + } + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + match self.as_ref() { + None => EitherIter::First(empty()), + Some(s) => EitherIter::Second(s.tlv_iter(tag)), + } + } +} diff --git a/rs-matter/src/tlv/traits/octets.rs b/rs-matter/src/tlv/traits/octets.rs new file mode 100644 index 00000000..b3c53d99 --- /dev/null +++ b/rs-matter/src/tlv/traits/octets.rs @@ -0,0 +1,167 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for octet strings (i.e. byte arrays). +//! +//! Support is provided via two dedicated newtypes: +//! - `Octets<'a>` newtype which wraps an ordinary `&[u8]` - for borrowed byte arrays +//! - `OctetsOwned` newtype which wraps a `Vec` for owned byte arrays of fixed length N +//! +//! Newtype wrapping is necessary because naked Rust slices, arrays and the naked `Vec` type +//! serialize and deserialize as TLV arrays, rather than as octet strings. +//! +//! I.e. serializing `[0; 3]` will result in a TLV array with 3 elements of type u8 and value 0, rather than a TLV +//! octet string containing 3 zero bytes. + +use core::borrow::{Borrow, BorrowMut}; +use core::fmt::Debug; +use core::hash::Hash; +use core::ops::{Deref, DerefMut}; + +use crate::error::{Error, ErrorCode}; +use crate::utils::init::{self, init, IntoFallibleInit}; +use crate::utils::storage::Vec; + +use super::{FromTLV, TLVElement, TLVTag, TLVWrite, ToTLV, TLV}; + +/// For backwards compatibility +pub type OctetStr<'a> = Octets<'a>; + +/// For backwards compatibility +pub type OctetStrOwned = OctetsOwned; + +/// Newtype for borrowed byte arrays +/// +/// When deserializing, this type grabs the octet slice directly from the `TLVElement`. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +#[repr(transparent)] +pub struct Octets<'a>(pub &'a [u8]); + +impl<'a> Octets<'a> { + pub const fn new(slice: &'a [u8]) -> Self { + Self(slice) + } +} + +impl<'a> Deref for Octets<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> FromTLV<'a> for Octets<'a> { + fn from_tlv(element: &TLVElement<'a>) -> Result { + Ok(Octets(element.str()?)) + } +} + +impl<'a> ToTLV for Octets<'a> { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.str(tag, self.0) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + TLV::str(tag, self.0).into_tlv_iter() + } +} + +/// Newtype for owned byte arrays with a fixed maximum length +/// (represented by a `Vec`) +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[repr(transparent)] +pub struct OctetsOwned { + pub vec: Vec, +} + +impl Default for OctetsOwned { + fn default() -> Self { + Self::new() + } +} + +impl OctetsOwned { + /// Create a new empty `OctetsOwned` instance + pub const fn new() -> Self { + Self { + vec: Vec::::new(), + } + } + + /// Create an in-place initializer for an empty `OctetsOwned` instance + pub fn init() -> impl init::Init { + init!(Self { + vec <- Vec::::init(), + }) + } +} + +impl Borrow<[u8]> for OctetsOwned { + fn borrow(&self) -> &[u8] { + &self.vec + } +} + +impl BorrowMut<[u8]> for OctetsOwned { + fn borrow_mut(&mut self) -> &mut [u8] { + &mut self.vec + } +} + +impl Deref for OctetsOwned { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.vec + } +} + +impl DerefMut for OctetsOwned { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.vec + } +} + +impl<'a, const N: usize> FromTLV<'a> for OctetsOwned { + fn from_tlv(element: &TLVElement<'a>) -> Result { + Ok(Self { + vec: element.str()?.try_into().map_err(|_| ErrorCode::NoSpace)?, + }) + } + + fn init_from_tlv(element: TLVElement<'a>) -> impl init::Init { + init::Init::chain(OctetsOwned::init().into_fallible(), move |bytes| { + bytes + .vec + .extend_from_slice(element.str()?) + .map_err(|_| ErrorCode::NoSpace)?; + + Ok(()) + }) + } +} + +impl ToTLV for OctetsOwned { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.str(tag, &self.vec) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + TLV::str(tag, self.vec.as_slice()).into_tlv_iter() + } +} diff --git a/rs-matter/src/tlv/traits/primitive.rs b/rs-matter/src/tlv/traits/primitive.rs new file mode 100644 index 00000000..43376f29 --- /dev/null +++ b/rs-matter/src/tlv/traits/primitive.rs @@ -0,0 +1,82 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for Rust primitive types. + +use crate::error::{Error, ErrorCode}; + +macro_rules! fromtlv_for { + ($($t:ident)*) => { + $( + impl<'a> $crate::tlv::FromTLV<'a> for $t { + fn from_tlv(element: &$crate::tlv::TLVElement<'a>) -> Result { + element.$t() + } + } + )* + }; +} + +macro_rules! fromtlv_for_nonzero { + ($($t:ident:$n:ty)*) => { + $( + impl<'a> $crate::tlv::FromTLV<'a> for $n { + fn from_tlv(element: &$crate::tlv::TLVElement<'a>) -> Result { + <$n>::new(element.$t()?).ok_or_else(|| ErrorCode::Invalid.into()) + } + } + )* + }; +} + +macro_rules! totlv_for { + ($($t:ident)*) => { + $( + impl $crate::tlv::ToTLV for $t { + fn to_tlv(&self, tag: &$crate::tlv::TLVTag, mut tw: W) -> Result<(), Error> { + tw.$t(tag, *self) + } + + fn tlv_iter(&self, tag: $crate::tlv::TLVTag) -> impl Iterator> { + $crate::tlv::TLV::$t(tag, *self).into_tlv_iter() + } + } + )* + }; +} + +macro_rules! totlv_for_nonzero { + ($($t:ident:$n:ty)*) => { + $( + impl $crate::tlv::ToTLV for $n { + fn to_tlv(&self, tag: &$crate::tlv::TLVTag, mut tw: W) -> Result<(), Error> { + tw.$t(tag, self.get()) + } + + fn tlv_iter(&self, tag: $crate::tlv::TLVTag) -> impl Iterator> { + $crate::tlv::TLV::$t(tag, self.get()).into_tlv_iter() + } + } + )* + }; +} + +fromtlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); +fromtlv_for_nonzero!(i8:core::num::NonZeroI8 u8:core::num::NonZeroU8 i16:core::num::NonZeroI16 u16:core::num::NonZeroU16 i32:core::num::NonZeroI32 u32:core::num::NonZeroU32 i64:core::num::NonZeroI64 u64:core::num::NonZeroU64); + +totlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool); +totlv_for_nonzero!(i8:core::num::NonZeroI8 u8:core::num::NonZeroU8 i16:core::num::NonZeroI16 u16:core::num::NonZeroU16 i32:core::num::NonZeroI32 u32:core::num::NonZeroU32 i64:core::num::NonZeroI64 u64:core::num::NonZeroU64); diff --git a/rs-matter/src/tlv/traits/slice.rs b/rs-matter/src/tlv/traits/slice.rs new file mode 100644 index 00000000..869b2059 --- /dev/null +++ b/rs-matter/src/tlv/traits/slice.rs @@ -0,0 +1,105 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for Rust slices `&[T]`. +//! Rust slices are serialized as TLV arrays. +//! +//! Note that only serialization `(trait `ToTLV`) is supported for Rust slices, +//! because deserialization (`FromTLV`) requires the deserialized Rust type +//! to be `Sized`, which slices aren't. +//! +//! (Deserializing strings as `&str` and octets as `Bytes<'a>` (which is really a newtype over +//! `&'a [u8]`) is supported, but that's because their deserialization works by borrowing their +//! content 1:1 from inside the byte slice of the `TLVElement`, which is not possible for a generic +//! `T` and only possible when `T` is a `u8`.) + +use crate::error::Error; + +use super::{TLVTag, TLVValue, TLVWrite, ToTLV, TLV}; + +/// This type alias is necessary, because `FromTLV` / `ToTLV` do not (yet) support +/// members that are slices. +/// +/// Therefore, use `Slice<'a, T>` instead of `&'a [T]` as a syntax in your structs. +pub type Slice<'a, T> = &'a [T]; + +impl<'a, T: ToTLV> ToTLV for &'a [T] +where + T: ToTLV, +{ + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + to_tlv_array(tag, self.iter(), tw) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator, Error>> { + tlv_array_iter(tag, self.iter()) + } +} + +// TODO: Uncomment once `feature(impl_trait_in_assoc_type)` is stable +// pub struct IntoTLVIter<'a, T>(pub &'a TLVTag, pub T); + +// impl<'a, T> IntoIterator for IntoTLVIter<'a, &'a [T]> +// where +// T: ToTLV + 'a, +// { +// type Item = Result, Error>; +// type IntoIter = impl Iterator; + +// fn into_iter(self) -> Self::IntoIter { +// tlv_array_iter(self.0.clone(), self.1.iter()) +// } +// } + +pub(crate) fn to_tlv_array(tag: &TLVTag, iter: I, mut tw: W) -> Result<(), Error> +where + I: Iterator, + I::Item: ToTLV, + W: TLVWrite, +{ + tw.start_array(tag)?; + + for i in iter { + i.to_tlv(&TLVTag::Anonymous, &mut tw)?; + } + + tw.end_container() +} + +pub(crate) fn tlv_array_iter<'s, I, T>( + tag: TLVTag, + iter: I, +) -> impl Iterator, Error>> +where + I: Iterator + 's, + T: ToTLV + 's, +{ + tlv_container_iter(TLV::new(tag, TLVValue::Array), iter) +} + +pub(crate) fn tlv_container_iter<'s, I, T>( + tlv: TLV<'s>, + iter: I, +) -> impl Iterator, Error>> + 's +where + I: Iterator + 's, + T: ToTLV + 's, +{ + tlv.into_tlv_iter() + .chain(iter.flat_map(|t| t.tlv_iter(TLVTag::Anonymous))) + .chain(TLV::end_container().into_tlv_iter()) +} diff --git a/rs-matter/src/tlv/traits/str.rs b/rs-matter/src/tlv/traits/str.rs new file mode 100644 index 00000000..54519680 --- /dev/null +++ b/rs-matter/src/tlv/traits/str.rs @@ -0,0 +1,79 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for octets representing valid utf8 sequences (i.e. utf8 strings). +//! +//! - `&str` is used for serializing and deserializing borrowed utf8 strings +//! - `String` (from `heapless`) is used for serializing and deserializing owned strings of fixed length N +//! +//! Note that (for now) `String` has no efficient in-place initialization, so it should not be used for +//! holding large strings, or else a stack overflow might occur. + +use heapless::String; + +use crate::error::{Error, ErrorCode}; + +use super::{FromTLV, TLVElement, TLVTag, TLVWrite, ToTLV, TLV}; + +/// For (partial) backwards compatibility +/// +/// Partial because `UtfStr` used to be a newtype rather than a type alias, +/// and - furthermore - used to expose the Utf8 octets as raw bytes +/// rather than as the native Rust `str` type. The reason for that is probably +/// a misundersatanding that Utf16l, Utf32l and Utf64l are not UTF-8 strings, +/// while they actually are. Simply their length prefix is encoded variably. +pub type UtfStr<'a> = Utf8Str<'a>; + +/// Necessary because the `FromTLV` proc macro impl currently cannot handle +/// reference types. +/// +/// This restriction might be lifted in the future. +pub type Utf8Str<'a> = &'a str; + +impl<'a> FromTLV<'a> for &'a str { + fn from_tlv(element: &TLVElement<'a>) -> Result { + element.utf8() + } +} + +impl ToTLV for &str { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.utf8(tag, self) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator, Error>> { + TLV::utf8(tag, self).into_tlv_iter() + } +} + +impl<'a, const N: usize> FromTLV<'a> for String { + fn from_tlv(element: &TLVElement<'a>) -> Result, Error> { + element + .utf8() + .and_then(|s| s.try_into().map_err(|_| ErrorCode::NoSpace.into())) + } +} + +impl ToTLV for String { + fn to_tlv(&self, tag: &TLVTag, mut tw: W) -> Result<(), Error> { + tw.utf8(tag, self) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator, Error>> { + TLV::utf8(tag, self.as_str()).into_tlv_iter() + } +} diff --git a/rs-matter/src/tlv/traits/vec.rs b/rs-matter/src/tlv/traits/vec.rs new file mode 100644 index 00000000..211d330e --- /dev/null +++ b/rs-matter/src/tlv/traits/vec.rs @@ -0,0 +1,76 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! TLV support for the `Vec` type. +//! `Vec` is serialized and deserialized as a TLV array. +//! +//! Unlike Rust `[T; N]` arrays, the `Vec` type can be efficiently deserialized in-place, so use it +//! when the array holds large structures (like fabrics, certificates, sessions and so on). +//! +//! Of course, the `Vec` type is always owned (even if the deserialized elements `T` do borrow from the +//! deserializer), so it might consume more memory than necessary, as its memory is statically allocated +//! to be N * size_of(T) bytes. +//! +//! For cases where the array does not need to be owned and instantiating `T` elements on the fly when +//! traversing the array is tolerable (i.e. `T` is small enough), prefer `TLVArray`, which operates +//! directly on the borrowed, encoded TLV representation of the whole array. + +use crate::error::{Error, ErrorCode}; +use crate::utils::init::{self, IntoFallibleInit}; +use crate::utils::storage::Vec; + +use super::{slice::tlv_array_iter, FromTLV, TLVArray, TLVElement, TLVTag, TLVWrite, ToTLV, TLV}; + +impl<'a, T, const N: usize> FromTLV<'a> for Vec +where + T: FromTLV<'a> + 'a, +{ + fn from_tlv(element: &TLVElement<'a>) -> Result { + let mut vec = Vec::::new(); + + for item in TLVArray::new(element.clone())? { + vec.push(item?).map_err(|_| ErrorCode::NoSpace)?; + } + + Ok(vec) + } + + fn init_from_tlv(tlv: TLVElement<'a>) -> impl init::Init { + init::Init::chain(Vec::::init().into_fallible(), move |vec| { + let mut iter = TLVArray::new(tlv)?.iter(); + + while let Some(item) = iter.try_next_init() { + vec.push_init(item?, || ErrorCode::NoSpace.into())?; + } + + Ok(()) + }) + } +} + +impl ToTLV for Vec +where + T: ToTLV, +{ + fn to_tlv(&self, tag: &TLVTag, tw: W) -> Result<(), Error> { + self.as_slice().to_tlv(tag, tw) + } + + fn tlv_iter(&self, tag: TLVTag) -> impl Iterator> { + tlv_array_iter(tag, self.iter()) + } +} diff --git a/rs-matter/src/tlv/write.rs b/rs-matter/src/tlv/write.rs new file mode 100644 index 00000000..9e8ce7cb --- /dev/null +++ b/rs-matter/src/tlv/write.rs @@ -0,0 +1,909 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use num_traits::ToBytes; + +use crate::error::{Error, ErrorCode}; +use crate::utils::storage::WriteBuf; + +use super::{TLVControl, TLVTag, TLVTagType, TLVValue, TLVValueType}; + +/// For backwards compatibility +pub struct TLVWriter<'a, 'b>(&'a mut WriteBuf<'b>); + +impl<'a, 'b> TLVWriter<'a, 'b> { + pub fn new(buf: &'a mut WriteBuf<'b>) -> Self { + Self(buf) + } + + /// Write a tag and a TLV Octet String to the TLV stream, where the Octet String is a slice of u8 bytes. + /// + /// The writing is done via a user-supplied callback `cb`, that is expected to fill the provided buffer with the data + /// and to return the length of the written data. + /// + /// This method is useful when the data to be written needs to be computed first, and the computation needs a buffer where + /// to operate. + /// + /// Note that this method always uses a Str16l value type to write the data, which restricts the data length to no more than + /// 65535 bytes. + pub fn str_cb( + &mut self, + tag: &TLVTag, + cb: impl FnOnce(&mut [u8]) -> Result, + ) -> Result<(), Error> { + self.0.str_cb(tag, cb) + } + + /// Write a tag and a TLV UTF-8 String to the TLV stream, where the UTF-8 String is a str. + /// + /// The writing is done via a user-supplied callback `cb`, that is expected to fill the provided buffer with the data + /// and to return the length of the written data. + /// + /// This method is useful when the data to be written needs to be computed first, and the computation needs a buffer where + /// to operate. + /// + /// Note that this method always uses a Utf16l value type to write the data, which restricts the data length to no more than + /// 65535 bytes. + pub fn utf8_cb( + &mut self, + tag: &TLVTag, + cb: impl FnOnce(&mut [u8]) -> Result, + ) -> Result<(), Error> { + self.0.utf8_cb(tag, cb) + } +} + +impl<'a, 'b> TLVWrite for TLVWriter<'a, 'b> { + type Position = usize; + + fn write(&mut self, byte: u8) -> Result<(), Error> { + WriteBuf::append(self.0, &[byte]) + } + + fn get_tail(&self) -> Self::Position { + WriteBuf::get_tail(self.0) + } + + fn rewind_to(&mut self, pos: Self::Position) { + WriteBuf::rewind_tail_to(self.0, pos) + } +} + +/// A trait representing a storage where data can be serialized as a TLV stream. +/// by synchronously emitting bytes to the storage. +/// +/// The one method that needs to be implemented by user code is `write`. +/// +/// The trait operates in an append-only manner without requiring access to the serialized +/// TLV data, so it can be implemented with an in-memory storage, or a file storage, or anything +/// that can output a byte to somewhere (like the `Write` Rust traits). +/// +/// With that said, the trait has two additional methods that (optionally) allow for "rewinding" +/// the storage. Implementing these is optional, and they currently exist only for backwards +/// compatibility with code implemented prior to the introduction of this trait. +/// +/// For iterator-style TLV serialization look at the `ToTLVIter` trait. +pub trait TLVWrite { + type Position; + + /// Write a TLV tag and value to the TLV stream. + fn tlv(&mut self, tag: &TLVTag, value: &TLVValue) -> Result<(), Error> { + self.raw_value(tag, value.value_type(), &[])?; + + match value { + TLVValue::Str8l(a) => self.write_raw_data((a.len() as u8).to_le_bytes()), + TLVValue::Str16l(a) => self.write_raw_data((a.len() as u16).to_le_bytes()), + TLVValue::Str32l(a) => self.write_raw_data((a.len() as u32).to_le_bytes()), + TLVValue::Str64l(a) => self.write_raw_data((a.len() as u64).to_le_bytes()), + TLVValue::Utf8l(a) => self.write_raw_data((a.len() as u8).to_le_bytes()), + TLVValue::Utf16l(a) => self.write_raw_data((a.len() as u16).to_le_bytes()), + TLVValue::Utf32l(a) => self.write_raw_data((a.len() as u32).to_le_bytes()), + TLVValue::Utf64l(a) => self.write_raw_data((a.len() as u64).to_le_bytes()), + _ => Ok(()), + }?; + + match value { + TLVValue::S8(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::S16(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::S32(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::S64(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::U8(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::U16(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::U32(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::U64(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::False => Ok(()), + TLVValue::True => Ok(()), + TLVValue::F32(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::F64(a) => self.write_raw_data(a.to_le_bytes()), + TLVValue::Utf8l(a) + | TLVValue::Utf16l(a) + | TLVValue::Utf32l(a) + | TLVValue::Utf64l(a) => self.write_raw_data(a.as_bytes().iter().copied()), + TLVValue::Str8l(a) + | TLVValue::Str16l(a) + | TLVValue::Str32l(a) + | TLVValue::Str64l(a) => self.write_raw_data(a.iter().copied()), + TLVValue::Null + | TLVValue::Struct + | TLVValue::Array + | TLVValue::List + | TLVValue::EndCnt => Ok(()), + } + } + + /// Write a tag and a TLV S8 value to the TLV stream. + fn i8(&mut self, tag: &TLVTag, data: i8) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::S8, &data.to_le_bytes()) + } + + /// Write a tag and a TLV U8 value to the TLV stream. + fn u8(&mut self, tag: &TLVTag, data: u8) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::U8, &data.to_le_bytes()) + } + + /// Write a tag and a TLV S16 or (if the data is small enough) S8 value to the TLV stream. + fn i16(&mut self, tag: &TLVTag, data: i16) -> Result<(), Error> { + if data >= i8::MIN as i16 && data <= i8::MAX as i16 { + self.i8(tag, data as i8) + } else { + self.raw_value(tag, TLVValueType::S16, &data.to_le_bytes()) + } + } + + /// Write a tag and a TLV U16 or (if the data is small enough) U8 value to the TLV stream. + fn u16(&mut self, tag: &TLVTag, data: u16) -> Result<(), Error> { + if data <= u8::MAX as u16 { + self.u8(tag, data as u8) + } else { + self.raw_value(tag, TLVValueType::U16, &data.to_le_bytes()) + } + } + + /// Write a tag and a TLV S32 or (if the data is small enough) S16 or S8 value to the TLV stream. + fn i32(&mut self, tag: &TLVTag, data: i32) -> Result<(), Error> { + if data >= i16::MIN as i32 && data <= i16::MAX as i32 { + self.i16(tag, data as i16) + } else { + self.raw_value(tag, TLVValueType::S32, &data.to_le_bytes()) + } + } + + /// Write a tag and a TLV U32 or (if the data is small enough) U16 or U8 value to the TLV stream. + fn u32(&mut self, tag: &TLVTag, data: u32) -> Result<(), Error> { + if data <= u16::MAX as u32 { + self.u16(tag, data as u16) + } else { + self.raw_value(tag, TLVValueType::U32, &data.to_le_bytes()) + } + } + + /// Write a tag and a TLV S64 or (if the data is small enough) S32, S16, or S8 value to the TLV stream. + fn i64(&mut self, tag: &TLVTag, data: i64) -> Result<(), Error> { + if data >= i32::MIN as i64 && data <= i32::MAX as i64 { + self.i32(tag, data as i32) + } else { + self.raw_value(tag, TLVValueType::S64, &data.to_le_bytes()) + } + } + + /// Write a tag and a TLV U64 or (if the data is small enough) U32, U16, or U8 value to the TLV stream. + fn u64(&mut self, tag: &TLVTag, data: u64) -> Result<(), Error> { + if data <= u32::MAX as u64 { + self.u32(tag, data as u32) + } else { + self.raw_value(tag, TLVValueType::U64, &data.to_le_bytes()) + } + } + + /// Write a tag and a TLV F32 to the TLV stream. + fn f32(&mut self, tag: &TLVTag, data: f32) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::F32, &data.to_le_bytes()) + } + + /// Write a tag and a TLV F64 to the TLV stream. + fn f64(&mut self, tag: &TLVTag, data: f64) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::F64, &data.to_le_bytes()) + } + + /// Write a tag and a TLV Octet String to the TLV stream, where the Octet String is a slice of u8 bytes. + /// + /// The exact octet string type (Str8l, Str16l, Str32l, or Str64l) is chosen based on the length of the data, + /// whereas the smallest type filling the provided data length is chosen. + fn str(&mut self, tag: &TLVTag, data: &[u8]) -> Result<(), Error> { + self.stri(tag, data.len(), data.iter().copied()) + } + + /// Write a tag and a TLV Octet String to the TLV stream, where the Octet String is + /// anything that can be turned into an iterator of u8 bytes. + /// + /// The exact octet string type (Str8l, Str16l, Str32l, or Str64l) is chosen based on the length of the data, + /// whereas the smallest type filling the provided data length is chosen. + /// + /// NOTE: The length of the Octet String must be provided by the user and it must match the + /// number of bytes returned by the provided iterator, or else the generated TLV stream will be invalid. + fn stri(&mut self, tag: &TLVTag, len: usize, data: I) -> Result<(), Error> + where + I: IntoIterator, + { + if len <= u8::MAX as usize { + self.raw_value(tag, TLVValueType::Str8l, &(len as u8).to_le_bytes())?; + } else if len <= u16::MAX as usize { + self.raw_value(tag, TLVValueType::Str16l, &(len as u16).to_le_bytes())?; + } else if len <= u32::MAX as usize { + self.raw_value(tag, TLVValueType::Str32l, &(len as u32).to_le_bytes())?; + } else { + self.raw_value(tag, TLVValueType::Str64l, &(len as u64).to_le_bytes())?; + } + + self.write_raw_data(data) + } + + /// Write a tag and a TLV UTF-8 String to the TLV stream, where the UTF-8 String is a str. + /// + /// The exact UTF-8 string type (Utf8l, Utf16l, Utf32l, or Utf64l) is chosen based on the length of the data, + /// whereas the smallest type filling the provided data length is chosen. + fn utf8(&mut self, tag: &TLVTag, data: &str) -> Result<(), Error> { + self.utf8i(tag, data.len(), data.as_bytes().iter().copied()) + } + + /// Write a tag and a TLV UTF-8 String to the TLV stream, where the UTF-8 String is + /// anything that can be turned into an iterator of u8 bytes. + /// + /// The exact UTF-8 string type (Utf8l, Utf16l, Utf32l, or Utf64l) is chosen based on the length of the data, + /// whereas the smallest type filling the provided data length is chosen. + /// + /// NOTE 1: The length of the UTF-8 String must be provided by the user and it must match the + /// number of bytes returned by the provided iterator, or else the generated TLV stream will be invalid. + /// + /// NOTE 2: The provided iterator must return valid UTF-8 bytes, or else the generated TLV stream will be invalid. + fn utf8i(&mut self, tag: &TLVTag, len: usize, data: I) -> Result<(), Error> + where + I: IntoIterator, + { + if len <= u8::MAX as usize { + self.raw_value(tag, TLVValueType::Utf8l, &(len as u8).to_le_bytes())?; + } else if len <= u16::MAX as usize { + self.raw_value(tag, TLVValueType::Utf16l, &(len as u16).to_le_bytes())?; + } else if len <= u32::MAX as usize { + self.raw_value(tag, TLVValueType::Utf32l, &(len as u32).to_le_bytes())?; + } else { + self.raw_value(tag, TLVValueType::Utf64l, &(len as u64).to_le_bytes())?; + } + + self.write_raw_data(data) + } + + /// Write a tag and a value indicating the start of a Struct TLV container. + /// + /// NOTE: The user must call `end_container` after writing all the Struct fields + /// to close the Struct container or else the generated TLV stream will be invalid. + fn start_struct(&mut self, tag: &TLVTag) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::Struct, &[]) + } + + /// Write a tag and a value indicating the start of an Array TLV container. + /// + /// NOTE: The user must call `end_container` after writing all the Array elements + /// to close the Array container or else the generated TLV stream will be invalid. + fn start_array(&mut self, tag: &TLVTag) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::Array, &[]) + } + + /// Write a tag and a value indicating the start of a List TLV container. + /// + /// NOTE: The user must call `end_container` after writing all the List elements + /// to close the List container or else the generated TLV stream will be invalid. + fn start_list(&mut self, tag: &TLVTag) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::List, &[]) + } + + /// Write a tag and a value indicating the start of a Struct TLV container. + /// + /// NOTE: The user must call `end_container` after writing all the Struct fields + /// to close the Struct container or else the generated TLV stream will be invalid. + fn start_container(&mut self, tag: &TLVTag, container_type: TLVValueType) -> Result<(), Error> { + if !container_type.is_container() { + Err(ErrorCode::TLVTypeMismatch)?; + } + + self.raw_value(tag, container_type, &[]) + } + + /// Write a value indicating the end of a Struct, Array, or List TLV container. + /// + /// NOTE: This method must be called only when the corresponding container has been opened + /// using `start_struct`, `start_array`, or `start_list`, or else the generated TLV stream will be invalid. + fn end_container(&mut self) -> Result<(), Error> { + self.write(TLVControl::new(TLVTagType::Anonymous, TLVValueType::EndCnt).as_raw()) + } + + /// Write a tag and a TLV Null value to the TLV stream. + fn null(&mut self, tag: &TLVTag) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::Null, &[]) + } + + /// Write a tag and a TLV True or False value to the TLV stream. + fn bool(&mut self, tag: &TLVTag, val: bool) -> Result<(), Error> { + self.raw_value( + tag, + if val { + TLVValueType::True + } else { + TLVValueType::False + }, + &[], + ) + } + + /// Write a tag and a raw, already-encoded TLV value represented as a byte slice. + fn raw_value( + &mut self, + tag: &TLVTag, + value_type: TLVValueType, + value_payload: &[u8], + ) -> Result<(), Error> { + self.write(TLVControl::new(tag.tag_type(), value_type).as_raw())?; + + match tag { + TLVTag::Anonymous => Ok(()), + TLVTag::Context(v) => self.write_raw_data(v.to_le_bytes()), + TLVTag::CommonPrf16(v) | TLVTag::ImplPrf16(v) => self.write_raw_data(v.to_le_bytes()), + TLVTag::CommonPrf32(v) | TLVTag::ImplPrf32(v) => self.write_raw_data(v.to_le_bytes()), + TLVTag::FullQual48 { + vendor_id, + profile, + tag, + } => { + self.write_raw_data(vendor_id.to_le_bytes())?; + self.write_raw_data(profile.to_le_bytes())?; + self.write_raw_data(tag.to_le_bytes()) + } + TLVTag::FullQual64 { + vendor_id, + profile, + tag, + } => { + self.write_raw_data(vendor_id.to_le_bytes())?; + self.write_raw_data(profile.to_le_bytes())?; + self.write_raw_data(tag.to_le_bytes()) + } + }?; + + self.write_raw_data(value_payload.iter().copied()) + } + + /// Append multiple raw bytes to the TLV stream. + fn write_raw_data(&mut self, bytes: I) -> Result<(), Error> + where + I: IntoIterator, + { + for byte in bytes { + self.write(byte)?; + } + + Ok(()) + } + + fn write(&mut self, byte: u8) -> Result<(), Error>; + + fn get_tail(&self) -> Self::Position; + + fn rewind_to(&mut self, _pos: Self::Position); +} + +impl TLVWrite for &mut T +where + T: TLVWrite, +{ + type Position = T::Position; + + fn write(&mut self, byte: u8) -> Result<(), Error> { + (**self).write(byte) + } + + fn get_tail(&self) -> Self::Position { + (**self).get_tail() + } + + fn rewind_to(&mut self, pos: Self::Position) { + (**self).rewind_to(pos) + } +} + +impl<'a> TLVWrite for WriteBuf<'a> { + type Position = usize; + + fn write(&mut self, byte: u8) -> Result<(), Error> { + WriteBuf::append(self, &[byte]) + } + + fn get_tail(&self) -> Self::Position { + WriteBuf::get_tail(self) + } + + fn rewind_to(&mut self, pos: Self::Position) { + WriteBuf::rewind_tail_to(self, pos) + } +} + +impl<'a> WriteBuf<'a> { + /// Write a tag and a TLV Octet String to the TLV stream, where the Octet String is a slice of u8 bytes. + /// + /// The writing is done via a user-supplied callback `cb`, that is expected to fill the provided buffer with the data + /// and to return the length of the written data. + /// + /// This method is useful when the data to be written needs to be computed first, and the computation needs a buffer where + /// to operate. + /// + /// Note that this method always uses a Str16l value type to write the data, which restricts the data length to no more than + /// 65535 bytes. + pub fn str_cb( + &mut self, + tag: &TLVTag, + cb: impl FnOnce(&mut [u8]) -> Result, + ) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::Str16l, &0_u16.to_le_bytes())?; + + let value_offset = self.get_tail(); + + let len = self.append_with_buf(cb)?; + + self.buf[value_offset - 2..value_offset].copy_from_slice(&(len as u16).to_le_bytes()); + + Ok(()) + } + + /// Write a tag and a TLV UTF-8 String to the TLV stream, where the UTF-8 String is a str. + /// + /// The writing is done via a user-supplied callback `cb`, that is expected to fill the provided buffer with the data + /// and to return the length of the written data. + /// + /// This method is useful when the data to be written needs to be computed first, and the computation needs a buffer where + /// to operate. + /// + /// Note that this method always uses a Utf16l value type to write the data, which restricts the data length to no more than + /// 65535 bytes. + pub fn utf8_cb( + &mut self, + tag: &TLVTag, + cb: impl FnOnce(&mut [u8]) -> Result, + ) -> Result<(), Error> { + self.raw_value(tag, TLVValueType::Utf16l, &0_u16.to_le_bytes())?; + + let value_offset = self.get_tail(); + + let len = self.append_with_buf(cb)?; + + self.buf[value_offset - 2..value_offset].copy_from_slice(&(len as u16).to_le_bytes()); + + Ok(()) + } +} + +/// A TLVWrite implementation that counts the number of bytes written. +impl TLVWrite for usize { + type Position = usize; + + fn write(&mut self, _byte: u8) -> Result<(), Error> { + *self += 1; + + Ok(()) + } + + fn get_tail(&self) -> Self::Position { + *self + } + + fn rewind_to(&mut self, pos: Self::Position) { + *self = pos; + } +} + +#[cfg(test)] +mod tests { + use core::f32; + + use super::{TLVTag, TLVWrite}; + use crate::{tlv::TLVValue, utils::storage::WriteBuf}; + + #[test] + fn test_write_success() { + let mut buf = [0; 20]; + let mut tw = WriteBuf::new(&mut buf); + + tw.start_struct(&TLVTag::Anonymous).unwrap(); + tw.u8(&TLVTag::Anonymous, 12).unwrap(); + tw.u8(&TLVTag::Context(1), 13).unwrap(); + tw.u16(&TLVTag::Anonymous, 0x1212).unwrap(); + tw.u16(&TLVTag::Context(2), 0x1313).unwrap(); + tw.start_array(&TLVTag::Context(3)).unwrap(); + tw.bool(&TLVTag::Anonymous, true).unwrap(); + tw.end_container().unwrap(); + tw.end_container().unwrap(); + assert_eq!( + buf, + [21, 4, 12, 36, 1, 13, 5, 0x12, 0x012, 37, 2, 0x13, 0x13, 54, 3, 9, 24, 24, 0, 0] + ); + } + + #[test] + fn test_write_overflow() { + let mut buf = [0; 6]; + let mut tw = WriteBuf::new(&mut buf); + + tw.u8(&TLVTag::Anonymous, 12).unwrap(); + tw.u8(&TLVTag::Context(1), 13).unwrap(); + if tw.u16(&TLVTag::Anonymous, 12).is_ok() { + panic!("This should have returned error") + } + if tw.u16(&TLVTag::Context(2), 13).is_ok() { + panic!("This should have returned error") + } + assert_eq!(buf, [4, 12, 36, 1, 13, 4]); + } + + #[test] + fn test_put_str8() { + let mut buf = [0; 20]; + let mut tw = WriteBuf::new(&mut buf); + + tw.u8(&TLVTag::Context(1), 13).unwrap(); + tw.str(&TLVTag::Anonymous, &[10, 11, 12, 13, 14]).unwrap(); + tw.u16(&TLVTag::Context(2), 0x1313).unwrap(); + tw.str(&TLVTag::Context(3), &[20, 21, 22]).unwrap(); + assert_eq!( + buf, + [36, 1, 13, 16, 5, 10, 11, 12, 13, 14, 37, 2, 0x13, 0x13, 48, 3, 3, 20, 21, 22] + ); + } + + #[test] + fn test_matter_spec_examples() { + let mut buf = [0; 200]; + let mut tw = WriteBuf::new(&mut buf); + + // Boolean false + + tw.bool(&TLVTag::Anonymous, false).unwrap(); + assert_eq!(&[0x08], tw.as_slice()); + + // Boolean true + + tw.reset(); + tw.bool(&TLVTag::Anonymous, true).unwrap(); + assert_eq!(&[0x09], tw.as_slice()); + + // Signed Integer, 1-octet, value 42 + + tw.reset(); + tw.i8(&TLVTag::Anonymous, 42).unwrap(); + assert_eq!(&[0x00, 0x2a], tw.as_slice()); + + // Signed Integer, 1-octet, value -17 + + tw.reset(); + tw.i32(&TLVTag::Anonymous, -17).unwrap(); + assert_eq!(&[0x00, 0xef], tw.as_slice()); + + // Unsigned Integer, 1-octet, value 42U + + tw.reset(); + tw.u8(&TLVTag::Anonymous, 42).unwrap(); + assert_eq!(&[0x04, 0x2a], tw.as_slice()); + + // Signed Integer, 2-octet, value 422 + + tw.reset(); + tw.i16(&TLVTag::Anonymous, 422).unwrap(); + assert_eq!(&[0x01, 0xa6, 0x01], tw.as_slice()); + + // Signed Integer, 4-octet, value -170000 + + tw.reset(); + tw.i32(&TLVTag::Anonymous, -170000).unwrap(); + assert_eq!(&[0x02, 0xf0, 0x67, 0xfd, 0xff], tw.as_slice()); + + // Signed Integer, 8-octet, value 40000000000 + + tw.reset(); + tw.i64(&TLVTag::Anonymous, 40000000000).unwrap(); + assert_eq!( + &[0x03, 0x00, 0x90, 0x2f, 0x50, 0x09, 0x00, 0x00, 0x00], + tw.as_slice() + ); + + // UTF-8 String, 1-octet length, "Hello!" + + tw.reset(); + tw.utf8(&TLVTag::Anonymous, "Hello!").unwrap(); + assert_eq!( + &[0x0c, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21], + tw.as_slice() + ); + + // UTF-8 String, 1-octet length, "Tschüs" + + tw.reset(); + tw.utf8i( + &TLVTag::Anonymous, + "Tschüs".len(), + "Tschüs".as_bytes().iter().copied(), + ) + .unwrap(); + assert_eq!( + &[0x0c, 0x07, 0x54, 0x73, 0x63, 0x68, 0xc3, 0xbc, 0x73], + tw.as_slice() + ); + + // Octet String, 1-octet length, octets 00 01 02 03 04 + + tw.reset(); + tw.str(&TLVTag::Anonymous, &[0x00, 0x01, 0x02, 0x03, 0x04]) + .unwrap(); + assert_eq!(&[0x10, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04], tw.as_slice()); + + // Null + + tw.reset(); + tw.tlv(&TLVTag::Anonymous, &TLVValue::Null).unwrap(); + assert_eq!(&[0x14], tw.as_slice()); + + // Single precision floating point 0.0 + + tw.reset(); + tw.tlv(&TLVTag::Anonymous, &TLVValue::F32(0.0)).unwrap(); + assert_eq!(&[0x0a, 0x00, 0x00, 0x00, 0x00], tw.as_slice()); + + // Single precision floating point (1.0 / 3.0) + + tw.reset(); + tw.f32(&TLVTag::Anonymous, 1.0 / 3.0).unwrap(); + assert_eq!(&[0x0a, 0xab, 0xaa, 0xaa, 0x3e], tw.as_slice()); + + // Single precision floating point 17.9 + + tw.reset(); + tw.f32(&TLVTag::Anonymous, 17.9).unwrap(); + assert_eq!(&[0x0a, 0x33, 0x33, 0x8f, 0x41], tw.as_slice()); + + // Single precision floating point infinity + + tw.reset(); + tw.f32(&TLVTag::Anonymous, f32::INFINITY).unwrap(); + assert_eq!(&[0x0a, 0x00, 0x00, 0x80, 0x7f], tw.as_slice()); + + // Single precision floating point negative infinity + + tw.reset(); + tw.f32(&TLVTag::Anonymous, f32::NEG_INFINITY).unwrap(); + assert_eq!(&[0x0a, 0x00, 0x00, 0x80, 0xff], tw.as_slice()); + + // Double precision floating point 0.0 + + tw.reset(); + tw.f64(&TLVTag::Anonymous, 0.0).unwrap(); + assert_eq!( + &[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00], + tw.as_slice() + ); + + // Double precision floating point (1.0 / 3.0) + + tw.reset(); + tw.f64(&TLVTag::Anonymous, 1.0 / 3.0).unwrap(); + assert_eq!( + &[0x0b, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0xd5, 0x3f], + tw.as_slice() + ); + + // Double precision floating point 17.9 + + tw.reset(); + tw.f64(&TLVTag::Anonymous, 17.9).unwrap(); + assert_eq!( + &[0x0b, 0x66, 0x66, 0x66, 0x66, 0x66, 0xe6, 0x31, 0x40], + tw.as_slice() + ); + + // Double precision floating point infinity (∞) + + tw.reset(); + tw.f64(&TLVTag::Anonymous, f64::INFINITY).unwrap(); + assert_eq!( + &[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0x7f], + tw.as_slice() + ); + + // Double precision floating point negative infinity + + tw.reset(); + tw.f64(&TLVTag::Anonymous, f64::NEG_INFINITY).unwrap(); + assert_eq!( + &[0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xf0, 0xff], + tw.as_slice() + ); + + // Empty Structure, {} + + tw.reset(); + tw.start_struct(&TLVTag::Anonymous).unwrap(); + tw.end_container().unwrap(); + assert_eq!(&[0x15, 0x18], tw.as_slice()); + + // Empty Array, [] + + tw.reset(); + tw.start_array(&TLVTag::Anonymous).unwrap(); + tw.end_container().unwrap(); + assert_eq!(&[0x16, 0x18], tw.as_slice()); + + // Empty List, [] + + tw.reset(); + tw.start_list(&TLVTag::Anonymous).unwrap(); + tw.end_container().unwrap(); + assert_eq!(&[0x17, 0x18], tw.as_slice()); + + // Structure, two context specific tags, Signed Integer, 1 octet values, {0 = 42, 1 = -17} + + tw.reset(); + tw.start_struct(&TLVTag::Anonymous).unwrap(); + tw.i8(&TLVTag::Context(0), 42).unwrap(); + tw.i32(&TLVTag::Context(1), -17).unwrap(); + tw.end_container().unwrap(); + assert_eq!( + &[0x15, 0x20, 0x00, 0x2a, 0x20, 0x01, 0xef, 0x18], + tw.as_slice() + ); + + // Array, Signed Integer, 1-octet values, [0, 1, 2, 3, 4] + + tw.reset(); + tw.start_array(&TLVTag::Anonymous).unwrap(); + for i in 0..5 { + tw.i8(&TLVTag::Anonymous, i as i8).unwrap(); + } + tw.end_container().unwrap(); + assert_eq!( + &[0x16, 0x00, 0x00, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04, 0x18], + tw.as_slice() + ); + + // List, mix of anonymous and context tags, Signed Integer, 1 octet values, [[1, 0 = 42, 2, 3, 0 = -17]] + + tw.reset(); + tw.start_list(&TLVTag::Anonymous).unwrap(); + tw.i64(&TLVTag::Anonymous, 1).unwrap(); + tw.i16(&TLVTag::Context(0), 42).unwrap(); + tw.i8(&TLVTag::Anonymous, 2).unwrap(); + tw.i8(&TLVTag::Anonymous, 3).unwrap(); + tw.i32(&TLVTag::Context(0), -17).unwrap(); + tw.end_container().unwrap(); + assert_eq!( + &[0x17, 0x00, 0x01, 0x20, 0x00, 0x2a, 0x00, 0x02, 0x00, 0x03, 0x20, 0x00, 0xef, 0x18], + tw.as_slice() + ); + + // Array, mix of element types, [42, -170000, {}, 17.9, "Hello!"] + + tw.reset(); + tw.start_array(&TLVTag::Anonymous).unwrap(); + tw.i64(&TLVTag::Anonymous, 42).unwrap(); + tw.i64(&TLVTag::Anonymous, -170000).unwrap(); + tw.start_struct(&TLVTag::Anonymous).unwrap(); + tw.end_container().unwrap(); + tw.f32(&TLVTag::Anonymous, 17.9).unwrap(); + tw.utf8(&TLVTag::Anonymous, "Hello!").unwrap(); + tw.end_container().unwrap(); + assert_eq!( + &[ + 0x16, 0x00, 0x2a, 0x02, 0xf0, 0x67, 0xfd, 0xff, 0x15, 0x18, 0x0a, 0x33, 0x33, 0x8f, + 0x41, 0x0c, 0x06, 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x21, 0x18, + ], + tw.as_slice() + ); + + // Anonymous tag, Unsigned Integer, 1-octet value, 42U + + tw.reset(); + tw.u64(&TLVTag::Anonymous, 42).unwrap(); + assert_eq!(&[0x04, 0x2a], tw.as_slice()); + + // Context tag 1, Unsigned Integer, 1-octet value, 1 = 42U + + tw.reset(); + tw.u64(&TLVTag::Context(1), 42).unwrap(); + assert_eq!(&[0x24, 0x01, 0x2a], tw.as_slice()); + + // Common profile tag 1, Unsigned Integer, 1-octet value, Matter::1 = 42U + + tw.reset(); + tw.u64(&TLVTag::CommonPrf16(1), 42).unwrap(); + assert_eq!(&[0x44, 0x01, 0x00, 0x2a], tw.as_slice()); + + // Common profile tag 100000, Unsigned Integer, 1-octet value, Matter::100000 = 42U + + tw.reset(); + tw.u64(&TLVTag::CommonPrf32(100000), 42).unwrap(); + assert_eq!(&[0x64, 0xa0, 0x86, 0x01, 0x00, 0x2a], tw.as_slice()); + + // Fully qualified tag, Vendor ID 0xFFF1/65521, pro­file number 0xDEED/57069, + // 2-octet tag 1, Unsigned Integer, 1-octet value 42, 65521::57069:1 = 42U + + tw.reset(); + tw.u64( + &TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 1, + }, + 42, + ) + .unwrap(); + assert_eq!( + &[0xc4, 0xf1, 0xff, 0xed, 0xde, 0x01, 0x00, 0x2a], + tw.as_slice() + ); + + // Fully qualified tag, Vendor ID 0xFFF1/65521, pro­file number 0xDEED/57069, + // 4-octet tag 0xAA55FEED/2857762541, Unsigned Integer, 1-octet value 42, 65521::57069:2857762541 = 42U + + tw.reset(); + tw.u64( + &TLVTag::FullQual64 { + vendor_id: 65521, + profile: 57069, + tag: 2857762541, + }, + 42, + ) + .unwrap(); + assert_eq!( + &[0xe4, 0xf1, 0xff, 0xed, 0xde, 0xed, 0xfe, 0x55, 0xaa, 0x2a], + tw.as_slice() + ); + + // Structure with the fully qualified tag, Vendor ID 0xFFF1/65521, profile number 0xDEED/57069, + // 2-octet tag 1. The structure contains a single ele­ment labeled using a fully qualified tag under + // the same profile, with 2-octet tag 0xAA55/43605. 65521::57069:1 = {65521::57069:43605 = 42U} + + tw.reset(); + tw.start_struct(&TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 1, + }) + .unwrap(); + tw.u64( + &TLVTag::FullQual48 { + vendor_id: 65521, + profile: 57069, + tag: 43605, + }, + 42, + ) + .unwrap(); + tw.end_container().unwrap(); + assert_eq!( + &[ + 0xd5, 0xf1, 0xff, 0xed, 0xde, 0x01, 0x00, 0xc4, 0xf1, 0xff, 0xed, 0xde, 0x55, 0xaa, + 0x2a, 0x18, + ], + tw.as_slice() + ); + } +} diff --git a/rs-matter/src/tlv/writer.rs b/rs-matter/src/tlv/writer.rs deleted file mode 100644 index e6af4efe..00000000 --- a/rs-matter/src/tlv/writer.rs +++ /dev/null @@ -1,356 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use log::error; - -use crate::error::*; -use crate::utils::storage::WriteBuf; - -use super::{TagType, TAG_SHIFT_BITS, TAG_SIZE_MAP}; - -#[allow(dead_code)] -enum WriteElementType { - S8 = 0, - S16 = 1, - S32 = 2, - S64 = 3, - U8 = 4, - U16 = 5, - U32 = 6, - U64 = 7, - False = 8, - True = 9, - F32 = 10, - F64 = 11, - Utf8l = 12, - Utf16l = 13, - Utf32l = 14, - Utf64l = 15, - Str8l = 16, - Str16l = 17, - Str32l = 18, - Str64l = 19, - Null = 20, - Struct = 21, - Array = 22, - List = 23, - EndCnt = 24, - Last, -} - -pub struct TLVWriter<'a, 'b> { - buf: &'a mut WriteBuf<'b>, -} - -impl<'a, 'b> TLVWriter<'a, 'b> { - pub fn new(buf: &'a mut WriteBuf<'b>) -> Self { - TLVWriter { buf } - } - - // TODO: The current method of using writebuf's put methods force us to do - // at max 3 checks while writing a single TLV (once for control, once for tag, - // once for value), so do a single check and write the whole thing. - #[inline(always)] - fn put_control_tag( - &mut self, - tag_type: TagType, - val_type: WriteElementType, - ) -> Result<(), Error> { - let (tag_id, tag_val) = match tag_type { - TagType::Anonymous => (0_u8, 0), - TagType::Context(v) => (1, v as u64), - TagType::CommonPrf16(v) => (2, v as u64), - TagType::CommonPrf32(v) => (3, v as u64), - TagType::ImplPrf16(v) => (4, v as u64), - TagType::ImplPrf32(v) => (5, v as u64), - TagType::FullQual48(v) => (6, v), - TagType::FullQual64(v) => (7, v), - }; - self.buf - .le_u8(((tag_id) << TAG_SHIFT_BITS) | (val_type as u8))?; - if tag_type != TagType::Anonymous { - self.buf.le_uint(TAG_SIZE_MAP[tag_id as usize], tag_val)?; - } - Ok(()) - } - - pub fn i8(&mut self, tag_type: TagType, data: i8) -> Result<(), Error> { - self.put_control_tag(tag_type, WriteElementType::S8)?; - self.buf.le_i8(data) - } - - pub fn u8(&mut self, tag_type: TagType, data: u8) -> Result<(), Error> { - self.put_control_tag(tag_type, WriteElementType::U8)?; - self.buf.le_u8(data) - } - - pub fn i16(&mut self, tag_type: TagType, data: i16) -> Result<(), Error> { - if data >= i8::MIN as i16 && data <= i8::MAX as i16 { - self.i8(tag_type, data as i8) - } else { - self.put_control_tag(tag_type, WriteElementType::S16)?; - self.buf.le_i16(data) - } - } - - pub fn u16(&mut self, tag_type: TagType, data: u16) -> Result<(), Error> { - if data <= 0xff { - self.u8(tag_type, data as u8) - } else { - self.put_control_tag(tag_type, WriteElementType::U16)?; - self.buf.le_u16(data) - } - } - - pub fn i32(&mut self, tag_type: TagType, data: i32) -> Result<(), Error> { - if data >= i8::MIN as i32 && data <= i8::MAX as i32 { - self.i8(tag_type, data as i8) - } else if data >= i16::MIN as i32 && data <= i16::MAX as i32 { - self.i16(tag_type, data as i16) - } else { - self.put_control_tag(tag_type, WriteElementType::S32)?; - self.buf.le_i32(data) - } - } - - pub fn u32(&mut self, tag_type: TagType, data: u32) -> Result<(), Error> { - if data <= 0xff { - self.u8(tag_type, data as u8) - } else if data <= 0xffff { - self.u16(tag_type, data as u16) - } else { - self.put_control_tag(tag_type, WriteElementType::U32)?; - self.buf.le_u32(data) - } - } - - pub fn i64(&mut self, tag_type: TagType, data: i64) -> Result<(), Error> { - if data >= i8::MIN as i64 && data <= i8::MAX as i64 { - self.i8(tag_type, data as i8) - } else if data >= i16::MIN as i64 && data <= i16::MAX as i64 { - self.i16(tag_type, data as i16) - } else if data >= i32::MIN as i64 && data <= i32::MAX as i64 { - self.i32(tag_type, data as i32) - } else { - self.put_control_tag(tag_type, WriteElementType::S64)?; - self.buf.le_i64(data) - } - } - - pub fn u64(&mut self, tag_type: TagType, data: u64) -> Result<(), Error> { - if data <= 0xff { - self.u8(tag_type, data as u8) - } else if data <= 0xffff { - self.u16(tag_type, data as u16) - } else if data <= 0xffffffff { - self.u32(tag_type, data as u32) - } else { - self.put_control_tag(tag_type, WriteElementType::U64)?; - self.buf.le_u64(data) - } - } - - pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { - if data.len() > 256 { - error!("use str16() instead"); - return Err(ErrorCode::Invalid.into()); - } - self.put_control_tag(tag_type, WriteElementType::Str8l)?; - self.buf.le_u8(data.len() as u8)?; - self.buf.copy_from_slice(data) - } - - pub fn str16(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { - if data.len() <= 0xff { - self.str8(tag_type, data) - } else { - self.put_control_tag(tag_type, WriteElementType::Str16l)?; - self.buf.le_u16(data.len() as u16)?; - self.buf.copy_from_slice(data) - } - } - - // This is quite hacky - pub fn str16_as(&mut self, tag_type: TagType, data_gen: F) -> Result<(), Error> - where - F: FnOnce(&mut [u8]) -> Result, - { - let anchor = self.buf.get_tail(); - self.put_control_tag(tag_type, WriteElementType::Str16l)?; - - let wb = self.buf.empty_as_mut_slice(); - // Reserve 2 spaces for the control and length - let str = &mut wb[2..]; - let len = data_gen(str).unwrap_or_default(); - if len <= 0xff { - // Shift everything by 1 - let str = &mut wb[1..]; - for i in 0..len { - str[i] = str[i + 1]; - } - self.buf.rewind_tail_to(anchor); - self.put_control_tag(tag_type, WriteElementType::Str8l)?; - self.buf.le_u8(len as u8)?; - } else { - self.buf.le_u16(len as u16)?; - } - self.buf.forward_tail_by(len); - Ok(()) - } - - pub fn utf8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { - self.put_control_tag(tag_type, WriteElementType::Utf8l)?; - self.buf.le_u8(data.len() as u8)?; - self.buf.copy_from_slice(data) - } - - pub fn utf16(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { - if data.len() <= 0xff { - self.utf8(tag_type, data) - } else { - self.put_control_tag(tag_type, WriteElementType::Utf16l)?; - self.buf.le_u16(data.len() as u16)?; - self.buf.copy_from_slice(data) - } - } - - fn no_val(&mut self, tag_type: TagType, element: WriteElementType) -> Result<(), Error> { - self.put_control_tag(tag_type, element) - } - - pub fn start_struct(&mut self, tag_type: TagType) -> Result<(), Error> { - self.no_val(tag_type, WriteElementType::Struct) - } - - pub fn start_array(&mut self, tag_type: TagType) -> Result<(), Error> { - self.no_val(tag_type, WriteElementType::Array) - } - - pub fn start_list(&mut self, tag_type: TagType) -> Result<(), Error> { - self.no_val(tag_type, WriteElementType::List) - } - - pub fn end_container(&mut self) -> Result<(), Error> { - self.no_val(TagType::Anonymous, WriteElementType::EndCnt) - } - - pub fn null(&mut self, tag_type: TagType) -> Result<(), Error> { - self.no_val(tag_type, WriteElementType::Null) - } - - pub fn bool(&mut self, tag_type: TagType, val: bool) -> Result<(), Error> { - if val { - self.no_val(tag_type, WriteElementType::True) - } else { - self.no_val(tag_type, WriteElementType::False) - } - } - - pub fn get_tail(&self) -> usize { - self.buf.get_tail() - } - - pub fn rewind_to(&mut self, anchor: usize) { - self.buf.rewind_tail_to(anchor); - } - - pub fn get_buf(&mut self) -> &mut WriteBuf<'b> { - self.buf - } -} - -#[cfg(test)] -mod tests { - use super::{TLVWriter, TagType}; - use crate::utils::storage::WriteBuf; - - #[test] - fn test_write_success() { - let mut buf = [0; 20]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - tw.start_struct(TagType::Anonymous).unwrap(); - tw.u8(TagType::Anonymous, 12).unwrap(); - tw.u8(TagType::Context(1), 13).unwrap(); - tw.u16(TagType::Anonymous, 0x1212).unwrap(); - tw.u16(TagType::Context(2), 0x1313).unwrap(); - tw.start_array(TagType::Context(3)).unwrap(); - tw.bool(TagType::Anonymous, true).unwrap(); - tw.end_container().unwrap(); - tw.end_container().unwrap(); - assert_eq!( - buf, - [21, 4, 12, 36, 1, 13, 5, 0x12, 0x012, 37, 2, 0x13, 0x13, 54, 3, 9, 24, 24, 0, 0] - ); - } - - #[test] - fn test_write_overflow() { - let mut buf = [0; 6]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - tw.u8(TagType::Anonymous, 12).unwrap(); - tw.u8(TagType::Context(1), 13).unwrap(); - if tw.u16(TagType::Anonymous, 12).is_ok() { - panic!("This should have returned error") - } - if tw.u16(TagType::Context(2), 13).is_ok() { - panic!("This should have returned error") - } - assert_eq!(buf, [4, 12, 36, 1, 13, 4]); - } - - #[test] - fn test_put_str8() { - let mut buf = [0; 20]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - tw.u8(TagType::Context(1), 13).unwrap(); - tw.str8(TagType::Anonymous, &[10, 11, 12, 13, 14]).unwrap(); - tw.u16(TagType::Context(2), 0x1313).unwrap(); - tw.str8(TagType::Context(3), &[20, 21, 22]).unwrap(); - assert_eq!( - buf, - [36, 1, 13, 16, 5, 10, 11, 12, 13, 14, 37, 2, 0x13, 0x13, 48, 3, 3, 20, 21, 22] - ); - } - - #[test] - fn test_put_str16_as() { - let mut buf = [0; 20]; - let mut writebuf = WriteBuf::new(&mut buf); - let mut tw = TLVWriter::new(&mut writebuf); - - tw.u8(TagType::Context(1), 13).unwrap(); - tw.str8(TagType::Context(2), &[10, 11, 12, 13, 14]).unwrap(); - tw.str16_as(TagType::Context(3), |buf| { - buf[0] = 10; - buf[1] = 11; - Ok(2) - }) - .unwrap(); - tw.u8(TagType::Context(4), 13).unwrap(); - - assert_eq!( - buf, - [36, 1, 13, 48, 2, 5, 10, 11, 12, 13, 14, 48, 3, 2, 10, 11, 36, 4, 13, 0] - ); - } -} diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index b1e386bd..04c0a7a1 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -30,7 +30,7 @@ use crate::error::{Error, ErrorCode}; use crate::mdns::{MdnsImpl, MdnsService}; use crate::secure_channel::common::{sc_write, OpCode, SCStatusCodes, PROTO_ID_SECURE_CHANNEL}; use crate::secure_channel::status_report::StatusReport; -use crate::tlv::TLVList; +use crate::tlv::TLVElement; use crate::utils::cell::RefCell; use crate::utils::epoch::Epoch; use crate::utils::init::{init, Init}; @@ -1122,7 +1122,7 @@ impl Packet { write!( f, "; TLV:\n----------------\n{}\n----------------\n", - TLVList::new(buf) + TLVElement::new(buf) )?; } else { write!( diff --git a/rs-matter/src/utils/init.rs b/rs-matter/src/utils/init.rs index e3bedfc1..68b12c7d 100644 --- a/rs-matter/src/utils/init.rs +++ b/rs-matter/src/utils/init.rs @@ -21,6 +21,14 @@ use core::{cell::UnsafeCell, mem::MaybeUninit}; /// Re-export `pinned-init` because its API is very unstable currently (0.0.x) pub use pinned_init::*; +/// Convert a closure returning `Result, E>` into an `Init`. +pub fn into_init>(f: F) -> impl Init +where + F: FnOnce() -> Result, +{ + unsafe { init_from_closure(move |slot| f()?.__init(slot)) } +} + /// An extension trait for converting `Init` to a fallible `Init`. /// Useful when chaining an infallible initializer with a fallible chained initialization function. pub trait IntoFallibleInit: Init { @@ -77,7 +85,16 @@ pub trait InitMaybeUninit { self.try_init_with(init).unwrap() } + /// Try to initialize Self with the given fallible in-place initializer. fn try_init_with, E>(&mut self, init: I) -> Result<&mut T, E>; + + /// Initialize Self with all-zeroes + fn init_zeroed(&mut self) -> &mut T + where + T: Zeroable, + { + self.init_with(pinned_init::zeroed()) + } } impl InitMaybeUninit for MaybeUninit { diff --git a/rs-matter/src/utils/iter.rs b/rs-matter/src/utils/iter.rs new file mode 100644 index 00000000..0dfe7159 --- /dev/null +++ b/rs-matter/src/utils/iter.rs @@ -0,0 +1,48 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/// An extension trait for `Iterator` implementing several utility methods. +pub trait TryFindIterator: Iterator> + Sized { + /// Find the first element that satisfies the supplied `predicate`. + /// + /// Method name is `do_try_find` to avoid collissions with `Iterator::try_find` + /// once it gets stabilized. + fn do_try_find

(self, mut predicate: P) -> Result, E> + where + P: FnMut(&T) -> Result, + { + for val in self { + let val = val?; + + let result = predicate(&val); + match result { + Ok(matches) => { + if matches { + return Ok(Some(val)); + } + } + Err(err) => { + return Err(err); + } + } + } + + Ok(None) + } +} + +impl TryFindIterator for I where I: Iterator> {} diff --git a/rs-matter/src/utils/mod.rs b/rs-matter/src/utils/mod.rs index da766956..2b06762a 100644 --- a/rs-matter/src/utils/mod.rs +++ b/rs-matter/src/utils/mod.rs @@ -18,6 +18,7 @@ pub mod cell; pub mod epoch; pub mod init; +pub mod iter; pub mod maybe; pub mod rand; pub mod select; diff --git a/rs-matter/src/utils/storage/writebuf.rs b/rs-matter/src/utils/storage/writebuf.rs index 0dcb873e..3ecd78ec 100644 --- a/rs-matter/src/utils/storage/writebuf.rs +++ b/rs-matter/src/utils/storage/writebuf.rs @@ -20,7 +20,7 @@ use byteorder::{ByteOrder, LittleEndian}; #[derive(Debug)] pub struct WriteBuf<'a> { - buf: &'a mut [u8], + pub(crate) buf: &'a mut [u8], buf_size: usize, start: usize, end: usize, @@ -137,6 +137,16 @@ impl<'a> WriteBuf<'a> { }) } + pub fn append_with_buf(&mut self, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> Result, + { + let len = f(self.empty_as_mut_slice())?; + self.end += len; + + Ok(len) + } + pub fn append_with(&mut self, size: usize, f: F) -> Result<(), Error> where F: FnOnce(&mut Self), diff --git a/rs-matter/tests/common/commands.rs b/rs-matter/tests/common/commands.rs deleted file mode 100644 index abd27ba2..00000000 --- a/rs-matter/tests/common/commands.rs +++ /dev/null @@ -1,98 +0,0 @@ -/* - * - * Copyright (c) 2023 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use rs_matter::{ - data_model::objects::EncodeValue, - interaction_model::{ - messages::ib::{CmdPath, CmdStatus, InvResp}, - messages::msg, - }, -}; - -pub enum ExpectedInvResp { - Cmd(CmdPath, u8), - Status(CmdStatus), -} - -pub fn assert_inv_response(resp: &msg::InvResp, expected: &[ExpectedInvResp]) { - let mut index = 0; - for inv_response in resp.inv_responses.as_ref().unwrap().iter() { - println!("Validating index {}", index); - match &expected[index] { - ExpectedInvResp::Cmd(e_c, e_d) => match inv_response { - InvResp::Cmd(c) => { - assert_eq!(e_c, &c.path); - match c.data { - EncodeValue::Tlv(t) => { - assert_eq!(*e_d, t.find_tag(0).unwrap().u8().unwrap()) - } - _ => panic!("Incorrect CmdDataType"), - } - } - _ => { - panic!("Invalid response, expected InvResponse::Cmd"); - } - }, - ExpectedInvResp::Status(e_status) => match inv_response { - InvResp::Status(status) => { - assert_eq!(e_status, &status); - } - _ => { - panic!("Invalid response, expected InvResponse::Status"); - } - }, - } - println!("Index {} success", index); - index += 1; - } - assert_eq!(index, expected.len()); -} - -#[macro_export] -macro_rules! cmd_data { - ($path:expr, $data:literal) => { - CmdData::new($path, EncodeValue::Value(&($data as u32))) - }; -} - -#[macro_export] -macro_rules! echo_req { - ($endpoint:literal, $data:literal) => { - CmdData::new( - CmdPath::new( - Some($endpoint), - Some(echo_cluster::ID), - Some(echo_cluster::Commands::EchoReq as u32), - ), - EncodeValue::Value(&($data as u32)), - ) - }; -} - -#[macro_export] -macro_rules! echo_resp { - ($endpoint:literal, $data:literal) => { - ExpectedInvResp::Cmd( - CmdPath::new( - Some($endpoint), - Some(echo_cluster::ID), - Some(echo_cluster::RespCommands::EchoResp as u32), - ), - $data, - ) - }; -} diff --git a/rs-matter/tests/common/e2e.rs b/rs-matter/tests/common/e2e.rs new file mode 100644 index 00000000..e0936231 --- /dev/null +++ b/rs-matter/tests/common/e2e.rs @@ -0,0 +1,300 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use core::num::NonZeroU8; + +use embassy_futures::select::select3; +use embassy_sync::{ + blocking_mutex::raw::NoopRawMutex, + zerocopy_channel::{Channel, Receiver, Sender}, +}; + +use rs_matter::acl::{AclEntry, AuthMode}; +use rs_matter::data_model::cluster_basic_information::BasicInfoConfig; +use rs_matter::data_model::core::{DataModel, IMBuffer}; +use rs_matter::data_model::objects::{AsyncHandler, AsyncMetadata, Privilege}; +use rs_matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher}; +use rs_matter::data_model::subscriptions::Subscriptions; +use rs_matter::error::Error; +use rs_matter::mdns::MdnsService; +use rs_matter::respond::Responder; +use rs_matter::transport::exchange::Exchange; +use rs_matter::transport::network::{ + Address, NetworkReceive, NetworkSend, MAX_RX_PACKET_SIZE, MAX_TX_PACKET_SIZE, +}; +use rs_matter::transport::session::{NocCatIds, ReservedSession, SessionMode}; +use rs_matter::utils::select::Coalesce; +use rs_matter::utils::storage::pooled::PooledBuffers; +use rs_matter::{Matter, MATTER_PORT}; + +pub mod im; +pub mod test; +pub mod tlv; + +// For backwards compatibility +pub type ImEngine = E2eRunner; + +// For backwards compatibility +pub const IM_ENGINE_PEER_ID: u64 = E2eRunner::PEER_ID; + +/// A test runner for end-to-end tests. +/// +/// The runner works by instantiating two `Matter` instances, one for the local node and one for the +/// remote node which is being tested. The instances are connected over a fake UDP network. +/// +/// The runner then pre-set a single session between the two nodes and runs all tests in the context +/// of a single exchange per test run. +/// +/// All transport-related state is reset between test runs. +pub struct E2eRunner { + pub matter: Matter<'static>, + matter_client: Matter<'static>, + buffers: PooledBuffers<10, NoopRawMutex, IMBuffer>, + subscriptions: Subscriptions<1>, + cat_ids: NocCatIds, +} + +impl E2eRunner { + const ADDR: Address = Address::Udp(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))); + + const BASIC_INFO: BasicInfoConfig<'static> = BasicInfoConfig { + vid: 1, + pid: 1, + hw_ver: 1, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "E2E", + device_name: "E2E", + product_name: "E2E", + vendor_name: "E2E", + }; + + /// The ID of the local Matter instance + pub const PEER_ID: u64 = 445566; + + /// The ID of the remote (tested) Matter instance + pub const REMOTE_PEER_ID: u64 = 123456; + + /// Create a new runner with default category IDs. + pub fn new_default() -> Self { + Self::new(NocCatIds::default()) + } + + /// Create a new runner with the given category IDs. + pub fn new(cat_ids: NocCatIds) -> Self { + Self { + matter: Self::new_matter(), + matter_client: Self::new_matter(), + buffers: PooledBuffers::new(0), + subscriptions: Subscriptions::new(), + cat_ids, + } + } + + /// Initialize the local and remote (tested) Matter instances + /// that the runner owns + pub fn init(&self) -> Result<(), Error> { + Self::init_matter( + &self.matter, + Self::REMOTE_PEER_ID, + Self::PEER_ID, + &self.cat_ids, + )?; + + Self::init_matter( + &self.matter_client, + Self::PEER_ID, + Self::REMOTE_PEER_ID, + &self.cat_ids, + ) + } + + /// Get the Matter instance for the local node (the test driver). + pub fn matter_client(&self) -> &Matter<'static> { + &self.matter_client + } + + /// Add a default ACL entry to the remote (tested) Matter instance. + pub fn add_default_acl(&self) { + // Only allow the standard peer node id of the IM Engine + let mut default_acl = + AclEntry::new(NonZeroU8::new(1).unwrap(), Privilege::ADMIN, AuthMode::Case); + default_acl.add_subject(Self::PEER_ID).unwrap(); + self.matter.acl_mgr.borrow_mut().add(default_acl).unwrap(); + } + + /// Initiates a new exchange on the local Matter instance + pub async fn initiate_exchange(&self) -> Result, Error> { + Exchange::initiate( + self.matter_client(), + 1, /*just one fabric in tests*/ + Self::REMOTE_PEER_ID, + true, + ) + .await + } + + /// Runs both the local and the remote (tested) Matter instances, + /// by connecting them with a fake UDP network. + /// + /// The remote (tested) Matter instance will run with the provided DM handler. + /// + /// The local Matter instance does not have a DM handler as it is only used to + /// drive the tests (i.e. it does not have any server clusters and such). + pub async fn run(&self, handler: H) -> Result<(), Error> + where + H: AsyncHandler + AsyncMetadata, + { + self.init()?; + + let mut buf1 = [heapless::Vec::new(); 1]; + let mut buf2 = [heapless::Vec::new(); 1]; + + let mut pipe1 = NetworkPipe::::new(&mut buf1); + let mut pipe2 = NetworkPipe::::new(&mut buf2); + + let (send_remote, recv_local) = pipe1.split(); + let (send_local, recv_remote) = pipe2.split(); + + let matter_client = &self.matter_client; + + let responder = Responder::new( + "Default", + DataModel::new(&self.buffers, &self.subscriptions, handler), + &self.matter, + 0, + ); + + select3( + matter_client + .transport_mgr + .run(NetworkSendImpl(send_local), NetworkReceiveImpl(recv_local)), + self.matter.transport_mgr.run( + NetworkSendImpl(send_remote), + NetworkReceiveImpl(recv_remote), + ), + responder.run::<4>(), + ) + .coalesce() + .await + } + + fn new_matter() -> Matter<'static> { + #[cfg(feature = "std")] + use rs_matter::utils::epoch::sys_epoch as epoch; + + #[cfg(not(feature = "std"))] + use rs_matter::utils::epoch::dummy_epoch as epoch; + + #[cfg(feature = "std")] + use rs_matter::utils::rand::sys_rand as rand; + + #[cfg(not(feature = "std"))] + use rs_matter::utils::rand::dummy_rand as rand; + + let matter = Matter::new( + &Self::BASIC_INFO, + &E2eDummyDevAtt, + MdnsService::Disabled, + epoch, + rand, + MATTER_PORT, + ); + + matter.initialize_transport_buffers().unwrap(); + + matter + } + + fn init_matter( + matter: &Matter, + local_nodeid: u64, + remote_nodeid: u64, + cat_ids: &NocCatIds, + ) -> Result<(), Error> { + matter.transport_mgr.reset()?; + + let mut session = ReservedSession::reserve_now(matter)?; + + session.update( + local_nodeid, + remote_nodeid, + 1, + 1, + Self::ADDR, + SessionMode::Case { + fab_idx: NonZeroU8::new(1).unwrap(), + cat_ids: *cat_ids, + }, + None, + None, + None, + )?; + + session.complete(); + + Ok(()) + } +} + +/// A dummy device attribute data fetcher that always returns the same hard-coded test data. +struct E2eDummyDevAtt; + +impl DevAttDataFetcher for E2eDummyDevAtt { + fn get_devatt_data(&self, _data_type: DataType, _data: &mut [u8]) -> Result { + Ok(2) + } +} + +type NetworkPipe<'a, const N: usize> = Channel<'a, NoopRawMutex, heapless::Vec>; + +struct NetworkReceiveImpl<'a, const N: usize>(Receiver<'a, NoopRawMutex, heapless::Vec>); + +impl<'a, const N: usize> NetworkSend for NetworkSendImpl<'a, N> { + async fn send_to(&mut self, data: &[u8], _addr: Address) -> Result<(), Error> { + let vec = self.0.send().await; + + vec.clear(); + vec.extend_from_slice(data).unwrap(); + + self.0.send_done(); + + Ok(()) + } +} + +struct NetworkSendImpl<'a, const N: usize>(Sender<'a, NoopRawMutex, heapless::Vec>); + +impl<'a, const N: usize> NetworkReceive for NetworkReceiveImpl<'a, N> { + async fn wait_available(&mut self) -> Result<(), Error> { + self.0.receive().await; + + Ok(()) + } + + async fn recv_from(&mut self, buffer: &mut [u8]) -> Result<(usize, Address), Error> { + let vec = self.0.receive().await; + + buffer[..vec.len()].copy_from_slice(vec); + let len = vec.len(); + + self.0.receive_done(); + + Ok((len, E2eRunner::ADDR)) + } +} diff --git a/rs-matter/tests/common/e2e/im.rs b/rs-matter/tests/common/e2e/im.rs new file mode 100644 index 00000000..924f2d81 --- /dev/null +++ b/rs-matter/tests/common/e2e/im.rs @@ -0,0 +1,668 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use bitflags::bitflags; + +use rs_matter::error::Error; +use rs_matter::interaction_model::core::{OpCode, PROTO_ID_INTERACTION_MODEL}; +use rs_matter::interaction_model::messages::ib::{ + AttrPath, AttrResp, AttrStatus, DataVersionFilter, EventFilter, EventPath, +}; +use rs_matter::interaction_model::messages::msg::{ReportDataMsg, WriteReqTag}; +use rs_matter::tlv::{FromTLV, Slice, TLVElement, TLVTag, TLVWrite, TLVWriter, ToTLV}; +use rs_matter::transport::exchange::MessageMeta; +use rs_matter::utils::storage::WriteBuf; + +use super::tlv::{TLVTest, TestToTLV}; + +use attributes::{TestAttrData, TestAttrResp}; +use commands::{TestCmdData, TestCmdResp}; + +pub mod attributes; +pub mod commands; +pub mod echo_cluster; +pub mod handler; + +/// A `ReadReq` alternative more suitable for testing. +/// +/// Unlike `ReadReq`, `TestReadReq` uses regular Rust slices where +/// `ReadReq` uses `TLVArray` instances. +#[derive(Debug, Default, Clone, ToTLV)] +pub struct TestReadReq<'a> { + pub attr_requests: Option<&'a [AttrPath]>, + pub event_requests: Option<&'a [EventPath]>, + pub event_filters: Option<&'a [EventFilter]>, + pub fabric_filtered: bool, + pub dataver_filters: Option<&'a [DataVersionFilter]>, +} + +impl<'a> TestReadReq<'a> { + /// Create a new `TestReadReq` instance. + pub const fn new() -> Self { + Self { + attr_requests: None, + event_requests: None, + event_filters: None, + fabric_filtered: false, + dataver_filters: None, + } + } + + /// Create a new `TestReadReq` instance with the provided attribute requests. + pub const fn reqs(reqs: &'a [AttrPath]) -> Self { + Self { + attr_requests: Some(reqs), + ..Self::new() + } + } +} + +/// A `ReadResp` alternative more suitable for testing. +/// +/// Unlike `ReadResp`, `TestReadResp` uses regular Rust slices where +/// `ReadResp` uses `TLVArray` instances. Also, it utilizes `TestReadData` +/// for the write requests, where `ReadResp` uses `AttrData` instances. +#[derive(Debug, Default, Clone)] +pub struct TestWriteReq<'a> { + pub suppress_response: Option, + pub timed_request: Option, + pub write_requests: &'a [TestAttrData<'a>], + pub more_chunked: Option, +} + +impl<'a> TestWriteReq<'a> { + /// Create a new `TestWriteReq` instance. + pub const fn new() -> Self { + Self { + suppress_response: None, + timed_request: None, + write_requests: &[], + more_chunked: None, + } + } + + /// Create a new `TestWriteReq` instance with the provided write requests. + pub const fn reqs(reqs: &'a [TestAttrData<'a>]) -> Self { + Self { + write_requests: reqs, + suppress_response: Some(true), + ..Self::new() + } + } +} + +impl TestToTLV for TestWriteReq<'_> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + if let Some(supress_response) = self.suppress_response { + tw.bool( + &TLVTag::Context(WriteReqTag::SuppressResponse as _), + supress_response, + )?; + } + + if let Some(timed_request) = self.timed_request { + tw.bool( + &TLVTag::Context(WriteReqTag::TimedRequest as _), + timed_request, + )?; + } + + tw.start_array(&TLVTag::Context(WriteReqTag::WriteRequests as _))?; + for write_request in self.write_requests { + write_request.test_to_tlv(&TLVTag::Anonymous, tw)?; + } + tw.end_container()?; + + if let Some(more_chunked) = self.more_chunked { + tw.bool( + &TLVTag::Context(WriteReqTag::MoreChunked as _), + more_chunked, + )?; + } + + tw.end_container()?; + + Ok(()) + } +} + +/// A `WriteResp` alternative more suitable for testing. +/// +/// Unlike `WriteResp`, `TestWriteResp` uses regular Rust slices where +/// `WriteResp` uses `TLVArray` instances. +#[derive(ToTLV, Debug, Default, Clone)] +#[tlvargs(lifetime = "'a")] +pub struct TestWriteResp<'a> { + pub write_responses: Slice<'a, AttrStatus>, +} + +impl<'a> TestWriteResp<'a> { + /// Create a new `TestWriteResp` instance with the provided write responses. + pub const fn resp(write_responses: &'a [AttrStatus]) -> Self { + Self { write_responses } + } +} + +/// A `SubscribeResp` alternative more suitable for testing. +/// +/// Unlike `SubscribeResp`, `TestSubscribeResp` uses regular Rust slices where +/// `SubscribeResp` uses `TLVArray` instances. +#[derive(Debug, Default, Clone, ToTLV)] +pub struct TestSubscribeReq<'a> { + pub keep_subs: bool, + pub min_int_floor: u16, + pub max_int_ceil: u16, + pub attr_requests: Option<&'a [AttrPath]>, + pub event_requests: Option<&'a [EventPath]>, + pub event_filters: Option<&'a [EventFilter]>, + // The Context Tags are discontiguous for some reason + pub _dummy: Option, + pub fabric_filtered: bool, + pub dataver_filters: Option<&'a [DataVersionFilter]>, +} + +impl<'a> TestSubscribeReq<'a> { + /// Create a new `TestSubscribeReq` instance. + pub const fn new() -> Self { + Self { + keep_subs: false, + min_int_floor: 0, + max_int_ceil: 0, + attr_requests: None, + event_requests: None, + event_filters: None, + _dummy: None, + fabric_filtered: false, + dataver_filters: None, + } + } + + /// Create a new `TestSubscribeReq` instance with the provided attribute requests. + pub const fn reqs(reqs: &'a [AttrPath]) -> Self { + Self { + attr_requests: Some(reqs), + ..Self::new() + } + } +} + +/// A `ReportDataMsg` alternative more suitable for testing. +/// +/// Unlike `ReportDataMsg`, `TestReportDataMsg` uses regular Rust slices where +/// `ReportDataMsg` uses `TLVArray` instances. Also, it utilizes `TestAttrResp` +/// for the attribute reports, where `ReportDataMsg` uses `AttrResp` instances. +#[derive(Debug, Default, Clone)] +pub struct TestReportDataMsg<'a> { + pub subscription_id: Option, + pub attr_reports: Option<&'a [TestAttrResp<'a>]>, + // TODO + pub event_reports: Option, + pub more_chunks: Option, + pub suppress_response: Option, +} + +impl<'a> TestReportDataMsg<'a> { + /// Create a new `TestReportDataMsg` instance. + pub const fn new() -> Self { + Self { + subscription_id: None, + attr_reports: None, + event_reports: None, + more_chunks: None, + suppress_response: None, + } + } + + /// Create a new `TestReportDataMsg` instance with the provided attribute reports. + pub const fn reports(reports: &'a [TestAttrResp<'a>]) -> Self { + Self { + attr_reports: Some(reports), + suppress_response: Some(true), + ..Self::new() + } + } +} + +impl<'a> TestToTLV for TestReportDataMsg<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + if let Some(subscription_id) = self.subscription_id { + tw.u32(&TLVTag::Context(0), subscription_id)?; + } + + if let Some(attr_reports) = self.attr_reports { + tw.start_array(&TLVTag::Context(1))?; + for attr_report in attr_reports { + attr_report.test_to_tlv(&TLVTag::Anonymous, tw)?; + } + tw.end_container()?; + } + + if let Some(event_reports) = self.event_reports { + tw.bool(&TLVTag::Context(2), event_reports)?; + } + + if let Some(more_chunks) = self.more_chunks { + tw.bool(&TLVTag::Context(3), more_chunks)?; + } + + if let Some(suppress_response) = self.suppress_response { + tw.bool(&TLVTag::Context(4), suppress_response)?; + } + + tw.end_container()?; + + Ok(()) + } +} + +/// A `InvReq` alternative more suitable for testing. +/// +/// Unlike `InvReq`, `TestInvReq` uses regular Rust slices where +/// `InvReq` uses `TLVArray` instances. Also, it utilizes `TestCmdData` +/// for the invocation requests, where `InvReq` uses `CmdData` instances. +#[derive(Debug, Default, Clone)] +pub struct TestInvReq<'a> { + pub suppress_response: Option, + pub timed_request: Option, + pub inv_requests: Option<&'a [TestCmdData<'a>]>, +} + +impl<'a> TestInvReq<'a> { + /// Create a new `TestInvReq` instance. + pub const fn new() -> Self { + Self { + suppress_response: None, + timed_request: None, + inv_requests: None, + } + } + + /// Create a new `TestInvReq` instance with the provided command requests. + pub const fn reqs(reqs: &'a [TestCmdData<'a>]) -> Self { + Self { + inv_requests: Some(reqs), + ..Self::new() + } + } +} + +impl<'a> TestToTLV for TestInvReq<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + if let Some(suppress_response) = self.suppress_response { + tw.bool(&TLVTag::Context(0), suppress_response)?; + } + + if let Some(timed_request) = self.timed_request { + tw.bool(&TLVTag::Context(1), timed_request)?; + } + + if let Some(inv_requests) = self.inv_requests { + tw.start_array(&TLVTag::Context(2))?; + for inv_request in inv_requests { + inv_request.test_to_tlv(&TLVTag::Anonymous, tw)?; + } + tw.end_container()?; + } + + tw.end_container()?; + + Ok(()) + } +} + +/// An `InvResp` alternative more suitable for testing. +/// +/// Unlike `InvResp`, `TestInvResp` uses regular Rust slices where +/// `InvResp` uses `TLVArray` instances. Also, it utilizes `TestCmdResp` +/// for the invocation responses, where `InvResp` uses `CmdResp` instances. +#[derive(Debug, Default, Clone)] +pub struct TestInvResp<'a> { + pub suppress_response: Option, + pub inv_responses: Option<&'a [TestCmdResp<'a>]>, +} + +impl<'a> TestInvResp<'a> { + /// Create a new `TestInvResp` instance with the provided command responses. + pub const fn resp(inv_responses: &'a [TestCmdResp<'a>]) -> Self { + Self { + suppress_response: Some(false), + inv_responses: Some(inv_responses), + } + } +} + +impl<'a> TestToTLV for TestInvResp<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + if let Some(suppress_response) = self.suppress_response { + tw.bool(&TLVTag::Context(0), suppress_response)?; + } + + if let Some(inv_responses) = self.inv_responses { + tw.start_array(&TLVTag::Context(1))?; + for inv_response in inv_responses { + inv_response.test_to_tlv(&TLVTag::Anonymous, tw)?; + } + tw.end_container()?; + } + + tw.end_container()?; + + Ok(()) + } +} + +bitflags! { + /// Flags for trimming data from reply payloads. + /// + /// Useful when the E2E tests do now want to assert on e.g. + /// dataver, and/or concrete data returned by the Matter server. + /// + /// Currently, only trimming IM `ReportData` payloads is supported, + /// but if the end-to-end tests grow, this could be expanded to other IM messages. + #[repr(transparent)] + #[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct ReplyProcessor: u8 { + const REMOVE_ATTRDATA_DATAVER = 0b01; + const REMOVE_ATTRDATA_VALUE = 0b10; + } +} + +impl ReplyProcessor { + /// Remove the dataver and/or the data value from the `AttrData` payload, if so requested + pub fn process(&self, element: &TLVElement, buf: &mut [u8]) -> Result { + let mut wb = WriteBuf::new(buf); + let mut tw = TLVWriter::new(&mut wb); + + if self.is_empty() { + element.to_tlv(&TLVTag::Anonymous, &mut tw)?; + + return Ok(wb.get_tail()); + } + + let report_data = ReportDataMsg::from_tlv(element)?; + + tw.start_struct(&TLVTag::Anonymous)?; + + if let Some(subscription_id) = report_data.subscription_id { + tw.u32(&TLVTag::Context(0), subscription_id)?; + } + + if let Some(attr_reports) = report_data.attr_reports { + tw.start_array(&TLVTag::Context(1))?; + + for attr_report in attr_reports { + let mut attr_report = attr_report?; + + if let AttrResp::Data(data) = &mut attr_report { + if self.contains(Self::REMOVE_ATTRDATA_DATAVER) { + data.data_ver = None; + } + + if self.contains(Self::REMOVE_ATTRDATA_VALUE) { + data.data = TLVElement::new(&[]); + } + } + + attr_report.to_tlv(&TLVTag::Anonymous, &mut tw)?; + } + + tw.end_container()?; + } + + if let Some(event_reports) = report_data.event_reports { + tw.bool(&TLVTag::Context(2), event_reports)?; + } + + if let Some(more_chunks) = report_data.more_chunks { + tw.bool(&TLVTag::Context(3), more_chunks)?; + } + + if let Some(suppress_response) = report_data.suppress_response { + tw.bool(&TLVTag::Context(4), suppress_response)?; + } + + tw.end_container()?; + + Ok(wb.get_tail()) + } + + /// Process the supplied element without removing any data + pub fn none(element: &TLVElement, buf: &mut [u8]) -> Result { + Self::empty().process(element, buf) + } + + /// Process the supplied element with removing the dataver from the `AttrData` payload + pub fn remove_attr_dataver(element: &TLVElement, buf: &mut [u8]) -> Result { + Self::REMOVE_ATTRDATA_DATAVER.process(element, buf) + } + + /// Process the supplied element with removing the data value from the `AttrData` payload + pub fn remove_attr_data<'a>(element: &TLVElement, buf: &mut [u8]) -> Result { + (Self::REMOVE_ATTRDATA_VALUE | Self::REMOVE_ATTRDATA_DATAVER).process(element, buf) + } +} + +impl TLVTest +where + F: Fn(&TLVElement, &mut [u8]) -> Result, +{ + /// Create a new TLV test instance with input payload being the IM `ReadRequest` message + /// and the expected payload being the IM `ReportData` message. + pub const fn read(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::ReadRequest as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::ReportData as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } + + /// Create a new TLV test instance with input payload being the IM `StatusResponse` message + /// and the expected payload being the IM `ReportData` message. + pub const fn continue_report(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::StatusResponse as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::ReportData as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } + + /// Create a new TLV test instance with input payload being the IM `WriteRequest` message + /// and the expected payload being the IM `WriteResponse` message. + pub const fn write(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::WriteRequest as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::WriteResponse as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } + + /// Create a new TLV test instance with input payload being the IM `SubscribeRequest` message + /// and the expected payload being the IM `ReportData` message. + pub const fn subscribe(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::SubscribeRequest as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::ReportData as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } + + /// Create a new TLV test instance with input payload being the IM `StatusResponse` message + /// and the expected payload being the IM `SubscribeResponse` message. + pub const fn subscribe_final(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::StatusResponse as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::SubscribeResponse as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } + + /// Create a new TLV test instance with input payload being the IM `InvokeRequest` message + /// and the expected payload being the IM `InvokeResponse` message. + pub const fn invoke(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::InvokeRequest as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::InvokeResponse as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } + + /// Create a new TLV test instance with input payload being the IM `TimedRequest` message + /// and the expected payload being the IM `StatusResponse` message. + pub const fn timed(input_payload: I, expected_payload: E, process_reply: F) -> Self { + Self { + input_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::TimedRequest as _, + true, + ), + input_payload, + expected_meta: MessageMeta::new( + PROTO_ID_INTERACTION_MODEL, + OpCode::StatusResponse as _, + true, + ), + expected_payload, + process_reply, + delay_ms: None, + } + } +} + +impl<'a> + TLVTest< + TestReadReq<'a>, + TestReportDataMsg<'a>, + fn(&TLVElement, &mut [u8]) -> Result, + > +{ + /// Create a new TLV test instance with input payload being the IM `ReadRequest` message + /// and the expected payload being the IM `ReportData` message and the input payload and the + /// expected payload being the provided attribute requests and responses. + /// + /// The reply will be processed to remove the data version from the `AttrData` payload. + pub const fn read_attrs(input: &'a [AttrPath], expected: &'a [TestAttrResp<'a>]) -> Self { + Self::read( + TestReadReq::reqs(input), + TestReportDataMsg::reports(expected), + ReplyProcessor::remove_attr_dataver, + ) + } +} + +impl<'a> + TLVTest, TestWriteResp<'a>, fn(&TLVElement, &mut [u8]) -> Result> +{ + /// Create a new TLV test instance with input payload being the IM `WriteRequest` message + /// and the expected payload being the IM `WriteResponse` message and the input payload and the + /// expected payload being the provided write requests and responses. + pub const fn write_attrs(input: &'a [TestAttrData<'a>], expected: &'a [AttrStatus]) -> Self { + Self::write( + TestWriteReq::reqs(input), + TestWriteResp::resp(expected), + ReplyProcessor::none, + ) + } +} + +impl<'a> + TLVTest, TestInvResp<'a>, fn(&TLVElement, &mut [u8]) -> Result> +{ + /// Create a new TLV test instance with input payload being the IM `InvokeRequest` message + /// and the expected payload being the IM `InvokeResponse` message and the input payload and the + /// expected payload being the provided command requests and responses. + pub const fn inv_cmds(input: &'a [TestCmdData<'a>], expected: &'a [TestCmdResp<'a>]) -> Self { + Self::invoke( + TestInvReq::reqs(input), + TestInvResp::resp(expected), + ReplyProcessor::none, + ) + } +} diff --git a/rs-matter/tests/common/e2e/im/attributes.rs b/rs-matter/tests/common/e2e/im/attributes.rs new file mode 100644 index 00000000..0b751e31 --- /dev/null +++ b/rs-matter/tests/common/e2e/im/attributes.rs @@ -0,0 +1,179 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use rs_matter::data_model::objects::{AsyncHandler, AsyncMetadata}; +use rs_matter::error::Error; +use rs_matter::interaction_model::messages::ib::{AttrPath, AttrStatus}; +use rs_matter::interaction_model::messages::GenericPath; +use rs_matter::tlv::{TLVTag, TLVWrite, TLVWriter}; + +use crate::common::e2e::tlv::{TLVTest, TestToTLV}; +use crate::common::e2e::E2eRunner; + +/// A macro for creating a `TestAttrResp` instance of variant `Status`. +#[macro_export] +macro_rules! attr_status { + ($path:expr, $status:expr) => { + $crate::common::e2e::im::attributes::TestAttrResp::AttrStatus( + rs_matter::interaction_model::messages::ib::AttrStatus::new($path, $status, 0), + ) + }; +} + +/// A macro for creating a `TestAttrResp` instance of variant `AttrData` taking +/// a `GenericPath` instance and data. +#[macro_export] +macro_rules! attr_data_path { + ($path:expr, $data:expr) => { + $crate::common::e2e::im::attributes::TestAttrResp::AttrData( + $crate::common::e2e::im::attributes::TestAttrData { + data_ver: None, + path: rs_matter::interaction_model::messages::ib::AttrPath::new(&$path), + data: $data, + }, + ) + }; +} + +/// A macro for creating a `TestAttrResp` instance of variant `AttrData` taking +/// an endpoint, cluster, attribute, and data. +/// +/// Unlike the `attr_data_path` variant, this one does not support wildcards, +/// but has a shorter syntax. +#[macro_export] +macro_rules! attr_data { + ($endpoint:expr, $cluster:expr, $attr: expr, $data:expr) => { + $crate::attr_data_path!( + rs_matter::interaction_model::messages::GenericPath::new( + Some($endpoint), + Some($cluster), + Some($attr as _) + ), + $data + ) + }; +} + +/// An `AttrData` altenrative more suitable for testing. +/// +/// The main difference is that `TestAttrData::data` implements `TestToTLV`, whereas +/// `AttrData::data` is a `TLVElement`. +#[derive(Debug, Clone)] +pub struct TestAttrData<'a> { + pub data_ver: Option, + pub path: AttrPath, + pub data: Option<&'a dyn TestToTLV>, +} + +impl<'a> TestAttrData<'a> { + /// Create a new `TestAttrData` instance. + pub const fn new(data_ver: Option, path: AttrPath, data: &'a dyn TestToTLV) -> Self { + Self { + data_ver, + path, + data: Some(data), + } + } +} + +impl<'a> TestToTLV for TestAttrData<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + if let Some(data_ver) = self.data_ver { + tw.u32(&TLVTag::Context(0), data_ver)?; + } + + self.path.test_to_tlv(&TLVTag::Context(1), tw)?; + + if let Some(data) = self.data { + data.test_to_tlv(&TLVTag::Context(2), tw)?; + } + + tw.end_container()?; + + Ok(()) + } +} + +/// An `AttrResp` alternative more suitable for testing, in that the +/// `TestAttrResp::AttrData` variant uses `TestAttrData` instead of `AttrData`. +#[derive(Debug)] +pub enum TestAttrResp<'a> { + AttrStatus(AttrStatus), + AttrData(TestAttrData<'a>), +} + +impl<'a> TestAttrResp<'a> { + /// Create a new `TestAttrResp` instance with an `AttrData` value. + pub fn data(path: &GenericPath, data: &'a dyn TestToTLV) -> Self { + Self::AttrData(TestAttrData::new(None, AttrPath::new(path), data)) + } +} + +impl<'a> TestToTLV for TestAttrResp<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + match self { + TestAttrResp::AttrStatus(status) => status.test_to_tlv(&TLVTag::Context(0), tw), + TestAttrResp::AttrData(data) => data.test_to_tlv(&TLVTag::Context(1), tw), + }?; + + tw.end_container() + } +} + +impl E2eRunner { + /// For backwards compatibility. + pub fn read_reqs<'a>(input: &'a [AttrPath], expected: &'a [TestAttrResp<'a>]) { + let runner = Self::new_default(); + runner.add_default_acl(); + runner.handle_read_reqs(runner.handler(), input, expected) + } + + /// For backwards compatibility. + pub fn write_reqs<'a>(input: &'a [TestAttrData<'a>], expected: &'a [AttrStatus]) { + let runner = Self::new_default(); + runner.add_default_acl(); + runner.handle_write_reqs(runner.handler(), input, expected) + } + + /// For backwards compatibility. + pub fn handle_read_reqs<'a, H>( + &self, + handler: H, + input: &'a [AttrPath], + expected: &'a [TestAttrResp<'a>], + ) where + H: AsyncHandler + AsyncMetadata, + { + self.test_one(handler, TLVTest::read_attrs(input, expected)) + } + + /// For backwards compatibility. + pub fn handle_write_reqs<'a, H>( + &self, + handler: H, + input: &'a [TestAttrData<'a>], + expected: &'a [AttrStatus], + ) where + H: AsyncHandler + AsyncMetadata, + { + self.test_one(handler, TLVTest::write_attrs(input, expected)) + } +} diff --git a/rs-matter/tests/common/e2e/im/commands.rs b/rs-matter/tests/common/e2e/im/commands.rs new file mode 100644 index 00000000..568dded3 --- /dev/null +++ b/rs-matter/tests/common/e2e/im/commands.rs @@ -0,0 +1,137 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use rs_matter::data_model::objects::{AsyncHandler, AsyncMetadata}; +use rs_matter::error::Error; +use rs_matter::interaction_model::messages::ib::{CmdPath, CmdStatus}; +use rs_matter::tlv::{TLVTag, TLVWrite, TLVWriter}; + +use crate::common::e2e::tlv::{TLVTest, TestToTLV}; +use crate::common::e2e::E2eRunner; + +/// A macro for creating a `TestCmdData` instance by using literal values for data. +#[macro_export] +macro_rules! cmd_data { + ($path:expr, $data:literal) => { + $crate::common::e2e::im::commands::TestCmdData::new($path, &($data as u32)) + }; +} + +#[macro_export] +macro_rules! echo_req { + ($endpoint:literal, $data:literal) => { + $crate::common::e2e::im::commands::TestCmdData::new( + rs_matter::interaction_model::messages::ib::CmdPath::new( + Some($endpoint), + Some($crate::common::e2e::im::echo_cluster::ID), + Some($crate::common::e2e::im::echo_cluster::Commands::EchoReq as u32), + ), + &($data as u32), + ) + }; +} + +#[macro_export] +macro_rules! echo_resp { + ($endpoint:literal, $data:literal) => { + $crate::common::e2e::im::commands::TestCmdResp::Cmd( + $crate::common::e2e::im::commands::TestCmdData::new( + rs_matter::interaction_model::messages::ib::CmdPath::new( + Some($endpoint), + Some($crate::common::e2e::im::echo_cluster::ID), + Some($crate::common::e2e::im::echo_cluster::RespCommands::EchoResp as u32), + ), + &($data as u32), + ), + ) + }; +} + +/// A `TestCmdData` alternative more suitable for testing. +/// +/// The main difference is that `TestCmdData::data` implements `TestToTLV`, whereas +/// `CmdData::data` is a `TLVElement`. +#[derive(Debug, Clone)] +pub struct TestCmdData<'a> { + pub path: CmdPath, + pub data: &'a dyn TestToTLV, +} + +impl<'a> TestCmdData<'a> { + /// Create a new `TestCmdData` instance. + pub const fn new(path: CmdPath, data: &'a dyn TestToTLV) -> Self { + Self { path, data } + } +} + +impl<'a> TestToTLV for TestCmdData<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + self.path.test_to_tlv(&TLVTag::Context(0), tw)?; + + self.data.test_to_tlv(&TLVTag::Context(1), tw)?; + + tw.end_container()?; + + Ok(()) + } +} + +/// A `TestCmdResp` alternative more suitable for testing. +/// +/// The main difference is that `TestCmdResp::data` implements `TestToTLV`, whereas +/// `CmdResp::data` is a `TLVElement`. +#[derive(Debug, Clone)] +pub enum TestCmdResp<'a> { + Cmd(TestCmdData<'a>), + Status(CmdStatus), +} + +impl<'a> TestToTLV for TestCmdResp<'a> { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + tw.start_struct(tag)?; + + match self { + TestCmdResp::Cmd(data) => data.test_to_tlv(&TLVTag::Context(0), tw), + TestCmdResp::Status(status) => status.test_to_tlv(&TLVTag::Context(1), tw), + }?; + + tw.end_container() + } +} + +impl E2eRunner { + /// For backwards compatibility. + pub fn commands<'a>(input: &'a [TestCmdData<'a>], expected: &'a [TestCmdResp<'a>]) { + let runner = Self::new_default(); + runner.add_default_acl(); + runner.handle_commands(runner.handler(), input, expected) + } + + /// For backwards compatibility. + pub fn handle_commands<'a, H>( + &self, + handler: H, + input: &'a [TestCmdData<'a>], + expected: &'a [TestCmdResp<'a>], + ) where + H: AsyncHandler + AsyncMetadata, + { + self.test_one(handler, TLVTest::inv_cmds(input, expected)) + } +} diff --git a/rs-matter/tests/common/echo_cluster.rs b/rs-matter/tests/common/e2e/im/echo_cluster.rs similarity index 87% rename from rs-matter/tests/common/echo_cluster.rs rename to rs-matter/tests/common/e2e/im/echo_cluster.rs index 43117133..17a96948 100644 --- a/rs-matter/tests/common/echo_cluster.rs +++ b/rs-matter/tests/common/e2e/im/echo_cluster.rs @@ -16,23 +16,30 @@ */ use core::cell::Cell; +use core::fmt::Debug; + use std::sync::{Arc, Mutex, Once}; use num_derive::FromPrimitive; -use rs_matter::{ - attribute_enum, command_enum, - data_model::objects::{ - Access, AttrData, AttrDataEncoder, AttrDataWriter, AttrDetails, AttrType, Attribute, - Cluster, CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, NonBlockingHandler, - Quality, ATTRIBUTE_LIST, FEATURE_MAP, - }, - error::{Error, ErrorCode}, - interaction_model::messages::ib::{attr_list_write, ListOperation}, - tlv::{TLVElement, TagType}, - transport::exchange::Exchange, -}; + use strum::{EnumDiscriminants, FromRepr}; +use rs_matter::data_model::objects::{ + Access, AttrDataEncoder, AttrDataWriter, AttrDetails, AttrType, Attribute, Cluster, + CmdDataEncoder, CmdDataWriter, CmdDetails, Dataver, Handler, NonBlockingHandler, Quality, + ATTRIBUTE_LIST, FEATURE_MAP, +}; +use rs_matter::error::{Error, ErrorCode}; +use rs_matter::interaction_model::messages::ib::{attr_list_write, ListOperation}; +use rs_matter::tlv::{TLVElement, TLVTag, TLVWrite}; +use rs_matter::transport::exchange::Exchange; +use rs_matter::{attribute_enum, command_enum}; + +pub const WRITE_LIST_MAX: usize = 5; + +pub const ATTR_CUSTOM_VALUE: u32 = 0xcafebeef; +pub const ATTR_WRITE_DEFAULT_VALUE: u16 = 0xcafe; + pub const ID: u32 = 0xABCD; #[derive(FromRepr, EnumDiscriminants)] @@ -122,8 +129,7 @@ impl TestChecker { } } -pub const WRITE_LIST_MAX: usize = 5; - +/// A sample cluster that echoes back the input data. Useful for testing. pub struct EchoCluster { pub data_ver: Dataver, pub multiplier: u8, @@ -159,9 +165,9 @@ impl EchoCluster { let tc_handle = TestChecker::get().unwrap(); let tc = tc_handle.lock().unwrap(); - writer.start_array(AttrDataWriter::TAG)?; + writer.start_array(&AttrDataWriter::TAG)?; for i in tc.write_list.iter().flatten() { - writer.u16(TagType::Anonymous, *i)?; + writer.u16(&TLVTag::Anonymous, *i)?; } writer.end_container()?; @@ -174,7 +180,11 @@ impl EchoCluster { } } - pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + pub fn write( + &self, + attr: &AttrDetails, + data: rs_matter::data_model::objects::AttrData, + ) -> Result<(), Error> { let data = data.with_dataver(self.data_ver.get())?; match attr.attr_id.try_into()? { @@ -207,10 +217,8 @@ impl EchoCluster { let mut writer = encoder.with_command(RespCommands::EchoResp as _)?; - writer.start_struct(CmdDataWriter::TAG)?; // Echo = input * self.multiplier - writer.u8(TagType::Context(0), a * self.multiplier)?; - writer.end_container()?; + writer.u8(&CmdDataWriter::TAG, a * self.multiplier)?; writer.complete() } @@ -259,9 +267,6 @@ impl EchoCluster { } } -pub const ATTR_CUSTOM_VALUE: u32 = 0xcafebeef; -pub const ATTR_WRITE_DEFAULT_VALUE: u16 = 0xcafe; - impl Handler for EchoCluster { fn read( &self, @@ -272,7 +277,12 @@ impl Handler for EchoCluster { EchoCluster::read(self, attr, encoder) } - fn write(&self, _exchange: &Exchange, attr: &AttrDetails, data: AttrData) -> Result<(), Error> { + fn write( + &self, + _exchange: &Exchange, + attr: &AttrDetails, + data: rs_matter::data_model::objects::AttrData, + ) -> Result<(), Error> { EchoCluster::write(self, attr, data) } diff --git a/rs-matter/tests/common/e2e/im/handler.rs b/rs-matter/tests/common/e2e/im/handler.rs new file mode 100644 index 00000000..b8487a84 --- /dev/null +++ b/rs-matter/tests/common/e2e/im/handler.rs @@ -0,0 +1,207 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use rs_matter::data_model::cluster_basic_information; +use rs_matter::data_model::cluster_on_off::{self, OnOffCluster}; +use rs_matter::data_model::device_types::{DEV_TYPE_ON_OFF_LIGHT, DEV_TYPE_ROOT_NODE}; +use rs_matter::data_model::objects::{ + AsyncHandler, AsyncMetadata, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails, Dataver, + Endpoint, Handler, Metadata, Node, NonBlockingHandler, +}; +use rs_matter::data_model::root_endpoint::{self, EthRootEndpointHandler}; +use rs_matter::data_model::sdm::admin_commissioning; +use rs_matter::data_model::sdm::general_commissioning; +use rs_matter::data_model::sdm::noc; +use rs_matter::data_model::sdm::nw_commissioning; +use rs_matter::data_model::system_model::access_control; +use rs_matter::data_model::system_model::descriptor::{self, DescriptorCluster}; +use rs_matter::error::Error; +use rs_matter::handler_chain_type; +use rs_matter::tlv::TLVElement; +use rs_matter::transport::exchange::Exchange; +use rs_matter::Matter; + +use crate::common::e2e::E2eRunner; + +use super::echo_cluster::{self, EchoCluster}; + +/// A sample handler for E2E IM tests. +pub struct E2eTestHandler<'a>( + handler_chain_type!(OnOffCluster, EchoCluster, DescriptorCluster<'static>, EchoCluster | EthRootEndpointHandler<'a>), +); + +impl<'a> E2eTestHandler<'a> { + pub const NODE: Node<'static> = Node { + id: 0, + endpoints: &[ + Endpoint { + id: 0, + clusters: &[ + descriptor::CLUSTER, + cluster_basic_information::CLUSTER, + general_commissioning::CLUSTER, + nw_commissioning::ETH_CLUSTER, + admin_commissioning::CLUSTER, + noc::CLUSTER, + access_control::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ROOT_NODE, + }, + Endpoint { + id: 1, + clusters: &[ + descriptor::CLUSTER, + cluster_on_off::CLUSTER, + echo_cluster::CLUSTER, + ], + device_type: DEV_TYPE_ON_OFF_LIGHT, + }, + ], + }; + + pub fn new(matter: &'a Matter<'a>) -> Self { + let handler = root_endpoint::eth_handler(0, matter.rand()) + .chain( + 0, + echo_cluster::ID, + EchoCluster::new(2, Dataver::new_rand(matter.rand())), + ) + .chain( + 1, + descriptor::ID, + DescriptorCluster::new(Dataver::new_rand(matter.rand())), + ) + .chain( + 1, + echo_cluster::ID, + EchoCluster::new(3, Dataver::new_rand(matter.rand())), + ) + .chain( + 1, + cluster_on_off::ID, + OnOffCluster::new(Dataver::new_rand(matter.rand())), + ); + + Self(handler) + } + + pub fn echo_cluster(&self, endpoint: u16) -> &EchoCluster { + match endpoint { + 0 => &self.0.next.next.next.handler, + 1 => &self.0.next.handler, + _ => panic!(), + } + } +} + +impl<'a> Handler for E2eTestHandler<'a> { + fn read( + &self, + exchange: &Exchange, + attr: &AttrDetails, + encoder: AttrDataEncoder, + ) -> Result<(), Error> { + self.0.read(exchange, attr, encoder) + } + + fn write( + &self, + exchange: &Exchange, + attr: &AttrDetails, + data: rs_matter::data_model::objects::AttrData, + ) -> Result<(), Error> { + self.0.write(exchange, attr, data) + } + + fn invoke( + &self, + exchange: &Exchange, + cmd: &CmdDetails, + data: &TLVElement, + encoder: CmdDataEncoder, + ) -> Result<(), Error> { + self.0.invoke(exchange, cmd, data, encoder) + } +} + +impl<'a> NonBlockingHandler for E2eTestHandler<'a> {} + +impl<'a> AsyncHandler for E2eTestHandler<'a> { + async fn read( + &self, + exchange: &Exchange<'_>, + attr: &AttrDetails<'_>, + encoder: AttrDataEncoder<'_, '_, '_>, + ) -> Result<(), Error> { + self.0.read(exchange, attr, encoder) + } + + fn read_awaits(&self, _exchange: &Exchange, _attr: &AttrDetails) -> bool { + false + } + + fn write_awaits(&self, _exchange: &Exchange, _attr: &AttrDetails) -> bool { + false + } + + fn invoke_awaits(&self, _exchange: &Exchange, _cmd: &CmdDetails) -> bool { + false + } + + async fn write( + &self, + exchange: &Exchange<'_>, + attr: &AttrDetails<'_>, + data: rs_matter::data_model::objects::AttrData<'_>, + ) -> Result<(), Error> { + self.0.write(exchange, attr, data) + } + + async fn invoke( + &self, + exchange: &Exchange<'_>, + cmd: &CmdDetails<'_>, + data: &TLVElement<'_>, + encoder: CmdDataEncoder<'_, '_, '_>, + ) -> Result<(), Error> { + self.0.invoke(exchange, cmd, data, encoder) + } +} + +impl<'a> Metadata for E2eTestHandler<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + fn lock(&self) -> Self::MetadataGuard<'_> { + Self::NODE + } +} + +impl<'a> AsyncMetadata for E2eTestHandler<'a> { + type MetadataGuard<'g> = Node<'g> where Self: 'g; + + async fn lock(&self) -> Self::MetadataGuard<'_> { + Self::NODE + } +} + +impl E2eRunner { + // For backwards compatibility + pub fn handler(&self) -> E2eTestHandler<'_> { + E2eTestHandler::new(&self.matter) + } +} diff --git a/rs-matter/tests/common/e2e/test.rs b/rs-matter/tests/common/e2e/test.rs new file mode 100644 index 00000000..45f6dbad --- /dev/null +++ b/rs-matter/tests/common/e2e/test.rs @@ -0,0 +1,123 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use embassy_futures::block_on; +use embassy_futures::select::select; + +use embassy_time::{Duration, Timer}; + +use rs_matter::data_model::objects::{AsyncHandler, AsyncMetadata}; +use rs_matter::error::Error; +use rs_matter::transport::exchange::{Exchange, MessageMeta}; +use rs_matter::utils::select::Coalesce; +use rs_matter::utils::storage::WriteBuf; + +use super::E2eRunner; + +/// Represents an E2E test. +pub trait E2eTest { + /// Prepare the input message for the test. + fn fill_input(&self, message_buf: &mut WriteBuf) -> Result; + + /// Validate the message returned by the remote peer. + fn validate_result(&self, meta: MessageMeta, message: &[u8]) -> Result<(), Error>; + + /// Optionally return a delay in milliseconds to wait after receiving the message by the remote peer. + fn delay(&self) -> Option { + None + } +} + +impl E2eTest for &dyn E2eTest { + fn fill_input(&self, message_buf: &mut WriteBuf) -> Result { + (*self).fill_input(message_buf) + } + + fn validate_result(&self, meta: MessageMeta, message: &[u8]) -> Result<(), Error> { + (*self).validate_result(meta, message) + } + + fn delay(&self) -> Option { + (*self).delay() + } +} + +impl E2eRunner { + /// Run the provided test with the given handler and wait with blocking + /// until the test completes or fails. + pub fn test_one(&self, handler: H, test: T) + where + H: AsyncHandler + AsyncMetadata, + T: E2eTest, + { + self.test_all(handler, core::iter::once(test)) + } + + /// Run the provided tests with the given handler and wait with blocking + /// until all tests complete or the first one fails. + pub fn test_all(&self, handler: H, tests: I) + where + H: AsyncHandler + AsyncMetadata, + I: IntoIterator, + T: E2eTest, + { + block_on( + select(self.run(handler), async move { + let mut exchange = self.initiate_exchange().await?; + + for test in tests { + Self::execute_test(&mut exchange, test).await?; + } + + exchange.acknowledge().await?; + + Ok(()) + }) + .coalesce(), + ) + .unwrap() + } + + /// Execute the test via the provided exchange. + pub async fn execute_test(exchange: &mut Exchange<'_>, test: T) -> Result<(), Error> + where + T: E2eTest, + { + exchange + .send_with(|_, wb| { + let meta = test.fill_input(wb)?; + + Ok(Some(meta)) + }) + .await?; + + { + // In a separate block so that the RX message is dropped before we start waiting + + let rx = exchange.recv().await?; + + test.validate_result(rx.meta(), rx.payload())?; + } + + let delay = test.delay().unwrap_or(0); + if delay > 0 { + Timer::after(Duration::from_millis(delay as _)).await; + } + + Ok(()) + } +} diff --git a/rs-matter/tests/common/e2e/tlv.rs b/rs-matter/tests/common/e2e/tlv.rs new file mode 100644 index 00000000..abd37782 --- /dev/null +++ b/rs-matter/tests/common/e2e/tlv.rs @@ -0,0 +1,120 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use core::fmt::Debug; + +use rs_matter::error::Error; +use rs_matter::tlv::{TLVElement, TLVTag, TLVWriter, ToTLV}; +use rs_matter::transport::exchange::MessageMeta; +use rs_matter::utils::storage::WriteBuf; + +use super::test::E2eTest; + +/// A `ToTLV` trait variant useful for testing. +/// +/// Unlike `ToTLV`, `TestToTLV` is `dyn`-friendly, but therefore does +/// require a `TLVWriter` to be passed in. +pub trait TestToTLV: Debug + Sync { + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error>; +} + +impl TestToTLV for T +where + T: ToTLV + Debug + Sync, +{ + fn test_to_tlv(&self, tag: &TLVTag, tw: &mut TLVWriter) -> Result<(), Error> { + ToTLV::to_tlv(self, tag, tw) + } +} + +/// A concrete `E2eTest` implementation that assumes that the input and output payload +/// of the test are both TLV payloads. +/// +/// It validates the differences between the output and the expected payload using a Diff +/// algorithm, which provides a human readable output. +pub struct TLVTest { + pub input_meta: MessageMeta, + pub input_payload: I, + pub expected_meta: MessageMeta, + pub expected_payload: E, + pub process_reply: F, + pub delay_ms: Option, +} + +impl E2eTest for TLVTest +where + I: TestToTLV, + E: TestToTLV, + F: Fn(&TLVElement, &mut [u8]) -> Result, +{ + fn fill_input(&self, message_buf: &mut WriteBuf) -> Result { + self.input_payload + .test_to_tlv(&TLVTag::Anonymous, &mut TLVWriter::new(message_buf))?; + + Ok(self.input_meta) + } + + fn validate_result(&self, meta: MessageMeta, message: &[u8]) -> Result<(), Error> { + use core::fmt::Write; + + assert_eq!(self.expected_meta, meta); + + let mut buf = [0; 1500]; + let mut wb = WriteBuf::new(&mut buf); + + let mut tw = TLVWriter::new(&mut wb); + + self.expected_payload + .test_to_tlv(&TLVTag::Anonymous, &mut tw)?; + let expected_element = TLVElement::new(wb.as_slice()); + + let element = TLVElement::new(message); + + let mut buf2 = [0; 1500]; + let len = (self.process_reply)(&element, &mut buf2)?; + + let element = TLVElement::new(&buf2[..len]); + + if expected_element != element { + let expected_str = format!("{expected_element}"); + let actual_str = format!("{element}"); + + let diff = similar::TextDiff::from_lines(&expected_str, &actual_str); + + let mut diff_str = String::new(); + + // TODO: Color the diff output + for change in diff.iter_all_changes() { + let sign = match change.tag() { + similar::ChangeTag::Delete => "-", + similar::ChangeTag::Insert => "+", + similar::ChangeTag::Equal => " ", + }; + + write!(diff_str, "{}{}", sign, change).unwrap(); + } + + assert!(false, "Expected does not match actual:\n== Diff:\n{diff_str}\n== Expected:\n{expected_str}\n== Actual:\n{actual_str}"); + } + + Ok(()) + } + + fn delay(&self) -> Option { + self.delay_ms + } +} diff --git a/rs-matter/tests/common/handlers.rs b/rs-matter/tests/common/handlers.rs deleted file mode 100644 index 63d743c9..00000000 --- a/rs-matter/tests/common/handlers.rs +++ /dev/null @@ -1,300 +0,0 @@ -use log::{info, warn}; -use rs_matter::{ - error::ErrorCode, - interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{AttrData, AttrPath, AttrResp, AttrStatus, CmdData, DataVersionFilter}, - msg::{ - self, InvReq, ReadReq, ReportDataMsg, StatusResp, TimedReq, WriteReq, WriteResp, - WriteRespTag, - }, - }, - }, - tlv::{self, FromTLV, TLVArray, ToTLV}, -}; - -use super::{ - attributes::assert_attr_report, - commands::{assert_inv_response, ExpectedInvResp}, - im_engine::{ImEngine, ImEngineHandler, ImInput, ImOutput}, -}; - -pub enum WriteResponse<'a> { - TransactionError, - TransactionSuccess(&'a [AttrStatus]), -} - -pub enum TimedInvResponse<'a> { - TransactionError(IMStatusCode), - TransactionSuccess(&'a [ExpectedInvResp]), -} - -impl<'a> ImEngine<'a> { - pub fn read_reqs(input: &[AttrPath], expected: &[AttrResp]) { - let im = ImEngine::new_default(); - - im.add_default_acl(); - im.handle_read_reqs(&im.handler(), input, expected); - } - - // Helper for handling Read Req sequences for this file - pub fn handle_read_reqs( - &self, - handler: &ImEngineHandler, - input: &[AttrPath], - expected: &[AttrResp], - ) { - let mut out = heapless::Vec::<_, 1>::new(); - let received = self.gen_read_reqs_output(handler, input, None, &mut out); - assert_attr_report(&received, expected) - } - - pub fn gen_read_reqs_output<'c, const N: usize>( - &self, - handler: &ImEngineHandler, - input: &[AttrPath], - dataver_filters: Option>, - out: &'c mut heapless::Vec, - ) -> ReportDataMsg<'c> { - let mut read_req = ReadReq::new(true).set_attr_requests(input); - read_req.dataver_filters = dataver_filters; - - let input = ImInput::new(OpCode::ReadRequest, &read_req); - - self.process(handler, &[&input], out).unwrap(); - - for o in &*out { - tlv::print_tlv_list(&o.data); - } - - let root = tlv::get_root_node_struct(&out[0].data).unwrap(); - ReportDataMsg::from_tlv(&root).unwrap() - } - - pub fn write_reqs(input: &[AttrData], expected: &[AttrStatus]) { - let im = ImEngine::new_default(); - - im.add_default_acl(); - im.handle_write_reqs(&im.handler(), input, expected); - } - - pub fn handle_write_reqs( - &self, - handler: &ImEngineHandler, - input: &[AttrData], - expected: &[AttrStatus], - ) { - let write_req = WriteReq::new(false, input); - - let input = ImInput::new(OpCode::WriteRequest, &write_req); - let mut out = heapless::Vec::<_, 1>::new(); - self.process(handler, &[&input], &mut out).unwrap(); - - for o in &out { - tlv::print_tlv_list(&o.data); - } - - let root = tlv::get_root_node_struct(&out[0].data).unwrap(); - - let mut index = 0; - let response_iter = root - .find_tag(WriteRespTag::WriteResponses as u32) - .unwrap() - .confirm_array() - .unwrap() - .enter() - .unwrap(); - - for response in response_iter { - info!("Validating index {}", index); - let status = AttrStatus::from_tlv(&response).unwrap(); - assert_eq!(expected[index], status); - info!("Index {} success", index); - index += 1; - } - assert_eq!(index, expected.len()); - } - - pub fn commands(input: &[CmdData], expected: &[ExpectedInvResp]) { - let im = ImEngine::new_default(); - - im.add_default_acl(); - im.handle_commands(&im.handler(), input, expected) - } - - // Helper for handling Invoke Command sequences - pub fn handle_commands( - &self, - handler: &ImEngineHandler, - input: &[CmdData], - expected: &[ExpectedInvResp], - ) { - let req = InvReq { - suppress_response: Some(false), - timed_request: Some(false), - inv_requests: Some(TLVArray::Slice(input)), - }; - - let input = ImInput::new(OpCode::InvokeRequest, &req); - - let mut out = heapless::Vec::<_, 1>::new(); - self.process(handler, &[&input], &mut out).unwrap(); - - for o in &out { - tlv::print_tlv_list(&o.data); - } - - let root = tlv::get_root_node_struct(&out[0].data).unwrap(); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, expected) - } - - fn gen_timed_reqs_output( - &self, - handler: &ImEngineHandler, - opcode: OpCode, - request: &dyn ToTLV, - timeout: u16, - delay: u16, - out: &mut heapless::Vec, - ) { - let mut inp = heapless::Vec::<_, 2>::new(); - - let timed_req = TimedReq { timeout }; - let im_input = ImInput::new_delayed(OpCode::TimedRequest, &timed_req, Some(delay)); - - if timeout != 0 { - // Send Timed Req - inp.push(&im_input).map_err(|_| ErrorCode::NoSpace).unwrap(); - } else { - warn!("Skipping timed request"); - } - - // Send Write Req - let input = ImInput::new(opcode, request); - inp.push(&input).map_err(|_| ErrorCode::NoSpace).unwrap(); - - self.process(handler, &inp, out).unwrap(); - - drop(inp); - - for o in out { - tlv::print_tlv_list(&o.data); - } - } - - pub fn timed_write_reqs( - input: &[AttrData], - expected: &WriteResponse, - timeout: u16, - delay: u16, - ) { - let im = ImEngine::new_default(); - - im.add_default_acl(); - im.handle_timed_write_reqs(&im.handler(), input, expected, timeout, delay); - } - - // Helper for handling Write Attribute sequences - pub fn handle_timed_write_reqs( - &self, - handler: &ImEngineHandler, - input: &[AttrData], - expected: &WriteResponse, - timeout: u16, - delay: u16, - ) { - let mut out = heapless::Vec::<_, 2>::new(); - let mut write_req = WriteReq::new(false, input); - write_req.timed_request = Some(true); - - self.gen_timed_reqs_output( - handler, - OpCode::WriteRequest, - &write_req, - timeout, - delay, - &mut out, - ); - - let out = &out[out.len() - 1]; - let root = tlv::get_root_node_struct(&out.data).unwrap(); - - match *expected { - WriteResponse::TransactionSuccess(t) => { - assert_eq!(out.action, OpCode::WriteResponse); - let resp = WriteResp::from_tlv(&root).unwrap(); - assert_eq!(resp.write_responses, t); - } - WriteResponse::TransactionError => { - assert_eq!(out.action, OpCode::StatusResponse); - let status_resp = StatusResp::from_tlv(&root).unwrap(); - assert_eq!(status_resp.status, IMStatusCode::Timeout); - } - } - } - - pub fn timed_commands( - input: &[CmdData], - expected: &TimedInvResponse, - timeout: u16, - delay: u16, - set_timed_request: bool, - ) { - let im = ImEngine::new_default(); - - im.add_default_acl(); - im.handle_timed_commands( - &im.handler(), - input, - expected, - timeout, - delay, - set_timed_request, - ); - } - - // Helper for handling Invoke Command sequences - pub fn handle_timed_commands( - &self, - handler: &ImEngineHandler, - input: &[CmdData], - expected: &TimedInvResponse, - timeout: u16, - delay: u16, - set_timed_request: bool, - ) { - let mut out = heapless::Vec::<_, 2>::new(); - let req = InvReq { - suppress_response: Some(false), - timed_request: Some(set_timed_request), - inv_requests: Some(TLVArray::Slice(input)), - }; - - self.gen_timed_reqs_output( - handler, - OpCode::InvokeRequest, - &req, - timeout, - delay, - &mut out, - ); - - let out = &out[out.len() - 1]; - let root = tlv::get_root_node_struct(&out.data).unwrap(); - - match expected { - TimedInvResponse::TransactionSuccess(t) => { - assert_eq!(out.action, OpCode::InvokeResponse); - let resp = msg::InvResp::from_tlv(&root).unwrap(); - assert_inv_response(&resp, t) - } - TimedInvResponse::TransactionError(e) => { - assert_eq!(out.action, OpCode::StatusResponse); - let status_resp = StatusResp::from_tlv(&root).unwrap(); - assert_eq!(status_resp.status, *e); - } - } - } -} diff --git a/rs-matter/tests/common/mod.rs b/rs-matter/tests/common/mod.rs index 052f8077..b468e05e 100644 --- a/rs-matter/tests/common/mod.rs +++ b/rs-matter/tests/common/mod.rs @@ -15,11 +15,7 @@ * limitations under the License. */ -pub mod attributes; -pub mod commands; -pub mod echo_cluster; -pub mod handlers; -pub mod im_engine; +pub mod e2e; pub fn init_env_logger() { #[cfg(all(feature = "std", not(target_os = "espidf")))] diff --git a/rs-matter/tests/data_model/acl_and_dataver.rs b/rs-matter/tests/data_model/acl_and_dataver.rs index 7cedf18e..48a23b6f 100644 --- a/rs-matter/tests/data_model/acl_and_dataver.rs +++ b/rs-matter/tests/data_model/acl_and_dataver.rs @@ -17,29 +17,21 @@ use core::num::NonZeroU8; -use rs_matter::{ - acl::{gen_noc_cat, AclEntry, AuthMode, Target}, - data_model::{ - objects::{EncodeValue, Privilege}, - system_model::access_control, - }, - interaction_model::{ - core::IMStatusCode, - messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus, ClusterPath, DataVersionFilter}, - messages::GenericPath, - }, - tlv::{ElementType, TLVArray, TLVElement, TLVWriter, TagType}, +use rs_matter::acl::{gen_noc_cat, AclEntry, AuthMode, Target}; +use rs_matter::data_model::{objects::Privilege, system_model::access_control}; +use rs_matter::interaction_model::core::IMStatusCode; +use rs_matter::interaction_model::messages::ib::{ + AttrPath, AttrStatus, ClusterPath, DataVersionFilter, }; +use rs_matter::interaction_model::messages::GenericPath; -use crate::{ - attr_data_path, attr_status, - common::{ - attributes::*, - echo_cluster::{self, ATTR_WRITE_DEFAULT_VALUE}, - im_engine::{ImEngine, IM_ENGINE_PEER_ID}, - init_env_logger, - }, -}; +use crate::common::e2e::im::attributes::{TestAttrData, TestAttrResp}; +use crate::common::e2e::im::echo_cluster::ATTR_WRITE_DEFAULT_VALUE; +use crate::common::e2e::im::{echo_cluster, ReplyProcessor, TestReadReq, TestReportDataMsg}; +use crate::common::e2e::tlv::TLVTest; +use crate::common::e2e::{ImEngine, IM_ENGINE_PEER_ID}; +use crate::common::init_env_logger; +use crate::{attr_data, attr_data_path, attr_status}; const FAB_1: NonZeroU8 = match NonZeroU8::new(1) { Some(f) => f, @@ -72,9 +64,7 @@ fn wc_read_attribute() { let handler = im.handler(); // Test1: Empty Response as no ACL matches - let input = &[AttrPath::new(&wc_att1)]; - let expected = &[]; - im.handle_read_reqs(&handler, input, expected); + im.handle_read_reqs(&handler, &[AttrPath::new(&wc_att1)], &[]); // Add ACL to allow our peer to only access endpoint 0 let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); @@ -83,9 +73,11 @@ fn wc_read_attribute() { im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test2: Only Single response as only single endpoint is allowed - let input = &[AttrPath::new(&wc_att1)]; - let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; - im.handle_read_reqs(&handler, input, expected); + im.handle_read_reqs( + &handler, + &[AttrPath::new(&wc_att1)], + &[TestAttrResp::data(&ep0_att1, &0x1234u16)], + ); // Add ACL to allow our peer to also access endpoint 1 let mut acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); @@ -94,12 +86,14 @@ fn wc_read_attribute() { im.matter.acl_mgr.borrow_mut().add(acl).unwrap(); // Test3: Both responses are valid - let input = &[AttrPath::new(&wc_att1)]; - let expected = &[ - attr_data_path!(ep0_att1, ElementType::U16(0x1234)), - attr_data_path!(ep1_att1, ElementType::U16(0x1234)), - ]; - im.handle_read_reqs(&handler, input, expected); + im.handle_read_reqs( + &handler, + &[AttrPath::new(&wc_att1)], + &[ + TestAttrResp::data(&ep0_att1, &0x1234u16), + TestAttrResp::data(&ep1_att1, &0x1234u16), + ], + ); } #[test] @@ -113,18 +107,13 @@ fn exact_read_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); - let ep0_att1 = GenericPath::new( - Some(0), - Some(echo_cluster::ID), - Some(echo_cluster::AttributesDiscriminants::Att1 as u32), - ); let im = ImEngine::new_default(); let handler = im.handler(); // Test1: Unsupported Access error as no ACL matches let input = &[AttrPath::new(&wc_att1)]; - let expected = &[attr_status!(&ep0_att1, IMStatusCode::UnsupportedAccess)]; + let expected = &[attr_status!(&wc_att1, IMStatusCode::UnsupportedAccess)]; im.handle_read_reqs(&handler, input, expected); // Add ACL to allow our peer to access any endpoint @@ -134,7 +123,7 @@ fn exact_read_attribute() { // Test2: Only Single response as only single endpoint is allowed let input = &[AttrPath::new(&wc_att1)]; - let expected = &[attr_data_path!(ep0_att1, ElementType::U16(0x1234))]; + let expected = &[attr_data_path!(wc_att1, Some(&0x1234u16))]; im.handle_read_reqs(&handler, input, expected); } @@ -145,12 +134,6 @@ fn wc_write_attribute() { init_env_logger(); let val0 = 10; let val1 = 20; - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; - let attr_data1 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val1); - }; let wc_att = GenericPath::new( None, @@ -168,16 +151,8 @@ fn wc_write_attribute() { Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input0 = &[AttrData::new( - None, - AttrPath::new(&wc_att), - EncodeValue::Closure(&attr_data0), - )]; - let input1 = &[AttrData::new( - None, - AttrPath::new(&wc_att), - EncodeValue::Closure(&attr_data1), - )]; + let input0 = &[TestAttrData::new(None, AttrPath::new(&wc_att), &val0 as _)]; + let input1 = &[TestAttrData::new(None, AttrPath::new(&wc_att), &val1 as _)]; let im = ImEngine::new_default(); let handler = im.handler(); @@ -235,9 +210,6 @@ fn wc_write_attribute() { fn exact_write_attribute() { init_env_logger(); let val0 = 10; - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let ep0_att = GenericPath::new( Some(0), @@ -245,11 +217,7 @@ fn exact_write_attribute() { Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input = &[AttrData::new( - None, - AttrPath::new(&ep0_att), - EncodeValue::Closure(&attr_data0), - )]; + let input = &[TestAttrData::new(None, AttrPath::new(&ep0_att), &val0 as _)]; let expected_fail = &[AttrStatus::new( &ep0_att, IMStatusCode::UnsupportedAccess, @@ -286,9 +254,6 @@ fn exact_write_attribute() { fn exact_write_attribute_noc_cat() { init_env_logger(); let val0 = 10; - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let ep0_att = GenericPath::new( Some(0), @@ -296,11 +261,7 @@ fn exact_write_attribute_noc_cat() { Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input = &[AttrData::new( - None, - AttrPath::new(&ep0_att), - EncodeValue::Closure(&attr_data0), - )]; + let input = &[TestAttrData::new(None, AttrPath::new(&ep0_att), &val0 as _)]; let expected_fail = &[AttrStatus::new( &ep0_att, IMStatusCode::UnsupportedAccess, @@ -339,19 +300,12 @@ fn exact_write_attribute_noc_cat() { fn insufficient_perms_write() { init_env_logger(); let val0 = 10; - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input0 = &[AttrData::new( - None, - AttrPath::new(&ep0_att), - EncodeValue::Closure(&attr_data0), - )]; + let input0 = &[TestAttrData::new(None, AttrPath::new(&ep0_att), &val0 as _)]; let im = ImEngine::new_default(); let handler = im.handler(); @@ -404,19 +358,12 @@ fn write_with_runtime_acl_add() { let handler = im.handler(); let val0 = 10; - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let ep0_att = GenericPath::new( Some(0), Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input0 = AttrData::new( - None, - AttrPath::new(&ep0_att), - EncodeValue::Closure(&attr_data0), - ); + let input0 = TestAttrData::new(None, AttrPath::new(&ep0_att), &val0 as _); // Create ACL to allow our peer ADMIN on everything let mut allow_acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); @@ -427,11 +374,7 @@ fn write_with_runtime_acl_add() { Some(access_control::ID), Some(access_control::AttributesDiscriminants::Acl as u32), ); - let acl_input = AttrData::new( - None, - AttrPath::new(&acl_att), - EncodeValue::Value(&allow_acl), - ); + let acl_input = TestAttrData::new(None, AttrPath::new(&acl_att), &allow_acl); // Create ACL that only allows write to the ACL Cluster let mut basic_acl = AclEntry::new(FAB_1, Privilege::ADMIN, AuthMode::Case); @@ -480,40 +423,26 @@ fn test_read_data_ver() { let input = &[AttrPath::new(&wc_ep_att1)]; let expected = &[ - attr_data_path!( - GenericPath::new( - Some(0), - Some(echo_cluster::ID), - Some(echo_cluster::AttributesDiscriminants::Att1 as u32) - ), - ElementType::U16(0x1234) + attr_data!( + 0, + echo_cluster::ID, + echo_cluster::AttributesDiscriminants::Att1, + Some(&0x1234u16) ), - attr_data_path!( - GenericPath::new( - Some(1), - Some(echo_cluster::ID), - Some(echo_cluster::AttributesDiscriminants::Att1 as u32) - ), - ElementType::U16(0x1234) + attr_data!( + 1, + echo_cluster::ID, + echo_cluster::AttributesDiscriminants::Att1, + Some(&0x1234u16) ), ]; - let mut out = heapless::Vec::new(); - - // Test 1: Simple read to retrieve the current Data Version of Cluster at Endpoint 0 - let received = im.gen_read_reqs_output::<1>(&handler, input, None, &mut out); - assert_attr_report(&received, expected); + // Test 1: Simple read without any data version filters + im.test_one(&handler, TLVTest::read_attrs(input, expected)); - let data_ver_cluster_at_0 = received - .attr_reports - .as_ref() - .unwrap() - .get_index(0) - .unwrap_data() - .data_ver - .unwrap(); + let data_ver_cluster_at_0 = handler.echo_cluster(0).data_ver.get(); - let dataver_filter = [DataVersionFilter { + let dataver_filter = &[DataVersionFilter { path: ClusterPath { node: None, endpoint: 0, @@ -523,23 +452,24 @@ fn test_read_data_ver() { }]; // Test 2: Add Dataversion filter for cluster at endpoint 0 only single entry should be retrieved - let mut out = heapless::Vec::new(); - let received = im.gen_read_reqs_output::<1>( - &handler, - input, - Some(TLVArray::Slice(&dataver_filter)), - &mut out, - ); - let expected_only_one = &[attr_data_path!( - GenericPath::new( - Some(1), - Some(echo_cluster::ID), - Some(echo_cluster::AttributesDiscriminants::Att1 as u32) - ), - ElementType::U16(0x1234) + let expected_only_one = &[attr_data!( + 1, + echo_cluster::ID, + echo_cluster::AttributesDiscriminants::Att1, + Some(&0x1234u16) )]; - assert_attr_report(&received, expected_only_one); + im.test_one( + &handler, + TLVTest::read( + TestReadReq { + dataver_filters: Some(dataver_filter), + ..TestReadReq::reqs(input) + }, + TestReportDataMsg::reports(expected_only_one), + ReplyProcessor::remove_attr_dataver, + ), + ); // Test 3: Exact read attribute let ep0_att1 = GenericPath::new( @@ -548,15 +478,17 @@ fn test_read_data_ver() { Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ); let input = &[AttrPath::new(&ep0_att1)]; - let received = im.gen_read_reqs_output( + im.test_one( &handler, - input, - Some(TLVArray::Slice(&dataver_filter)), - &mut out, + TLVTest::read( + TestReadReq { + dataver_filters: Some(dataver_filter), + ..TestReadReq::reqs(input) + }, + TestReportDataMsg::reports(&[]), + ReplyProcessor::none, + ), ); - let expected_error = &[]; - - assert_attr_report(&received, expected_error); } #[test] @@ -589,16 +521,14 @@ fn test_write_data_ver() { let val0 = 10u16; let val1 = 11u16; - let attr_data0 = EncodeValue::Value(&val0); - let attr_data1 = EncodeValue::Value(&val1); let initial_data_ver = handler.echo_cluster(0).data_ver.get(); // Test 1: Write with correct dataversion should succeed - let input_correct_dataver = &[AttrData::new( + let input_correct_dataver = &[TestAttrData::new( Some(initial_data_ver), AttrPath::new(&ep0_attwrite), - attr_data0, + &val0 as _, )]; im.handle_write_reqs( &handler, @@ -609,10 +539,10 @@ fn test_write_data_ver() { // Test 2: Write with incorrect dataversion should fail // Now the data version would have incremented due to the previous write - let input_correct_dataver = &[AttrData::new( + let input_correct_dataver = &[TestAttrData::new( Some(initial_data_ver), AttrPath::new(&ep0_attwrite), - attr_data1.clone(), + &val1 as _, )]; im.handle_write_reqs( &handler, @@ -630,10 +560,10 @@ fn test_write_data_ver() { // data version would not match let new_data_ver = handler.echo_cluster(0).data_ver.get(); - let input_correct_dataver = &[AttrData::new( + let input_correct_dataver = &[TestAttrData::new( Some(new_data_ver), AttrPath::new(&wc_ep_attwrite), - attr_data1, + &val1 as _, )]; im.handle_write_reqs( &handler, diff --git a/rs-matter/tests/data_model/attribute_lists.rs b/rs-matter/tests/data_model/attribute_lists.rs index 33006810..98098c51 100644 --- a/rs-matter/tests/data_model/attribute_lists.rs +++ b/rs-matter/tests/data_model/attribute_lists.rs @@ -15,21 +15,15 @@ * limitations under the License. */ -use rs_matter::{ - data_model::objects::EncodeValue, - interaction_model::{ - core::IMStatusCode, - messages::ib::{AttrData, AttrPath, AttrStatus}, - messages::GenericPath, - }, - tlv::Nullable, -}; - -use crate::common::{ - echo_cluster::{self, TestChecker}, - im_engine::ImEngine, - init_env_logger, -}; +use rs_matter::interaction_model::core::IMStatusCode; +use rs_matter::interaction_model::messages::ib::{AttrPath, AttrStatus}; +use rs_matter::interaction_model::messages::GenericPath; +use rs_matter::tlv::{Nullable, TLVValue}; + +use crate::common::e2e::im::attributes::TestAttrData; +use crate::common::e2e::im::echo_cluster::{self, TestChecker}; +use crate::common::e2e::ImEngine; +use crate::common::init_env_logger; // Helper for handling Write Attribute sequences #[test] @@ -42,13 +36,8 @@ fn attr_list_ops() { init_env_logger(); - let delete_item = EncodeValue::Closure(&|tag, t| { - let _ = t.null(tag); - }); - let delete_all = EncodeValue::Closure(&|tag, t| { - let _ = t.start_array(tag); - let _ = t.end_container(); - }); + let delete_item = TLVValue::null(); + let delete_all: &[u32] = &[]; let att_data = GenericPath::new( Some(0), @@ -58,11 +47,7 @@ fn attr_list_ops() { let mut att_path = AttrPath::new(&att_data); // Test 1: Add Operation - add val0 - let input = &[AttrData::new( - None, - att_path.clone(), - EncodeValue::Value(&val0), - )]; + let input = &[TestAttrData::new(None, att_path.clone(), &val0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::write_reqs(input, expected); @@ -72,11 +57,7 @@ fn attr_list_ops() { } // Test 2: Another Add Operation - add val1 - let input = &[AttrData::new( - None, - att_path.clone(), - EncodeValue::Value(&val1), - )]; + let input = &[TestAttrData::new(None, att_path.clone(), &val1)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::write_reqs(input, expected); @@ -86,12 +67,8 @@ fn attr_list_ops() { } // Test 3: Edit Operation - edit val1 to val0 - att_path.list_index = Some(Nullable::NotNull(1)); - let input = &[AttrData::new( - None, - att_path.clone(), - EncodeValue::Value(&val0), - )]; + att_path.list_index = Some(Nullable::some(1)); + let input = &[TestAttrData::new(None, att_path.clone(), &val0)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::write_reqs(input, expected); @@ -101,8 +78,8 @@ fn attr_list_ops() { } // Test 4: Delete Operation - delete index 0 - att_path.list_index = Some(Nullable::NotNull(0)); - let input = &[AttrData::new(None, att_path.clone(), delete_item)]; + att_path.list_index = Some(Nullable::some(0)); + let input = &[TestAttrData::new(None, att_path.clone(), &delete_item)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::write_reqs(input, expected); @@ -114,11 +91,7 @@ fn attr_list_ops() { // Test 5: Overwrite Operation - overwrite first 2 entries let overwrite_val: [u32; 2] = [20, 21]; att_path.list_index = None; - let input = &[AttrData::new( - None, - att_path.clone(), - EncodeValue::Value(&overwrite_val), - )]; + let input = &[TestAttrData::new(None, att_path.clone(), &overwrite_val)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::write_reqs(input, expected); @@ -129,7 +102,7 @@ fn attr_list_ops() { // Test 6: Overwrite Operation - delete whole list att_path.list_index = None; - let input = &[AttrData::new(None, att_path, delete_all)]; + let input = &[TestAttrData::new(None, att_path, &delete_all)]; let expected = &[AttrStatus::new(&att_data, IMStatusCode::Success, 0)]; ImEngine::write_reqs(input, expected); diff --git a/rs-matter/tests/data_model/attributes.rs b/rs-matter/tests/data_model/attributes.rs index 1e641af4..07d675b0 100644 --- a/rs-matter/tests/data_model/attributes.rs +++ b/rs-matter/tests/data_model/attributes.rs @@ -15,23 +15,15 @@ * limitations under the License. */ -use rs_matter::{ - data_model::{ - cluster_on_off, - objects::{EncodeValue, GlobalElements}, - }, - interaction_model::{ - core::IMStatusCode, - messages::ib::{AttrData, AttrPath, AttrResp, AttrStatus}, - messages::GenericPath, - }, - tlv::{ElementType, TLVElement, TLVWriter, TagType}, -}; - -use crate::{ - attr_data, attr_data_path, attr_status, - common::{attributes::*, echo_cluster, im_engine::ImEngine, init_env_logger}, -}; +use rs_matter::data_model::{cluster_on_off, objects::GlobalElements}; +use rs_matter::interaction_model::core::IMStatusCode; +use rs_matter::interaction_model::messages::ib::{AttrPath, AttrStatus}; +use rs_matter::interaction_model::messages::GenericPath; + +use crate::common::e2e::im::{attributes::TestAttrData, echo_cluster}; +use crate::common::e2e::ImEngine; +use crate::common::init_env_logger; +use crate::{attr_data, attr_data_path, attr_status}; #[test] fn test_read_success() { @@ -62,12 +54,9 @@ fn test_read_success() { AttrPath::new(&ep1_attcustom), ]; let expected = &[ - attr_data_path!(ep0_att1, ElementType::U16(0x1234)), - attr_data_path!(ep1_att2, ElementType::U16(0x5678)), - attr_data_path!( - ep1_attcustom, - ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) - ), + attr_data_path!(ep0_att1, Some(&0x1234u16)), + attr_data_path!(ep1_att2, Some(&0x5678u16)), + attr_data_path!(ep1_attcustom, Some(&echo_cluster::ATTR_CUSTOM_VALUE)), ]; ImEngine::read_reqs(input, expected); } @@ -138,13 +127,13 @@ fn test_read_wc_endpoint_all_have_clusters() { 0, echo_cluster::ID, echo_cluster::AttributesDiscriminants::Att1, - ElementType::U16(0x1234) + Some(&0x1234u16) ), attr_data!( 1, echo_cluster::ID, echo_cluster::AttributesDiscriminants::Att1, - ElementType::U16(0x1234) + Some(&0x1234u16) ), ]; ImEngine::read_reqs(input, expected); @@ -164,13 +153,11 @@ fn test_read_wc_endpoint_only_1_has_cluster() { ); let input = &[AttrPath::new(&wc_ep_onoff)]; - let expected = &[attr_data_path!( - GenericPath::new( - Some(1), - Some(cluster_on_off::ID), - Some(cluster_on_off::AttributesDiscriminants::OnOff as u32) - ), - ElementType::False + let expected = &[attr_data!( + 1, + cluster_on_off::ID, + cluster_on_off::AttributesDiscriminants::OnOff, + Some(&false) )]; ImEngine::read_reqs(input, expected); } @@ -184,35 +171,23 @@ fn test_read_wc_endpoint_wc_attribute() { let wc_ep_wc_attr = GenericPath::new(None, Some(echo_cluster::ID), None); let input = &[AttrPath::new(&wc_ep_wc_attr)]; - let attr_list = TLVHolder::new_array( - 2, - &[ - GlobalElements::FeatureMap as u16, - GlobalElements::AttributeList as u16, - echo_cluster::AttributesDiscriminants::Att1 as u16, - echo_cluster::AttributesDiscriminants::Att2 as u16, - echo_cluster::AttributesDiscriminants::AttWrite as u16, - echo_cluster::AttributesDiscriminants::AttCustom as u16, - ], - ); - let attr_list_tlv = attr_list.to_tlv(); + let attr_list: &[u16] = &[ + GlobalElements::FeatureMap as u16, + GlobalElements::AttributeList as u16, + echo_cluster::AttributesDiscriminants::Att1 as u16, + echo_cluster::AttributesDiscriminants::Att2 as u16, + echo_cluster::AttributesDiscriminants::AttWrite as u16, + echo_cluster::AttributesDiscriminants::AttCustom as u16, + echo_cluster::AttributesDiscriminants::AttWriteList as u16, + ]; let expected = &[ - attr_data_path!( - GenericPath::new( - Some(0), - Some(echo_cluster::ID), - Some(GlobalElements::FeatureMap as u32), - ), - ElementType::U8(0) - ), - attr_data_path!( - GenericPath::new( - Some(0), - Some(echo_cluster::ID), - Some(GlobalElements::AttributeList as u32), - ), - attr_list_tlv.get_element_type().clone() + attr_data!(0, echo_cluster::ID, GlobalElements::FeatureMap, Some(&0u8)), + attr_data!( + 0, + echo_cluster::ID, + GlobalElements::AttributeList, + Some(&attr_list) ), attr_data_path!( GenericPath::new( @@ -220,7 +195,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ), - ElementType::U16(0x1234) + Some(&0x1234u16) ), attr_data_path!( GenericPath::new( @@ -228,7 +203,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ), - ElementType::U16(0x5678) + Some(&0x5678u16) ), attr_data_path!( GenericPath::new( @@ -236,7 +211,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ), - ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) + Some(&echo_cluster::ATTR_CUSTOM_VALUE) ), attr_data_path!( GenericPath::new( @@ -244,7 +219,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(GlobalElements::FeatureMap as u32), ), - ElementType::U8(0) + Some(&0u8) ), attr_data_path!( GenericPath::new( @@ -252,7 +227,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(GlobalElements::AttributeList as u32), ), - attr_list_tlv.get_element_type().clone() + Some(&attr_list) ), attr_data_path!( GenericPath::new( @@ -260,7 +235,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::Att1 as u32), ), - ElementType::U16(0x1234) + Some(&0x1234u16) ), attr_data_path!( GenericPath::new( @@ -268,7 +243,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::Att2 as u32), ), - ElementType::U16(0x5678) + Some(&0x5678u16) ), attr_data_path!( GenericPath::new( @@ -276,7 +251,7 @@ fn test_read_wc_endpoint_wc_attribute() { Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::AttCustom as u32), ), - ElementType::U32(echo_cluster::ATTR_CUSTOM_VALUE) + Some(&echo_cluster::ATTR_CUSTOM_VALUE) ), ]; ImEngine::read_reqs(input, expected); @@ -290,12 +265,6 @@ fn test_write_success() { let val0 = 10; let val1 = 15; init_env_logger(); - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; - let attr_data1 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val1); - }; let ep0_att = GenericPath::new( Some(0), @@ -309,16 +278,8 @@ fn test_write_success() { ); let input = &[ - AttrData::new( - None, - AttrPath::new(&ep0_att), - EncodeValue::Closure(&attr_data0), - ), - AttrData::new( - None, - AttrPath::new(&ep1_att), - EncodeValue::Closure(&attr_data1), - ), + TestAttrData::new(None, AttrPath::new(&ep0_att), &val0 as _), + TestAttrData::new(None, AttrPath::new(&ep1_att), &val1 as _), ]; let expected = &[ AttrStatus::new(&ep0_att, IMStatusCode::Success, 0), @@ -341,20 +302,13 @@ fn test_write_wc_endpoint() { // - wildcard endpoint, AttWrite let val0 = 10; init_env_logger(); - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let ep_att = GenericPath::new( None, Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input = &[AttrData::new( - None, - AttrPath::new(&ep_att), - EncodeValue::Closure(&attr_data0), - )]; + let input = &[TestAttrData::new(None, AttrPath::new(&ep_att), &val0 as _)]; let ep0_att = GenericPath::new( Some(0), @@ -394,9 +348,6 @@ fn test_write_unsupported_fields() { init_env_logger(); let val0 = 50; - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let invalid_endpoint = GenericPath::new( Some(4), @@ -424,41 +375,21 @@ fn test_write_unsupported_fields() { let wc_attribute = GenericPath::new(Some(0), Some(echo_cluster::ID), None); let input = &[ - AttrData::new( - None, - AttrPath::new(&invalid_endpoint), - EncodeValue::Closure(&attr_data0), - ), - AttrData::new( - None, - AttrPath::new(&invalid_cluster), - EncodeValue::Closure(&attr_data0), - ), - AttrData::new( - None, - AttrPath::new(&invalid_attribute), - EncodeValue::Closure(&attr_data0), - ), - AttrData::new( + TestAttrData::new(None, AttrPath::new(&invalid_endpoint), &val0 as _), + TestAttrData::new(None, AttrPath::new(&invalid_cluster), &val0 as _), + TestAttrData::new(None, AttrPath::new(&invalid_attribute), &val0 as _), + TestAttrData::new( None, AttrPath::new(&wc_endpoint_invalid_cluster), - EncodeValue::Closure(&attr_data0), + &val0 as _, ), - AttrData::new( + TestAttrData::new( None, AttrPath::new(&wc_endpoint_invalid_attribute), - EncodeValue::Closure(&attr_data0), - ), - AttrData::new( - None, - AttrPath::new(&wc_cluster), - EncodeValue::Closure(&attr_data0), - ), - AttrData::new( - None, - AttrPath::new(&wc_attribute), - EncodeValue::Closure(&attr_data0), + &val0 as _, ), + TestAttrData::new(None, AttrPath::new(&wc_cluster), &val0 as _), + TestAttrData::new(None, AttrPath::new(&wc_attribute), &val0 as _), ]; let expected = &[ AttrStatus::new(&invalid_endpoint, IMStatusCode::UnsupportedEndpoint, 0), diff --git a/rs-matter/tests/data_model/commands.rs b/rs-matter/tests/data_model/commands.rs index b02a545f..b391f01d 100644 --- a/rs-matter/tests/data_model/commands.rs +++ b/rs-matter/tests/data_model/commands.rs @@ -15,19 +15,15 @@ * limitations under the License. */ -use crate::{ - cmd_data, - common::{commands::*, echo_cluster, im_engine::ImEngine, init_env_logger}, - echo_req, echo_resp, -}; +use rs_matter::data_model::cluster_on_off; +use rs_matter::interaction_model::core::IMStatusCode; +use rs_matter::interaction_model::messages::ib::{CmdPath, CmdStatus}; -use rs_matter::{ - data_model::{cluster_on_off, objects::EncodeValue}, - interaction_model::{ - core::IMStatusCode, - messages::ib::{CmdData, CmdPath, CmdStatus}, - }, -}; +use crate::common::e2e::im::commands::TestCmdResp; +use crate::common::e2e::im::echo_cluster; +use crate::common::e2e::ImEngine; +use crate::common::init_env_logger; +use crate::{cmd_data, echo_req, echo_resp}; #[test] fn test_invoke_cmds_success() { @@ -77,17 +73,17 @@ fn test_invoke_cmds_unsupported_fields() { ]; let expected = &[ - ExpectedInvResp::Status(CmdStatus::new( + TestCmdResp::Status(CmdStatus::new( invalid_endpoint, IMStatusCode::UnsupportedEndpoint, 0, )), - ExpectedInvResp::Status(CmdStatus::new( + TestCmdResp::Status(CmdStatus::new( invalid_cluster, IMStatusCode::UnsupportedCluster, 0, )), - ExpectedInvResp::Status(CmdStatus::new( + TestCmdResp::Status(CmdStatus::new( invalid_command, IMStatusCode::UnsupportedCommand, 0, @@ -128,7 +124,7 @@ fn test_invoke_cmd_wc_endpoint_only_1_has_cluster() { Some(cluster_on_off::CommandsDiscriminants::On as u32), ); let input = &[cmd_data!(target, 1)]; - let expected = &[ExpectedInvResp::Status(CmdStatus::new( + let expected = &[TestCmdResp::Status(CmdStatus::new( expected_path, IMStatusCode::Success, 0, diff --git a/rs-matter/tests/data_model/long_reads.rs b/rs-matter/tests/data_model/long_reads.rs index b3515258..63f72bdb 100644 --- a/rs-matter/tests/data_model/long_reads.rs +++ b/rs-matter/tests/data_model/long_reads.rs @@ -15,391 +15,198 @@ * limitations under the License. */ -use rs_matter::{ - data_model::{ - cluster_basic_information as basic_info, cluster_on_off as onoff, - objects::{EncodeValue, GlobalElements}, - sdm::{ - admin_commissioning as adm_comm, general_commissioning as gen_comm, noc, - nw_commissioning, - }, - system_model::{access_control as acl, descriptor}, - }, - interaction_model::{ - core::{IMStatusCode, OpCode}, - messages::{ - ib::{AttrData, AttrPath, AttrResp}, - msg::{ReadReq, ReportDataMsg, StatusResp, SubscribeResp}, - }, - messages::{msg::SubscribeReq, GenericPath}, - }, - tlv::{self, ElementType, FromTLV, TLVElement, TagType}, +use rs_matter::data_model::objects::GlobalElements; +use rs_matter::data_model::sdm::{ + admin_commissioning as adm_comm, general_commissioning as gen_comm, noc, nw_commissioning, }; +use rs_matter::data_model::system_model::{access_control as acl, descriptor}; +use rs_matter::data_model::{cluster_basic_information as basic_info, cluster_on_off as onoff}; +use rs_matter::interaction_model::core::IMStatusCode; +use rs_matter::interaction_model::messages::ib::AttrPath; +use rs_matter::interaction_model::messages::msg::{StatusResp, SubscribeResp}; +use rs_matter::interaction_model::messages::GenericPath; -use crate::{ - attr_data, - common::{ - attributes::*, - echo_cluster as echo, - im_engine::{ImEngine, ImInput}, - init_env_logger, - }, -}; - -fn wildcard_read_resp(part: u8) -> Vec> { - // For brevity, we only check the AttrPath, not the actual 'data' - let dont_care = ElementType::U8(0); - let part1 = vec![ - attr_data!(0, 29, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 29, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 0, - 29, - descriptor::Attributes::DeviceTypeList, - dont_care.clone() - ), - attr_data!(0, 29, descriptor::Attributes::ServerList, dont_care.clone()), - attr_data!(0, 29, descriptor::Attributes::PartsList, dont_care.clone()), - attr_data!(0, 29, descriptor::Attributes::ClientList, dont_care.clone()), - attr_data!(0, 40, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 40, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::DMRevision, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::VendorName, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::VendorId, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::ProductName, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::ProductId, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::NodeLabel, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::HwVer, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::SwVer, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::SwVerString, - dont_care.clone() - ), - attr_data!( - 0, - 40, - basic_info::AttributesDiscriminants::SerialNo, - dont_care.clone() - ), - attr_data!(0, 48, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 48, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 0, - 48, - gen_comm::AttributesDiscriminants::BreadCrumb, - dont_care.clone() - ), - attr_data!( - 0, - 48, - gen_comm::AttributesDiscriminants::RegConfig, - dont_care.clone() - ), - attr_data!( - 0, - 48, - gen_comm::AttributesDiscriminants::LocationCapability, - dont_care.clone() - ), - attr_data!( - 0, - 48, - gen_comm::AttributesDiscriminants::BasicCommissioningInfo, - dont_care.clone() - ), - attr_data!( - 0, - 48, - gen_comm::AttributesDiscriminants::SupportsConcurrentConnection, - dont_care.clone() - ), - attr_data!(0, 49, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 49, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::MaxNetworks, - dont_care.clone() - ), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::Networks, - dont_care.clone() - ), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::ConnectMaxTimeSecs, - dont_care.clone() - ), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::InterfaceEnabled, - dont_care.clone() - ), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::LastNetworkingStatus, - dont_care.clone() - ), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::LastNetworkID, - dont_care.clone() - ), - attr_data!( - 0, - 49, - nw_commissioning::Attributes::LastConnectErrorValue, - dont_care.clone() - ), - attr_data!(0, 60, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 60, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 0, - 60, - adm_comm::AttributesDiscriminants::WindowStatus, - dont_care.clone() - ), - ]; +use crate::attr_data; +use crate::common::e2e::im::attributes::TestAttrResp; +use crate::common::e2e::im::{echo_cluster as echo, ReplyProcessor, TestSubscribeReq}; +use crate::common::e2e::im::{TestReadReq, TestReportDataMsg}; +use crate::common::e2e::test::E2eTest; +use crate::common::e2e::tlv::TLVTest; +use crate::common::e2e::ImEngine; +use crate::common::init_env_logger; - let part2 = vec![ - attr_data!( - 0, - 60, - adm_comm::AttributesDiscriminants::AdminFabricIndex, - dont_care.clone() - ), - attr_data!( - 0, - 60, - adm_comm::AttributesDiscriminants::AdminVendorId, - dont_care.clone() - ), - attr_data!(0, 62, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 62, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 0, - 62, - noc::AttributesDiscriminants::CurrentFabricIndex, - dont_care.clone() - ), - attr_data!( - 0, - 62, - noc::AttributesDiscriminants::Fabrics, - dont_care.clone() - ), - attr_data!( - 0, - 62, - noc::AttributesDiscriminants::SupportedFabrics, - dont_care.clone() - ), - attr_data!( - 0, - 62, - noc::AttributesDiscriminants::CommissionedFabrics, - dont_care.clone() - ), - attr_data!(0, 31, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(0, 31, GlobalElements::AttributeList, dont_care.clone()), - attr_data!(0, 31, acl::AttributesDiscriminants::Acl, dont_care.clone()), - attr_data!( - 0, - 31, - acl::AttributesDiscriminants::Extension, - dont_care.clone() - ), - attr_data!( - 0, - 31, - acl::AttributesDiscriminants::SubjectsPerEntry, - dont_care.clone() - ), - attr_data!( - 0, - 31, - acl::AttributesDiscriminants::TargetsPerEntry, - dont_care.clone() - ), - attr_data!( - 0, - 31, - acl::AttributesDiscriminants::EntriesPerFabric, - dont_care.clone() - ), - attr_data!(0, echo::ID, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!( - 0, - echo::ID, - GlobalElements::AttributeList, - dont_care.clone() - ), - attr_data!( - 0, - echo::ID, - echo::AttributesDiscriminants::Att1, - dont_care.clone() - ), - attr_data!( - 0, - echo::ID, - echo::AttributesDiscriminants::Att2, - dont_care.clone() - ), - attr_data!( - 0, - echo::ID, - echo::AttributesDiscriminants::AttCustom, - dont_care.clone() - ), - attr_data!(1, 29, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(1, 29, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 1, - 29, - descriptor::Attributes::DeviceTypeList, - dont_care.clone() - ), - attr_data!(1, 29, descriptor::Attributes::ServerList, dont_care.clone()), - attr_data!(1, 29, descriptor::Attributes::PartsList, dont_care.clone()), - attr_data!(1, 29, descriptor::Attributes::ClientList, dont_care.clone()), - attr_data!(1, 6, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!(1, 6, GlobalElements::AttributeList, dont_care.clone()), - attr_data!( - 1, - 6, - onoff::AttributesDiscriminants::OnOff, - dont_care.clone() - ), - attr_data!(1, echo::ID, GlobalElements::FeatureMap, dont_care.clone()), - attr_data!( - 1, - echo::ID, - GlobalElements::AttributeList, - dont_care.clone() - ), - attr_data!( - 1, - echo::ID, - echo::AttributesDiscriminants::Att1, - dont_care.clone() - ), - attr_data!( - 1, - echo::ID, - echo::AttributesDiscriminants::Att2, - dont_care.clone() - ), - attr_data!( - 1, - echo::ID, - echo::AttributesDiscriminants::AttCustom, - dont_care - ), - ]; +static PART_1: &[TestAttrResp<'static>] = &[ + attr_data!(0, 29, GlobalElements::FeatureMap, None), + attr_data!(0, 29, GlobalElements::AttributeList, None), + attr_data!(0, 29, descriptor::Attributes::DeviceTypeList, None), + attr_data!(0, 29, descriptor::Attributes::ServerList, None), + attr_data!(0, 29, descriptor::Attributes::PartsList, None), + attr_data!(0, 29, descriptor::Attributes::ClientList, None), + attr_data!(0, 40, GlobalElements::FeatureMap, None), + attr_data!(0, 40, GlobalElements::AttributeList, None), + attr_data!(0, 40, basic_info::AttributesDiscriminants::DMRevision, None), + attr_data!(0, 40, basic_info::AttributesDiscriminants::VendorName, None), + attr_data!(0, 40, basic_info::AttributesDiscriminants::VendorId, None), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::ProductName, + None + ), + attr_data!(0, 40, basic_info::AttributesDiscriminants::ProductId, None), + attr_data!(0, 40, basic_info::AttributesDiscriminants::NodeLabel, None), + attr_data!(0, 40, basic_info::AttributesDiscriminants::HwVer, None), + attr_data!(0, 40, basic_info::AttributesDiscriminants::SwVer, None), + attr_data!( + 0, + 40, + basic_info::AttributesDiscriminants::SwVerString, + None + ), + attr_data!(0, 40, basic_info::AttributesDiscriminants::SerialNo, None), + attr_data!(0, 48, GlobalElements::FeatureMap, None), + attr_data!(0, 48, GlobalElements::AttributeList, None), + attr_data!(0, 48, gen_comm::AttributesDiscriminants::BreadCrumb, None), + attr_data!(0, 48, gen_comm::AttributesDiscriminants::RegConfig, None), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::LocationCapability, + None + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::BasicCommissioningInfo, + None + ), + attr_data!( + 0, + 48, + gen_comm::AttributesDiscriminants::SupportsConcurrentConnection, + None + ), + attr_data!(0, 49, GlobalElements::FeatureMap, None), + attr_data!(0, 49, GlobalElements::AttributeList, None), + attr_data!(0, 49, nw_commissioning::Attributes::MaxNetworks, None), + attr_data!(0, 49, nw_commissioning::Attributes::Networks, None), + attr_data!( + 0, + 49, + nw_commissioning::Attributes::ConnectMaxTimeSecs, + None + ), + attr_data!(0, 49, nw_commissioning::Attributes::InterfaceEnabled, None), + attr_data!( + 0, + 49, + nw_commissioning::Attributes::LastNetworkingStatus, + None + ), + attr_data!(0, 49, nw_commissioning::Attributes::LastNetworkID, None), + attr_data!( + 0, + 49, + nw_commissioning::Attributes::LastConnectErrorValue, + None + ), + attr_data!(0, 60, GlobalElements::FeatureMap, None), + attr_data!(0, 60, GlobalElements::AttributeList, None), + attr_data!(0, 60, adm_comm::AttributesDiscriminants::WindowStatus, None), + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::AdminFabricIndex, + None + ), +]; - if part == 1 { - part1 - } else { - part2 - } -} +static PART_2: &[TestAttrResp<'static>] = &[ + attr_data!( + 0, + 60, + adm_comm::AttributesDiscriminants::AdminVendorId, + None + ), + attr_data!(0, 62, GlobalElements::FeatureMap, None), + attr_data!(0, 62, GlobalElements::AttributeList, None), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::CurrentFabricIndex, + None + ), + attr_data!(0, 62, noc::AttributesDiscriminants::Fabrics, None), + attr_data!(0, 62, noc::AttributesDiscriminants::SupportedFabrics, None), + attr_data!( + 0, + 62, + noc::AttributesDiscriminants::CommissionedFabrics, + None + ), + attr_data!(0, 31, GlobalElements::FeatureMap, None), + attr_data!(0, 31, GlobalElements::AttributeList, None), + attr_data!(0, 31, acl::AttributesDiscriminants::Acl, None), + attr_data!(0, 31, acl::AttributesDiscriminants::Extension, None), + attr_data!(0, 31, acl::AttributesDiscriminants::SubjectsPerEntry, None), + attr_data!(0, 31, acl::AttributesDiscriminants::TargetsPerEntry, None), + attr_data!(0, 31, acl::AttributesDiscriminants::EntriesPerFabric, None), + attr_data!(0, echo::ID, GlobalElements::FeatureMap, None), + attr_data!(0, echo::ID, GlobalElements::AttributeList, None), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att1, None), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::Att2, None), + attr_data!(0, echo::ID, echo::AttributesDiscriminants::AttCustom, None), + attr_data!(1, 29, GlobalElements::FeatureMap, None), + attr_data!(1, 29, GlobalElements::AttributeList, None), + attr_data!(1, 29, descriptor::Attributes::DeviceTypeList, None), + attr_data!(1, 29, descriptor::Attributes::ServerList, None), + attr_data!(1, 29, descriptor::Attributes::PartsList, None), + attr_data!(1, 29, descriptor::Attributes::ClientList, None), + attr_data!(1, 6, GlobalElements::FeatureMap, None), + attr_data!(1, 6, GlobalElements::AttributeList, None), + attr_data!(1, 6, onoff::AttributesDiscriminants::OnOff, None), + attr_data!(1, echo::ID, GlobalElements::FeatureMap, None), + attr_data!(1, echo::ID, GlobalElements::AttributeList, None), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att1, None), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::Att2, None), + attr_data!(1, echo::ID, echo::AttributesDiscriminants::AttCustom, None), +]; #[test] fn test_long_read_success() { // Read the entire attribute database, which requires 2 reads to complete init_env_logger(); - let mut out = heapless::Vec::<_, 3>::new(); let im = ImEngine::new_default(); let handler = im.handler(); im.add_default_acl(); - let wc_path = GenericPath::new(None, None, None); - - let read_all = [AttrPath::new(&wc_path)]; - let read_req = ReadReq::new(true).set_attr_requests(&read_all); - let expected_part1 = wildcard_read_resp(1); - - let status_report = StatusResp { - status: IMStatusCode::Success, - }; - let expected_part2 = wildcard_read_resp(2); - - im.process( + im.test_all( &handler, - &[ - &ImInput::new(OpCode::ReadRequest, &read_req), - &ImInput::new(OpCode::StatusResponse, &status_report), + [ + &TLVTest::read( + TestReadReq::reqs(&[AttrPath::new(&GenericPath::new(None, None, None))]), + TestReportDataMsg { + attr_reports: Some(PART_1), + more_chunks: Some(true), + ..Default::default() + }, + ReplyProcessor::remove_attr_data, + ) as &dyn E2eTest, + &TLVTest::continue_report( + StatusResp { + status: IMStatusCode::Success, + }, + TestReportDataMsg { + attr_reports: Some(PART_2), + suppress_response: Some(true), + ..Default::default() + }, + ReplyProcessor::remove_attr_data, + ), ], - &mut out, - ) - .unwrap(); - - assert_eq!(out.len(), 2); - - assert_eq!(out[0].action, OpCode::ReportData); - - let root = tlv::get_root_node_struct(&out[0].data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part1); - assert_eq!(report_data.more_chunks, Some(true)); - - assert_eq!(out[1].action, OpCode::ReportData); - - let root = tlv::get_root_node_struct(&out[1].data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part2); - assert_eq!(report_data.more_chunks, None); + ); } #[test] @@ -407,53 +214,50 @@ fn test_long_read_subscription_success() { // Subscribe to the entire attribute database, which requires 2 reads to complete init_env_logger(); - let mut out = heapless::Vec::<_, 3>::new(); let im = ImEngine::new_default(); let handler = im.handler(); im.add_default_acl(); - let wc_path = GenericPath::new(None, None, None); - - let read_all = [AttrPath::new(&wc_path)]; - let subs_req = SubscribeReq::new(true, 1, 20).set_attr_requests(&read_all); - let expected_part1 = wildcard_read_resp(1); - - let status_report = StatusResp { - status: IMStatusCode::Success, - }; - let expected_part2 = wildcard_read_resp(2); - - im.process( + im.test_all( &handler, - &[ - &ImInput::new(OpCode::SubscribeRequest, &subs_req), - &ImInput::new(OpCode::StatusResponse, &status_report), - &ImInput::new(OpCode::StatusResponse, &status_report), + [ + &TLVTest::subscribe( + TestSubscribeReq { + min_int_floor: 1, + max_int_ceil: 10, + ..TestSubscribeReq::reqs(&[AttrPath::new(&GenericPath::new(None, None, None))]) + }, + TestReportDataMsg { + subscription_id: Some(1), + attr_reports: Some(PART_1), + more_chunks: Some(true), + ..Default::default() + }, + ReplyProcessor::remove_attr_data, + ) as &dyn E2eTest, + &TLVTest::continue_report( + StatusResp { + status: IMStatusCode::Success, + }, + TestReportDataMsg { + subscription_id: Some(1), + attr_reports: Some(PART_2), + ..Default::default() + }, + ReplyProcessor::remove_attr_data, + ), + &TLVTest::subscribe_final( + StatusResp { + status: IMStatusCode::Success, + }, + SubscribeResp { + subs_id: 1, + max_int: 40, + ..Default::default() + }, + ReplyProcessor::none, + ), ], - &mut out, - ) - .unwrap(); - - assert_eq!(out.len(), 3); - - assert_eq!(out[0].action, OpCode::ReportData); - - let root = tlv::get_root_node_struct(&out[0].data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part1); - assert_eq!(report_data.more_chunks, Some(true)); - - assert_eq!(out[1].action, OpCode::ReportData); - - let root = tlv::get_root_node_struct(&out[1].data).unwrap(); - let report_data = ReportDataMsg::from_tlv(&root).unwrap(); - assert_attr_report_skip_data(&report_data, &expected_part2); - assert_eq!(report_data.more_chunks, None); - - assert_eq!(out[2].action, OpCode::SubscribeResponse); - - let root = tlv::get_root_node_struct(&out[2].data).unwrap(); - let subs_resp = SubscribeResp::from_tlv(&root).unwrap(); - assert_eq!(subs_resp.subs_id, 1); + ); } diff --git a/rs-matter/tests/data_model/mod.rs b/rs-matter/tests/data_model/mod.rs new file mode 100644 index 00000000..85e37cca --- /dev/null +++ b/rs-matter/tests/data_model/mod.rs @@ -0,0 +1,23 @@ +/* + * + * Copyright (c) 2020-2022 Project CHIP Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +mod acl_and_dataver; +mod attribute_lists; +mod attributes; +mod commands; +mod long_reads; +mod timed_requests; diff --git a/rs-matter/tests/data_model/timed_requests.rs b/rs-matter/tests/data_model/timed_requests.rs index 1cce0a9b..f5af2077 100644 --- a/rs-matter/tests/data_model/timed_requests.rs +++ b/rs-matter/tests/data_model/timed_requests.rs @@ -15,47 +15,36 @@ * limitations under the License. */ -use rs_matter::{ - data_model::objects::EncodeValue, - interaction_model::{ - core::IMStatusCode, - messages::ib::{AttrData, AttrPath, AttrStatus}, - messages::{ib::CmdData, ib::CmdPath, GenericPath}, - }, - tlv::TLVWriter, -}; +use rs_matter::interaction_model::core::{IMStatusCode, OpCode, PROTO_ID_INTERACTION_MODEL}; +use rs_matter::interaction_model::messages::ib::{AttrPath, AttrStatus}; +use rs_matter::interaction_model::messages::msg::{StatusResp, TimedReq}; +use rs_matter::interaction_model::messages::GenericPath; +use rs_matter::transport::exchange::MessageMeta; -use crate::{ - common::{ - commands::*, - echo_cluster, - handlers::{TimedInvResponse, WriteResponse}, - im_engine::ImEngine, - init_env_logger, - }, - echo_req, echo_resp, +use crate::common::e2e::im::attributes::TestAttrData; +use crate::common::e2e::im::{ + echo_cluster, ReplyProcessor, TestInvReq, TestInvResp, TestWriteReq, TestWriteResp, }; +use crate::common::e2e::test::E2eTest; +use crate::common::e2e::tlv::TLVTest; +use crate::common::e2e::ImEngine; +use crate::common::init_env_logger; +use crate::{echo_req, echo_resp}; #[test] fn test_timed_write_fail_and_success() { + // - 2 Timed Attr Write Transactions should fail due to timeout mismatch // - 1 Timed Attr Write Transaction should fail due to timeout // - 1 Timed Attr Write Transaction should succeed let val0 = 10; init_env_logger(); - let attr_data0 = |tag, t: &mut TLVWriter| { - let _ = t.u16(tag, val0); - }; let ep_att = GenericPath::new( None, Some(echo_cluster::ID), Some(echo_cluster::AttributesDiscriminants::AttWrite as u32), ); - let input = &[AttrData::new( - None, - AttrPath::new(&ep_att), - EncodeValue::Closure(&attr_data0), - )]; + let input = &[TestAttrData::new(None, AttrPath::new(&ep_att), &val0 as _)]; let ep0_att = GenericPath::new( Some(0), @@ -73,20 +62,134 @@ fn test_timed_write_fail_and_success() { AttrStatus::new(&ep1_att, IMStatusCode::Success, 0), ]; - // Test with incorrect handling - ImEngine::timed_write_reqs(input, &WriteResponse::TransactionError, 100, 500); - - // Test with correct handling let im = ImEngine::new_default(); let handler = im.handler(); im.add_default_acl(); - im.handle_timed_write_reqs( + + // Test with timeout mismatch (timeout not set, but the following write req is timed) + im.test_one( + &handler, + TLVTest { + delay_ms: None, + input_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::WriteRequest as _, + reliable: true, + }, + input_payload: TestWriteReq { + timed_request: Some(true), + ..TestWriteReq::reqs(input) + }, + expected_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::StatusResponse as _, + reliable: true, + }, + expected_payload: StatusResp { + status: IMStatusCode::TimedRequestMisMatch, + }, + process_reply: ReplyProcessor::none, + }, + ); + + // Test with timeout mismatch (timeout set, but the write req is not timed) + im.test_all( + &handler, + [ + &TLVTest { + delay_ms: Some(100), + ..TLVTest::timed( + TimedReq { timeout: 1 }, + StatusResp { + status: IMStatusCode::Success, + }, + ReplyProcessor::none, + ) + } as &dyn E2eTest, + &TLVTest { + delay_ms: None, + input_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::WriteRequest as _, + reliable: true, + }, + input_payload: TestWriteReq::reqs(input), + expected_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::StatusResponse as _, + reliable: true, + }, + expected_payload: StatusResp { + status: IMStatusCode::TimedRequestMisMatch, + }, + process_reply: ReplyProcessor::none, + }, + ], + ); + + // Test with incorrect handling + im.test_all( + &handler, + [ + &TLVTest { + delay_ms: Some(100), + ..TLVTest::timed( + TimedReq { timeout: 1 }, + StatusResp { + status: IMStatusCode::Success, + }, + ReplyProcessor::none, + ) + } as &dyn E2eTest, + &TLVTest { + delay_ms: None, + input_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::WriteRequest as _, + reliable: true, + }, + input_payload: TestWriteReq { + timed_request: Some(true), + ..TestWriteReq::reqs(input) + }, + expected_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::StatusResponse as _, + reliable: true, + }, + expected_payload: StatusResp { + status: IMStatusCode::Timeout, + }, + process_reply: ReplyProcessor::none, + }, + ], + ); + + // Test with correct handling + im.test_all( &handler, - input, - &WriteResponse::TransactionSuccess(expected), - 400, - 0, + [ + &TLVTest { + delay_ms: None, + ..TLVTest::timed( + TimedReq { timeout: 500 }, + StatusResp { + status: IMStatusCode::Success, + }, + ReplyProcessor::none, + ) + } as &dyn E2eTest, + &TLVTest::write( + TestWriteReq { + timed_request: Some(true), + ..TestWriteReq::reqs(input) + }, + TestWriteResp::resp(expected), + ReplyProcessor::none, + ), + ], ); + assert_eq!(val0, handler.echo_cluster(0).att_write.get()); } @@ -97,50 +200,155 @@ fn test_timed_cmd_success() { let input = &[echo_req!(0, 5), echo_req!(1, 10)]; let expected = &[echo_resp!(0, 10), echo_resp!(1, 30)]; - ImEngine::timed_commands( - input, - &TimedInvResponse::TransactionSuccess(expected), - 2000, - 0, - true, + + let im = ImEngine::new_default(); + let handler = im.handler(); + im.add_default_acl(); + + // Test with correct handling + im.test_all( + &handler, + [ + &TLVTest { + delay_ms: None, + ..TLVTest::timed( + TimedReq { timeout: 2000 }, + StatusResp { + status: IMStatusCode::Success, + }, + ReplyProcessor::none, + ) + } as &dyn E2eTest, + &TLVTest::invoke( + TestInvReq { + timed_request: Some(true), + ..TestInvReq::reqs(input) + }, + TestInvResp::resp(expected), + ReplyProcessor::none, + ), + ], ); } #[test] fn test_timed_cmd_timeout() { - // A timed request that is executed after t imeout + // A timed request that is executed after a timeout init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::timed_commands( - input, - &TimedInvResponse::TransactionError(IMStatusCode::Timeout), - 100, - 500, - true, + + let im = ImEngine::new_default(); + let handler = im.handler(); + im.add_default_acl(); + + im.test_all( + &handler, + [ + &TLVTest { + delay_ms: Some(500), + ..TLVTest::timed( + TimedReq { timeout: 1 }, + StatusResp { + status: IMStatusCode::Success, + }, + ReplyProcessor::none, + ) + } as &dyn E2eTest, + &TLVTest { + delay_ms: None, + input_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::InvokeRequest as _, + reliable: true, + }, + input_payload: TestInvReq { + timed_request: Some(true), + ..TestInvReq::reqs(input) + }, + expected_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::StatusResponse as _, + reliable: true, + }, + expected_payload: StatusResp { + status: IMStatusCode::Timeout, + }, + process_reply: ReplyProcessor::none, + }, + ], ); } #[test] -fn test_timed_cmd_timedout_mismatch() { +fn test_timed_cmd_timeout_mismatch() { // A timed request with timeout mismatch init_env_logger(); let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::timed_commands( - input, - &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), - 2000, - 0, - false, + + let im = ImEngine::new_default(); + let handler = im.handler(); + im.add_default_acl(); + + // Test with timeout mismatch (timeout not set, but the following write req is timed) + im.test_one( + &handler, + TLVTest { + delay_ms: None, + input_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::InvokeRequest as _, + reliable: true, + }, + input_payload: TestInvReq { + timed_request: Some(true), + ..TestInvReq::reqs(input) + }, + expected_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::StatusResponse as _, + reliable: true, + }, + expected_payload: StatusResp { + status: IMStatusCode::TimedRequestMisMatch, + }, + process_reply: ReplyProcessor::none, + }, ); - let input = &[echo_req!(0, 5), echo_req!(1, 10)]; - ImEngine::timed_commands( - input, - &TimedInvResponse::TransactionError(IMStatusCode::TimedRequestMisMatch), - 0, - 0, - true, + // Test with timeout mismatch (timeout set, but the following write req is timed) + im.test_all( + &handler, + [ + &TLVTest { + delay_ms: None, + ..TLVTest::timed( + TimedReq { timeout: 1 }, + StatusResp { + status: IMStatusCode::Success, + }, + ReplyProcessor::none, + ) + } as &dyn E2eTest, + &TLVTest { + delay_ms: None, + input_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::InvokeRequest as _, + reliable: true, + }, + input_payload: TestInvReq::reqs(input), + expected_meta: MessageMeta { + proto_id: PROTO_ID_INTERACTION_MODEL, + proto_opcode: OpCode::StatusResponse as _, + reliable: true, + }, + expected_payload: StatusResp { + status: IMStatusCode::TimedRequestMisMatch, + }, + process_reply: ReplyProcessor::none, + }, + ], ); } diff --git a/rs-matter/tests/data_model_tests.rs b/rs-matter/tests/data_model_tests.rs index 392909fa..be359ead 100644 --- a/rs-matter/tests/data_model_tests.rs +++ b/rs-matter/tests/data_model_tests.rs @@ -16,12 +16,4 @@ */ mod common; - -mod data_model { - mod acl_and_dataver; - mod attribute_lists; - mod attributes; - mod commands; - mod long_reads; - mod timed_requests; -} +mod data_model; diff --git a/rs-matter/tests/tlv_encoding.rs b/rs-matter/tests/tlv_encoding.rs index 99a0b003..98659db9 100644 --- a/rs-matter/tests/tlv_encoding.rs +++ b/rs-matter/tests/tlv_encoding.rs @@ -19,7 +19,7 @@ mod tlv_encoding_tests { use bitflags::bitflags; use rs_matter::bitflags_tlv; use rs_matter::error::Error; - use rs_matter::tlv::{get_root_node, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; + use rs_matter::tlv::{FromTLV, TLVElement, TLVTag, TLVWriter, ToTLV}; use rs_matter::utils::storage::WriteBuf; #[derive(PartialEq, Debug, ToTLV, FromTLV)] @@ -48,14 +48,13 @@ mod tlv_encoding_tests { let mut output_buffer = [0u8; MAX_OUTPUT_SIZE]; let mut write_buf = WriteBuf::new(&mut output_buffer); let mut writer = TLVWriter::new(&mut write_buf); - what.to_tlv(&mut writer, TagType::Anonymous)?; + what.to_tlv(&TLVTag::Anonymous, &mut writer)?; Ok(Vec::from(write_buf.as_slice())) } fn decode_from_tlv<'a, T: FromTLV<'a>>(data: &'a [u8]) -> Result { - let node = get_root_node(data)?; - T::from_tlv(&node) + T::from_tlv(&TLVElement::new(data)) } macro_rules! asserted_ok { diff --git a/tools/tlv/src/main.rs b/tools/tlv/src/main.rs index caf5fc15..281bbbc1 100644 --- a/tools/tlv/src/main.rs +++ b/tools/tlv/src/main.rs @@ -67,15 +67,18 @@ fn main() { let tlv_list = base.parse_list(m.value_of("tlvs").unwrap(), ','); // println!("Decoding: {:x?}", tlv_list.as_slice()); + + let tlv = tlv::TLVElement::new(tlv_list.as_slice()); + if m.is_present("cert") { - let cert = cert::Cert::new(tlv_list.as_slice()).unwrap(); + let cert = cert::CertRef::new(tlv); println!("{}", cert); } else if m.is_present("as-asn1") { let mut asn1_cert = [0_u8; 1024]; - let cert = cert::Cert::new(tlv_list.as_slice()).unwrap(); + let cert = cert::CertRef::new(tlv); let len = cert.as_asn1(&mut asn1_cert).unwrap(); println!("{:02x?}", &asn1_cert[..len]); } else { - tlv::print_tlv_list(tlv_list.as_slice()); + println!("TLV {tlv}"); } }