Skip to content

Commit

Permalink
Utilities for scraping check name sets from IR
Browse files Browse the repository at this point in the history
  • Loading branch information
imalsogreg committed Oct 11, 2024
1 parent 95e2b4c commit d9ea796
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 7 deletions.
25 changes: 25 additions & 0 deletions engine/baml-lib/baml-core/src/ir/repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,10 +650,16 @@ impl WithRepr<Field> for FieldWalker<'_> {

type ClassId = String;

/// A BAML Class.
#[derive(serde::Serialize, Debug)]
pub struct Class {
/// User defined class name.
pub name: ClassId,

/// Fields of the class.
pub static_fields: Vec<Node<Field>>,

/// Parameters to the class definition.
pub inputs: Vec<(String, FieldType)>,
}

Expand Down Expand Up @@ -1094,3 +1100,22 @@ impl WithRepr<Prompt> for PromptAst<'_> {
})
}
}

/// Generate an IntermediateRepr from a single block of BAML source code.
/// This is useful for generating IR test fixtures.
pub fn make_test_ir(source_code: &str) -> anyhow::Result<IntermediateRepr> {
use std::path::PathBuf;
use internal_baml_diagnostics::SourceFile;
use crate::ValidatedSchema;
use crate::validate;

let path: PathBuf = "fake_file.baml".into();
let source_file: SourceFile = (path.clone(), source_code).into();
let validated_schema: ValidatedSchema = validate(&path, vec![source_file]);
let diagnostics = &validated_schema.diagnostics;
if diagnostics.has_errors() {
return Err(anyhow::anyhow!("Source code was invalid: \n{:?}", diagnostics.errors()))
}
let ir = IntermediateRepr::from_parser_database(&validated_schema.db, validated_schema.configuration)?;
Ok(ir)
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,27 @@ fn validate_type_allowed(ctx: &mut Context<'_>, field_type: &FieldType) {

fn validate_type_constraints(ctx: &mut Context<'_>, field_type: &FieldType) {
let constraint_attrs = field_type.attributes().iter().filter(|attr| ["assert", "check"].contains(&attr.name.name())).collect::<Vec<_>>();
for Attribute { arguments, span, .. } in constraint_attrs.iter() {
for Attribute { arguments, span, name, .. } in constraint_attrs.iter() {
let arg_expressions = arguments.arguments.iter().map(|Argument{value,..}| value).collect::<Vec<_>>();

match arg_expressions.as_slice() {
[Expression::JinjaExpressionValue(_, _), Expression::StringValue(_,_)] => {},
[Expression::JinjaExpressionValue(_, _)] => {},
[Expression::JinjaExpressionValue(_, _), Expression::StringValue(s,_)] => {
// TODO: (Greg) use a real identifier parser. This is a temporary hack.
if !s.chars().all(|c| c.is_alphanumeric() || c == '_') {
ctx.push_error(DatamodelError::new_validation_error(
"Constraint names must be valid identifiers - only alphanumeric characters and underscores",
span.clone()
))
}
},
[Expression::JinjaExpressionValue(_, _)] => {
if name.to_string() == "check" {
ctx.push_error(DatamodelError::new_validation_error(
"Check constraints must have a name.",
span.clone()
))
}
},
_ => {
ctx.push_error(DatamodelError::new_validation_error(
"A constraint must have one Jinja argument such as {{ expr }}, and optionally one String label",
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/jsonish/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ test_deserializer!(
#[test]
fn test_nested_constraint_distribution() {
fn mk_constraint(s: &str) -> Constraint {
Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: s.to_string() }
Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: Some(s.to_string()) }
}

let input = FieldType::Constrained {
Expand Down
2 changes: 1 addition & 1 deletion engine/baml-lib/parser-database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub use types::{
};

use self::{context::Context, interner::StringId, types::Types};
use internal_baml_diagnostics::{DatamodelError, DatamodelWarning, Diagnostics};
use internal_baml_diagnostics::{DatamodelError, Diagnostics};
use names::Names;

/// ParserDatabase is a container for a Schema AST, together with information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ mod tests {
#[test]
fn test_nested_constraint_distribution() {
fn mk_constraint(s: &str) -> Constraint {
Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: s.to_string() }
Constraint { level: ConstraintLevel::Assert, expression: JinjaExpression(s.to_string()), label: Some(s.to_string()) }
}

let input = FieldType::Constrained {
Expand Down
152 changes: 151 additions & 1 deletion engine/language_client_codegen/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use anyhow::{Context, Result};
use baml_types::{Constraint, ConstraintLevel, FieldType};
use indexmap::IndexMap;
use internal_baml_core::{
configuration::{GeneratorDefaultClientMode, GeneratorOutputType},
ir::repr::IntermediateRepr,
};
use std::{collections::BTreeMap, path::PathBuf};
use std::{collections::{BTreeMap, HashSet}, path::PathBuf};
use version_check::{check_version, GeneratorType, VersionCheckMode};

mod dir_writer;
Expand Down Expand Up @@ -219,3 +220,152 @@ impl GenerateClient for GeneratorOutputType {
})
}
}

/// A set of names of @check attributes. This set determines the
/// way name of a Python Class or TypeScript Interface that holds
/// the results of running these checks. See TODO (Docs) for details on
/// the support types generated from checks.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TypeCheckAttributes(pub HashSet<String>);

impl std::hash::Hash for TypeCheckAttributes {
fn hash<H>(&self, state: &mut H)
where H: std::hash::Hasher
{
self.0.iter().for_each(|s| s.hash(state))
}

}

impl TypeCheckAttributes {
/// Extend one set of attributes with the contents of another.
pub fn extend(&mut self, other: &TypeCheckAttributes) {
self.0.extend(other.0.clone())
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}

/// Search the IR for all types with checks, combining the checks on each type
/// into a `TypeCheckAttributes` (a HashSet of the check names). Return a HashSet
/// of these HashSets.
///
/// For example, consider this IR defining two classes:
///
/// ``` baml
/// class Foo {
/// int @check("a") @check("b")
/// string @check("a")
/// }
///
/// class Bar {
/// bool @check("a")
/// }
/// ````
///
/// It contains two distinct `TypeCheckAttributes`:
/// - ["a"]
/// - ["a", "b"]
///
/// We will need to construct two district support types:
/// `Classes_a` and `Classes_a_b`.
pub fn type_check_attributes(
ir: &IntermediateRepr
) -> HashSet<TypeCheckAttributes> {

// A field_type can contain 0 or 1 sets of check attributes.
fn field_type_attributes(field_type: &FieldType) -> Option<TypeCheckAttributes> {
match field_type {
FieldType::Constrained {base, constraints} => {
let direct_sub_attributes = field_type_attributes(base);
let mut check_names =
TypeCheckAttributes(
constraints
.iter()
.filter_map(|Constraint {label, level, ..}|
if matches!(level, ConstraintLevel::Check) {
Some(label.clone().expect("TODO"))
} else { None }
).collect());
if let Some(sub_attrs) = direct_sub_attributes {
check_names.extend(&sub_attrs);
}
if !check_names.is_empty() {
Some(check_names)
} else {
None
}
},
_ => None
}
}

let mut all_types_in_ir: Vec<&FieldType> = Vec::new();
for class in ir.walk_classes() {
for field in class.item.elem.static_fields.iter() {
let field_type = &field.elem.r#type.elem;
all_types_in_ir.push(field_type);
}
}
for function in ir.walk_functions() {
for (_param_name, parameter) in function.item.elem.inputs.iter() {
all_types_in_ir.push(parameter);
}
let return_type = &function.item.elem.output;
all_types_in_ir.push(return_type);
}

all_types_in_ir.into_iter().filter_map(field_type_attributes).collect()

}

#[cfg(test)]
mod tests {
use internal_baml_core::ir::repr::make_test_ir;
use super::*;

/// Utility function for creating test fixtures.
fn mk_tc_attrs(names: &[&str]) -> TypeCheckAttributes {
TypeCheckAttributes(names.into_iter().map(|s| s.to_string()).collect())
}

#[test]
fn find_type_check_attributes() {
let ir = make_test_ir(
r##"
client<llm> GPT4 {
provider openai
options {
model gpt-4o
api_key env.OPENAI_API_KEY
}
}
function Go(a: int @check({{ this < 0 }}, "c")) -> Foo {
client GPT4
prompt #""#
}
class Foo {
ab int @check({{this}}, "a") @check({{this}}, "b")
a int @check({{this}}, "a")
}
class Bar {
cb int @check({{this}}, "c") @check({{this}}, "b")
nil int @description("no checks") @assert({{this}}, "a") @assert({{this}}, "d")
}
"##).expect("Valid source");

let attrs = type_check_attributes(&ir);
assert_eq!(attrs.len(), 4);
assert!(attrs.contains( &mk_tc_attrs(&["c"]) ));
assert!(attrs.contains( &mk_tc_attrs(&["a","b"]) ));
assert!(attrs.contains( &mk_tc_attrs(&["a"]) ));
assert!(attrs.contains( &mk_tc_attrs(&["c", "b"]) ));
assert!(!attrs.contains( &mk_tc_attrs(&["a", "d"]) ));
}
}

0 comments on commit d9ea796

Please sign in to comment.