Skip to content

Commit

Permalink
Be more thorough when checking asserts of arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Oct 20, 2024
1 parent e010444 commit 5ada779
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 96 deletions.
81 changes: 79 additions & 2 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use crate::{
Class, Client, Enum, EnumValue, Field, FunctionNode, RetryPolicy, TemplateString, TestCase,
},
};
use anyhow::Result;
use baml_types::{BamlMap, BamlValue};
use anyhow::{Context, Result};
use baml_types::{BamlMap, BamlValue, BamlValueWithMeta, FieldType, TypeValue};
pub use to_baml_arg::ArgCoercer;

use super::repr;
Expand Down Expand Up @@ -44,6 +44,7 @@ pub trait IRHelper {
params: &BamlMap<String, BamlValue>,
coerce_settings: ArgCoercer,
) -> Result<BamlValue>;
fn distribute_type<'a>(&'a self, value: BamlValue, field_type: &'a FieldType) -> Result<BamlValueWithMeta<&'a FieldType>>;
}

impl IRHelper for IntermediateRepr {
Expand Down Expand Up @@ -184,4 +185,80 @@ impl IRHelper for IntermediateRepr {
Ok(BamlValue::Map(baml_arg_map))
}
}

/// For some `BamlValue` with type `FieldType`, walk the structure of both the value
/// and the type simultaneously, associating each node in the `BamlValue` with its
/// `FieldType`.
fn distribute_type<'a>(
&'a self,
value: BamlValue,
field_type: &'a FieldType,
) -> anyhow::Result<BamlValueWithMeta<&'a FieldType>> {
let (unconstrained_type, _) = field_type.distribute_constraints();
match (value, unconstrained_type) {

(BamlValue::String(s), FieldType::Primitive(TypeValue::String)) => Ok(BamlValueWithMeta::String(s, field_type)),
(BamlValue::String(_), _) => anyhow::bail!("Could not unify Strinig with {:?}", field_type),

(BamlValue::Int(i), FieldType::Primitive(TypeValue::Int)) => Ok(BamlValueWithMeta::Int(i, field_type)),
(BamlValue::Int(_), _) => anyhow::bail!("Could not unify Int with {:?}", field_type),

(BamlValue::Float(f), FieldType::Primitive(TypeValue::Float)) => Ok(BamlValueWithMeta::Float(f, field_type)),
(BamlValue::Float(_), _) => anyhow::bail!("Could not unify Float with {:?}", field_type),

(BamlValue::Bool(b), FieldType::Primitive(TypeValue::Bool)) => Ok(BamlValueWithMeta::Bool(b, field_type)),
(BamlValue::Bool(_), _) => anyhow::bail!("Could not unify Bool with {:?}", field_type),

(BamlValue::Null, FieldType::Primitive(TypeValue::Null)) => Ok(BamlValueWithMeta::Null(field_type)),
(BamlValue::Null, _) => anyhow::bail!("Could not unify Null with {:?}", field_type),

(BamlValue::Map(pairs), FieldType::Map(k,val_type)) => {
let mapped_fields: BamlMap<String, BamlValueWithMeta<&FieldType>> =
pairs
.into_iter()
.map(|(key, val)| {
let sub_value = self.distribute_type(val, val_type.as_ref())?;
Ok((key, sub_value))
})
.collect::<anyhow::Result<BamlMap<String,BamlValueWithMeta<&FieldType>>>>()?;
Ok(BamlValueWithMeta::Map( mapped_fields, field_type ))
},
(BamlValue::Map(_), _) => anyhow::bail!("Could not unify Map with {:?}", field_type),

(BamlValue::List(items), FieldType::List(item_type)) => {
let mapped_items: Vec<BamlValueWithMeta<&FieldType>> =
items
.into_iter()
.map(|i| self.distribute_type(i, item_type))
.collect::<anyhow::Result<Vec<_>>>()?;
Ok(BamlValueWithMeta::List(mapped_items, field_type))
}
(BamlValue::List(_), _) => anyhow::bail!("Could not unify List with {:?}", field_type),

(BamlValue::Media(m), FieldType::Primitive(TypeValue::Media(_))) => Ok(BamlValueWithMeta::Media(m, field_type)),
(BamlValue::Media(_), _) => anyhow::bail!("Could not unify Media with {:?}", field_type),

(BamlValue::Enum(name, val), FieldType::Enum(type_name)) => if name == *type_name {
Ok(BamlValueWithMeta::Enum(name, val, field_type))
} else {
Err(anyhow::anyhow!("Could not unify Enum {name} with Enum type {type_name}"))
}
(BamlValue::Enum(enum_name,_), _) => anyhow::bail!("Could not unify Enum {enum_name} with {:?}", field_type),

(BamlValue::Class(name, fields), FieldType::Class(type_name)) => if name == *type_name {
let class_type = &self.find_class(type_name)?.item.elem;
let class_fields: BamlMap<&str, &FieldType> = class_type.static_fields.iter().map(|field_node| (field_node.elem.name.as_ref(), &field_node.elem.r#type.elem)).collect();
let mapped_fields = fields.into_iter().map(|(k,v)| {
let field_type = class_fields.get(k.as_str()).context("Could not find field {k} in class {name}")?;
let mapped_field = self.distribute_type(v, field_type)?;
Ok((k, mapped_field))
}).collect::<anyhow::Result<BamlMap<String, BamlValueWithMeta<&FieldType>>>>()?;
Ok(BamlValueWithMeta::Class(name, mapped_fields, field_type))
} else {
Err(anyhow::anyhow!("Could not unify Class {name} with Class type {type_name}"))
}
(BamlValue::Class(class_name,_), _) => anyhow::bail!("Could not unify Class {class_name} with {:?}", field_type),

}
}
}
68 changes: 43 additions & 25 deletions engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use baml_types::{
BamlMap, BamlValue, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue
BamlMap, BamlValue, BamlValueWithMeta, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue
};
use core::result::Result;
use std::path::PathBuf;
Expand Down Expand Up @@ -329,35 +329,53 @@ impl ArgCoercer {
}
}
FieldType::Constrained { base, constraints } => {
// to_baml_arg(base)
let val = self.coerce_arg(ir, base, value, scope)?;
for c@Constraint {
level,
expression,
label,
} in constraints.iter()
{
let constraint_ok =
evaluate_predicate(&val, &expression).unwrap_or_else(|err| {
scope.push_error(format!(
"Error while evaluating check {c:?}: {:?}",
err
));
false
});
if !constraint_ok {
let search_for_failures_result = first_failing_assert_nested(ir, &val, field_type).map_err(|e| {
scope.push_error(format!("Failed to evaluate assert: {:?}", e));
()
})?;
match search_for_failures_result {
Some(Constraint {label, expression, ..}) => {
let msg = label.as_ref().unwrap_or(&expression.0);
match level {
ConstraintLevel::Check => {
scope.push_warning(format!("Failed check: {msg}"));
}
ConstraintLevel::Assert => {
scope.push_error(format!("Failed assert: {msg}"));
}
}
scope.push_error(format!("Failed assert: {msg}"));
Ok(val)
}
None => Ok(val)
}
Ok(val)
}
}
}
}

/// Search a potentially deeply-nested `BamlValue` for any failing asserts,
/// returning the first one encountered.
fn first_failing_assert_nested<'a>(
ir: &'a IntermediateRepr,
baml_value: &BamlValue,
field_type: &'a FieldType
) -> anyhow::Result<Option<Constraint>> {
let value_with_types = ir.distribute_type(baml_value.clone(), field_type)?;
let first_failure = value_with_types
.iter()
.map(|value_node| {
let (_, constraints) = value_node.meta().distribute_constraints();
constraints.into_iter().filter_map(|c| {
let constraint = c.clone();
let baml_value: BamlValue = value_node.into();
let result = evaluate_predicate(&&baml_value, &c.expression).map_err(|e| {
anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e))
});
match result {
Ok(false) => if c.level == ConstraintLevel::Assert {Some(Ok(constraint))} else { None },
Ok(true) => None,
Err(e) => Some(Err(e))

}
})
})
.flatten()
.next();
first_failure.transpose()

}
3 changes: 2 additions & 1 deletion engine/baml-lib/baml-types/src/baml_value.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::collections::HashMap;
use std::{collections::{HashSet, VecDeque}, fmt};

use indexmap::IndexMap;
use serde::ser::{SerializeMap, SerializeSeq};
use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer};

use crate::media::BamlMediaType;
use crate::{BamlMap, BamlMedia, ResponseCheck};
use crate::{BamlMap, BamlMedia, FieldType, ResponseCheck, TypeValue};

#[derive(Clone, Debug, PartialEq)]
pub enum BamlValue {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/baml-types/src/field_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ impl FieldType {
///
/// If the function encounters directly nested Constrained types,
/// (i.e. `FieldType::Constrained { base: FieldType::Constrained { .. }, .. } `)
/// then the constraints of the two levels will be combined into a single vector.
/// then the constraints of the two levels will be flattened into a single vector.
/// So, we always return a base type that is not FieldType::Constrained.
pub fn distribute_constraints(self: &FieldType) -> (&FieldType, Vec<Constraint>) {

Expand Down
5 changes: 5 additions & 0 deletions engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ pub trait DefaultValue {
}

/// Run all checks and asserts for a value at a given type.
/// This function only runs checks on the top-level node of the `BamlValue`.
/// Checks on nested fields, list items etc. are not run here.
///
/// For a function that traverses a whole `BamlValue` looking for failed asserts,
/// see `first_failing_assert_nested`.
pub fn run_user_checks(
baml_value: &BamlValue,
type_: &FieldType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
K = TypeVar('K')

class Check(BaseModel):
name: Optional[str]
name: str
expression: str
status: str

class Checked(BaseModel, Generic[T,K]):
value: T
checks: K
checks: K
22 changes: 14 additions & 8 deletions engine/language_client_python/src/types/function_results.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::Context;
use baml_types::{BamlValue, BamlValueWithMeta, ResponseCheck};
use pyo3::prelude::{pymethods, PyResult};
use pyo3::types::{PyAnyMethods, PyListMethods, PyModule};
use pyo3::types::{PyAnyMethods, PyListMethods, PyModule, PyType};
use pyo3::{Bound, IntoPy, PyObject, Python};
use pythonize::pythonize;

Expand Down Expand Up @@ -46,26 +46,29 @@ fn pythonize_checks(
checks: &Vec<ResponseCheck>
) -> PyResult<PyObject> {

fn type_name_for_checks(checks: &Vec<ResponseCheck>) -> String {
fn type_name_for_checks(py: Python<'_>, checks: &Vec<ResponseCheck>) -> PyResult<String> {

let mut name = "Checks".to_string();
let mut check_names: Vec<String> = checks.iter().map(|ResponseCheck{name,..}| name).cloned().collect();
check_names.sort();
for check_name in check_names.iter() {
name.push_str("__");
name.push_str(check_name);
}
name
Ok(name)
}

let checks_class_name = type_name_for_checks(checks);
let baml_py = py.import_bound("baml_py")?;
let check_class = baml_py.getattr("Check")?;

let checks_class_name = type_name_for_checks(py, checks)?;
let checks_class = cls_module.getattr(checks_class_name.as_str())?;
let properties_dict = pyo3::types::PyDict::new_bound(py);
checks.iter().try_for_each(|ResponseCheck{name, expression, status}| {
// Construct the Check.
let check_class = cls_module.getattr("Check")?;
let check_properties_dict = pyo3::types::PyDict::new_bound(py);
check_properties_dict.set_item("name", name)?;
check_properties_dict.set_item("expr", expression)?;
check_properties_dict.set_item("expression", expression)?;
check_properties_dict.set_item("status", status)?;
let check_instance = check_class.call_method("model_validate", (check_properties_dict,), None)?;

Expand Down Expand Up @@ -96,13 +99,16 @@ fn pythonize_strict(
*parsed.meta_mut() = vec![];
let python_value = pythonize_strict(py, parsed, enum_module, cls_module)?;

// Assemble the pythonized checks and pythonized value into a `Checked[T,K]`.
let properties_dict = pyo3::types::PyDict::new_bound(py);
properties_dict.set_item("value", python_value)?;
properties_dict.set_item("checks", python_checks)?;

let class_checked_type = cls_module.getattr("Checked")?;
dbg!(&properties_dict);
eprintln!("{:?}", properties_dict);
let baml_py = py.import_bound("baml_py")?;
let class_checked_type = baml_py.getattr("Checked")?;
let checked_instance = class_checked_type.call_method("model_validate", (properties_dict,), None)?;

Ok(checked_instance.into())
} else {
match parsed {
Expand Down
6 changes: 3 additions & 3 deletions integ-tests/openapi/baml_client/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3306,16 +3306,16 @@ components:
checks:
type: object
properties:
regex_bad:
$ref: '#components/schemas/Check'
trivial:
$ref: '#components/schemas/Check'
regex_good:
$ref: '#components/schemas/Check'
regex_bad:
$ref: '#components/schemas/Check'
required:
- regex_bad
- trivial
- regex_good
- regex_bad
additionalProperties: false
required:
- value
Expand Down
2 changes: 1 addition & 1 deletion integ-tests/python/baml_client/partial_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from . import types


from .types import Checks__young_enough, Checks__too_big, Checks__valid_phone_number, Checks__valid_email, Checks__unreasonably_certain, Checks__earth_aged__no_infants, Checks__regex_bad__regex_good__trivial
from .types import Checks__valid_phone_number, Checks__regex_bad__regex_good__trivial, Checks__too_big, Checks__earth_aged__no_infants, Checks__valid_email, Checks__young_enough, Checks__unreasonably_certain


###############################################################################
Expand Down
16 changes: 8 additions & 8 deletions integ-tests/python/baml_client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,29 @@ class TestEnum(str, Enum):
F = "F"
G = "G"

class Checks__valid_email(BaseModel):
valid_email: baml_py.Check

class Checks__unreasonably_certain(BaseModel):
unreasonably_certain: baml_py.Check

class Checks__too_big(BaseModel):
too_big: baml_py.Check

class Checks__young_enough(BaseModel):
young_enough: baml_py.Check

class Checks__valid_email(BaseModel):
valid_email: baml_py.Check
class Checks__valid_phone_number(BaseModel):
valid_phone_number: baml_py.Check

class Checks__regex_bad__regex_good__trivial(BaseModel):
regex_good: baml_py.Check
trivial: baml_py.Check
regex_bad: baml_py.Check

class Checks__earth_aged__no_infants(BaseModel):
no_infants: baml_py.Check
earth_aged: baml_py.Check
no_infants: baml_py.Check

class Checks__valid_phone_number(BaseModel):
valid_phone_number: baml_py.Check
class Checks__too_big(BaseModel):
too_big: baml_py.Check

class BigNumbers(BaseModel):

Expand Down
Loading

0 comments on commit 5ada779

Please sign in to comment.