Skip to content

Commit

Permalink
codegen for check support classes
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Oct 11, 2024
1 parent d9ea796 commit 7f90614
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 26 deletions.
53 changes: 27 additions & 26 deletions engine/language_client_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,32 +275,6 @@ pub fn type_check_attributes(
ir: &IntermediateRepr
) -> HashSet<TypeCheckAttributes> {

// A field_type can contain 0 or 1 sets of check attributes.
fn field_type_attributes(field_type: &FieldType) -> Option<TypeCheckAttributes> {
match field_type {
FieldType::Constrained {base, constraints} => {
let direct_sub_attributes = field_type_attributes(base);
let mut check_names =
TypeCheckAttributes(
constraints
.iter()
.filter_map(|Constraint {label, level, ..}|
if matches!(level, ConstraintLevel::Check) {
Some(label.clone().expect("TODO"))
} else { None }
).collect());
if let Some(sub_attrs) = direct_sub_attributes {
check_names.extend(&sub_attrs);
}
if !check_names.is_empty() {
Some(check_names)
} else {
None
}
},
_ => None
}
}

let mut all_types_in_ir: Vec<&FieldType> = Vec::new();
for class in ir.walk_classes() {
Expand All @@ -321,6 +295,33 @@ pub fn type_check_attributes(

}

/// The set of Check names associated with a type.
fn field_type_attributes(field_type: &FieldType) -> Option<TypeCheckAttributes> {
match field_type {
FieldType::Constrained {base, constraints} => {
let direct_sub_attributes = field_type_attributes(base);
let mut check_names =
TypeCheckAttributes(
constraints
.iter()
.filter_map(|Constraint {label, level, ..}|
if matches!(level, ConstraintLevel::Check) {
Some(label.clone().expect("TODO"))
} else { None }
).collect());
if let Some(sub_attrs) = direct_sub_attributes {
check_names.extend(&sub_attrs);
}
if !check_names.is_empty() {
Some(check_names)
} else {
None
}
},
_ => None
}
}

#[cfg(test)]
mod tests {
use internal_baml_core::ir::repr::make_test_ir;
Expand Down
26 changes: 26 additions & 0 deletions engine/language_client_codegen/src/python/generate_types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use anyhow::Result;

use crate::TypeCheckAttributes;

use super::python_language_features::ToPython;
use internal_baml_core::ir::{
repr::IntermediateRepr, ClassWalker, EnumWalker, FieldType, IRHelper,
Expand Down Expand Up @@ -157,6 +159,23 @@ pub fn add_default_value(node: &FieldType, type_str: &String) -> String {
}
}

fn type_name_for_checks(checks: &TypeCheckAttributes) -> String {
let mut name = "Checks_".to_string();
for check_name in checks.0.iter() {
name.push_str(check_name);
}
name
}

fn type_def_for_checks(checks: &TypeCheckAttributes) -> String {
let mut source_code = format!("class {}(BaseModel):\n", type_name_for_checks(checks));
for check_name in checks.0.iter() {
source_code.push_str(&format!(" {check_name}: baml_py.Check\n"));
}
source_code.push_str("\n");
source_code
}

trait ToTypeReferenceInTypeDefinition {
fn to_type_ref(&self, ir: &IntermediateRepr) -> String;
fn to_partial_type_ref(&self, ir: &IntermediateRepr, wrapped: bool) -> String;
Expand Down Expand Up @@ -200,6 +219,13 @@ impl ToTypeReferenceInTypeDefinition for FieldType {
),
FieldType::Optional(inner) => format!("Optional[{}]", inner.to_type_ref(ir)),
FieldType::Constrained{base,constraints} => {
match field_type_attributes(self) {
Some(checks) => {
let base_type_ref = base.to_type_ref(ir);
let checks_type_ref = type_name_for_checks(checks);
format!("baml_py.Checked[{base_type_ref},{checks_type_ref}]")
}
}
if !constraints.is_empty() {
format!("baml_py.Checked[{}]", base.to_type_ref(ir))
} else {
Expand Down

0 comments on commit 7f90614

Please sign in to comment.