diff --git a/engine/Cargo.lock b/engine/Cargo.lock index f257dff47..28c00c32d 100644 --- a/engine/Cargo.lock +++ b/engine/Cargo.lock @@ -2345,6 +2345,7 @@ dependencies = [ "env_logger", "indexmap 2.2.6", "internal-baml-core", + "itertools 0.13.0", "log", "pathdiff", "semver", @@ -2413,6 +2414,7 @@ dependencies = [ "indexmap 2.2.6", "log", "minijinja", + "regex", "serde", "serde_json", "strsim 0.11.1", @@ -2509,6 +2511,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs index 0f8182874..a88db813a 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/scope_diagnostics.rs @@ -148,4 +148,8 @@ impl ScopeStack { pub fn push_error(&mut self, error: String) { self.scopes.last_mut().unwrap().errors.push(error); } + + pub fn push_warning(&mut self, warning: String) { + self.scopes.last_mut().unwrap().warnings.push(warning); + } } diff --git a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs index 3192854ba..1505e6f2f 100644 --- a/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs +++ b/engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs @@ -1,10 +1,13 @@ -use baml_types::{BamlMap, BamlMediaType, BamlValue, FieldType, LiteralValue, TypeValue}; +use baml_types::{ + BamlMap, BamlValue, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue +}; use core::result::Result; use std::path::PathBuf; use crate::ir::IntermediateRepr; use super::{scope_diagnostics::ScopeStack, IRHelper}; +use internal_baml_jinja::evaluate_predicate; #[derive(Default)] pub struct ParameterError { @@ -325,6 +328,36 @@ impl ArgCoercer { } } } + FieldType::Constrained { base, constraints } => { + let val = self.coerce_arg(ir, base, value, scope)?; + for c@Constraint { + level, + expression, + label, + } in constraints.iter() + { + let constraint_ok = + evaluate_predicate(&val, &expression).unwrap_or_else(|err| { + scope.push_error(format!( + "Error while evaluating check {c:?}: {:?}", + err + )); + false + }); + if !constraint_ok { + let msg = label.as_ref().unwrap_or(&expression.0); + match level { + ConstraintLevel::Check => { + scope.push_warning(format!("Failed check: {msg}")); + } + ConstraintLevel::Assert => { + scope.push_error(format!("Failed assert: {msg}")); + } + } + } + } + Ok(val) + } } } } diff --git a/engine/baml-lib/baml-core/src/ir/json_schema.rs b/engine/baml-lib/baml-core/src/ir/json_schema.rs index c2bd4d99b..6d218651b 100644 --- a/engine/baml-lib/baml-core/src/ir/json_schema.rs +++ b/engine/baml-lib/baml-core/src/ir/json_schema.rs @@ -236,6 +236,7 @@ impl<'db> WithJsonSchema for FieldType { } } } + FieldType::Constrained { base, .. } => base.json_schema(), } } } diff --git a/engine/baml-lib/baml-core/src/ir/repr.rs b/engine/baml-lib/baml-core/src/ir/repr.rs index d4ad06dac..f5409a78e 100644 --- a/engine/baml-lib/baml-core/src/ir/repr.rs +++ b/engine/baml-lib/baml-core/src/ir/repr.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; -use anyhow::{anyhow, Context, Result}; -use baml_types::FieldType; +use anyhow::{anyhow, Result}; +use baml_types::{Constraint, ConstraintLevel, FieldType}; use either::Either; use indexmap::IndexMap; use internal_baml_parser_database::{ @@ -13,6 +13,7 @@ use internal_baml_parser_database::{ }; use internal_baml_schema_ast::ast::SubType; +use baml_types::JinjaExpression; use internal_baml_schema_ast::ast::{self, FieldArity, WithName, WithSpan}; use serde::Serialize; @@ -197,6 +198,8 @@ pub struct NodeAttributes { #[serde(with = "indexmap::map::serde_seq")] meta: IndexMap, + constraints: Vec, + // Spans #[serde(skip)] pub span: Option, @@ -208,39 +211,69 @@ impl NodeAttributes { } } -fn to_ir_attributes( - db: &ParserDatabase, - maybe_ast_attributes: Option<&Attributes>, -) -> IndexMap { - let mut attributes = IndexMap::new(); - - if let Some(Attributes { - description, - alias, - dynamic_type, - skip, - }) = maybe_ast_attributes - { - if let Some(true) = dynamic_type { - attributes.insert("dynamic_type".to_string(), Expression::Bool(true)); - } - if let Some(v) = alias { - attributes.insert("alias".to_string(), Expression::String(db[*v].to_string())); - } - if let Some(d) = description { - let ir_expr = match d { - ast::Expression::StringValue(s, _) => Expression::String(s.clone()), - ast::Expression::RawStringValue(s) => Expression::RawString(s.value().to_string()), - _ => panic!("Couldn't deal with description: {:?}", d), - }; - attributes.insert("description".to_string(), ir_expr); - } - if let Some(true) = skip { - attributes.insert("skip".to_string(), Expression::Bool(true)); +impl Default for NodeAttributes { + fn default() -> Self { + NodeAttributes { + meta: IndexMap::new(), + constraints: Vec::new(), + span: None, } } +} - attributes +fn to_ir_attributes( + db: &ParserDatabase, + maybe_ast_attributes: Option<&Attributes>, +) -> (IndexMap, Vec) { + let null_result = (IndexMap::new(), Vec::new()); + maybe_ast_attributes.map_or(null_result, |attributes| { + let Attributes { + description, + alias, + dynamic_type, + skip, + constraints, + } = attributes; + let description = description.as_ref().and_then(|d| { + let name = "description".to_string(); + match d { + ast::Expression::StringValue(s, _) => Some((name, Expression::String(s.clone()))), + ast::Expression::RawStringValue(s) => { + Some((name, Expression::RawString(s.value().to_string()))) + } + ast::Expression::JinjaExpressionValue(j, _) => { + Some((name, Expression::JinjaExpression(j.clone()))) + } + _ => { + eprintln!("Warning, encountered an unexpected description attribute"); + None + } + } + }); + let alias = alias + .as_ref() + .map(|v| ("alias".to_string(), Expression::String(db[*v].to_string()))); + let dynamic_type = dynamic_type.as_ref().and_then(|v| { + if *v { + Some(("dynamic_type".to_string(), Expression::Bool(true))) + } else { + None + } + }); + let skip = skip.as_ref().and_then(|v| { + if *v { + Some(("skip".to_string(), Expression::Bool(true))) + } else { + None + } + }); + + let meta = vec![description, alias, dynamic_type, skip] + .into_iter() + .filter_map(|s| s) + .collect(); + (meta, constraints.clone()) + }) } /// Nodes allow attaching metadata to a given IR entity: attributes, source location, etc @@ -256,6 +289,7 @@ pub trait WithRepr { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: IndexMap::new(), + constraints: Vec::new(), span: None, } } @@ -278,8 +312,46 @@ fn type_with_arity(t: FieldType, arity: &FieldArity) -> FieldType { } impl WithRepr for ast::FieldType { + + // TODO: (Greg) This code only extracts constraints, and ignores any + // other types of attributes attached to the type directly. + fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { + let constraints = self + .attributes() + .iter() + .filter_map(|attr| { + let level = match attr.name.to_string().as_str() { + "assert" => Some(ConstraintLevel::Assert), + "check" => Some(ConstraintLevel::Check), + _ => None + }?; + let (expression, label) = match attr.arguments.arguments.as_slice() { + [arg1, arg2] => match (arg1.clone().value, arg2.clone().value) { + (ast::Expression::JinjaExpressionValue(j,_), ast::Expression::StringValue(s,_)) => Some((j,Some(s))), + _ => None + }, + [arg1] => match arg1.clone().value { + ast::Expression::JinjaExpressionValue(JinjaExpression(j),_) => Some((JinjaExpression(j.clone()),None)), + _ => None + } + _ => None, + }?; + Some(Constraint{ level, expression, label }) + }) + .collect::>(); + let attributes = NodeAttributes { + meta: IndexMap::new(), + constraints, + span: Some(self.span().clone()), + }; + + attributes + } + fn repr(&self, db: &ParserDatabase) -> Result { - Ok(match self { + let constraints = WithRepr::attributes(self, db).constraints; + let has_constraints = constraints.len() > 0; + let base = match self { ast::FieldType::Primitive(arity, typeval, ..) => { let repr = FieldType::Primitive(typeval.clone()); if arity.is_optional() { @@ -347,7 +419,14 @@ impl WithRepr for ast::FieldType { FieldType::Tuple(t.iter().map(|ft| ft.repr(db)).collect::>>()?), arity, ), - }) + }; + + let with_constraints = if has_constraints { + FieldType::Constrained { base: Box::new(base.clone()), constraints } + } else { + base + }; + Ok(with_constraints) } } @@ -384,6 +463,7 @@ pub enum Expression { RawString(String), List(Vec), Map(Vec<(Expression, Expression)>), + JinjaExpression(JinjaExpression), } impl Expression { @@ -411,6 +491,9 @@ impl WithRepr for ast::Expression { ast::Expression::NumericValue(val, _) => Expression::Numeric(val.clone()), ast::Expression::StringValue(val, _) => Expression::String(val.clone()), ast::Expression::RawStringValue(val) => Expression::RawString(val.value().to_string()), + ast::Expression::JinjaExpressionValue(val, _) => { + Expression::JinjaExpression(val.clone()) + } ast::Expression::Identifier(idn) => match idn { ast::Identifier::ENV(k, _) => { Ok(Expression::Identifier(Identifier::ENV(k.clone()))) @@ -459,7 +542,7 @@ impl WithRepr for TemplateStringWalker<'_> { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: Default::default(), - + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -480,7 +563,6 @@ impl WithRepr for TemplateStringWalker<'_> { .ok() }) .collect::>(), - _ => vec![], }), content: self.template_string().to_string(), }) @@ -499,8 +581,10 @@ pub struct Enum { impl WithRepr for EnumValueWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + let (meta, constraints) = to_ir_attributes(db, self.get_default_attributes()); let attributes = NodeAttributes { - meta: to_ir_attributes(db, self.get_default_attributes()), + meta, + constraints, span: Some(self.span().clone()), }; @@ -514,8 +598,10 @@ impl WithRepr for EnumValueWalker<'_> { impl WithRepr for EnumWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + let (meta, constraints) = to_ir_attributes(db, self.get_default_attributes(SubType::Enum)); let attributes = NodeAttributes { - meta: to_ir_attributes(db, self.get_default_attributes(SubType::Enum)), + meta, + constraints, span: Some(self.span().clone()), }; @@ -541,8 +627,10 @@ pub struct Field { impl WithRepr for FieldWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { + let (meta, constraints) = to_ir_attributes(db, self.get_default_attributes()); let attributes = NodeAttributes { - meta: to_ir_attributes(db, self.get_default_attributes()), + meta, + constraints, span: Some(self.span().clone()), }; @@ -570,18 +658,26 @@ impl WithRepr for FieldWalker<'_> { type ClassId = String; +/// A BAML Class. #[derive(serde::Serialize, Debug)] pub struct Class { + /// User defined class name. pub name: ClassId, + + /// Fields of the class. pub static_fields: Vec>, + + /// Parameters to the class definition. pub inputs: Vec<(String, FieldType)>, } impl WithRepr for ClassWalker<'_> { fn attributes(&self, db: &ParserDatabase) -> NodeAttributes { let default_attributes = self.get_default_attributes(SubType::Class); + let (meta, constraints) = to_ir_attributes(db, default_attributes); let attributes = NodeAttributes { - meta: to_ir_attributes(db, default_attributes), + meta, + constraints, span: Some(self.span().clone()), }; @@ -799,6 +895,7 @@ impl WithRepr for FunctionWalker<'_> { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: Default::default(), + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -855,6 +952,7 @@ impl WithRepr for ClientWalker<'_> { fn attributes(&self, _: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: IndexMap::new(), + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -895,6 +993,7 @@ impl WithRepr for ConfigurationWalker<'_> { fn attributes(&self, _db: &ParserDatabase) -> NodeAttributes { NodeAttributes { meta: IndexMap::new(), + constraints: Vec::new(), span: Some(self.span().clone()), } } @@ -936,12 +1035,12 @@ impl WithRepr for (&ConfigurationWalker<'_>, usize) { let span = self.0.test_case().functions[self.1].1.clone(); NodeAttributes { meta: IndexMap::new(), - + constraints: Vec::new(), span: Some(span), } } - fn repr(&self, db: &ParserDatabase) -> Result { + fn repr(&self, _db: &ParserDatabase) -> Result { Ok(TestCaseFunction( self.0.test_case().functions[self.1].0.clone(), )) @@ -953,6 +1052,7 @@ impl WithRepr for ConfigurationWalker<'_> { NodeAttributes { meta: IndexMap::new(), span: Some(self.span().clone()), + constraints: Vec::new(), } } @@ -1008,3 +1108,22 @@ impl WithRepr for PromptAst<'_> { }) } } + +/// Generate an IntermediateRepr from a single block of BAML source code. +/// This is useful for generating IR test fixtures. +pub fn make_test_ir(source_code: &str) -> anyhow::Result { + use std::path::PathBuf; + use internal_baml_diagnostics::SourceFile; + use crate::ValidatedSchema; + use crate::validate; + + let path: PathBuf = "fake_file.baml".into(); + let source_file: SourceFile = (path.clone(), source_code).into(); + let validated_schema: ValidatedSchema = validate(&path, vec![source_file]); + let diagnostics = &validated_schema.diagnostics; + if diagnostics.has_errors() { + return Err(anyhow::anyhow!("Source code was invalid: \n{:?}", diagnostics.errors())) + } + let ir = IntermediateRepr::from_parser_database(&validated_schema.db, validated_schema.configuration)?; + Ok(ir) +} diff --git a/engine/baml-lib/baml-core/src/ir/walker.rs b/engine/baml-lib/baml-core/src/ir/walker.rs index 9810201c5..709517e7d 100644 --- a/engine/baml-lib/baml-core/src/ir/walker.rs +++ b/engine/baml-lib/baml-core/src/ir/walker.rs @@ -2,6 +2,7 @@ use anyhow::Result; use baml_types::BamlValue; use indexmap::IndexMap; +use internal_baml_jinja::render_expression; use internal_baml_parser_database::RetryPolicyStrategy; use std::collections::HashMap; @@ -214,6 +215,15 @@ impl Expression { anyhow::bail!("Invalid numeric value: {}", n) } } + Expression::JinjaExpression(expr) => { + // TODO: do not coerce all context values to strings. + let jinja_context: HashMap = env_values + .iter() + .map(|(k, v)| (k.clone(), BamlValue::String(v.clone()))) + .collect(); + let res_string = render_expression(&expr, &jinja_context)?; + Ok(BamlValue::String(res_string)) + } } } } @@ -407,7 +417,13 @@ impl<'a> Walker<'a, &'a Field> { self.item .attributes .get("description") - .map(|v| v.as_string_value(env_values)) + .map(|v| { + let normalized = v.normalize(env_values)?; + let baml_value = normalized + .as_str() + .ok_or(anyhow::anyhow!("Unexpected: Evaluated to non-string value"))?; + Ok(String::from(baml_value)) + }) .transpose() } @@ -415,3 +431,22 @@ impl<'a> Walker<'a, &'a Field> { self.item.attributes.span.as_ref() } } + +#[cfg(test)] +mod tests { + use super::*; + use baml_types::JinjaExpression; + + #[test] + fn basic_jinja_normalization() { + let expr = Expression::JinjaExpression(JinjaExpression("this == 'hello'".to_string())); + let env = vec![("this".to_string(), "hello".to_string())] + .into_iter() + .collect(); + let normalized = expr.normalize(&env).unwrap(); + match normalized { + BamlValue::String(s) => assert_eq!(&s, "true"), + _ => panic!("Expected String Expression"), + } + } +} diff --git a/engine/baml-lib/baml-core/src/lib.rs b/engine/baml-lib/baml-core/src/lib.rs index d14c4772f..d15eee019 100644 --- a/engine/baml-lib/baml-core/src/lib.rs +++ b/engine/baml-lib/baml-core/src/lib.rs @@ -41,7 +41,7 @@ impl std::fmt::Debug for ValidatedSchema { } } -/// The most general API for dealing with Prisma schemas. It accumulates what analysis and +/// The most general API for dealing with BAML source code. It accumulates what analysis and /// validation information it can, and returns it along with any error and warning diagnostics. pub fn validate(root_path: &PathBuf, files: Vec) -> ValidatedSchema { let mut diagnostics = Diagnostics::new(root_path.clone()); diff --git a/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs b/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs index 37e8b2f65..a75162253 100644 --- a/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs +++ b/engine/baml-lib/baml-core/src/validate/generator_loader/v1.rs @@ -138,21 +138,21 @@ pub(crate) fn parse_generator( }; match parse_required_key(&args, "test_command", ast_generator.span()) { - Ok(name) => (), + Ok(_name) => (), Err(err) => { errors.push(err); } }; match parse_required_key(&args, "install_command", ast_generator.span()) { - Ok(name) => (), + Ok(_name) => (), Err(err) => { errors.push(err); } }; match parse_required_key(&args, "package_version_command", ast_generator.span()) { - Ok(name) => (), + Ok(_name) => (), Err(err) => { errors.push(err); } diff --git a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs index cc26b9ed0..fed6df8a2 100644 --- a/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs +++ b/engine/baml-lib/baml-core/src/validate/validation_pipeline/validations/types.rs @@ -1,6 +1,6 @@ use baml_types::TypeValue; use internal_baml_diagnostics::DatamodelError; -use internal_baml_schema_ast::ast::{FieldArity, FieldType, Identifier, WithName, WithSpan}; +use internal_baml_schema_ast::ast::{Argument, Attribute, Expression, FieldArity, FieldType, Identifier, WithName, WithSpan}; use crate::validate::validation_pipeline::context::Context; @@ -22,6 +22,7 @@ fn errors_with_names<'a>(ctx: &'a mut Context<'_>, idn: &Identifier) { pub(crate) fn validate_type(ctx: &mut Context<'_>, field_type: &FieldType) { validate_type_exists(ctx, field_type); validate_type_allowed(ctx, field_type); + validate_type_constraints(ctx, field_type); } fn validate_type_exists(ctx: &mut Context<'_>, field_type: &FieldType) -> bool { @@ -46,7 +47,7 @@ fn validate_type_exists(ctx: &mut Context<'_>, field_type: &FieldType) -> bool { fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { match field_type { FieldType::Map(arity, kv_types, ..) => { - if (arity.is_optional()) { + if arity.is_optional() { ctx.push_error(DatamodelError::new_validation_error( format!("Maps are not allowed to be optional").as_str(), field_type.span().clone(), @@ -70,7 +71,7 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { FieldType::Symbol(..) => {} FieldType::List(arity, field_type, ..) => { - if (arity.is_optional()) { + if arity.is_optional() { ctx.push_error(DatamodelError::new_validation_error( format!("Lists are not allowed to be optional").as_str(), field_type.span().clone(), @@ -85,3 +86,36 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) { } } } + +fn validate_type_constraints(ctx: &mut Context<'_>, field_type: &FieldType) { + let constraint_attrs = field_type.attributes().iter().filter(|attr| ["assert", "check"].contains(&attr.name.name())).collect::>(); + for Attribute { arguments, span, name, .. } in constraint_attrs.iter() { + let arg_expressions = arguments.arguments.iter().map(|Argument{value,..}| value).collect::>(); + + match arg_expressions.as_slice() { + [Expression::JinjaExpressionValue(_, _), Expression::StringValue(s,_)] => { + // TODO: (Greg) use a real identifier parser. This is a temporary hack. + if !s.chars().all(|c| c.is_alphanumeric() || c == '_') { + ctx.push_error(DatamodelError::new_validation_error( + "Constraint names must be valid identifiers - only alphanumeric characters and underscores", + span.clone() + )) + } + }, + [Expression::JinjaExpressionValue(_, _)] => { + if name.to_string() == "check" { + ctx.push_error(DatamodelError::new_validation_error( + "Check constraints must have a name.", + span.clone() + )) + } + }, + _ => { + ctx.push_error(DatamodelError::new_validation_error( + "A constraint must have one Jinja argument such as {{ expr }}, and optionally one String label", + span.clone() + )); + } + } + } +} diff --git a/engine/baml-lib/baml-types/Cargo.toml b/engine/baml-lib/baml-types/Cargo.toml index 0cd8f0285..6a390a29e 100644 --- a/engine/baml-lib/baml-types/Cargo.toml +++ b/engine/baml-lib/baml-types/Cargo.toml @@ -19,7 +19,6 @@ workspace = true optional = true [dependencies.minijinja] -optional = true version = "1.0.16" default-features = false features = [ @@ -43,4 +42,3 @@ features = [ [features] default = ["stable_sort"] stable_sort = ["indexmap"] -mini-jinja = ["minijinja"] diff --git a/engine/baml-lib/baml-types/src/baml_value.rs b/engine/baml-lib/baml-types/src/baml_value.rs index 07a82ba44..8835bb9f6 100644 --- a/engine/baml-lib/baml-types/src/baml_value.rs +++ b/engine/baml-lib/baml-types/src/baml_value.rs @@ -1,9 +1,11 @@ +use std::collections::HashMap; use std::{collections::HashSet, fmt}; -use serde::{de::Visitor, Deserialize, Deserializer}; +use serde::ser::{SerializeMap, SerializeSeq}; +use serde::{de::Visitor, Deserialize, Deserializer, Serialize, Serializer}; use crate::media::BamlMediaType; -use crate::{BamlMap, BamlMedia}; +use crate::{BamlMap, BamlMedia, ResponseCheck}; #[derive(Clone, Debug, PartialEq)] pub enum BamlValue { @@ -141,6 +143,13 @@ impl BamlValue { _ => None, } } + + pub fn as_list_owned(self) -> Option> { + match self { + BamlValue::List(vals) => Some(vals), + _ => None, + } + } } impl std::fmt::Display for BamlValue { @@ -336,3 +345,191 @@ impl<'de> Visitor<'de> for BamlValueVisitor { Ok(BamlValue::Map(values)) } } + +/// A BamlValue with associated metadata. +/// This type is used to flexibly carry additional information. +/// It is used as a base type for situations where we want to represent +/// a BamlValue with additional information per node, such as a score, +/// or a constraint result. +#[derive(Clone, Debug, PartialEq)] +pub enum BamlValueWithMeta { + String(String, T), + Int(i64, T), + Float(f64, T), + Bool(bool, T), + Map(BamlMap>, T), + List(Vec>, T), + Media(BamlMedia, T), + Enum(String, String, T), + Class(String, BamlMap>, T), + Null(T), +} + +impl BamlValueWithMeta { + + pub fn r#type(&self) -> String { + let plain_value: BamlValue = self.into(); + plain_value.r#type() + } + + pub fn value(self) -> BamlValue { + match self { + BamlValueWithMeta::String(v, _) => BamlValue::String(v), + BamlValueWithMeta::Int(v, _) => BamlValue::Int(v), + BamlValueWithMeta::Float(v, _) => BamlValue::Float(v), + BamlValueWithMeta::Bool(v, _) => BamlValue::Bool(v), + BamlValueWithMeta::Map(v, _) => { + BamlValue::Map(v.into_iter().map(|(k, v)| (k, v.value())).collect()) + } + BamlValueWithMeta::List(v, _) => { + BamlValue::List(v.into_iter().map(|v| v.value()).collect()) + } + BamlValueWithMeta::Media(v, _) => BamlValue::Media(v), + BamlValueWithMeta::Enum(v, w, _) => BamlValue::Enum(v, w), + BamlValueWithMeta::Class(n, fs, _) => { + BamlValue::Class(n, fs.into_iter().map(|(k, v)| (k, v.value())).collect()) + } + BamlValueWithMeta::Null(_) => BamlValue::Null, + } + } + + pub fn meta(&self) -> &T { + match self { + BamlValueWithMeta::String(_, m) => m, + BamlValueWithMeta::Int(_, m) => m, + BamlValueWithMeta::Float(_, m) => m, + BamlValueWithMeta::Bool(_, m) => m, + BamlValueWithMeta::Map(_, m) => m, + BamlValueWithMeta::List(_, m) => m, + BamlValueWithMeta::Media(_, m) => m, + BamlValueWithMeta::Enum(_, _, m) => m, + BamlValueWithMeta::Class(_, _, m) => m, + BamlValueWithMeta::Null(m) => m, + } + } + + pub fn map_meta(self, f: F) -> BamlValueWithMeta + where + F: Fn(T) -> U + Copy, + { + match self { + BamlValueWithMeta::String(v, m) => BamlValueWithMeta::String(v, f(m)), + BamlValueWithMeta::Int(v, m) => BamlValueWithMeta::Int(v, f(m)), + BamlValueWithMeta::Float(v, m) => BamlValueWithMeta::Float(v, f(m)), + BamlValueWithMeta::Bool(v, m) => BamlValueWithMeta::Bool(v, f(m)), + BamlValueWithMeta::Map(v, m) => BamlValueWithMeta::Map( + v.into_iter().map(|(k, v)| (k, v.map_meta(f))).collect(), + f(m), + ), + BamlValueWithMeta::List(v, m) => { + BamlValueWithMeta::List(v.into_iter().map(|v| v.map_meta(f)).collect(), f(m)) + } + BamlValueWithMeta::Media(v, m) => BamlValueWithMeta::Media(v, f(m)), + BamlValueWithMeta::Enum(v, e, m) => BamlValueWithMeta::Enum(v, e, f(m)), + BamlValueWithMeta::Class(n, fs, m) => BamlValueWithMeta::Class( + n, + fs.into_iter().map(|(k, v)| (k, v.map_meta(f))).collect(), + f(m), + ), + BamlValueWithMeta::Null(m) => BamlValueWithMeta::Null(f(m)), + } + } +} + +impl From<&BamlValueWithMeta> for BamlValue { + fn from(baml_value: &BamlValueWithMeta) -> BamlValue { + use BamlValueWithMeta::*; + match baml_value { + String(v, _) => BamlValue::String(v.clone()), + Int(v, _) => BamlValue::Int(v.clone()), + Float(v, _) => BamlValue::Float(v.clone()), + Bool(v, _) => BamlValue::Bool(v.clone()), + Map(v, _) => BamlValue::Map(v.into_iter().map(|(k,v)| (k.clone(), v.into())).collect()), + List(v, _) => BamlValue::List(v.into_iter().map(|v| v.into()).collect()), + Media(v, _) => BamlValue::Media(v.clone()), + Enum(enum_name, v, _) => BamlValue::Enum(enum_name.clone(), v.clone()), + Class(class_name, v, _) => BamlValue::Class(class_name.clone(), v.into_iter().map(|(k,v)| (k.clone(), v.into())).collect()), + Null(_) => BamlValue::Null, + } + } +} + +impl From> for BamlValue { + fn from(baml_value: BamlValueWithMeta) -> BamlValue { + use BamlValueWithMeta::*; + match baml_value { + String(v, _) => BamlValue::String(v), + Int(v, _) => BamlValue::Int(v), + Float(v, _) => BamlValue::Float(v), + Bool(v, _) => BamlValue::Bool(v), + Map(v, _) => BamlValue::Map(v.into_iter().map(|(k,v)| (k, v.into())).collect()), + List(v, _) => BamlValue::List(v.into_iter().map(|v| v.into()).collect()), + Media(v, _) => BamlValue::Media(v), + Enum(enum_name, v, _) => BamlValue::Enum(enum_name, v), + Class(class_name, v, _) => BamlValue::Class(class_name, v.into_iter().map(|(k,v)| (k, v.into())).collect()), + Null(_) => BamlValue::Null, + } + } +} + +impl Serialize for BamlValueWithMeta> { + fn serialize(&self, serializer: S) -> Result + where S: Serializer, + { + match self { + BamlValueWithMeta::String(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Int(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Float(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Bool(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Map(v, cr) => { + let mut map = serializer.serialize_map(None)?; + for (key, value) in v { + map.serialize_entry(key, value)?; + } + add_checks(&mut map, cr)?; + map.end() + }, + BamlValueWithMeta::List(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Media(v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Enum(_enum_name, v, cr) => serialize_with_checks(v, cr, serializer), + BamlValueWithMeta::Class(_class_name, v, cr) => { + let mut map = serializer.serialize_map(None)?; + for (key, value) in v { + map.serialize_entry(key, value)?; + } + add_checks(&mut map, cr)?; + map.end() + }, + BamlValueWithMeta::Null(cr) => serialize_with_checks(&(), cr, serializer), + } + } +} + +fn serialize_with_checks( + value: &T, + checks: &Vec, + serializer:S, + +) -> Result + where S: Serializer, +{ + if !checks.is_empty() { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("value", value)?; + add_checks(&mut map, checks)?; + map.end() + } else { + value.serialize(serializer) + } +} + +fn add_checks<'a, S: SerializeMap>( + map: &'a mut S, + checks: &'a Vec, +) -> Result<(), S::Error> { + if !checks.is_empty() { + let checks_map: HashMap<_,_> = checks.iter().map(|check| (check.name.clone(), check)).collect(); + map.serialize_entry("checks", &checks_map)?; + } + Ok(()) +} diff --git a/engine/baml-lib/baml-types/src/constraint.rs b/engine/baml-lib/baml-types/src/constraint.rs new file mode 100644 index 000000000..15ab3ae01 --- /dev/null +++ b/engine/baml-lib/baml-types/src/constraint.rs @@ -0,0 +1,38 @@ +use crate::JinjaExpression; + +#[derive(Clone, Debug, serde::Serialize, PartialEq)] +pub struct Constraint { + pub level: ConstraintLevel, + pub expression: JinjaExpression, + pub label: Option, +} + +#[derive(Clone, Debug, PartialEq, serde::Serialize)] +pub enum ConstraintLevel { + Check, + Assert, +} + +/// The user-visible schema for a failed check. +#[derive(Clone, Debug, serde::Serialize)] +pub struct ResponseCheck { + pub name: Option, + pub expression: String, + pub status: String, +} + +impl ResponseCheck { + pub fn from_constraint_result((Constraint{ level, expression, label }, succeeded): (Constraint, bool)) -> Option { + match level { + ConstraintLevel::Check => { + let status = if succeeded { "succeeded".to_string() } else { "failed".to_string() }; + Some( ResponseCheck { + name: label, + expression: expression.0, + status + }) + }, + _ => None, + } + } +} diff --git a/engine/baml-lib/baml-types/src/field_type/mod.rs b/engine/baml-lib/baml-types/src/field_type/mod.rs index bde7184ab..4b1093920 100644 --- a/engine/baml-lib/baml-types/src/field_type/mod.rs +++ b/engine/baml-lib/baml-types/src/field_type/mod.rs @@ -1,4 +1,5 @@ use crate::BamlMediaType; +use crate::Constraint; mod builder; @@ -69,7 +70,7 @@ impl std::fmt::Display for LiteralValue { } /// FieldType represents the type of either a class field or a function arg. -#[derive(serde::Serialize, Debug, Clone)] +#[derive(serde::Serialize, Debug, Clone, PartialEq)] pub enum FieldType { Primitive(TypeValue), Enum(String), @@ -80,6 +81,7 @@ pub enum FieldType { Union(Vec), Tuple(Vec), Optional(Box), + Constrained{ base: Box, constraints: Vec }, } // Impl display for FieldType @@ -116,6 +118,7 @@ impl std::fmt::Display for FieldType { FieldType::Map(k, v) => write!(f, "map<{}, {}>", k.to_string(), v.to_string()), FieldType::List(t) => write!(f, "{}[]", t.to_string()), FieldType::Optional(t) => write!(f, "{}?", t.to_string()), + FieldType::Constrained{base,..} => base.fmt(f), } } } @@ -126,6 +129,7 @@ impl FieldType { FieldType::Primitive(_) => true, FieldType::Optional(t) => t.is_primitive(), FieldType::List(t) => t.is_primitive(), + FieldType::Constrained{base,..} => base.is_primitive(), _ => false, } } @@ -134,8 +138,8 @@ impl FieldType { match self { FieldType::Optional(_) => true, FieldType::Primitive(TypeValue::Null) => true, - FieldType::Union(types) => types.iter().any(FieldType::is_optional), + FieldType::Constrained{base,..} => base.is_optional(), _ => false, } } @@ -144,6 +148,7 @@ impl FieldType { match self { FieldType::Primitive(TypeValue::Null) => true, FieldType::Optional(t) => t.is_null(), + FieldType::Constrained{base,..} => base.is_null(), _ => false, } } diff --git a/engine/baml-lib/baml-types/src/lib.rs b/engine/baml-lib/baml-types/src/lib.rs index 21e6cf0e8..fb721ff8d 100644 --- a/engine/baml-lib/baml-types/src/lib.rs +++ b/engine/baml-lib/baml-types/src/lib.rs @@ -1,14 +1,16 @@ +mod constraint; mod map; mod media; -#[cfg(feature = "mini-jinja")] mod minijinja; mod baml_value; mod field_type; mod generator; -pub use baml_value::BamlValue; +pub use baml_value::{BamlValue, BamlValueWithMeta}; +pub use constraint::*; pub use field_type::{FieldType, LiteralValue, TypeValue}; pub use generator::{GeneratorDefaultClientMode, GeneratorOutputType}; pub use map::Map as BamlMap; pub use media::{BamlMedia, BamlMediaContent, BamlMediaType, MediaBase64, MediaUrl}; +pub use minijinja::JinjaExpression; diff --git a/engine/baml-lib/baml-types/src/minijinja.rs b/engine/baml-lib/baml-types/src/minijinja.rs index e1c3f168d..36aa7a5a4 100644 --- a/engine/baml-lib/baml-types/src/minijinja.rs +++ b/engine/baml-lib/baml-types/src/minijinja.rs @@ -1,5 +1,19 @@ +use std::fmt; use crate::{BamlMedia, BamlValue}; +/// A wrapper around a jinja expression. The inner `String` should not contain +/// the interpolation brackets `{{ }}`; it should be a bare expression like +/// `"this|length < something"`. +#[derive(Clone, Debug, PartialEq, serde::Serialize)] +pub struct JinjaExpression(pub String); + + +impl fmt::Display for JinjaExpression { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + impl From for minijinja::Value { fn from(arg: BamlValue) -> minijinja::Value { match arg { diff --git a/engine/baml-lib/baml/tests/validation_files/constraints/constraints_everywhere.baml b/engine/baml-lib/baml/tests/validation_files/constraints/constraints_everywhere.baml new file mode 100644 index 000000000..c3b156bde --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/constraints/constraints_everywhere.baml @@ -0,0 +1,12 @@ +client Bar { + provider baml-openai-chat +} + +class Foo { + age int @check({{this > 10}}, "old enough") +} + +function FooToInt(foo: Foo, a: Foo @assert({{this.age > 20}}, "really old")) -> int @check({{ this < 10 }}) { + client Bar + prompt #"fa"# +} diff --git a/engine/baml-lib/baml/tests/validation_files/constraints/misspelled.baml b/engine/baml-lib/baml/tests/validation_files/constraints/misspelled.baml new file mode 100644 index 000000000..33b8cc612 --- /dev/null +++ b/engine/baml-lib/baml/tests/validation_files/constraints/misspelled.baml @@ -0,0 +1,11 @@ +class Foo { + // A constraint that didn't use Jinja Expression syntax. + age int @check("this < 10") +} + +// error: Error validating: A constraint must have one Jinja argument such as {{ expr }}, and optionally one String label +// --> constraints/misspelled.baml:3 +// | +// 2 | // A constraint that didn't use Jinja Expression syntax. +// 3 | age int @check("this < 10") +// | diff --git a/engine/baml-lib/jinja/Cargo.toml b/engine/baml-lib/jinja/Cargo.toml index 93e6d37fb..0f0bb78fb 100644 --- a/engine/baml-lib/jinja/Cargo.toml +++ b/engine/baml-lib/jinja/Cargo.toml @@ -9,7 +9,7 @@ license-file.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -baml-types = { path = "../baml-types", features = ["mini-jinja"] } +baml-types = { path = "../baml-types" } # TODO: disable imports, etc minijinja = { version = "1.0.16", default-features = false, features = [ "macros", @@ -38,6 +38,7 @@ serde_json.workspace = true strum.workspace = true strsim = "0.11.1" colored = "2.1.0" +regex.workspace = true [dev-dependencies] env_logger = "0.11.3" diff --git a/engine/baml-lib/jinja/src/lib.rs b/engine/baml-lib/jinja/src/lib.rs index 4d49dcc9d..8d914a289 100644 --- a/engine/baml-lib/jinja/src/lib.rs +++ b/engine/baml-lib/jinja/src/lib.rs @@ -1,4 +1,4 @@ -use baml_types::{BamlMedia, BamlValue}; +use baml_types::{BamlMedia, BamlValue, JinjaExpression}; use colored::*; mod chat_message_part; mod evaluate_type; @@ -12,6 +12,7 @@ pub use evaluate_type::{PredefinedTypes, Type, TypeError}; use minijinja::{self, value::Kwargs}; use minijinja::{context, ErrorKind, Value}; use output_format::types::OutputFormatContent; +use regex::Regex; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::HashMap; @@ -24,9 +25,17 @@ fn get_env<'a>() -> minijinja::Environment<'a> { env.set_debug(true); env.set_trim_blocks(true); env.set_lstrip_blocks(true); + env.add_filter("regex_match", regex_match); env } +fn regex_match(value: String, regex: String) -> bool { + match Regex::new(®ex) { + Err(_) => false, + Ok(re) => re.is_match(&value) + } +} + #[derive(Debug)] pub struct ValidationError { pub errors: Vec, @@ -80,6 +89,10 @@ pub struct RenderContext_Client { pub default_role: String, } +/// A collection of values about the rendering context that will be made +/// available to a prompt via `{{ ctx }}`. For example `{{ ctx.client.name }}` +/// used in a prompt string will resolve to the name of the client, e.g. +/// "openai". #[derive(Debug)] pub struct RenderContext { pub client: RenderContext_Client, @@ -487,12 +500,42 @@ pub fn render_prompt( } } +/// Render a bare minijinaja expression with the given context. +/// E.g. `"a|length > 2"` with context `{"a": [1, 2, 3]}` will return `"true"`. +pub fn render_expression( + expression: &JinjaExpression, + ctx: &HashMap, +) -> anyhow::Result { + let env = get_env(); + // In rust string literals, `{` is escaped as `{{`. + // So producing the string `{{}}` requires writing the literal `"{{{{}}}}"` + let template = format!(r#"{{{{ {} }}}}"#, expression.0); + let args_dict = minijinja::Value::from_serialize(ctx); + eprintln!("{}", &template); + Ok(env.render_str(&template, &args_dict)?) +} + +// TODO: (Greg) better error handling. +// TODO: (Greg) Upstream, typecheck the expression. +pub fn evaluate_predicate( + this: &BamlValue, + predicate_expression: &JinjaExpression, +) -> Result { + let ctx: HashMap = + [("this".to_string(), this.clone())].into_iter().collect(); + match render_expression(&predicate_expression, &ctx)?.as_ref() { + "true" => Ok(true), + "false" => Ok(false), + _ => Err(anyhow::anyhow!("TODO")), + } +} + #[cfg(test)] mod render_tests { use super::*; - use baml_types::{BamlMap, BamlMediaType}; + use baml_types::{BamlMap, BamlMediaType, JinjaExpression}; use env_logger; use std::sync::Once; @@ -1107,4 +1150,45 @@ mod render_tests { Ok(()) } + + #[test] + fn test_render_expressions() { + let ctx = vec![( + "a".to_string(), + BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()) + ), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))] + .into_iter() + .collect(); + + assert_eq!( + render_expression(&JinjaExpression("1".to_string()), &ctx).unwrap(), + "1" + ); + assert_eq!( + render_expression(&JinjaExpression("1 + 1".to_string()), &ctx).unwrap(), + "2" + ); + assert_eq!( + render_expression(&JinjaExpression("a|length > 2".to_string()), &ctx).unwrap(), + "true" + ); + } + + #[test] + fn test_render_regex_match() { + let ctx = vec![( + "a".to_string(), + BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()) + ), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))] + .into_iter() + .collect(); + assert_eq!( + render_expression(&JinjaExpression(r##"b|regex_match("123")"##.to_string()), &ctx).unwrap(), + "true" + ); + assert_eq!( + render_expression(&JinjaExpression(r##"b|regex_match("\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}")"##.to_string()), &ctx).unwrap(), + "true" + ) + } } diff --git a/engine/baml-lib/jinja/src/output_format/mod.rs b/engine/baml-lib/jinja/src/output_format/mod.rs index f2abfc716..59efe6f6e 100644 --- a/engine/baml-lib/jinja/src/output_format/mod.rs +++ b/engine/baml-lib/jinja/src/output_format/mod.rs @@ -9,6 +9,7 @@ use crate::{types::RenderOptions, RenderContext}; use self::types::OutputFormatContent; +// TODO: Rename the field to `content`. #[derive(Debug)] pub struct OutputFormat { text: OutputFormatContent, diff --git a/engine/baml-lib/jinja/src/output_format/types.rs b/engine/baml-lib/jinja/src/output_format/types.rs index 1125ef58b..fc795a69e 100644 --- a/engine/baml-lib/jinja/src/output_format/types.rs +++ b/engine/baml-lib/jinja/src/output_format/types.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use anyhow::Result; -use baml_types::{FieldType, LiteralValue, TypeValue}; +use baml_types::{FieldType, LiteralValue, TypeValue, Constraint}; use indexmap::{IndexMap, IndexSet}; #[derive(Debug)] @@ -34,18 +34,23 @@ impl Name { } } +// TODO: (Greg) Enum needs to carry its constraints. #[derive(Debug)] pub struct Enum { pub name: Name, // name and description pub values: Vec<(Name, Option)>, + pub constraints: Vec, } +/// The components of a Class needed to render `OutputFormatContent`. +/// This type is also used by `jsonish` to drive flexible parsing. #[derive(Debug)] pub struct Class { pub name: Name, - // type and description + // fields have name, type and description. pub fields: Vec<(Name, FieldType, Option)>, + pub constraints: Vec, } #[derive(Debug, Clone)] @@ -227,10 +232,8 @@ impl OutputFormatContent { } fn prefix<'a>(&self, options: &'a RenderOptions) -> Option<&'a str> { - match &options.prefix { - RenderSetting::Always(prefix) => Some(prefix.as_str()), - RenderSetting::Never => None, - RenderSetting::Auto => match &self.target { + fn auto_prefix(ft: &FieldType) -> Option<&'static str> { + match ft { FieldType::Primitive(TypeValue::String) => None, FieldType::Primitive(_) => Some("Answer as a: "), FieldType::Literal(_) => Some("Answer using this specific value:\n"), @@ -241,7 +244,13 @@ impl OutputFormatContent { FieldType::Optional(_) => Some("Answer in JSON using this schema:\n"), FieldType::Map(_, _) => Some("Answer in JSON using this schema:\n"), FieldType::Tuple(_) => None, - }, + FieldType::Constrained { base, .. } => auto_prefix(base), + } + } + match &options.prefix { + RenderSetting::Always(prefix) => Some(prefix.as_str()), + RenderSetting::Never => None, + RenderSetting::Auto => auto_prefix(&self.target), } } @@ -287,6 +296,9 @@ impl OutputFormatContent { LiteralValue::Int(i) => i.to_string(), LiteralValue::Bool(b) => b.to_string(), }, + FieldType::Constrained { base, .. } => { + self.inner_type_render(options, base, render_state, group_hoisted_literals)? + } FieldType::Enum(e) => { let Some(enm) = self.enums.get(e) else { return Err(minijinja::Error::new( @@ -523,6 +535,7 @@ mod tests { (Name::new("Green".to_string()), None), (Name::new("Blue".to_string()), None), ], + constraints: Vec::new(), }); let content = OutputFormatContent::new(enums, vec![], FieldType::Enum("Color".to_string())); @@ -553,6 +566,7 @@ mod tests { Some("The person's age".to_string()), ), ], + constraints: Vec::new(), }); let content = @@ -589,6 +603,7 @@ mod tests { None, ), ], + constraints: Vec::new(), }); let content = diff --git a/engine/baml-lib/jinja/src/render_context.rs b/engine/baml-lib/jinja/src/render_context.rs index a1f8861c4..ba33725b2 100644 --- a/engine/baml-lib/jinja/src/render_context.rs +++ b/engine/baml-lib/jinja/src/render_context.rs @@ -22,6 +22,8 @@ impl std::fmt::Display for RenderContext_Client { } } +// TODO: (Greg) This type is duplicated in `src/lib.rs`. Are they both +// needed? If not, delete one. #[derive(Debug)] pub struct RenderContext { client: RenderContext_Client, diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs index 8dde401fb..bab1d87a3 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/field_type.rs @@ -3,7 +3,7 @@ use baml_types::BamlMap; use internal_baml_core::{ir::FieldType, ir::TypeValue}; use crate::deserializer::{ - coercer::{DefaultValue, TypeCoercer}, + coercer::{run_user_checks, DefaultValue, TypeCoercer}, deserialize_flags::{DeserializerConditions, Flag}, types::BamlValueWithFlags, }; @@ -84,6 +84,19 @@ impl TypeCoercer for FieldType { FieldType::Optional(_) => coerce_optional(ctx, self, value), FieldType::Map(_, _) => coerce_map(ctx, self, value), FieldType::Tuple(_) => Err(ctx.error_internal("Tuple not supported")), + FieldType::Constrained { base, .. } => { + let mut coerced_value = base.coerce(ctx, base, value)?; + let constraint_results = + run_user_checks(&coerced_value.clone().into(), &self).map_err( + |e| ParsingError { + reason: format!("Failed to evaluate constraints: {:?}", e), + scope: ctx.scope.clone(), + causes: Vec::new(), + }, + )?; + coerced_value.add_flag(Flag::ConstraintResults(constraint_results)); + Ok(coerced_value) + } }, } } @@ -100,7 +113,7 @@ impl DefaultValue for FieldType { match self { FieldType::Enum(e) => None, FieldType::Literal(_) => None, - FieldType::Class(c) => None, + FieldType::Class(_) => None, FieldType::List(_) => Some(BamlValueWithFlags::List(get_flags(), Vec::new())), FieldType::Union(items) => items.iter().find_map(|i| i.default_value(error)), FieldType::Primitive(TypeValue::Null) | FieldType::Optional(_) => { @@ -119,6 +132,8 @@ impl DefaultValue for FieldType { } } FieldType::Primitive(_) => None, + // If it has constraints, we can't assume our defaults meet them. + FieldType::Constrained { .. } => None, } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs index 4ebf500f5..5b273f72d 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/ir_ref/coerce_class.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use baml_types::BamlMap; +use baml_types::{BamlMap, Constraint}; use internal_baml_core::ir::FieldType; use internal_baml_jinja::types::{Class, Name}; @@ -11,7 +11,7 @@ use crate::deserializer::{ use super::ParsingContext; -// Name, type, description +// Name, type, description, constraints. type FieldValue = (Name, FieldType, Option); impl TypeCoercer for Class { diff --git a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs index d07bd2257..9c63089d1 100644 --- a/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs +++ b/engine/baml-lib/jsonish/src/deserializer/coercer/mod.rs @@ -10,7 +10,9 @@ mod ir_ref; mod match_string; use anyhow::Result; -use internal_baml_jinja::types::OutputFormatContent; + +use baml_types::{BamlValue, Constraint}; +use internal_baml_jinja::{evaluate_predicate, types::OutputFormatContent}; use internal_baml_core::ir::FieldType; @@ -125,7 +127,7 @@ impl ParsingContext<'_> { &self, unparsed: Vec<(String, &ParsingError)>, missing: Vec, - item: Option<&crate::jsonish::Value>, + _item: Option<&crate::jsonish::Value>, ) -> ParsingError { ParsingError { reason: format!( @@ -136,7 +138,7 @@ impl ParsingContext<'_> { scope: self.scope.clone(), causes: missing .into_iter() - .map(|(k)| ParsingError { + .map(|k| ParsingError { scope: self.scope.clone(), reason: format!("Missing required field: {}", k), causes: vec![], @@ -219,3 +221,22 @@ pub trait TypeCoercer { pub trait DefaultValue { fn default_value(&self, error: Option<&ParsingError>) -> Option; } + +/// Run all checks and asserts for a value at a given type. +pub fn run_user_checks( + baml_value: &BamlValue, + type_: &FieldType, +) -> Result> { + match type_ { + FieldType::Constrained { constraints, .. } => { + constraints.iter().map(|constraint| { + let result = + evaluate_predicate(baml_value, &constraint.expression).map_err(|e| { + anyhow::anyhow!(format!("Error evaluating constraint: {:?}", e)) + })?; + Ok((constraint.clone(), result)) + }).collect::>>() + } + _ => Ok(vec![]), + } +} diff --git a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs index a05b85399..04bdebfe4 100644 --- a/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs +++ b/engine/baml-lib/jsonish/src/deserializer/deserialize_flags.rs @@ -1,4 +1,5 @@ use super::{coercer::ParsingError, types::BamlValueWithFlags}; +use baml_types::Constraint; #[derive(Debug, Clone)] pub enum Flag { @@ -42,6 +43,9 @@ pub enum Flag { // X -> Object convertions. NoFields(Option), + + // Constraint results. + ConstraintResults(Vec<(Constraint, bool)>), } #[derive(Clone)] @@ -90,9 +94,18 @@ impl DeserializerConditions { Flag::NoFields(_) => None, Flag::UnionMatch(_idx, _) => None, Flag::DefaultButHadUnparseableValue(e) => Some(e.clone()), + Flag::ConstraintResults(_) => None, }) .collect::>() } + + pub fn constraint_results(&self) -> Vec<(Constraint, bool)> { + self.flags.iter().filter_map(|flag| match flag { + Flag::ConstraintResults(cs) => Some(cs.clone()), + _ => None, + }).flatten().collect() + } + } impl std::fmt::Debug for DeserializerConditions { @@ -229,6 +242,13 @@ impl std::fmt::Display for Flag { writeln!(f, "")?; } } + Flag::ConstraintResults(cs) => { + for (Constraint{ label, level, expression }, succeeded) in cs.iter() { + let msg = label.as_ref().unwrap_or(&expression.0); + let f_result = if *succeeded { "Succeeded" } else { "Failed" }; + writeln!(f, "{level:?} {msg} {f_result}")?; + } + } } Ok(()) } diff --git a/engine/baml-lib/jsonish/src/deserializer/score.rs b/engine/baml-lib/jsonish/src/deserializer/score.rs index cba62ce0c..bf25dc39b 100644 --- a/engine/baml-lib/jsonish/src/deserializer/score.rs +++ b/engine/baml-lib/jsonish/src/deserializer/score.rs @@ -1,3 +1,5 @@ +use baml_types::{Constraint, ConstraintLevel}; + use super::{ deserialize_flags::{DeserializerConditions, Flag}, types::{BamlValueWithFlags, ValueWithFlags}, @@ -62,6 +64,18 @@ impl WithScore for Flag { Flag::StringToChar(_) => 1, Flag::FloatToInt(_) => 1, Flag::NoFields(_) => 1, + Flag::ConstraintResults(cs) => { + cs + .iter() + .map(|(Constraint{ level,.. }, succeeded)| + if *succeeded { 0 } else { + match level { + ConstraintLevel::Check => 5, + ConstraintLevel::Assert => 50, + } + }) + .sum() + } } } } diff --git a/engine/baml-lib/jsonish/src/deserializer/types.rs b/engine/baml-lib/jsonish/src/deserializer/types.rs index e1dc4681f..b82657e0a 100644 --- a/engine/baml-lib/jsonish/src/deserializer/types.rs +++ b/engine/baml-lib/jsonish/src/deserializer/types.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use baml_types::{BamlMap, BamlMedia, BamlValue}; +use baml_types::{BamlMap, BamlMedia, BamlValue, BamlValueWithMeta, Constraint}; use serde_json::json; use strsim::jaro; @@ -229,7 +229,7 @@ impl BamlValueWithFlags { #[derive(Debug, Clone)] pub struct ValueWithFlags { - value: T, + pub value: T, pub(super) flags: DeserializerConditions, } @@ -440,3 +440,33 @@ impl std::fmt::Display for BamlValueWithFlags { Ok(()) } } + +impl From for BamlValueWithMeta> { + fn from(baml_value: BamlValueWithFlags) -> Self { + use BamlValueWithFlags::*; + let c = baml_value.conditions().constraint_results(); + match baml_value { + String(ValueWithFlags { value, .. }) => BamlValueWithMeta::String(value, c), + Int(ValueWithFlags { value, .. }) => BamlValueWithMeta::Int(value, c), + Float(ValueWithFlags { value, .. }) => BamlValueWithMeta::Float(value, c), + Bool(ValueWithFlags { value, .. }) => BamlValueWithMeta::Bool(value, c), + Map(_, values) => BamlValueWithMeta::Map( + values.into_iter().map(|(k, v)| (k, v.1.into())).collect(), + c, + ), // TODO: (Greg) I discard the DeserializerConditions tupled up with the value of the BamlMap. I'm not sure why BamlMap value is (DeserializerContitions, BamlValueWithFlags) in the first place. + List(_, values) => { + BamlValueWithMeta::List(values.into_iter().map(|v| v.into()).collect(), c) + } + Media(ValueWithFlags { value, .. }) => BamlValueWithMeta::Media(value, c), + Enum(enum_name, ValueWithFlags { value, .. }) => { + BamlValueWithMeta::Enum(enum_name, value, c) + } + Class(class_name, _, fields) => BamlValueWithMeta::Class( + class_name, + fields.into_iter().map(|(k, v)| (k, v.into())).collect(), + c, + ), + Null(_) => BamlValueWithMeta::Null(c), + } + } +} diff --git a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs index 19a7ed99d..ab7a5058c 100644 --- a/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs +++ b/engine/baml-lib/jsonish/src/jsonish/parser/fixing_parser/json_parse_state.rs @@ -215,7 +215,7 @@ impl JsonParseState { log::debug!("Closing due to: new key after space + comma"); return Some(idx); } - x => { + _x => { break; } } diff --git a/engine/baml-lib/jsonish/src/lib.rs b/engine/baml-lib/jsonish/src/lib.rs index e89408c68..8325a01a5 100644 --- a/engine/baml-lib/jsonish/src/lib.rs +++ b/engine/baml-lib/jsonish/src/lib.rs @@ -2,7 +2,7 @@ mod tests; use anyhow::Result; -mod deserializer; +pub mod deserializer; mod jsonish; use baml_types::FieldType; diff --git a/engine/baml-lib/jsonish/src/tests/macros.rs b/engine/baml-lib/jsonish/src/tests/macros.rs index 2a8703437..7c0dd8281 100644 --- a/engine/baml-lib/jsonish/src/tests/macros.rs +++ b/engine/baml-lib/jsonish/src/tests/macros.rs @@ -16,6 +16,12 @@ macro_rules! test_failing_deserializer { }; } +/// Arguments: +/// name: name of test function to generate. +/// file_content: a BAML schema. +/// raw_string: an example payload coming from an LLM to parse. +/// target_type: The type to try to parse raw_string into. +/// json: The expected JSON encoding that the parser should return. macro_rules! test_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { #[test_log::test] @@ -45,6 +51,25 @@ macro_rules! test_deserializer { }; } +macro_rules! test_deserializer_with_expected_score { + ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $target_score:expr) => { + #[test_log::test] + fn $name() { + let ir = load_test_ir($file_content); + let target = render_output_format(&ir, &$target_type, &Default::default()).unwrap(); + + let result = from_str(&target, &$target_type, $raw_string, false); + + assert!(result.is_ok(), "Failed to parse: {:?}", result); + + let value = result.unwrap(); + dbg!(&value); + log::trace!("Score: {}", value.score()); + assert_eq!(value.score(), $target_score); + } + }; +} + macro_rules! test_partial_deserializer { ($name:ident, $file_content:expr, $raw_string:expr, $target_type:expr, $($json:tt)+) => { #[test_log::test] diff --git a/engine/baml-lib/jsonish/src/tests/mod.rs b/engine/baml-lib/jsonish/src/tests/mod.rs index 8b39002c6..8350bc0c8 100644 --- a/engine/baml-lib/jsonish/src/tests/mod.rs +++ b/engine/baml-lib/jsonish/src/tests/mod.rs @@ -6,6 +6,7 @@ pub mod macros; mod test_basics; mod test_class; +mod test_constraints; mod test_enum; mod test_lists; mod test_literals; @@ -18,7 +19,7 @@ use std::{ path::PathBuf, }; -use baml_types::BamlValue; +use baml_types::{BamlValue, Constraint, ConstraintLevel, JinjaExpression}; use internal_baml_core::{ internal_baml_diagnostics::SourceFile, ir::{repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, TypeValue}, @@ -105,20 +106,53 @@ fn find_enum_value( Ok(Some((name, desc))) } +/// Eliminate the `FieldType::Constrained` variant by searching for it, and stripping +/// it off of its base type, returning a tulpe of the base type and any constraints found +/// (if called on an argument that is not Constrained, the returned constraints Vec is +/// empty). +/// +/// If the function encounters directly nested Constrained types, +/// (i.e. `FieldType::Constrained { base: FieldType::Constrained { .. }, .. } `) +/// then the constraints of the two levels will be combined into a single vector. +/// So, we always return a base type that is not FieldType::Constrained. +fn distribute_constraints(field_type: &FieldType) -> (&FieldType, Vec) { + + match field_type { + // Check the first level to see if it's constrained. + FieldType::Constrained { base, constraints } => { + match base.as_ref() { + // If so, we must check the second level to see if we need to combine + // constraints across levels. + // The recursion here means that arbitrarily nested `FieldType::Constrained`s + // will be collapsed before the function returns. + FieldType::Constrained{..} => { + let (sub_base, sub_constraints) = distribute_constraints(base); + let combined_constraints = vec![constraints.clone(), sub_constraints].into_iter().flatten().collect(); + (sub_base, combined_constraints) + }, + _ => (base, constraints.clone()), + } + }, + _ => (field_type, Vec::new()), + } +} + +// TODO: (Greg) Is the use of `String` as a hash key safe? Is there some way to +// get a collision that results in some type not getting put onto the stack? fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, env_values: &HashMap, ) -> Result<(Vec, Vec)> { - let mut checked_types = HashSet::new(); + let mut checked_types: HashSet = HashSet::new(); let mut enums = Vec::new(); - let mut classes = Vec::new(); + let mut classes: Vec = Vec::new(); let mut start: Vec = vec![output.clone()]; while !start.is_empty() { let output = start.pop().unwrap(); - match &output { - FieldType::Enum(enm) => { + match distribute_constraints(&output) { + (FieldType::Enum(enm), constraints) => { if checked_types.insert(output.to_string()) { let walker = ir.find_enum(enm); @@ -140,15 +174,16 @@ fn relevant_data_models<'a>( enums.push(Enum { name: Name::new_with_alias(enm.to_string(), walker?.alias(env_values)?), values, + constraints, }); } } - FieldType::List(inner) | FieldType::Optional(inner) => { + (FieldType::List(inner), _constraints) | (FieldType::Optional(inner), _constraints) => { if !checked_types.contains(&inner.to_string()) { start.push(inner.as_ref().clone()); } } - FieldType::Map(k, v) => { + (FieldType::Map(k, v), _constraints) => { if checked_types.insert(output.to_string()) { if !checked_types.contains(&k.to_string()) { start.push(k.as_ref().clone()); @@ -158,7 +193,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Tuple(options) | FieldType::Union(options) => { + (FieldType::Tuple(options), _constraints) | (FieldType::Union(options), _constraints) => { if checked_types.insert((&output).to_string()) { for inner in options { if !checked_types.contains(&inner.to_string()) { @@ -167,7 +202,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Class(cls) => { + (FieldType::Class(cls), constraints) => { if checked_types.insert(output.to_string()) { let walker = ir.find_class(&cls); @@ -192,11 +227,15 @@ fn relevant_data_models<'a>( classes.push(Class { name: Name::new_with_alias(cls.to_string(), walker?.alias(env_values)?), fields, + constraints, }); } } - FieldType::Primitive(_) => {} - FieldType::Literal(_) => {} + (FieldType::Literal(_), _) => {} + (FieldType::Primitive(_), _constraints) => {} + (FieldType::Constrained{..}, _) => { + unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") + } } } @@ -749,4 +788,29 @@ fn partial_int_not_deleted() { let baml_value: BamlValue = res.into(); // Note: This happens to parse as a List, but Null also seems appropriate. assert_eq!(baml_value, BamlValue::List(vec![])); + +#[test] +fn test_nested_constraint_distribution() { + fn mk_constraint(s: &str) -> Constraint { + Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: Some(s.to_string()) } + } + + let input = FieldType::Constrained { + constraints: vec![mk_constraint("a")], + base: Box::new(FieldType::Constrained { + constraints: vec![mk_constraint("b")], + base: Box::new(FieldType::Constrained { + constraints: vec![mk_constraint("c")], + base: Box::new(FieldType::Primitive(TypeValue::Int)), + }) + }) + }; + + let expected_base = FieldType::Primitive(TypeValue::Int); + let expected_constraints = vec![mk_constraint("a"),mk_constraint("b"), mk_constraint("c")]; + + let (base, constraints) = distribute_constraints(&input); + + assert_eq!(base, &expected_base); + assert_eq!(constraints, expected_constraints); } diff --git a/engine/baml-lib/jsonish/src/tests/test_constraints.rs b/engine/baml-lib/jsonish/src/tests/test_constraints.rs new file mode 100644 index 000000000..8c9e52e5b --- /dev/null +++ b/engine/baml-lib/jsonish/src/tests/test_constraints.rs @@ -0,0 +1,129 @@ +use super::*; + +const CLASS_FOO_INT_STRING: &str = r#" +class Foo { + age int + @check({{this < 10}}, "age less than 10") + @check({{this < 20}}, "age less than 20") + @assert({{this >= 0}}, "nonnegative") + name string + @assert({{this|length > 0}}, "Nonempty name") +} +"#; + +test_deserializer_with_expected_score!( + test_class_failing_one_check, + CLASS_FOO_INT_STRING, + r#"{"age": 11, "name": "Greg"}"#, + FieldType::Class("Foo".to_string()), + 5 +); + +test_deserializer_with_expected_score!( + test_class_failing_two_checks, + CLASS_FOO_INT_STRING, + r#"{"age": 21, "name": "Grog"}"#, + FieldType::Class("Foo".to_string()), + 10 +); + +test_deserializer_with_expected_score!( + test_class_failing_assert, + CLASS_FOO_INT_STRING, + r#"{"age": -1, "name": "Sam"}"#, + FieldType::Class("Foo".to_string()), + 50 +); + +test_deserializer_with_expected_score!( + test_class_multiple_failing_asserts, + CLASS_FOO_INT_STRING, + r#"{"age": -1, "name": ""}"#, + FieldType::Class("Foo".to_string()), + 100 +); + +const UNION_WITH_CHECKS: &str = r#" +class Thing1 { + bar int @check({{ this < 10 }}, "bar small") +} + +class Thing2 { + bar int @check({{ this > 20 }}, "bar big") +} + +class Either { + bar Thing1 | Thing2 + things (Thing1 | Thing2)[] @assert({{this|length < 4}}, "list not too long") +} +"#; + +test_deserializer_with_expected_score!( + test_union_decision_from_check, + UNION_WITH_CHECKS, + r#"{"bar": 5, "things":[]}"#, + FieldType::Class("Either".to_string()), + 2 +); + +test_deserializer_with_expected_score!( + test_union_decision_from_check_no_good_answer, + UNION_WITH_CHECKS, + r#"{"bar": 15, "things":[]}"#, + FieldType::Class("Either".to_string()), + 7 +); + +test_deserializer_with_expected_score!( + test_union_decision_in_list, + UNION_WITH_CHECKS, + r#"{"bar": 1, "things":[{"bar": 25}, {"bar": 35}, {"bar": 15}, {"bar": 15}]}"#, + FieldType::Class("Either".to_string()), + 62 +); + +const MAP_WITH_CHECKS: &str = r#" +class Foo { + foo map @check({{ this["hello"] == 10 }}, "hello is 10") +} +"#; + +test_deserializer_with_expected_score!( + test_map_with_check, + MAP_WITH_CHECKS, + r#"{"foo": {"hello": 10, "there":13}}"#, + FieldType::Class("Foo".to_string()), + 1 +); + +test_deserializer_with_expected_score!( + test_map_with_check_fails, + MAP_WITH_CHECKS, + r#"{"foo": {"hello": 11, "there":13}}"#, + FieldType::Class("Foo".to_string()), + 6 +); + +const NESTED_CLASS_CONSTRAINTS: &str = r#" +class Outer { + inner Inner +} + +class Inner { + value int @check({{ this < 10 }}) +} +"#; + +test_deserializer_with_expected_score!( + test_nested_class_constraints, + NESTED_CLASS_CONSTRAINTS, + r#"{"inner": {"value": 15}}"#, + FieldType::Class("Outer".to_string()), + 5 +); + +const MISSPELLED_CONSTRAINT: &str = r#" +class Foo { + foo int @description("hi") @check({{this == 1}},"hi") +} +"#; diff --git a/engine/baml-lib/jsonish/src/tests/test_unions.rs b/engine/baml-lib/jsonish/src/tests/test_unions.rs index 9d2784b3d..40316411d 100644 --- a/engine/baml-lib/jsonish/src/tests/test_unions.rs +++ b/engine/baml-lib/jsonish/src/tests/test_unions.rs @@ -247,3 +247,35 @@ test_deserializer!( ] } ); + +const CONTACT_INFO: &str = r#" +class PhoneNumber { + value string @check({{this|regex_match("\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")}}, "valid_phone_number") + foo int? // A nullable marker indicating PhoneNumber was chosen. +} + +class EmailAddress { + value string @check({{this|regex_match("^[_]*([a-z0-9]+(\.|_*)?)+@([a-z][a-z0-9-]+(\.|-*\.))+[a-z]{2,6}$")}}, "valid_email") + bar int? // A nullable marker indicating EmailAddress was chosen. +} + +class ContactInfo { + primary PhoneNumber | EmailAddress +} +"#; + +test_deserializer!( + test_check1, + CONTACT_INFO, + r#"{"primary": {"value": "908-797-8281"}}"#, + FieldType::Class("ContactInfo".to_string()), + {"primary": {"value": "908-797-8281", "foo": null}} +); + +test_deserializer!( + test_check2, + CONTACT_INFO, + r#"{"primary": {"value": "help@boundaryml.com"}}"#, + FieldType::Class("ContactInfo".to_string()), + {"primary": {"value": "help@boundaryml.com", "bar": null}} +); diff --git a/engine/baml-lib/parser-database/src/attributes/constraint.rs b/engine/baml-lib/parser-database/src/attributes/constraint.rs new file mode 100644 index 000000000..993944cf7 --- /dev/null +++ b/engine/baml-lib/parser-database/src/attributes/constraint.rs @@ -0,0 +1,38 @@ +use baml_types::{Constraint, ConstraintLevel}; +use internal_baml_schema_ast::ast::Expression; + +use crate::{context::Context, types::Attributes}; + +pub(super) fn visit_constraint_attributes( + attribute_name: String, + attributes: &mut Attributes, + ctx: &mut Context<'_>, +) { + let expression_arg = ctx.visit_default_arg_with_idx("expression").map_err(|err| { + ctx.push_error(err); + }); + let label = ctx.visit_default_arg_with_idx("name"); + let label = match label { + Ok((_, Expression::StringValue(descr, _))) => Some(descr.clone()), + _ => None, + }; + match expression_arg { + Ok((_, Expression::JinjaExpressionValue(expression, _))) => { + let level = match attribute_name.as_str() { + "assert" => ConstraintLevel::Assert, + "check" => ConstraintLevel::Check, + _ => { + panic!("Internal error: Only \"assert\" and \"check\" are valid attribute names in this context."); + } + }; + attributes.constraints.push(Constraint { + level, + expression: expression.clone(), + label, + }); + } + _ => panic!( + "The impossible happened: Reached arguments that are ruled out by the tokenizer." + ), + } +} diff --git a/engine/baml-lib/parser-database/src/attributes/mod.rs b/engine/baml-lib/parser-database/src/attributes/mod.rs index 0b0efc708..5e4518985 100644 --- a/engine/baml-lib/parser-database/src/attributes/mod.rs +++ b/engine/baml-lib/parser-database/src/attributes/mod.rs @@ -1,10 +1,12 @@ use internal_baml_schema_ast::ast::{Top, TopId, TypeExpId, TypeExpressionBlock}; mod alias; +mod constraint; mod description; mod to_string_attribute; use crate::interner::StringId; use crate::{context::Context, types::ClassAttributes, types::EnumAttributes}; +use baml_types::Constraint; use internal_baml_schema_ast::ast::{Expression, SubType}; /// @@ -21,6 +23,9 @@ pub struct Attributes { /// Whether the node should be skipped during prompt rendering and parsing. pub skip: Option, + + /// @check and @assert attributes attached to the node. + pub constraints: Vec, } impl Attributes { @@ -63,7 +68,6 @@ impl Attributes { pub fn set_skip(&mut self) { self.skip.replace(true); } - } pub(super) fn resolve_attributes(ctx: &mut Context<'_>) { for top in ctx.ast.iter_tops() { @@ -90,7 +94,7 @@ fn resolve_type_exp_block_attributes<'db>( let mut enum_attributes = EnumAttributes::default(); for (value_idx, _value) in ast_typexpr.iter_fields() { - ctx.visit_attributes((type_id, value_idx).into()); + ctx.assert_all_attributes_processed((type_id, value_idx).into()); if let Some(attrs) = to_string_attribute::visit(ctx, false) { enum_attributes.value_serilizers.insert(value_idx, attrs); } @@ -98,7 +102,7 @@ fn resolve_type_exp_block_attributes<'db>( } // Now validate the enum attributes. - ctx.visit_attributes(type_id.into()); + ctx.assert_all_attributes_processed(type_id.into()); enum_attributes.serilizer = to_string_attribute::visit(ctx, true); ctx.validate_visited_attributes(); @@ -107,16 +111,18 @@ fn resolve_type_exp_block_attributes<'db>( SubType::Class => { let mut class_attributes = ClassAttributes::default(); + dbg!(&ast_typexpr); for (field_idx, _field) in ast_typexpr.iter_fields() { - ctx.visit_attributes((type_id, field_idx).into()); + ctx.assert_all_attributes_processed((type_id, field_idx).into()); if let Some(attrs) = to_string_attribute::visit(ctx, false) { + dbg!(&attrs); class_attributes.field_serilizers.insert(field_idx, attrs); } ctx.validate_visited_attributes(); } // Now validate the class attributes. - ctx.visit_attributes(type_id.into()); + ctx.assert_all_attributes_processed(type_id.into()); class_attributes.serilizer = to_string_attribute::visit(ctx, true); ctx.validate_visited_attributes(); diff --git a/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs b/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs index c9fa3d4b7..70567efa2 100644 --- a/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs +++ b/engine/baml-lib/parser-database/src/attributes/to_string_attribute.rs @@ -1,6 +1,7 @@ use crate::{context::Context, types::Attributes}; use super::alias::visit_alias_attribute; +use super::constraint::visit_constraint_attributes; use super::description::visit_description_attribute; pub(super) fn visit(ctx: &mut Context<'_>, as_block: bool) -> Option { @@ -26,6 +27,13 @@ pub(super) fn visit(ctx: &mut Context<'_>, as_block: bool) -> Option ctx.validate_visited_arguments(); } + if let Some(attribute_name) = ctx.visit_repeated_attr_from_names(&["assert", "check"]) { + panic!("HERE"); + visit_constraint_attributes(attribute_name, &mut attributes, ctx); + modified = true; + ctx.validate_visited_arguments(); + } + if as_block { if ctx.visit_optional_single_attr("dynamic") { attributes.set_dynamic_type(); diff --git a/engine/baml-lib/parser-database/src/context/attributes.rs b/engine/baml-lib/parser-database/src/context/attributes.rs index 4ddecc8d4..3d725f440 100644 --- a/engine/baml-lib/parser-database/src/context/attributes.rs +++ b/engine/baml-lib/parser-database/src/context/attributes.rs @@ -10,7 +10,7 @@ pub(super) struct AttributesValidationState { /// The attribute being validated. pub(super) attribute: Option, - pub(super) args: VecDeque, // the _remaining_ arguments of `attribute` + pub(super) args: VecDeque, // the _remaining_ arguments of `attribute` } impl AttributesValidationState { diff --git a/engine/baml-lib/parser-database/src/context/mod.rs b/engine/baml-lib/parser-database/src/context/mod.rs index 04faf7832..5b7a95ae1 100644 --- a/engine/baml-lib/parser-database/src/context/mod.rs +++ b/engine/baml-lib/parser-database/src/context/mod.rs @@ -1,5 +1,5 @@ use internal_baml_diagnostics::DatamodelWarning; -use internal_baml_schema_ast::ast::ArguementId; +use internal_baml_schema_ast::ast::ArgumentId; use crate::{ ast, ast::WithName, interner::StringInterner, names::Names, types::Types, DatamodelError, @@ -83,10 +83,13 @@ impl<'db> Context<'db> { /// /// - When you are done validating an attribute, you must call `discard_arguments()` or /// `validate_visited_arguments()`. Otherwise, Context will helpfully panic. - pub(super) fn visit_attributes(&mut self, ast_attributes: ast::AttributeContainer) { + pub(super) fn assert_all_attributes_processed( + &mut self, + ast_attributes: ast::AttributeContainer, + ) { if self.attributes.attributes.is_some() || !self.attributes.unused_attributes.is_empty() { panic!( - "`ctx.visit_attributes() called with {:?} while the Context is still validating previous attribute set on {:?}`", + "`ctx.assert_all_attributes_processed() called with {:?} while the Context is still validating previous attribute set on {:?}`", ast_attributes, self.attributes.attributes ); @@ -98,7 +101,7 @@ impl<'db> Context<'db> { /// Extract an attribute that can occur zero or more times. Example: @@index on models. /// /// Returns `true` as long as a next attribute is found. - pub(crate) fn visit_repeated_attr(&mut self, name: &'static str) -> bool { + pub(crate) fn _visit_repeated_attr(&mut self, name: &'static str) -> bool { let mut has_valid_attribute = false; while !has_valid_attribute { @@ -117,6 +120,37 @@ impl<'db> Context<'db> { has_valid_attribute } + /// Extract an attribute that can occur zero or more times. Example: @assert on types. + /// Argument is a list of names that are all valid for this attribute. + /// + /// Returns Some(name_match) if name_match is the attribute name and is in the + /// `names` argument. + pub(crate) fn visit_repeated_attr_from_names( + &mut self, + names: &'static [&'static str], + ) -> Option { + let mut has_valid_attribute = false; + let mut matching_name: Option = None; + + let all_attributes = + iter_attributes(self.attributes.attributes.as_ref(), self.ast).collect::>(); + while !has_valid_attribute { + let first_attr = iter_attributes(self.attributes.attributes.as_ref(), self.ast) + .filter(|(_, attr)| names.contains(&attr.name.name())) + .find(|(attr_id, _)| self.attributes.unused_attributes.contains(attr_id)); + let (attr_id, attr) = if let Some(first_attr) = first_attr { + first_attr + } else { + break; + }; + self.attributes.unused_attributes.remove(&attr_id); + has_valid_attribute = self.set_attribute(attr_id, attr); + matching_name = Some(attr.name.name().to_string()); + } + + matching_name + } + /// Validate an _optional_ attribute that should occur only once. Returns whether the attribute /// is defined. #[must_use] @@ -155,7 +189,7 @@ impl<'db> Context<'db> { pub(crate) fn visit_default_arg_with_idx( &mut self, name: &str, - ) -> Result<(ArguementId, &'db ast::Expression), DatamodelError> { + ) -> Result<(ArgumentId, &'db ast::Expression), DatamodelError> { match self.attributes.args.pop_front() { Some(arg_idx) => { let arg = self.arg_at(arg_idx); @@ -186,7 +220,7 @@ impl<'db> Context<'db> { self.discard_arguments(); } - /// Counterpart to visit_attributes(). This must be called at the end of the validation of the + /// Counterpart to assert_all_attributes_processed(). This must be called at the end of the validation of the /// attribute set. The Drop impl will helpfully panic otherwise. pub(crate) fn validate_visited_attributes(&mut self) { if !self.attributes.args.is_empty() || self.attributes.attribute.is_some() { @@ -216,7 +250,7 @@ impl<'db> Context<'db> { &self.ast[id] } - fn arg_at(&self, idx: ArguementId) -> &'db ast::Argument { + fn arg_at(&self, idx: ArgumentId) -> &'db ast::Argument { &self.current_attribute().arguments[idx] } diff --git a/engine/baml-lib/parser-database/src/lib.rs b/engine/baml-lib/parser-database/src/lib.rs index 5fe20ef7a..38cdeb663 100644 --- a/engine/baml-lib/parser-database/src/lib.rs +++ b/engine/baml-lib/parser-database/src/lib.rs @@ -47,7 +47,7 @@ pub use types::{ }; use self::{context::Context, interner::StringId, types::Types}; -use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Diagnostics}; +use internal_baml_diagnostics::{DatamodelError, Diagnostics}; use names::Names; /// ParserDatabase is a container for a Schema AST, together with information diff --git a/engine/baml-lib/schema-ast/src/ast.rs b/engine/baml-lib/schema-ast/src/ast.rs index df322df8f..1c31fe8b3 100644 --- a/engine/baml-lib/schema-ast/src/ast.rs +++ b/engine/baml-lib/schema-ast/src/ast.rs @@ -18,7 +18,7 @@ mod type_expression_block; mod value_expression_block; pub(crate) use self::comment::Comment; -pub use argument::{ArguementId, Argument, ArgumentsList}; +pub use argument::{ArgumentId, Argument, ArgumentsList}; pub use attribute::{Attribute, AttributeContainer, AttributeId}; pub use config::ConfigBlockProperty; pub use expression::{Expression, RawString}; @@ -32,7 +32,7 @@ pub use top::Top; pub use traits::{WithAttributes, WithDocumentation, WithIdentifier, WithName, WithSpan}; pub use type_expression_block::{FieldId, SubType, TypeExpressionBlock}; pub use value_expression_block::{ - ArgumentId, BlockArg, BlockArgs, ValueExprBlock, ValueExprBlockType, + BlockArg, BlockArgs, ValueExprBlock, ValueExprBlockType, }; /// AST representation of a prisma schema. diff --git a/engine/baml-lib/schema-ast/src/ast/argument.rs b/engine/baml-lib/schema-ast/src/ast/argument.rs index 265ba4985..5fc04d019 100644 --- a/engine/baml-lib/schema-ast/src/ast/argument.rs +++ b/engine/baml-lib/schema-ast/src/ast/argument.rs @@ -4,19 +4,19 @@ use std::fmt::{Display, Formatter}; /// An opaque identifier for a value in an AST enum. Use the /// `r#enum[enum_value_id]` syntax to resolve the id to an `ast::EnumValue`. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ArguementId(pub u32); +pub struct ArgumentId(pub u32); -impl ArguementId { +impl ArgumentId { /// Used for range bounds when iterating over BTreeMaps. - pub const MIN: ArguementId = ArguementId(0); + pub const MIN: ArgumentId = ArgumentId(0); /// Used for range bounds when iterating over BTreeMaps. - pub const MAX: ArguementId = ArguementId(u32::MAX); + pub const MAX: ArgumentId = ArgumentId(u32::MAX); } -impl std::ops::Index for ArgumentsList { +impl std::ops::Index for ArgumentsList { type Output = Argument; - fn index(&self, index: ArguementId) -> &Self::Output { + fn index(&self, index: ArgumentId) -> &Self::Output { &self.arguments[index.0 as usize] } } @@ -34,11 +34,11 @@ pub struct ArgumentsList { } impl ArgumentsList { - pub fn iter(&self) -> impl ExactSizeIterator { + pub fn iter(&self) -> impl ExactSizeIterator { self.arguments .iter() .enumerate() - .map(|(idx, field)| (ArguementId(idx as u32), field)) + .map(|(idx, field)| (ArgumentId(idx as u32), field)) } } diff --git a/engine/baml-lib/schema-ast/src/ast/attribute.rs b/engine/baml-lib/schema-ast/src/ast/attribute.rs index 896d92ee2..0ccd549d1 100644 --- a/engine/baml-lib/schema-ast/src/ast/attribute.rs +++ b/engine/baml-lib/schema-ast/src/ast/attribute.rs @@ -1,4 +1,4 @@ -use super::{ArguementId, ArgumentsList, Identifier, Span, WithIdentifier, WithSpan}; +use super::{ArgumentId, ArgumentsList, Identifier, Span, WithIdentifier, WithSpan}; use std::ops::Index; /// An attribute (following `@` or `@@``) on a model, model field, enum, enum value or composite @@ -29,7 +29,7 @@ pub struct Attribute { impl Attribute { /// Try to find the argument and return its span. - pub fn span_for_argument(&self, argument: ArguementId) -> Span { + pub fn span_for_argument(&self, argument: ArgumentId) -> Span { self.arguments[argument].span.clone() } diff --git a/engine/baml-lib/schema-ast/src/ast/expression.rs b/engine/baml-lib/schema-ast/src/ast/expression.rs index 0c4892e92..b7677beb2 100644 --- a/engine/baml-lib/schema-ast/src/ast/expression.rs +++ b/engine/baml-lib/schema-ast/src/ast/expression.rs @@ -4,6 +4,7 @@ use crate::ast::Span; use std::fmt; use super::{Identifier, WithName, WithSpan}; +use baml_types::JinjaExpression; #[derive(Debug, Clone)] pub struct RawString { @@ -159,6 +160,8 @@ pub enum Expression { Array(Vec, Span), /// A mapping function. Map(Vec<(Expression, Expression)>, Span), + /// A JinjaExpression. e.g. "this|length > 5". + JinjaExpressionValue(JinjaExpression, Span), } impl fmt::Display for Expression { @@ -171,6 +174,7 @@ impl fmt::Display for Expression { Expression::RawStringValue(val, ..) => { write!(f, "{}", crate::string_literal(val.value())) } + Expression::JinjaExpressionValue(val,..) => fmt::Display::fmt(val, f), Expression::Array(vals, _) => { let vals = vals .iter() @@ -293,6 +297,7 @@ impl Expression { Self::NumericValue(_, span) => span, Self::StringValue(_, span) => span, Self::RawStringValue(r) => r.span(), + Self::JinjaExpressionValue(_,span) => span, Self::Identifier(id) => id.span(), Self::Map(_, span) => span, Self::Array(_, span) => span, @@ -310,6 +315,7 @@ impl Expression { Expression::NumericValue(_, _) => "numeric", Expression::StringValue(_, _) => "string", Expression::RawStringValue(_) => "raw_string", + Expression::JinjaExpressionValue(_, _) => "jinja_expression", Expression::Identifier(id) => match id { Identifier::String(_, _) => "string", Identifier::Local(_, _) => "local_type", @@ -354,6 +360,8 @@ impl Expression { (StringValue(_,_), _) => panic!("Types do not match: {:?} and {:?}", self, other), (RawStringValue(s1), RawStringValue(s2)) => s1.assert_eq_up_to_span(s2), (RawStringValue(_), _) => panic!("Types do not match: {:?} and {:?}", self, other), + (JinjaExpressionValue(j1, _), JinjaExpressionValue(j2, _)) => assert_eq!(j1, j2), + (JinjaExpressionValue(_,_), _) => panic!("Types do not match: {:?} and {:?}", self, other), (Array(xs,_), Array(ys,_)) => { assert_eq!(xs.len(), ys.len()); xs.iter().zip(ys).for_each(|(x,y)| { x.assert_eq_up_to_span(y); }) diff --git a/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs index b6314d1b8..87d5147c9 100644 --- a/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/type_expression_block.rs @@ -30,21 +30,19 @@ pub enum SubType { Other(String), } -/// An enum declaration. Enumeration can either be in the database schema, or completely a Prisma level concept. -/// -/// PostgreSQL stores enums in a schema, while in MySQL the information is in -/// the table definition. On MongoDB the enumerations are handled in the Query -/// Engine. +/// A class or enum declaration. #[derive(Debug, Clone)] pub struct TypeExpressionBlock { - /// The name of the enum. + /// The name of the class or enum. /// /// ```ignore /// enum Foo { ... } /// ^^^ + /// class Bar { ... } + /// ^^^ /// ``` pub name: Identifier, - /// The values of the enum. + /// The values of the enum, or fields of the class. /// /// ```ignore /// enum Foo { diff --git a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs index 62ca4b6ae..ccebc25fa 100644 --- a/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs +++ b/engine/baml-lib/schema-ast/src/ast/value_expression_block.rs @@ -2,12 +2,9 @@ use super::{ traits::WithAttributes, Attribute, Comment, Expression, Field, FieldType, Identifier, Span, WithDocumentation, WithIdentifier, WithSpan, }; +use super::argument::ArgumentId; use std::fmt::Display; use std::fmt::Formatter; -/// An opaque identifier for a value in an AST enum. Use the -/// `r#enum[enum_value_id]` syntax to resolve the id to an `ast::EnumValue`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ArgumentId(pub u32); /// An opaque identifier for a field in an AST model. Use the /// `model[field_id]` syntax to resolve the id to an `ast::Field`. @@ -29,13 +26,6 @@ impl std::ops::Index for ValueExprBlock { } } -impl ArgumentId { - /// Used for range bounds when iterating over BTreeMaps. - pub const MIN: ArgumentId = ArgumentId(0); - /// Used for range bounds when iterating over BTreeMaps. - pub const MAX: ArgumentId = ArgumentId(u32::MAX); -} - impl std::ops::Index for BlockArgs { type Output = (Identifier, BlockArg); diff --git a/engine/baml-lib/schema-ast/src/parser/datamodel.pest b/engine/baml-lib/schema-ast/src/parser/datamodel.pest index f615316bc..3a906850b 100644 --- a/engine/baml-lib/schema-ast/src/parser/datamodel.pest +++ b/engine/baml-lib/schema-ast/src/parser/datamodel.pest @@ -94,7 +94,8 @@ map_entry = { (comment_block | empty_lines)* ~ map_key ~ (expression | ENTRY_CAT splitter = _{ ("," ~ NEWLINE?) | NEWLINE } map_expression = { "{" ~ empty_lines? ~ (map_entry ~ (splitter ~ map_entry)*)? ~ (comment_block | empty_lines)* ~ "}" } array_expression = { "[" ~ empty_lines? ~ ((expression | ARRAY_CATCH_ALL) ~ trailing_comment? ~ (splitter ~ (comment_block | empty_lines)* ~ (expression | ARRAY_CATCH_ALL) ~ trailing_comment?)*)? ~ (comment_block | empty_lines)* ~ splitter? ~ "]" } -expression = { map_expression | array_expression | numeric_literal | string_literal | identifier } +jinja_expression = { "{{" ~ (!("}}" | "{{") ~ ANY)* ~ "}}" } +expression = { jinja_expression | map_expression | array_expression | numeric_literal | string_literal | identifier } ARRAY_CATCH_ALL = { !"]" ~ CATCH_ALL } ENTRY_CATCH_ALL = { field_attribute | BLOCK_LEVEL_CATCH_ALL } // ###################################### diff --git a/engine/baml-lib/schema-ast/src/parser/parse_expression.rs b/engine/baml-lib/schema-ast/src/parser/parse_expression.rs index 2a0edb3f0..caac48b55 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_expression.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_expression.rs @@ -3,6 +3,7 @@ use super::{ parse_identifier::parse_identifier, Rule, }; +use baml_types::JinjaExpression; use crate::{assert_correct_parser, ast::*, unreachable_rule}; use internal_baml_diagnostics::Diagnostics; @@ -17,6 +18,7 @@ pub(crate) fn parse_expression( Rule::string_literal => Some(parse_string_literal(first_child, diagnostics)), Rule::map_expression => Some(parse_map(first_child, diagnostics)), Rule::array_expression => Some(parse_array(first_child, diagnostics)), + Rule::jinja_expression => Some(parse_jinja_expression(first_child, diagnostics)), Rule::identifier => Some(Expression::Identifier(parse_identifier( first_child, @@ -245,10 +247,32 @@ fn unescape_string(val: &str) -> String { result } +/// Parse a `JinjaExpression` from raw source. Escape backslashes, +/// because we want the user's backslash intent to be preserved in +/// the string backing the `JinjaExpression`. In other words, control +/// sequences like `\n` are intended to be forwarded to the Jinja +/// processing engine, not to break a Jinja Expression into two lines, +/// therefor the backing string should be contain "\\n". +pub fn parse_jinja_expression(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Expression { + assert_correct_parser!(token, Rule::jinja_expression); + let mut inner_text = String::new(); + for c in token.as_str()[2..token.as_str().len() - 2].chars() { + match c { + // When encountering a single backslash, produce two backslashes. + '\\' => inner_text.push_str("\\\\"), + // Otherwise, just copy the character. + _ => inner_text.push(c), + } + } + Expression::JinjaExpressionValue(JinjaExpression(inner_text), diagnostics.span(token.as_span())) +} + #[cfg(test)] mod tests { + use super::*; use super::super::{BAMLParser, Rule}; - use pest::{consumes_to, parses_to}; + use pest::{Parser, parses_to, consumes_to}; + use internal_baml_diagnostics::{Diagnostics, SourceFile}; #[test] fn array_trailing_comma() { @@ -287,4 +311,24 @@ mod tests { ])] }; } + + #[test] + fn test_parse_jinja_expression() { + let input = "{{ 1 + 1 }}"; + let root_path = "test_file.baml"; + let source = SourceFile::new_static(root_path.into(), input); + let mut diagnostics = Diagnostics::new(root_path.into()); + diagnostics.set_source(&source); + + let pair = BAMLParser::parse(Rule::jinja_expression, input) + .unwrap() + .next() + .unwrap(); + let expr = parse_jinja_expression(pair, &mut diagnostics); + match expr { + Expression::JinjaExpressionValue(JinjaExpression(s), _) => assert_eq!(s, " 1 + 1 "), + _ => panic!("Expected JinjaExpression, got {:?}", expr), + } + } + } diff --git a/engine/baml-lib/schema-ast/src/parser/parse_field.rs b/engine/baml-lib/schema-ast/src/parser/parse_field.rs index 2644eb562..eabff4c3e 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_field.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_field.rs @@ -60,6 +60,18 @@ pub(crate) fn parse_value_expr( } } +fn reassociate_type_attributes( + field_attributes: &mut Vec, + field_type: &mut FieldType, +) { + let mut all_attrs = field_type.attributes().to_owned(); + all_attrs.append(field_attributes); + let (attrs_for_type, attrs_for_field): (Vec, Vec) = + all_attrs.into_iter().partition(|attr| ["assert", "check"].contains(&attr.name())); + field_type.set_attributes(attrs_for_type); + *field_attributes = attrs_for_field; +} + pub(crate) fn parse_type_expr( model_name: &Option, container_type: &'static str, @@ -90,11 +102,18 @@ pub(crate) fn parse_type_expr( } } + // Strip certain attributes from the field and attach them to the type. + match field_type.as_mut() { + None => {}, + Some(ft) => reassociate_type_attributes(&mut field_attributes, ft), + } + match (name, &field_type) { + // Class field. (Some(name), Some(field_type)) => Ok(Field { expr: Some(field_type.clone()), name, - attributes: field_type.clone().attributes().to_vec(), + attributes: field_attributes, documentation: comment, span: diagnostics.span(pair_span), }), diff --git a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs index 79fcd479c..c8e00dd6c 100644 --- a/engine/baml-lib/schema-ast/src/parser/parse_schema.rs +++ b/engine/baml-lib/schema-ast/src/parser/parse_schema.rs @@ -178,6 +178,7 @@ mod tests { let input = r#" class MyClass { myProperty string[] @description("This is a description") @alias("MP") + prop2 string @description({{ "a " + "b" }}) } "#; @@ -192,11 +193,13 @@ mod tests { assert_eq!(schema_ast.tops.len(), 1); match &schema_ast.tops[0] { - Top::Class(model) => { - assert_eq!(model.name.name(), "MyClass"); - assert_eq!(model.fields.len(), 1); - assert_eq!(model.fields[0].name.name(), "myProperty"); - assert_eq!(model.fields[0].attributes.len(), 2) + Top::Class(TypeExpressionBlock { name, fields, .. }) => { + assert_eq!(name.name(), "MyClass"); + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name.name(), "myProperty"); + assert_eq!(fields[1].name.name(), "prop2"); + assert_eq!(fields[0].attributes.len(), 2); + assert_eq!(fields[1].attributes.len(), 1); } _ => panic!("Expected a model declaration"), } diff --git a/engine/baml-runtime/src/cli/mod.rs b/engine/baml-runtime/src/cli/mod.rs index 6cabc917b..458d569ea 100644 --- a/engine/baml-runtime/src/cli/mod.rs +++ b/engine/baml-runtime/src/cli/mod.rs @@ -1,5 +1,5 @@ mod dev; -mod generate; +pub mod generate; mod init; mod serve; diff --git a/engine/baml-runtime/src/cli/serve/mod.rs b/engine/baml-runtime/src/cli/serve/mod.rs index 6cc67af64..90ef713b8 100644 --- a/engine/baml-runtime/src/cli/serve/mod.rs +++ b/engine/baml-runtime/src/cli/serve/mod.rs @@ -33,7 +33,9 @@ use tokio::{net::TcpListener, sync::RwLock}; use tokio_stream::StreamExt; use crate::{ - client_registry::ClientRegistry, errors::ExposedError, internal::llm_client::LLMResponse, + client_registry::ClientRegistry, + errors::ExposedError, + internal::llm_client::{LLMResponse, ResponseBamlValue}, BamlRuntime, FunctionResult, RuntimeContextManager, }; use internal_baml_codegen::openapi::OpenApiSchema; @@ -367,7 +369,7 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` LLMResponse::Success(_) => match function_result.parsed_content() { // Just because the LLM returned 2xx doesn't mean that it returned parse-able content! Ok(parsed) => { - (StatusCode::OK, Json::(parsed.into())).into_response() + (StatusCode::OK, Json::(parsed.clone())).into_response() } Err(e) => { if let Some(ExposedError::ValidationError { @@ -478,8 +480,10 @@ Tip: test that the server is up using `curl http://localhost:{}/_debug/ping` Ok(function_result) => match function_result.llm_response() { LLMResponse::Success(_) => match function_result.parsed_content() { // Just because the LLM returned 2xx doesn't mean that it returned parse-able content! - Ok(parsed) => (StatusCode::OK, Json::(parsed.into())) - .into_response(), + Ok(parsed) => { + (StatusCode::OK, Json::(parsed.clone())) + .into_response() + } Err(e) => { log::debug!("Error parsing content: {:?}", e); diff --git a/engine/baml-runtime/src/internal/llm_client/mod.rs b/engine/baml-runtime/src/internal/llm_client/mod.rs index 934526e17..2bafebbcb 100644 --- a/engine/baml-runtime/src/internal/llm_client/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/mod.rs @@ -1,6 +1,5 @@ use std::collections::{HashMap, HashSet}; -use base64::write; use colored::*; pub mod llm_provider; pub mod orchestrator; @@ -12,10 +11,11 @@ pub mod traits; use anyhow::Result; +use baml_types::{BamlValueWithMeta, Constraint, ConstraintLevel, ResponseCheck}; use internal_baml_core::ir::ClientWalker; -use internal_baml_jinja::{ChatMessagePart, RenderedChatMessage, RenderedPrompt}; +use internal_baml_jinja::RenderedPrompt; +use jsonish::BamlValueWithFlags; use serde::{Deserialize, Serialize}; -use serde_json::Map; use std::error::Error; use reqwest::StatusCode; @@ -23,6 +23,32 @@ use reqwest::StatusCode; #[cfg(target_arch = "wasm32")] use wasm_bindgen::JsValue; +pub type ResponseBamlValue = BamlValueWithMeta>; + +/// Validate a parsed value, checking asserts and checks. +pub fn parsed_value_to_response(baml_value: BamlValueWithFlags) -> Result { + let baml_value_with_meta: BamlValueWithMeta> = baml_value.into(); + let first_failing_assert: Option = baml_value_with_meta + .meta() + .iter() + .filter_map(|(c @ Constraint { level, .. }, succeeded)| { + if !succeeded && level == &ConstraintLevel::Assert { + Some(c.clone()) + } else { + None + } + }) + .next(); + match first_failing_assert { + Some(err) => Err(anyhow::anyhow!("Failed assertion: {:?}", err)), + None => Ok(baml_value_with_meta.map_meta(|cs| { + cs.into_iter() + .filter_map(|res| ResponseCheck::from_constraint_result(res)) + .collect() + })), + } +} + #[derive(Clone, Copy, PartialEq)] pub enum ResolveMediaUrls { // there are 5 input formats: diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs index f16408694..04817bf92 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs @@ -7,8 +7,7 @@ use web_time::Duration; use crate::{ internal::{ llm_client::{ - traits::{WithPrompt, WithSingleCallable}, - LLMResponse, + parsed_value_to_response, traits::{WithPrompt, WithSingleCallable}, LLMResponse, ResponseBamlValue }, prompt_renderer::PromptRenderer, }, @@ -28,7 +27,7 @@ pub async fn orchestrate( Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, Duration, ) { @@ -50,7 +49,13 @@ pub async fn orchestrate( }; let sleep_duration = node.error_sleep_duration().cloned(); - results.push((node.scope, response, parsed_response)); + let response_with_constraints: Option> = + parsed_response.map( + |r| r.and_then( + |v| parsed_value_to_response(v) + ) + ); + results.push((node.scope, response, response_with_constraints)); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs index c8069a961..81fa7542a 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs @@ -83,7 +83,7 @@ impl OrchestratorNode { } } -#[derive(Default, Clone, Serialize)] +#[derive(Debug, Default, Clone, Serialize)] pub struct OrchestrationScope { pub scope: Vec, } @@ -138,7 +138,7 @@ impl OrchestrationScope { } } -#[derive(Clone, Serialize)] +#[derive(Clone, Debug, Serialize)] pub enum ExecutionScope { Direct(String), // PolicyName, RetryCount, RetryDelayMs diff --git a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs index ecf0ac5fb..ccda0b29d 100644 --- a/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs +++ b/engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs @@ -8,8 +8,7 @@ use web_time::Duration; use crate::{ internal::{ llm_client::{ - traits::{WithPrompt, WithStreamable}, - LLMErrorResponse, LLMResponse, + parsed_value_to_response, traits::{WithPrompt, WithStreamable}, LLMErrorResponse, LLMResponse, ResponseBamlValue }, prompt_renderer::PromptRenderer, }, @@ -31,7 +30,7 @@ pub async fn orchestrate_stream( Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, Duration, ) @@ -60,10 +59,12 @@ where match &stream_part { LLMResponse::Success(s) => { let parsed = partial_parse_fn(&s.content); + let response_value: Result = + parsed.and_then(|v| parsed_value_to_response(v)); on_event(FunctionResult::new( node.scope.clone(), LLMResponse::Success(s.clone()), - Some(parsed), + Some(response_value), )); } _ => {} @@ -92,8 +93,9 @@ where LLMResponse::Success(s) => Some(parse_fn(&s.content)), _ => None, }; + let response_value: Option> = parsed_response.map(|r| r.and_then(|v| parsed_value_to_response(v))); let sleep_duration = node.error_sleep_duration().cloned(); - results.push((node.scope, final_response, parsed_response)); + results.push((node.scope, final_response, response_value)); // Currently, we break out of the loop if an LLM responded, even if we couldn't parse the result. if results diff --git a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs index 81b7b21e6..fe200495d 100644 --- a/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs +++ b/engine/baml-runtime/src/internal/prompt_renderer/render_output_format.rs @@ -1,7 +1,7 @@ use std::collections::HashSet; use anyhow::Result; -use baml_types::BamlValue; +use baml_types::{BamlValue, Constraint}; use indexmap::IndexSet; use internal_baml_core::ir::{ repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper, @@ -66,7 +66,7 @@ fn find_new_class_field<'a>( field_name: &str, class_walker: &Result>, overrides: &'a RuntimeClassOverride, - ctx: &RuntimeContext, + _ctx: &RuntimeContext, ) -> Result<(Name, FieldType, Option)> { let Some(field_overrides) = overrides.new_fields.get(field_name) else { anyhow::bail!("Class {} does not have a field: {}", class_name, field_name); @@ -194,6 +194,37 @@ fn find_enum_value( Ok(Some((name, desc))) } +/// Eliminate the `FieldType::Constrained` variant by searching for it, and stripping +/// it off of its base type, returning a tulpe of the base type and any constraints found +/// (if called on an argument that is not Constrained, the returned constraints Vec is +/// empty). +/// +/// If the function encounters directly nested Constrained types, +/// (i.e. `FieldType::Constrained { base: FieldType::Constrained { .. }, .. } `) +/// then the constraints of the two levels will be combined into a single vector. +/// So, we always return a base type that is not FieldType::Constrained. +fn distribute_constraints(field_type: &FieldType) -> (&FieldType, Vec) { + + match field_type { + // Check the first level to see if it's constrained. + FieldType::Constrained { base, constraints } => { + match base.as_ref() { + // If so, we must check the second level to see if we need to combine + // constraints across levels. + // The recursion here means that arbitrarily nested `FieldType::Constrained`s + // will be collapsed before the function returns. + FieldType::Constrained{..} => { + let (sub_base, sub_constraints) = distribute_constraints(base); + let combined_constraints = vec![constraints.clone(), sub_constraints].into_iter().flatten().collect(); + (sub_base, combined_constraints) + }, + _ => (base, constraints.clone()), + } + }, + _ => (field_type, Vec::new()), + } +} + fn relevant_data_models<'a>( ir: &'a IntermediateRepr, output: &'a FieldType, @@ -205,8 +236,8 @@ fn relevant_data_models<'a>( let mut start: Vec = vec![output.clone()]; while let Some(output) = start.pop() { - match &output { - FieldType::Enum(enm) => { + match distribute_constraints(&output) { + (FieldType::Enum(enm), constraints) => { if checked_types.insert(output.to_string()) { let overrides = ctx.enum_overrides.get(enm); let walker = ir.find_enum(enm); @@ -246,15 +277,16 @@ fn relevant_data_models<'a>( enums.push(Enum { name: Name::new_with_alias(enm.to_string(), alias.value()), values, + constraints, }); } } - FieldType::List(inner) | FieldType::Optional(inner) => { + (FieldType::List(inner), _) | (FieldType::Optional(inner), _) => { if !checked_types.contains(&inner.to_string()) { start.push(inner.as_ref().clone()); } } - FieldType::Map(k, v) => { + (FieldType::Map(k, v), _) => { if checked_types.insert(output.to_string()) { if !checked_types.contains(&k.to_string()) { start.push(k.as_ref().clone()); @@ -264,7 +296,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Tuple(options) | FieldType::Union(options) => { + (FieldType::Tuple(options), _) | (FieldType::Union(options), _) => { if checked_types.insert((&output).to_string()) { for inner in options { if !checked_types.contains(&inner.to_string()) { @@ -273,7 +305,7 @@ fn relevant_data_models<'a>( } } } - FieldType::Class(cls) => { + (FieldType::Class(cls), constraints) => { if checked_types.insert(output.to_string()) { let overrides = ctx.class_override.get(cls); let walker = ir.find_class(&cls); @@ -330,11 +362,15 @@ fn relevant_data_models<'a>( classes.push(Class { name: Name::new_with_alias(cls.to_string(), alias.value()), fields, + constraints, }); } } - FieldType::Primitive(_) => {} - FieldType::Literal(_) => {} + (FieldType::Literal(_), _) => {} + (FieldType::Primitive(_), _) => {} + (FieldType::Constrained{..}, _)=> { + unreachable!("It is guaranteed that a call to distribute_constraints will not return FieldType::Constrained") + }, } } @@ -343,6 +379,10 @@ fn relevant_data_models<'a>( #[cfg(test)] mod tests { + use std::collections::HashMap; + use baml_types::{ConstraintLevel, JinjaExpression, TypeValue}; + + use crate::BamlRuntime; use super::*; use crate::BamlRuntime; use std::collections::HashMap; @@ -372,4 +412,30 @@ mod tests { assert_eq!(foo_enum.values[0].0.real_name(), "Bar".to_string()); assert_eq!(foo_enum.values.len(), 1); } + + #[test] + fn test_nested_constraint_distribution() { + fn mk_constraint(s: &str) -> Constraint { + Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: Some(s.to_string()) } + } + + let input = FieldType::Constrained { + constraints: vec![mk_constraint("a")], + base: Box::new(FieldType::Constrained { + constraints: vec![mk_constraint("b")], + base: Box::new(FieldType::Constrained { + constraints: vec![mk_constraint("c")], + base: Box::new(FieldType::Primitive(TypeValue::Int)), + }) + }) + }; + + let expected_base = FieldType::Primitive(TypeValue::Int); + let expected_constraints = vec![mk_constraint("a"),mk_constraint("b"), mk_constraint("c")]; + + let (base, constraints) = distribute_constraints(&input); + + assert_eq!(base, &expected_base); + assert_eq!(constraints, expected_constraints); + } } diff --git a/engine/baml-runtime/src/types/expression_helper.rs b/engine/baml-runtime/src/types/expression_helper.rs index dd05b724b..16df949e5 100644 --- a/engine/baml-runtime/src/types/expression_helper.rs +++ b/engine/baml-runtime/src/types/expression_helper.rs @@ -51,6 +51,7 @@ pub fn to_value(ctx: &RuntimeContext, expr: &Expression) -> Result>>()?; json!(res) - } + }, + Expression::JinjaExpression(_) => anyhow::bail!("Unable to normalize jinja expression to a value without a context."), }) } diff --git a/engine/baml-runtime/src/types/response.rs b/engine/baml-runtime/src/types/response.rs index 9f6ac4017..9a4e03fc7 100644 --- a/engine/baml-runtime/src/types/response.rs +++ b/engine/baml-runtime/src/types/response.rs @@ -1,16 +1,16 @@ pub use crate::internal::llm_client::LLMResponse; -use crate::{errors::ExposedError, internal::llm_client::orchestrator::OrchestrationScope}; +use crate::{errors::ExposedError, internal::llm_client::{orchestrator::OrchestrationScope, ResponseBamlValue}}; use anyhow::Result; use colored::*; use baml_types::BamlValue; -use jsonish::BamlValueWithFlags; +#[derive(Debug)] pub struct FunctionResult { event_chain: Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, } @@ -27,7 +27,6 @@ impl std::fmt::Display for FunctionResult { writeln!(f, "{}", self.llm_response())?; match &self.parsed() { Some(Ok(val)) => { - let val: BamlValue = val.into(); writeln!( f, "{}", @@ -48,10 +47,10 @@ impl FunctionResult { pub fn new( scope: OrchestrationScope, response: LLMResponse, - parsed: Option>, + baml_value: Option>, ) -> Self { Self { - event_chain: vec![(scope, response, parsed)], + event_chain: vec![(scope, response, baml_value)], } } @@ -60,7 +59,7 @@ impl FunctionResult { ) -> &Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )> { &self.event_chain } @@ -69,7 +68,7 @@ impl FunctionResult { chain: Vec<( OrchestrationScope, LLMResponse, - Option>, + Option>, )>, ) -> Result { if chain.is_empty() { @@ -91,11 +90,11 @@ impl FunctionResult { &self.event_chain.last().unwrap().0 } - pub fn parsed(&self) -> &Option> { + pub fn parsed(&self) -> &Option> { &self.event_chain.last().unwrap().2 } - pub fn parsed_content(&self) -> Result<&BamlValueWithFlags> { + pub fn parsed_content(&self) -> Result<&ResponseBamlValue> { self.parsed() .as_ref() .map(|res| { diff --git a/engine/language_client_codegen/Cargo.toml b/engine/language_client_codegen/Cargo.toml index d78fd5853..7a90dc4b3 100644 --- a/engine/language_client_codegen/Cargo.toml +++ b/engine/language_client_codegen/Cargo.toml @@ -25,3 +25,4 @@ sugar_path = "1.2.0" walkdir.workspace = true semver = "1.0.23" colored = "2.1.0" +itertools = "0.13.0" diff --git a/engine/language_client_codegen/src/lib.rs b/engine/language_client_codegen/src/lib.rs index 4f6d152e4..e0fe7d8b5 100644 --- a/engine/language_client_codegen/src/lib.rs +++ b/engine/language_client_codegen/src/lib.rs @@ -1,10 +1,11 @@ use anyhow::{Context, Result}; +use baml_types::{Constraint, ConstraintLevel, FieldType}; use indexmap::IndexMap; use internal_baml_core::{ configuration::{GeneratorDefaultClientMode, GeneratorOutputType}, ir::repr::IntermediateRepr, }; -use std::{collections::BTreeMap, path::PathBuf}; +use std::{collections::{BTreeMap, HashSet}, path::PathBuf}; use version_check::{check_version, GeneratorType, VersionCheckMode}; mod dir_writer; @@ -219,3 +220,153 @@ impl GenerateClient for GeneratorOutputType { }) } } + +/// A set of names of @check attributes. This set determines the +/// way name of a Python Class or TypeScript Interface that holds +/// the results of running these checks. See TODO (Docs) for details on +/// the support types generated from checks. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct TypeCheckAttributes(pub HashSet); + +impl <'a> std::hash::Hash for TypeCheckAttributes { + fn hash(&self, state: &mut H) + where H: std::hash::Hasher + { + self.0.iter().for_each(|s| s.hash(state)) + } + +} + +impl TypeCheckAttributes { + /// Extend one set of attributes with the contents of another. + pub fn extend(&mut self, other: &TypeCheckAttributes) { + self.0.extend(other.0.clone()) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } +} + +/// Search the IR for all types with checks, combining the checks on each type +/// into a `TypeCheckAttributes` (a HashSet of the check names). Return a HashSet +/// of these HashSets. +/// +/// For example, consider this IR defining two classes: +/// +/// ``` baml +/// class Foo { +/// int @check("a") @check("b") +/// string @check("a") +/// } +/// +/// class Bar { +/// bool @check("a") +/// } +/// ```` +/// +/// It contains two distinct `TypeCheckAttributes`: +/// - ["a"] +/// - ["a", "b"] +/// +/// We will need to construct two district support types: +/// `Classes_a` and `Classes_a_b`. +pub fn type_check_attributes( + ir: &IntermediateRepr +) -> HashSet { + + + let mut all_types_in_ir: Vec<&FieldType> = Vec::new(); + for class in ir.walk_classes() { + for field in class.item.elem.static_fields.iter() { + let field_type = &field.elem.r#type.elem; + all_types_in_ir.push(field_type); + } + } + for function in ir.walk_functions() { + for (_param_name, parameter) in function.item.elem.inputs.iter() { + all_types_in_ir.push(parameter); + } + let return_type = &function.item.elem.output; + all_types_in_ir.push(return_type); + } + + all_types_in_ir.into_iter().filter_map(field_type_attributes).collect() + +} + +/// The set of Check names associated with a type. +fn field_type_attributes<'a>(field_type: &FieldType) -> Option { + match field_type { + FieldType::Constrained {base, constraints} => { + let direct_sub_attributes = field_type_attributes(base); + let mut check_names = + TypeCheckAttributes( + constraints + .iter() + .filter_map(|Constraint {label, level, ..}| + if matches!(level, ConstraintLevel::Check) { + Some(label.clone().expect("TODO")) + } else { None } + ).collect::>()); + if let Some(ref sub_attrs) = direct_sub_attributes { + check_names.extend(&sub_attrs); + } + if !check_names.is_empty() { + Some(check_names) + } else { + None + } + }, + _ => None + } +} + +#[cfg(test)] +mod tests { + use internal_baml_core::ir::repr::make_test_ir; + use super::*; + + /// Utility function for creating test fixtures. + fn mk_tc_attrs(names: &[&str]) -> TypeCheckAttributes { + TypeCheckAttributes(names.into_iter().map(|s| s.to_string()).collect()) + } + + #[test] + fn find_type_check_attributes() { + let ir = make_test_ir( + r##" +client GPT4 { + provider openai + options { + model gpt-4o + api_key env.OPENAI_API_KEY + } +} + +function Go(a: int @check({{ this < 0 }}, "c")) -> Foo { + client GPT4 + prompt #""# +} + +class Foo { + ab int @check({{this}}, "a") @check({{this}}, "b") + a int @check({{this}}, "a") +} + +class Bar { + cb int @check({{this}}, "c") @check({{this}}, "b") + nil int @description("no checks") @assert({{this}}, "a") @assert({{this}}, "d") +} + + "##).expect("Valid source"); + + let attrs = type_check_attributes(&ir); + assert_eq!(attrs.len(), 4); + assert!(attrs.contains( &mk_tc_attrs(&["c"]) )); + assert!(attrs.contains( &mk_tc_attrs(&["a","b"]) )); + assert!(attrs.contains( &mk_tc_attrs(&["a"]) )); + assert!(attrs.contains( &mk_tc_attrs(&["c", "b"]) )); + assert!(!attrs.contains( &mk_tc_attrs(&["a", "d"]) )); + } +} diff --git a/engine/language_client_codegen/src/openapi.rs b/engine/language_client_codegen/src/openapi.rs index 90d4f1bbf..7c4b3bf08 100644 --- a/engine/language_client_codegen/src/openapi.rs +++ b/engine/language_client_codegen/src/openapi.rs @@ -1,5 +1,4 @@ -use std::collections::HashMap; -use std::{path::PathBuf, process::Command}; +use std::path::PathBuf; use anyhow::{Context, Result}; use baml_types::{BamlMediaType, FieldType, LiteralValue, TypeValue}; @@ -8,7 +7,7 @@ use internal_baml_core::ir::{ repr::{Function, IntermediateRepr, Node, Walker}, ClassWalker, EnumWalker, }; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use serde_json::json; use crate::dir_writer::{FileCollector, LanguageFeatures, RemoveDirBehavior}; @@ -71,46 +70,6 @@ impl Serialize for OpenApiSchema<'_> { &self, serializer: S, ) -> core::result::Result { - let baml_image_schema = TypeSpecWithMeta { - meta: TypeMetadata { - title: Some("BamlImage".to_string()), - r#enum: None, - r#const: None, - nullable: false, - }, - type_spec: TypeSpec::Inline(TypeDef::Class { - properties: vec![ - ( - "base64".to_string(), - TypeSpecWithMeta { - meta: TypeMetadata { - title: None, - r#enum: None, - r#const: None, - nullable: false, - }, - type_spec: TypeSpec::Inline(TypeDef::String), - }, - ), - ( - "media_type".to_string(), - TypeSpecWithMeta { - meta: TypeMetadata { - title: None, - r#enum: None, - r#const: None, - nullable: true, - }, - type_spec: TypeSpec::Inline(TypeDef::String), - }, - ), - ] - .into_iter() - .collect(), - required: vec!["base64".to_string()], - additional_properties: false, - }), - }; let schemas = match self .schemas .iter() @@ -638,6 +597,7 @@ impl<'ir> ToTypeReferenceInTypeDefinition<'ir> for FieldType { // something i saw suggested doing this type_spec } + FieldType::Constrained{base,..} => base.to_type_spec(ir)?, }) } } diff --git a/engine/language_client_codegen/src/python/generate_types.rs b/engine/language_client_codegen/src/python/generate_types.rs index 829bbdbb6..7800b80d5 100644 --- a/engine/language_client_codegen/src/python/generate_types.rs +++ b/engine/language_client_codegen/src/python/generate_types.rs @@ -1,4 +1,8 @@ use anyhow::Result; +use itertools::join; +use std::borrow::Cow; + +use crate::{field_type_attributes, type_check_attributes, TypeCheckAttributes}; use super::python_language_features::ToPython; use internal_baml_core::ir::{ @@ -10,6 +14,7 @@ use internal_baml_core::ir::{ pub(crate) struct PythonTypes<'ir> { enums: Vec>, classes: Vec>, + checks_classes: Vec> } #[derive(askama::Template)] @@ -17,6 +22,7 @@ pub(crate) struct PythonTypes<'ir> { pub(crate) struct TypeBuilder<'ir> { enums: Vec>, classes: Vec>, + checks_classes: Vec>, } struct PythonEnum<'ir> { @@ -26,15 +32,16 @@ struct PythonEnum<'ir> { } struct PythonClass<'ir> { - name: &'ir str, + name: Cow<'ir, str>, // the name, and the type of the field - fields: Vec<(&'ir str, String)>, + fields: Vec<(Cow<'ir, str>, String)>, dynamic: bool, } #[derive(askama::Template)] #[template(path = "partial_types.py.j2", escape = "none")] pub(crate) struct PythonStreamTypes<'ir> { + check_type_names: String, partial_classes: Vec>, } @@ -52,9 +59,15 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonT fn try_from( (ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs), ) -> Result> { + let checks_classes = + type_check_attributes(ir) + .into_iter() + .map(|checks| type_def_for_checks(checks)) + .collect::>(); Ok(PythonTypes { enums: ir.walk_enums().map(PythonEnum::from).collect::>(), classes: ir.walk_classes().map(PythonClass::from).collect::>(), + checks_classes, }) } } @@ -65,9 +78,15 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for TypeBui fn try_from( (ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs), ) -> Result> { + let checks_classes = + type_check_attributes(ir) + .into_iter() + .map(|checks| type_def_for_checks(checks)) + .collect::>(); Ok(TypeBuilder { enums: ir.walk_enums().map(PythonEnum::from).collect::>(), classes: ir.walk_classes().map(PythonClass::from).collect::>(), + checks_classes, }) } } @@ -91,7 +110,7 @@ impl<'ir> From> for PythonEnum<'ir> { impl<'ir> From> for PythonClass<'ir> { fn from(c: ClassWalker<'ir>) -> Self { PythonClass { - name: c.name(), + name: Cow::Borrowed(c.name()), dynamic: c.item.attributes.get("dynamic_type").is_some(), fields: c .item @@ -100,7 +119,7 @@ impl<'ir> From> for PythonClass<'ir> { .iter() .map(|f| { ( - f.elem.name.as_str(), + Cow::Borrowed(f.elem.name.as_str()), add_default_value( &f.elem.r#type.elem, &f.elem.r#type.elem.to_type_ref(&c.db), @@ -116,7 +135,13 @@ impl<'ir> TryFrom<(&'ir IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonS type Error = anyhow::Error; fn try_from((ir, _): (&'ir IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result { + let check_type_names = + join(type_check_attributes(ir) + .into_iter() + .map(|checks| type_name_for_checks(&checks)), + ", "); Ok(Self { + check_type_names, partial_classes: ir .walk_classes() .map(PartialPythonClass::from) @@ -157,6 +182,25 @@ pub fn add_default_value(node: &FieldType, type_str: &String) -> String { } } +pub fn type_name_for_checks(checks: &TypeCheckAttributes) -> String { + let mut name = "Checks".to_string(); + let mut names: Vec<&String> = checks.0.iter().collect(); + names.sort(); + for check_name in names.iter() { + name.push_str("__"); + name.push_str(check_name); + } + name +} + +fn type_def_for_checks(checks: TypeCheckAttributes) -> PythonClass<'static> { + PythonClass { + name: Cow::Owned(type_name_for_checks(&checks)), + fields: checks.0.into_iter().map(|check_name| (Cow::Owned(check_name), "baml_py.Check".to_string())).collect(), + dynamic: false + } +} + trait ToTypeReferenceInTypeDefinition { fn to_type_ref(&self, ir: &IntermediateRepr) -> String; fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String; @@ -200,6 +244,18 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)), + FieldType::Constrained{base, ..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_type_ref(ir); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]") + } + None => { + base.to_type_ref(ir) + } + } + }, } } @@ -250,6 +306,17 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => inner.to_partial_type_ref(ir, false), + FieldType::Constrained{base,..} => { + let base_type_ref = base.to_partial_type_ref(ir, false); + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_partial_type_ref(ir, false); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]") + } + None => base_type_ref + } + }, } } } diff --git a/engine/language_client_codegen/src/python/mod.rs b/engine/language_client_codegen/src/python/mod.rs index a89e03866..688358cdc 100644 --- a/engine/language_client_codegen/src/python/mod.rs +++ b/engine/language_client_codegen/src/python/mod.rs @@ -4,6 +4,7 @@ mod python_language_features; use std::path::PathBuf; use anyhow::Result; +use generate_types::type_name_for_checks; use indexmap::IndexMap; use internal_baml_core::{ configuration::GeneratorDefaultClientMode, @@ -11,7 +12,7 @@ use internal_baml_core::{ }; use self::python_language_features::{PythonLanguageFeatures, ToPython}; -use crate::dir_writer::FileCollector; +use crate::{dir_writer::FileCollector, field_type_attributes}; #[derive(askama::Template)] #[template(path = "async_client.py.j2", escape = "none")] @@ -109,7 +110,7 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonInit { impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonGlobals { type Error = anyhow::Error; - fn try_from((_, args): (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result { + fn try_from((_, _args): (&'_ IntermediateRepr, &'_ crate::GeneratorArgs)) -> Result { Ok(PythonGlobals {}) } } @@ -157,12 +158,12 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonClient let (_function, _impl_) = c.item; Ok(PythonFunction { name: f.name().to_string(), - partial_return_type: f.elem().output().to_partial_type_ref(ir), - return_type: f.elem().output().to_type_ref(ir), + partial_return_type: f.elem().output().to_partial_type_ref(ir, true), + return_type: f.elem().output().to_type_ref(ir, true), args: f .inputs() .iter() - .map(|(name, r#type)| (name.to_string(), r#type.to_type_ref(ir))) + .map(|(name, r#type)| (name.to_string(), r#type.to_type_ref(ir, false))) .collect(), }) }) @@ -178,13 +179,13 @@ impl TryFrom<(&'_ IntermediateRepr, &'_ crate::GeneratorArgs)> for PythonClient } trait ToTypeReferenceInClientDefinition { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String; + fn to_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String; - fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String; + fn to_partial_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String; } impl ToTypeReferenceInClientDefinition for FieldType { - fn to_type_ref(&self, ir: &IntermediateRepr) -> String { + fn to_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String { match self { FieldType::Enum(name) => { if ir @@ -199,16 +200,16 @@ impl ToTypeReferenceInClientDefinition for FieldType { } FieldType::Literal(value) => format!("Literal[{}]", value), FieldType::Class(name) => format!("types.{name}"), - FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir)), + FieldType::List(inner) => format!("List[{}]", inner.to_type_ref(ir, with_checked)), FieldType::Map(key, value) => { - format!("Dict[{}, {}]", key.to_type_ref(ir), value.to_type_ref(ir)) + format!("Dict[{}, {}]", key.to_type_ref(ir, with_checked), value.to_type_ref(ir, with_checked)) } FieldType::Primitive(r#type) => r#type.to_python(), FieldType::Union(inner) => format!( "Union[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, with_checked)) .collect::>() .join(", ") ), @@ -216,15 +217,27 @@ impl ToTypeReferenceInClientDefinition for FieldType { "Tuple[{}]", inner .iter() - .map(|t| t.to_type_ref(ir)) + .map(|t| t.to_type_ref(ir, with_checked)) .collect::>() .join(", ") ), - FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)), + FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir, with_checked)), + FieldType::Constrained{base, ..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_type_ref(ir, with_checked); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]") + } + None => { + base.to_type_ref(ir, with_checked) + } + } + }, } } - fn to_partial_type_ref(&self, ir: &IntermediateRepr) -> String { + fn to_partial_type_ref(&self, ir: &IntermediateRepr, with_checked: bool) -> String { match self { FieldType::Enum(name) => { if ir @@ -239,12 +252,12 @@ impl ToTypeReferenceInClientDefinition for FieldType { } FieldType::Class(name) => format!("partial_types.{name}"), FieldType::Literal(value) => format!("Literal[{}]", value), - FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir)), + FieldType::List(inner) => format!("List[{}]", inner.to_partial_type_ref(ir, with_checked)), FieldType::Map(key, value) => { format!( "Dict[{}, {}]", - key.to_type_ref(ir), - value.to_partial_type_ref(ir) + key.to_type_ref(ir, with_checked), + value.to_partial_type_ref(ir, with_checked) ) } FieldType::Primitive(r#type) => format!("Optional[{}]", r#type.to_python()), @@ -252,7 +265,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { "Optional[Union[{}]]", inner .iter() - .map(|t| t.to_partial_type_ref(ir)) + .map(|t| t.to_partial_type_ref(ir, with_checked)) .collect::>() .join(", ") ), @@ -260,11 +273,23 @@ impl ToTypeReferenceInClientDefinition for FieldType { "Optional[Tuple[{}]]", inner .iter() - .map(|t| t.to_partial_type_ref(ir)) + .map(|t| t.to_partial_type_ref(ir, with_checked)) .collect::>() .join(", ") ), - FieldType::Optional(inner) => inner.to_partial_type_ref(ir), + FieldType::Optional(inner) => inner.to_partial_type_ref(ir, with_checked), + FieldType::Constrained{base, ..} => { + match field_type_attributes(self) { + Some(checks) => { + let base_type_ref = base.to_partial_type_ref(ir, with_checked); + let checks_type_ref = type_name_for_checks(&checks); + format!("baml_py.Checked[{base_type_ref},types.{checks_type_ref}]") + } + None => { + base.to_partial_type_ref(ir, with_checked) + } + } + }, } } } diff --git a/engine/language_client_codegen/src/python/templates/partial_types.py.j2 b/engine/language_client_codegen/src/python/templates/partial_types.py.j2 index 3638f4553..5ccd64725 100644 --- a/engine/language_client_codegen/src/python/templates/partial_types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/partial_types.py.j2 @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict from typing import Dict, List, Optional, Union, Literal from . import types +from .types import {{check_type_names}} ############################################################################### # diff --git a/engine/language_client_codegen/src/python/templates/types.py.j2 b/engine/language_client_codegen/src/python/templates/types.py.j2 index cc3973cf3..9d1797eff 100644 --- a/engine/language_client_codegen/src/python/templates/types.py.j2 +++ b/engine/language_client_codegen/src/python/templates/types.py.j2 @@ -13,6 +13,15 @@ class {{enum.name}}(str, Enum): {%- endfor %} {% endfor %} +{#- Checks Classes -#} +{% for cls in checks_classes %} +class {{cls.name}}(BaseModel): + + {%- for (name, type) in cls.fields %} + {{name}}: {{type}} + {%- endfor %} +{% endfor %} + {#- Classes -#} {% for cls in classes %} class {{cls.name}}(BaseModel): diff --git a/engine/language_client_codegen/src/ruby/expression.rs b/engine/language_client_codegen/src/ruby/expression.rs index c23976112..567e58381 100644 --- a/engine/language_client_codegen/src/ruby/expression.rs +++ b/engine/language_client_codegen/src/ruby/expression.rs @@ -36,6 +36,7 @@ impl ToRuby for Expression { Expression::RawString(val) => format!("`{}`", val.replace('`', "\\`")), Expression::Numeric(val) => val.clone(), Expression::Bool(val) => val.to_string(), + Expression::JinjaExpression(val) => val.to_string(), } } } diff --git a/engine/language_client_codegen/src/ruby/field_type.rs b/engine/language_client_codegen/src/ruby/field_type.rs index 43622c6be..a23ae64cd 100644 --- a/engine/language_client_codegen/src/ruby/field_type.rs +++ b/engine/language_client_codegen/src/ruby/field_type.rs @@ -47,6 +47,7 @@ impl ToRuby for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("T.nilable({})", inner.to_ruby()), + FieldType::Constrained{base,..} => base.to_ruby(), } } } diff --git a/engine/language_client_codegen/src/ruby/generate_types.rs b/engine/language_client_codegen/src/ruby/generate_types.rs index 329a6a58e..47ae0b1f3 100644 --- a/engine/language_client_codegen/src/ruby/generate_types.rs +++ b/engine/language_client_codegen/src/ruby/generate_types.rs @@ -163,6 +163,7 @@ impl ToTypeReferenceInTypeDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => inner.to_partial_type_ref(), + FieldType::Constrained{base,..} => base.to_partial_type_ref(), } } } diff --git a/engine/language_client_codegen/src/typescript/mod.rs b/engine/language_client_codegen/src/typescript/mod.rs index 75e35a08b..c75b1e147 100644 --- a/engine/language_client_codegen/src/typescript/mod.rs +++ b/engine/language_client_codegen/src/typescript/mod.rs @@ -4,7 +4,6 @@ mod typescript_language_features; use std::path::PathBuf; use anyhow::Result; -use either::Either; use indexmap::IndexMap; use internal_baml_core::{ configuration::GeneratorDefaultClientMode, @@ -295,6 +294,7 @@ impl ToTypeReferenceInClientDefinition for FieldType { .join(", ") ), FieldType::Optional(inner) => format!("{} | null", inner.to_type_ref(ir)), + FieldType::Constrained{base,..} => base.to_type_ref(ir), } } } diff --git a/engine/language_client_python/python_src/baml_py/__init__.py b/engine/language_client_python/python_src/baml_py/__init__.py index f5b32b514..49735260c 100644 --- a/engine/language_client_python/python_src/baml_py/__init__.py +++ b/engine/language_client_python/python_src/baml_py/__init__.py @@ -18,6 +18,7 @@ ) from .stream import BamlStream, BamlSyncStream from .ctx_manager import CtxManager as BamlCtxManager +from .constraints import Check, Checked __all__ = [ "BamlRuntime", diff --git a/engine/language_client_python/python_src/baml_py/constraints.py b/engine/language_client_python/python_src/baml_py/constraints.py new file mode 100644 index 000000000..45ecf8aeb --- /dev/null +++ b/engine/language_client_python/python_src/baml_py/constraints.py @@ -0,0 +1,14 @@ +from typing import Generic, Optional, TypeVar +from pydantic import BaseModel + +T = TypeVar('T') +K = TypeVar('K') + +class Check(BaseModel): + name: Optional[str] + expression: str + status: str + +class Checked(BaseModel, Generic[T,K]): + value: T + checks: K \ No newline at end of file diff --git a/engine/language_client_python/src/types/function_results.rs b/engine/language_client_python/src/types/function_results.rs index 427538360..f890dd559 100644 --- a/engine/language_client_python/src/types/function_results.rs +++ b/engine/language_client_python/src/types/function_results.rs @@ -23,6 +23,6 @@ impl FunctionResult { .parsed_content() .map_err(BamlError::from_anyhow)?; - Ok(pythonize(py, &BamlValue::from(parsed))?) + Ok(pythonize(py, &parsed)?) } } diff --git a/engine/language_client_typescript/src/runtime.rs b/engine/language_client_typescript/src/runtime.rs index 570872a4f..76fdc561f 100644 --- a/engine/language_client_typescript/src/runtime.rs +++ b/engine/language_client_typescript/src/runtime.rs @@ -346,7 +346,7 @@ impl BamlRuntime { } #[napi] - pub fn flush(&mut self, env: Env) -> napi::Result<()> { + pub fn flush(&mut self, _env: Env) -> napi::Result<()> { self.inner.flush().map_err(|e| from_anyhow_error(e)) } diff --git a/integ-tests/baml_src/test-files/constraints/constraints.baml b/integ-tests/baml_src/test-files/constraints/constraints.baml new file mode 100644 index 000000000..302dabde1 --- /dev/null +++ b/integ-tests/baml_src/test-files/constraints/constraints.baml @@ -0,0 +1,68 @@ +// These classes and functions test several properties of +// constrains: +// +// - The ability for constrains on fields to pass or fail. +// - The ability for constraints on bare args and return types to pass or fail. +// - The ability of constraints to influence which variant of a union is chosen +// by the parser, when the structure is not sufficient to decide. + + +class Martian { + age int @check({{ this < 30 }}, "young_enough") +} + +class Earthling { + age int @check({{this < 200 and this > 0}}, "earth_aged") @check({{this >1}}, "no_infants") +} + + +class FooAny { + planetary_age Martian | Earthling + certainty int @check({{this == 102931}}, "unreasonably_certain") + species string @check({{this == "Homo sapiens"}}, "trivial") @check({{this|regex_match("Homo")}}, "regex_good") @check({{this|regex_match("neanderthalensis")}}, "regex_bad") +} + + +class InputWithConstraint { + name string @assert({{this|length > 1}}, "nonempty") + amount int @check({{ this < 1 }}, "small") +} + +function PredictAge(name: string) -> FooAny { + client GPT35 + prompt #" + Using your understanding of the historical popularity + of names, predict the age of a person with the name + {{ name }} in years. Also predict their genus and + species. It's Homo sapiens (with exactly that spelling + and capitalization). I'll give you a hint: If the name + is "Greg", his age is 41. + + {{ctx.output_format}} + "# +} + + +function PredictAgeComplex(inp: InputWithConstraint) -> InputWithConstraint { + client GPT35 + prompt #" + Using your understanding of the historical popularity + of names, predict the age of a person with the name + {{ inp.name }} in years. Also predict their genus and + species. It's Homo sapiens (with exactly that spelling). + + {{ctx.output_format}} + "# +} + +function PredictAgeBare(inp: string @assert({{this|length > 1}}, "big_enough")) -> int @check({{this == 10102}}, "too_big") { + client GPT35 + prompt #" + Using your understanding of the historical popularity + of names, predict the age of a person with the name + {{ inp.name }} in years. Also predict their genus and + species. It's Homo sapiens (with exactly that spelling). + + {{ctx.output_format}} + "# +} diff --git a/integ-tests/baml_src/test-files/constraints/contact-info.baml b/integ-tests/baml_src/test-files/constraints/contact-info.baml new file mode 100644 index 000000000..21e6939f0 --- /dev/null +++ b/integ-tests/baml_src/test-files/constraints/contact-info.baml @@ -0,0 +1,22 @@ +class PhoneNumber { + value string @check({{this|regex_match("\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")}}, "valid_phone_number") +} + +class EmailAddress { + value string @check({{this|regex_match("^[_]*([a-z0-9]+(\.|_*)?)+@([a-z][a-z0-9-]+(\.|-*\.))+[a-z]{2,6}$")}}, "valid_email") +} + +class ContactInfo { + primary PhoneNumber | EmailAddress +} + +function ExtractContactInfo(document: string) -> ContactInfo { + client GPT35 + prompt #" + Extract a primary contact info from this document: + + {{ document }} + + {{ ctx.output_format }} + "# +} diff --git a/integ-tests/openapi/baml_client/openapi.yaml b/integ-tests/openapi/baml_client/openapi.yaml index d9f81ccb0..23631bc1e 100644 --- a/integ-tests/openapi/baml_client/openapi.yaml +++ b/integ-tests/openapi/baml_client/openapi.yaml @@ -232,6 +232,19 @@ paths: title: ExpectFailureResponse type: string operationId: ExpectFailure + /call/ExtractContactInfo: + post: + requestBody: + $ref: '#/components/requestBodies/ExtractContactInfo' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: ExtractContactInfoResponse + $ref: '#/components/schemas/ContactInfo' + operationId: ExtractContactInfo /call/ExtractNames: post: requestBody: @@ -556,6 +569,45 @@ paths: items: $ref: '#/components/schemas/OptionalTest_ReturnType' operationId: OptionalTest_Function + /call/PredictAge: + post: + requestBody: + $ref: '#/components/requestBodies/PredictAge' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: PredictAgeResponse + $ref: '#/components/schemas/FooAny' + operationId: PredictAge + /call/PredictAgeBare: + post: + requestBody: + $ref: '#/components/requestBodies/PredictAgeBare' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: PredictAgeBareResponse + type: integer + operationId: PredictAgeBare + /call/PredictAgeComplex: + post: + requestBody: + $ref: '#/components/requestBodies/PredictAgeComplex' + responses: + '200': + description: Successful operation + content: + application/json: + schema: + title: PredictAgeComplexResponse + $ref: '#/components/schemas/InputWithConstraint' + operationId: PredictAgeComplex /call/PromptTestClaude: post: requestBody: @@ -1380,6 +1432,22 @@ components: $ref: '#/components/schemas/BamlOptions' required: [] additionalProperties: false + ExtractContactInfo: + required: true + content: + application/json: + schema: + title: ExtractContactInfoRequest + type: object + properties: + document: + type: string + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - document + additionalProperties: false ExtractNames: required: true content: @@ -1770,6 +1838,54 @@ components: required: - input additionalProperties: false + PredictAge: + required: true + content: + application/json: + schema: + title: PredictAgeRequest + type: object + properties: + name: + type: string + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - name + additionalProperties: false + PredictAgeBare: + required: true + content: + application/json: + schema: + title: PredictAgeBareRequest + type: object + properties: + inp: + type: string + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - inp + additionalProperties: false + PredictAgeComplex: + required: true + content: + application/json: + schema: + title: PredictAgeComplexRequest + type: object + properties: + inp: + $ref: '#/components/schemas/InputWithConstraint' + __baml_options__: + nullable: true + $ref: '#/components/schemas/BamlOptions' + required: + - inp + additionalProperties: false PromptTestClaude: required: true content: @@ -2719,6 +2835,16 @@ components: - big_nums - another additionalProperties: false + ContactInfo: + type: object + properties: + primary: + oneOf: + - $ref: '#/components/schemas/PhoneNumber' + - $ref: '#/components/schemas/EmailAddress' + required: + - primary + additionalProperties: false CustomTaskResult: type: object properties: @@ -2776,6 +2902,14 @@ components: properties: {} required: [] additionalProperties: false + Earthling: + type: object + properties: + age: + type: integer + required: + - age + additionalProperties: false Education: type: object properties: @@ -2811,6 +2945,14 @@ components: - body - from_address additionalProperties: false + EmailAddress: + type: object + properties: + value: + type: string + required: + - value + additionalProperties: false Event: type: object properties: @@ -2856,6 +2998,22 @@ components: - arrivalTime - seatNumber additionalProperties: false + FooAny: + type: object + properties: + planetary_age: + oneOf: + - $ref: '#/components/schemas/Martian' + - $ref: '#/components/schemas/Earthling' + certainty: + type: integer + species: + type: string + required: + - planetary_age + - certainty + - species + additionalProperties: false GroceryReceipt: type: object properties: @@ -2903,6 +3061,25 @@ components: - prop2 - prop3 additionalProperties: false + InputWithConstraint: + type: object + properties: + name: + type: string + amount: + type: integer + required: + - name + - amount + additionalProperties: false + Martian: + type: object + properties: + age: + type: integer + required: + - age + additionalProperties: false NamedArgsSingleClass: type: object properties: @@ -2988,6 +3165,14 @@ components: $ref: '#/components/schemas/Color' required: [] additionalProperties: false + PhoneNumber: + type: object + properties: + value: + type: string + required: + - value + additionalProperties: false Quantity: type: object properties: diff --git a/integ-tests/python/baml_client/async_client.py b/integ-tests/python/baml_client/async_client.py index f3a0d202a..61ba841d1 100644 --- a/integ-tests/python/baml_client/async_client.py +++ b/integ-tests/python/baml_client/async_client.py @@ -443,6 +443,30 @@ async def ExpectFailure( mdl = create_model("ExpectFailureReturnType", inner=(str, ...)) return coerce(mdl, raw.parsed()) + async def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> types.ContactInfo: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "ExtractContactInfo", + { + "document": document, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + return coerce(mdl, raw.parsed()) + async def ExtractNames( self, input: str, @@ -1019,6 +1043,78 @@ async def OptionalTest_Function( mdl = create_model("OptionalTest_FunctionReturnType", inner=(List[Optional[types.OptionalTest_ReturnType]], ...)) return coerce(mdl, raw.parsed()) + async def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> types.FooAny: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "PredictAge", + { + "name": name, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + return coerce(mdl, raw.parsed()) + + async def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.Checked[int,types.Checks__too_big]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "PredictAgeBare", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + return coerce(mdl, raw.parsed()) + + async def PredictAgeComplex( + self, + inp: types.InputWithConstraint, + baml_options: BamlCallOptions = {}, + ) -> types.InputWithConstraint: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = await self.__runtime.call_function( + "PredictAgeComplex", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeComplexReturnType", inner=(types.InputWithConstraint, ...)) + return coerce(mdl, raw.parsed()) + async def PromptTestClaude( self, input: str, @@ -2568,6 +2664,39 @@ def ExpectFailure( self.__ctx_manager.get(), ) + def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.ContactInfo, types.ContactInfo]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "ExtractContactInfo", + { + "document": document, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + partial_mdl = create_model("ExtractContactInfoPartialReturnType", inner=(partial_types.ContactInfo, ...)) + + return baml_py.BamlStream[partial_types.ContactInfo, types.ContactInfo]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def ExtractNames( self, input: str, @@ -3362,6 +3491,105 @@ def OptionalTest_Function( self.__ctx_manager.get(), ) + def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.FooAny, types.FooAny]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "PredictAge", + { + "name": name, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + partial_mdl = create_model("PredictAgePartialReturnType", inner=(partial_types.FooAny, ...)) + + return baml_py.BamlStream[partial_types.FooAny, types.FooAny]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + + def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "PredictAgeBare", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + partial_mdl = create_model("PredictAgeBarePartialReturnType", inner=(baml_py.Checked[Optional[int],types.Checks__too_big], ...)) + + return baml_py.BamlStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + + def PredictAgeComplex( + self, + inp: types.InputWithConstraint, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlStream[partial_types.InputWithConstraint, types.InputWithConstraint]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function( + "PredictAgeComplex", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeComplexReturnType", inner=(types.InputWithConstraint, ...)) + partial_mdl = create_model("PredictAgeComplexPartialReturnType", inner=(partial_types.InputWithConstraint, ...)) + + return baml_py.BamlStream[partial_types.InputWithConstraint, types.InputWithConstraint]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def PromptTestClaude( self, input: str, diff --git a/integ-tests/python/baml_client/inlinedbaml.py b/integ-tests/python/baml_client/inlinedbaml.py index 0c06b842b..04b4a35a8 100644 --- a/integ-tests/python/baml_client/inlinedbaml.py +++ b/integ-tests/python/baml_client/inlinedbaml.py @@ -29,6 +29,8 @@ "test-files/aliases/classes.baml": "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n", "test-files/aliases/enums.baml": "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}", "test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.", + "test-files/constraints/constraints.baml": "// These classes and functions test several properties of\n// constrains:\n//\n// - The ability for constrains on fields to pass or fail.\n// - The ability for constraints on bare args and return types to pass or fail.\n// - The ability of constraints to influence which variant of a union is chosen\n// by the parser, when the structure is not sufficient to decide.\n\n\nclass Martian {\n age int @check({{ this < 30 }}, \"young_enough\")\n}\n\nclass Earthling {\n age int @check({{this < 200 and this > 0}}, \"earth_aged\") @check({{this >1}}, \"no_infants\")\n}\n\n\nclass FooAny {\n planetary_age Martian | Earthling\n certainty int @check({{this == 102931}}, \"unreasonably_certain\")\n species string @check({{this == \"Homo sapiens\"}}, \"trivial\") @check({{this|regex_match(\"Homo\")}}, \"regex_good\") @check({{this|regex_match(\"neanderthalensis\")}}, \"regex_bad\")\n}\n\n\nclass InputWithConstraint {\n name string @assert({{this|length > 1}}, \"nonempty\")\n amount int @check({{ this < 1 }}, \"small\")\n}\n\nfunction PredictAge(name: string) -> FooAny {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling\n and capitalization). I'll give you a hint: If the name\n is \"Greg\", his age is 41.\n\n {{ctx.output_format}}\n \"#\n}\n\n\nfunction PredictAgeComplex(inp: InputWithConstraint) -> InputWithConstraint {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n\nfunction PredictAgeBare(inp: string @assert({{this|length > 1}}, \"big_enough\")) -> int @check({{this == 10102}}, \"too_big\") {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n", + "test-files/constraints/contact-info.baml": "class PhoneNumber {\n value string @check({{this|regex_match(\"\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\")}}, \"valid_phone_number\")\n}\n\nclass EmailAddress {\n value string @check({{this|regex_match(\"^[_]*([a-z0-9]+(\\.|_*)?)+@([a-z][a-z0-9-]+(\\.|-*\\.))+[a-z]{2,6}$\")}}, \"valid_email\")\n}\n\nclass ContactInfo {\n primary PhoneNumber | EmailAddress\n}\n\nfunction ExtractContactInfo(document: string) -> ContactInfo {\n client GPT35\n prompt #\"\n Extract a primary contact info from this document:\n\n {{ document }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/descriptions/descriptions.baml": "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}", "test-files/dynamic/client-registry.baml": "// Intentionally use a bad key\nclient BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n", "test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}", diff --git a/integ-tests/python/baml_client/partial_types.py b/integ-tests/python/baml_client/partial_types.py index 26b0b2f19..213ecb9bc 100644 --- a/integ-tests/python/baml_client/partial_types.py +++ b/integ-tests/python/baml_client/partial_types.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Union, Literal from . import types +from .types import Checks__unreasonably_certain, Checks__small, Checks__earth_aged__no_infants, Checks__valid_email, Checks__regex_bad__regex_good__trivial, Checks__too_big, Checks__valid_phone_number, Checks__young_enough ############################################################################### # @@ -74,6 +75,11 @@ class CompoundBigNumbers(BaseModel): big_nums: List["BigNumbers"] another: Optional["BigNumbers"] = None +class ContactInfo(BaseModel): + + + primary: Optional[Union["PhoneNumber", "EmailAddress"]] = None + class CustomTaskResult(BaseModel): @@ -112,6 +118,11 @@ class DynamicOutput(BaseModel): model_config = ConfigDict(extra='allow') +class Earthling(BaseModel): + + + age: baml_py.Checked[Optional[int],Checks__earth_aged__no_infants] + class Education(BaseModel): @@ -128,6 +139,11 @@ class Email(BaseModel): body: Optional[str] = None from_address: Optional[str] = None +class EmailAddress(BaseModel): + + + value: baml_py.Checked[Optional[str],Checks__valid_email] + class Event(BaseModel): @@ -150,6 +166,13 @@ class FlightConfirmation(BaseModel): arrivalTime: Optional[str] = None seatNumber: Optional[str] = None +class FooAny(BaseModel): + + + planetary_age: Optional[Union["Martian", "Earthling"]] = None + certainty: baml_py.Checked[Optional[int],Checks__unreasonably_certain] + species: baml_py.Checked[Optional[str],Checks__regex_bad__regex_good__trivial] + class GroceryReceipt(BaseModel): @@ -171,6 +194,17 @@ class InnerClass2(BaseModel): prop2: Optional[int] = None prop3: Optional[float] = None +class InputWithConstraint(BaseModel): + + + name: Optional[str] = None + amount: baml_py.Checked[Optional[int],Checks__small] + +class Martian(BaseModel): + + + age: baml_py.Checked[Optional[int],Checks__young_enough] + class NamedArgsSingleClass(BaseModel): @@ -218,6 +252,11 @@ class Person(BaseModel): name: Optional[str] = None hair_color: Optional[Union[types.Color, str]] = None +class PhoneNumber(BaseModel): + + + value: baml_py.Checked[Optional[str],Checks__valid_phone_number] + class Quantity(BaseModel): diff --git a/integ-tests/python/baml_client/sync_client.py b/integ-tests/python/baml_client/sync_client.py index 2d127ac97..5c99b6b44 100644 --- a/integ-tests/python/baml_client/sync_client.py +++ b/integ-tests/python/baml_client/sync_client.py @@ -441,6 +441,30 @@ def ExpectFailure( mdl = create_model("ExpectFailureReturnType", inner=(str, ...)) return coerce(mdl, raw.parsed()) + def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> types.ContactInfo: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "ExtractContactInfo", + { + "document": document, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + return coerce(mdl, raw.parsed()) + def ExtractNames( self, input: str, @@ -1017,6 +1041,78 @@ def OptionalTest_Function( mdl = create_model("OptionalTest_FunctionReturnType", inner=(List[Optional[types.OptionalTest_ReturnType]], ...)) return coerce(mdl, raw.parsed()) + def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> types.FooAny: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "PredictAge", + { + "name": name, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + return coerce(mdl, raw.parsed()) + + def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.Checked[int,types.Checks__too_big]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "PredictAgeBare", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + return coerce(mdl, raw.parsed()) + + def PredictAgeComplex( + self, + inp: types.InputWithConstraint, + baml_options: BamlCallOptions = {}, + ) -> types.InputWithConstraint: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.call_function_sync( + "PredictAgeComplex", + { + "inp": inp, + }, + self.__ctx_manager.get(), + tb, + __cr__, + ) + mdl = create_model("PredictAgeComplexReturnType", inner=(types.InputWithConstraint, ...)) + return coerce(mdl, raw.parsed()) + def PromptTestClaude( self, input: str, @@ -2567,6 +2663,39 @@ def ExpectFailure( self.__ctx_manager.get(), ) + def ExtractContactInfo( + self, + document: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.ContactInfo, types.ContactInfo]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "ExtractContactInfo", + { + "document": document, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("ExtractContactInfoReturnType", inner=(types.ContactInfo, ...)) + partial_mdl = create_model("ExtractContactInfoPartialReturnType", inner=(partial_types.ContactInfo, ...)) + + return baml_py.BamlSyncStream[partial_types.ContactInfo, types.ContactInfo]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def ExtractNames( self, input: str, @@ -3361,6 +3490,105 @@ def OptionalTest_Function( self.__ctx_manager.get(), ) + def PredictAge( + self, + name: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.FooAny, types.FooAny]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "PredictAge", + { + "name": name, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeReturnType", inner=(types.FooAny, ...)) + partial_mdl = create_model("PredictAgePartialReturnType", inner=(partial_types.FooAny, ...)) + + return baml_py.BamlSyncStream[partial_types.FooAny, types.FooAny]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + + def PredictAgeBare( + self, + inp: str, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "PredictAgeBare", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeBareReturnType", inner=(baml_py.Checked[int,types.Checks__too_big], ...)) + partial_mdl = create_model("PredictAgeBarePartialReturnType", inner=(baml_py.Checked[Optional[int],types.Checks__too_big], ...)) + + return baml_py.BamlSyncStream[baml_py.Checked[Optional[int],types.Checks__too_big], baml_py.Checked[int,types.Checks__too_big]]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + + def PredictAgeComplex( + self, + inp: types.InputWithConstraint, + baml_options: BamlCallOptions = {}, + ) -> baml_py.BamlSyncStream[partial_types.InputWithConstraint, types.InputWithConstraint]: + __tb__ = baml_options.get("tb", None) + if __tb__ is not None: + tb = __tb__._tb + else: + tb = None + __cr__ = baml_options.get("client_registry", None) + + raw = self.__runtime.stream_function_sync( + "PredictAgeComplex", + { + "inp": inp, + }, + None, + self.__ctx_manager.get(), + tb, + __cr__, + ) + + mdl = create_model("PredictAgeComplexReturnType", inner=(types.InputWithConstraint, ...)) + partial_mdl = create_model("PredictAgeComplexPartialReturnType", inner=(partial_types.InputWithConstraint, ...)) + + return baml_py.BamlSyncStream[partial_types.InputWithConstraint, types.InputWithConstraint]( + raw, + lambda x: coerce(partial_mdl, x), + lambda x: coerce(mdl, x), + self.__ctx_manager.get(), + ) + def PromptTestClaude( self, input: str, diff --git a/integ-tests/python/baml_client/type_builder.py b/integ-tests/python/baml_client/type_builder.py index e54c2b50a..ce7f12499 100644 --- a/integ-tests/python/baml_client/type_builder.py +++ b/integ-tests/python/baml_client/type_builder.py @@ -19,7 +19,7 @@ class TypeBuilder(_TypeBuilder): def __init__(self): super().__init__(classes=set( - ["BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Education","Email","Event","FakeImage","FlightConfirmation","GroceryReceipt","InnerClass","InnerClass2","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","UnionTest_ReturnType","WithReasoning",] + ["BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","GroceryReceipt","InnerClass","InnerClass2","InputWithConstraint","Martian","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","UnionTest_ReturnType","WithReasoning",] ), enums=set( ["Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum",] )) diff --git a/integ-tests/python/baml_client/types.py b/integ-tests/python/baml_client/types.py index e470b5127..98b7ec260 100644 --- a/integ-tests/python/baml_client/types.py +++ b/integ-tests/python/baml_client/types.py @@ -119,6 +119,33 @@ class TestEnum(str, Enum): F = "F" G = "G" +class Checks__valid_email(BaseModel): + valid_email: baml_py.Check + +class Checks__valid_phone_number(BaseModel): + valid_phone_number: baml_py.Check + +class Checks__small(BaseModel): + small: baml_py.Check + +class Checks__unreasonably_certain(BaseModel): + unreasonably_certain: baml_py.Check + +class Checks__earth_aged__no_infants(BaseModel): + earth_aged: baml_py.Check + no_infants: baml_py.Check + +class Checks__regex_bad__regex_good__trivial(BaseModel): + regex_bad: baml_py.Check + trivial: baml_py.Check + regex_good: baml_py.Check + +class Checks__too_big(BaseModel): + too_big: baml_py.Check + +class Checks__young_enough(BaseModel): + young_enough: baml_py.Check + class BigNumbers(BaseModel): @@ -165,6 +192,11 @@ class CompoundBigNumbers(BaseModel): big_nums: List["BigNumbers"] another: "BigNumbers" +class ContactInfo(BaseModel): + + + primary: Union["PhoneNumber", "EmailAddress"] + class CustomTaskResult(BaseModel): @@ -203,6 +235,11 @@ class DynamicOutput(BaseModel): model_config = ConfigDict(extra='allow') +class Earthling(BaseModel): + + + age: baml_py.Checked[int,Checks__earth_aged__no_infants] + class Education(BaseModel): @@ -219,6 +256,11 @@ class Email(BaseModel): body: str from_address: str +class EmailAddress(BaseModel): + + + value: baml_py.Checked[str,Checks__valid_email] + class Event(BaseModel): @@ -241,6 +283,13 @@ class FlightConfirmation(BaseModel): arrivalTime: str seatNumber: str +class FooAny(BaseModel): + + + planetary_age: Union["Martian", "Earthling"] + certainty: baml_py.Checked[int,Checks__unreasonably_certain] + species: baml_py.Checked[str,Checks__regex_bad__regex_good__trivial] + class GroceryReceipt(BaseModel): @@ -262,6 +311,17 @@ class InnerClass2(BaseModel): prop2: int prop3: float +class InputWithConstraint(BaseModel): + + + name: str + amount: baml_py.Checked[int,Checks__small] + +class Martian(BaseModel): + + + age: baml_py.Checked[int,Checks__young_enough] + class NamedArgsSingleClass(BaseModel): @@ -309,6 +369,11 @@ class Person(BaseModel): name: Optional[str] = None hair_color: Optional[Union["Color", str]] = None +class PhoneNumber(BaseModel): + + + value: baml_py.Checked[str,Checks__valid_phone_number] + class Quantity(BaseModel): diff --git a/integ-tests/python/tests/test_functions.py b/integ-tests/python/tests/test_functions.py index d11607f5d..c38417c5b 100644 --- a/integ-tests/python/tests/test_functions.py +++ b/integ-tests/python/tests/test_functions.py @@ -21,11 +21,13 @@ from ..baml_client import partial_types from ..baml_client.types import ( DynInputOutput, + FooAny, NamedArgsSingleEnumList, NamedArgsSingleClass, StringToClassEntry, CompoundBigNumbers, ) +import baml_client.types as types from ..baml_client.tracing import trace, set_tags, flush, on_log_event from ..baml_client.type_builder import TypeBuilder from ..baml_client import reset_baml_env_vars @@ -59,6 +61,21 @@ async def test_single_string_list(self): res = await b.TestFnNamedArgsSingleStringList(["a", "b", "c"]) assert "a" in res and "b" in res and "c" in res + @pytest.mark.asyncio + async def test_constraints(self): + res = await b.PredictAge("Greg") + assert res.certainty.checks.unreasonably_certain.status == "failed" + + @pytest.mark.asyncio + async def test_union_variant_checking(self): + res = await b.ExtractContactInfo("Reach me at 123-456-7890") + assert res.primary.value is not None + assert res.primary.value.checks.valid_phone_number.status == "succeeded" + + res = await b.ExtractContactInfo("Reach me at help@boundaryml.com") + assert res.primary.value is not None + assert res.primary.value.checks.valid_email.status == "succeeded" + @pytest.mark.asyncio async def test_single_class(self): res = await b.TestFnNamedArgsSingleClass( diff --git a/integ-tests/ruby/baml_client/client.rb b/integ-tests/ruby/baml_client/client.rb index 63cee85fe..0fb9e2594 100644 --- a/integ-tests/ruby/baml_client/client.rb +++ b/integ-tests/ruby/baml_client/client.rb @@ -562,6 +562,38 @@ def ExpectFailure( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + document: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::ContactInfo) + } + def ExtractContactInfo( + *varargs, + document:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ExtractContactInfo may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "ExtractContactInfo", + { + document: document, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -1330,6 +1362,102 @@ def OptionalTest_Function( (raw.parsed_using_types(Baml::Types)) end + sig { + params( + varargs: T.untyped, + name: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::FooAny) + } + def PredictAge( + *varargs, + name:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAge may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "PredictAge", + { + name: name, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + inp: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Integer) + } + def PredictAgeBare( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAgeBare may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "PredictAgeBare", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + + sig { + params( + varargs: T.untyped, + inp: Baml::Types::InputWithConstraint, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::Types::InputWithConstraint) + } + def PredictAgeComplex( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAgeComplex may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.call_function( + "PredictAgeComplex", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + (raw.parsed_using_types(Baml::Types)) + end + sig { params( varargs: T.untyped, @@ -3247,6 +3375,41 @@ def ExpectFailure( ) end + sig { + params( + varargs: T.untyped, + document: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::ContactInfo]) + } + def ExtractContactInfo( + *varargs, + document:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("ExtractContactInfo may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "ExtractContactInfo", + { + document: document, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::ContactInfo, Baml::Types::ContactInfo].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, @@ -4087,6 +4250,111 @@ def OptionalTest_Function( ) end + sig { + params( + varargs: T.untyped, + name: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::FooAny]) + } + def PredictAge( + *varargs, + name:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAge may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "PredictAge", + { + name: name, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::FooAny, Baml::Types::FooAny].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + inp: String, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Integer]) + } + def PredictAgeBare( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAgeBare may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "PredictAgeBare", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[T.nilable(Integer), Integer].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + + sig { + params( + varargs: T.untyped, + inp: Baml::Types::InputWithConstraint, + baml_options: T::Hash[Symbol, T.any(Baml::TypeBuilder, Baml::ClientRegistry)] + ).returns(Baml::BamlStream[Baml::Types::InputWithConstraint]) + } + def PredictAgeComplex( + *varargs, + inp:, + baml_options: {} + ) + if varargs.any? + + raise ArgumentError.new("PredictAgeComplex may only be called with keyword arguments") + end + if (baml_options.keys - [:client_registry, :tb]).any? + raise ArgumentError.new("Received unknown keys in baml_options (valid keys: :client_registry, :tb): #{baml_options.keys - [:client_registry, :tb]}") + end + + raw = @runtime.stream_function( + "PredictAgeComplex", + { + inp: inp, + }, + @ctx_manager, + baml_options[:tb]&.instance_variable_get(:@registry), + baml_options[:client_registry], + ) + Baml::BamlStream[Baml::PartialTypes::InputWithConstraint, Baml::Types::InputWithConstraint].new( + ffi_stream: raw, + ctx_manager: @ctx_manager + ) + end + sig { params( varargs: T.untyped, diff --git a/integ-tests/ruby/baml_client/inlined.rb b/integ-tests/ruby/baml_client/inlined.rb index 5db069735..16b1f6cca 100644 --- a/integ-tests/ruby/baml_client/inlined.rb +++ b/integ-tests/ruby/baml_client/inlined.rb @@ -29,6 +29,8 @@ module Inlined "test-files/aliases/classes.baml" => "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n", "test-files/aliases/enums.baml" => "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}", "test-files/comments/comments.baml" => "// add some functions, classes, enums etc with comments all over.", + "test-files/constraints/constraints.baml" => "// These classes and functions test several properties of\n// constrains:\n//\n// - The ability for constrains on fields to pass or fail.\n// - The ability for constraints on bare args and return types to pass or fail.\n// - The ability of constraints to influence which variant of a union is chosen\n// by the parser, when the structure is not sufficient to decide.\n\n\nclass Martian {\n age int @check({{ this < 30 }}, \"young_enough\")\n}\n\nclass Earthling {\n age int @check({{this < 200 and this > 0}}, \"earth_aged\") @check({{this >1}}, \"no_infants\")\n}\n\n\nclass FooAny {\n planetary_age Martian | Earthling\n certainty int @check({{this == 102931}}, \"unreasonably_certain\")\n species string @check({{this == \"Homo sapiens\"}}, \"trivial\") @check({{this|regex_match(\"Homo\")}}, \"regex_good\") @check({{this|regex_match(\"neanderthalensis\")}}, \"regex_bad\")\n}\n\n\nclass InputWithConstraint {\n name string @assert({{this|length > 1}}, \"nonempty\")\n amount int @check({{ this < 1 }}, \"small\")\n}\n\nfunction PredictAge(name: string) -> FooAny {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling\n and capitalization). I'll give you a hint: If the name\n is \"Greg\", his age is 41.\n\n {{ctx.output_format}}\n \"#\n}\n\n\nfunction PredictAgeComplex(inp: InputWithConstraint) -> InputWithConstraint {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n\nfunction PredictAgeBare(inp: string @assert({{this|length > 1}}, \"big_enough\")) -> int @check({{this == 10102}}, \"too_big\") {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n", + "test-files/constraints/contact-info.baml" => "class PhoneNumber {\n value string @check({{this|regex_match(\"\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\")}}, \"valid_phone_number\")\n}\n\nclass EmailAddress {\n value string @check({{this|regex_match(\"^[_]*([a-z0-9]+(\\.|_*)?)+@([a-z][a-z0-9-]+(\\.|-*\\.))+[a-z]{2,6}$\")}}, \"valid_email\")\n}\n\nclass ContactInfo {\n primary PhoneNumber | EmailAddress\n}\n\nfunction ExtractContactInfo(document: string) -> ContactInfo {\n client GPT35\n prompt #\"\n Extract a primary contact info from this document:\n\n {{ document }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/descriptions/descriptions.baml" => "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}", "test-files/dynamic/client-registry.baml" => "// Intentionally use a bad key\nclient BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n", "test-files/dynamic/dynamic.baml" => "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}", diff --git a/integ-tests/ruby/baml_client/partial-types.rb b/integ-tests/ruby/baml_client/partial-types.rb index 4bb8df926..5c0c55216 100644 --- a/integ-tests/ruby/baml_client/partial-types.rb +++ b/integ-tests/ruby/baml_client/partial-types.rb @@ -27,20 +27,26 @@ class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end class ClassWithImage < T::Struct; end class CompoundBigNumbers < T::Struct; end + class ContactInfo < T::Struct; end class CustomTaskResult < T::Struct; end class DummyOutput < T::Struct; end class DynInputOutput < T::Struct; end class DynamicClassOne < T::Struct; end class DynamicClassTwo < T::Struct; end class DynamicOutput < T::Struct; end + class Earthling < T::Struct; end class Education < T::Struct; end class Email < T::Struct; end + class EmailAddress < T::Struct; end class Event < T::Struct; end class FakeImage < T::Struct; end class FlightConfirmation < T::Struct; end + class FooAny < T::Struct; end class GroceryReceipt < T::Struct; end class InnerClass < T::Struct; end class InnerClass2 < T::Struct; end + class InputWithConstraint < T::Struct; end + class Martian < T::Struct; end class NamedArgsSingleClass < T::Struct; end class Nested < T::Struct; end class Nested2 < T::Struct; end @@ -48,6 +54,7 @@ class OptionalTest_Prop1 < T::Struct; end class OptionalTest_ReturnType < T::Struct; end class OrderInfo < T::Struct; end class Person < T::Struct; end + class PhoneNumber < T::Struct; end class Quantity < T::Struct; end class RaysData < T::Struct; end class ReceiptInfo < T::Struct; end @@ -170,6 +177,18 @@ def initialize(props) @props = props end end + class ContactInfo < T::Struct + include Baml::Sorbet::Struct + const :primary, T.nilable(T.any(Baml::PartialTypes::PhoneNumber, Baml::PartialTypes::EmailAddress)) + + def initialize(props) + super( + primary: props[:primary], + ) + + @props = props + end + end class CustomTaskResult < T::Struct include Baml::Sorbet::Struct const :bookOrder, T.nilable(T.any(Baml::PartialTypes::BookOrder, T.nilable(NilClass))) @@ -248,6 +267,18 @@ def initialize(props) @props = props end end + class Earthling < T::Struct + include Baml::Sorbet::Struct + const :age, T.nilable(Integer) + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class Education < T::Struct include Baml::Sorbet::Struct const :institution, T.nilable(String) @@ -284,6 +315,18 @@ def initialize(props) @props = props end end + class EmailAddress < T::Struct + include Baml::Sorbet::Struct + const :value, T.nilable(String) + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Event < T::Struct include Baml::Sorbet::Struct const :title, T.nilable(String) @@ -334,6 +377,22 @@ def initialize(props) @props = props end end + class FooAny < T::Struct + include Baml::Sorbet::Struct + const :planetary_age, T.nilable(T.any(Baml::PartialTypes::Martian, Baml::PartialTypes::Earthling)) + const :certainty, T.nilable(Integer) + const :species, T.nilable(String) + + def initialize(props) + super( + planetary_age: props[:planetary_age], + certainty: props[:certainty], + species: props[:species], + ) + + @props = props + end + end class GroceryReceipt < T::Struct include Baml::Sorbet::Struct const :receiptId, T.nilable(String) @@ -382,6 +441,32 @@ def initialize(props) @props = props end end + class InputWithConstraint < T::Struct + include Baml::Sorbet::Struct + const :name, T.nilable(String) + const :amount, T.nilable(Integer) + + def initialize(props) + super( + name: props[:name], + amount: props[:amount], + ) + + @props = props + end + end + class Martian < T::Struct + include Baml::Sorbet::Struct + const :age, T.nilable(Integer) + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class NamedArgsSingleClass < T::Struct include Baml::Sorbet::Struct const :key, T.nilable(String) @@ -488,6 +573,18 @@ def initialize(props) @props = props end end + class PhoneNumber < T::Struct + include Baml::Sorbet::Struct + const :value, T.nilable(String) + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Quantity < T::Struct include Baml::Sorbet::Struct const :amount, T.nilable(T.any(T.nilable(Integer), T.nilable(Float))) diff --git a/integ-tests/ruby/baml_client/type-registry.rb b/integ-tests/ruby/baml_client/type-registry.rb index 24d9a20ce..051da8f6b 100644 --- a/integ-tests/ruby/baml_client/type-registry.rb +++ b/integ-tests/ruby/baml_client/type-registry.rb @@ -18,7 +18,7 @@ module Baml class TypeBuilder def initialize @registry = Baml::Ffi::TypeBuilder.new - @classes = Set[ "BigNumbers", "Blah", "BookOrder", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassWithImage", "CompoundBigNumbers", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Education", "Email", "Event", "FakeImage", "FlightConfirmation", "GroceryReceipt", "InnerClass", "InnerClass2", "NamedArgsSingleClass", "Nested", "Nested2", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "Person", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "UnionTest_ReturnType", "WithReasoning", ] + @classes = Set[ "BigNumbers", "Blah", "BookOrder", "ClassOptionalOutput", "ClassOptionalOutput2", "ClassWithImage", "CompoundBigNumbers", "ContactInfo", "CustomTaskResult", "DummyOutput", "DynInputOutput", "DynamicClassOne", "DynamicClassTwo", "DynamicOutput", "Earthling", "Education", "Email", "EmailAddress", "Event", "FakeImage", "FlightConfirmation", "FooAny", "GroceryReceipt", "InnerClass", "InnerClass2", "InputWithConstraint", "Martian", "NamedArgsSingleClass", "Nested", "Nested2", "OptionalTest_Prop1", "OptionalTest_ReturnType", "OrderInfo", "Person", "PhoneNumber", "Quantity", "RaysData", "ReceiptInfo", "ReceiptItem", "Recipe", "Resume", "Schema", "SearchParams", "SomeClassNestedDynamic", "StringToClassEntry", "TestClassAlias", "TestClassNested", "TestClassWithEnum", "TestOutputClass", "UnionTest_ReturnType", "WithReasoning", ] @enums = Set[ "Category", "Category2", "Category3", "Color", "DataType", "DynEnumOne", "DynEnumTwo", "EnumInClass", "EnumOutput", "Hobby", "NamedArgsSingleEnum", "NamedArgsSingleEnumList", "OptionalTest_CategoryType", "OrderStatus", "Tag", "TestEnum", ] end diff --git a/integ-tests/ruby/baml_client/types.rb b/integ-tests/ruby/baml_client/types.rb index 989ec5aae..1118cccfd 100644 --- a/integ-tests/ruby/baml_client/types.rb +++ b/integ-tests/ruby/baml_client/types.rb @@ -137,20 +137,26 @@ class ClassOptionalOutput < T::Struct; end class ClassOptionalOutput2 < T::Struct; end class ClassWithImage < T::Struct; end class CompoundBigNumbers < T::Struct; end + class ContactInfo < T::Struct; end class CustomTaskResult < T::Struct; end class DummyOutput < T::Struct; end class DynInputOutput < T::Struct; end class DynamicClassOne < T::Struct; end class DynamicClassTwo < T::Struct; end class DynamicOutput < T::Struct; end + class Earthling < T::Struct; end class Education < T::Struct; end class Email < T::Struct; end + class EmailAddress < T::Struct; end class Event < T::Struct; end class FakeImage < T::Struct; end class FlightConfirmation < T::Struct; end + class FooAny < T::Struct; end class GroceryReceipt < T::Struct; end class InnerClass < T::Struct; end class InnerClass2 < T::Struct; end + class InputWithConstraint < T::Struct; end + class Martian < T::Struct; end class NamedArgsSingleClass < T::Struct; end class Nested < T::Struct; end class Nested2 < T::Struct; end @@ -158,6 +164,7 @@ class OptionalTest_Prop1 < T::Struct; end class OptionalTest_ReturnType < T::Struct; end class OrderInfo < T::Struct; end class Person < T::Struct; end + class PhoneNumber < T::Struct; end class Quantity < T::Struct; end class RaysData < T::Struct; end class ReceiptInfo < T::Struct; end @@ -280,6 +287,18 @@ def initialize(props) @props = props end end + class ContactInfo < T::Struct + include Baml::Sorbet::Struct + const :primary, T.any(Baml::Types::PhoneNumber, Baml::Types::EmailAddress) + + def initialize(props) + super( + primary: props[:primary], + ) + + @props = props + end + end class CustomTaskResult < T::Struct include Baml::Sorbet::Struct const :bookOrder, T.any(Baml::Types::BookOrder, T.nilable(NilClass)) @@ -358,6 +377,18 @@ def initialize(props) @props = props end end + class Earthling < T::Struct + include Baml::Sorbet::Struct + const :age, Integer + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class Education < T::Struct include Baml::Sorbet::Struct const :institution, String @@ -394,6 +425,18 @@ def initialize(props) @props = props end end + class EmailAddress < T::Struct + include Baml::Sorbet::Struct + const :value, String + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Event < T::Struct include Baml::Sorbet::Struct const :title, String @@ -444,6 +487,22 @@ def initialize(props) @props = props end end + class FooAny < T::Struct + include Baml::Sorbet::Struct + const :planetary_age, T.any(Baml::Types::Martian, Baml::Types::Earthling) + const :certainty, Integer + const :species, String + + def initialize(props) + super( + planetary_age: props[:planetary_age], + certainty: props[:certainty], + species: props[:species], + ) + + @props = props + end + end class GroceryReceipt < T::Struct include Baml::Sorbet::Struct const :receiptId, String @@ -492,6 +551,32 @@ def initialize(props) @props = props end end + class InputWithConstraint < T::Struct + include Baml::Sorbet::Struct + const :name, String + const :amount, Integer + + def initialize(props) + super( + name: props[:name], + amount: props[:amount], + ) + + @props = props + end + end + class Martian < T::Struct + include Baml::Sorbet::Struct + const :age, Integer + + def initialize(props) + super( + age: props[:age], + ) + + @props = props + end + end class NamedArgsSingleClass < T::Struct include Baml::Sorbet::Struct const :key, String @@ -598,6 +683,18 @@ def initialize(props) @props = props end end + class PhoneNumber < T::Struct + include Baml::Sorbet::Struct + const :value, String + + def initialize(props) + super( + value: props[:value], + ) + + @props = props + end + end class Quantity < T::Struct include Baml::Sorbet::Struct const :amount, T.any(Integer, Float) diff --git a/integ-tests/typescript/baml_client/async_client.ts b/integ-tests/typescript/baml_client/async_client.ts index 41ea1ac5e..281f90e99 100644 --- a/integ-tests/typescript/baml_client/async_client.ts +++ b/integ-tests/typescript/baml_client/async_client.ts @@ -16,7 +16,7 @@ $ pnpm add @boundaryml/baml // @ts-nocheck // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlStream, Image, ClientRegistry, BamlValidationError, createBamlValidationError } from "@boundaryml/baml" -import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Education, Email, Event, FakeImage, FlightConfirmation, GroceryReceipt, InnerClass, InnerClass2, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, GroceryReceipt, InnerClass, InnerClass2, InputWithConstraint, Martian, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" @@ -442,6 +442,31 @@ export class BamlAsyncClient { } } + async ExtractContactInfo( + document: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "ExtractContactInfo", + { + "document": document + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as ContactInfo + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1042,6 +1067,81 @@ export class BamlAsyncClient { } } + async PredictAge( + name: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "PredictAge", + { + "name": name + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as FooAny + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async PredictAgeBare( + inp: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "PredictAgeBare", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + async PredictAgeComplex( + inp: InputWithConstraint, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): Promise { + try { + const raw = await this.runtime.callFunction( + "PredictAgeComplex", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as InputWithConstraint + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + async PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -2626,6 +2726,39 @@ class BamlStreamClient { } } + ExtractContactInfo( + document: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, ContactInfo> { + try { + const raw = this.runtime.streamFunction( + "ExtractContactInfo", + { + "document": document + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, ContactInfo>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is ContactInfo => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -3418,6 +3551,105 @@ class BamlStreamClient { } } + PredictAge( + name: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, FooAny> { + try { + const raw = this.runtime.streamFunction( + "PredictAge", + { + "name": name + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, FooAny>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is FooAny => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + PredictAgeBare( + inp: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, number> { + try { + const raw = this.runtime.streamFunction( + "PredictAgeBare", + { + "inp": inp + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, number>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is number => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + + PredictAgeComplex( + inp: InputWithConstraint, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): BamlStream, InputWithConstraint> { + try { + const raw = this.runtime.streamFunction( + "PredictAgeComplex", + { + "inp": inp + }, + undefined, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return new BamlStream, InputWithConstraint>( + raw, + (a): a is RecursivePartialNull => a, + (a): a is InputWithConstraint => a, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + ) + } catch (error) { + if (error instanceof Error) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } + } + throw error; + } + } + PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/inlinedbaml.ts b/integ-tests/typescript/baml_client/inlinedbaml.ts index 185c840b8..cd623f543 100644 --- a/integ-tests/typescript/baml_client/inlinedbaml.ts +++ b/integ-tests/typescript/baml_client/inlinedbaml.ts @@ -30,6 +30,8 @@ const fileMap = { "test-files/aliases/classes.baml": "class TestClassAlias {\n key string @alias(\"key-dash\") @description(#\"\n This is a description for key\n af asdf\n \"#)\n key2 string @alias(\"key21\")\n key3 string @alias(\"key with space\")\n key4 string //unaliased\n key5 string @alias(\"key.with.punctuation/123\")\n}\n\nfunction FnTestClassAlias(input: string) -> TestClassAlias {\n client GPT35\n prompt #\"\n {{ctx.output_format}}\n \"#\n}\n\ntest FnTestClassAlias {\n functions [FnTestClassAlias]\n args {\n input \"example input\"\n }\n}\n", "test-files/aliases/enums.baml": "enum TestEnum {\n A @alias(\"k1\") @description(#\"\n User is angry\n \"#)\n B @alias(\"k22\") @description(#\"\n User is happy\n \"#)\n // tests whether k1 doesnt incorrectly get matched with k11\n C @alias(\"k11\") @description(#\"\n User is sad\n \"#)\n D @alias(\"k44\") @description(\n User is confused\n )\n E @description(\n User is excited\n )\n F @alias(\"k5\") // only alias\n \n G @alias(\"k6\") @description(#\"\n User is bored\n With a long description\n \"#)\n \n @@alias(\"Category\")\n}\n\nfunction FnTestAliasedEnumOutput(input: string) -> TestEnum {\n client GPT35\n prompt #\"\n Classify the user input into the following category\n \n {{ ctx.output_format }}\n\n {{ _.role('user') }}\n {{input}}\n\n {{ _.role('assistant') }}\n Category ID:\n \"#\n}\n\ntest FnTestAliasedEnumOutput {\n functions [FnTestAliasedEnumOutput]\n args {\n input \"mehhhhh\"\n }\n}", "test-files/comments/comments.baml": "// add some functions, classes, enums etc with comments all over.", + "test-files/constraints/constraints.baml": "// These classes and functions test several properties of\n// constrains:\n//\n// - The ability for constrains on fields to pass or fail.\n// - The ability for constraints on bare args and return types to pass or fail.\n// - The ability of constraints to influence which variant of a union is chosen\n// by the parser, when the structure is not sufficient to decide.\n\n\nclass Martian {\n age int @check({{ this < 30 }}, \"young_enough\")\n}\n\nclass Earthling {\n age int @check({{this < 200 and this > 0}}, \"earth_aged\") @check({{this >1}}, \"no_infants\")\n}\n\n\nclass FooAny {\n planetary_age Martian | Earthling\n certainty int @check({{this == 102931}}, \"unreasonably_certain\")\n species string @check({{this == \"Homo sapiens\"}}, \"trivial\") @check({{this|regex_match(\"Homo\")}}, \"regex_good\") @check({{this|regex_match(\"neanderthalensis\")}}, \"regex_bad\")\n}\n\n\nclass InputWithConstraint {\n name string @assert({{this|length > 1}}, \"nonempty\")\n amount int @check({{ this < 1 }}, \"small\")\n}\n\nfunction PredictAge(name: string) -> FooAny {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling\n and capitalization). I'll give you a hint: If the name\n is \"Greg\", his age is 41.\n\n {{ctx.output_format}}\n \"#\n}\n\n\nfunction PredictAgeComplex(inp: InputWithConstraint) -> InputWithConstraint {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n\nfunction PredictAgeBare(inp: string @assert({{this|length > 1}}, \"big_enough\")) -> int @check({{this == 10102}}, \"too_big\") {\n client GPT35\n prompt #\"\n Using your understanding of the historical popularity\n of names, predict the age of a person with the name\n {{ inp.name }} in years. Also predict their genus and\n species. It's Homo sapiens (with exactly that spelling).\n\n {{ctx.output_format}}\n \"#\n}\n", + "test-files/constraints/contact-info.baml": "class PhoneNumber {\n value string @check({{this|regex_match(\"\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}\")}}, \"valid_phone_number\")\n}\n\nclass EmailAddress {\n value string @check({{this|regex_match(\"^[_]*([a-z0-9]+(\\.|_*)?)+@([a-z][a-z0-9-]+(\\.|-*\\.))+[a-z]{2,6}$\")}}, \"valid_email\")\n}\n\nclass ContactInfo {\n primary PhoneNumber | EmailAddress\n}\n\nfunction ExtractContactInfo(document: string) -> ContactInfo {\n client GPT35\n prompt #\"\n Extract a primary contact info from this document:\n\n {{ document }}\n\n {{ ctx.output_format }}\n \"#\n}\n", "test-files/descriptions/descriptions.baml": "\nclass Nested {\n prop3 string | null @description(#\"\n write \"three\"\n \"#)\n prop4 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n prop20 Nested2\n}\n\nclass Nested2 {\n prop11 string | null @description(#\"\n write \"three\"\n \"#)\n prop12 string | null @description(#\"\n write \"four\"\n \"#) @alias(\"blah\")\n}\n\nclass Schema {\n prop1 string | null @description(#\"\n write \"one\"\n \"#)\n prop2 Nested | string @description(#\"\n write \"two\"\n \"#)\n prop5 (string | null)[] @description(#\"\n write \"hi\"\n \"#)\n prop6 string | Nested[] @alias(\"blah\") @description(#\"\n write the string \"blah\" regardless of the other types here\n \"#)\n nested_attrs (string | null | Nested)[] @description(#\"\n write the string \"nested\" regardless of other types\n \"#)\n parens (string | null) @description(#\"\n write \"parens1\"\n \"#)\n other_group (string | (int | string)) @description(#\"\n write \"other\"\n \"#) @alias(other)\n}\n\n\nfunction SchemaDescriptions(input: string) -> Schema {\n client GPT4o\n prompt #\"\n Return a schema with this format:\n\n {{ctx.output_format}}\n \"#\n}", "test-files/dynamic/client-registry.baml": "// Intentionally use a bad key\nclient BadClient {\n provider openai\n options {\n model \"gpt-3.5-turbo\"\n api_key \"sk-invalid\"\n }\n}\n\nfunction ExpectFailure() -> string {\n client BadClient\n\n prompt #\"\n What is the capital of England?\n \"#\n}\n", "test-files/dynamic/dynamic.baml": "class DynamicClassOne {\n @@dynamic\n}\n\nenum DynEnumOne {\n @@dynamic\n}\n\nenum DynEnumTwo {\n @@dynamic\n}\n\nclass SomeClassNestedDynamic {\n hi string\n @@dynamic\n\n}\n\nclass DynamicClassTwo {\n hi string\n some_class SomeClassNestedDynamic\n status DynEnumOne\n @@dynamic\n}\n\nfunction DynamicFunc(input: DynamicClassOne) -> DynamicClassTwo {\n client GPT35\n prompt #\"\n Please extract the schema from \n {{ input }}\n\n {{ ctx.output_format }}\n \"#\n}\n\nclass DynInputOutput {\n testKey string\n @@dynamic\n}\n\nfunction DynamicInputOutput(input: DynInputOutput) -> DynInputOutput {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\nfunction DynamicListInputOutput(input: DynInputOutput[]) -> DynInputOutput[] {\n client GPT35\n prompt #\"\n Here is some input data:\n ----\n {{ input }}\n ----\n\n Extract the information.\n {{ ctx.output_format }}\n \"#\n}\n\n\n\nclass DynamicOutput {\n @@dynamic\n}\n \nfunction MyFunc(input: string) -> DynamicOutput {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}\n\nfunction ClassifyDynEnumTwo(input: string) -> DynEnumTwo {\n client GPT35\n prompt #\"\n Given a string, extract info using the schema:\n\n {{ input}}\n\n {{ ctx.output_format }}\n \"#\n}", diff --git a/integ-tests/typescript/baml_client/sync_client.ts b/integ-tests/typescript/baml_client/sync_client.ts index 25e577efa..12af4c915 100644 --- a/integ-tests/typescript/baml_client/sync_client.ts +++ b/integ-tests/typescript/baml_client/sync_client.ts @@ -16,7 +16,7 @@ $ pnpm add @boundaryml/baml // @ts-nocheck // biome-ignore format: autogenerated code import { BamlRuntime, FunctionResult, BamlCtxManager, BamlSyncStream, Image, ClientRegistry } from "@boundaryml/baml" -import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Education, Email, Event, FakeImage, FlightConfirmation, GroceryReceipt, InnerClass, InnerClass2, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" +import {BigNumbers, Blah, BookOrder, ClassOptionalOutput, ClassOptionalOutput2, ClassWithImage, CompoundBigNumbers, ContactInfo, CustomTaskResult, DummyOutput, DynInputOutput, DynamicClassOne, DynamicClassTwo, DynamicOutput, Earthling, Education, Email, EmailAddress, Event, FakeImage, FlightConfirmation, FooAny, GroceryReceipt, InnerClass, InnerClass2, InputWithConstraint, Martian, NamedArgsSingleClass, Nested, Nested2, OptionalTest_Prop1, OptionalTest_ReturnType, OrderInfo, Person, PhoneNumber, Quantity, RaysData, ReceiptInfo, ReceiptItem, Recipe, Resume, Schema, SearchParams, SomeClassNestedDynamic, StringToClassEntry, TestClassAlias, TestClassNested, TestClassWithEnum, TestOutputClass, UnionTest_ReturnType, WithReasoning, Category, Category2, Category3, Color, DataType, DynEnumOne, DynEnumTwo, EnumInClass, EnumOutput, Hobby, NamedArgsSingleEnum, NamedArgsSingleEnumList, OptionalTest_CategoryType, OrderStatus, Tag, TestEnum} from "./types" import TypeBuilder from "./type_builder" import { DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_CTX, DO_NOT_USE_DIRECTLY_UNLESS_YOU_KNOW_WHAT_YOURE_DOING_RUNTIME } from "./globals" @@ -442,6 +442,31 @@ export class BamlSyncClient { } } + ExtractContactInfo( + document: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): ContactInfo { + try { + const raw = this.runtime.callFunctionSync( + "ExtractContactInfo", + { + "document": document + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as ContactInfo + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + ExtractNames( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } @@ -1042,6 +1067,81 @@ export class BamlSyncClient { } } + PredictAge( + name: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): FooAny { + try { + const raw = this.runtime.callFunctionSync( + "PredictAge", + { + "name": name + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as FooAny + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + PredictAgeBare( + inp: string, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): number { + try { + const raw = this.runtime.callFunctionSync( + "PredictAgeBare", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as number + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + + PredictAgeComplex( + inp: InputWithConstraint, + __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } + ): InputWithConstraint { + try { + const raw = this.runtime.callFunctionSync( + "PredictAgeComplex", + { + "inp": inp + }, + this.ctx_manager.cloneContext(), + __baml_options__?.tb?.__tb(), + __baml_options__?.clientRegistry, + ) + return raw.parsed() as InputWithConstraint + } catch (error: any) { + const bamlError = createBamlValidationError(error); + if (bamlError instanceof BamlValidationError) { + throw bamlError; + } else { + throw error; + } + } + } + PromptTestClaude( input: string, __baml_options__?: { tb?: TypeBuilder, clientRegistry?: ClientRegistry } diff --git a/integ-tests/typescript/baml_client/type_builder.ts b/integ-tests/typescript/baml_client/type_builder.ts index 1c95e7578..a72dfed17 100644 --- a/integ-tests/typescript/baml_client/type_builder.ts +++ b/integ-tests/typescript/baml_client/type_builder.ts @@ -48,7 +48,7 @@ export default class TypeBuilder { constructor() { this.tb = new _TypeBuilder({ classes: new Set([ - "BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Education","Email","Event","FakeImage","FlightConfirmation","GroceryReceipt","InnerClass","InnerClass2","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","UnionTest_ReturnType","WithReasoning", + "BigNumbers","Blah","BookOrder","ClassOptionalOutput","ClassOptionalOutput2","ClassWithImage","CompoundBigNumbers","ContactInfo","CustomTaskResult","DummyOutput","DynInputOutput","DynamicClassOne","DynamicClassTwo","DynamicOutput","Earthling","Education","Email","EmailAddress","Event","FakeImage","FlightConfirmation","FooAny","GroceryReceipt","InnerClass","InnerClass2","InputWithConstraint","Martian","NamedArgsSingleClass","Nested","Nested2","OptionalTest_Prop1","OptionalTest_ReturnType","OrderInfo","Person","PhoneNumber","Quantity","RaysData","ReceiptInfo","ReceiptItem","Recipe","Resume","Schema","SearchParams","SomeClassNestedDynamic","StringToClassEntry","TestClassAlias","TestClassNested","TestClassWithEnum","TestOutputClass","UnionTest_ReturnType","WithReasoning", ]), enums: new Set([ "Category","Category2","Category3","Color","DataType","DynEnumOne","DynEnumTwo","EnumInClass","EnumOutput","Hobby","NamedArgsSingleEnum","NamedArgsSingleEnumList","OptionalTest_CategoryType","OrderStatus","Tag","TestEnum", diff --git a/integ-tests/typescript/baml_client/types.ts b/integ-tests/typescript/baml_client/types.ts index cc6c2b3d7..f47098ee7 100644 --- a/integ-tests/typescript/baml_client/types.ts +++ b/integ-tests/typescript/baml_client/types.ts @@ -162,6 +162,11 @@ export interface CompoundBigNumbers { } +export interface ContactInfo { + primary: PhoneNumber | EmailAddress + +} + export interface CustomTaskResult { bookOrder?: BookOrder | null | null flightConfirmation?: FlightConfirmation | null | null @@ -200,6 +205,11 @@ export interface DynamicOutput { [key: string]: any; } +export interface Earthling { + age: number + +} + export interface Education { institution: string location: string @@ -216,6 +226,11 @@ export interface Email { } +export interface EmailAddress { + value: string + +} + export interface Event { title: string date: string @@ -238,6 +253,13 @@ export interface FlightConfirmation { } +export interface FooAny { + planetary_age: Martian | Earthling + certainty: number + species: string + +} + export interface GroceryReceipt { receiptId: string storeName: string @@ -259,6 +281,17 @@ export interface InnerClass2 { } +export interface InputWithConstraint { + name: string + amount: number + +} + +export interface Martian { + age: number + +} + export interface NamedArgsSingleClass { key: string key_two: boolean @@ -306,6 +339,11 @@ export interface Person { [key: string]: any; } +export interface PhoneNumber { + value: string + +} + export interface Quantity { amount: number | number unit?: string | null