Skip to content

Commit

Permalink
Add constraints via Jinja expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Oct 19, 2024
1 parent 830b0cb commit e010444
Show file tree
Hide file tree
Showing 125 changed files with 5,002 additions and 423 deletions.
12 changes: 12 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions engine/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ http-body = "1.0.0"
indexmap = { version = "2.1.0", features = ["serde"] }
indoc = "2.0.5"
log = "0.4.20"
# TODO: disable imports, etc
minijinja = { version = "1.0.16", default-features = false, features = [
"macros",
"builtins",
"debug",
"preserve_order",
"adjacent_loop_items",
"unicode",
"json",
"unstable_machinery",
"unstable_machinery_serde",
"custom_syntax",
"internal_debug",
"deserialization",
# We don't want to use these features:
# multi_template
# loader
#
] }
regex = "1.10.4"
scopeguard = "1.2.0"
serde_json = { version = "1", features = ["float_roundtrip", "preserve_order"] }
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ internal-baml-jinja-types = { path = "../jinja" }
internal-baml-parser-database = { path = "../parser-database" }
internal-baml-prompt-parser = { path = "../prompt-parser" }
internal-baml-schema-ast = { path = "../schema-ast" }
minijinja.workspace = true
rayon = "1.8.0"
regex = "1.10.3"
semver = "1.0.20"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,8 @@ impl ScopeStack {
pub fn push_error(&mut self, error: String) {
self.scopes.last_mut().unwrap().errors.push(error);
}

pub fn push_warning(&mut self, warning: String) {
self.scopes.last_mut().unwrap().warnings.push(warning);
}
}
35 changes: 34 additions & 1 deletion engine/baml-lib/baml-core/src/ir/ir_helpers/to_baml_arg.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use baml_types::{BamlMap, BamlMediaType, BamlValue, FieldType, LiteralValue, TypeValue};
use baml_types::{
BamlMap, BamlValue, Constraint, ConstraintLevel, FieldType, LiteralValue, TypeValue
};
use core::result::Result;
use std::path::PathBuf;

use crate::ir::IntermediateRepr;

use super::{scope_diagnostics::ScopeStack, IRHelper};
use crate::ir::jinja_helpers::evaluate_predicate;

#[derive(Default)]
pub struct ParameterError {
Expand Down Expand Up @@ -325,6 +328,36 @@ impl ArgCoercer {
}
}
}
FieldType::Constrained { base, constraints } => {
let val = self.coerce_arg(ir, base, value, scope)?;
for c@Constraint {
level,
expression,
label,
} in constraints.iter()
{
let constraint_ok =
evaluate_predicate(&val, &expression).unwrap_or_else(|err| {
scope.push_error(format!(
"Error while evaluating check {c:?}: {:?}",
err
));
false
});
if !constraint_ok {
let msg = label.as_ref().unwrap_or(&expression.0);
match level {
ConstraintLevel::Check => {
scope.push_warning(format!("Failed check: {msg}"));
}
ConstraintLevel::Assert => {
scope.push_error(format!("Failed assert: {msg}"));
}
}
}
}
Ok(val)
}
}
}
}
98 changes: 98 additions & 0 deletions engine/baml-lib/baml-core/src/ir/jinja_helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use std::collections::HashMap;

use baml_types::{BamlValue, JinjaExpression};
use regex::Regex;

pub fn get_env<'a>() -> minijinja::Environment<'a> {
let mut env = minijinja::Environment::new();
env.set_debug(true);
env.set_trim_blocks(true);
env.set_lstrip_blocks(true);
env.add_filter("regex_match", regex_match);
env
}

fn regex_match(value: String, regex: String) -> bool {
match Regex::new(&regex) {
Err(_) => false,
Ok(re) => re.is_match(&value)
}
}

/// Render a bare minijinaja expression with the given context.
/// E.g. `"a|length > 2"` with context `{"a": [1, 2, 3]}` will return `"true"`.
pub fn render_expression(
expression: &JinjaExpression,
ctx: &HashMap<String, BamlValue>,
) -> anyhow::Result<String> {
let env = get_env();
// In rust string literals, `{` is escaped as `{{`.
// So producing the string `{{}}` requires writing the literal `"{{{{}}}}"`
let template = format!(r#"{{{{ {} }}}}"#, expression.0);
let args_dict = minijinja::Value::from_serialize(ctx);
eprintln!("{}", &template);
Ok(env.render_str(&template, &args_dict)?)
}

// TODO: (Greg) better error handling.
// TODO: (Greg) Upstream, typecheck the expression.
pub fn evaluate_predicate(
this: &BamlValue,
predicate_expression: &JinjaExpression,
) -> Result<bool, anyhow::Error> {
let ctx: HashMap<String, BamlValue> =
[("this".to_string(), this.clone())].into_iter().collect();
match render_expression(&predicate_expression, &ctx)?.as_ref() {
"true" => Ok(true),
"false" => Ok(false),
_ => Err(anyhow::anyhow!("TODO")),
}
}

#[cfg(test)]
mod tests {
use baml_types::BamlValue;
use super::*;


#[test]
fn test_render_expressions() {
let ctx = vec![(
"a".to_string(),
BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into())
), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))]
.into_iter()
.collect();

assert_eq!(
render_expression(&JinjaExpression("1".to_string()), &ctx).unwrap(),
"1"
);
assert_eq!(
render_expression(&JinjaExpression("1 + 1".to_string()), &ctx).unwrap(),
"2"
);
assert_eq!(
render_expression(&JinjaExpression("a|length > 2".to_string()), &ctx).unwrap(),
"true"
);
}

#[test]
fn test_render_regex_match() {
let ctx = vec![(
"a".to_string(),
BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into())
), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))]
.into_iter()
.collect();
assert_eq!(
render_expression(&JinjaExpression(r##"b|regex_match("123")"##.to_string()), &ctx).unwrap(),
"true"
);
assert_eq!(
render_expression(&JinjaExpression(r##"b|regex_match("\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}")"##.to_string()), &ctx).unwrap(),
"true"
)
}
}
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/src/ir/json_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ impl<'db> WithJsonSchema for FieldType {
}
}
}
FieldType::Constrained { base, .. } => base.json_schema(),
}
}
}
1 change: 1 addition & 0 deletions engine/baml-lib/baml-core/src/ir/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod ir_helpers;
pub mod jinja_helpers;
mod json_schema;
pub mod repr;
mod walker;
Expand Down
Loading

0 comments on commit e010444

Please sign in to comment.