Skip to content

Commit

Permalink
Codegen for Check Types, fix union picker, regex filter
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Oct 14, 2024
1 parent 7f90614 commit 803cec8
Show file tree
Hide file tree
Showing 34 changed files with 755 additions and 129 deletions.
11 changes: 11 additions & 0 deletions engine/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 20 additions & 32 deletions engine/baml-lib/baml-types/src/baml_value.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashSet, fmt};

use serde::ser::{SerializeMap, SerializeSeq};
Expand Down Expand Up @@ -476,72 +477,59 @@ impl Serialize for BamlValueWithMeta<Vec<ResponseCheck>> {
where S: Serializer,
{
match self {
BamlValueWithMeta::String(v, cr) => serialize_with_constraints(v, cr, serializer),
BamlValueWithMeta::Int(v, cr) => serialize_with_constraints(v, cr, serializer),
BamlValueWithMeta::Float(v, cr) => serialize_with_constraints(v, cr, serializer),
BamlValueWithMeta::Bool(v, cr) => serialize_with_constraints(v, cr, serializer),
BamlValueWithMeta::String(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Int(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Float(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Bool(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Map(v, cr) => {
let mut map = serializer.serialize_map(None)?;
for (key, value) in v {
map.serialize_entry(key, value)?;
}
add_failed_checks(&mut map, cr)?;
map.end()
},
BamlValueWithMeta::List(v, cr) => {
if cr.len() > 0 {
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("value", v)?;
map.serialize_entry("checks", cr)?;
map.end()
} else {
v.serialize(serializer)
}
},
BamlValueWithMeta::Media(v, cr) => serialize_with_constraints(v, cr, serializer),
BamlValueWithMeta::Enum(_enum_name, v, cr) => {
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("value", v)?;
add_failed_checks(&mut map, cr)?;
add_checks(&mut map, cr)?;
map.end()
},
BamlValueWithMeta::List(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Media(v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Enum(_enum_name, v, cr) => serialize_with_checks(v, cr, serializer),
BamlValueWithMeta::Class(_class_name, v, cr) => {
let mut map = serializer.serialize_map(None)?;
for (key, value) in v {
map.serialize_entry(key, value)?;
}
add_failed_checks(&mut map, cr)?;
add_checks(&mut map, cr)?;
map.end()
},
BamlValueWithMeta::Null(cr) => serialize_with_constraints(&(), cr, serializer),
BamlValueWithMeta::Null(cr) => serialize_with_checks(&(), cr, serializer),
}
}
}

fn serialize_with_constraints<S, T: Serialize>(
fn serialize_with_checks<S, T: Serialize>(
value: &T,
constraints: &Vec<ResponseCheck>,
checks: &Vec<ResponseCheck>,
serializer:S,

) -> Result<S::Ok, S::Error>
where S: Serializer,
{
if constraints.len() > 0 {
if !checks.is_empty() {
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("value", value)?;
map.serialize_entry("checks", constraints)?;
add_checks(&mut map, checks)?;
map.end()
} else {
value.serialize(serializer)
}
}

fn add_failed_checks<'a, S: SerializeMap>(
fn add_checks<'a, S: SerializeMap>(
map: &'a mut S,
constraints: &'a Vec<ResponseCheck>,
checks: &'a Vec<ResponseCheck>,
) -> Result<(), S::Error> {
if constraints.len() > 0 {
map.serialize_entry("failed_checks", constraints)?;
if !checks.is_empty() {
let checks_map: HashMap<_,_> = checks.iter().map(|check| (check.name.clone(), check)).collect();
map.serialize_entry("checks", &checks_map)?;
}
Ok(())
}
1 change: 1 addition & 0 deletions engine/baml-lib/jinja/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ serde_json.workspace = true
strum.workspace = true
strsim = "0.11.1"
colored = "2.1.0"
regex.workspace = true

[dev-dependencies]
env_logger = "0.11.3"
32 changes: 30 additions & 2 deletions engine/baml-lib/jinja/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use evaluate_type::{PredefinedTypes, Type, TypeError};
use minijinja::{self, value::Kwargs};
use minijinja::{context, ErrorKind, Value};
use output_format::types::OutputFormatContent;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::HashMap;
Expand All @@ -24,9 +25,17 @@ fn get_env<'a>() -> minijinja::Environment<'a> {
env.set_debug(true);
env.set_trim_blocks(true);
env.set_lstrip_blocks(true);
env.add_filter("regex_match", regex_match);
env
}

fn regex_match(value: String, regex: String) -> bool {
match Regex::new(&regex) {
Err(_) => false,
Ok(re) => re.is_match(&value)
}
}

#[derive(Debug)]
pub struct ValidationError {
pub errors: Vec<TypeError>,
Expand Down Expand Up @@ -502,6 +511,7 @@ pub fn render_expression(
// So producing the string `{{}}` requires writing the literal `"{{{{}}}}"`
let template = format!(r#"{{{{ {} }}}}"#, expression.0);
let args_dict = minijinja::Value::from_serialize(ctx);
eprintln!("{}", &template);
Ok(env.render_str(&template, &args_dict)?)
}

Expand Down Expand Up @@ -1145,8 +1155,8 @@ mod render_tests {
fn test_render_expressions() {
let ctx = vec![(
"a".to_string(),
BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into()),
)]
BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into())
), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))]
.into_iter()
.collect();

Expand All @@ -1163,4 +1173,22 @@ mod render_tests {
"true"
);
}

#[test]
fn test_render_regex_match() {
let ctx = vec![(
"a".to_string(),
BamlValue::List(vec![BamlValue::Int(1), BamlValue::Int(2), BamlValue::Int(3)].into())
), ("b".to_string(), BamlValue::String("(123)456-7890".to_string()))]
.into_iter()
.collect();
assert_eq!(
render_expression(&JinjaExpression(r##"b|regex_match("123")"##.to_string()), &ctx).unwrap(),
"true"
);
assert_eq!(
render_expression(&JinjaExpression(r##"b|regex_match("\\(?\\d{3}\\)?[-.\\s]?\\d{3}[-.\\s]?\\d{4}")"##.to_string()), &ctx).unwrap(),
"true"
)
}
}
32 changes: 32 additions & 0 deletions engine/baml-lib/jsonish/src/tests/test_unions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,35 @@ test_deserializer!(
]
}
);

const CONTACT_INFO: &str = r#"
class PhoneNumber {
value string @check({{this|regex_match("\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")}}, "valid_phone_number")
foo int? // A nullable marker indicating PhoneNumber was chosen.
}
class EmailAddress {
value string @check({{this|regex_match("^[_]*([a-z0-9]+(\.|_*)?)+@([a-z][a-z0-9-]+(\.|-*\.))+[a-z]{2,6}$")}}, "valid_email")
bar int? // A nullable marker indicating EmailAddress was chosen.
}
class ContactInfo {
primary PhoneNumber | EmailAddress
}
"#;

test_deserializer!(
test_check1,
CONTACT_INFO,
r#"{"primary": {"value": "908-797-8281"}}"#,
FieldType::Class("ContactInfo".to_string()),
{"primary": {"value": "908-797-8281", "foo": null}}
);

test_deserializer!(
test_check2,
CONTACT_INFO,
r#"{"primary": {"value": "[email protected]"}}"#,
FieldType::Class("ContactInfo".to_string()),
{"primary": {"value": "[email protected]", "bar": null}}
);
4 changes: 0 additions & 4 deletions engine/baml-lib/parser-database/src/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,24 +134,20 @@ impl<'db> Context<'db> {

let all_attributes =
iter_attributes(self.attributes.attributes.as_ref(), self.ast).collect::<Vec<_>>();
dbg!(all_attributes);
while !has_valid_attribute {
let first_attr = iter_attributes(self.attributes.attributes.as_ref(), self.ast)
.filter(|(_, attr)| names.contains(&attr.name.name()))
.find(|(attr_id, _)| self.attributes.unused_attributes.contains(attr_id));
dbg!(&first_attr);
let (attr_id, attr) = if let Some(first_attr) = first_attr {
first_attr
} else {
break;
};
dbg!(&first_attr);
self.attributes.unused_attributes.remove(&attr_id);
has_valid_attribute = self.set_attribute(attr_id, attr);
matching_name = Some(attr.name.name().to_string());
}

dbg!(&matching_name);
matching_name
}

Expand Down
16 changes: 15 additions & 1 deletion engine/baml-lib/schema-ast/src/parser/parse_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,23 @@ fn unescape_string(val: &str) -> String {
result
}

/// Parse a `JinjaExpression` from raw source. Escape backslashes,
/// because we want the user's backslash intent to be preserved in
/// the string backing the `JinjaExpression`. In other words, control
/// sequences like `\n` are intended to be forwarded to the Jinja
/// processing engine, not to break a Jinja Expression into two lines,
/// therefor the backing string should be contain "\\n".
pub fn parse_jinja_expression(token: Pair<'_>, diagnostics: &mut Diagnostics) -> Expression {
assert_correct_parser!(token, Rule::jinja_expression);
let inner_text = token.as_str()[2..token.as_str().len() - 2].to_string();
let mut inner_text = String::new();
for c in token.as_str()[2..token.as_str().len() - 2].chars() {
match c {
// When encountering a single backslash, produce two backslashes.
'\\' => inner_text.push_str("\\\\"),
// Otherwise, just copy the character.
_ => inner_text.push(c),
}
}
Expression::JinjaExpressionValue(JinjaExpression(inner_text), diagnostics.span(token.as_span()))
}

Expand Down
1 change: 1 addition & 0 deletions engine/language_client_codegen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ sugar_path = "1.2.0"
walkdir.workspace = true
semver = "1.0.23"
colored = "2.1.0"
itertools = "0.13.0"
8 changes: 4 additions & 4 deletions engine/language_client_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ impl GenerateClient for GeneratorOutputType {
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TypeCheckAttributes(pub HashSet<String>);

impl std::hash::Hash for TypeCheckAttributes {
impl <'a> std::hash::Hash for TypeCheckAttributes {
fn hash<H>(&self, state: &mut H)
where H: std::hash::Hasher
{
Expand Down Expand Up @@ -296,7 +296,7 @@ pub fn type_check_attributes(
}

/// The set of Check names associated with a type.
fn field_type_attributes(field_type: &FieldType) -> Option<TypeCheckAttributes> {
fn field_type_attributes<'a>(field_type: &FieldType) -> Option<TypeCheckAttributes> {
match field_type {
FieldType::Constrained {base, constraints} => {
let direct_sub_attributes = field_type_attributes(base);
Expand All @@ -308,8 +308,8 @@ fn field_type_attributes(field_type: &FieldType) -> Option<TypeCheckAttributes>
if matches!(level, ConstraintLevel::Check) {
Some(label.clone().expect("TODO"))
} else { None }
).collect());
if let Some(sub_attrs) = direct_sub_attributes {
).collect::<HashSet<String>>());
if let Some(ref sub_attrs) = direct_sub_attributes {
check_names.extend(&sub_attrs);
}
if !check_names.is_empty() {
Expand Down
Loading

0 comments on commit 803cec8

Please sign in to comment.