From f7ddc58f9afcda505017ed0559bde251f4d5e522 Mon Sep 17 00:00:00 2001 From: Chris Beck Date: Mon, 14 Oct 2024 10:56:21 -0600 Subject: [PATCH] allow setting multiple validation_predicate fixes #11 --- REFERENCE_derive_conf.md | 2 + .../src/proc_macro_options/struct_item.rs | 15 ++-- tests/test_validation_predicate.rs | 72 +++++++++++++++++++ 3 files changed, 81 insertions(+), 8 deletions(-) diff --git a/REFERENCE_derive_conf.md b/REFERENCE_derive_conf.md index 4e76dc9..8924442 100644 --- a/REFERENCE_derive_conf.md +++ b/REFERENCE_derive_conf.md @@ -697,6 +697,8 @@ works on that struct. Attributes that are not "top-level only" will still have a Creates a validation constraint that must be satisfied after parsing this struct succeeds, from a user-defined function. The function should have signature `fn(&T) -> Result<(), impl Display>`. + The `validation_prediate = ...` attribute is allowed to repeat multiple times, to set multiple validation prediates. + [^1]: Actually, the *tokens* of the type are used, so e.g. it must be `bool` and not an alias for `bool`. diff --git a/conf_derive/src/proc_macro_options/struct_item.rs b/conf_derive/src/proc_macro_options/struct_item.rs index 49b8236..b82e8f0 100644 --- a/conf_derive/src/proc_macro_options/struct_item.rs +++ b/conf_derive/src/proc_macro_options/struct_item.rs @@ -51,7 +51,7 @@ pub struct StructItem { pub env_prefix: Option, pub serde: Option, pub one_of_fields: Vec<(Ordering, List)>, - pub validation_predicate: Option, + pub validation_predicates: Vec, pub doc_string: Option, } @@ -66,7 +66,7 @@ impl StructItem { env_prefix: None, serde: None, one_of_fields: Vec::default(), - validation_predicate: None, + validation_predicates: Vec::default(), doc_string: None, }; @@ -99,11 +99,10 @@ impl StructItem { } else if path.is_ident("serde") { set_once(&path, &mut result.serde, Some(StructSerdeItem::new(meta)?)) } else if path.is_ident("validation_predicate") { - set_once( - &path, - &mut result.validation_predicate, - Some(parse_required_value::(meta)?), - ) + result + .validation_predicates + .push(parse_required_value::(meta)?); + Ok(()) } else if path.is_ident("one_of_fields") { let idents: List = meta.input.parse()?; if idents.elements.len() < 2 { @@ -288,7 +287,7 @@ impl StructItem { } // Apply user-provided validation predicate, if any - if let Some(user_validation_predicate) = self.validation_predicate.as_ref() { + for user_validation_predicate in self.validation_predicates.iter() { predicate_evaluations.push(quote! { { fn __validation_predicate__( diff --git a/tests/test_validation_predicate.rs b/tests/test_validation_predicate.rs index e5e5409..7db6b28 100644 --- a/tests/test_validation_predicate.rs +++ b/tests/test_validation_predicate.rs @@ -63,3 +63,75 @@ fn test_validate_predicate_two_of_parsing() { assert!(result.b); assert!(result.c); } + +#[derive(Conf, Debug)] +#[conf(validation_predicate = MultiConstraint::b_required_if, validation_predicate = MultiConstraint::c_required_if)] +struct MultiConstraint { + #[arg(short)] + a: Option, + #[arg(short)] + b: Option, + #[arg(short)] + c: Option, +} + +impl MultiConstraint { + fn b_required_if(&self) -> Result<(), &'static str> { + if self.a == Some("b".to_owned()) && self.b.is_none() { + return Err("b is required if a = 'b'"); + } + Ok(()) + } + + fn c_required_if(&self) -> Result<(), &'static str> { + if self.b == Some("c".to_owned()) && self.c.is_none() { + return Err("c is required if b = 'c'"); + } + Ok(()) + } +} + +#[test] +fn test_multiple_validate_predicates() { + let result = MultiConstraint::try_parse_from::<&str, &str, &str>(vec!["."], vec![]).unwrap(); + assert_eq!(result.a, None); + assert_eq!(result.b, None); + assert_eq!(result.c, None); + + let result = + MultiConstraint::try_parse_from::<&str, &str, &str>(vec![".", "-a", "x"], vec![]).unwrap(); + assert_eq!(result.a, Some("x".to_owned())); + assert_eq!(result.b, None); + assert_eq!(result.c, None); + + assert_error_contains_text!( + MultiConstraint::try_parse_from::<&str, &str, &str>(vec![".", "-a", "b"], vec![]), + ["b is required if a = 'b'"] + ); + + let result = MultiConstraint::try_parse_from::<&str, &str, &str>( + vec![".", "-a", "b", "-b", "x"], + vec![], + ) + .unwrap(); + assert_eq!(result.a, Some("b".to_owned())); + assert_eq!(result.b, Some("x".to_owned())); + assert_eq!(result.c, None); + + assert_error_contains_text!( + MultiConstraint::try_parse_from::<&str, &str, &str>( + vec![".", "-a", "b", "-b", "c"], + vec![] + ), + ["c is required if b = 'c'"] + ); + + let result = MultiConstraint::try_parse_from::<&str, &str, &str>( + vec![".", "-a", "b", "-b", "c", "-c", "x"], + vec![], + ) + .unwrap(); + assert_eq!(result.a, Some("b".to_owned())); + assert_eq!(result.b, Some("c".to_owned())); + assert_eq!(result.c, Some("x".to_owned())); +}