diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs index 67e872061..a9e654104 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs @@ -10,8 +10,8 @@ use crate::{ Class, Client, Enum, EnumValue, Field, FunctionNode, RetryPolicy, TemplateString, TestCase, }, }; -use anyhow::Result; -use baml_types::{BamlMap, BamlValue}; +use anyhow::{Context, Result}; +use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, TypeValue}; pub use to_baml_arg::ArgCoercer; use super::repr; @@ -44,6 +44,7 @@ pub trait IRHelper { params: &BamlMap, coerce_settings: ArgCoercer, ) -> Result; + fn distribute_type<'a>(&'a self, value: BamlValue, field_type: &'a FieldType) -> Result>; } impl IRHelper for IntermediateRepr { @@ -184,4 +185,80 @@ impl IRHelper for IntermediateRepr { Ok(BamlValue::Map(baml_arg_map)) } } + + /// For some `BamlValue` with type `FieldType`, walk the structure of both the value + /// and the type simultaneously, associating each node in the `BamlValue` with its + /// `FieldType`. + fn distribute_type<'a>( + &'a self, + value: BamlValue, + field_type: &'a FieldType, + ) -> anyhow::Result> { + let (unconstrained_type, _) = field_type.distribute_constraints(); + match (value, unconstrained_type) { + + (BamlValue::String(s), FieldType::Primitive(TypeValue::String)) => Ok(BamlValueWithMeta::String(s, field_type)), + (BamlValue::String(_), _) => anyhow::bail!("Could not unify Strinig with {:?}", field_type), + + (BamlValue::Int(i), FieldType::Primitive(TypeValue::Int)) => Ok(BamlValueWithMeta::Int(i, field_type)), + (BamlValue::Int(_), _) => anyhow::bail!("Could not unify Int with {:?}", field_type), + + (BamlValue::Float(f), FieldType::Primitive(TypeValue::Float)) => Ok(BamlValueWithMeta::Float(f, field_type)), + (BamlValue::Float(_), _) => anyhow::bail!("Could not unify Float with {:?}", field_type), + + (BamlValue::Bool(b), FieldType::Primitive(TypeValue::Bool)) => Ok(BamlValueWithMeta::Bool(b, field_type)), + (BamlValue::Bool(_), _) => anyhow::bail!("Could not unify Bool with {:?}", field_type), + + (BamlValue::Null, FieldType::Primitive(TypeValue::Null)) => Ok(BamlValueWithMeta::Null(field_type)), + (BamlValue::Null, _) => anyhow::bail!("Could not unify Null with {:?}", field_type), + + (BamlValue::Map(pairs), FieldType::Map(k,val_type)) => { + let mapped_fields: BamlMap> = + pairs + .into_iter() + .map(|(key, val)| { + let sub_value = self.distribute_type(val, val_type.as_ref())?; + Ok((key, sub_value)) + }) + .collect::>>>()?; + Ok(BamlValueWithMeta::Map( mapped_fields, field_type )) + }, + (BamlValue::Map(_), _) => anyhow::bail!("Could not unify Map with {:?}", field_type), + + (BamlValue::List(items), FieldType::List(item_type)) => { + let mapped_items: Vec> = + items + .into_iter() + .map(|i| self.distribute_type(i, item_type)) + .collect::>>()?; + Ok(BamlValueWithMeta::List(mapped_items, field_type)) + } + (BamlValue::List(_), _) => anyhow::bail!("Could not unify List with {:?}", field_type), + + (BamlValue::Media(m), FieldType::Primitive(TypeValue::Media(_))) => Ok(BamlValueWithMeta::Media(m, field_type)), + (BamlValue::Media(_), _) => anyhow::bail!("Could not unify Media with {:?}", field_type), + + (BamlValue::Enum(name, val), FieldType::Enum(type_name)) => if name == *type_name { + Ok(BamlValueWithMeta::Enum(name, val, field_type)) + } else { + Err(anyhow::anyhow!("Could not unify Enum {name} with Enum type {type_name}")) + } + (BamlValue::Enum(enum_name,_), _) => anyhow::bail!("Could not unify Enum {enum_name} with {:?}", field_type), + + (BamlValue::Class(name, fields), FieldType::Class(type_name)) => if name == *type_name { + let class_type = &self.find_class(type_name)?.item.elem; + let class_fields: BamlMap<&str, &FieldType> = class_type.static_fields.iter().map(|field_node| (field_node.elem.name.as_ref(), &field_node.elem.r#type.elem)).collect(); + let mapped_fields = fields.into_iter().map(|(k,v)| { + let field_type = class_fields.get(k.as_str()).context("Could not find field {k} in class {name}")?; + let mapped_field = self.distribute_type(v, field_type)?; + Ok((k, mapped_field)) + }).collect::>>>()?; + Ok(BamlValueWithMeta::Class(name, mapped_fields, field_type)) + } else { + Err(anyhow::anyhow!("Could not unify Class {name} with Class type {type_name}")) + } + (BamlValue::Class(class_name,_), _) => anyhow::bail!("Could not unify Class {class_name} with {:?}", field_type), + + } + } } diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 46d4d67cb..b8efcdd41 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -1,5 +1,5 @@ use baml_types::{ - BamlMap, BamlValue, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue + BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue }; use core::result::Result; use std::path::PathBuf; @@ -329,35 +329,53 @@ impl ArgCoercer { } } FieldType::Constrained { base, constraints } => { + // to_baml_arg(base) let val = self.coerce_arg(ir, base, value, scope)?; - for c@Constraint { - level, - expression, - label, - } in constraints.iter() - { - let constraint_ok = - evaluate_predicate(&val, &expression).unwrap_or_else(|err| { - scope.push_error(format!( - "Error while evaluating check {c:?}: {:?}", - err - )); - false - }); - if !constraint_ok { + let search_for_failures_result = first_failing_assert_nested(ir, &val, field_type).map_err(|e| { + scope.push_error(format!("Failed to evaluate assert: {:?}", e)); + () + })?; + match search_for_failures_result { + Some(Constraint {label, expression, ..}) => { let msg = label.as_ref().unwrap_or(&expression.0); - match level { - ConstraintLevel::Check => { - scope.push_warning(format!("Failed check: {msg}")); - } - ConstraintLevel::Assert => { - scope.push_error(format!("Failed assert: {msg}")); - } - } + scope.push_error(format!("Failed assert: {msg}")); + Ok(val) } + None => Ok(val) } - Ok(val) } } } } + +/// Search a potentially deeply-nested `BamlValue` for any failing asserts, +/// returning the first one encountered. +fn first_failing_assert_nested<'a>( + ir: &'a IntermediateRepr, + baml_value: &BamlValue, + field_type: &'a FieldType +) -> anyhow::Result> { + let value_with_types = ir.distribute_type(baml_value.clone(), field_type)?; + let first_failure = value_with_types + .iter() + .map(|value_node| { + let (_, constraints) = value_node.meta().distribute_constraints(); + constraints.into_iter().filter_map(|c| { + let constraint = c.clone(); + let baml_value: BamlValue = value_node.into(); + let result = evaluate_predicate(&&baml_value, &c.expression).map_err(|e| { + anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e)) + }); + match result { + Ok(false) => if c.level == ConstraintLevel::Assert {Some(Ok(constraint))} else { None }, + Ok(true) => None, + Err(e) => Some(Err(e)) + + } + }) + }) + .flatten() + .next(); + first_failure.transpose() + +} diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index a4f4cae4b..0cdb3a9be 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; use std::{collections::{HashSet, VecDeque}, fmt}; +use indexmap::IndexMap; use serde::ser::{SerializeMap, SerializeSeq}; use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use crate::media::BamlMediaType; -use crate::{BamlMap, BamlMedia, ResponseCheck}; +use crate::{BamlMap, BamlMedia, FieldType, ResponseCheck, TypeValue}; #[derive(Clone, Debug, PartialEq)] pub enum BamlValue { diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index becc743aa..7d45d8996 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -161,7 +161,7 @@ impl FieldType { /// /// If the function encounters directly nested Constrained types, /// (i.e. `FieldType::Constrained { base: FieldType::Constrained { .. }, .. } `) - /// then the constraints of the two levels will be combined into a single vector. + /// then the constraints of the two levels will be flattened into a single vector. /// So, we always return a base type that is not FieldType::Constrained. pub fn distribute_constraints(self: &FieldType) -> (&FieldType, Vec) { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs index 473f707c8..b9f91ce96 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs @@ -223,6 +223,11 @@ pub trait DefaultValue { } /// Run all checks and asserts for a value at a given type. +/// This function only runs checks on the top-level node of the `BamlValue`. +/// Checks on nested fields, list items etc. are not run here. +/// +/// For a function that traverses a whole `BamlValue` looking for failed asserts, +/// see `first_failing_assert_nested`. pub fn run_user_checks( baml_value: &BamlValue, type_: &FieldType, diff --git a/engine/language_client_python/python_src/baml_py/constraints.py b/engine/language_client_python/python_src/baml_py/constraints.py index 45ecf8aeb..4637e60fb 100644 --- a/engine/language_client_python/python_src/baml_py/constraints.py +++ b/engine/language_client_python/python_src/baml_py/constraints.py @@ -5,10 +5,10 @@ K = TypeVar('K') class Check(BaseModel): - name: Optional[str] + name: str expression: str status: str class Checked(BaseModel, Generic[T,K]): value: T - checks: K \ No newline at end of file + checks: K diff --git a/engine/language_client_python/src/types/function_results.rs b/engine/language_client_python/src/types/function_results.rs index 8802ed8a2..e9bde2194 100644 --- a/engine/language_client_python/src/types/function_results.rs +++ b/engine/language_client_python/src/types/function_results.rs @@ -1,7 +1,7 @@ use anyhow::Context; use baml_types::{BamlValue, BamlValueWithMeta, ResponseCheck}; use pyo3::prelude::{pymethods, PyResult}; -use pyo3::types::{PyAnyMethods, PyListMethods, PyModule}; +use pyo3::types::{PyAnyMethods, PyListMethods, PyModule, PyType}; use pyo3::{Bound, IntoPy, PyObject, Python}; use pythonize::pythonize; @@ -46,7 +46,8 @@ fn pythonize_checks( checks: &Vec ) -> PyResult { - fn type_name_for_checks(checks: &Vec) -> String { + fn type_name_for_checks(py: Python<'_>, checks: &Vec) -> PyResult { + let mut name = "Checks".to_string(); let mut check_names: Vec = checks.iter().map(|ResponseCheck{name,..}| name).cloned().collect(); check_names.sort(); @@ -54,18 +55,20 @@ fn pythonize_checks( name.push_str("__"); name.push_str(check_name); } - name + Ok(name) } - let checks_class_name = type_name_for_checks(checks); + let baml_py = py.import_bound("baml_py")?; + let check_class = baml_py.getattr("Check")?; + + let checks_class_name = type_name_for_checks(py, checks)?; let checks_class = cls_module.getattr(checks_class_name.as_str())?; let properties_dict = pyo3::types::PyDict::new_bound(py); checks.iter().try_for_each(|ResponseCheck{name, expression, status}| { // Construct the Check. - let check_class = cls_module.getattr("Check")?; let check_properties_dict = pyo3::types::PyDict::new_bound(py); check_properties_dict.set_item("name", name)?; - check_properties_dict.set_item("expr", expression)?; + check_properties_dict.set_item("expression", expression)?; check_properties_dict.set_item("status", status)?; let check_instance = check_class.call_method("model_validate", (check_properties_dict,), None)?; @@ -96,13 +99,16 @@ fn pythonize_strict( *parsed.meta_mut() = vec![]; let python_value = pythonize_strict(py, parsed, enum_module, cls_module)?; + // Assemble the pythonized checks and pythonized value into a `Checked[T,K]`. let properties_dict = pyo3::types::PyDict::new_bound(py); properties_dict.set_item("value", python_value)?; properties_dict.set_item("checks", python_checks)?; - let class_checked_type = cls_module.getattr("Checked")?; + dbg!(&properties_dict); + eprintln!("{:?}", properties_dict); + let baml_py = py.import_bound("baml_py")?; + let class_checked_type = baml_py.getattr("Checked")?; let checked_instance = class_checked_type.call_method("model_validate", (properties_dict,), None)?; - Ok(checked_instance.into()) } else { match parsed { diff --git a/integ-tests/openapi/baml_client/openapi.yaml b/integ-tests/openapi/baml_client/openapi.yaml index e0c6a0802..cb5bcde8b 100644 --- a/integ-tests/openapi/baml_client/openapi.yaml +++ b/integ-tests/openapi/baml_client/openapi.yaml @@ -3306,16 +3306,16 @@ components: checks: type: object properties: + regex_bad: + $ref: '#components/schemas/Check' trivial: $ref: '#components/schemas/Check' regex_good: $ref: '#components/schemas/Check' - regex_bad: - $ref: '#components/schemas/Check' required: + - regex_bad - trivial - regex_good - - regex_bad additionalProperties: false required: - value diff --git a/integ-tests/python/baml_client/partial_types.py b/integ-tests/python/baml_client/partial_types.py index e2d68a357..0623d6d95 100644 --- a/integ-tests/python/baml_client/partial_types.py +++ b/integ-tests/python/baml_client/partial_types.py @@ -21,7 +21,7 @@ from . import types -from .types import Checks__young_enough, Checks__too_big, Checks__valid_phone_number, Checks__valid_email, Checks__unreasonably_certain, Checks__earth_aged__no_infants, Checks__regex_bad__regex_good__trivial +from .types import Checks__valid_phone_number, Checks__regex_bad__regex_good__trivial, Checks__too_big, Checks__earth_aged__no_infants, Checks__valid_email, Checks__young_enough, Checks__unreasonably_certain ############################################################################### diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index 86a8f381c..cc55f4c53 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -124,17 +124,17 @@ class TestEnum(str, Enum): F = "F" G = "G" +class Checks__valid_email(BaseModel): + valid_email: baml_py.Check + class Checks__unreasonably_certain(BaseModel): unreasonably_certain: baml_py.Check -class Checks__too_big(BaseModel): - too_big: baml_py.Check - class Checks__young_enough(BaseModel): young_enough: baml_py.Check -class Checks__valid_email(BaseModel): - valid_email: baml_py.Check +class Checks__valid_phone_number(BaseModel): + valid_phone_number: baml_py.Check class Checks__regex_bad__regex_good__trivial(BaseModel): regex_good: baml_py.Check @@ -142,11 +142,11 @@ class Checks__regex_bad__regex_good__trivial(BaseModel): regex_bad: baml_py.Check class Checks__earth_aged__no_infants(BaseModel): - no_infants: baml_py.Check earth_aged: baml_py.Check + no_infants: baml_py.Check -class Checks__valid_phone_number(BaseModel): - valid_phone_number: baml_py.Check +class Checks__too_big(BaseModel): + too_big: baml_py.Check class BigNumbers(BaseModel): diff --git a/integ-tests/ruby/baml_client/types.rb b/integ-tests/ruby/baml_client/types.rb index aee245da8..d31b069c8 100644 --- a/integ-tests/ruby/baml_client/types.rb +++ b/integ-tests/ruby/baml_client/types.rb @@ -191,13 +191,13 @@ class TestOutputClass < T::Struct; end class TwoStoriesOneTitle < T::Struct; end class UnionTest_ReturnType < T::Struct; end class WithReasoning < T::Struct; end - class Checks__regex_bad__regex_good__trivial < T::Struct; end class Checks__valid_email < T::Struct; end - class Checks__earth_aged__no_infants < T::Struct; end class Checks__young_enough < T::Struct; end - class Checks__valid_phone_number < T::Struct; end - class Checks__too_big < T::Struct; end class Checks__unreasonably_certain < T::Struct; end + class Checks__too_big < T::Struct; end + class Checks__earth_aged__no_infants < T::Struct; end + class Checks__regex_bad__regex_good__trivial < T::Struct; end + class Checks__valid_phone_number < T::Struct; end class BigNumbers < T::Struct include Baml::Sorbet::Struct const :a, Integer @@ -1028,91 +1028,91 @@ def initialize(props) @props = props end end - class Checks__regex_bad__regex_good__trivial < T::Struct + class Checks__valid_email < T::Struct include Baml::Sorbet::Struct - const :regex_bad, Baml::Check - const :regex_good, Baml::Check - const :trivial, Baml::Check + const :valid_email, Baml::Check def initialize(props) super( - regex_bad: props[:regex_bad], - regex_good: props[:regex_good], - trivial: props[:trivial], + valid_email: props[:valid_email], ) @props = props end end - class Checks__valid_email < T::Struct + class Checks__young_enough < T::Struct include Baml::Sorbet::Struct - const :valid_email, Baml::Check + const :young_enough, Baml::Check def initialize(props) super( - valid_email: props[:valid_email], + young_enough: props[:young_enough], ) @props = props end end - class Checks__earth_aged__no_infants < T::Struct + class Checks__unreasonably_certain < T::Struct include Baml::Sorbet::Struct - const :earth_aged, Baml::Check - const :no_infants, Baml::Check + const :unreasonably_certain, Baml::Check def initialize(props) super( - earth_aged: props[:earth_aged], - no_infants: props[:no_infants], + unreasonably_certain: props[:unreasonably_certain], ) @props = props end end - class Checks__young_enough < T::Struct + class Checks__too_big < T::Struct include Baml::Sorbet::Struct - const :young_enough, Baml::Check + const :too_big, Baml::Check def initialize(props) super( - young_enough: props[:young_enough], + too_big: props[:too_big], ) @props = props end end - class Checks__valid_phone_number < T::Struct + class Checks__earth_aged__no_infants < T::Struct include Baml::Sorbet::Struct - const :valid_phone_number, Baml::Check + const :no_infants, Baml::Check + const :earth_aged, Baml::Check def initialize(props) super( - valid_phone_number: props[:valid_phone_number], + no_infants: props[:no_infants], + earth_aged: props[:earth_aged], ) @props = props end end - class Checks__too_big < T::Struct + class Checks__regex_bad__regex_good__trivial < T::Struct include Baml::Sorbet::Struct - const :too_big, Baml::Check + const :regex_bad, Baml::Check + const :regex_good, Baml::Check + const :trivial, Baml::Check def initialize(props) super( - too_big: props[:too_big], + regex_bad: props[:regex_bad], + regex_good: props[:regex_good], + trivial: props[:trivial], ) @props = props end end - class Checks__unreasonably_certain < T::Struct + class Checks__valid_phone_number < T::Struct include Baml::Sorbet::Struct - const :unreasonably_certain, Baml::Check + const :valid_phone_number, Baml::Check def initialize(props) super( - unreasonably_certain: props[:unreasonably_certain], + valid_phone_number: props[:valid_phone_number], ) @props = props diff --git a/integ-tests/typescript/baml_client/types.ts b/integ-tests/typescript/baml_client/types.ts index 86a1e299f..3432f45cf 100644 --- a/integ-tests/typescript/baml_client/types.ts +++ b/integ-tests/typescript/baml_client/types.ts @@ -121,30 +121,26 @@ export enum TestEnum { G = "G", } -export interface Checks__valid_phone_number { - valid_phone_number: Check +export interface Checks__valid_email { + valid_email: Check } -export interface Checks__too_big { - too_big: Check +export interface Checks__valid_phone_number { + valid_phone_number: Check } -export interface Checks__young_enough { - young_enough: Check +export interface Checks__regex_bad__regex_good__trivial { + regex_good: Check + regex_bad: Check + trivial: Check } export interface Checks__unreasonably_certain { unreasonably_certain: Check } -export interface Checks__regex_bad__regex_good__trivial { - trivial: Check - regex_bad: Check - regex_good: Check -} - -export interface Checks__valid_email { - valid_email: Check +export interface Checks__too_big { + too_big: Check } export interface Checks__earth_aged__no_infants { @@ -152,6 +148,10 @@ export interface Checks__earth_aged__no_infants { no_infants: Check } +export interface Checks__young_enough { + young_enough: Check +} + export interface BigNumbers { a: number b: number